styleparam.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. // Copyright 2019 DeepMap, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package runtime
  15. import (
  16. "bytes"
  17. "encoding"
  18. "encoding/json"
  19. "errors"
  20. "fmt"
  21. "net/url"
  22. "reflect"
  23. "sort"
  24. "strconv"
  25. "strings"
  26. "time"
  27. "github.com/oapi-codegen/runtime/types"
  28. "github.com/google/uuid"
  29. )
  30. // Parameter escaping works differently based on where a header is found
  31. type ParamLocation int
  32. const (
  33. ParamLocationUndefined ParamLocation = iota
  34. ParamLocationQuery
  35. ParamLocationPath
  36. ParamLocationHeader
  37. ParamLocationCookie
  38. )
  39. // StyleParam is used by older generated code, and must remain compatible
  40. // with that code. It is not to be used in new templates. Please see the
  41. // function below, which can specialize its output based on the location of
  42. // the parameter.
  43. func StyleParam(style string, explode bool, paramName string, value interface{}) (string, error) {
  44. return StyleParamWithLocation(style, explode, paramName, ParamLocationUndefined, value)
  45. }
  46. // Given an input value, such as a primitive type, array or object, turn it
  47. // into a parameter based on style/explode definition, performing whatever
  48. // escaping is necessary based on parameter location
  49. func StyleParamWithLocation(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
  50. t := reflect.TypeOf(value)
  51. v := reflect.ValueOf(value)
  52. // Things may be passed in by pointer, we need to dereference, so return
  53. // error on nil.
  54. if t.Kind() == reflect.Ptr {
  55. if v.IsNil() {
  56. return "", fmt.Errorf("value is a nil pointer")
  57. }
  58. v = reflect.Indirect(v)
  59. t = v.Type()
  60. }
  61. // If the value implements encoding.TextMarshaler we use it for marshaling
  62. // https://github.com/deepmap/oapi-codegen/issues/504
  63. if tu, ok := value.(encoding.TextMarshaler); ok {
  64. t := reflect.Indirect(reflect.ValueOf(value)).Type()
  65. convertableToTime := t.ConvertibleTo(reflect.TypeOf(time.Time{}))
  66. convertableToDate := t.ConvertibleTo(reflect.TypeOf(types.Date{}))
  67. // Since both time.Time and types.Date implement encoding.TextMarshaler
  68. // we should avoid calling theirs MarshalText()
  69. if !convertableToTime && !convertableToDate {
  70. b, err := tu.MarshalText()
  71. if err != nil {
  72. return "", fmt.Errorf("error marshaling '%s' as text: %s", value, err)
  73. }
  74. return stylePrimitive(style, explode, paramName, paramLocation, string(b))
  75. }
  76. }
  77. switch t.Kind() {
  78. case reflect.Slice:
  79. n := v.Len()
  80. sliceVal := make([]interface{}, n)
  81. for i := 0; i < n; i++ {
  82. sliceVal[i] = v.Index(i).Interface()
  83. }
  84. return styleSlice(style, explode, paramName, paramLocation, sliceVal)
  85. case reflect.Struct:
  86. return styleStruct(style, explode, paramName, paramLocation, value)
  87. case reflect.Map:
  88. return styleMap(style, explode, paramName, paramLocation, value)
  89. default:
  90. return stylePrimitive(style, explode, paramName, paramLocation, value)
  91. }
  92. }
  93. func styleSlice(style string, explode bool, paramName string, paramLocation ParamLocation, values []interface{}) (string, error) {
  94. if style == "deepObject" {
  95. if !explode {
  96. return "", errors.New("deepObjects must be exploded")
  97. }
  98. return MarshalDeepObject(values, paramName)
  99. }
  100. var prefix string
  101. var separator string
  102. switch style {
  103. case "simple":
  104. separator = ","
  105. case "label":
  106. prefix = "."
  107. if explode {
  108. separator = "."
  109. } else {
  110. separator = ","
  111. }
  112. case "matrix":
  113. prefix = fmt.Sprintf(";%s=", paramName)
  114. if explode {
  115. separator = prefix
  116. } else {
  117. separator = ","
  118. }
  119. case "form":
  120. prefix = fmt.Sprintf("%s=", paramName)
  121. if explode {
  122. separator = "&" + prefix
  123. } else {
  124. separator = ","
  125. }
  126. case "spaceDelimited":
  127. prefix = fmt.Sprintf("%s=", paramName)
  128. if explode {
  129. separator = "&" + prefix
  130. } else {
  131. separator = " "
  132. }
  133. case "pipeDelimited":
  134. prefix = fmt.Sprintf("%s=", paramName)
  135. if explode {
  136. separator = "&" + prefix
  137. } else {
  138. separator = "|"
  139. }
  140. default:
  141. return "", fmt.Errorf("unsupported style '%s'", style)
  142. }
  143. // We're going to assume here that the array is one of simple types.
  144. var err error
  145. var part string
  146. parts := make([]string, len(values))
  147. for i, v := range values {
  148. part, err = primitiveToString(v)
  149. part = escapeParameterString(part, paramLocation)
  150. parts[i] = part
  151. if err != nil {
  152. return "", fmt.Errorf("error formatting '%s': %s", paramName, err)
  153. }
  154. }
  155. return prefix + strings.Join(parts, separator), nil
  156. }
  157. func sortedKeys(strMap map[string]string) []string {
  158. keys := make([]string, len(strMap))
  159. i := 0
  160. for k := range strMap {
  161. keys[i] = k
  162. i++
  163. }
  164. sort.Strings(keys)
  165. return keys
  166. }
  167. // These are special cases. The value may be a date, time, or uuid,
  168. // in which case, marshal it into the correct format.
  169. func marshalKnownTypes(value interface{}) (string, bool) {
  170. v := reflect.Indirect(reflect.ValueOf(value))
  171. t := v.Type()
  172. if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
  173. tt := v.Convert(reflect.TypeOf(time.Time{}))
  174. timeVal := tt.Interface().(time.Time)
  175. return timeVal.Format(time.RFC3339Nano), true
  176. }
  177. if t.ConvertibleTo(reflect.TypeOf(types.Date{})) {
  178. d := v.Convert(reflect.TypeOf(types.Date{}))
  179. dateVal := d.Interface().(types.Date)
  180. return dateVal.Format(types.DateFormat), true
  181. }
  182. if t.ConvertibleTo(reflect.TypeOf(types.UUID{})) {
  183. u := v.Convert(reflect.TypeOf(types.UUID{}))
  184. uuidVal := u.Interface().(types.UUID)
  185. return uuidVal.String(), true
  186. }
  187. return "", false
  188. }
  189. func styleStruct(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
  190. if timeVal, ok := marshalKnownTypes(value); ok {
  191. styledVal, err := stylePrimitive(style, explode, paramName, paramLocation, timeVal)
  192. if err != nil {
  193. return "", fmt.Errorf("failed to style time: %w", err)
  194. }
  195. return styledVal, nil
  196. }
  197. if style == "deepObject" {
  198. if !explode {
  199. return "", errors.New("deepObjects must be exploded")
  200. }
  201. return MarshalDeepObject(value, paramName)
  202. }
  203. // If input has Marshaler, such as object has Additional Property or AnyOf,
  204. // We use this Marshaler and convert into interface{} before styling.
  205. if m, ok := value.(json.Marshaler); ok {
  206. buf, err := m.MarshalJSON()
  207. if err != nil {
  208. return "", fmt.Errorf("failed to marshal input to JSON: %w", err)
  209. }
  210. e := json.NewDecoder(bytes.NewReader(buf))
  211. e.UseNumber()
  212. var i2 interface{}
  213. err = e.Decode(&i2)
  214. if err != nil {
  215. return "", fmt.Errorf("failed to unmarshal JSON: %w", err)
  216. }
  217. s, err := StyleParamWithLocation(style, explode, paramName, paramLocation, i2)
  218. if err != nil {
  219. return "", fmt.Errorf("error style JSON structure: %w", err)
  220. }
  221. return s, nil
  222. }
  223. // Otherwise, we need to build a dictionary of the struct's fields. Each
  224. // field may only be a primitive value.
  225. v := reflect.ValueOf(value)
  226. t := reflect.TypeOf(value)
  227. fieldDict := make(map[string]string)
  228. for i := 0; i < t.NumField(); i++ {
  229. fieldT := t.Field(i)
  230. // Find the json annotation on the field, and use the json specified
  231. // name if available, otherwise, just the field name.
  232. tag := fieldT.Tag.Get("json")
  233. fieldName := fieldT.Name
  234. if tag != "" {
  235. tagParts := strings.Split(tag, ",")
  236. name := tagParts[0]
  237. if name != "" {
  238. fieldName = name
  239. }
  240. }
  241. f := v.Field(i)
  242. // Unset optional fields will be nil pointers, skip over those.
  243. if f.Type().Kind() == reflect.Ptr && f.IsNil() {
  244. continue
  245. }
  246. str, err := primitiveToString(f.Interface())
  247. if err != nil {
  248. return "", fmt.Errorf("error formatting '%s': %s", paramName, err)
  249. }
  250. fieldDict[fieldName] = str
  251. }
  252. return processFieldDict(style, explode, paramName, paramLocation, fieldDict)
  253. }
  254. func styleMap(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
  255. if style == "deepObject" {
  256. if !explode {
  257. return "", errors.New("deepObjects must be exploded")
  258. }
  259. return MarshalDeepObject(value, paramName)
  260. }
  261. dict, ok := value.(map[string]interface{})
  262. if !ok {
  263. return "", errors.New("map not of type map[string]interface{}")
  264. }
  265. fieldDict := make(map[string]string)
  266. for fieldName, value := range dict {
  267. str, err := primitiveToString(value)
  268. if err != nil {
  269. return "", fmt.Errorf("error formatting '%s': %s", paramName, err)
  270. }
  271. fieldDict[fieldName] = str
  272. }
  273. return processFieldDict(style, explode, paramName, paramLocation, fieldDict)
  274. }
  275. func processFieldDict(style string, explode bool, paramName string, paramLocation ParamLocation, fieldDict map[string]string) (string, error) {
  276. var parts []string
  277. // This works for everything except deepObject. We'll handle that one
  278. // separately.
  279. if style != "deepObject" {
  280. if explode {
  281. for _, k := range sortedKeys(fieldDict) {
  282. v := escapeParameterString(fieldDict[k], paramLocation)
  283. parts = append(parts, k+"="+v)
  284. }
  285. } else {
  286. for _, k := range sortedKeys(fieldDict) {
  287. v := escapeParameterString(fieldDict[k], paramLocation)
  288. parts = append(parts, k)
  289. parts = append(parts, v)
  290. }
  291. }
  292. }
  293. var prefix string
  294. var separator string
  295. switch style {
  296. case "simple":
  297. separator = ","
  298. case "label":
  299. prefix = "."
  300. if explode {
  301. separator = prefix
  302. } else {
  303. separator = ","
  304. }
  305. case "matrix":
  306. if explode {
  307. separator = ";"
  308. prefix = ";"
  309. } else {
  310. separator = ","
  311. prefix = fmt.Sprintf(";%s=", paramName)
  312. }
  313. case "form":
  314. if explode {
  315. separator = "&"
  316. } else {
  317. prefix = fmt.Sprintf("%s=", paramName)
  318. separator = ","
  319. }
  320. case "deepObject":
  321. {
  322. if !explode {
  323. return "", fmt.Errorf("deepObject parameters must be exploded")
  324. }
  325. for _, k := range sortedKeys(fieldDict) {
  326. v := fieldDict[k]
  327. part := fmt.Sprintf("%s[%s]=%s", paramName, k, v)
  328. parts = append(parts, part)
  329. }
  330. separator = "&"
  331. }
  332. default:
  333. return "", fmt.Errorf("unsupported style '%s'", style)
  334. }
  335. return prefix + strings.Join(parts, separator), nil
  336. }
  337. func stylePrimitive(style string, explode bool, paramName string, paramLocation ParamLocation, value interface{}) (string, error) {
  338. strVal, err := primitiveToString(value)
  339. if err != nil {
  340. return "", err
  341. }
  342. var prefix string
  343. switch style {
  344. case "simple":
  345. case "label":
  346. prefix = "."
  347. case "matrix":
  348. prefix = fmt.Sprintf(";%s=", paramName)
  349. case "form":
  350. prefix = fmt.Sprintf("%s=", paramName)
  351. default:
  352. return "", fmt.Errorf("unsupported style '%s'", style)
  353. }
  354. return prefix + escapeParameterString(strVal, paramLocation), nil
  355. }
  356. // Converts a primitive value to a string. We need to do this based on the
  357. // Kind of an interface, not the Type to work with aliased types.
  358. func primitiveToString(value interface{}) (string, error) {
  359. var output string
  360. // sometimes time and date used like primitive types
  361. // it can happen if paramether is object and has time or date as field
  362. if res, ok := marshalKnownTypes(value); ok {
  363. return res, nil
  364. }
  365. // Values may come in by pointer for optionals, so make sure to dereferene.
  366. v := reflect.Indirect(reflect.ValueOf(value))
  367. t := v.Type()
  368. kind := t.Kind()
  369. switch kind {
  370. case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
  371. output = strconv.FormatInt(v.Int(), 10)
  372. case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
  373. output = strconv.FormatUint(v.Uint(), 10)
  374. case reflect.Float64:
  375. output = strconv.FormatFloat(v.Float(), 'f', -1, 64)
  376. case reflect.Float32:
  377. output = strconv.FormatFloat(v.Float(), 'f', -1, 32)
  378. case reflect.Bool:
  379. if v.Bool() {
  380. output = "true"
  381. } else {
  382. output = "false"
  383. }
  384. case reflect.String:
  385. output = v.String()
  386. case reflect.Struct:
  387. // If input has Marshaler, such as object has Additional Property or AnyOf,
  388. // We use this Marshaler and convert into interface{} before styling.
  389. if v, ok := value.(uuid.UUID); ok {
  390. output = v.String()
  391. break
  392. }
  393. if m, ok := value.(json.Marshaler); ok {
  394. buf, err := m.MarshalJSON()
  395. if err != nil {
  396. return "", fmt.Errorf("failed to marshal input to JSON: %w", err)
  397. }
  398. e := json.NewDecoder(bytes.NewReader(buf))
  399. e.UseNumber()
  400. var i2 interface{}
  401. err = e.Decode(&i2)
  402. if err != nil {
  403. return "", fmt.Errorf("failed to unmarshal JSON: %w", err)
  404. }
  405. output, err = primitiveToString(i2)
  406. if err != nil {
  407. return "", fmt.Errorf("error convert JSON structure: %w", err)
  408. }
  409. break
  410. }
  411. fallthrough
  412. default:
  413. v, ok := value.(fmt.Stringer)
  414. if !ok {
  415. return "", fmt.Errorf("unsupported type %s", reflect.TypeOf(value).String())
  416. }
  417. output = v.String()
  418. }
  419. return output, nil
  420. }
  421. // escapeParameterString escapes a parameter value bas on the location of that parameter.
  422. // Query params and path params need different kinds of escaping, while header
  423. // and cookie params seem not to need escaping.
  424. func escapeParameterString(value string, paramLocation ParamLocation) string {
  425. switch paramLocation {
  426. case ParamLocationQuery:
  427. return url.QueryEscape(value)
  428. case ParamLocationPath:
  429. return url.PathEscape(value)
  430. default:
  431. return value
  432. }
  433. }