assertion_compare.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. package assert
  2. import (
  3. "fmt"
  4. "reflect"
  5. "time"
  6. )
  7. type CompareType int
  8. const (
  9. compareLess CompareType = iota - 1
  10. compareEqual
  11. compareGreater
  12. )
  13. var (
  14. intType = reflect.TypeOf(int(1))
  15. int8Type = reflect.TypeOf(int8(1))
  16. int16Type = reflect.TypeOf(int16(1))
  17. int32Type = reflect.TypeOf(int32(1))
  18. int64Type = reflect.TypeOf(int64(1))
  19. uintType = reflect.TypeOf(uint(1))
  20. uint8Type = reflect.TypeOf(uint8(1))
  21. uint16Type = reflect.TypeOf(uint16(1))
  22. uint32Type = reflect.TypeOf(uint32(1))
  23. uint64Type = reflect.TypeOf(uint64(1))
  24. float32Type = reflect.TypeOf(float32(1))
  25. float64Type = reflect.TypeOf(float64(1))
  26. stringType = reflect.TypeOf("")
  27. timeType = reflect.TypeOf(time.Time{})
  28. )
  29. func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
  30. obj1Value := reflect.ValueOf(obj1)
  31. obj2Value := reflect.ValueOf(obj2)
  32. // throughout this switch we try and avoid calling .Convert() if possible,
  33. // as this has a pretty big performance impact
  34. switch kind {
  35. case reflect.Int:
  36. {
  37. intobj1, ok := obj1.(int)
  38. if !ok {
  39. intobj1 = obj1Value.Convert(intType).Interface().(int)
  40. }
  41. intobj2, ok := obj2.(int)
  42. if !ok {
  43. intobj2 = obj2Value.Convert(intType).Interface().(int)
  44. }
  45. if intobj1 > intobj2 {
  46. return compareGreater, true
  47. }
  48. if intobj1 == intobj2 {
  49. return compareEqual, true
  50. }
  51. if intobj1 < intobj2 {
  52. return compareLess, true
  53. }
  54. }
  55. case reflect.Int8:
  56. {
  57. int8obj1, ok := obj1.(int8)
  58. if !ok {
  59. int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
  60. }
  61. int8obj2, ok := obj2.(int8)
  62. if !ok {
  63. int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
  64. }
  65. if int8obj1 > int8obj2 {
  66. return compareGreater, true
  67. }
  68. if int8obj1 == int8obj2 {
  69. return compareEqual, true
  70. }
  71. if int8obj1 < int8obj2 {
  72. return compareLess, true
  73. }
  74. }
  75. case reflect.Int16:
  76. {
  77. int16obj1, ok := obj1.(int16)
  78. if !ok {
  79. int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
  80. }
  81. int16obj2, ok := obj2.(int16)
  82. if !ok {
  83. int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
  84. }
  85. if int16obj1 > int16obj2 {
  86. return compareGreater, true
  87. }
  88. if int16obj1 == int16obj2 {
  89. return compareEqual, true
  90. }
  91. if int16obj1 < int16obj2 {
  92. return compareLess, true
  93. }
  94. }
  95. case reflect.Int32:
  96. {
  97. int32obj1, ok := obj1.(int32)
  98. if !ok {
  99. int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
  100. }
  101. int32obj2, ok := obj2.(int32)
  102. if !ok {
  103. int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
  104. }
  105. if int32obj1 > int32obj2 {
  106. return compareGreater, true
  107. }
  108. if int32obj1 == int32obj2 {
  109. return compareEqual, true
  110. }
  111. if int32obj1 < int32obj2 {
  112. return compareLess, true
  113. }
  114. }
  115. case reflect.Int64:
  116. {
  117. int64obj1, ok := obj1.(int64)
  118. if !ok {
  119. int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
  120. }
  121. int64obj2, ok := obj2.(int64)
  122. if !ok {
  123. int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
  124. }
  125. if int64obj1 > int64obj2 {
  126. return compareGreater, true
  127. }
  128. if int64obj1 == int64obj2 {
  129. return compareEqual, true
  130. }
  131. if int64obj1 < int64obj2 {
  132. return compareLess, true
  133. }
  134. }
  135. case reflect.Uint:
  136. {
  137. uintobj1, ok := obj1.(uint)
  138. if !ok {
  139. uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
  140. }
  141. uintobj2, ok := obj2.(uint)
  142. if !ok {
  143. uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
  144. }
  145. if uintobj1 > uintobj2 {
  146. return compareGreater, true
  147. }
  148. if uintobj1 == uintobj2 {
  149. return compareEqual, true
  150. }
  151. if uintobj1 < uintobj2 {
  152. return compareLess, true
  153. }
  154. }
  155. case reflect.Uint8:
  156. {
  157. uint8obj1, ok := obj1.(uint8)
  158. if !ok {
  159. uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
  160. }
  161. uint8obj2, ok := obj2.(uint8)
  162. if !ok {
  163. uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
  164. }
  165. if uint8obj1 > uint8obj2 {
  166. return compareGreater, true
  167. }
  168. if uint8obj1 == uint8obj2 {
  169. return compareEqual, true
  170. }
  171. if uint8obj1 < uint8obj2 {
  172. return compareLess, true
  173. }
  174. }
  175. case reflect.Uint16:
  176. {
  177. uint16obj1, ok := obj1.(uint16)
  178. if !ok {
  179. uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
  180. }
  181. uint16obj2, ok := obj2.(uint16)
  182. if !ok {
  183. uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
  184. }
  185. if uint16obj1 > uint16obj2 {
  186. return compareGreater, true
  187. }
  188. if uint16obj1 == uint16obj2 {
  189. return compareEqual, true
  190. }
  191. if uint16obj1 < uint16obj2 {
  192. return compareLess, true
  193. }
  194. }
  195. case reflect.Uint32:
  196. {
  197. uint32obj1, ok := obj1.(uint32)
  198. if !ok {
  199. uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
  200. }
  201. uint32obj2, ok := obj2.(uint32)
  202. if !ok {
  203. uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
  204. }
  205. if uint32obj1 > uint32obj2 {
  206. return compareGreater, true
  207. }
  208. if uint32obj1 == uint32obj2 {
  209. return compareEqual, true
  210. }
  211. if uint32obj1 < uint32obj2 {
  212. return compareLess, true
  213. }
  214. }
  215. case reflect.Uint64:
  216. {
  217. uint64obj1, ok := obj1.(uint64)
  218. if !ok {
  219. uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
  220. }
  221. uint64obj2, ok := obj2.(uint64)
  222. if !ok {
  223. uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
  224. }
  225. if uint64obj1 > uint64obj2 {
  226. return compareGreater, true
  227. }
  228. if uint64obj1 == uint64obj2 {
  229. return compareEqual, true
  230. }
  231. if uint64obj1 < uint64obj2 {
  232. return compareLess, true
  233. }
  234. }
  235. case reflect.Float32:
  236. {
  237. float32obj1, ok := obj1.(float32)
  238. if !ok {
  239. float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
  240. }
  241. float32obj2, ok := obj2.(float32)
  242. if !ok {
  243. float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
  244. }
  245. if float32obj1 > float32obj2 {
  246. return compareGreater, true
  247. }
  248. if float32obj1 == float32obj2 {
  249. return compareEqual, true
  250. }
  251. if float32obj1 < float32obj2 {
  252. return compareLess, true
  253. }
  254. }
  255. case reflect.Float64:
  256. {
  257. float64obj1, ok := obj1.(float64)
  258. if !ok {
  259. float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
  260. }
  261. float64obj2, ok := obj2.(float64)
  262. if !ok {
  263. float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
  264. }
  265. if float64obj1 > float64obj2 {
  266. return compareGreater, true
  267. }
  268. if float64obj1 == float64obj2 {
  269. return compareEqual, true
  270. }
  271. if float64obj1 < float64obj2 {
  272. return compareLess, true
  273. }
  274. }
  275. case reflect.String:
  276. {
  277. stringobj1, ok := obj1.(string)
  278. if !ok {
  279. stringobj1 = obj1Value.Convert(stringType).Interface().(string)
  280. }
  281. stringobj2, ok := obj2.(string)
  282. if !ok {
  283. stringobj2 = obj2Value.Convert(stringType).Interface().(string)
  284. }
  285. if stringobj1 > stringobj2 {
  286. return compareGreater, true
  287. }
  288. if stringobj1 == stringobj2 {
  289. return compareEqual, true
  290. }
  291. if stringobj1 < stringobj2 {
  292. return compareLess, true
  293. }
  294. }
  295. // Check for known struct types we can check for compare results.
  296. case reflect.Struct:
  297. {
  298. // All structs enter here. We're not interested in most types.
  299. if !canConvert(obj1Value, timeType) {
  300. break
  301. }
  302. // time.Time can compared!
  303. timeObj1, ok := obj1.(time.Time)
  304. if !ok {
  305. timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
  306. }
  307. timeObj2, ok := obj2.(time.Time)
  308. if !ok {
  309. timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
  310. }
  311. return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
  312. }
  313. }
  314. return compareEqual, false
  315. }
  316. // Greater asserts that the first element is greater than the second
  317. //
  318. // assert.Greater(t, 2, 1)
  319. // assert.Greater(t, float64(2), float64(1))
  320. // assert.Greater(t, "b", "a")
  321. func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  322. if h, ok := t.(tHelper); ok {
  323. h.Helper()
  324. }
  325. return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
  326. }
  327. // GreaterOrEqual asserts that the first element is greater than or equal to the second
  328. //
  329. // assert.GreaterOrEqual(t, 2, 1)
  330. // assert.GreaterOrEqual(t, 2, 2)
  331. // assert.GreaterOrEqual(t, "b", "a")
  332. // assert.GreaterOrEqual(t, "b", "b")
  333. func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  334. if h, ok := t.(tHelper); ok {
  335. h.Helper()
  336. }
  337. return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
  338. }
  339. // Less asserts that the first element is less than the second
  340. //
  341. // assert.Less(t, 1, 2)
  342. // assert.Less(t, float64(1), float64(2))
  343. // assert.Less(t, "a", "b")
  344. func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  345. if h, ok := t.(tHelper); ok {
  346. h.Helper()
  347. }
  348. return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
  349. }
  350. // LessOrEqual asserts that the first element is less than or equal to the second
  351. //
  352. // assert.LessOrEqual(t, 1, 2)
  353. // assert.LessOrEqual(t, 2, 2)
  354. // assert.LessOrEqual(t, "a", "b")
  355. // assert.LessOrEqual(t, "b", "b")
  356. func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  357. if h, ok := t.(tHelper); ok {
  358. h.Helper()
  359. }
  360. return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
  361. }
  362. // Positive asserts that the specified element is positive
  363. //
  364. // assert.Positive(t, 1)
  365. // assert.Positive(t, 1.23)
  366. func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
  367. if h, ok := t.(tHelper); ok {
  368. h.Helper()
  369. }
  370. zero := reflect.Zero(reflect.TypeOf(e))
  371. return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
  372. }
  373. // Negative asserts that the specified element is negative
  374. //
  375. // assert.Negative(t, -1)
  376. // assert.Negative(t, -1.23)
  377. func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
  378. if h, ok := t.(tHelper); ok {
  379. h.Helper()
  380. }
  381. zero := reflect.Zero(reflect.TypeOf(e))
  382. return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
  383. }
  384. func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
  385. if h, ok := t.(tHelper); ok {
  386. h.Helper()
  387. }
  388. e1Kind := reflect.ValueOf(e1).Kind()
  389. e2Kind := reflect.ValueOf(e2).Kind()
  390. if e1Kind != e2Kind {
  391. return Fail(t, "Elements should be the same type", msgAndArgs...)
  392. }
  393. compareResult, isComparable := compare(e1, e2, e1Kind)
  394. if !isComparable {
  395. return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
  396. }
  397. if !containsValue(allowedComparesResults, compareResult) {
  398. return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
  399. }
  400. return true
  401. }
  402. func containsValue(values []CompareType, value CompareType) bool {
  403. for _, v := range values {
  404. if v == value {
  405. return true
  406. }
  407. }
  408. return false
  409. }