expression.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. package clause
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "go/ast"
  6. "reflect"
  7. )
  8. // Expression expression interface
  9. type Expression interface {
  10. Build(builder Builder)
  11. }
  12. // NegationExpressionBuilder negation expression builder
  13. type NegationExpressionBuilder interface {
  14. NegationBuild(builder Builder)
  15. }
  16. // Expr raw expression
  17. type Expr struct {
  18. SQL string
  19. Vars []interface{}
  20. WithoutParentheses bool
  21. }
  22. // Build build raw expression
  23. func (expr Expr) Build(builder Builder) {
  24. var (
  25. afterParenthesis bool
  26. idx int
  27. )
  28. for _, v := range []byte(expr.SQL) {
  29. if v == '?' && len(expr.Vars) > idx {
  30. if afterParenthesis || expr.WithoutParentheses {
  31. if _, ok := expr.Vars[idx].(driver.Valuer); ok {
  32. builder.AddVar(builder, expr.Vars[idx])
  33. } else {
  34. switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
  35. case reflect.Slice, reflect.Array:
  36. if rv.Len() == 0 {
  37. builder.AddVar(builder, nil)
  38. } else {
  39. for i := 0; i < rv.Len(); i++ {
  40. if i > 0 {
  41. builder.WriteByte(',')
  42. }
  43. builder.AddVar(builder, rv.Index(i).Interface())
  44. }
  45. }
  46. default:
  47. builder.AddVar(builder, expr.Vars[idx])
  48. }
  49. }
  50. } else {
  51. builder.AddVar(builder, expr.Vars[idx])
  52. }
  53. idx++
  54. } else {
  55. if v == '(' {
  56. afterParenthesis = true
  57. } else {
  58. afterParenthesis = false
  59. }
  60. builder.WriteByte(v)
  61. }
  62. }
  63. }
  64. // NamedExpr raw expression for named expr
  65. type NamedExpr struct {
  66. SQL string
  67. Vars []interface{}
  68. }
  69. // Build build raw expression
  70. func (expr NamedExpr) Build(builder Builder) {
  71. var (
  72. idx int
  73. inName bool
  74. namedMap = make(map[string]interface{}, len(expr.Vars))
  75. )
  76. for _, v := range expr.Vars {
  77. switch value := v.(type) {
  78. case sql.NamedArg:
  79. namedMap[value.Name] = value.Value
  80. case map[string]interface{}:
  81. for k, v := range value {
  82. namedMap[k] = v
  83. }
  84. default:
  85. var appendFieldsToMap func(reflect.Value)
  86. appendFieldsToMap = func(reflectValue reflect.Value) {
  87. reflectValue = reflect.Indirect(reflectValue)
  88. switch reflectValue.Kind() {
  89. case reflect.Struct:
  90. modelType := reflectValue.Type()
  91. for i := 0; i < modelType.NumField(); i++ {
  92. if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
  93. namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
  94. if fieldStruct.Anonymous {
  95. appendFieldsToMap(reflectValue.Field(i))
  96. }
  97. }
  98. }
  99. }
  100. }
  101. appendFieldsToMap(reflect.ValueOf(value))
  102. }
  103. }
  104. name := make([]byte, 0, 10)
  105. for _, v := range []byte(expr.SQL) {
  106. if v == '@' && !inName {
  107. inName = true
  108. name = []byte{}
  109. } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' {
  110. if inName {
  111. if nv, ok := namedMap[string(name)]; ok {
  112. builder.AddVar(builder, nv)
  113. } else {
  114. builder.WriteByte('@')
  115. builder.WriteString(string(name))
  116. }
  117. inName = false
  118. }
  119. builder.WriteByte(v)
  120. } else if v == '?' && len(expr.Vars) > idx {
  121. builder.AddVar(builder, expr.Vars[idx])
  122. idx++
  123. } else if inName {
  124. name = append(name, v)
  125. } else {
  126. builder.WriteByte(v)
  127. }
  128. }
  129. if inName {
  130. builder.AddVar(builder, namedMap[string(name)])
  131. }
  132. }
  133. // IN Whether a value is within a set of values
  134. type IN struct {
  135. Column interface{}
  136. Values []interface{}
  137. }
  138. func (in IN) Build(builder Builder) {
  139. builder.WriteQuoted(in.Column)
  140. switch len(in.Values) {
  141. case 0:
  142. builder.WriteString(" IN (NULL)")
  143. case 1:
  144. builder.WriteString(" = ")
  145. builder.AddVar(builder, in.Values...)
  146. default:
  147. builder.WriteString(" IN (")
  148. builder.AddVar(builder, in.Values...)
  149. builder.WriteByte(')')
  150. }
  151. }
  152. func (in IN) NegationBuild(builder Builder) {
  153. switch len(in.Values) {
  154. case 0:
  155. case 1:
  156. builder.WriteQuoted(in.Column)
  157. builder.WriteString(" <> ")
  158. builder.AddVar(builder, in.Values...)
  159. default:
  160. builder.WriteQuoted(in.Column)
  161. builder.WriteString(" NOT IN (")
  162. builder.AddVar(builder, in.Values...)
  163. builder.WriteByte(')')
  164. }
  165. }
  166. // Eq equal to for where
  167. type Eq struct {
  168. Column interface{}
  169. Value interface{}
  170. }
  171. func (eq Eq) Build(builder Builder) {
  172. builder.WriteQuoted(eq.Column)
  173. if eq.Value == nil {
  174. builder.WriteString(" IS NULL")
  175. } else {
  176. builder.WriteString(" = ")
  177. builder.AddVar(builder, eq.Value)
  178. }
  179. }
  180. func (eq Eq) NegationBuild(builder Builder) {
  181. Neq{eq.Column, eq.Value}.Build(builder)
  182. }
  183. // Neq not equal to for where
  184. type Neq Eq
  185. func (neq Neq) Build(builder Builder) {
  186. builder.WriteQuoted(neq.Column)
  187. if neq.Value == nil {
  188. builder.WriteString(" IS NOT NULL")
  189. } else {
  190. builder.WriteString(" <> ")
  191. builder.AddVar(builder, neq.Value)
  192. }
  193. }
  194. func (neq Neq) NegationBuild(builder Builder) {
  195. Eq{neq.Column, neq.Value}.Build(builder)
  196. }
  197. // Gt greater than for where
  198. type Gt Eq
  199. func (gt Gt) Build(builder Builder) {
  200. builder.WriteQuoted(gt.Column)
  201. builder.WriteString(" > ")
  202. builder.AddVar(builder, gt.Value)
  203. }
  204. func (gt Gt) NegationBuild(builder Builder) {
  205. Lte{gt.Column, gt.Value}.Build(builder)
  206. }
  207. // Gte greater than or equal to for where
  208. type Gte Eq
  209. func (gte Gte) Build(builder Builder) {
  210. builder.WriteQuoted(gte.Column)
  211. builder.WriteString(" >= ")
  212. builder.AddVar(builder, gte.Value)
  213. }
  214. func (gte Gte) NegationBuild(builder Builder) {
  215. Lt{gte.Column, gte.Value}.Build(builder)
  216. }
  217. // Lt less than for where
  218. type Lt Eq
  219. func (lt Lt) Build(builder Builder) {
  220. builder.WriteQuoted(lt.Column)
  221. builder.WriteString(" < ")
  222. builder.AddVar(builder, lt.Value)
  223. }
  224. func (lt Lt) NegationBuild(builder Builder) {
  225. Gte{lt.Column, lt.Value}.Build(builder)
  226. }
  227. // Lte less than or equal to for where
  228. type Lte Eq
  229. func (lte Lte) Build(builder Builder) {
  230. builder.WriteQuoted(lte.Column)
  231. builder.WriteString(" <= ")
  232. builder.AddVar(builder, lte.Value)
  233. }
  234. func (lte Lte) NegationBuild(builder Builder) {
  235. Gt{lte.Column, lte.Value}.Build(builder)
  236. }
  237. // Like whether string matches regular expression
  238. type Like Eq
  239. func (like Like) Build(builder Builder) {
  240. builder.WriteQuoted(like.Column)
  241. builder.WriteString(" LIKE ")
  242. builder.AddVar(builder, like.Value)
  243. }
  244. func (like Like) NegationBuild(builder Builder) {
  245. builder.WriteQuoted(like.Column)
  246. builder.WriteString(" NOT LIKE ")
  247. builder.AddVar(builder, like.Value)
  248. }