bindform.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. package runtime
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "mime/multipart"
  7. "net/url"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. "github.com/deepmap/oapi-codegen/pkg/types"
  12. )
  13. const tagName = "json"
  14. const jsonContentType = "application/json"
  15. type RequestBodyEncoding struct {
  16. ContentType string
  17. Style string
  18. Explode *bool
  19. }
  20. func BindMultipart(ptr interface{}, reader multipart.Reader) error {
  21. const defaultMemory = 32 << 20
  22. form, err := reader.ReadForm(defaultMemory)
  23. if err != nil {
  24. return err
  25. }
  26. return BindForm(ptr, form.Value, form.File, nil)
  27. }
  28. func BindForm(ptr interface{}, form map[string][]string, files map[string][]*multipart.FileHeader, encodings map[string]RequestBodyEncoding) error {
  29. ptrVal := reflect.Indirect(reflect.ValueOf(ptr))
  30. if ptrVal.Kind() != reflect.Struct {
  31. return errors.New("form data body should be a struct")
  32. }
  33. tValue := ptrVal.Type()
  34. for i := 0; i < tValue.NumField(); i++ {
  35. field := ptrVal.Field(i)
  36. tag := tValue.Field(i).Tag.Get(tagName)
  37. if !field.CanInterface() || tag == "-" {
  38. continue
  39. }
  40. tag = strings.Split(tag, ",")[0] // extract the name of the tag
  41. if encoding, ok := encodings[tag]; ok {
  42. // custom encoding
  43. values := form[tag]
  44. if len(values) == 0 {
  45. continue
  46. }
  47. value := values[0]
  48. if encoding.ContentType != "" {
  49. if strings.HasPrefix(encoding.ContentType, jsonContentType) {
  50. if err := json.Unmarshal([]byte(value), ptr); err != nil {
  51. return err
  52. }
  53. }
  54. return errors.New("unsupported encoding, only application/json is supported")
  55. } else {
  56. var explode bool
  57. if encoding.Explode != nil {
  58. explode = *encoding.Explode
  59. }
  60. if err := BindStyledParameterWithLocation(encoding.Style, explode, tag, ParamLocationUndefined, value, field.Addr().Interface()); err != nil {
  61. return err
  62. }
  63. }
  64. } else {
  65. // regular form data
  66. if _, err := bindFormImpl(field, form, files, tag); err != nil {
  67. return err
  68. }
  69. }
  70. }
  71. return nil
  72. }
  73. func MarshalForm(ptr interface{}, encodings map[string]RequestBodyEncoding) (url.Values, error) {
  74. ptrVal := reflect.Indirect(reflect.ValueOf(ptr))
  75. if ptrVal.Kind() != reflect.Struct {
  76. return nil, errors.New("form data body should be a struct")
  77. }
  78. tValue := ptrVal.Type()
  79. result := make(url.Values)
  80. for i := 0; i < tValue.NumField(); i++ {
  81. field := ptrVal.Field(i)
  82. tag := tValue.Field(i).Tag.Get(tagName)
  83. if !field.CanInterface() || tag == "-" {
  84. continue
  85. }
  86. omitEmpty := strings.HasSuffix(tag, ",omitempty")
  87. if omitEmpty && field.IsZero() {
  88. continue
  89. }
  90. tag = strings.Split(tag, ",")[0] // extract the name of the tag
  91. if encoding, ok := encodings[tag]; ok && encoding.ContentType != "" {
  92. if strings.HasPrefix(encoding.ContentType, jsonContentType) {
  93. if data, err := json.Marshal(field); err != nil { //nolint:staticcheck
  94. return nil, err
  95. } else {
  96. result[tag] = append(result[tag], string(data))
  97. }
  98. }
  99. return nil, errors.New("unsupported encoding, only application/json is supported")
  100. } else {
  101. marshalFormImpl(field, result, tag)
  102. }
  103. }
  104. return result, nil
  105. }
  106. func bindFormImpl(v reflect.Value, form map[string][]string, files map[string][]*multipart.FileHeader, name string) (bool, error) {
  107. var hasData bool
  108. switch v.Kind() {
  109. case reflect.Interface:
  110. return bindFormImpl(v.Elem(), form, files, name)
  111. case reflect.Ptr:
  112. ptrData := v.Elem()
  113. if !ptrData.IsValid() {
  114. ptrData = reflect.New(v.Type().Elem())
  115. }
  116. ptrHasData, err := bindFormImpl(ptrData, form, files, name)
  117. if err == nil && ptrHasData && !v.Elem().IsValid() {
  118. v.Set(ptrData)
  119. }
  120. return ptrHasData, err
  121. case reflect.Slice:
  122. if files := append(files[name], files[name+"[]"]...); len(files) != 0 {
  123. if _, ok := v.Interface().([]types.File); ok {
  124. result := make([]types.File, len(files))
  125. for i, file := range files {
  126. result[i].InitFromMultipart(file)
  127. }
  128. v.Set(reflect.ValueOf(result))
  129. hasData = true
  130. }
  131. }
  132. indexedElementsCount := indexedElementsCount(form, files, name)
  133. items := append(form[name], form[name+"[]"]...)
  134. if indexedElementsCount+len(items) != 0 {
  135. result := reflect.MakeSlice(v.Type(), indexedElementsCount+len(items), indexedElementsCount+len(items))
  136. for i := 0; i < indexedElementsCount; i++ {
  137. if _, err := bindFormImpl(result.Index(i), form, files, fmt.Sprintf("%s[%v]", name, i)); err != nil {
  138. return false, err
  139. }
  140. }
  141. for i, item := range items {
  142. if err := BindStringToObject(item, result.Index(indexedElementsCount+i).Addr().Interface()); err != nil {
  143. return false, err
  144. }
  145. }
  146. v.Set(result)
  147. hasData = true
  148. }
  149. case reflect.Struct:
  150. if files := files[name]; len(files) != 0 {
  151. if file, ok := v.Interface().(types.File); ok {
  152. file.InitFromMultipart(files[0])
  153. v.Set(reflect.ValueOf(file))
  154. return true, nil
  155. }
  156. }
  157. for i := 0; i < v.NumField(); i++ {
  158. field := v.Type().Field(i)
  159. tag := field.Tag.Get(tagName)
  160. if field.Name == "AdditionalProperties" && field.Type.Kind() == reflect.Map && tag == "-" {
  161. additionalPropertiesHasData, err := bindAdditionalProperties(v.Field(i), v, form, files, name)
  162. if err != nil {
  163. return false, err
  164. }
  165. hasData = hasData || additionalPropertiesHasData
  166. }
  167. if !v.Field(i).CanInterface() || tag == "-" {
  168. continue
  169. }
  170. tag = strings.Split(tag, ",")[0] // extract the name of the tag
  171. fieldHasData, err := bindFormImpl(v.Field(i), form, files, fmt.Sprintf("%s[%s]", name, tag))
  172. if err != nil {
  173. return false, err
  174. }
  175. hasData = hasData || fieldHasData
  176. }
  177. return hasData, nil
  178. default:
  179. value := form[name]
  180. if len(value) != 0 {
  181. return true, BindStringToObject(value[0], v.Addr().Interface())
  182. }
  183. }
  184. return hasData, nil
  185. }
  186. func indexedElementsCount(form map[string][]string, files map[string][]*multipart.FileHeader, name string) int {
  187. name += "["
  188. maxIndex := -1
  189. for k := range form {
  190. if strings.HasPrefix(k, name) {
  191. str := strings.TrimPrefix(k, name)
  192. str = str[:strings.Index(str, "]")]
  193. if idx, err := strconv.Atoi(str); err == nil {
  194. if idx > maxIndex {
  195. maxIndex = idx
  196. }
  197. }
  198. }
  199. }
  200. for k := range files {
  201. if strings.HasPrefix(k, name) {
  202. str := strings.TrimPrefix(k, name)
  203. str = str[:strings.Index(str, "]")]
  204. if idx, err := strconv.Atoi(str); err == nil {
  205. if idx > maxIndex {
  206. maxIndex = idx
  207. }
  208. }
  209. }
  210. }
  211. return maxIndex + 1
  212. }
  213. func bindAdditionalProperties(additionalProperties reflect.Value, parentStruct reflect.Value, form map[string][]string, files map[string][]*multipart.FileHeader, name string) (bool, error) {
  214. hasData := false
  215. valueType := additionalProperties.Type().Elem()
  216. // store all fixed properties in a set
  217. fieldsSet := make(map[string]struct{})
  218. for i := 0; i < parentStruct.NumField(); i++ {
  219. tag := parentStruct.Type().Field(i).Tag.Get(tagName)
  220. if !parentStruct.Field(i).CanInterface() || tag == "-" {
  221. continue
  222. }
  223. tag = strings.Split(tag, ",")[0]
  224. fieldsSet[tag] = struct{}{}
  225. }
  226. result := reflect.MakeMap(additionalProperties.Type())
  227. for k := range form {
  228. if strings.HasPrefix(k, name+"[") {
  229. key := strings.TrimPrefix(k, name+"[")
  230. key = key[:strings.Index(key, "]")]
  231. if _, ok := fieldsSet[key]; ok {
  232. continue
  233. }
  234. value := reflect.New(valueType)
  235. ptrHasData, err := bindFormImpl(value, form, files, fmt.Sprintf("%s[%s]", name, key))
  236. if err != nil {
  237. return false, err
  238. }
  239. result.SetMapIndex(reflect.ValueOf(key), value.Elem())
  240. hasData = hasData || ptrHasData
  241. }
  242. }
  243. for k := range files {
  244. if strings.HasPrefix(k, name+"[") {
  245. key := strings.TrimPrefix(k, name+"[")
  246. key = key[:strings.Index(key, "]")]
  247. if _, ok := fieldsSet[key]; ok {
  248. continue
  249. }
  250. value := reflect.New(valueType)
  251. result.SetMapIndex(reflect.ValueOf(key), value)
  252. ptrHasData, err := bindFormImpl(value, form, files, fmt.Sprintf("%s[%s]", name, key))
  253. if err != nil {
  254. return false, err
  255. }
  256. result.SetMapIndex(reflect.ValueOf(key), value.Elem())
  257. hasData = hasData || ptrHasData
  258. }
  259. }
  260. if hasData {
  261. additionalProperties.Set(result)
  262. }
  263. return hasData, nil
  264. }
  265. func marshalFormImpl(v reflect.Value, result url.Values, name string) {
  266. switch v.Kind() {
  267. case reflect.Interface, reflect.Ptr:
  268. marshalFormImpl(v.Elem(), result, name)
  269. case reflect.Slice:
  270. for i := 0; i < v.Len(); i++ {
  271. elem := v.Index(i)
  272. marshalFormImpl(elem, result, fmt.Sprintf("%s[%v]", name, i))
  273. }
  274. case reflect.Struct:
  275. for i := 0; i < v.NumField(); i++ {
  276. field := v.Type().Field(i)
  277. tag := field.Tag.Get(tagName)
  278. if field.Name == "AdditionalProperties" && tag == "-" {
  279. iter := v.MapRange()
  280. for iter.Next() {
  281. marshalFormImpl(iter.Value(), result, fmt.Sprintf("%s[%s]", name, iter.Key().String()))
  282. }
  283. continue
  284. }
  285. if !v.Field(i).CanInterface() || tag == "-" {
  286. continue
  287. }
  288. tag = strings.Split(tag, ",")[0] // extract the name of the tag
  289. marshalFormImpl(v.Field(i), result, fmt.Sprintf("%s[%s]", name, tag))
  290. }
  291. default:
  292. result[name] = append(result[name], fmt.Sprint(v.Interface()))
  293. }
  294. }