binding.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. package hero
  2. import (
  3. "fmt"
  4. "reflect"
  5. "sort"
  6. "strconv"
  7. "github.com/kataras/iris/v12/context"
  8. )
  9. // binding contains the Dependency and the Input, it's the result of a function or struct + dependencies.
  10. type binding struct {
  11. Dependency *Dependency
  12. Input *Input
  13. }
  14. // Input contains the input reference of which a dependency is binded to.
  15. type Input struct {
  16. Index int // for func inputs
  17. StructFieldIndex []int // for struct fields in order to support embedded ones.
  18. StructFieldName string // the struct field's name.
  19. Type reflect.Type
  20. selfValue reflect.Value // reflect.ValueOf(*Input) cache.
  21. }
  22. func newInput(typ reflect.Type, index int, structFieldIndex []int) *Input {
  23. in := &Input{
  24. Index: index,
  25. StructFieldIndex: structFieldIndex,
  26. Type: typ,
  27. }
  28. in.selfValue = reflect.ValueOf(in)
  29. return in
  30. }
  31. func newStructFieldInput(f reflect.StructField) *Input {
  32. input := newInput(f.Type, f.Index[0], f.Index)
  33. input.StructFieldName = f.Name
  34. return input
  35. }
  36. // String returns the string representation of a binding.
  37. func (b *binding) String() string {
  38. var index string
  39. if len(b.Input.StructFieldIndex) > 0 {
  40. index = strconv.Itoa(b.Input.StructFieldIndex[0])
  41. for _, i := range b.Input.StructFieldIndex[1:] {
  42. index += fmt.Sprintf(".%d", i)
  43. }
  44. } else {
  45. index = strconv.Itoa(b.Input.Index)
  46. }
  47. return fmt.Sprintf("[%s:%s] maps to [%s]", index, b.Input.Type.String(), b.Dependency)
  48. }
  49. // Equal compares "b" and "other" bindings and reports whether they are referring to the same values.
  50. func (b *binding) Equal(other *binding) bool {
  51. if b == nil {
  52. return other == nil
  53. }
  54. if other == nil {
  55. return false
  56. }
  57. // if b.String() != other.String() {
  58. // return false
  59. // }
  60. if expected, got := b.Dependency != nil, other.Dependency != nil; expected != got {
  61. return false
  62. }
  63. if expected, got := fmt.Sprintf("%v", b.Dependency.OriginalValue), fmt.Sprintf("%v", other.Dependency.OriginalValue); expected != got {
  64. return false
  65. }
  66. if expected, got := b.Dependency.DestType != nil, other.Dependency.DestType != nil; expected != got {
  67. return false
  68. }
  69. if b.Dependency.DestType != nil {
  70. if expected, got := b.Dependency.DestType.String(), other.Dependency.DestType.String(); expected != got {
  71. return false
  72. }
  73. }
  74. if expected, got := b.Input != nil, other.Input != nil; expected != got {
  75. return false
  76. }
  77. if b.Input != nil {
  78. if expected, got := b.Input.Index, other.Input.Index; expected != got {
  79. return false
  80. }
  81. if expected, got := b.Input.Type.String(), other.Input.Type.String(); expected != got {
  82. return false
  83. }
  84. if expected, got := b.Input.StructFieldIndex, other.Input.StructFieldIndex; !reflect.DeepEqual(expected, got) {
  85. return false
  86. }
  87. }
  88. return true
  89. }
  90. // DependencyMatcher type alias describes a dependency match function.
  91. type DependencyMatcher = func(*Dependency, reflect.Type) bool
  92. // DefaultDependencyMatcher is the default dependency match function for all DI containers.
  93. // It is used to collect dependencies from struct's fields and function's parameters.
  94. var DefaultDependencyMatcher = func(dep *Dependency, in reflect.Type) bool {
  95. if dep.Explicit {
  96. return dep.DestType == in
  97. }
  98. return dep.DestType == nil || equalTypes(dep.DestType, in)
  99. }
  100. // ToDependencyMatchFunc converts a DependencyMatcher (generic for all dependencies)
  101. // to a dependency-specific input matcher.
  102. func ToDependencyMatchFunc(d *Dependency, match DependencyMatcher) DependencyMatchFunc {
  103. return func(in reflect.Type) bool {
  104. return match(d, in)
  105. }
  106. }
  107. func getBindingsFor(inputs []reflect.Type, deps []*Dependency, disablePayloadAutoBinding bool, paramsCount int) (bindings []*binding) {
  108. // Path parameter start index is the result of [total path parameters] - [total func path parameters inputs],
  109. // moving from last to first path parameters and first to last (probably) available input args.
  110. //
  111. // That way the above will work as expected:
  112. // 1. mvc.New(app.Party("/path/{firstparam}")).Handle(....Controller.GetBy(secondparam string))
  113. // 2. mvc.New(app.Party("/path/{firstparam}/{secondparam}")).Handle(...Controller.GetBy(firstparam, secondparam string))
  114. // 3. usersRouter := app.Party("/users/{id:uint64}"); usersRouter.ConfigureContainer().Handle(method, "/", handler(id uint64))
  115. // 4. usersRouter.Party("/friends").ConfigureContainer().Handle(method, "/{friendID:uint64}", handler(friendID uint64))
  116. //
  117. // Therefore, count the inputs that can be path parameters first.
  118. shouldBindParams := make(map[int]struct{})
  119. totalParamsExpected := 0
  120. if paramsCount != -1 {
  121. for i, in := range inputs {
  122. if _, canBePathParameter := context.ParamResolvers[in]; !canBePathParameter {
  123. continue
  124. }
  125. shouldBindParams[i] = struct{}{}
  126. totalParamsExpected++
  127. }
  128. }
  129. startParamIndex := paramsCount - totalParamsExpected
  130. if startParamIndex < 0 {
  131. startParamIndex = 0
  132. }
  133. lastParamIndex := startParamIndex
  134. getParamIndex := func() int {
  135. paramIndex := lastParamIndex
  136. lastParamIndex++
  137. return paramIndex
  138. }
  139. bindedInput := make(map[int]struct{})
  140. for i, in := range inputs { //order matters.
  141. _, canBePathParameter := shouldBindParams[i]
  142. prevN := len(bindings) // to check if a new binding is attached; a dependency was matched (see below).
  143. for j := len(deps) - 1; j >= 0; j-- {
  144. d := deps[j]
  145. // Note: we could use the same slice to return.
  146. //
  147. // Add all dynamic dependencies (caller-selecting) and the exact typed dependencies.
  148. //
  149. // A dependency can only be matched to 1 value, and 1 value has a single dependency
  150. // (e.g. to avoid conflicting path parameters of the same type).
  151. if _, alreadyBinded := bindedInput[j]; alreadyBinded {
  152. continue
  153. }
  154. match := d.Match(in)
  155. if !match {
  156. continue
  157. }
  158. if canBePathParameter {
  159. // wrap the existing dependency handler.
  160. paramHandler := paramDependencyHandler(getParamIndex())
  161. prevHandler := d.Handle
  162. d.Handle = func(ctx *context.Context, input *Input) (reflect.Value, error) {
  163. v, err := paramHandler(ctx, input)
  164. if err != nil {
  165. v, err = prevHandler(ctx, input)
  166. }
  167. return v, err
  168. }
  169. d.Static = false
  170. d.OriginalValue = nil
  171. }
  172. bindings = append(bindings, &binding{
  173. Dependency: d,
  174. Input: newInput(in, i, nil),
  175. })
  176. if !d.Explicit { // if explicit then it can be binded to more than one input
  177. bindedInput[j] = struct{}{}
  178. }
  179. break
  180. }
  181. if prevN == len(bindings) {
  182. if canBePathParameter { // Let's keep that option just for "payload": disablePayloadAutoBinding
  183. // no new dependency added for this input,
  184. // let's check for path parameters.
  185. bindings = append(bindings, paramBinding(i, getParamIndex(), in))
  186. continue
  187. }
  188. // else, if payload binding is not disabled,
  189. // add builtin request bindings that
  190. // could be registered by end-dev but they didn't
  191. if !disablePayloadAutoBinding && isPayloadType(in) {
  192. bindings = append(bindings, payloadBinding(i, in))
  193. continue
  194. }
  195. }
  196. }
  197. return
  198. }
  199. func isPayloadType(in reflect.Type) bool {
  200. switch indirectType(in).Kind() {
  201. case reflect.Struct, reflect.Slice, reflect.Ptr:
  202. return true
  203. default:
  204. return false
  205. }
  206. }
  207. func getBindingsForFunc(fn reflect.Value, dependencies []*Dependency, disablePayloadAutoBinding bool, paramsCount int) []*binding {
  208. fnTyp := fn.Type()
  209. if !isFunc(fnTyp) {
  210. panic(fmt.Sprintf("bindings: unresolved: no a func type: %#+v", fn))
  211. }
  212. n := fnTyp.NumIn()
  213. inputs := make([]reflect.Type, n)
  214. for i := 0; i < n; i++ {
  215. inputs[i] = fnTyp.In(i)
  216. }
  217. bindings := getBindingsFor(inputs, dependencies, disablePayloadAutoBinding, paramsCount)
  218. if expected, got := n, len(bindings); expected != got {
  219. expectedInputs := ""
  220. missingInputs := ""
  221. for i, in := range inputs {
  222. pos := i + 1
  223. typName := in.String()
  224. expectedInputs += fmt.Sprintf("\n - [%d] %s", pos, typName)
  225. found := false
  226. for _, b := range bindings {
  227. if b.Input.Index == i {
  228. found = true
  229. break
  230. }
  231. }
  232. if !found {
  233. missingInputs += fmt.Sprintf("\n - [%d] %s", pos, typName)
  234. }
  235. }
  236. fnName := context.HandlerName(fn)
  237. panic(fmt.Sprintf("expected [%d] bindings (input parameters) but got [%d]\nFunction:\n - %s\nExpected:%s\nMissing:%s",
  238. expected, got, fnName, expectedInputs, missingInputs))
  239. }
  240. return bindings
  241. }
  242. func getBindingsForStruct(v reflect.Value, dependencies []*Dependency, markExportedFieldsAsRequired bool, disablePayloadAutoBinding, enableStructDependents bool, matchDependency DependencyMatcher, paramsCount int, sorter Sorter) (bindings []*binding) {
  243. typ := indirectType(v.Type())
  244. if typ.Kind() != reflect.Struct {
  245. panic(fmt.Sprintf("bindings: unresolved: not a struct type: %#+v", v))
  246. }
  247. // get bindings from any struct's non zero values first, including unexported.
  248. elem := reflect.Indirect(v)
  249. nonZero := lookupNonZeroFieldValues(elem)
  250. for _, f := range nonZero {
  251. // fmt.Printf("Controller [%s] | NonZero | Field Index: %v | Field Type: %s\n", typ, f.Index, f.Type)
  252. bindings = append(bindings, &binding{
  253. Dependency: newDependency(elem.FieldByIndex(f.Index).Interface(), disablePayloadAutoBinding, enableStructDependents, nil),
  254. Input: newStructFieldInput(f),
  255. })
  256. }
  257. fields, stateless := lookupFields(elem, true, true, nil)
  258. n := len(fields)
  259. if n > 1 && sorter != nil {
  260. sort.Slice(fields, func(i, j int) bool {
  261. return sorter(fields[i].Type, fields[j].Type)
  262. })
  263. }
  264. inputs := make([]reflect.Type, n)
  265. for i := 0; i < n; i++ {
  266. // fmt.Printf("Controller [%s] | Field Index: %v | Field Type: %s\n", typ, fields[i].Index, fields[i].Type)
  267. inputs[i] = fields[i].Type
  268. }
  269. exportedBindings := getBindingsFor(inputs, dependencies, disablePayloadAutoBinding, paramsCount)
  270. // fmt.Printf("Controller [%s] | Inputs length: %d vs Bindings length: %d | NonZero: %d | Stateless : %d\n",
  271. // typ, n, len(exportedBindings), len(nonZero), stateless)
  272. // for i, b := range exportedBindings {
  273. // fmt.Printf("[%d] [Static=%v] %#+v\n", i, b.Dependency.Static, b.Dependency.OriginalValue)
  274. // }
  275. if markExportedFieldsAsRequired && len(exportedBindings) != n {
  276. panic(fmt.Sprintf("MarkExportedFieldsAsRequired is true and at least one of struct's (%s) field was not binded to a dependency.\nFields length: %d, matched exported bindings length: %d.\nUse the Reporter for further details", typ.String(), n, len(exportedBindings)))
  277. }
  278. if stateless == 0 && len(nonZero) >= len(exportedBindings) {
  279. // if we have not a single stateless and fields are defined then just return.
  280. // Note(@kataras): this can accept further improvements.
  281. return
  282. }
  283. // get declared bindings from deps.
  284. bindings = append(bindings, exportedBindings...)
  285. for _, binding := range bindings {
  286. // fmt.Printf(""Controller [%s] | Binding: %s\n", typ, binding.String())
  287. if len(binding.Input.StructFieldIndex) == 0 {
  288. // set correctly the input's field index and name.
  289. f := fields[binding.Input.Index]
  290. binding.Input.StructFieldIndex = f.Index
  291. binding.Input.StructFieldName = f.Name
  292. }
  293. // fmt.Printf("Controller [%s] | binding Index: %v | binding Type: %s\n", typ, binding.Input.StructFieldIndex, binding.Input.Type)
  294. // fmt.Printf("Controller [%s] Set [%s] to struct field index: %v\n", typ.String(), binding.Input.Type.String(), binding.Input.StructFieldIndex)
  295. }
  296. return
  297. }
  298. func getStaticInputs(bindings []*binding, numIn int) []reflect.Value {
  299. inputs := make([]reflect.Value, numIn)
  300. for _, b := range bindings {
  301. if d := b.Dependency; d != nil && d.Static {
  302. inputs[b.Input.Index], _ = d.Handle(nil, nil)
  303. }
  304. }
  305. return inputs
  306. }
  307. /*
  308. Builtin dynamic bindings.
  309. */
  310. func paramBinding(index, paramIndex int, typ reflect.Type) *binding {
  311. return &binding{
  312. Dependency: &Dependency{Handle: paramDependencyHandler(paramIndex), DestType: typ, Source: getSource()},
  313. Input: newInput(typ, index, nil),
  314. }
  315. }
  316. func paramDependencyHandler(paramIndex int) DependencyHandler {
  317. return func(ctx *context.Context, input *Input) (reflect.Value, error) {
  318. if ctx.Params().Len() <= paramIndex {
  319. return emptyValue, ErrSeeOther
  320. }
  321. return reflect.ValueOf(ctx.Params().Store[paramIndex].ValueRaw), nil
  322. }
  323. }
  324. // registered if input parameters are more than matched dependencies.
  325. // It binds an input to a request body based on the request content-type header
  326. // (JSON, Protobuf, Msgpack, XML, YAML, Query, Form).
  327. func payloadBinding(index int, typ reflect.Type) *binding {
  328. // fmt.Printf("Register payload binding for index: %d and type: %s\n", index, typ.String())
  329. return &binding{
  330. Dependency: &Dependency{
  331. Handle: func(ctx *context.Context, input *Input) (newValue reflect.Value, err error) {
  332. wasPtr := input.Type.Kind() == reflect.Ptr
  333. if serveDepsV := ctx.Values().Get(context.DependenciesContextKey); serveDepsV != nil {
  334. if serveDeps, ok := serveDepsV.(context.DependenciesMap); ok {
  335. if newValue, ok = serveDeps[typ]; ok {
  336. return
  337. }
  338. }
  339. }
  340. if input.Type.Kind() == reflect.Slice {
  341. newValue = reflect.New(reflect.SliceOf(indirectType(input.Type)))
  342. } else {
  343. newValue = reflect.New(indirectType(input.Type))
  344. }
  345. ptr := newValue.Interface()
  346. err = ctx.ReadBody(ptr)
  347. if !wasPtr {
  348. newValue = newValue.Elem()
  349. }
  350. return
  351. },
  352. Source: getSource(),
  353. },
  354. Input: newInput(typ, index, nil),
  355. }
  356. }