migrator.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package mysql
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "gorm.io/gorm"
  6. "gorm.io/gorm/clause"
  7. "gorm.io/gorm/migrator"
  8. "gorm.io/gorm/schema"
  9. )
  10. type Migrator struct {
  11. migrator.Migrator
  12. Dialector
  13. }
  14. type Column struct {
  15. name string
  16. nullable sql.NullString
  17. datatype string
  18. maxlen sql.NullInt64
  19. precision sql.NullInt64
  20. scale sql.NullInt64
  21. datetimeprecision sql.NullInt64
  22. }
  23. func (c Column) Name() string {
  24. return c.name
  25. }
  26. func (c Column) DatabaseTypeName() string {
  27. return c.datatype
  28. }
  29. func (c Column) Length() (length int64, ok bool) {
  30. ok = c.maxlen.Valid
  31. if ok {
  32. length = c.maxlen.Int64
  33. } else {
  34. length = 0
  35. }
  36. return
  37. }
  38. func (c Column) Nullable() (nullable bool, ok bool) {
  39. if c.nullable.Valid {
  40. nullable, ok = c.nullable.String == "YES", true
  41. } else {
  42. nullable, ok = false, false
  43. }
  44. return
  45. }
  46. func (c Column) DecimalSize() (precision int64, scale int64, ok bool) {
  47. if c.precision.Valid {
  48. if c.scale.Valid {
  49. precision, scale, ok = c.precision.Int64, c.scale.Int64, true
  50. } else {
  51. precision, scale, ok = c.precision.Int64, 0, true
  52. }
  53. } else if c.datetimeprecision.Valid {
  54. precision, scale, ok = c.datetimeprecision.Int64, 0, true
  55. } else {
  56. precision, scale, ok = 0, 0, false
  57. }
  58. return
  59. }
  60. func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr {
  61. expr := m.Migrator.FullDataTypeOf(field)
  62. if value, ok := field.TagSettings["COMMENT"]; ok {
  63. expr.SQL += " COMMENT " + m.Dialector.Explain("?", value)
  64. }
  65. return expr
  66. }
  67. func (m Migrator) AlterColumn(value interface{}, field string) error {
  68. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  69. if field := stmt.Schema.LookUpField(field); field != nil {
  70. return m.DB.Exec(
  71. "ALTER TABLE ? MODIFY COLUMN ? ?",
  72. clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
  73. ).Error
  74. }
  75. return fmt.Errorf("failed to look up field with name: %s", field)
  76. })
  77. }
  78. func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
  79. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  80. if m.Dialector.DontSupportRenameColumn {
  81. var field *schema.Field
  82. if f := stmt.Schema.LookUpField(oldName); f != nil {
  83. oldName = f.DBName
  84. field = f
  85. }
  86. if f := stmt.Schema.LookUpField(newName); f != nil {
  87. newName = f.DBName
  88. field = f
  89. }
  90. if field != nil {
  91. return m.DB.Exec(
  92. "ALTER TABLE ? CHANGE ? ? ?",
  93. clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, m.FullDataTypeOf(field),
  94. ).Error
  95. }
  96. } else {
  97. return m.Migrator.RenameColumn(value, oldName, newName)
  98. }
  99. return fmt.Errorf("failed to look up field with name: %s", newName)
  100. })
  101. }
  102. func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
  103. if m.Dialector.DontSupportRenameIndex {
  104. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  105. err := m.DropIndex(value, oldName)
  106. if err == nil {
  107. if idx := stmt.Schema.LookIndex(newName); idx == nil {
  108. if idx = stmt.Schema.LookIndex(oldName); idx != nil {
  109. opts := m.BuildIndexOptions(idx.Fields, stmt)
  110. values := []interface{}{clause.Column{Name: newName}, clause.Table{Name: stmt.Table}, opts}
  111. createIndexSQL := "CREATE "
  112. if idx.Class != "" {
  113. createIndexSQL += idx.Class + " "
  114. }
  115. createIndexSQL += "INDEX ? ON ??"
  116. if idx.Type != "" {
  117. createIndexSQL += " USING " + idx.Type
  118. }
  119. return m.DB.Exec(createIndexSQL, values...).Error
  120. }
  121. }
  122. err = m.CreateIndex(value, newName)
  123. }
  124. return err
  125. })
  126. } else {
  127. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  128. return m.DB.Exec(
  129. "ALTER TABLE ? RENAME INDEX ? TO ?",
  130. clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
  131. ).Error
  132. })
  133. }
  134. }
  135. func (m Migrator) DropTable(values ...interface{}) error {
  136. values = m.ReorderModels(values, false)
  137. tx := m.DB.Session(&gorm.Session{})
  138. tx.Exec("SET FOREIGN_KEY_CHECKS = 0;")
  139. for i := len(values) - 1; i >= 0; i-- {
  140. if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
  141. return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
  142. }); err != nil {
  143. return err
  144. }
  145. }
  146. tx.Exec("SET FOREIGN_KEY_CHECKS = 1;")
  147. return nil
  148. }
  149. func (m Migrator) DropConstraint(value interface{}, name string) error {
  150. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  151. for _, chk := range stmt.Schema.ParseCheckConstraints() {
  152. if chk.Name == name {
  153. return m.DB.Exec(
  154. "ALTER TABLE ? DROP CHECK ?",
  155. clause.Table{Name: stmt.Table}, clause.Column{Name: name},
  156. ).Error
  157. }
  158. }
  159. return m.DB.Exec(
  160. "ALTER TABLE ? DROP FOREIGN KEY ?",
  161. clause.Table{Name: stmt.Table}, clause.Column{Name: name},
  162. ).Error
  163. })
  164. }
  165. func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) {
  166. columnTypes = make([]gorm.ColumnType, 0)
  167. err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
  168. currentDatabase := m.DB.Migrator().CurrentDatabase()
  169. columns, err := m.DB.Raw(
  170. "SELECT column_name, is_nullable, data_type, character_maximum_length, "+
  171. "numeric_precision, numeric_scale, datetime_precision "+
  172. "FROM information_schema.columns WHERE table_schema = ? AND table_name = ?",
  173. currentDatabase, stmt.Table).Rows()
  174. if err != nil {
  175. return err
  176. }
  177. defer columns.Close()
  178. for columns.Next() {
  179. var column Column
  180. err = columns.Scan(
  181. &column.name,
  182. &column.nullable,
  183. &column.datatype,
  184. &column.maxlen,
  185. &column.precision,
  186. &column.scale,
  187. &column.datetimeprecision,
  188. )
  189. if err != nil {
  190. return err
  191. }
  192. columnTypes = append(columnTypes, column)
  193. }
  194. return err
  195. })
  196. return
  197. }