assertion_compare.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. package assert
  2. import (
  3. "bytes"
  4. "fmt"
  5. "reflect"
  6. "time"
  7. )
  8. type CompareType int
  9. const (
  10. compareLess CompareType = iota - 1
  11. compareEqual
  12. compareGreater
  13. )
  14. var (
  15. intType = reflect.TypeOf(int(1))
  16. int8Type = reflect.TypeOf(int8(1))
  17. int16Type = reflect.TypeOf(int16(1))
  18. int32Type = reflect.TypeOf(int32(1))
  19. int64Type = reflect.TypeOf(int64(1))
  20. uintType = reflect.TypeOf(uint(1))
  21. uint8Type = reflect.TypeOf(uint8(1))
  22. uint16Type = reflect.TypeOf(uint16(1))
  23. uint32Type = reflect.TypeOf(uint32(1))
  24. uint64Type = reflect.TypeOf(uint64(1))
  25. float32Type = reflect.TypeOf(float32(1))
  26. float64Type = reflect.TypeOf(float64(1))
  27. stringType = reflect.TypeOf("")
  28. timeType = reflect.TypeOf(time.Time{})
  29. bytesType = reflect.TypeOf([]byte{})
  30. )
  31. func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
  32. obj1Value := reflect.ValueOf(obj1)
  33. obj2Value := reflect.ValueOf(obj2)
  34. // throughout this switch we try and avoid calling .Convert() if possible,
  35. // as this has a pretty big performance impact
  36. switch kind {
  37. case reflect.Int:
  38. {
  39. intobj1, ok := obj1.(int)
  40. if !ok {
  41. intobj1 = obj1Value.Convert(intType).Interface().(int)
  42. }
  43. intobj2, ok := obj2.(int)
  44. if !ok {
  45. intobj2 = obj2Value.Convert(intType).Interface().(int)
  46. }
  47. if intobj1 > intobj2 {
  48. return compareGreater, true
  49. }
  50. if intobj1 == intobj2 {
  51. return compareEqual, true
  52. }
  53. if intobj1 < intobj2 {
  54. return compareLess, true
  55. }
  56. }
  57. case reflect.Int8:
  58. {
  59. int8obj1, ok := obj1.(int8)
  60. if !ok {
  61. int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
  62. }
  63. int8obj2, ok := obj2.(int8)
  64. if !ok {
  65. int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
  66. }
  67. if int8obj1 > int8obj2 {
  68. return compareGreater, true
  69. }
  70. if int8obj1 == int8obj2 {
  71. return compareEqual, true
  72. }
  73. if int8obj1 < int8obj2 {
  74. return compareLess, true
  75. }
  76. }
  77. case reflect.Int16:
  78. {
  79. int16obj1, ok := obj1.(int16)
  80. if !ok {
  81. int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
  82. }
  83. int16obj2, ok := obj2.(int16)
  84. if !ok {
  85. int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
  86. }
  87. if int16obj1 > int16obj2 {
  88. return compareGreater, true
  89. }
  90. if int16obj1 == int16obj2 {
  91. return compareEqual, true
  92. }
  93. if int16obj1 < int16obj2 {
  94. return compareLess, true
  95. }
  96. }
  97. case reflect.Int32:
  98. {
  99. int32obj1, ok := obj1.(int32)
  100. if !ok {
  101. int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
  102. }
  103. int32obj2, ok := obj2.(int32)
  104. if !ok {
  105. int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
  106. }
  107. if int32obj1 > int32obj2 {
  108. return compareGreater, true
  109. }
  110. if int32obj1 == int32obj2 {
  111. return compareEqual, true
  112. }
  113. if int32obj1 < int32obj2 {
  114. return compareLess, true
  115. }
  116. }
  117. case reflect.Int64:
  118. {
  119. int64obj1, ok := obj1.(int64)
  120. if !ok {
  121. int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
  122. }
  123. int64obj2, ok := obj2.(int64)
  124. if !ok {
  125. int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
  126. }
  127. if int64obj1 > int64obj2 {
  128. return compareGreater, true
  129. }
  130. if int64obj1 == int64obj2 {
  131. return compareEqual, true
  132. }
  133. if int64obj1 < int64obj2 {
  134. return compareLess, true
  135. }
  136. }
  137. case reflect.Uint:
  138. {
  139. uintobj1, ok := obj1.(uint)
  140. if !ok {
  141. uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
  142. }
  143. uintobj2, ok := obj2.(uint)
  144. if !ok {
  145. uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
  146. }
  147. if uintobj1 > uintobj2 {
  148. return compareGreater, true
  149. }
  150. if uintobj1 == uintobj2 {
  151. return compareEqual, true
  152. }
  153. if uintobj1 < uintobj2 {
  154. return compareLess, true
  155. }
  156. }
  157. case reflect.Uint8:
  158. {
  159. uint8obj1, ok := obj1.(uint8)
  160. if !ok {
  161. uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
  162. }
  163. uint8obj2, ok := obj2.(uint8)
  164. if !ok {
  165. uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
  166. }
  167. if uint8obj1 > uint8obj2 {
  168. return compareGreater, true
  169. }
  170. if uint8obj1 == uint8obj2 {
  171. return compareEqual, true
  172. }
  173. if uint8obj1 < uint8obj2 {
  174. return compareLess, true
  175. }
  176. }
  177. case reflect.Uint16:
  178. {
  179. uint16obj1, ok := obj1.(uint16)
  180. if !ok {
  181. uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
  182. }
  183. uint16obj2, ok := obj2.(uint16)
  184. if !ok {
  185. uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
  186. }
  187. if uint16obj1 > uint16obj2 {
  188. return compareGreater, true
  189. }
  190. if uint16obj1 == uint16obj2 {
  191. return compareEqual, true
  192. }
  193. if uint16obj1 < uint16obj2 {
  194. return compareLess, true
  195. }
  196. }
  197. case reflect.Uint32:
  198. {
  199. uint32obj1, ok := obj1.(uint32)
  200. if !ok {
  201. uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
  202. }
  203. uint32obj2, ok := obj2.(uint32)
  204. if !ok {
  205. uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
  206. }
  207. if uint32obj1 > uint32obj2 {
  208. return compareGreater, true
  209. }
  210. if uint32obj1 == uint32obj2 {
  211. return compareEqual, true
  212. }
  213. if uint32obj1 < uint32obj2 {
  214. return compareLess, true
  215. }
  216. }
  217. case reflect.Uint64:
  218. {
  219. uint64obj1, ok := obj1.(uint64)
  220. if !ok {
  221. uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
  222. }
  223. uint64obj2, ok := obj2.(uint64)
  224. if !ok {
  225. uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
  226. }
  227. if uint64obj1 > uint64obj2 {
  228. return compareGreater, true
  229. }
  230. if uint64obj1 == uint64obj2 {
  231. return compareEqual, true
  232. }
  233. if uint64obj1 < uint64obj2 {
  234. return compareLess, true
  235. }
  236. }
  237. case reflect.Float32:
  238. {
  239. float32obj1, ok := obj1.(float32)
  240. if !ok {
  241. float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
  242. }
  243. float32obj2, ok := obj2.(float32)
  244. if !ok {
  245. float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
  246. }
  247. if float32obj1 > float32obj2 {
  248. return compareGreater, true
  249. }
  250. if float32obj1 == float32obj2 {
  251. return compareEqual, true
  252. }
  253. if float32obj1 < float32obj2 {
  254. return compareLess, true
  255. }
  256. }
  257. case reflect.Float64:
  258. {
  259. float64obj1, ok := obj1.(float64)
  260. if !ok {
  261. float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
  262. }
  263. float64obj2, ok := obj2.(float64)
  264. if !ok {
  265. float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
  266. }
  267. if float64obj1 > float64obj2 {
  268. return compareGreater, true
  269. }
  270. if float64obj1 == float64obj2 {
  271. return compareEqual, true
  272. }
  273. if float64obj1 < float64obj2 {
  274. return compareLess, true
  275. }
  276. }
  277. case reflect.String:
  278. {
  279. stringobj1, ok := obj1.(string)
  280. if !ok {
  281. stringobj1 = obj1Value.Convert(stringType).Interface().(string)
  282. }
  283. stringobj2, ok := obj2.(string)
  284. if !ok {
  285. stringobj2 = obj2Value.Convert(stringType).Interface().(string)
  286. }
  287. if stringobj1 > stringobj2 {
  288. return compareGreater, true
  289. }
  290. if stringobj1 == stringobj2 {
  291. return compareEqual, true
  292. }
  293. if stringobj1 < stringobj2 {
  294. return compareLess, true
  295. }
  296. }
  297. // Check for known struct types we can check for compare results.
  298. case reflect.Struct:
  299. {
  300. // All structs enter here. We're not interested in most types.
  301. if !canConvert(obj1Value, timeType) {
  302. break
  303. }
  304. // time.Time can compared!
  305. timeObj1, ok := obj1.(time.Time)
  306. if !ok {
  307. timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
  308. }
  309. timeObj2, ok := obj2.(time.Time)
  310. if !ok {
  311. timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
  312. }
  313. return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
  314. }
  315. case reflect.Slice:
  316. {
  317. // We only care about the []byte type.
  318. if !canConvert(obj1Value, bytesType) {
  319. break
  320. }
  321. // []byte can be compared!
  322. bytesObj1, ok := obj1.([]byte)
  323. if !ok {
  324. bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
  325. }
  326. bytesObj2, ok := obj2.([]byte)
  327. if !ok {
  328. bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
  329. }
  330. return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
  331. }
  332. }
  333. return compareEqual, false
  334. }
  335. // Greater asserts that the first element is greater than the second
  336. //
  337. // assert.Greater(t, 2, 1)
  338. // assert.Greater(t, float64(2), float64(1))
  339. // assert.Greater(t, "b", "a")
  340. func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  341. if h, ok := t.(tHelper); ok {
  342. h.Helper()
  343. }
  344. return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
  345. }
  346. // GreaterOrEqual asserts that the first element is greater than or equal to the second
  347. //
  348. // assert.GreaterOrEqual(t, 2, 1)
  349. // assert.GreaterOrEqual(t, 2, 2)
  350. // assert.GreaterOrEqual(t, "b", "a")
  351. // assert.GreaterOrEqual(t, "b", "b")
  352. func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  353. if h, ok := t.(tHelper); ok {
  354. h.Helper()
  355. }
  356. return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
  357. }
  358. // Less asserts that the first element is less than the second
  359. //
  360. // assert.Less(t, 1, 2)
  361. // assert.Less(t, float64(1), float64(2))
  362. // assert.Less(t, "a", "b")
  363. func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  364. if h, ok := t.(tHelper); ok {
  365. h.Helper()
  366. }
  367. return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
  368. }
  369. // LessOrEqual asserts that the first element is less than or equal to the second
  370. //
  371. // assert.LessOrEqual(t, 1, 2)
  372. // assert.LessOrEqual(t, 2, 2)
  373. // assert.LessOrEqual(t, "a", "b")
  374. // assert.LessOrEqual(t, "b", "b")
  375. func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
  376. if h, ok := t.(tHelper); ok {
  377. h.Helper()
  378. }
  379. return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
  380. }
  381. // Positive asserts that the specified element is positive
  382. //
  383. // assert.Positive(t, 1)
  384. // assert.Positive(t, 1.23)
  385. func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
  386. if h, ok := t.(tHelper); ok {
  387. h.Helper()
  388. }
  389. zero := reflect.Zero(reflect.TypeOf(e))
  390. return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
  391. }
  392. // Negative asserts that the specified element is negative
  393. //
  394. // assert.Negative(t, -1)
  395. // assert.Negative(t, -1.23)
  396. func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
  397. if h, ok := t.(tHelper); ok {
  398. h.Helper()
  399. }
  400. zero := reflect.Zero(reflect.TypeOf(e))
  401. return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
  402. }
  403. func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
  404. if h, ok := t.(tHelper); ok {
  405. h.Helper()
  406. }
  407. e1Kind := reflect.ValueOf(e1).Kind()
  408. e2Kind := reflect.ValueOf(e2).Kind()
  409. if e1Kind != e2Kind {
  410. return Fail(t, "Elements should be the same type", msgAndArgs...)
  411. }
  412. compareResult, isComparable := compare(e1, e2, e1Kind)
  413. if !isComparable {
  414. return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
  415. }
  416. if !containsValue(allowedComparesResults, compareResult) {
  417. return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
  418. }
  419. return true
  420. }
  421. func containsValue(values []CompareType, value CompareType) bool {
  422. for _, v := range values {
  423. if v == value {
  424. return true
  425. }
  426. }
  427. return false
  428. }