association.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. package gorm
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. "strings"
  7. "gorm.io/gorm/clause"
  8. "gorm.io/gorm/schema"
  9. "gorm.io/gorm/utils"
  10. )
  11. // Association Mode contains some helper methods to handle relationship things easily.
  12. type Association struct {
  13. DB *DB
  14. Relationship *schema.Relationship
  15. Error error
  16. }
  17. func (db *DB) Association(column string) *Association {
  18. association := &Association{DB: db}
  19. table := db.Statement.Table
  20. if err := db.Statement.Parse(db.Statement.Model); err == nil {
  21. db.Statement.Table = table
  22. association.Relationship = db.Statement.Schema.Relationships.Relations[column]
  23. if association.Relationship == nil {
  24. association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
  25. }
  26. db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
  27. for db.Statement.ReflectValue.Kind() == reflect.Ptr {
  28. db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
  29. }
  30. } else {
  31. association.Error = err
  32. }
  33. return association
  34. }
  35. func (association *Association) Find(out interface{}, conds ...interface{}) error {
  36. if association.Error == nil {
  37. association.Error = association.buildCondition().Find(out, conds...).Error
  38. }
  39. return association.Error
  40. }
  41. func (association *Association) Append(values ...interface{}) error {
  42. if association.Error == nil {
  43. switch association.Relationship.Type {
  44. case schema.HasOne, schema.BelongsTo:
  45. if len(values) > 0 {
  46. association.Error = association.Replace(values...)
  47. }
  48. default:
  49. association.saveAssociation( /*clear*/ false, values...)
  50. }
  51. }
  52. return association.Error
  53. }
  54. func (association *Association) Replace(values ...interface{}) error {
  55. if association.Error == nil {
  56. // save associations
  57. association.saveAssociation( /*clear*/ true, values...)
  58. // set old associations's foreign key to null
  59. reflectValue := association.DB.Statement.ReflectValue
  60. rel := association.Relationship
  61. switch rel.Type {
  62. case schema.BelongsTo:
  63. if len(values) == 0 {
  64. updateMap := map[string]interface{}{}
  65. switch reflectValue.Kind() {
  66. case reflect.Slice, reflect.Array:
  67. for i := 0; i < reflectValue.Len(); i++ {
  68. association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
  69. }
  70. case reflect.Struct:
  71. association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
  72. }
  73. for _, ref := range rel.References {
  74. updateMap[ref.ForeignKey.DBName] = nil
  75. }
  76. association.Error = association.DB.UpdateColumns(updateMap).Error
  77. }
  78. case schema.HasOne, schema.HasMany:
  79. var (
  80. primaryFields []*schema.Field
  81. foreignKeys []string
  82. updateMap = map[string]interface{}{}
  83. relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
  84. modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
  85. tx = association.DB.Model(modelValue)
  86. )
  87. if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
  88. if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
  89. tx.Not(clause.IN{Column: column, Values: values})
  90. }
  91. }
  92. for _, ref := range rel.References {
  93. if ref.OwnPrimaryKey {
  94. primaryFields = append(primaryFields, ref.PrimaryKey)
  95. foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
  96. updateMap[ref.ForeignKey.DBName] = nil
  97. } else if ref.PrimaryValue != "" {
  98. tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  99. }
  100. }
  101. if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
  102. column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
  103. tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap)
  104. }
  105. case schema.Many2Many:
  106. var (
  107. primaryFields, relPrimaryFields []*schema.Field
  108. joinPrimaryKeys, joinRelPrimaryKeys []string
  109. modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
  110. tx = association.DB.Model(modelValue)
  111. )
  112. for _, ref := range rel.References {
  113. if ref.PrimaryValue == "" {
  114. if ref.OwnPrimaryKey {
  115. primaryFields = append(primaryFields, ref.PrimaryKey)
  116. joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
  117. } else {
  118. relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
  119. joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
  120. }
  121. } else {
  122. tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  123. }
  124. }
  125. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
  126. if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
  127. tx.Where(clause.IN{Column: column, Values: values})
  128. } else {
  129. return ErrPrimaryKeyRequired
  130. }
  131. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
  132. if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
  133. tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
  134. }
  135. tx.Delete(modelValue)
  136. }
  137. }
  138. return association.Error
  139. }
  140. func (association *Association) Delete(values ...interface{}) error {
  141. if association.Error == nil {
  142. var (
  143. reflectValue = association.DB.Statement.ReflectValue
  144. rel = association.Relationship
  145. primaryFields []*schema.Field
  146. foreignKeys []string
  147. updateAttrs = map[string]interface{}{}
  148. conds []clause.Expression
  149. )
  150. for _, ref := range rel.References {
  151. if ref.PrimaryValue == "" {
  152. primaryFields = append(primaryFields, ref.PrimaryKey)
  153. foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
  154. updateAttrs[ref.ForeignKey.DBName] = nil
  155. } else {
  156. conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  157. }
  158. }
  159. switch rel.Type {
  160. case schema.BelongsTo:
  161. tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
  162. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
  163. pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
  164. conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
  165. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
  166. relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
  167. conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
  168. association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
  169. case schema.HasOne, schema.HasMany:
  170. tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
  171. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
  172. pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
  173. conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
  174. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
  175. relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
  176. conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
  177. association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
  178. case schema.Many2Many:
  179. var (
  180. primaryFields, relPrimaryFields []*schema.Field
  181. joinPrimaryKeys, joinRelPrimaryKeys []string
  182. joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
  183. )
  184. for _, ref := range rel.References {
  185. if ref.PrimaryValue == "" {
  186. if ref.OwnPrimaryKey {
  187. primaryFields = append(primaryFields, ref.PrimaryKey)
  188. joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
  189. } else {
  190. relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
  191. joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
  192. }
  193. } else {
  194. conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  195. }
  196. }
  197. _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
  198. pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
  199. conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
  200. _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
  201. relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
  202. conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
  203. association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
  204. }
  205. if association.Error == nil {
  206. // clean up deleted values's foreign key
  207. relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
  208. cleanUpDeletedRelations := func(data reflect.Value) {
  209. if _, zero := rel.Field.ValueOf(data); !zero {
  210. fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
  211. primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
  212. switch fieldValue.Kind() {
  213. case reflect.Slice, reflect.Array:
  214. validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
  215. for i := 0; i < fieldValue.Len(); i++ {
  216. for idx, field := range rel.FieldSchema.PrimaryFields {
  217. primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i))
  218. }
  219. if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
  220. validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i))
  221. }
  222. }
  223. association.Error = rel.Field.Set(data, validFieldValues.Interface())
  224. case reflect.Struct:
  225. for idx, field := range rel.FieldSchema.PrimaryFields {
  226. primaryValues[idx], _ = field.ValueOf(fieldValue)
  227. }
  228. if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
  229. if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
  230. break
  231. }
  232. if rel.JoinTable == nil {
  233. for _, ref := range rel.References {
  234. if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
  235. association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
  236. } else {
  237. association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
  238. }
  239. }
  240. }
  241. }
  242. }
  243. }
  244. }
  245. switch reflectValue.Kind() {
  246. case reflect.Slice, reflect.Array:
  247. for i := 0; i < reflectValue.Len(); i++ {
  248. cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i)))
  249. }
  250. case reflect.Struct:
  251. cleanUpDeletedRelations(reflectValue)
  252. }
  253. }
  254. }
  255. return association.Error
  256. }
  257. func (association *Association) Clear() error {
  258. return association.Replace()
  259. }
  260. func (association *Association) Count() (count int64) {
  261. if association.Error == nil {
  262. association.Error = association.buildCondition().Count(&count).Error
  263. }
  264. return
  265. }
  266. type assignBack struct {
  267. Source reflect.Value
  268. Index int
  269. Dest reflect.Value
  270. }
  271. func (association *Association) saveAssociation(clear bool, values ...interface{}) {
  272. var (
  273. reflectValue = association.DB.Statement.ReflectValue
  274. assignBacks []assignBack // assign association values back to arguments after save
  275. )
  276. appendToRelations := func(source, rv reflect.Value, clear bool) {
  277. switch association.Relationship.Type {
  278. case schema.HasOne, schema.BelongsTo:
  279. switch rv.Kind() {
  280. case reflect.Slice, reflect.Array:
  281. if rv.Len() > 0 {
  282. association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
  283. if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
  284. assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
  285. }
  286. }
  287. case reflect.Struct:
  288. association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
  289. if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
  290. assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
  291. }
  292. }
  293. case schema.HasMany, schema.Many2Many:
  294. elemType := association.Relationship.Field.IndirectFieldType.Elem()
  295. fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source))
  296. if clear {
  297. fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
  298. }
  299. appendToFieldValues := func(ev reflect.Value) {
  300. if ev.Type().AssignableTo(elemType) {
  301. fieldValue = reflect.Append(fieldValue, ev)
  302. } else if ev.Type().Elem().AssignableTo(elemType) {
  303. fieldValue = reflect.Append(fieldValue, ev.Elem())
  304. } else {
  305. association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name)
  306. }
  307. if elemType.Kind() == reflect.Struct {
  308. assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()})
  309. }
  310. }
  311. switch rv.Kind() {
  312. case reflect.Slice, reflect.Array:
  313. for i := 0; i < rv.Len(); i++ {
  314. appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
  315. }
  316. case reflect.Struct:
  317. appendToFieldValues(rv.Addr())
  318. }
  319. if association.Error == nil {
  320. association.Error = association.Relationship.Field.Set(source, fieldValue.Interface())
  321. }
  322. }
  323. }
  324. selectedSaveColumns := []string{association.Relationship.Name}
  325. for _, ref := range association.Relationship.References {
  326. if !ref.OwnPrimaryKey {
  327. selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
  328. }
  329. }
  330. switch reflectValue.Kind() {
  331. case reflect.Slice, reflect.Array:
  332. if len(values) != reflectValue.Len() {
  333. // clear old data
  334. if clear && len(values) == 0 {
  335. for i := 0; i < reflectValue.Len(); i++ {
  336. if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
  337. association.Error = err
  338. break
  339. }
  340. if association.Relationship.JoinTable == nil {
  341. for _, ref := range association.Relationship.References {
  342. if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
  343. if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
  344. association.Error = err
  345. break
  346. }
  347. }
  348. }
  349. }
  350. }
  351. break
  352. }
  353. association.Error = errors.New("invalid association values, length doesn't match")
  354. return
  355. }
  356. for i := 0; i < reflectValue.Len(); i++ {
  357. appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
  358. // TODO support save slice data, sql with case?
  359. association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error
  360. }
  361. case reflect.Struct:
  362. // clear old data
  363. if clear && len(values) == 0 {
  364. association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
  365. if association.Relationship.JoinTable == nil && association.Error == nil {
  366. for _, ref := range association.Relationship.References {
  367. if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
  368. association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
  369. }
  370. }
  371. }
  372. }
  373. for idx, value := range values {
  374. rv := reflect.Indirect(reflect.ValueOf(value))
  375. appendToRelations(reflectValue, rv, clear && idx == 0)
  376. }
  377. if len(values) > 0 {
  378. association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error
  379. }
  380. }
  381. for _, assignBack := range assignBacks {
  382. fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source))
  383. if assignBack.Index > 0 {
  384. reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
  385. } else {
  386. reflect.Indirect(assignBack.Dest).Set(fieldValue)
  387. }
  388. }
  389. }
  390. func (association *Association) buildCondition() *DB {
  391. var (
  392. queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
  393. modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
  394. tx = association.DB.Model(modelValue)
  395. )
  396. if association.Relationship.JoinTable != nil {
  397. if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
  398. joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
  399. for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
  400. joinStmt.AddClause(queryClause)
  401. }
  402. joinStmt.Build("WHERE")
  403. tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
  404. }
  405. tx.Clauses(clause.From{Joins: []clause.Join{{
  406. Table: clause.Table{Name: association.Relationship.JoinTable.Table},
  407. ON: clause.Where{Exprs: queryConds},
  408. }}})
  409. } else {
  410. tx.Clauses(clause.Where{Exprs: queryConds})
  411. }
  412. return tx
  413. }