deepobject.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. package runtime
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "net/url"
  7. "reflect"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. "time"
  12. "github.com/deepmap/oapi-codegen/pkg/types"
  13. )
  14. func marshalDeepObject(in interface{}, path []string) ([]string, error) {
  15. var result []string
  16. switch t := in.(type) {
  17. case []interface{}:
  18. // For the array, we will use numerical subscripts of the form [x],
  19. // in the same order as the array.
  20. for i, iface := range t {
  21. newPath := append(path, strconv.Itoa(i))
  22. fields, err := marshalDeepObject(iface, newPath)
  23. if err != nil {
  24. return nil, fmt.Errorf("error traversing array: %w", err)
  25. }
  26. result = append(result, fields...)
  27. }
  28. case map[string]interface{}:
  29. // For a map, each key (field name) becomes a member of the path, and
  30. // we recurse. First, sort the keys.
  31. keys := make([]string, len(t))
  32. i := 0
  33. for k := range t {
  34. keys[i] = k
  35. i++
  36. }
  37. sort.Strings(keys)
  38. // Now, for each key, we recursively marshal it.
  39. for _, k := range keys {
  40. newPath := append(path, k)
  41. fields, err := marshalDeepObject(t[k], newPath)
  42. if err != nil {
  43. return nil, fmt.Errorf("error traversing map: %w", err)
  44. }
  45. result = append(result, fields...)
  46. }
  47. default:
  48. // Now, for a concrete value, we will turn the path elements
  49. // into a deepObject style set of subscripts. [a, b, c] turns into
  50. // [a][b][c]
  51. prefix := "[" + strings.Join(path, "][") + "]"
  52. result = []string{
  53. prefix + fmt.Sprintf("=%v", t),
  54. }
  55. }
  56. return result, nil
  57. }
  58. func MarshalDeepObject(i interface{}, paramName string) (string, error) {
  59. // We're going to marshal to JSON and unmarshal into an interface{},
  60. // which will use the json pkg to deal with all the field annotations. We
  61. // can then walk the generic object structure to produce a deepObject. This
  62. // isn't efficient and it would be more efficient to reflect on our own,
  63. // but it's complicated, error-prone code.
  64. buf, err := json.Marshal(i)
  65. if err != nil {
  66. return "", fmt.Errorf("failed to marshal input to JSON: %w", err)
  67. }
  68. var i2 interface{}
  69. err = json.Unmarshal(buf, &i2)
  70. if err != nil {
  71. return "", fmt.Errorf("failed to unmarshal JSON: %w", err)
  72. }
  73. fields, err := marshalDeepObject(i2, nil)
  74. if err != nil {
  75. return "", fmt.Errorf("error traversing JSON structure: %w", err)
  76. }
  77. // Prefix the param name to each subscripted field.
  78. for i := range fields {
  79. fields[i] = paramName + fields[i]
  80. }
  81. return strings.Join(fields, "&"), nil
  82. }
  83. type fieldOrValue struct {
  84. fields map[string]fieldOrValue
  85. value string
  86. }
  87. func (f *fieldOrValue) appendPathValue(path []string, value string) {
  88. fieldName := path[0]
  89. if len(path) == 1 {
  90. f.fields[fieldName] = fieldOrValue{value: value}
  91. return
  92. }
  93. pv, found := f.fields[fieldName]
  94. if !found {
  95. pv = fieldOrValue{
  96. fields: make(map[string]fieldOrValue),
  97. }
  98. f.fields[fieldName] = pv
  99. }
  100. pv.appendPathValue(path[1:], value)
  101. }
  102. func makeFieldOrValue(paths [][]string, values []string) fieldOrValue {
  103. f := fieldOrValue{
  104. fields: make(map[string]fieldOrValue),
  105. }
  106. for i := range paths {
  107. path := paths[i]
  108. value := values[i]
  109. f.appendPathValue(path, value)
  110. }
  111. return f
  112. }
  113. func UnmarshalDeepObject(dst interface{}, paramName string, params url.Values) error {
  114. // Params are all the query args, so we need those that look like
  115. // "paramName["...
  116. var fieldNames []string
  117. var fieldValues []string
  118. searchStr := paramName + "["
  119. for pName, pValues := range params {
  120. if strings.HasPrefix(pName, searchStr) {
  121. // trim the parameter name from the full name.
  122. pName = pName[len(paramName):]
  123. fieldNames = append(fieldNames, pName)
  124. if len(pValues) != 1 {
  125. return fmt.Errorf("%s has multiple values", pName)
  126. }
  127. fieldValues = append(fieldValues, pValues[0])
  128. }
  129. }
  130. // Now, for each field, reconstruct its subscript path and value
  131. paths := make([][]string, len(fieldNames))
  132. for i, path := range fieldNames {
  133. path = strings.TrimLeft(path, "[")
  134. path = strings.TrimRight(path, "]")
  135. paths[i] = strings.Split(path, "][")
  136. }
  137. fieldPaths := makeFieldOrValue(paths, fieldValues)
  138. err := assignPathValues(dst, fieldPaths)
  139. if err != nil {
  140. return fmt.Errorf("error assigning value to destination: %w", err)
  141. }
  142. return nil
  143. }
  144. // This returns a field name, either using the variable name, or the json
  145. // annotation if that exists.
  146. func getFieldName(f reflect.StructField) string {
  147. n := f.Name
  148. tag, found := f.Tag.Lookup("json")
  149. if found {
  150. // If we have a json field, and the first part of it before the
  151. // first comma is non-empty, that's our field name.
  152. parts := strings.Split(tag, ",")
  153. if parts[0] != "" {
  154. n = parts[0]
  155. }
  156. }
  157. return n
  158. }
  159. // Create a map of field names that we'll see in the deepObject to reflect
  160. // field indices on the given type.
  161. func fieldIndicesByJsonTag(i interface{}) (map[string]int, error) {
  162. t := reflect.TypeOf(i)
  163. if t.Kind() != reflect.Struct {
  164. return nil, errors.New("expected a struct as input")
  165. }
  166. n := t.NumField()
  167. fieldMap := make(map[string]int)
  168. for i := 0; i < n; i++ {
  169. field := t.Field(i)
  170. fieldName := getFieldName(field)
  171. fieldMap[fieldName] = i
  172. }
  173. return fieldMap, nil
  174. }
  175. func assignPathValues(dst interface{}, pathValues fieldOrValue) error {
  176. //t := reflect.TypeOf(dst)
  177. v := reflect.ValueOf(dst)
  178. iv := reflect.Indirect(v)
  179. it := iv.Type()
  180. switch it.Kind() {
  181. case reflect.Slice:
  182. sliceLength := len(pathValues.fields)
  183. dstSlice := reflect.MakeSlice(it, sliceLength, sliceLength)
  184. err := assignSlice(dstSlice, pathValues)
  185. if err != nil {
  186. return fmt.Errorf("error assigning slice: %w", err)
  187. }
  188. iv.Set(dstSlice)
  189. return nil
  190. case reflect.Struct:
  191. // Some special types we care about are structs. Handle them
  192. // here. They may be redefined, so we need to do some hoop
  193. // jumping. If the types are aliased, we need to type convert
  194. // the pointer, then set the value of the dereference pointer.
  195. // We check to see if the object implements the Binder interface first.
  196. if dst, isBinder := v.Interface().(Binder); isBinder {
  197. return dst.Bind(pathValues.value)
  198. }
  199. // Then check the legacy types
  200. if it.ConvertibleTo(reflect.TypeOf(types.Date{})) {
  201. var date types.Date
  202. var err error
  203. date.Time, err = time.Parse(types.DateFormat, pathValues.value)
  204. if err != nil {
  205. return fmt.Errorf("invalid date format: %w", err)
  206. }
  207. dst := iv
  208. if it != reflect.TypeOf(types.Date{}) {
  209. // Types are aliased, convert the pointers.
  210. ivPtr := iv.Addr()
  211. aPtr := ivPtr.Convert(reflect.TypeOf(&types.Date{}))
  212. dst = reflect.Indirect(aPtr)
  213. }
  214. dst.Set(reflect.ValueOf(date))
  215. }
  216. if it.ConvertibleTo(reflect.TypeOf(time.Time{})) {
  217. var tm time.Time
  218. var err error
  219. tm, err = time.Parse(time.RFC3339Nano, pathValues.value)
  220. if err != nil {
  221. // Fall back to parsing it as a date.
  222. // TODO: why is this marked as an ineffassign?
  223. tm, err = time.Parse(types.DateFormat, pathValues.value) //nolint:ineffassign,staticcheck
  224. if err != nil {
  225. return fmt.Errorf("error parsing tim as RFC3339 or 2006-01-02 time: %s", err)
  226. }
  227. return fmt.Errorf("invalid date format: %w", err)
  228. }
  229. dst := iv
  230. if it != reflect.TypeOf(time.Time{}) {
  231. // Types are aliased, convert the pointers.
  232. ivPtr := iv.Addr()
  233. aPtr := ivPtr.Convert(reflect.TypeOf(&time.Time{}))
  234. dst = reflect.Indirect(aPtr)
  235. }
  236. dst.Set(reflect.ValueOf(tm))
  237. }
  238. fieldMap, err := fieldIndicesByJsonTag(iv.Interface())
  239. if err != nil {
  240. return fmt.Errorf("failed enumerating fields: %w", err)
  241. }
  242. for _, fieldName := range sortedFieldOrValueKeys(pathValues.fields) {
  243. fieldValue := pathValues.fields[fieldName]
  244. fieldIndex, found := fieldMap[fieldName]
  245. if !found {
  246. return fmt.Errorf("field [%s] is not present in destination object", fieldName)
  247. }
  248. field := iv.Field(fieldIndex)
  249. err = assignPathValues(field.Addr().Interface(), fieldValue)
  250. if err != nil {
  251. return fmt.Errorf("error assigning field [%s]: %w", fieldName, err)
  252. }
  253. }
  254. return nil
  255. case reflect.Ptr:
  256. // If we have a pointer after redirecting, it means we're dealing with
  257. // an optional field, such as *string, which was passed in as &foo. We
  258. // will allocate it if necessary, and call ourselves with a different
  259. // interface.
  260. dstVal := reflect.New(it.Elem())
  261. dstPtr := dstVal.Interface()
  262. err := assignPathValues(dstPtr, pathValues)
  263. iv.Set(dstVal)
  264. return err
  265. case reflect.Bool:
  266. val, err := strconv.ParseBool(pathValues.value)
  267. if err != nil {
  268. return fmt.Errorf("expected a valid bool, got %s", pathValues.value)
  269. }
  270. iv.SetBool(val)
  271. return nil
  272. case reflect.Float32:
  273. val, err := strconv.ParseFloat(pathValues.value, 32)
  274. if err != nil {
  275. return fmt.Errorf("expected a valid float, got %s", pathValues.value)
  276. }
  277. iv.SetFloat(val)
  278. return nil
  279. case reflect.Float64:
  280. val, err := strconv.ParseFloat(pathValues.value, 64)
  281. if err != nil {
  282. return fmt.Errorf("expected a valid float, got %s", pathValues.value)
  283. }
  284. iv.SetFloat(val)
  285. return nil
  286. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  287. val, err := strconv.ParseInt(pathValues.value, 10, 64)
  288. if err != nil {
  289. return fmt.Errorf("expected a valid int, got %s", pathValues.value)
  290. }
  291. iv.SetInt(val)
  292. return nil
  293. case reflect.String:
  294. iv.SetString(pathValues.value)
  295. return nil
  296. default:
  297. return errors.New("unhandled type: " + it.String())
  298. }
  299. }
  300. func assignSlice(dst reflect.Value, pathValues fieldOrValue) error {
  301. // Gather up the values
  302. nValues := len(pathValues.fields)
  303. values := make([]string, nValues)
  304. // We expect to have consecutive array indices in the map
  305. for i := 0; i < nValues; i++ {
  306. indexStr := strconv.Itoa(i)
  307. fv, found := pathValues.fields[indexStr]
  308. if !found {
  309. return errors.New("array deepObjects must have consecutive indices")
  310. }
  311. values[i] = fv.value
  312. }
  313. // This could be cleaner, but we can call into assignPathValues to
  314. // avoid recreating this logic.
  315. for i := 0; i < nValues; i++ {
  316. dstElem := dst.Index(i).Addr()
  317. err := assignPathValues(dstElem.Interface(), fieldOrValue{value: values[i]})
  318. if err != nil {
  319. return fmt.Errorf("error binding array: %w", err)
  320. }
  321. }
  322. return nil
  323. }
  324. func sortedFieldOrValueKeys(m map[string]fieldOrValue) []string {
  325. keys := make([]string, 0, len(m))
  326. for k := range m {
  327. keys = append(keys, k)
  328. }
  329. sort.Strings(keys)
  330. return keys
  331. }