gdb_driver_mssql.go 9.3 KB


  1. // Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
  2. //
  3. // This Source Code Form is subject to the terms of the MIT License.
  4. // If a copy of the MIT was not distributed with this file,
  5. // You can obtain one at https://github.com/gogf/gf.
  6. //
  7. // Note:
  8. // 1. It needs manually import: _ "github.com/denisenkom/go-mssqldb"
  9. // 2. It does not support Save/Replace features.
  10. // 3. It does not support LastInsertId.
  11. package gdb
  12. import (
  13. "context"
  14. "database/sql"
  15. "fmt"
  16. "github.com/gogf/gf/errors/gcode"
  17. "strconv"
  18. "strings"
  19. "github.com/gogf/gf/errors/gerror"
  20. "github.com/gogf/gf/internal/intlog"
  21. "github.com/gogf/gf/text/gstr"
  22. "github.com/gogf/gf/text/gregex"
  23. )
  24. // DriverMssql is the driver for SQL server database.
  25. type DriverMssql struct {
  26. *Core
  27. }
  28. // New creates and returns a database object for SQL server.
  29. // It implements the interface of gdb.Driver for extra database driver installation.
  30. func (d *DriverMssql) New(core *Core, node *ConfigNode) (DB, error) {
  31. return &DriverMssql{
  32. Core: core,
  33. }, nil
  34. }
  35. // Open creates and returns a underlying sql.DB object for mssql.
  36. func (d *DriverMssql) Open(config *ConfigNode) (*sql.DB, error) {
  37. source := ""
  38. if config.Link != "" {
  39. source = config.Link
  40. } else {
  41. source = fmt.Sprintf(
  42. "user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable",
  43. config.User, config.Pass, config.Host, config.Port, config.Name,
  44. )
  45. }
  46. intlog.Printf(d.GetCtx(), "Open: %s", source)
  47. if db, err := sql.Open("sqlserver", source); err == nil {
  48. return db, nil
  49. } else {
  50. return nil, err
  51. }
  52. }
  53. // FilteredLink retrieves and returns filtered `linkInfo` that can be using for
  54. // logging or tracing purpose.
  55. func (d *DriverMssql) FilteredLink() string {
  56. linkInfo := d.GetConfig().Link
  57. if linkInfo == "" {
  58. return ""
  59. }
  60. s, _ := gregex.ReplaceString(
  61. `(.+);\s*password=(.+);\s*server=(.+)`,
  62. `$1;password=xxx;server=$3`,
  63. d.GetConfig().Link,
  64. )
  65. return s
  66. }
  67. // GetChars returns the security char for this type of database.
  68. func (d *DriverMssql) GetChars() (charLeft string, charRight string) {
  69. return "\"", "\""
  70. }
  71. // DoCommit deals with the sql string before commits it to underlying sql driver.
  72. func (d *DriverMssql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
  73. defer func() {
  74. newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs)
  75. }()
  76. var index int
  77. // Convert place holder char '?' to string "@px".
  78. str, _ := gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
  79. index++
  80. return fmt.Sprintf("@p%d", index)
  81. })
  82. str, _ = gregex.ReplaceString("\"", "", str)
  83. return d.parseSql(str), args, nil
  84. }
  85. // parseSql does some replacement of the sql before commits it to underlying driver,
  86. // for support of microsoft sql server.
  87. func (d *DriverMssql) parseSql(sql string) string {
  88. // SELECT * FROM USER WHERE ID=1 LIMIT 1
  89. if m, _ := gregex.MatchString(`^SELECT(.+)LIMIT 1$`, sql); len(m) > 1 {
  90. return fmt.Sprintf(`SELECT TOP 1 %s`, m[1])
  91. }
  92. // SELECT * FROM USER WHERE AGE>18 ORDER BY ID DESC LIMIT 100, 200
  93. patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))`
  94. if gregex.IsMatchString(patten, sql) == false {
  95. return sql
  96. }
  97. res, err := gregex.MatchAllString(patten, sql)
  98. if err != nil {
  99. return ""
  100. }
  101. index := 0
  102. keyword := strings.TrimSpace(res[index][0])
  103. keyword = strings.ToUpper(keyword)
  104. index++
  105. switch keyword {
  106. case "SELECT":
  107. // LIMIT statement checks.
  108. if len(res) < 2 ||
  109. (strings.HasPrefix(res[index][0], "LIMIT") == false &&
  110. strings.HasPrefix(res[index][0], "limit") == false) {
  111. break
  112. }
  113. if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) == false {
  114. break
  115. }
  116. // ORDER BY statement checks.
  117. selectStr := ""
  118. orderStr := ""
  119. haveOrder := gregex.IsMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql)
  120. if haveOrder {
  121. queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql)
  122. if len(queryExpr) != 4 ||
  123. strings.EqualFold(queryExpr[1], "SELECT") == false ||
  124. strings.EqualFold(queryExpr[3], "ORDER BY") == false {
  125. break
  126. }
  127. selectStr = queryExpr[2]
  128. orderExpr, _ := gregex.MatchString("((?i)ORDER BY)(.+)((?i)LIMIT)", sql)
  129. if len(orderExpr) != 4 ||
  130. strings.EqualFold(orderExpr[1], "ORDER BY") == false ||
  131. strings.EqualFold(orderExpr[3], "LIMIT") == false {
  132. break
  133. }
  134. orderStr = orderExpr[2]
  135. } else {
  136. queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql)
  137. if len(queryExpr) != 4 ||
  138. strings.EqualFold(queryExpr[1], "SELECT") == false ||
  139. strings.EqualFold(queryExpr[3], "LIMIT") == false {
  140. break
  141. }
  142. selectStr = queryExpr[2]
  143. }
  144. first, limit := 0, 0
  145. for i := 1; i < len(res[index]); i++ {
  146. if len(strings.TrimSpace(res[index][i])) == 0 {
  147. continue
  148. }
  149. if strings.HasPrefix(res[index][i], "LIMIT") ||
  150. strings.HasPrefix(res[index][i], "limit") {
  151. first, _ = strconv.Atoi(res[index][i+1])
  152. limit, _ = strconv.Atoi(res[index][i+2])
  153. break
  154. }
  155. }
  156. if haveOrder {
  157. sql = fmt.Sprintf(
  158. "SELECT * FROM "+
  159. "(SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ "+
  160. "WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d",
  161. orderStr, selectStr, first, first+limit,
  162. )
  163. } else {
  164. if first == 0 {
  165. first = limit
  166. }
  167. sql = fmt.Sprintf(
  168. "SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ",
  169. limit, first+limit, selectStr,
  170. )
  171. }
  172. default:
  173. }
  174. return sql
  175. }
  176. // Tables retrieves and returns the tables of current schema.
  177. // It's mainly used in cli tool chain for automatically generating the models.
  178. func (d *DriverMssql) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
  179. var result Result
  180. link, err := d.SlaveLink(schema...)
  181. if err != nil {
  182. return nil, err
  183. }
  184. result, err = d.DoGetAll(ctx, link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`)
  185. if err != nil {
  186. return
  187. }
  188. for _, m := range result {
  189. for _, v := range m {
  190. tables = append(tables, v.String())
  191. }
  192. }
  193. return
  194. }
  195. // TableFields retrieves and returns the fields information of specified table of current schema.
  196. //
  197. // Also see DriverMysql.TableFields.
  198. func (d *DriverMssql) TableFields(ctx context.Context, table string, schema ...string) (fields map[string]*TableField, err error) {
  199. charL, charR := d.GetChars()
  200. table = gstr.Trim(table, charL+charR)
  201. if gstr.Contains(table, " ") {
  202. return nil, gerror.NewCode(gcode.CodeInvalidParameter, "function TableFields supports only single table operations")
  203. }
  204. useSchema := d.db.GetSchema()
  205. if len(schema) > 0 && schema[0] != "" {
  206. useSchema = schema[0]
  207. }
  208. tableFieldsCacheKey := fmt.Sprintf(
  209. `mssql_table_fields_%s_%s@group:%s`,
  210. table, useSchema, d.GetGroup(),
  211. )
  212. v := tableFieldsMap.GetOrSetFuncLock(tableFieldsCacheKey, func() interface{} {
  213. var (
  214. result Result
  215. link, err = d.SlaveLink(useSchema)
  216. )
  217. if err != nil {
  218. return nil
  219. }
  220. structureSql := fmt.Sprintf(`
  221. SELECT
  222. a.name Field,
  223. CASE b.name
  224. WHEN 'datetime' THEN 'datetime'
  225. WHEN 'numeric' THEN b.name + '(' + convert(varchar(20), a.xprec) + ',' + convert(varchar(20), a.xscale) + ')'
  226. WHEN 'char' THEN b.name + '(' + convert(varchar(20), a.length)+ ')'
  227. WHEN 'varchar' THEN b.name + '(' + convert(varchar(20), a.length)+ ')'
  228. ELSE b.name + '(' + convert(varchar(20),a.length)+ ')' END AS Type,
  229. CASE WHEN a.isnullable=1 THEN 'YES' ELSE 'NO' end AS [Null],
  230. CASE WHEN exists (
  231. SELECT 1 FROM sysobjects WHERE xtype='PK' AND name IN (
  232. SELECT name FROM sysindexes WHERE indid IN (
  233. SELECT indid FROM sysindexkeys WHERE id = a.id AND colid=a.colid
  234. )
  235. )
  236. ) THEN 'PRI' ELSE '' END AS [Key],
  237. CASE WHEN COLUMNPROPERTY(a.id,a.name,'IsIdentity')=1 THEN 'auto_increment' ELSE '' END Extra,
  238. isnull(e.text,'') AS [Default],
  239. isnull(g.[value],'') AS [Comment]
  240. FROM syscolumns a
  241. LEFT JOIN systypes b ON a.xtype=b.xtype AND a.xusertype=b.xusertype
  242. INNER JOIN sysobjects d ON a.id=d.id AND d.xtype='U' AND d.name<>'dtproperties'
  243. LEFT JOIN syscomments e ON a.cdefault=e.id
  244. LEFT JOIN sys.extended_properties g ON a.id=g.major_id AND a.colid=g.minor_id
  245. LEFT JOIN sys.extended_properties f ON d.id=f.major_id AND f.minor_id =0
  246. WHERE d.name='%s'
  247. ORDER BY a.id,a.colorder`,
  248. strings.ToUpper(table),
  249. )
  250. structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql))
  251. result, err = d.DoGetAll(ctx, link, structureSql)
  252. if err != nil {
  253. return nil
  254. }
  255. fields = make(map[string]*TableField)
  256. for i, m := range result {
  257. fields[strings.ToLower(m["Field"].String())] = &TableField{
  258. Index: i,
  259. Name: strings.ToLower(m["Field"].String()),
  260. Type: strings.ToLower(m["Type"].String()),
  261. Null: m["Null"].Bool(),
  262. Key: m["Key"].String(),
  263. Default: m["Default"].Val(),
  264. Extra: m["Extra"].String(),
  265. Comment: m["Comment"].String(),
  266. }
  267. }
  268. return fields
  269. })
  270. if v != nil {
  271. fields = v.(map[string]*TableField)
  272. }
  273. return
  274. }
  275. // DoInsert is not supported in mssql.
  276. func (d *DriverMssql) DoInsert(ctx context.Context, link Link, table string, list List, option DoInsertOption) (result sql.Result, err error) {
  277. switch option.InsertOption {
  278. case insertOptionSave:
  279. return nil, gerror.NewCode(gcode.CodeNotSupported, `Save operation is not supported by mssql driver`)
  280. case insertOptionReplace:
  281. return nil, gerror.NewCode(gcode.CodeNotSupported, `Replace operation is not supported by mssql driver`)
  282. default:
  283. return d.Core.DoInsert(ctx, link, table, list, option)
  284. }
  285. }