decoder.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. // Copyright 2012 The Gorilla Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package schema
  5. import (
  6. "encoding"
  7. "errors"
  8. "fmt"
  9. "reflect"
  10. "strings"
  11. )
  12. // NewDecoder returns a new Decoder.
  13. func NewDecoder() *Decoder {
  14. return &Decoder{cache: newCache()}
  15. }
  16. // Decoder decodes values from a map[string][]string to a struct.
  17. type Decoder struct {
  18. cache *cache
  19. zeroEmpty bool
  20. ignoreUnknownKeys bool
  21. }
  22. // AddAliasTag adds a tag used to locate custom field aliases.
  23. // Defaults are "schema", "form" and "url".
  24. func (d *Decoder) AddAliasTag(tag ...string) *Decoder {
  25. d.cache.tags = append(d.cache.tags, tag...)
  26. return d
  27. }
  28. // SetAliasTag overrides the tags.
  29. func (d *Decoder) SetAliasTag(tag ...string) *Decoder {
  30. d.cache.tags = tag
  31. return d
  32. }
  33. // ZeroEmpty controls the behaviour when the decoder encounters empty values
  34. // in a map.
  35. // If z is true and a key in the map has the empty string as a value
  36. // then the corresponding struct field is set to the zero value.
  37. // If z is false then empty strings are ignored.
  38. //
  39. // The default value is false, that is empty values do not change
  40. // the value of the struct field.
  41. func (d *Decoder) ZeroEmpty(z bool) *Decoder {
  42. d.zeroEmpty = z
  43. return d
  44. }
  45. // IgnoreUnknownKeys controls the behaviour when the decoder encounters unknown
  46. // keys in the map.
  47. // If i is true and an unknown field is encountered, it is ignored. This is
  48. // similar to how unknown keys are handled by encoding/json.
  49. // If i is false then Decode will return an error. Note that any valid keys
  50. // will still be decoded in to the target struct.
  51. //
  52. // To preserve backwards compatibility, the default value is false.
  53. func (d *Decoder) IgnoreUnknownKeys(i bool) *Decoder {
  54. d.ignoreUnknownKeys = i
  55. return d
  56. }
  57. // RegisterConverter registers a converter function for a custom type.
  58. func (d *Decoder) RegisterConverter(value interface{}, converterFunc Converter) *Decoder {
  59. d.cache.registerConverter(value, converterFunc)
  60. return d
  61. }
  62. // Decode decodes a map[string][]string to a struct.
  63. //
  64. // The first parameter must be a pointer to a struct.
  65. //
  66. // The second parameter is a map, typically url.Values from an HTTP request.
  67. // Keys are "paths" in dotted notation to the struct fields and nested structs.
  68. //
  69. // See the package documentation for a full explanation of the mechanics.
  70. func (d *Decoder) Decode(dst interface{}, src map[string][]string) error {
  71. v := reflect.ValueOf(dst)
  72. if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
  73. return errors.New("schema: interface must be a pointer to struct")
  74. }
  75. v = v.Elem()
  76. t := v.Type()
  77. errors := MultiError{}
  78. for path, values := range src {
  79. if parts, err := d.cache.parsePath(path, t); err == nil {
  80. if err = d.decode(v, path, parts, values); err != nil {
  81. errors[path] = err
  82. }
  83. } else if !d.ignoreUnknownKeys {
  84. errors[path] = UnknownKeyError{Key: path}
  85. }
  86. }
  87. errors.merge(d.checkRequired(t, src))
  88. if len(errors) > 0 {
  89. return errors
  90. }
  91. return nil
  92. }
  93. // checkRequired checks whether required fields are empty
  94. //
  95. // check type t recursively if t has struct fields.
  96. //
  97. // src is the source map for decoding, we use it here to see if those required fields are included in src
  98. func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string) MultiError {
  99. m, errs := d.findRequiredFields(t, "", "")
  100. for key, fields := range m {
  101. if isEmptyFields(fields, src) {
  102. errs[key] = EmptyFieldError{Key: key}
  103. }
  104. }
  105. return errs
  106. }
  107. // findRequiredFields recursively searches the struct type t for required fields.
  108. //
  109. // canonicalPrefix and searchPrefix are used to resolve full paths in dotted notation
  110. // for nested struct fields. canonicalPrefix is a complete path which never omits
  111. // any embedded struct fields. searchPrefix is a user-friendly path which may omit
  112. // some embedded struct fields to point promoted fields.
  113. func (d *Decoder) findRequiredFields(t reflect.Type, canonicalPrefix, searchPrefix string) (map[string][]fieldWithPrefix, MultiError) {
  114. struc := d.cache.get(t)
  115. if struc == nil {
  116. // unexpect, cache.get never return nil
  117. return nil, MultiError{canonicalPrefix + "*": errors.New("cache fail")}
  118. }
  119. m := map[string][]fieldWithPrefix{}
  120. errs := MultiError{}
  121. for _, f := range struc.fields {
  122. if f.typ.Kind() == reflect.Struct {
  123. fcprefix := canonicalPrefix + f.canonicalAlias + "."
  124. for _, fspath := range f.paths(searchPrefix) {
  125. fm, ferrs := d.findRequiredFields(f.typ, fcprefix, fspath+".")
  126. for key, fields := range fm {
  127. m[key] = append(m[key], fields...)
  128. }
  129. errs.merge(ferrs)
  130. }
  131. }
  132. if f.isRequired {
  133. key := canonicalPrefix + f.canonicalAlias
  134. m[key] = append(m[key], fieldWithPrefix{
  135. fieldInfo: f,
  136. prefix: searchPrefix,
  137. })
  138. }
  139. }
  140. return m, errs
  141. }
  142. type fieldWithPrefix struct {
  143. *fieldInfo
  144. prefix string
  145. }
  146. // isEmptyFields returns true if all of specified fields are empty.
  147. func isEmptyFields(fields []fieldWithPrefix, src map[string][]string) bool {
  148. for _, f := range fields {
  149. for _, path := range f.paths(f.prefix) {
  150. if !isEmpty(f.typ, src[path]) {
  151. return false
  152. }
  153. }
  154. }
  155. return true
  156. }
  157. // isEmpty returns true if value is empty for specific type
  158. func isEmpty(t reflect.Type, value []string) bool {
  159. if len(value) == 0 {
  160. return true
  161. }
  162. switch t.Kind() {
  163. case boolType, float32Type, float64Type, intType, int8Type, int32Type, int64Type, stringType, uint8Type, uint16Type, uint32Type, uint64Type:
  164. return len(value[0]) == 0
  165. }
  166. return false
  167. }
  168. // decode fills a struct field using a parsed path.
  169. func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values []string) error {
  170. // Get the field walking the struct fields by index.
  171. for _, name := range parts[0].path {
  172. if v.Type().Kind() == reflect.Ptr {
  173. if v.IsNil() {
  174. v.Set(reflect.New(v.Type().Elem()))
  175. }
  176. v = v.Elem()
  177. }
  178. // alloc embedded structs
  179. if v.Type().Kind() == reflect.Struct {
  180. for i := 0; i < v.NumField(); i++ {
  181. field := v.Field(i)
  182. if field.Type().Kind() == reflect.Ptr && field.IsNil() && v.Type().Field(i).Anonymous == true {
  183. field.Set(reflect.New(field.Type().Elem()))
  184. }
  185. }
  186. }
  187. v = v.FieldByName(name)
  188. }
  189. // Don't even bother for unexported fields.
  190. if !v.CanSet() {
  191. return nil
  192. }
  193. // Dereference if needed.
  194. t := v.Type()
  195. if t.Kind() == reflect.Ptr {
  196. t = t.Elem()
  197. if v.IsNil() {
  198. v.Set(reflect.New(t))
  199. }
  200. v = v.Elem()
  201. }
  202. // Slice of structs. Let's go recursive.
  203. if len(parts) > 1 {
  204. idx := parts[0].index
  205. if v.IsNil() || v.Len() < idx+1 {
  206. value := reflect.MakeSlice(t, idx+1, idx+1)
  207. if v.Len() < idx+1 {
  208. // Resize it.
  209. reflect.Copy(value, v)
  210. }
  211. v.Set(value)
  212. }
  213. return d.decode(v.Index(idx), path, parts[1:], values)
  214. }
  215. // Get the converter early in case there is one for a slice type.
  216. conv := d.cache.converter(t)
  217. m := isTextUnmarshaler(v)
  218. if conv == nil && t.Kind() == reflect.Slice && m.IsSliceElement {
  219. var items []reflect.Value
  220. elemT := t.Elem()
  221. isPtrElem := elemT.Kind() == reflect.Ptr
  222. if isPtrElem {
  223. elemT = elemT.Elem()
  224. }
  225. // Try to get a converter for the element type.
  226. conv := d.cache.converter(elemT)
  227. if conv == nil {
  228. conv = builtinConverters[elemT.Kind()]
  229. if conv == nil {
  230. // As we are not dealing with slice of structs here, we don't need to check if the type
  231. // implements TextUnmarshaler interface
  232. return fmt.Errorf("schema: converter not found for %v", elemT)
  233. }
  234. }
  235. for key, value := range values {
  236. if value == "" {
  237. if d.zeroEmpty {
  238. items = append(items, reflect.Zero(elemT))
  239. }
  240. } else if m.IsValid {
  241. u := reflect.New(elemT)
  242. if m.IsSliceElementPtr {
  243. u = reflect.New(reflect.PtrTo(elemT).Elem())
  244. }
  245. if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)); err != nil {
  246. return ConversionError{
  247. Key: path,
  248. Type: t,
  249. Index: key,
  250. Err: err,
  251. }
  252. }
  253. if m.IsSliceElementPtr {
  254. items = append(items, u.Elem().Addr())
  255. } else if u.Kind() == reflect.Ptr {
  256. items = append(items, u.Elem())
  257. } else {
  258. items = append(items, u)
  259. }
  260. } else if item := conv(value); item.IsValid() {
  261. if isPtrElem {
  262. ptr := reflect.New(elemT)
  263. ptr.Elem().Set(item)
  264. item = ptr
  265. }
  266. if item.Type() != elemT && !isPtrElem {
  267. item = item.Convert(elemT)
  268. }
  269. items = append(items, item)
  270. } else {
  271. if strings.Contains(value, ",") {
  272. values := strings.Split(value, ",")
  273. for _, value := range values {
  274. if value == "" {
  275. if d.zeroEmpty {
  276. items = append(items, reflect.Zero(elemT))
  277. }
  278. } else if item := conv(value); item.IsValid() {
  279. if isPtrElem {
  280. ptr := reflect.New(elemT)
  281. ptr.Elem().Set(item)
  282. item = ptr
  283. }
  284. if item.Type() != elemT && !isPtrElem {
  285. item = item.Convert(elemT)
  286. }
  287. items = append(items, item)
  288. } else {
  289. return ConversionError{
  290. Key: path,
  291. Type: elemT,
  292. Index: key,
  293. }
  294. }
  295. }
  296. } else {
  297. return ConversionError{
  298. Key: path,
  299. Type: elemT,
  300. Index: key,
  301. }
  302. }
  303. }
  304. }
  305. value := reflect.Append(reflect.MakeSlice(t, 0, 0), items...)
  306. v.Set(value)
  307. } else {
  308. val := ""
  309. // Use the last value provided if any values were provided
  310. if len(values) > 0 {
  311. val = values[len(values)-1]
  312. }
  313. if conv != nil {
  314. if value := conv(val); value.IsValid() {
  315. v.Set(value.Convert(t))
  316. } else {
  317. return ConversionError{
  318. Key: path,
  319. Type: t,
  320. Index: -1,
  321. }
  322. }
  323. } else if m.IsValid {
  324. if m.IsPtr {
  325. u := reflect.New(v.Type())
  326. if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(val)); err != nil {
  327. return ConversionError{
  328. Key: path,
  329. Type: t,
  330. Index: -1,
  331. Err: err,
  332. }
  333. }
  334. v.Set(reflect.Indirect(u))
  335. } else {
  336. // If the value implements the encoding.TextUnmarshaler interface
  337. // apply UnmarshalText as the converter
  338. if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil {
  339. return ConversionError{
  340. Key: path,
  341. Type: t,
  342. Index: -1,
  343. Err: err,
  344. }
  345. }
  346. }
  347. } else if val == "" {
  348. if d.zeroEmpty {
  349. v.Set(reflect.Zero(t))
  350. }
  351. } else if conv := builtinConverters[t.Kind()]; conv != nil {
  352. if value := conv(val); value.IsValid() {
  353. v.Set(value.Convert(t))
  354. } else {
  355. return ConversionError{
  356. Key: path,
  357. Type: t,
  358. Index: -1,
  359. }
  360. }
  361. } else {
  362. return fmt.Errorf("schema: converter not found for %v", t)
  363. }
  364. }
  365. return nil
  366. }
  367. func isTextUnmarshaler(v reflect.Value) unmarshaler {
  368. // Create a new unmarshaller instance
  369. m := unmarshaler{}
  370. if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
  371. return m
  372. }
  373. // As the UnmarshalText function should be applied to the pointer of the
  374. // type, we check that type to see if it implements the necessary
  375. // method.
  376. if m.Unmarshaler, m.IsValid = reflect.New(v.Type()).Interface().(encoding.TextUnmarshaler); m.IsValid {
  377. m.IsPtr = true
  378. return m
  379. }
  380. // if v is []T or *[]T create new T
  381. t := v.Type()
  382. if t.Kind() == reflect.Ptr {
  383. t = t.Elem()
  384. }
  385. if t.Kind() == reflect.Slice {
  386. // Check if the slice implements encoding.TextUnmarshaller
  387. if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
  388. return m
  389. }
  390. // If t is a pointer slice, check if its elements implement
  391. // encoding.TextUnmarshaler
  392. m.IsSliceElement = true
  393. if t = t.Elem(); t.Kind() == reflect.Ptr {
  394. t = reflect.PtrTo(t.Elem())
  395. v = reflect.Zero(t)
  396. m.IsSliceElementPtr = true
  397. m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
  398. return m
  399. }
  400. }
  401. v = reflect.New(t)
  402. m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
  403. return m
  404. }
  405. // TextUnmarshaler helpers ----------------------------------------------------
  406. // unmarshaller contains information about a TextUnmarshaler type
  407. type unmarshaler struct {
  408. Unmarshaler encoding.TextUnmarshaler
  409. // IsValid indicates whether the resolved type indicated by the other
  410. // flags implements the encoding.TextUnmarshaler interface.
  411. IsValid bool
  412. // IsPtr indicates that the resolved type is the pointer of the original
  413. // type.
  414. IsPtr bool
  415. // IsSliceElement indicates that the resolved type is a slice element of
  416. // the original type.
  417. IsSliceElement bool
  418. // IsSliceElementPtr indicates that the resolved type is a pointer to a
  419. // slice element of the original type.
  420. IsSliceElementPtr bool
  421. }
  422. // Errors ---------------------------------------------------------------------
  423. // ConversionError stores information about a failed conversion.
  424. type ConversionError struct {
  425. Key string // key from the source map.
  426. Type reflect.Type // expected type of elem
  427. Index int // index for multi-value fields; -1 for single-value fields.
  428. Err error // low-level error (when it exists)
  429. }
  430. func (e ConversionError) Error() string {
  431. var output string
  432. if e.Index < 0 {
  433. output = fmt.Sprintf("schema: error converting value for %q", e.Key)
  434. } else {
  435. output = fmt.Sprintf("schema: error converting value for index %d of %q",
  436. e.Index, e.Key)
  437. }
  438. if e.Err != nil {
  439. output = fmt.Sprintf("%s. Details: %s", output, e.Err)
  440. }
  441. return output
  442. }
  443. // UnknownKeyError stores information about an unknown key in the source map.
  444. type UnknownKeyError struct {
  445. Key string // key from the source map.
  446. }
  447. func (e UnknownKeyError) Error() string {
  448. return fmt.Sprintf("schema: invalid path %q", e.Key)
  449. }
  450. // EmptyFieldError stores information about an empty required field.
  451. type EmptyFieldError struct {
  452. Key string // required key in the source map.
  453. }
  454. func (e EmptyFieldError) Error() string {
  455. return fmt.Sprintf("%v is empty", e.Key)
  456. }
  457. // MultiError stores multiple decoding errors.
  458. //
  459. // Borrowed from the App Engine SDK.
  460. type MultiError map[string]error
  461. func (e MultiError) Error() string {
  462. s := ""
  463. for _, err := range e {
  464. s = err.Error()
  465. break
  466. }
  467. switch len(e) {
  468. case 0:
  469. return "(0 errors)"
  470. case 1:
  471. return s
  472. case 2:
  473. return s + " (and 1 other error)"
  474. }
  475. return fmt.Sprintf("%s (and %d other errors)", s, len(e)-1)
  476. }
  477. func (e MultiError) merge(errors MultiError) {
  478. for key, err := range errors {
  479. if e[key] == nil {
  480. e[key] = err
  481. }
  482. }
  483. }