|
@@ -17,27 +17,38 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
|
return defaultTableName
|
|
return defaultTableName
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// lock for mutating global cached model metadata
|
|
|
|
+var structsLock sync.Mutex
|
|
|
|
+
|
|
|
|
+// global cache of model metadata
|
|
var modelStructsMap sync.Map
|
|
var modelStructsMap sync.Map
|
|
|
|
|
|
// ModelStruct model definition
|
|
// ModelStruct model definition
|
|
type ModelStruct struct {
|
|
type ModelStruct struct {
|
|
- PrimaryFields []*StructField
|
|
|
|
- StructFields []*StructField
|
|
|
|
- ModelType reflect.Type
|
|
|
|
|
|
+ PrimaryFields []*StructField
|
|
|
|
+ StructFields []*StructField
|
|
|
|
+ ModelType reflect.Type
|
|
|
|
+
|
|
defaultTableName string
|
|
defaultTableName string
|
|
|
|
+ l sync.Mutex
|
|
}
|
|
}
|
|
|
|
|
|
// TableName returns model's table name
|
|
// TableName returns model's table name
|
|
func (s *ModelStruct) TableName(db *DB) string {
|
|
func (s *ModelStruct) TableName(db *DB) string {
|
|
|
|
+ s.l.Lock()
|
|
|
|
+ defer s.l.Unlock()
|
|
|
|
+
|
|
if s.defaultTableName == "" && db != nil && s.ModelType != nil {
|
|
if s.defaultTableName == "" && db != nil && s.ModelType != nil {
|
|
// Set default table name
|
|
// Set default table name
|
|
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
|
|
if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok {
|
|
s.defaultTableName = tabler.TableName()
|
|
s.defaultTableName = tabler.TableName()
|
|
} else {
|
|
} else {
|
|
tableName := ToTableName(s.ModelType.Name())
|
|
tableName := ToTableName(s.ModelType.Name())
|
|
- if db == nil || !db.parent.singularTable {
|
|
|
|
|
|
+ db.parent.RLock()
|
|
|
|
+ if db == nil || (db.parent != nil && !db.parent.singularTable) {
|
|
tableName = inflection.Plural(tableName)
|
|
tableName = inflection.Plural(tableName)
|
|
}
|
|
}
|
|
|
|
+ db.parent.RUnlock()
|
|
s.defaultTableName = tableName
|
|
s.defaultTableName = tableName
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -65,52 +76,52 @@ type StructField struct {
|
|
}
|
|
}
|
|
|
|
|
|
// TagSettingsSet Sets a tag in the tag settings map
|
|
// TagSettingsSet Sets a tag in the tag settings map
|
|
-func (s *StructField) TagSettingsSet(key, val string) {
|
|
|
|
- s.tagSettingsLock.Lock()
|
|
|
|
- defer s.tagSettingsLock.Unlock()
|
|
|
|
- s.TagSettings[key] = val
|
|
|
|
|
|
+func (sf *StructField) TagSettingsSet(key, val string) {
|
|
|
|
+ sf.tagSettingsLock.Lock()
|
|
|
|
+ defer sf.tagSettingsLock.Unlock()
|
|
|
|
+ sf.TagSettings[key] = val
|
|
}
|
|
}
|
|
|
|
|
|
// TagSettingsGet returns a tag from the tag settings
|
|
// TagSettingsGet returns a tag from the tag settings
|
|
-func (s *StructField) TagSettingsGet(key string) (string, bool) {
|
|
|
|
- s.tagSettingsLock.RLock()
|
|
|
|
- defer s.tagSettingsLock.RUnlock()
|
|
|
|
- val, ok := s.TagSettings[key]
|
|
|
|
|
|
+func (sf *StructField) TagSettingsGet(key string) (string, bool) {
|
|
|
|
+ sf.tagSettingsLock.RLock()
|
|
|
|
+ defer sf.tagSettingsLock.RUnlock()
|
|
|
|
+ val, ok := sf.TagSettings[key]
|
|
return val, ok
|
|
return val, ok
|
|
}
|
|
}
|
|
|
|
|
|
// TagSettingsDelete deletes a tag
|
|
// TagSettingsDelete deletes a tag
|
|
-func (s *StructField) TagSettingsDelete(key string) {
|
|
|
|
- s.tagSettingsLock.Lock()
|
|
|
|
- defer s.tagSettingsLock.Unlock()
|
|
|
|
- delete(s.TagSettings, key)
|
|
|
|
|
|
+func (sf *StructField) TagSettingsDelete(key string) {
|
|
|
|
+ sf.tagSettingsLock.Lock()
|
|
|
|
+ defer sf.tagSettingsLock.Unlock()
|
|
|
|
+ delete(sf.TagSettings, key)
|
|
}
|
|
}
|
|
|
|
|
|
-func (structField *StructField) clone() *StructField {
|
|
|
|
|
|
+func (sf *StructField) clone() *StructField {
|
|
clone := &StructField{
|
|
clone := &StructField{
|
|
- DBName: structField.DBName,
|
|
|
|
- Name: structField.Name,
|
|
|
|
- Names: structField.Names,
|
|
|
|
- IsPrimaryKey: structField.IsPrimaryKey,
|
|
|
|
- IsNormal: structField.IsNormal,
|
|
|
|
- IsIgnored: structField.IsIgnored,
|
|
|
|
- IsScanner: structField.IsScanner,
|
|
|
|
- HasDefaultValue: structField.HasDefaultValue,
|
|
|
|
- Tag: structField.Tag,
|
|
|
|
|
|
+ DBName: sf.DBName,
|
|
|
|
+ Name: sf.Name,
|
|
|
|
+ Names: sf.Names,
|
|
|
|
+ IsPrimaryKey: sf.IsPrimaryKey,
|
|
|
|
+ IsNormal: sf.IsNormal,
|
|
|
|
+ IsIgnored: sf.IsIgnored,
|
|
|
|
+ IsScanner: sf.IsScanner,
|
|
|
|
+ HasDefaultValue: sf.HasDefaultValue,
|
|
|
|
+ Tag: sf.Tag,
|
|
TagSettings: map[string]string{},
|
|
TagSettings: map[string]string{},
|
|
- Struct: structField.Struct,
|
|
|
|
- IsForeignKey: structField.IsForeignKey,
|
|
|
|
|
|
+ Struct: sf.Struct,
|
|
|
|
+ IsForeignKey: sf.IsForeignKey,
|
|
}
|
|
}
|
|
|
|
|
|
- if structField.Relationship != nil {
|
|
|
|
- relationship := *structField.Relationship
|
|
|
|
|
|
+ if sf.Relationship != nil {
|
|
|
|
+ relationship := *sf.Relationship
|
|
clone.Relationship = &relationship
|
|
clone.Relationship = &relationship
|
|
}
|
|
}
|
|
|
|
|
|
// copy the struct field tagSettings, they should be read-locked while they are copied
|
|
// copy the struct field tagSettings, they should be read-locked while they are copied
|
|
- structField.tagSettingsLock.Lock()
|
|
|
|
- defer structField.tagSettingsLock.Unlock()
|
|
|
|
- for key, value := range structField.TagSettings {
|
|
|
|
|
|
+ sf.tagSettingsLock.Lock()
|
|
|
|
+ defer sf.tagSettingsLock.Unlock()
|
|
|
|
+ for key, value := range sf.TagSettings {
|
|
clone.TagSettings[key] = value
|
|
clone.TagSettings[key] = value
|
|
}
|
|
}
|
|
|
|
|
|
@@ -141,6 +152,10 @@ func getForeignField(column string, fields []*StructField) *StructField {
|
|
|
|
|
|
// GetModelStruct get value's model struct, relationships based on struct and tag definition
|
|
// GetModelStruct get value's model struct, relationships based on struct and tag definition
|
|
func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
|
|
+ return scope.getModelStruct(scope, make([]*StructField, 0))
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (scope *Scope) getModelStruct(rootScope *Scope, allFields []*StructField) *ModelStruct {
|
|
var modelStruct ModelStruct
|
|
var modelStruct ModelStruct
|
|
// Scope value can't be nil
|
|
// Scope value can't be nil
|
|
if scope.Value == nil {
|
|
if scope.Value == nil {
|
|
@@ -158,7 +173,18 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
}
|
|
}
|
|
|
|
|
|
// Get Cached model struct
|
|
// Get Cached model struct
|
|
- if value, ok := modelStructsMap.Load(reflectType); ok && value != nil {
|
|
|
|
|
|
+ isSingularTable := false
|
|
|
|
+ if scope.db != nil && scope.db.parent != nil {
|
|
|
|
+ scope.db.parent.RLock()
|
|
|
|
+ isSingularTable = scope.db.parent.singularTable
|
|
|
|
+ scope.db.parent.RUnlock()
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ hashKey := struct {
|
|
|
|
+ singularTable bool
|
|
|
|
+ reflectType reflect.Type
|
|
|
|
+ }{isSingularTable, reflectType}
|
|
|
|
+ if value, ok := modelStructsMap.Load(hashKey); ok && value != nil {
|
|
return value.(*ModelStruct)
|
|
return value.(*ModelStruct)
|
|
}
|
|
}
|
|
|
|
|
|
@@ -184,7 +210,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
|
}
|
|
}
|
|
|
|
|
|
- if _, ok := field.TagSettingsGet("DEFAULT"); ok {
|
|
|
|
|
|
+ if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey {
|
|
field.HasDefaultValue = true
|
|
field.HasDefaultValue = true
|
|
}
|
|
}
|
|
|
|
|
|
@@ -215,7 +241,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
field.IsNormal = true
|
|
field.IsNormal = true
|
|
} else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous {
|
|
} else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous {
|
|
// is embedded struct
|
|
// is embedded struct
|
|
- for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields {
|
|
|
|
|
|
+ for _, subField := range scope.New(fieldValue).getModelStruct(rootScope, allFields).StructFields {
|
|
subField = subField.clone()
|
|
subField = subField.clone()
|
|
subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
|
|
subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
|
|
if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
|
|
if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok {
|
|
@@ -239,6 +265,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
}
|
|
}
|
|
|
|
|
|
modelStruct.StructFields = append(modelStruct.StructFields, subField)
|
|
modelStruct.StructFields = append(modelStruct.StructFields, subField)
|
|
|
|
+ allFields = append(allFields, subField)
|
|
}
|
|
}
|
|
continue
|
|
continue
|
|
} else {
|
|
} else {
|
|
@@ -335,7 +362,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
}
|
|
}
|
|
|
|
|
|
joinTableHandler := JoinTableHandler{}
|
|
joinTableHandler := JoinTableHandler{}
|
|
- joinTableHandler.Setup(relationship, ToTableName(many2many), reflectType, elemType)
|
|
|
|
|
|
+ joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
|
relationship.JoinTableHandler = &joinTableHandler
|
|
relationship.JoinTableHandler = &joinTableHandler
|
|
field.Relationship = relationship
|
|
field.Relationship = relationship
|
|
} else {
|
|
} else {
|
|
@@ -372,7 +399,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
} else {
|
|
} else {
|
|
// generate foreign keys from defined association foreign keys
|
|
// generate foreign keys from defined association foreign keys
|
|
for _, scopeFieldName := range associationForeignKeys {
|
|
for _, scopeFieldName := range associationForeignKeys {
|
|
- if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil {
|
|
|
|
|
|
+ if foreignField := getForeignField(scopeFieldName, allFields); foreignField != nil {
|
|
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
|
|
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
|
|
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
|
|
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
|
|
}
|
|
}
|
|
@@ -384,13 +411,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
for _, foreignKey := range foreignKeys {
|
|
for _, foreignKey := range foreignKeys {
|
|
if strings.HasPrefix(foreignKey, associationType) {
|
|
if strings.HasPrefix(foreignKey, associationType) {
|
|
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
|
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
|
- if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
|
|
|
|
|
|
+ if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
|
|
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
|
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
|
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
|
- associationForeignKeys = []string{scope.PrimaryKey()}
|
|
|
|
|
|
+ associationForeignKeys = []string{rootScope.PrimaryKey()}
|
|
}
|
|
}
|
|
} else if len(foreignKeys) != len(associationForeignKeys) {
|
|
} else if len(foreignKeys) != len(associationForeignKeys) {
|
|
scope.Err(errors.New("invalid foreign keys, should have same length"))
|
|
scope.Err(errors.New("invalid foreign keys, should have same length"))
|
|
@@ -400,9 +427,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
|
|
|
|
for idx, foreignKey := range foreignKeys {
|
|
for idx, foreignKey := range foreignKeys {
|
|
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
|
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
|
- if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
|
|
|
|
- // source foreign keys
|
|
|
|
|
|
+ if associationField := getForeignField(associationForeignKeys[idx], allFields); associationField != nil {
|
|
|
|
+ // mark field as foreignkey, use global lock to avoid race
|
|
|
|
+ structsLock.Lock()
|
|
foreignField.IsForeignKey = true
|
|
foreignField.IsForeignKey = true
|
|
|
|
+ structsLock.Unlock()
|
|
|
|
+
|
|
|
|
+ // association foreign keys
|
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
|
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
|
|
|
|
|
|
@@ -476,7 +507,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
} else {
|
|
} else {
|
|
// generate foreign keys form association foreign keys
|
|
// generate foreign keys form association foreign keys
|
|
for _, associationForeignKey := range tagAssociationForeignKeys {
|
|
for _, associationForeignKey := range tagAssociationForeignKeys {
|
|
- if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
|
|
|
|
|
|
+ if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
|
|
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
|
|
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
|
|
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
|
|
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
|
|
}
|
|
}
|
|
@@ -488,13 +519,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
for _, foreignKey := range foreignKeys {
|
|
for _, foreignKey := range foreignKeys {
|
|
if strings.HasPrefix(foreignKey, associationType) {
|
|
if strings.HasPrefix(foreignKey, associationType) {
|
|
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
|
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
|
- if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
|
|
|
|
|
|
+ if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil {
|
|
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
|
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
|
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
|
- associationForeignKeys = []string{scope.PrimaryKey()}
|
|
|
|
|
|
+ associationForeignKeys = []string{rootScope.PrimaryKey()}
|
|
}
|
|
}
|
|
} else if len(foreignKeys) != len(associationForeignKeys) {
|
|
} else if len(foreignKeys) != len(associationForeignKeys) {
|
|
scope.Err(errors.New("invalid foreign keys, should have same length"))
|
|
scope.Err(errors.New("invalid foreign keys, should have same length"))
|
|
@@ -504,9 +535,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
|
|
|
|
for idx, foreignKey := range foreignKeys {
|
|
for idx, foreignKey := range foreignKeys {
|
|
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
|
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
|
- if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
|
|
|
|
|
|
+ if scopeField := getForeignField(associationForeignKeys[idx], allFields); scopeField != nil {
|
|
|
|
+ // mark field as foreignkey, use global lock to avoid race
|
|
|
|
+ structsLock.Lock()
|
|
foreignField.IsForeignKey = true
|
|
foreignField.IsForeignKey = true
|
|
- // source foreign keys
|
|
|
|
|
|
+ structsLock.Unlock()
|
|
|
|
+
|
|
|
|
+ // association foreign keys
|
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
|
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
|
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
|
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
|
|
|
|
|
|
@@ -564,7 +599,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
for idx, foreignKey := range foreignKeys {
|
|
for idx, foreignKey := range foreignKeys {
|
|
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
|
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
|
if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
|
|
if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
|
|
|
|
+ // mark field as foreignkey, use global lock to avoid race
|
|
|
|
+ structsLock.Lock()
|
|
foreignField.IsForeignKey = true
|
|
foreignField.IsForeignKey = true
|
|
|
|
+ structsLock.Unlock()
|
|
|
|
|
|
// association foreign keys
|
|
// association foreign keys
|
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
|
@@ -597,6 +635,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
}
|
|
}
|
|
|
|
|
|
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
|
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
|
|
|
+ allFields = append(allFields, field)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -607,7 +646,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- modelStructsMap.Store(reflectType, &modelStruct)
|
|
|
|
|
|
+ modelStructsMap.Store(hashKey, &modelStruct)
|
|
|
|
|
|
return &modelStruct
|
|
return &modelStruct
|
|
}
|
|
}
|
|
@@ -620,6 +659,9 @@ func (scope *Scope) GetStructFields() (fields []*StructField) {
|
|
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
|
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
|
setting := map[string]string{}
|
|
setting := map[string]string{}
|
|
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
|
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
|
|
|
+ if str == "" {
|
|
|
|
+ continue
|
|
|
|
+ }
|
|
tags := strings.Split(str, ";")
|
|
tags := strings.Split(str, ";")
|
|
for _, value := range tags {
|
|
for _, value := range tags {
|
|
v := strings.Split(value, ":")
|
|
v := strings.Split(value, ":")
|