decode_map.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. package msgpack
  2. import (
  3. "errors"
  4. "fmt"
  5. "reflect"
  6. "github.com/vmihailenco/msgpack/v5/msgpcode"
  7. )
  8. var errArrayStruct = errors.New("msgpack: number of fields in array-encoded struct has changed")
  9. var (
  10. mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil))
  11. mapStringStringType = mapStringStringPtrType.Elem()
  12. mapStringBoolPtrType = reflect.TypeOf((*map[string]bool)(nil))
  13. mapStringBoolType = mapStringBoolPtrType.Elem()
  14. )
  15. var (
  16. mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil))
  17. mapStringInterfaceType = mapStringInterfacePtrType.Elem()
  18. )
  19. func decodeMapValue(d *Decoder, v reflect.Value) error {
  20. n, err := d.DecodeMapLen()
  21. if err != nil {
  22. return err
  23. }
  24. typ := v.Type()
  25. if n == -1 {
  26. v.Set(reflect.Zero(typ))
  27. return nil
  28. }
  29. if v.IsNil() {
  30. ln := n
  31. if d.flags&disableAllocLimitFlag == 0 {
  32. ln = min(ln, maxMapSize)
  33. }
  34. v.Set(reflect.MakeMapWithSize(typ, ln))
  35. }
  36. if n == 0 {
  37. return nil
  38. }
  39. return d.decodeTypedMapValue(v, n)
  40. }
  41. func (d *Decoder) decodeMapDefault() (interface{}, error) {
  42. if d.mapDecoder != nil {
  43. return d.mapDecoder(d)
  44. }
  45. return d.DecodeMap()
  46. }
  47. // DecodeMapLen decodes map length. Length is -1 when map is nil.
  48. func (d *Decoder) DecodeMapLen() (int, error) {
  49. c, err := d.readCode()
  50. if err != nil {
  51. return 0, err
  52. }
  53. if msgpcode.IsExt(c) {
  54. if err = d.skipExtHeader(c); err != nil {
  55. return 0, err
  56. }
  57. c, err = d.readCode()
  58. if err != nil {
  59. return 0, err
  60. }
  61. }
  62. return d.mapLen(c)
  63. }
  64. func (d *Decoder) mapLen(c byte) (int, error) {
  65. if c == msgpcode.Nil {
  66. return -1, nil
  67. }
  68. if c >= msgpcode.FixedMapLow && c <= msgpcode.FixedMapHigh {
  69. return int(c & msgpcode.FixedMapMask), nil
  70. }
  71. if c == msgpcode.Map16 {
  72. size, err := d.uint16()
  73. return int(size), err
  74. }
  75. if c == msgpcode.Map32 {
  76. size, err := d.uint32()
  77. return int(size), err
  78. }
  79. return 0, unexpectedCodeError{code: c, hint: "map length"}
  80. }
  81. func decodeMapStringStringValue(d *Decoder, v reflect.Value) error {
  82. mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string)
  83. return d.decodeMapStringStringPtr(mptr)
  84. }
  85. func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error {
  86. size, err := d.DecodeMapLen()
  87. if err != nil {
  88. return err
  89. }
  90. if size == -1 {
  91. *ptr = nil
  92. return nil
  93. }
  94. m := *ptr
  95. if m == nil {
  96. ln := size
  97. if d.flags&disableAllocLimitFlag == 0 {
  98. ln = min(size, maxMapSize)
  99. }
  100. *ptr = make(map[string]string, ln)
  101. m = *ptr
  102. }
  103. for i := 0; i < size; i++ {
  104. mk, err := d.DecodeString()
  105. if err != nil {
  106. return err
  107. }
  108. mv, err := d.DecodeString()
  109. if err != nil {
  110. return err
  111. }
  112. m[mk] = mv
  113. }
  114. return nil
  115. }
  116. func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error {
  117. ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{})
  118. return d.decodeMapStringInterfacePtr(ptr)
  119. }
  120. func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error {
  121. m, err := d.DecodeMap()
  122. if err != nil {
  123. return err
  124. }
  125. *ptr = m
  126. return nil
  127. }
  128. func (d *Decoder) DecodeMap() (map[string]interface{}, error) {
  129. n, err := d.DecodeMapLen()
  130. if err != nil {
  131. return nil, err
  132. }
  133. if n == -1 {
  134. return nil, nil
  135. }
  136. m := make(map[string]interface{}, n)
  137. for i := 0; i < n; i++ {
  138. mk, err := d.DecodeString()
  139. if err != nil {
  140. return nil, err
  141. }
  142. mv, err := d.decodeInterfaceCond()
  143. if err != nil {
  144. return nil, err
  145. }
  146. m[mk] = mv
  147. }
  148. return m, nil
  149. }
  150. func (d *Decoder) DecodeUntypedMap() (map[interface{}]interface{}, error) {
  151. n, err := d.DecodeMapLen()
  152. if err != nil {
  153. return nil, err
  154. }
  155. if n == -1 {
  156. return nil, nil
  157. }
  158. m := make(map[interface{}]interface{}, n)
  159. for i := 0; i < n; i++ {
  160. mk, err := d.decodeInterfaceCond()
  161. if err != nil {
  162. return nil, err
  163. }
  164. mv, err := d.decodeInterfaceCond()
  165. if err != nil {
  166. return nil, err
  167. }
  168. m[mk] = mv
  169. }
  170. return m, nil
  171. }
  172. // DecodeTypedMap decodes a typed map. Typed map is a map that has a fixed type for keys and values.
  173. // Key and value types may be different.
  174. func (d *Decoder) DecodeTypedMap() (interface{}, error) {
  175. n, err := d.DecodeMapLen()
  176. if err != nil {
  177. return nil, err
  178. }
  179. if n <= 0 {
  180. return nil, nil
  181. }
  182. key, err := d.decodeInterfaceCond()
  183. if err != nil {
  184. return nil, err
  185. }
  186. value, err := d.decodeInterfaceCond()
  187. if err != nil {
  188. return nil, err
  189. }
  190. keyType := reflect.TypeOf(key)
  191. valueType := reflect.TypeOf(value)
  192. if !keyType.Comparable() {
  193. return nil, fmt.Errorf("msgpack: unsupported map key: %s", keyType.String())
  194. }
  195. mapType := reflect.MapOf(keyType, valueType)
  196. ln := n
  197. if d.flags&disableAllocLimitFlag == 0 {
  198. ln = min(ln, maxMapSize)
  199. }
  200. mapValue := reflect.MakeMapWithSize(mapType, ln)
  201. mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value))
  202. n--
  203. if err := d.decodeTypedMapValue(mapValue, n); err != nil {
  204. return nil, err
  205. }
  206. return mapValue.Interface(), nil
  207. }
  208. func (d *Decoder) decodeTypedMapValue(v reflect.Value, n int) error {
  209. var (
  210. typ = v.Type()
  211. keyType = typ.Key()
  212. valueType = typ.Elem()
  213. )
  214. for i := 0; i < n; i++ {
  215. mk := d.newValue(keyType).Elem()
  216. if err := d.DecodeValue(mk); err != nil {
  217. return err
  218. }
  219. mv := d.newValue(valueType).Elem()
  220. if err := d.DecodeValue(mv); err != nil {
  221. return err
  222. }
  223. v.SetMapIndex(mk, mv)
  224. }
  225. return nil
  226. }
  227. func (d *Decoder) skipMap(c byte) error {
  228. n, err := d.mapLen(c)
  229. if err != nil {
  230. return err
  231. }
  232. for i := 0; i < n; i++ {
  233. if err := d.Skip(); err != nil {
  234. return err
  235. }
  236. if err := d.Skip(); err != nil {
  237. return err
  238. }
  239. }
  240. return nil
  241. }
  242. func decodeStructValue(d *Decoder, v reflect.Value) error {
  243. c, err := d.readCode()
  244. if err != nil {
  245. return err
  246. }
  247. n, err := d.mapLen(c)
  248. if err == nil {
  249. return d.decodeStruct(v, n)
  250. }
  251. var err2 error
  252. n, err2 = d.arrayLen(c)
  253. if err2 != nil {
  254. return err
  255. }
  256. if n <= 0 {
  257. v.Set(reflect.Zero(v.Type()))
  258. return nil
  259. }
  260. fields := structs.Fields(v.Type(), d.structTag)
  261. if n != len(fields.List) {
  262. return errArrayStruct
  263. }
  264. for _, f := range fields.List {
  265. if err := f.DecodeValue(d, v); err != nil {
  266. return err
  267. }
  268. }
  269. return nil
  270. }
  271. func (d *Decoder) decodeStruct(v reflect.Value, n int) error {
  272. if n == -1 {
  273. v.Set(reflect.Zero(v.Type()))
  274. return nil
  275. }
  276. fields := structs.Fields(v.Type(), d.structTag)
  277. for i := 0; i < n; i++ {
  278. name, err := d.decodeStringTemp()
  279. if err != nil {
  280. return err
  281. }
  282. if f := fields.Map[name]; f != nil {
  283. if err := f.DecodeValue(d, v); err != nil {
  284. return err
  285. }
  286. continue
  287. }
  288. if d.flags&disallowUnknownFieldsFlag != 0 {
  289. return fmt.Errorf("msgpack: unknown field %q", name)
  290. }
  291. if err := d.Skip(); err != nil {
  292. return err
  293. }
  294. }
  295. return nil
  296. }