finisher_api.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. package gorm
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "reflect"
  7. "strings"
  8. "gorm.io/gorm/clause"
  9. "gorm.io/gorm/logger"
  10. "gorm.io/gorm/schema"
  11. "gorm.io/gorm/utils"
  12. )
  13. // Create insert the value into database
  14. func (db *DB) Create(value interface{}) (tx *DB) {
  15. tx = db.getInstance()
  16. tx.Statement.Dest = value
  17. tx.callbacks.Create().Execute(tx)
  18. return
  19. }
  20. // Save update value in database, if the value doesn't have primary key, will insert it
  21. func (db *DB) Save(value interface{}) (tx *DB) {
  22. tx = db.getInstance()
  23. tx.Statement.Dest = value
  24. reflectValue := reflect.Indirect(reflect.ValueOf(value))
  25. switch reflectValue.Kind() {
  26. case reflect.Slice, reflect.Array:
  27. tx.Statement.UpdatingColumn = true
  28. tx.callbacks.Create().Execute(tx)
  29. case reflect.Struct:
  30. if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
  31. for _, pf := range tx.Statement.Schema.PrimaryFields {
  32. if _, isZero := pf.ValueOf(reflectValue); isZero {
  33. tx.callbacks.Create().Execute(tx)
  34. return
  35. }
  36. }
  37. }
  38. fallthrough
  39. default:
  40. selectedUpdate := len(tx.Statement.Selects) != 0
  41. // when updating, use all fields including those zero-value fields
  42. if !selectedUpdate {
  43. tx.Statement.Selects = append(tx.Statement.Selects, "*")
  44. }
  45. tx.callbacks.Update().Execute(tx)
  46. if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
  47. result := reflect.New(tx.Statement.Schema.ModelType).Interface()
  48. if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
  49. return tx.Create(value)
  50. }
  51. }
  52. }
  53. return
  54. }
  55. // First find first record that match given conditions, order by primary key
  56. func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
  57. tx = db.Limit(1).Order(clause.OrderByColumn{
  58. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  59. })
  60. if len(conds) > 0 {
  61. tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)})
  62. }
  63. tx.Statement.RaiseErrorOnNotFound = true
  64. tx.Statement.Dest = dest
  65. tx.callbacks.Query().Execute(tx)
  66. return
  67. }
  68. // Take return a record that match given conditions, the order will depend on the database implementation
  69. func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
  70. tx = db.Limit(1)
  71. if len(conds) > 0 {
  72. tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)})
  73. }
  74. tx.Statement.RaiseErrorOnNotFound = true
  75. tx.Statement.Dest = dest
  76. tx.callbacks.Query().Execute(tx)
  77. return
  78. }
  79. // Last find last record that match given conditions, order by primary key
  80. func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
  81. tx = db.Limit(1).Order(clause.OrderByColumn{
  82. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  83. Desc: true,
  84. })
  85. if len(conds) > 0 {
  86. tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)})
  87. }
  88. tx.Statement.RaiseErrorOnNotFound = true
  89. tx.Statement.Dest = dest
  90. tx.callbacks.Query().Execute(tx)
  91. return
  92. }
  93. // Find find records that match given conditions
  94. func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
  95. tx = db.getInstance()
  96. if len(conds) > 0 {
  97. tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)})
  98. }
  99. tx.Statement.Dest = dest
  100. tx.callbacks.Query().Execute(tx)
  101. return
  102. }
  103. // FindInBatches find records in batches
  104. func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) (tx *DB) {
  105. tx = db.Session(&Session{WithConditions: true})
  106. rowsAffected := int64(0)
  107. batch := 0
  108. for {
  109. result := tx.Limit(batchSize).Offset(batch * batchSize).Find(dest)
  110. rowsAffected += result.RowsAffected
  111. batch++
  112. if result.Error == nil && result.RowsAffected != 0 {
  113. tx.AddError(fc(result, batch))
  114. }
  115. if tx.Error != nil || int(result.RowsAffected) < batchSize {
  116. break
  117. }
  118. }
  119. tx.RowsAffected = rowsAffected
  120. return
  121. }
  122. func (tx *DB) assignInterfacesToValue(values ...interface{}) {
  123. for _, value := range values {
  124. switch v := value.(type) {
  125. case []clause.Expression:
  126. for _, expr := range v {
  127. if eq, ok := expr.(clause.Eq); ok {
  128. switch column := eq.Column.(type) {
  129. case string:
  130. if field := tx.Statement.Schema.LookUpField(column); field != nil {
  131. tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
  132. }
  133. case clause.Column:
  134. if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
  135. tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
  136. }
  137. }
  138. } else if andCond, ok := expr.(clause.AndConditions); ok {
  139. tx.assignInterfacesToValue(andCond.Exprs)
  140. }
  141. }
  142. case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
  143. exprs := tx.Statement.BuildCondition(value)
  144. tx.assignInterfacesToValue(exprs)
  145. default:
  146. if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
  147. reflectValue := reflect.Indirect(reflect.ValueOf(value))
  148. switch reflectValue.Kind() {
  149. case reflect.Struct:
  150. for _, f := range s.Fields {
  151. if f.Readable {
  152. if v, isZero := f.ValueOf(reflectValue); !isZero {
  153. if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
  154. tx.AddError(field.Set(tx.Statement.ReflectValue, v))
  155. }
  156. }
  157. }
  158. }
  159. }
  160. } else if len(values) > 0 {
  161. exprs := tx.Statement.BuildCondition(values[0], values[1:]...)
  162. tx.assignInterfacesToValue(exprs)
  163. return
  164. }
  165. }
  166. }
  167. }
  168. func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
  169. if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
  170. if c, ok := tx.Statement.Clauses["WHERE"]; ok {
  171. if where, ok := c.Expression.(clause.Where); ok {
  172. tx.assignInterfacesToValue(where.Exprs)
  173. }
  174. }
  175. // initialize with attrs, conds
  176. if len(tx.Statement.attrs) > 0 {
  177. tx.assignInterfacesToValue(tx.Statement.attrs...)
  178. }
  179. tx.Error = nil
  180. }
  181. // initialize with attrs, conds
  182. if len(tx.Statement.assigns) > 0 {
  183. tx.assignInterfacesToValue(tx.Statement.assigns...)
  184. }
  185. return
  186. }
  187. func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
  188. if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
  189. tx.Error = nil
  190. if c, ok := tx.Statement.Clauses["WHERE"]; ok {
  191. if where, ok := c.Expression.(clause.Where); ok {
  192. tx.assignInterfacesToValue(where.Exprs)
  193. }
  194. }
  195. // initialize with attrs, conds
  196. if len(tx.Statement.attrs) > 0 {
  197. tx.assignInterfacesToValue(tx.Statement.attrs...)
  198. }
  199. // initialize with attrs, conds
  200. if len(tx.Statement.assigns) > 0 {
  201. tx.assignInterfacesToValue(tx.Statement.assigns...)
  202. }
  203. return tx.Create(dest)
  204. } else if len(db.Statement.assigns) > 0 {
  205. exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
  206. assigns := map[string]interface{}{}
  207. for _, expr := range exprs {
  208. if eq, ok := expr.(clause.Eq); ok {
  209. switch column := eq.Column.(type) {
  210. case string:
  211. assigns[column] = eq.Value
  212. case clause.Column:
  213. assigns[column.Name] = eq.Value
  214. default:
  215. }
  216. }
  217. }
  218. return tx.Model(dest).Updates(assigns)
  219. }
  220. return db
  221. }
  222. // Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
  223. func (db *DB) Update(column string, value interface{}) (tx *DB) {
  224. tx = db.getInstance()
  225. tx.Statement.Dest = map[string]interface{}{column: value}
  226. tx.callbacks.Update().Execute(tx)
  227. return
  228. }
  229. // Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
  230. func (db *DB) Updates(values interface{}) (tx *DB) {
  231. tx = db.getInstance()
  232. tx.Statement.Dest = values
  233. tx.callbacks.Update().Execute(tx)
  234. return
  235. }
  236. func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
  237. tx = db.getInstance()
  238. tx.Statement.Dest = map[string]interface{}{column: value}
  239. tx.Statement.UpdatingColumn = true
  240. tx.callbacks.Update().Execute(tx)
  241. return
  242. }
  243. func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
  244. tx = db.getInstance()
  245. tx.Statement.Dest = values
  246. tx.Statement.UpdatingColumn = true
  247. tx.callbacks.Update().Execute(tx)
  248. return
  249. }
  250. // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
  251. func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
  252. tx = db.getInstance()
  253. if len(conds) > 0 {
  254. tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(conds[0], conds[1:]...)})
  255. }
  256. tx.Statement.Dest = value
  257. tx.callbacks.Delete().Execute(tx)
  258. return
  259. }
  260. func (db *DB) Count(count *int64) (tx *DB) {
  261. tx = db.getInstance()
  262. if tx.Statement.Model == nil {
  263. tx.Statement.Model = tx.Statement.Dest
  264. defer func() {
  265. tx.Statement.Model = nil
  266. }()
  267. }
  268. if len(tx.Statement.Selects) == 0 {
  269. tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
  270. defer delete(tx.Statement.Clauses, "SELECT")
  271. } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
  272. expr := clause.Expr{SQL: "count(1)"}
  273. if len(tx.Statement.Selects) == 1 {
  274. dbName := tx.Statement.Selects[0]
  275. if tx.Statement.Parse(tx.Statement.Model) == nil {
  276. if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
  277. dbName = f.DBName
  278. }
  279. }
  280. if tx.Statement.Distinct {
  281. expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
  282. } else {
  283. expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
  284. }
  285. }
  286. tx.Statement.AddClause(clause.Select{Expression: expr})
  287. defer tx.Statement.AddClause(clause.Select{})
  288. }
  289. if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
  290. if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
  291. delete(db.Statement.Clauses, "ORDER BY")
  292. defer func() {
  293. db.Statement.Clauses["ORDER BY"] = orderByClause
  294. }()
  295. }
  296. }
  297. tx.Statement.Dest = count
  298. tx.callbacks.Query().Execute(tx)
  299. if tx.RowsAffected != 1 {
  300. *count = tx.RowsAffected
  301. }
  302. return
  303. }
  304. func (db *DB) Row() *sql.Row {
  305. tx := db.getInstance().InstanceSet("rows", false)
  306. tx.callbacks.Row().Execute(tx)
  307. row, ok := tx.Statement.Dest.(*sql.Row)
  308. if !ok && tx.DryRun {
  309. db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
  310. }
  311. return row
  312. }
  313. func (db *DB) Rows() (*sql.Rows, error) {
  314. tx := db.getInstance().InstanceSet("rows", true)
  315. tx.callbacks.Row().Execute(tx)
  316. rows, ok := tx.Statement.Dest.(*sql.Rows)
  317. if !ok && tx.DryRun && tx.Error == nil {
  318. tx.Error = ErrDryRunModeUnsupported
  319. }
  320. return rows, tx.Error
  321. }
  322. // Scan scan value to a struct
  323. func (db *DB) Scan(dest interface{}) (tx *DB) {
  324. config := *db.Config
  325. currentLogger, newLogger := config.Logger, logger.Recorder.New()
  326. config.Logger = newLogger
  327. tx = db.getInstance()
  328. tx.Config = &config
  329. if rows, err := tx.Rows(); err != nil {
  330. tx.AddError(err)
  331. } else {
  332. defer rows.Close()
  333. if rows.Next() {
  334. tx.ScanRows(rows, dest)
  335. }
  336. }
  337. currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
  338. return newLogger.SQL, tx.RowsAffected
  339. }, tx.Error)
  340. tx.Logger = currentLogger
  341. return
  342. }
  343. // Pluck used to query single column from a model as a map
  344. // var ages []int64
  345. // db.Find(&users).Pluck("age", &ages)
  346. func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
  347. tx = db.getInstance()
  348. if tx.Statement.Model != nil {
  349. if tx.Statement.Parse(tx.Statement.Model) == nil {
  350. if f := tx.Statement.Schema.LookUpField(column); f != nil {
  351. column = f.DBName
  352. }
  353. }
  354. } else if tx.Statement.Table == "" {
  355. tx.AddError(ErrModelValueRequired)
  356. }
  357. fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
  358. tx.Statement.AddClauseIfNotExists(clause.Select{
  359. Distinct: tx.Statement.Distinct,
  360. Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
  361. })
  362. tx.Statement.Dest = dest
  363. tx.callbacks.Query().Execute(tx)
  364. return
  365. }
  366. func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
  367. tx := db.getInstance()
  368. if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
  369. tx.AddError(err)
  370. }
  371. tx.Statement.Dest = dest
  372. tx.Statement.ReflectValue = reflect.ValueOf(dest)
  373. for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
  374. tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
  375. }
  376. Scan(rows, tx, true)
  377. return tx.Error
  378. }
  379. // Transaction start a transaction as a block, return error will rollback, otherwise to commit.
  380. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
  381. panicked := true
  382. if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
  383. // nested transaction
  384. db.SavePoint(fmt.Sprintf("sp%p", fc))
  385. defer func() {
  386. // Make sure to rollback when panic, Block error or Commit error
  387. if panicked || err != nil {
  388. db.RollbackTo(fmt.Sprintf("sp%p", fc))
  389. }
  390. }()
  391. err = fc(db.Session(&Session{WithConditions: true}))
  392. } else {
  393. tx := db.Begin(opts...)
  394. defer func() {
  395. // Make sure to rollback when panic, Block error or Commit error
  396. if panicked || err != nil {
  397. tx.Rollback()
  398. }
  399. }()
  400. err = fc(tx)
  401. if err == nil {
  402. err = tx.Commit().Error
  403. }
  404. }
  405. panicked = false
  406. return
  407. }
  408. // Begin begins a transaction
  409. func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
  410. var (
  411. // clone statement
  412. tx = db.Session(&Session{WithConditions: true, Context: db.Statement.Context})
  413. opt *sql.TxOptions
  414. err error
  415. )
  416. if len(opts) > 0 {
  417. opt = opts[0]
  418. }
  419. if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
  420. tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
  421. } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
  422. tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
  423. } else {
  424. err = ErrInvalidTransaction
  425. }
  426. if err != nil {
  427. tx.AddError(err)
  428. }
  429. return tx
  430. }
  431. // Commit commit a transaction
  432. func (db *DB) Commit() *DB {
  433. if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
  434. db.AddError(committer.Commit())
  435. } else {
  436. db.AddError(ErrInvalidTransaction)
  437. }
  438. return db
  439. }
  440. // Rollback rollback a transaction
  441. func (db *DB) Rollback() *DB {
  442. if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
  443. if !reflect.ValueOf(committer).IsNil() {
  444. db.AddError(committer.Rollback())
  445. }
  446. } else {
  447. db.AddError(ErrInvalidTransaction)
  448. }
  449. return db
  450. }
  451. func (db *DB) SavePoint(name string) *DB {
  452. if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
  453. db.AddError(savePointer.SavePoint(db, name))
  454. } else {
  455. db.AddError(ErrUnsupportedDriver)
  456. }
  457. return db
  458. }
  459. func (db *DB) RollbackTo(name string) *DB {
  460. if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
  461. db.AddError(savePointer.RollbackTo(db, name))
  462. } else {
  463. db.AddError(ErrUnsupportedDriver)
  464. }
  465. return db
  466. }
  467. // Exec execute raw sql
  468. func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
  469. tx = db.getInstance()
  470. tx.Statement.SQL = strings.Builder{}
  471. if strings.Contains(sql, "@") {
  472. clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
  473. } else {
  474. clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
  475. }
  476. tx.callbacks.Raw().Execute(tx)
  477. return
  478. }