styleparam.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. // Copyright 2019 DeepMap, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package runtime
  15. import (
  16. "bytes"
  17. "encoding"
  18. "encoding/json"
  19. "errors"
  20. "fmt"
  21. "net/url"
  22. "reflect"
  23. "sort"
  24. "strconv"
  25. "strings"
  26. "time"
  27. "github.com/deepmap/oapi-codegen/pkg/types"
  28. )
  29. // Parameter escaping works differently based on where a header is found
  30. type ParamLocation int
  31. const (
  32. ParamLocationUndefined ParamLocation = iota
  33. ParamLocationQuery
  34. ParamLocationPath
  35. ParamLocationHeader
  36. ParamLocationCookie
  37. )
  38. // StyleParam is used by older generated code, and must remain compatible
  39. // with that code. It is not to be used in new templates. Please see the
  40. // function below, which can specialize its output based on the location of
  41. // the parameter.
  42. func StyleParam(style string, explode bool, paramName string, value interface{}) (string, error) {
  43. return StyleParamWithLocation(style, explode, paramName, ParamLocationUndefined, value)
  44. }
  45. // Given an input value, such as a primitive type, array or object, turn it
  46. // into a parameter based on style/explode definition, performing whatever
  47. // escaping is necessary based on parameter location
  48. func StyleParamWithLocation(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
  49. t := reflect.TypeOf(value)
  50. v := reflect.ValueOf(value)
  51. // Things may be passed in by pointer, we need to dereference, so return
  52. // error on nil.
  53. if t.Kind() == reflect.Ptr {
  54. if v.IsNil() {
  55. return "", fmt.Errorf("value is a nil pointer")
  56. }
  57. v = reflect.Indirect(v)
  58. t = v.Type()
  59. }
  60. // If the value implements encoding.TextMarshaler we use it for marshaling
  61. // https://github.com/deepmap/oapi-codegen/issues/504
  62. if tu, ok := value.(encoding.TextMarshaler); ok {
  63. t := reflect.Indirect(reflect.ValueOf(value)).Type()
  64. convertableToTime := t.ConvertibleTo(reflect.TypeOf(time.Time{}))
  65. convertableToDate := t.ConvertibleTo(reflect.TypeOf(types.Date{}))
  66. // Since both time.Time and types.Date implement encoding.TextMarshaler
  67. // we should avoid calling theirs MarshalText()
  68. if !convertableToTime && !convertableToDate {
  69. b, err := tu.MarshalText()
  70. if err != nil {
  71. return "", fmt.Errorf("error marshaling '%s' as text: %s", value, err)
  72. }
  73. return stylePrimitive(style, explode, paramName, paramLocation, string(b))
  74. }
  75. }
  76. switch t.Kind() {
  77. case reflect.Slice:
  78. n := v.Len()
  79. sliceVal := make([]interface{}, n)
  80. for i := 0; i < n; i++ {
  81. sliceVal[i] = v.Index(i).Interface()
  82. }
  83. return styleSlice(style, explode, paramName, paramLocation, sliceVal)
  84. case reflect.Struct:
  85. return styleStruct(style, explode, paramName, paramLocation, value)
  86. case reflect.Map:
  87. return styleMap(style, explode, paramName, paramLocation, value)
  88. default:
  89. return stylePrimitive(style, explode, paramName, paramLocation, value)
  90. }
  91. }
  92. func styleSlice(style string, explode bool, paramName string, paramLocation ParamLocation, values []interface{}) (string, error) {
  93. if style == "deepObject" {
  94. if !explode {
  95. return "", errors.New("deepObjects must be exploded")
  96. }
  97. return MarshalDeepObject(values, paramName)
  98. }
  99. var prefix string
  100. var separator string
  101. switch style {
  102. case "simple":
  103. separator = ","
  104. case "label":
  105. prefix = "."
  106. if explode {
  107. separator = "."
  108. } else {
  109. separator = ","
  110. }
  111. case "matrix":
  112. prefix = fmt.Sprintf(";%s=", paramName)
  113. if explode {
  114. separator = prefix
  115. } else {
  116. separator = ","
  117. }
  118. case "form":
  119. prefix = fmt.Sprintf("%s=", paramName)
  120. if explode {
  121. separator = "&" + prefix
  122. } else {
  123. separator = ","
  124. }
  125. case "spaceDelimited":
  126. prefix = fmt.Sprintf("%s=", paramName)
  127. if explode {
  128. separator = "&" + prefix
  129. } else {
  130. separator = " "
  131. }
  132. case "pipeDelimited":
  133. prefix = fmt.Sprintf("%s=", paramName)
  134. if explode {
  135. separator = "&" + prefix
  136. } else {
  137. separator = "|"
  138. }
  139. default:
  140. return "", fmt.Errorf("unsupported style '%s'", style)
  141. }
  142. // We're going to assume here that the array is one of simple types.
  143. var err error
  144. var part string
  145. parts := make([]string, len(values))
  146. for i, v := range values {
  147. part, err = primitiveToString(v)
  148. part = escapeParameterString(part, paramLocation)
  149. parts[i] = part
  150. if err != nil {
  151. return "", fmt.Errorf("error formatting '%s': %s", paramName, err)
  152. }
  153. }
  154. return prefix + strings.Join(parts, separator), nil
  155. }
  156. func sortedKeys(strMap map[string]string) []string {
  157. keys := make([]string, len(strMap))
  158. i := 0
  159. for k := range strMap {
  160. keys[i] = k
  161. i++
  162. }
  163. sort.Strings(keys)
  164. return keys
  165. }
  166. // These are special cases. The value may be a date, time, or uuid,
  167. // in which case, marshal it into the correct format.
  168. func marshalKnownTypes(value interface{}) (string, bool) {
  169. v := reflect.Indirect(reflect.ValueOf(value))
  170. t := v.Type()
  171. if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
  172. tt := v.Convert(reflect.TypeOf(time.Time{}))
  173. timeVal := tt.Interface().(time.Time)
  174. return timeVal.Format(time.RFC3339Nano), true
  175. }
  176. if t.ConvertibleTo(reflect.TypeOf(types.Date{})) {
  177. d := v.Convert(reflect.TypeOf(types.Date{}))
  178. dateVal := d.Interface().(types.Date)
  179. return dateVal.Format(types.DateFormat), true
  180. }
  181. if t.ConvertibleTo(reflect.TypeOf(types.UUID{})) {
  182. u := v.Convert(reflect.TypeOf(types.UUID{}))
  183. uuidVal := u.Interface().(types.UUID)
  184. return uuidVal.String(), true
  185. }
  186. return "", false
  187. }
  188. func styleStruct(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
  189. if timeVal, ok := marshalKnownTypes(value); ok {
  190. styledVal, err := stylePrimitive(style, explode, paramName, paramLocation, timeVal)
  191. if err != nil {
  192. return "", fmt.Errorf("failed to style time: %w", err)
  193. }
  194. return styledVal, nil
  195. }
  196. if style == "deepObject" {
  197. if !explode {
  198. return "", errors.New("deepObjects must be exploded")
  199. }
  200. return MarshalDeepObject(value, paramName)
  201. }
  202. // If input has Marshaler, such as object has Additional Property or AnyOf,
  203. // We use this Marshaler and convert into interface{} before styling.
  204. if m, ok := value.(json.Marshaler); ok {
  205. buf, err := m.MarshalJSON()
  206. if err != nil {
  207. return "", fmt.Errorf("failed to marshal input to JSON: %w", err)
  208. }
  209. e := json.NewDecoder(bytes.NewReader(buf))
  210. e.UseNumber()
  211. var i2 interface{}
  212. err = e.Decode(&i2)
  213. if err != nil {
  214. return "", fmt.Errorf("failed to unmarshal JSON: %w", err)
  215. }
  216. s, err := StyleParamWithLocation(style, explode, paramName, paramLocation, i2)
  217. if err != nil {
  218. return "", fmt.Errorf("error style JSON structure: %w", err)
  219. }
  220. return s, nil
  221. }
  222. // Otherwise, we need to build a dictionary of the struct's fields. Each
  223. // field may only be a primitive value.
  224. v := reflect.ValueOf(value)
  225. t := reflect.TypeOf(value)
  226. fieldDict := make(map[string]string)
  227. for i := 0; i < t.NumField(); i++ {
  228. fieldT := t.Field(i)
  229. // Find the json annotation on the field, and use the json specified
  230. // name if available, otherwise, just the field name.
  231. tag := fieldT.Tag.Get("json")
  232. fieldName := fieldT.Name
  233. if tag != "" {
  234. tagParts := strings.Split(tag, ",")
  235. name := tagParts[0]
  236. if name != "" {
  237. fieldName = name
  238. }
  239. }
  240. f := v.Field(i)
  241. // Unset optional fields will be nil pointers, skip over those.
  242. if f.Type().Kind() == reflect.Ptr && f.IsNil() {
  243. continue
  244. }
  245. str, err := primitiveToString(f.Interface())
  246. if err != nil {
  247. return "", fmt.Errorf("error formatting '%s': %s", paramName, err)
  248. }
  249. fieldDict[fieldName] = str
  250. }
  251. return processFieldDict(style, explode, paramName, paramLocation, fieldDict)
  252. }
  253. func styleMap(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
  254. if style == "deepObject" {
  255. if !explode {
  256. return "", errors.New("deepObjects must be exploded")
  257. }
  258. return MarshalDeepObject(value, paramName)
  259. }
  260. dict, ok := value.(map[string]interface{})
  261. if !ok {
  262. return "", errors.New("map not of type map[string]interface{}")
  263. }
  264. fieldDict := make(map[string]string)
  265. for fieldName, value := range dict {
  266. str, err := primitiveToString(value)
  267. if err != nil {
  268. return "", fmt.Errorf("error formatting '%s': %s", paramName, err)
  269. }
  270. fieldDict[fieldName] = str
  271. }
  272. return processFieldDict(style, explode, paramName, paramLocation, fieldDict)
  273. }
  274. func processFieldDict(style string, explode bool, paramName string, paramLocation ParamLocation, fieldDict map[string]string) (string, error) {
  275. var parts []string
  276. // This works for everything except deepObject. We'll handle that one
  277. // separately.
  278. if style != "deepObject" {
  279. if explode {
  280. for _, k := range sortedKeys(fieldDict) {
  281. v := escapeParameterString(fieldDict[k], paramLocation)
  282. parts = append(parts, k+"="+v)
  283. }
  284. } else {
  285. for _, k := range sortedKeys(fieldDict) {
  286. v := escapeParameterString(fieldDict[k], paramLocation)
  287. parts = append(parts, k)
  288. parts = append(parts, v)
  289. }
  290. }
  291. }
  292. var prefix string
  293. var separator string
  294. switch style {
  295. case "simple":
  296. separator = ","
  297. case "label":
  298. prefix = "."
  299. if explode {
  300. separator = prefix
  301. } else {
  302. separator = ","
  303. }
  304. case "matrix":
  305. if explode {
  306. separator = ";"
  307. prefix = ";"
  308. } else {
  309. separator = ","
  310. prefix = fmt.Sprintf(";%s=", paramName)
  311. }
  312. case "form":
  313. if explode {
  314. separator = "&"
  315. } else {
  316. prefix = fmt.Sprintf("%s=", paramName)
  317. separator = ","
  318. }
  319. case "deepObject":
  320. {
  321. if !explode {
  322. return "", fmt.Errorf("deepObject parameters must be exploded")
  323. }
  324. for _, k := range sortedKeys(fieldDict) {
  325. v := fieldDict[k]
  326. part := fmt.Sprintf("%s[%s]=%s", paramName, k, v)
  327. parts = append(parts, part)
  328. }
  329. separator = "&"
  330. }
  331. default:
  332. return "", fmt.Errorf("unsupported style '%s'", style)
  333. }
  334. return prefix + strings.Join(parts, separator), nil
  335. }
  336. func stylePrimitive(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
  337. strVal, err := primitiveToString(value)
  338. if err != nil {
  339. return "", err
  340. }
  341. var prefix string
  342. switch style {
  343. case "simple":
  344. case "label":
  345. prefix = "."
  346. case "matrix":
  347. prefix = fmt.Sprintf(";%s=", paramName)
  348. case "form":
  349. prefix = fmt.Sprintf("%s=", paramName)
  350. default:
  351. return "", fmt.Errorf("unsupported style '%s'", style)
  352. }
  353. return prefix + escapeParameterString(strVal, paramLocation), nil
  354. }
  355. // Converts a primitive value to a string. We need to do this based on the
  356. // Kind of an interface, not the Type to work with aliased types.
  357. func primitiveToString(value interface{}) (string, error) {
  358. var output string
  359. // sometimes time and date used like primitive types
  360. // it can happen if paramether is object and has time or date as field
  361. if res, ok := marshalKnownTypes(value); ok {
  362. return res, nil
  363. }
  364. // Values may come in by pointer for optionals, so make sure to dereferene.
  365. v := reflect.Indirect(reflect.ValueOf(value))
  366. t := v.Type()
  367. kind := t.Kind()
  368. switch kind {
  369. case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
  370. output = strconv.FormatInt(v.Int(), 10)
  371. case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
  372. output = strconv.FormatUint(v.Uint(), 10)
  373. case reflect.Float64:
  374. output = strconv.FormatFloat(v.Float(), 'f', -1, 64)
  375. case reflect.Float32:
  376. output = strconv.FormatFloat(v.Float(), 'f', -1, 32)
  377. case reflect.Bool:
  378. if v.Bool() {
  379. output = "true"
  380. } else {
  381. output = "false"
  382. }
  383. case reflect.String:
  384. output = v.String()
  385. case reflect.Struct:
  386. // If input has Marshaler, such as object has Additional Property or AnyOf,
  387. // We use this Marshaler and convert into interface{} before styling.
  388. if m, ok := value.(json.Marshaler); ok {
  389. buf, err := m.MarshalJSON()
  390. if err != nil {
  391. return "", fmt.Errorf("failed to marshal input to JSON: %w", err)
  392. }
  393. e := json.NewDecoder(bytes.NewReader(buf))
  394. e.UseNumber()
  395. var i2 interface{}
  396. err = e.Decode(&i2)
  397. if err != nil {
  398. return "", fmt.Errorf("failed to unmarshal JSON: %w", err)
  399. }
  400. output, err = primitiveToString(i2)
  401. if err != nil {
  402. return "", fmt.Errorf("error convert JSON structure: %w", err)
  403. }
  404. break
  405. }
  406. fallthrough
  407. default:
  408. v, ok := value.(fmt.Stringer)
  409. if !ok {
  410. return "", fmt.Errorf("unsupported type %s", reflect.TypeOf(value).String())
  411. }
  412. output = v.String()
  413. }
  414. return output, nil
  415. }
  416. // This function escapes a parameter value bas on the location of that parameter.
  417. // Query params and path params need different kinds of escaping, while header
  418. // and cookie params seem not to need escaping.
  419. func escapeParameterString(value string, paramLocation ParamLocation) string {
  420. switch paramLocation {
  421. case ParamLocationQuery:
  422. return url.QueryEscape(value)
  423. case ParamLocationPath:
  424. return url.PathEscape(value)
  425. default:
  426. return value
  427. }
  428. }