statement.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. package gorm
  2. import (
  3. "context"
  4. "database/sql"
  5. "database/sql/driver"
  6. "fmt"
  7. "reflect"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "gorm.io/gorm/clause"
  13. "gorm.io/gorm/logger"
  14. "gorm.io/gorm/schema"
  15. "gorm.io/gorm/utils"
  16. )
  17. // Statement statement
  18. type Statement struct {
  19. *DB
  20. TableExpr *clause.Expr
  21. Table string
  22. Model interface{}
  23. Unscoped bool
  24. Dest interface{}
  25. ReflectValue reflect.Value
  26. Clauses map[string]clause.Clause
  27. Distinct bool
  28. Selects []string // selected columns
  29. Omits []string // omit columns
  30. Joins []join
  31. Preloads map[string][]interface{}
  32. Settings sync.Map
  33. ConnPool ConnPool
  34. Schema *schema.Schema
  35. Context context.Context
  36. RaiseErrorOnNotFound bool
  37. UpdatingColumn bool
  38. SQL strings.Builder
  39. Vars []interface{}
  40. CurDestIndex int
  41. attrs []interface{}
  42. assigns []interface{}
  43. }
  44. type join struct {
  45. Name string
  46. Conds []interface{}
  47. }
  48. // StatementModifier statement modifier interface
  49. type StatementModifier interface {
  50. ModifyStatement(*Statement)
  51. }
  52. // Write write string
  53. func (stmt *Statement) WriteString(str string) (int, error) {
  54. return stmt.SQL.WriteString(str)
  55. }
  56. // Write write string
  57. func (stmt *Statement) WriteByte(c byte) error {
  58. return stmt.SQL.WriteByte(c)
  59. }
  60. // WriteQuoted write quoted value
  61. func (stmt *Statement) WriteQuoted(value interface{}) {
  62. stmt.QuoteTo(&stmt.SQL, value)
  63. }
  64. // QuoteTo write quoted value to writer
  65. func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
  66. switch v := field.(type) {
  67. case clause.Table:
  68. if v.Name == clause.CurrentTable {
  69. if stmt.TableExpr != nil {
  70. stmt.TableExpr.Build(stmt)
  71. } else {
  72. stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
  73. }
  74. } else if v.Raw {
  75. writer.WriteString(v.Name)
  76. } else {
  77. stmt.DB.Dialector.QuoteTo(writer, v.Name)
  78. }
  79. if v.Alias != "" {
  80. writer.WriteByte(' ')
  81. stmt.DB.Dialector.QuoteTo(writer, v.Alias)
  82. }
  83. case clause.Column:
  84. if v.Table != "" {
  85. if v.Table == clause.CurrentTable {
  86. stmt.DB.Dialector.QuoteTo(writer, stmt.Table)
  87. } else {
  88. stmt.DB.Dialector.QuoteTo(writer, v.Table)
  89. }
  90. writer.WriteByte('.')
  91. }
  92. if v.Name == clause.PrimaryKey {
  93. if stmt.Schema == nil {
  94. stmt.DB.AddError(ErrModelValueRequired)
  95. } else if stmt.Schema.PrioritizedPrimaryField != nil {
  96. stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
  97. } else if len(stmt.Schema.DBNames) > 0 {
  98. stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0])
  99. }
  100. } else if v.Raw {
  101. writer.WriteString(v.Name)
  102. } else {
  103. stmt.DB.Dialector.QuoteTo(writer, v.Name)
  104. }
  105. if v.Alias != "" {
  106. writer.WriteString(" AS ")
  107. stmt.DB.Dialector.QuoteTo(writer, v.Alias)
  108. }
  109. case []clause.Column:
  110. writer.WriteByte('(')
  111. for idx, d := range v {
  112. if idx > 0 {
  113. writer.WriteString(",")
  114. }
  115. stmt.QuoteTo(writer, d)
  116. }
  117. writer.WriteByte(')')
  118. case string:
  119. stmt.DB.Dialector.QuoteTo(writer, v)
  120. case []string:
  121. writer.WriteByte('(')
  122. for idx, d := range v {
  123. if idx > 0 {
  124. writer.WriteString(",")
  125. }
  126. stmt.DB.Dialector.QuoteTo(writer, d)
  127. }
  128. writer.WriteByte(')')
  129. default:
  130. stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
  131. }
  132. }
  133. // Quote returns quoted value
  134. func (stmt *Statement) Quote(field interface{}) string {
  135. var builder strings.Builder
  136. stmt.QuoteTo(&builder, field)
  137. return builder.String()
  138. }
  139. // Write write string
  140. func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
  141. for idx, v := range vars {
  142. if idx > 0 {
  143. writer.WriteByte(',')
  144. }
  145. switch v := v.(type) {
  146. case sql.NamedArg:
  147. stmt.Vars = append(stmt.Vars, v.Value)
  148. case clause.Column, clause.Table:
  149. stmt.QuoteTo(writer, v)
  150. case Valuer:
  151. stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
  152. case clause.Expr:
  153. var varStr strings.Builder
  154. var sql = v.SQL
  155. for _, arg := range v.Vars {
  156. stmt.Vars = append(stmt.Vars, arg)
  157. stmt.DB.Dialector.BindVarTo(&varStr, stmt, arg)
  158. sql = strings.Replace(sql, "?", varStr.String(), 1)
  159. varStr.Reset()
  160. }
  161. writer.WriteString(sql)
  162. case driver.Valuer:
  163. stmt.Vars = append(stmt.Vars, v)
  164. stmt.DB.Dialector.BindVarTo(writer, stmt, v)
  165. case []byte:
  166. stmt.Vars = append(stmt.Vars, v)
  167. stmt.DB.Dialector.BindVarTo(writer, stmt, v)
  168. case []interface{}:
  169. if len(v) > 0 {
  170. writer.WriteByte('(')
  171. stmt.AddVar(writer, v...)
  172. writer.WriteByte(')')
  173. } else {
  174. writer.WriteString("(NULL)")
  175. }
  176. case *DB:
  177. subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true, WithConditions: true}).getInstance()
  178. subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
  179. subdb.callbacks.Query().Execute(subdb)
  180. writer.WriteString(subdb.Statement.SQL.String())
  181. stmt.Vars = subdb.Statement.Vars
  182. default:
  183. switch rv := reflect.ValueOf(v); rv.Kind() {
  184. case reflect.Slice, reflect.Array:
  185. if rv.Len() == 0 {
  186. writer.WriteString("(NULL)")
  187. } else {
  188. writer.WriteByte('(')
  189. for i := 0; i < rv.Len(); i++ {
  190. if i > 0 {
  191. writer.WriteByte(',')
  192. }
  193. stmt.AddVar(writer, rv.Index(i).Interface())
  194. }
  195. writer.WriteByte(')')
  196. }
  197. default:
  198. stmt.Vars = append(stmt.Vars, v)
  199. stmt.DB.Dialector.BindVarTo(writer, stmt, v)
  200. }
  201. }
  202. }
  203. }
  204. // AddClause add clause
  205. func (stmt *Statement) AddClause(v clause.Interface) {
  206. if optimizer, ok := v.(StatementModifier); ok {
  207. optimizer.ModifyStatement(stmt)
  208. } else {
  209. name := v.Name()
  210. c := stmt.Clauses[name]
  211. c.Name = name
  212. v.MergeClause(&c)
  213. stmt.Clauses[name] = c
  214. }
  215. }
  216. // AddClauseIfNotExists add clause if not exists
  217. func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
  218. if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil {
  219. stmt.AddClause(v)
  220. }
  221. }
  222. // BuildCondition build condition
  223. func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (conds []clause.Expression) {
  224. if s, ok := query.(string); ok {
  225. // if it is a number, then treats it as primary key
  226. if _, err := strconv.Atoi(s); err != nil {
  227. if s == "" && len(args) == 0 {
  228. return
  229. } else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
  230. // looks like a where condition
  231. return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
  232. } else if len(args) > 0 && strings.Contains(s, "@") {
  233. // looks like a named query
  234. return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
  235. } else if len(args) == 1 {
  236. return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
  237. }
  238. }
  239. }
  240. args = append([]interface{}{query}, args...)
  241. for _, arg := range args {
  242. if valuer, ok := arg.(driver.Valuer); ok {
  243. arg, _ = valuer.Value()
  244. }
  245. switch v := arg.(type) {
  246. case clause.Expression:
  247. conds = append(conds, v)
  248. case *DB:
  249. if cs, ok := v.Statement.Clauses["WHERE"]; ok {
  250. if where, ok := cs.Expression.(clause.Where); ok {
  251. conds = append(conds, clause.And(where.Exprs...))
  252. } else if cs.Expression != nil {
  253. conds = append(conds, cs.Expression)
  254. }
  255. }
  256. case map[interface{}]interface{}:
  257. for i, j := range v {
  258. conds = append(conds, clause.Eq{Column: i, Value: j})
  259. }
  260. case map[string]string:
  261. var keys = make([]string, 0, len(v))
  262. for i := range v {
  263. keys = append(keys, i)
  264. }
  265. sort.Strings(keys)
  266. for _, key := range keys {
  267. conds = append(conds, clause.Eq{Column: key, Value: v[key]})
  268. }
  269. case map[string]interface{}:
  270. var keys = make([]string, 0, len(v))
  271. for i := range v {
  272. keys = append(keys, i)
  273. }
  274. sort.Strings(keys)
  275. for _, key := range keys {
  276. reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
  277. switch reflectValue.Kind() {
  278. case reflect.Slice, reflect.Array:
  279. if _, ok := v[key].(driver.Valuer); ok {
  280. conds = append(conds, clause.Eq{Column: key, Value: v[key]})
  281. } else if _, ok := v[key].(Valuer); ok {
  282. conds = append(conds, clause.Eq{Column: key, Value: v[key]})
  283. } else {
  284. values := make([]interface{}, reflectValue.Len())
  285. for i := 0; i < reflectValue.Len(); i++ {
  286. values[i] = reflectValue.Index(i).Interface()
  287. }
  288. conds = append(conds, clause.IN{Column: key, Values: values})
  289. }
  290. default:
  291. conds = append(conds, clause.Eq{Column: key, Value: v[key]})
  292. }
  293. }
  294. default:
  295. reflectValue := reflect.Indirect(reflect.ValueOf(arg))
  296. if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
  297. switch reflectValue.Kind() {
  298. case reflect.Struct:
  299. for _, field := range s.Fields {
  300. if field.Readable {
  301. if v, isZero := field.ValueOf(reflectValue); !isZero {
  302. if field.DBName != "" {
  303. conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
  304. } else if field.DataType != "" {
  305. conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
  306. }
  307. }
  308. }
  309. }
  310. case reflect.Slice, reflect.Array:
  311. for i := 0; i < reflectValue.Len(); i++ {
  312. for _, field := range s.Fields {
  313. if field.Readable {
  314. if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero {
  315. if field.DBName != "" {
  316. conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
  317. } else if field.DataType != "" {
  318. conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
  319. }
  320. }
  321. }
  322. }
  323. }
  324. }
  325. } else if len(conds) == 0 {
  326. if len(args) == 1 {
  327. switch reflectValue.Kind() {
  328. case reflect.Slice, reflect.Array:
  329. values := make([]interface{}, reflectValue.Len())
  330. for i := 0; i < reflectValue.Len(); i++ {
  331. values[i] = reflectValue.Index(i).Interface()
  332. }
  333. if len(values) > 0 {
  334. conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
  335. }
  336. return
  337. }
  338. }
  339. conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
  340. }
  341. }
  342. }
  343. return
  344. }
  345. // Build build sql with clauses names
  346. func (stmt *Statement) Build(clauses ...string) {
  347. var firstClauseWritten bool
  348. for _, name := range clauses {
  349. if c, ok := stmt.Clauses[name]; ok {
  350. if firstClauseWritten {
  351. stmt.WriteByte(' ')
  352. }
  353. firstClauseWritten = true
  354. if b, ok := stmt.DB.ClauseBuilders[name]; ok {
  355. b(c, stmt)
  356. } else {
  357. c.Build(stmt)
  358. }
  359. }
  360. }
  361. }
  362. func (stmt *Statement) Parse(value interface{}) (err error) {
  363. if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
  364. if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
  365. stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
  366. stmt.Table = tables[1]
  367. return
  368. }
  369. stmt.Table = stmt.Schema.Table
  370. }
  371. return err
  372. }
  373. func (stmt *Statement) clone() *Statement {
  374. newStmt := &Statement{
  375. TableExpr: stmt.TableExpr,
  376. Table: stmt.Table,
  377. Model: stmt.Model,
  378. Unscoped: stmt.Unscoped,
  379. Dest: stmt.Dest,
  380. ReflectValue: stmt.ReflectValue,
  381. Clauses: map[string]clause.Clause{},
  382. Distinct: stmt.Distinct,
  383. Selects: stmt.Selects,
  384. Omits: stmt.Omits,
  385. Preloads: map[string][]interface{}{},
  386. ConnPool: stmt.ConnPool,
  387. Schema: stmt.Schema,
  388. Context: stmt.Context,
  389. RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
  390. UpdatingColumn: stmt.UpdatingColumn,
  391. }
  392. for k, c := range stmt.Clauses {
  393. newStmt.Clauses[k] = c
  394. }
  395. for k, p := range stmt.Preloads {
  396. newStmt.Preloads[k] = p
  397. }
  398. if len(stmt.Joins) > 0 {
  399. newStmt.Joins = make([]join, len(stmt.Joins))
  400. copy(newStmt.Joins, stmt.Joins)
  401. }
  402. stmt.Settings.Range(func(k, v interface{}) bool {
  403. newStmt.Settings.Store(k, v)
  404. return true
  405. })
  406. return newStmt
  407. }
  408. // Helpers
  409. // SetColumn set column's value
  410. func (stmt *Statement) SetColumn(name string, value interface{}) {
  411. if v, ok := stmt.Dest.(map[string]interface{}); ok {
  412. v[name] = value
  413. } else if stmt.Schema != nil {
  414. if field := stmt.Schema.LookUpField(name); field != nil {
  415. destValue := reflect.ValueOf(stmt.Dest)
  416. for destValue.Kind() == reflect.Ptr {
  417. destValue = destValue.Elem()
  418. }
  419. if stmt.ReflectValue != destValue {
  420. if !destValue.CanAddr() {
  421. destValueCanAddr := reflect.New(destValue.Type())
  422. destValueCanAddr.Elem().Set(destValue)
  423. stmt.Dest = destValueCanAddr.Interface()
  424. destValue = destValueCanAddr.Elem()
  425. }
  426. switch destValue.Kind() {
  427. case reflect.Struct:
  428. field.Set(destValue, value)
  429. default:
  430. stmt.AddError(ErrInvalidData)
  431. }
  432. }
  433. switch stmt.ReflectValue.Kind() {
  434. case reflect.Slice, reflect.Array:
  435. field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
  436. case reflect.Struct:
  437. field.Set(stmt.ReflectValue, value)
  438. }
  439. } else {
  440. stmt.AddError(ErrInvalidField)
  441. }
  442. } else {
  443. stmt.AddError(ErrInvalidField)
  444. }
  445. }
  446. // Changed check model changed or not when updating
  447. func (stmt *Statement) Changed(fields ...string) bool {
  448. modelValue := stmt.ReflectValue
  449. switch modelValue.Kind() {
  450. case reflect.Slice, reflect.Array:
  451. modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
  452. }
  453. selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
  454. changed := func(field *schema.Field) bool {
  455. fieldValue, _ := field.ValueOf(modelValue)
  456. if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
  457. if v, ok := stmt.Dest.(map[string]interface{}); ok {
  458. if fv, ok := v[field.Name]; ok {
  459. return !utils.AssertEqual(fv, fieldValue)
  460. } else if fv, ok := v[field.DBName]; ok {
  461. return !utils.AssertEqual(fv, fieldValue)
  462. }
  463. } else {
  464. destValue := reflect.ValueOf(stmt.Dest)
  465. for destValue.Kind() == reflect.Ptr {
  466. destValue = destValue.Elem()
  467. }
  468. changedValue, zero := field.ValueOf(destValue)
  469. return !zero && !utils.AssertEqual(changedValue, fieldValue)
  470. }
  471. }
  472. return false
  473. }
  474. if len(fields) == 0 {
  475. for _, field := range stmt.Schema.FieldsByDBName {
  476. if changed(field) {
  477. return true
  478. }
  479. }
  480. } else {
  481. for _, name := range fields {
  482. if field := stmt.Schema.LookUpField(name); field != nil {
  483. if changed(field) {
  484. return true
  485. }
  486. }
  487. }
  488. }
  489. return false
  490. }
  491. // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
  492. func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
  493. results := map[string]bool{}
  494. notRestricted := false
  495. // select columns
  496. for _, column := range stmt.Selects {
  497. if column == "*" {
  498. notRestricted = true
  499. for _, dbName := range stmt.Schema.DBNames {
  500. results[dbName] = true
  501. }
  502. } else if column == clause.Associations && stmt.Schema != nil {
  503. for _, rel := range stmt.Schema.Relationships.Relations {
  504. results[rel.Name] = true
  505. }
  506. } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
  507. results[field.DBName] = true
  508. } else {
  509. results[column] = true
  510. }
  511. }
  512. // omit columns
  513. for _, omit := range stmt.Omits {
  514. if omit == clause.Associations {
  515. if stmt.Schema != nil {
  516. for _, rel := range stmt.Schema.Relationships.Relations {
  517. results[rel.Name] = false
  518. }
  519. }
  520. } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
  521. results[field.DBName] = false
  522. } else {
  523. results[omit] = false
  524. }
  525. }
  526. if stmt.Schema != nil {
  527. for _, field := range stmt.Schema.Fields {
  528. name := field.DBName
  529. if name == "" {
  530. name = field.Name
  531. }
  532. if requireCreate && !field.Creatable {
  533. results[name] = false
  534. } else if requireUpdate && !field.Updatable {
  535. results[name] = false
  536. }
  537. }
  538. }
  539. return results, !notRestricted && len(stmt.Selects) > 0
  540. }