relationship.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. package schema
  2. import (
  3. "fmt"
  4. "reflect"
  5. "regexp"
  6. "strings"
  7. "github.com/jinzhu/inflection"
  8. "gorm.io/gorm/clause"
  9. )
  10. // RelationshipType relationship type
  11. type RelationshipType string
  12. const (
  13. HasOne RelationshipType = "has_one" // HasOneRel has one relationship
  14. HasMany RelationshipType = "has_many" // HasManyRel has many relationship
  15. BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship
  16. Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship
  17. )
  18. type Relationships struct {
  19. HasOne []*Relationship
  20. BelongsTo []*Relationship
  21. HasMany []*Relationship
  22. Many2Many []*Relationship
  23. Relations map[string]*Relationship
  24. }
  25. type Relationship struct {
  26. Name string
  27. Type RelationshipType
  28. Field *Field
  29. Polymorphic *Polymorphic
  30. References []*Reference
  31. Schema *Schema
  32. FieldSchema *Schema
  33. JoinTable *Schema
  34. foreignKeys, primaryKeys []string
  35. }
  36. type Polymorphic struct {
  37. PolymorphicID *Field
  38. PolymorphicType *Field
  39. Value string
  40. }
  41. type Reference struct {
  42. PrimaryKey *Field
  43. PrimaryValue string
  44. ForeignKey *Field
  45. OwnPrimaryKey bool
  46. }
  47. func (schema *Schema) parseRelation(field *Field) {
  48. var (
  49. err error
  50. fieldValue = reflect.New(field.IndirectFieldType).Interface()
  51. relation = &Relationship{
  52. Name: field.Name,
  53. Field: field,
  54. Schema: schema,
  55. foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]),
  56. primaryKeys: toColumns(field.TagSettings["REFERENCES"]),
  57. }
  58. )
  59. cacheStore := schema.cacheStore
  60. if field.OwnerSchema != nil {
  61. cacheStore = field.OwnerSchema.cacheStore
  62. }
  63. if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil {
  64. schema.err = err
  65. return
  66. }
  67. if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
  68. schema.buildPolymorphicRelation(relation, field, polymorphic)
  69. } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
  70. schema.buildMany2ManyRelation(relation, field, many2many)
  71. } else {
  72. switch field.IndirectFieldType.Kind() {
  73. case reflect.Struct:
  74. schema.guessRelation(relation, field, guessBelongs)
  75. case reflect.Slice:
  76. schema.guessRelation(relation, field, guessHas)
  77. default:
  78. schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
  79. }
  80. }
  81. if relation.Type == "has" {
  82. if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil {
  83. relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
  84. }
  85. switch field.IndirectFieldType.Kind() {
  86. case reflect.Struct:
  87. relation.Type = HasOne
  88. case reflect.Slice:
  89. relation.Type = HasMany
  90. }
  91. }
  92. if schema.err == nil {
  93. schema.Relationships.Relations[relation.Name] = relation
  94. switch relation.Type {
  95. case HasOne:
  96. schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
  97. case HasMany:
  98. schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation)
  99. case BelongsTo:
  100. schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation)
  101. case Many2Many:
  102. schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation)
  103. }
  104. }
  105. }
  106. // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
  107. // type User struct {
  108. // Toys []Toy `gorm:"polymorphic:Owner;"`
  109. // }
  110. // type Pet struct {
  111. // Toy Toy `gorm:"polymorphic:Owner;"`
  112. // }
  113. // type Toy struct {
  114. // OwnerID int
  115. // OwnerType string
  116. // }
  117. func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
  118. relation.Polymorphic = &Polymorphic{
  119. Value: schema.Table,
  120. PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
  121. PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
  122. }
  123. if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
  124. relation.Polymorphic.Value = strings.TrimSpace(value)
  125. }
  126. if relation.Polymorphic.PolymorphicType == nil {
  127. schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
  128. }
  129. if relation.Polymorphic.PolymorphicID == nil {
  130. schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
  131. }
  132. if schema.err == nil {
  133. relation.References = append(relation.References, &Reference{
  134. PrimaryValue: relation.Polymorphic.Value,
  135. ForeignKey: relation.Polymorphic.PolymorphicType,
  136. })
  137. primaryKeyField := schema.PrioritizedPrimaryField
  138. if len(relation.foreignKeys) > 0 {
  139. if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
  140. schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
  141. }
  142. }
  143. // use same data type for foreign keys
  144. relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
  145. relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
  146. if relation.Polymorphic.PolymorphicID.Size == 0 {
  147. relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
  148. }
  149. relation.References = append(relation.References, &Reference{
  150. PrimaryKey: primaryKeyField,
  151. ForeignKey: relation.Polymorphic.PolymorphicID,
  152. OwnPrimaryKey: true,
  153. })
  154. }
  155. relation.Type = "has"
  156. }
  157. func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) {
  158. relation.Type = Many2Many
  159. var (
  160. err error
  161. joinTableFields []reflect.StructField
  162. fieldsMap = map[string]*Field{}
  163. ownFieldsMap = map[string]bool{} // fix self join many2many
  164. joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
  165. joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
  166. )
  167. ownForeignFields := schema.PrimaryFields
  168. refForeignFields := relation.FieldSchema.PrimaryFields
  169. if len(relation.foreignKeys) > 0 {
  170. ownForeignFields = []*Field{}
  171. for _, foreignKey := range relation.foreignKeys {
  172. if field := schema.LookUpField(foreignKey); field != nil {
  173. ownForeignFields = append(ownForeignFields, field)
  174. } else {
  175. schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
  176. return
  177. }
  178. }
  179. }
  180. if len(relation.primaryKeys) > 0 {
  181. refForeignFields = []*Field{}
  182. for _, foreignKey := range relation.primaryKeys {
  183. if field := relation.FieldSchema.LookUpField(foreignKey); field != nil {
  184. refForeignFields = append(refForeignFields, field)
  185. } else {
  186. schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
  187. return
  188. }
  189. }
  190. }
  191. for idx, ownField := range ownForeignFields {
  192. joinFieldName := schema.Name + ownField.Name
  193. if len(joinForeignKeys) > idx {
  194. joinFieldName = strings.Title(joinForeignKeys[idx])
  195. }
  196. ownFieldsMap[joinFieldName] = true
  197. fieldsMap[joinFieldName] = ownField
  198. joinTableFields = append(joinTableFields, reflect.StructField{
  199. Name: joinFieldName,
  200. PkgPath: ownField.StructField.PkgPath,
  201. Type: ownField.StructField.Type,
  202. Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
  203. })
  204. }
  205. for idx, relField := range refForeignFields {
  206. joinFieldName := relation.FieldSchema.Name + relField.Name
  207. if len(joinReferences) > idx {
  208. joinFieldName = strings.Title(joinReferences[idx])
  209. }
  210. if _, ok := ownFieldsMap[joinFieldName]; ok {
  211. if field.Name != relation.FieldSchema.Name {
  212. joinFieldName = inflection.Singular(field.Name) + relField.Name
  213. } else {
  214. joinFieldName += "Reference"
  215. }
  216. }
  217. fieldsMap[joinFieldName] = relField
  218. joinTableFields = append(joinTableFields, reflect.StructField{
  219. Name: joinFieldName,
  220. PkgPath: relField.StructField.PkgPath,
  221. Type: relField.StructField.Type,
  222. Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
  223. })
  224. }
  225. joinTableFields = append(joinTableFields, reflect.StructField{
  226. Name: schema.Name + field.Name,
  227. Type: schema.ModelType,
  228. Tag: `gorm:"-"`,
  229. })
  230. if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
  231. schema.err = err
  232. }
  233. relation.JoinTable.Name = many2many
  234. relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
  235. relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
  236. relName := relation.Schema.Name
  237. relRefName := relation.FieldSchema.Name
  238. if relName == relRefName {
  239. relRefName = relation.Field.Name
  240. }
  241. if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok {
  242. relation.JoinTable.Relationships.Relations[relName] = &Relationship{
  243. Name: relName,
  244. Type: BelongsTo,
  245. Schema: relation.JoinTable,
  246. FieldSchema: relation.Schema,
  247. }
  248. } else {
  249. relation.JoinTable.Relationships.Relations[relName].References = []*Reference{}
  250. }
  251. if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok {
  252. relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{
  253. Name: relRefName,
  254. Type: BelongsTo,
  255. Schema: relation.JoinTable,
  256. FieldSchema: relation.FieldSchema,
  257. }
  258. } else {
  259. relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{}
  260. }
  261. // build references
  262. for _, f := range relation.JoinTable.Fields {
  263. if f.Creatable || f.Readable || f.Updatable {
  264. // use same data type for foreign keys
  265. f.DataType = fieldsMap[f.Name].DataType
  266. f.GORMDataType = fieldsMap[f.Name].GORMDataType
  267. if f.Size == 0 {
  268. f.Size = fieldsMap[f.Name].Size
  269. }
  270. relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
  271. ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
  272. if ownPriamryField {
  273. joinRel := relation.JoinTable.Relationships.Relations[relName]
  274. joinRel.Field = relation.Field
  275. joinRel.References = append(joinRel.References, &Reference{
  276. PrimaryKey: fieldsMap[f.Name],
  277. ForeignKey: f,
  278. })
  279. } else {
  280. joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
  281. if joinRefRel.Field == nil {
  282. joinRefRel.Field = relation.Field
  283. }
  284. joinRefRel.References = append(joinRefRel.References, &Reference{
  285. PrimaryKey: fieldsMap[f.Name],
  286. ForeignKey: f,
  287. })
  288. }
  289. relation.References = append(relation.References, &Reference{
  290. PrimaryKey: fieldsMap[f.Name],
  291. ForeignKey: f,
  292. OwnPrimaryKey: ownPriamryField,
  293. })
  294. }
  295. }
  296. }
  297. type guessLevel int
  298. const (
  299. guessBelongs guessLevel = iota
  300. guessEmbeddedBelongs
  301. guessHas
  302. guessEmbeddedHas
  303. )
  304. func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) {
  305. var (
  306. primaryFields, foreignFields []*Field
  307. primarySchema, foreignSchema = schema, relation.FieldSchema
  308. )
  309. reguessOrErr := func() {
  310. switch gl {
  311. case guessBelongs:
  312. schema.guessRelation(relation, field, guessEmbeddedBelongs)
  313. case guessEmbeddedBelongs:
  314. schema.guessRelation(relation, field, guessHas)
  315. case guessHas:
  316. schema.guessRelation(relation, field, guessEmbeddedHas)
  317. // case guessEmbeddedHas:
  318. default:
  319. schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name)
  320. }
  321. }
  322. switch gl {
  323. case guessBelongs:
  324. primarySchema, foreignSchema = relation.FieldSchema, schema
  325. case guessEmbeddedBelongs:
  326. if field.OwnerSchema != nil {
  327. primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
  328. } else {
  329. reguessOrErr()
  330. return
  331. }
  332. case guessHas:
  333. case guessEmbeddedHas:
  334. if field.OwnerSchema != nil {
  335. primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
  336. } else {
  337. reguessOrErr()
  338. return
  339. }
  340. }
  341. if len(relation.foreignKeys) > 0 {
  342. for _, foreignKey := range relation.foreignKeys {
  343. if f := foreignSchema.LookUpField(foreignKey); f != nil {
  344. foreignFields = append(foreignFields, f)
  345. } else {
  346. reguessOrErr()
  347. return
  348. }
  349. }
  350. } else {
  351. for _, primaryField := range primarySchema.PrimaryFields {
  352. lookUpName := primarySchema.Name + primaryField.Name
  353. if gl == guessBelongs {
  354. lookUpName = field.Name + primaryField.Name
  355. }
  356. if f := foreignSchema.LookUpField(lookUpName); f != nil {
  357. foreignFields = append(foreignFields, f)
  358. primaryFields = append(primaryFields, primaryField)
  359. }
  360. }
  361. }
  362. if len(foreignFields) == 0 {
  363. reguessOrErr()
  364. return
  365. } else if len(relation.primaryKeys) > 0 {
  366. for idx, primaryKey := range relation.primaryKeys {
  367. if f := primarySchema.LookUpField(primaryKey); f != nil {
  368. if len(primaryFields) < idx+1 {
  369. primaryFields = append(primaryFields, f)
  370. } else if f != primaryFields[idx] {
  371. reguessOrErr()
  372. return
  373. }
  374. } else {
  375. reguessOrErr()
  376. return
  377. }
  378. }
  379. } else if len(primaryFields) == 0 {
  380. if len(foreignFields) == 1 {
  381. primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
  382. } else if len(primarySchema.PrimaryFields) == len(foreignFields) {
  383. primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
  384. } else {
  385. reguessOrErr()
  386. return
  387. }
  388. }
  389. // build references
  390. for idx, foreignField := range foreignFields {
  391. // use same data type for foreign keys
  392. foreignField.DataType = primaryFields[idx].DataType
  393. foreignField.GORMDataType = primaryFields[idx].GORMDataType
  394. if foreignField.Size == 0 {
  395. foreignField.Size = primaryFields[idx].Size
  396. }
  397. relation.References = append(relation.References, &Reference{
  398. PrimaryKey: primaryFields[idx],
  399. ForeignKey: foreignField,
  400. OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
  401. })
  402. }
  403. if gl == guessHas || gl == guessEmbeddedHas {
  404. relation.Type = "has"
  405. } else {
  406. relation.Type = BelongsTo
  407. }
  408. }
  409. type Constraint struct {
  410. Name string
  411. Field *Field
  412. Schema *Schema
  413. ForeignKeys []*Field
  414. ReferenceSchema *Schema
  415. References []*Field
  416. OnDelete string
  417. OnUpdate string
  418. }
  419. func (rel *Relationship) ParseConstraint() *Constraint {
  420. str := rel.Field.TagSettings["CONSTRAINT"]
  421. if str == "-" {
  422. return nil
  423. }
  424. var (
  425. name string
  426. idx = strings.Index(str, ",")
  427. settings = ParseTagSetting(str, ",")
  428. )
  429. if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) {
  430. name = str[0:idx]
  431. } else {
  432. name = rel.Schema.namer.RelationshipFKName(*rel)
  433. }
  434. constraint := Constraint{
  435. Name: name,
  436. Field: rel.Field,
  437. OnUpdate: settings["ONUPDATE"],
  438. OnDelete: settings["ONDELETE"],
  439. }
  440. for _, ref := range rel.References {
  441. if ref.PrimaryKey != nil {
  442. constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey)
  443. constraint.References = append(constraint.References, ref.PrimaryKey)
  444. if ref.OwnPrimaryKey {
  445. constraint.Schema = ref.ForeignKey.Schema
  446. constraint.ReferenceSchema = rel.Schema
  447. } else {
  448. constraint.Schema = rel.Schema
  449. constraint.ReferenceSchema = ref.PrimaryKey.Schema
  450. }
  451. }
  452. }
  453. if rel.JoinTable != nil {
  454. return nil
  455. }
  456. return &constraint
  457. }
  458. func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) {
  459. table := rel.FieldSchema.Table
  460. foreignFields := []*Field{}
  461. relForeignKeys := []string{}
  462. if rel.JoinTable != nil {
  463. table = rel.JoinTable.Table
  464. for _, ref := range rel.References {
  465. if ref.OwnPrimaryKey {
  466. foreignFields = append(foreignFields, ref.PrimaryKey)
  467. relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
  468. } else if ref.PrimaryValue != "" {
  469. conds = append(conds, clause.Eq{
  470. Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
  471. Value: ref.PrimaryValue,
  472. })
  473. } else {
  474. conds = append(conds, clause.Eq{
  475. Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
  476. Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName},
  477. })
  478. }
  479. }
  480. } else {
  481. for _, ref := range rel.References {
  482. if ref.OwnPrimaryKey {
  483. relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
  484. foreignFields = append(foreignFields, ref.PrimaryKey)
  485. } else if ref.PrimaryValue != "" {
  486. conds = append(conds, clause.Eq{
  487. Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName},
  488. Value: ref.PrimaryValue,
  489. })
  490. } else {
  491. relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
  492. foreignFields = append(foreignFields, ref.ForeignKey)
  493. }
  494. }
  495. }
  496. _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
  497. column, values := ToQueryValues(table, relForeignKeys, foreignValues)
  498. conds = append(conds, clause.IN{Column: column, Values: values})
  499. return
  500. }