encode.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. // Copyright 2014 Alvaro J. Genial. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package form
  5. import (
  6. "encoding"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "net/url"
  11. "reflect"
  12. "strconv"
  13. "strings"
  14. "time"
  15. )
  16. // NewEncoder returns a new form Encoder.
  17. func NewEncoder(w io.Writer) *Encoder {
  18. return &Encoder{w, defaultDelimiter, defaultEscape, false}
  19. }
  20. // Encoder provides a way to encode to a Writer.
  21. type Encoder struct {
  22. w io.Writer
  23. d rune
  24. e rune
  25. z bool
  26. }
  27. // DelimitWith sets r as the delimiter used for composite keys by Encoder e and returns the latter; it is '.' by default.
  28. func (e *Encoder) DelimitWith(r rune) *Encoder {
  29. e.d = r
  30. return e
  31. }
  32. // EscapeWith sets r as the escape used for delimiters (and to escape itself) by Encoder e and returns the latter; it is '\\' by default.
  33. func (e *Encoder) EscapeWith(r rune) *Encoder {
  34. e.e = r
  35. return e
  36. }
  37. // KeepZeros sets whether Encoder e should keep zero (default) values in their literal form when encoding, and returns the former; by default zero values are not kept, but are rather encoded as the empty string.
  38. func (e *Encoder) KeepZeros(z bool) *Encoder {
  39. e.z = z
  40. return e
  41. }
  42. // Encode encodes dst as form and writes it out using the Encoder's Writer.
  43. func (e Encoder) Encode(dst interface{}) error {
  44. v := reflect.ValueOf(dst)
  45. n, err := encodeToNode(v, e.z)
  46. if err != nil {
  47. return err
  48. }
  49. s := n.values(e.d, e.e).Encode()
  50. l, err := io.WriteString(e.w, s)
  51. switch {
  52. case err != nil:
  53. return err
  54. case l != len(s):
  55. return errors.New("could not write data completely")
  56. }
  57. return nil
  58. }
  59. // EncodeToString encodes dst as a form and returns it as a string.
  60. func EncodeToString(dst interface{}) (string, error) {
  61. v := reflect.ValueOf(dst)
  62. n, err := encodeToNode(v, false)
  63. if err != nil {
  64. return "", err
  65. }
  66. vs := n.values(defaultDelimiter, defaultEscape)
  67. return vs.Encode(), nil
  68. }
  69. // EncodeToValues encodes dst as a form and returns it as Values.
  70. func EncodeToValues(dst interface{}) (url.Values, error) {
  71. v := reflect.ValueOf(dst)
  72. n, err := encodeToNode(v, false)
  73. if err != nil {
  74. return nil, err
  75. }
  76. vs := n.values(defaultDelimiter, defaultEscape)
  77. return vs, nil
  78. }
  79. func encodeToNode(v reflect.Value, z bool) (n node, err error) {
  80. defer func() {
  81. if e := recover(); e != nil {
  82. err = fmt.Errorf("%v", e)
  83. }
  84. }()
  85. return getNode(encodeValue(v, z)), nil
  86. }
  87. func encodeValue(v reflect.Value, z bool) interface{} {
  88. t := v.Type()
  89. k := v.Kind()
  90. if s, ok := marshalValue(v); ok {
  91. return s
  92. } else if !z && isEmptyValue(v) {
  93. return "" // Treat the zero value as the empty string.
  94. }
  95. switch k {
  96. case reflect.Ptr, reflect.Interface:
  97. return encodeValue(v.Elem(), z)
  98. case reflect.Struct:
  99. if t.ConvertibleTo(timeType) {
  100. return encodeTime(v)
  101. } else if t.ConvertibleTo(urlType) {
  102. return encodeURL(v)
  103. }
  104. return encodeStruct(v, z)
  105. case reflect.Slice:
  106. return encodeSlice(v, z)
  107. case reflect.Array:
  108. return encodeArray(v, z)
  109. case reflect.Map:
  110. return encodeMap(v, z)
  111. case reflect.Invalid, reflect.Uintptr, reflect.UnsafePointer, reflect.Chan, reflect.Func:
  112. panic(t.String() + " has unsupported kind " + t.Kind().String())
  113. default:
  114. return encodeBasic(v)
  115. }
  116. }
  117. func encodeStruct(v reflect.Value, z bool) interface{} {
  118. t := v.Type()
  119. n := node{}
  120. for i := 0; i < t.NumField(); i++ {
  121. f := t.Field(i)
  122. k, oe := fieldInfo(f)
  123. if k == "-" {
  124. continue
  125. } else if fv := v.Field(i); oe && isEmptyValue(fv) {
  126. delete(n, k)
  127. } else {
  128. n[k] = encodeValue(fv, z)
  129. }
  130. }
  131. return n
  132. }
  133. func encodeMap(v reflect.Value, z bool) interface{} {
  134. n := node{}
  135. for _, i := range v.MapKeys() {
  136. k := getString(encodeValue(i, z))
  137. n[k] = encodeValue(v.MapIndex(i), z)
  138. }
  139. return n
  140. }
  141. func encodeArray(v reflect.Value, z bool) interface{} {
  142. n := node{}
  143. for i := 0; i < v.Len(); i++ {
  144. n[strconv.Itoa(i)] = encodeValue(v.Index(i), z)
  145. }
  146. return n
  147. }
  148. func encodeSlice(v reflect.Value, z bool) interface{} {
  149. t := v.Type()
  150. if t.Elem().Kind() == reflect.Uint8 {
  151. return string(v.Bytes()) // Encode byte slices as a single string by default.
  152. }
  153. n := node{}
  154. for i := 0; i < v.Len(); i++ {
  155. n[strconv.Itoa(i)] = encodeValue(v.Index(i), z)
  156. }
  157. return n
  158. }
  159. func encodeTime(v reflect.Value) string {
  160. t := v.Convert(timeType).Interface().(time.Time)
  161. if t.Year() == 0 && (t.Month() == 0 || t.Month() == 1) && (t.Day() == 0 || t.Day() == 1) {
  162. return t.Format("15:04:05.999999999Z07:00")
  163. } else if t.Hour() == 0 && t.Minute() == 0 && t.Second() == 0 && t.Nanosecond() == 0 {
  164. return t.Format("2006-01-02")
  165. }
  166. return t.Format("2006-01-02T15:04:05.999999999Z07:00")
  167. }
  168. func encodeURL(v reflect.Value) string {
  169. u := v.Convert(urlType).Interface().(url.URL)
  170. return u.String()
  171. }
  172. func encodeBasic(v reflect.Value) string {
  173. t := v.Type()
  174. switch k := t.Kind(); k {
  175. case reflect.Bool:
  176. return strconv.FormatBool(v.Bool())
  177. case reflect.Int,
  178. reflect.Int8,
  179. reflect.Int16,
  180. reflect.Int32,
  181. reflect.Int64:
  182. return strconv.FormatInt(v.Int(), 10)
  183. case reflect.Uint,
  184. reflect.Uint8,
  185. reflect.Uint16,
  186. reflect.Uint32,
  187. reflect.Uint64:
  188. return strconv.FormatUint(v.Uint(), 10)
  189. case reflect.Float32:
  190. return strconv.FormatFloat(v.Float(), 'g', -1, 32)
  191. case reflect.Float64:
  192. return strconv.FormatFloat(v.Float(), 'g', -1, 64)
  193. case reflect.Complex64, reflect.Complex128:
  194. s := fmt.Sprintf("%g", v.Complex())
  195. return strings.TrimSuffix(strings.TrimPrefix(s, "("), ")")
  196. case reflect.String:
  197. return v.String()
  198. }
  199. panic(t.String() + " has unsupported kind " + t.Kind().String())
  200. }
  201. func isEmptyValue(v reflect.Value) bool {
  202. switch t := v.Type(); v.Kind() {
  203. case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
  204. return v.Len() == 0
  205. case reflect.Bool:
  206. return !v.Bool()
  207. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  208. return v.Int() == 0
  209. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  210. return v.Uint() == 0
  211. case reflect.Float32, reflect.Float64:
  212. return v.Float() == 0
  213. case reflect.Complex64, reflect.Complex128:
  214. return v.Complex() == 0
  215. case reflect.Interface, reflect.Ptr:
  216. return v.IsNil()
  217. case reflect.Struct:
  218. if t.ConvertibleTo(timeType) {
  219. return v.Convert(timeType).Interface().(time.Time).IsZero()
  220. }
  221. return reflect.DeepEqual(v, reflect.Zero(t))
  222. }
  223. return false
  224. }
  225. // canIndexOrdinally returns whether a value contains an ordered sequence of elements.
  226. func canIndexOrdinally(v reflect.Value) bool {
  227. if !v.IsValid() {
  228. return false
  229. }
  230. switch t := v.Type(); t.Kind() {
  231. case reflect.Ptr, reflect.Interface:
  232. return canIndexOrdinally(v.Elem())
  233. case reflect.Slice, reflect.Array:
  234. return true
  235. }
  236. return false
  237. }
  238. func fieldInfo(f reflect.StructField) (k string, oe bool) {
  239. if f.PkgPath != "" { // Skip private fields.
  240. return omittedKey, oe
  241. }
  242. k = f.Name
  243. tag := f.Tag.Get("form")
  244. if tag == "" {
  245. return k, oe
  246. }
  247. ps := strings.SplitN(tag, ",", 2)
  248. if ps[0] != "" {
  249. k = ps[0]
  250. }
  251. if len(ps) == 2 {
  252. oe = ps[1] == "omitempty"
  253. }
  254. return k, oe
  255. }
  256. func findField(v reflect.Value, n string, ignoreCase bool) (reflect.Value, bool) {
  257. t := v.Type()
  258. l := v.NumField()
  259. var lowerN string
  260. caseInsensitiveMatch := -1
  261. if ignoreCase {
  262. lowerN = strings.ToLower(n)
  263. }
  264. // First try named fields.
  265. for i := 0; i < l; i++ {
  266. f := t.Field(i)
  267. k, _ := fieldInfo(f)
  268. if k == omittedKey {
  269. continue
  270. } else if n == k {
  271. return v.Field(i), true
  272. } else if ignoreCase && lowerN == strings.ToLower(k) {
  273. caseInsensitiveMatch = i
  274. }
  275. }
  276. // If no exact match was found try case insensitive match.
  277. if caseInsensitiveMatch != -1 {
  278. return v.Field(caseInsensitiveMatch), true
  279. }
  280. // Then try anonymous (embedded) fields.
  281. for i := 0; i < l; i++ {
  282. f := t.Field(i)
  283. k, _ := fieldInfo(f)
  284. if k == omittedKey || !f.Anonymous { // || k != "" ?
  285. continue
  286. }
  287. fv := v.Field(i)
  288. fk := fv.Kind()
  289. for fk == reflect.Ptr || fk == reflect.Interface {
  290. fv = fv.Elem()
  291. fk = fv.Kind()
  292. }
  293. if fk != reflect.Struct {
  294. continue
  295. }
  296. if ev, ok := findField(fv, n, ignoreCase); ok {
  297. return ev, true
  298. }
  299. }
  300. return reflect.Value{}, false
  301. }
  302. var (
  303. stringType = reflect.TypeOf(string(""))
  304. stringMapType = reflect.TypeOf(map[string]interface{}{})
  305. timeType = reflect.TypeOf(time.Time{})
  306. timePtrType = reflect.TypeOf(&time.Time{})
  307. urlType = reflect.TypeOf(url.URL{})
  308. )
  309. func skipTextMarshalling(t reflect.Type) bool {
  310. /*// Skip time.Time because its text unmarshaling is overly rigid:
  311. return t == timeType || t == timePtrType*/
  312. // Skip time.Time & convertibles because its text unmarshaling is overly rigid:
  313. return t.ConvertibleTo(timeType) || t.ConvertibleTo(timePtrType)
  314. }
  315. func unmarshalValue(v reflect.Value, x interface{}) bool {
  316. if skipTextMarshalling(v.Type()) {
  317. return false
  318. }
  319. tu, ok := v.Interface().(encoding.TextUnmarshaler)
  320. if !ok && !v.CanAddr() {
  321. return false
  322. } else if !ok {
  323. return unmarshalValue(v.Addr(), x)
  324. }
  325. s := getString(x)
  326. if err := tu.UnmarshalText([]byte(s)); err != nil {
  327. panic(err)
  328. }
  329. return true
  330. }
  331. func marshalValue(v reflect.Value) (string, bool) {
  332. if skipTextMarshalling(v.Type()) {
  333. return "", false
  334. }
  335. tm, ok := v.Interface().(encoding.TextMarshaler)
  336. if !ok && !v.CanAddr() {
  337. return "", false
  338. } else if !ok {
  339. return marshalValue(v.Addr())
  340. }
  341. bs, err := tm.MarshalText()
  342. if err != nil {
  343. panic(err)
  344. }
  345. return string(bs), true
  346. }