gorm.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. package gorm
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "sync"
  8. "time"
  9. "gorm.io/gorm/clause"
  10. "gorm.io/gorm/logger"
  11. "gorm.io/gorm/schema"
  12. )
  13. // Config GORM config
  14. type Config struct {
  15. // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
  16. // You can disable it by setting `SkipDefaultTransaction` to true
  17. SkipDefaultTransaction bool
  18. // NamingStrategy tables, columns naming strategy
  19. NamingStrategy schema.Namer
  20. // FullSaveAssociations full save associations
  21. FullSaveAssociations bool
  22. // Logger
  23. Logger logger.Interface
  24. // NowFunc the function to be used when creating a new timestamp
  25. NowFunc func() time.Time
  26. // DryRun generate sql without execute
  27. DryRun bool
  28. // PrepareStmt executes the given query in cached statement
  29. PrepareStmt bool
  30. // DisableAutomaticPing
  31. DisableAutomaticPing bool
  32. // DisableForeignKeyConstraintWhenMigrating
  33. DisableForeignKeyConstraintWhenMigrating bool
  34. // AllowGlobalUpdate allow global update
  35. AllowGlobalUpdate bool
  36. // ClauseBuilders clause builder
  37. ClauseBuilders map[string]clause.ClauseBuilder
  38. // ConnPool db conn pool
  39. ConnPool ConnPool
  40. // Dialector database dialector
  41. Dialector
  42. // Plugins registered plugins
  43. Plugins map[string]Plugin
  44. callbacks *callbacks
  45. cacheStore *sync.Map
  46. }
  47. // DB GORM DB definition
  48. type DB struct {
  49. *Config
  50. Error error
  51. RowsAffected int64
  52. Statement *Statement
  53. clone int
  54. }
  55. // Session session config when create session with Session() method
  56. type Session struct {
  57. DryRun bool
  58. PrepareStmt bool
  59. WithConditions bool
  60. SkipDefaultTransaction bool
  61. AllowGlobalUpdate bool
  62. FullSaveAssociations bool
  63. Context context.Context
  64. Logger logger.Interface
  65. NowFunc func() time.Time
  66. }
  67. // Open initialize db session based on dialector
  68. func Open(dialector Dialector, config *Config) (db *DB, err error) {
  69. if config == nil {
  70. config = &Config{}
  71. }
  72. if config.NamingStrategy == nil {
  73. config.NamingStrategy = schema.NamingStrategy{}
  74. }
  75. if config.Logger == nil {
  76. config.Logger = logger.Default
  77. }
  78. if config.NowFunc == nil {
  79. config.NowFunc = func() time.Time { return time.Now().Local() }
  80. }
  81. if dialector != nil {
  82. config.Dialector = dialector
  83. }
  84. if config.Plugins == nil {
  85. config.Plugins = map[string]Plugin{}
  86. }
  87. if config.cacheStore == nil {
  88. config.cacheStore = &sync.Map{}
  89. }
  90. db = &DB{Config: config, clone: 1}
  91. db.callbacks = initializeCallbacks(db)
  92. if config.ClauseBuilders == nil {
  93. config.ClauseBuilders = map[string]clause.ClauseBuilder{}
  94. }
  95. if config.Dialector != nil {
  96. err = config.Dialector.Initialize(db)
  97. }
  98. preparedStmt := &PreparedStmtDB{
  99. ConnPool: db.ConnPool,
  100. Stmts: map[string]*sql.Stmt{},
  101. Mux: &sync.RWMutex{},
  102. PreparedSQL: make([]string, 0, 100),
  103. }
  104. db.cacheStore.Store("preparedStmt", preparedStmt)
  105. if config.PrepareStmt {
  106. db.ConnPool = preparedStmt
  107. }
  108. db.Statement = &Statement{
  109. DB: db,
  110. ConnPool: db.ConnPool,
  111. Context: context.Background(),
  112. Clauses: map[string]clause.Clause{},
  113. }
  114. if err == nil && !config.DisableAutomaticPing {
  115. if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
  116. err = pinger.Ping()
  117. }
  118. }
  119. if err != nil {
  120. config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
  121. }
  122. return
  123. }
  124. // Session create new db session
  125. func (db *DB) Session(config *Session) *DB {
  126. var (
  127. txConfig = *db.Config
  128. tx = &DB{
  129. Config: &txConfig,
  130. Statement: db.Statement,
  131. clone: 1,
  132. }
  133. )
  134. if config.SkipDefaultTransaction {
  135. tx.Config.SkipDefaultTransaction = true
  136. }
  137. if config.AllowGlobalUpdate {
  138. txConfig.AllowGlobalUpdate = true
  139. }
  140. if config.FullSaveAssociations {
  141. txConfig.FullSaveAssociations = true
  142. }
  143. if config.Context != nil {
  144. tx.Statement = tx.Statement.clone()
  145. tx.Statement.DB = tx
  146. tx.Statement.Context = config.Context
  147. }
  148. if config.PrepareStmt {
  149. if v, ok := db.cacheStore.Load("preparedStmt"); ok {
  150. tx.Statement = tx.Statement.clone()
  151. preparedStmt := v.(*PreparedStmtDB)
  152. tx.Statement.ConnPool = &PreparedStmtDB{
  153. ConnPool: db.Config.ConnPool,
  154. Mux: preparedStmt.Mux,
  155. Stmts: preparedStmt.Stmts,
  156. }
  157. txConfig.ConnPool = tx.Statement.ConnPool
  158. txConfig.PrepareStmt = true
  159. }
  160. }
  161. if config.WithConditions {
  162. tx.clone = 2
  163. }
  164. if config.DryRun {
  165. tx.Config.DryRun = true
  166. }
  167. if config.Logger != nil {
  168. tx.Config.Logger = config.Logger
  169. }
  170. if config.NowFunc != nil {
  171. tx.Config.NowFunc = config.NowFunc
  172. }
  173. return tx
  174. }
  175. // WithContext change current instance db's context to ctx
  176. func (db *DB) WithContext(ctx context.Context) *DB {
  177. return db.Session(&Session{WithConditions: true, Context: ctx})
  178. }
  179. // Debug start debug mode
  180. func (db *DB) Debug() (tx *DB) {
  181. return db.Session(&Session{
  182. WithConditions: true,
  183. Logger: db.Logger.LogMode(logger.Info),
  184. })
  185. }
  186. // Set store value with key into current db instance's context
  187. func (db *DB) Set(key string, value interface{}) *DB {
  188. tx := db.getInstance()
  189. tx.Statement.Settings.Store(key, value)
  190. return tx
  191. }
  192. // Get get value with key from current db instance's context
  193. func (db *DB) Get(key string) (interface{}, bool) {
  194. return db.Statement.Settings.Load(key)
  195. }
  196. // InstanceSet store value with key into current db instance's context
  197. func (db *DB) InstanceSet(key string, value interface{}) *DB {
  198. tx := db.getInstance()
  199. tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
  200. return tx
  201. }
  202. // InstanceGet get value with key from current db instance's context
  203. func (db *DB) InstanceGet(key string) (interface{}, bool) {
  204. return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
  205. }
  206. // Callback returns callback manager
  207. func (db *DB) Callback() *callbacks {
  208. return db.callbacks
  209. }
  210. // AddError add error to db
  211. func (db *DB) AddError(err error) error {
  212. if db.Error == nil {
  213. db.Error = err
  214. } else if err != nil {
  215. db.Error = fmt.Errorf("%v; %w", db.Error, err)
  216. }
  217. return db.Error
  218. }
  219. // DB returns `*sql.DB`
  220. func (db *DB) DB() (*sql.DB, error) {
  221. connPool := db.ConnPool
  222. if stmtDB, ok := connPool.(*PreparedStmtDB); ok {
  223. connPool = stmtDB.ConnPool
  224. }
  225. if sqldb, ok := connPool.(*sql.DB); ok {
  226. return sqldb, nil
  227. }
  228. return nil, errors.New("invalid db")
  229. }
  230. func (db *DB) getInstance() *DB {
  231. if db.clone > 0 {
  232. tx := &DB{Config: db.Config}
  233. if db.clone == 1 {
  234. // clone with new statement
  235. tx.Statement = &Statement{
  236. DB: tx,
  237. ConnPool: db.Statement.ConnPool,
  238. Context: db.Statement.Context,
  239. Clauses: map[string]clause.Clause{},
  240. }
  241. } else {
  242. // with clone statement
  243. tx.Statement = db.Statement.clone()
  244. tx.Statement.DB = tx
  245. }
  246. return tx
  247. }
  248. return db
  249. }
  250. func Expr(expr string, args ...interface{}) clause.Expr {
  251. return clause.Expr{SQL: expr, Vars: args}
  252. }
  253. func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
  254. var (
  255. tx = db.getInstance()
  256. stmt = tx.Statement
  257. modelSchema, joinSchema *schema.Schema
  258. )
  259. if err := stmt.Parse(model); err == nil {
  260. modelSchema = stmt.Schema
  261. } else {
  262. return err
  263. }
  264. if err := stmt.Parse(joinTable); err == nil {
  265. joinSchema = stmt.Schema
  266. } else {
  267. return err
  268. }
  269. if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
  270. for _, ref := range relation.References {
  271. if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
  272. f.DataType = ref.ForeignKey.DataType
  273. f.GORMDataType = ref.ForeignKey.GORMDataType
  274. if f.Size == 0 {
  275. f.Size = ref.ForeignKey.Size
  276. }
  277. ref.ForeignKey = f
  278. } else {
  279. return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
  280. }
  281. }
  282. for name, rel := range relation.JoinTable.Relationships.Relations {
  283. if _, ok := joinSchema.Relationships.Relations[name]; !ok {
  284. rel.Schema = joinSchema
  285. joinSchema.Relationships.Relations[name] = rel
  286. }
  287. }
  288. relation.JoinTable = joinSchema
  289. } else {
  290. return fmt.Errorf("failed to found relation: %v", field)
  291. }
  292. return nil
  293. }
  294. func (db *DB) Use(plugin Plugin) (err error) {
  295. name := plugin.Name()
  296. if _, ok := db.Plugins[name]; !ok {
  297. if err = plugin.Initialize(db); err == nil {
  298. db.Plugins[name] = plugin
  299. }
  300. } else {
  301. return ErrRegistered
  302. }
  303. return err
  304. }