associations.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. package callbacks
  2. import (
  3. "reflect"
  4. "gorm.io/gorm"
  5. "gorm.io/gorm/clause"
  6. "gorm.io/gorm/schema"
  7. )
  8. func SaveBeforeAssociations(db *gorm.DB) {
  9. if db.Error == nil && db.Statement.Schema != nil {
  10. selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
  11. // Save Belongs To associations
  12. for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
  13. if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
  14. continue
  15. }
  16. setupReferences := func(obj reflect.Value, elem reflect.Value) {
  17. for _, ref := range rel.References {
  18. if !ref.OwnPrimaryKey {
  19. pv, _ := ref.PrimaryKey.ValueOf(elem)
  20. db.AddError(ref.ForeignKey.Set(obj, pv))
  21. if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
  22. dest[ref.ForeignKey.DBName] = pv
  23. if _, ok := dest[rel.Name]; ok {
  24. dest[rel.Name] = elem.Interface()
  25. }
  26. }
  27. }
  28. }
  29. }
  30. switch db.Statement.ReflectValue.Kind() {
  31. case reflect.Slice, reflect.Array:
  32. var (
  33. objs []reflect.Value
  34. fieldType = rel.Field.FieldType
  35. isPtr = fieldType.Kind() == reflect.Ptr
  36. )
  37. if !isPtr {
  38. fieldType = reflect.PtrTo(fieldType)
  39. }
  40. elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
  41. for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
  42. obj := db.Statement.ReflectValue.Index(i)
  43. if reflect.Indirect(obj).Kind() == reflect.Struct {
  44. if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
  45. rv := rel.Field.ReflectValueOf(obj) // relation reflect value
  46. objs = append(objs, obj)
  47. if isPtr {
  48. elems = reflect.Append(elems, rv)
  49. } else {
  50. elems = reflect.Append(elems, rv.Addr())
  51. }
  52. }
  53. } else {
  54. break
  55. }
  56. }
  57. if elems.Len() > 0 {
  58. if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil {
  59. for i := 0; i < elems.Len(); i++ {
  60. setupReferences(objs[i], elems.Index(i))
  61. }
  62. }
  63. }
  64. case reflect.Struct:
  65. if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
  66. rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
  67. if rv.Kind() != reflect.Ptr {
  68. rv = rv.Addr()
  69. }
  70. if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil {
  71. setupReferences(db.Statement.ReflectValue, rv)
  72. }
  73. }
  74. }
  75. }
  76. }
  77. }
  78. func SaveAfterAssociations(db *gorm.DB) {
  79. if db.Error == nil && db.Statement.Schema != nil {
  80. selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
  81. // Save Has One associations
  82. for _, rel := range db.Statement.Schema.Relationships.HasOne {
  83. if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
  84. continue
  85. }
  86. switch db.Statement.ReflectValue.Kind() {
  87. case reflect.Slice, reflect.Array:
  88. var (
  89. fieldType = rel.Field.FieldType
  90. isPtr = fieldType.Kind() == reflect.Ptr
  91. )
  92. if !isPtr {
  93. fieldType = reflect.PtrTo(fieldType)
  94. }
  95. elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
  96. for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
  97. obj := db.Statement.ReflectValue.Index(i)
  98. if reflect.Indirect(obj).Kind() == reflect.Struct {
  99. if _, zero := rel.Field.ValueOf(obj); !zero {
  100. rv := rel.Field.ReflectValueOf(obj)
  101. if rv.Kind() != reflect.Ptr {
  102. rv = rv.Addr()
  103. }
  104. for _, ref := range rel.References {
  105. if ref.OwnPrimaryKey {
  106. fv, _ := ref.PrimaryKey.ValueOf(obj)
  107. db.AddError(ref.ForeignKey.Set(rv, fv))
  108. } else if ref.PrimaryValue != "" {
  109. db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
  110. }
  111. }
  112. elems = reflect.Append(elems, rv)
  113. }
  114. }
  115. }
  116. if elems.Len() > 0 {
  117. assignmentColumns := []string{}
  118. for _, ref := range rel.References {
  119. assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
  120. }
  121. db.AddError(db.Session(&gorm.Session{}).Clauses(
  122. onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
  123. ).Create(elems.Interface()).Error)
  124. }
  125. case reflect.Struct:
  126. if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
  127. f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
  128. if f.Kind() != reflect.Ptr {
  129. f = f.Addr()
  130. }
  131. assignmentColumns := []string{}
  132. for _, ref := range rel.References {
  133. if ref.OwnPrimaryKey {
  134. fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
  135. ref.ForeignKey.Set(f, fv)
  136. } else if ref.PrimaryValue != "" {
  137. ref.ForeignKey.Set(f, ref.PrimaryValue)
  138. }
  139. assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
  140. }
  141. db.AddError(db.Session(&gorm.Session{}).Clauses(
  142. onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
  143. ).Create(f.Interface()).Error)
  144. }
  145. }
  146. }
  147. // Save Has Many associations
  148. for _, rel := range db.Statement.Schema.Relationships.HasMany {
  149. if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
  150. continue
  151. }
  152. fieldType := rel.Field.IndirectFieldType.Elem()
  153. isPtr := fieldType.Kind() == reflect.Ptr
  154. if !isPtr {
  155. fieldType = reflect.PtrTo(fieldType)
  156. }
  157. elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
  158. appendToElems := func(v reflect.Value) {
  159. if _, zero := rel.Field.ValueOf(v); !zero {
  160. f := reflect.Indirect(rel.Field.ReflectValueOf(v))
  161. for i := 0; i < f.Len(); i++ {
  162. elem := f.Index(i)
  163. for _, ref := range rel.References {
  164. if ref.OwnPrimaryKey {
  165. pv, _ := ref.PrimaryKey.ValueOf(v)
  166. ref.ForeignKey.Set(elem, pv)
  167. } else if ref.PrimaryValue != "" {
  168. ref.ForeignKey.Set(elem, ref.PrimaryValue)
  169. }
  170. }
  171. if isPtr {
  172. elems = reflect.Append(elems, elem)
  173. } else {
  174. elems = reflect.Append(elems, elem.Addr())
  175. }
  176. }
  177. }
  178. }
  179. switch db.Statement.ReflectValue.Kind() {
  180. case reflect.Slice, reflect.Array:
  181. for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
  182. obj := db.Statement.ReflectValue.Index(i)
  183. if reflect.Indirect(obj).Kind() == reflect.Struct {
  184. appendToElems(obj)
  185. }
  186. }
  187. case reflect.Struct:
  188. appendToElems(db.Statement.ReflectValue)
  189. }
  190. if elems.Len() > 0 {
  191. assignmentColumns := []string{}
  192. for _, ref := range rel.References {
  193. assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
  194. }
  195. db.AddError(db.Session(&gorm.Session{}).Clauses(
  196. onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
  197. ).Create(elems.Interface()).Error)
  198. }
  199. }
  200. // Save Many2Many associations
  201. for _, rel := range db.Statement.Schema.Relationships.Many2Many {
  202. if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
  203. continue
  204. }
  205. fieldType := rel.Field.IndirectFieldType.Elem()
  206. isPtr := fieldType.Kind() == reflect.Ptr
  207. if !isPtr {
  208. fieldType = reflect.PtrTo(fieldType)
  209. }
  210. elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
  211. joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 0)
  212. objs := []reflect.Value{}
  213. appendToJoins := func(obj reflect.Value, elem reflect.Value) {
  214. joinValue := reflect.New(rel.JoinTable.ModelType)
  215. for _, ref := range rel.References {
  216. if ref.OwnPrimaryKey {
  217. fv, _ := ref.PrimaryKey.ValueOf(obj)
  218. ref.ForeignKey.Set(joinValue, fv)
  219. } else if ref.PrimaryValue != "" {
  220. ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
  221. } else {
  222. fv, _ := ref.PrimaryKey.ValueOf(elem)
  223. ref.ForeignKey.Set(joinValue, fv)
  224. }
  225. }
  226. joins = reflect.Append(joins, joinValue)
  227. }
  228. appendToElems := func(v reflect.Value) {
  229. if _, zero := rel.Field.ValueOf(v); !zero {
  230. f := reflect.Indirect(rel.Field.ReflectValueOf(v))
  231. for i := 0; i < f.Len(); i++ {
  232. elem := f.Index(i)
  233. objs = append(objs, v)
  234. if isPtr {
  235. elems = reflect.Append(elems, elem)
  236. } else {
  237. elems = reflect.Append(elems, elem.Addr())
  238. }
  239. }
  240. }
  241. }
  242. switch db.Statement.ReflectValue.Kind() {
  243. case reflect.Slice, reflect.Array:
  244. for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
  245. obj := db.Statement.ReflectValue.Index(i)
  246. if reflect.Indirect(obj).Kind() == reflect.Struct {
  247. appendToElems(obj)
  248. }
  249. }
  250. case reflect.Struct:
  251. appendToElems(db.Statement.ReflectValue)
  252. }
  253. if elems.Len() > 0 {
  254. db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error)
  255. for i := 0; i < elems.Len(); i++ {
  256. appendToJoins(objs[i], elems.Index(i))
  257. }
  258. }
  259. if joins.Len() > 0 {
  260. db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error)
  261. }
  262. }
  263. }
  264. }
  265. func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict {
  266. if stmt.DB.FullSaveAssociations {
  267. defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
  268. for _, dbName := range s.DBNames {
  269. if !s.LookUpField(dbName).PrimaryKey {
  270. defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
  271. }
  272. }
  273. }
  274. if len(defaultUpdatingColumns) > 0 {
  275. var columns []clause.Column
  276. if s.PrioritizedPrimaryField != nil {
  277. columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}}
  278. } else {
  279. for _, dbName := range s.PrimaryFieldDBNames {
  280. columns = append(columns, clause.Column{Name: dbName})
  281. }
  282. }
  283. return clause.OnConflict{
  284. Columns: columns,
  285. DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
  286. }
  287. }
  288. return clause.OnConflict{DoNothing: true}
  289. }