rate.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. package redis_rate
  2. import (
  3. "context"
  4. "fmt"
  5. "strconv"
  6. "time"
  7. "github.com/go-redis/redis/v8"
  8. )
  9. const redisPrefix = "rate:"
  10. type rediser interface {
  11. Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
  12. EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
  13. ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd
  14. ScriptLoad(ctx context.Context, script string) *redis.StringCmd
  15. }
  16. type Limit struct {
  17. Rate int
  18. Burst int
  19. Period time.Duration
  20. }
  21. func (l Limit) String() string {
  22. return fmt.Sprintf("%d req/%s (burst %d)", l.Rate, fmtDur(l.Period), l.Burst)
  23. }
  24. func (l Limit) IsZero() bool {
  25. return l == Limit{}
  26. }
  27. func fmtDur(d time.Duration) string {
  28. switch d {
  29. case time.Second:
  30. return "s"
  31. case time.Minute:
  32. return "m"
  33. case time.Hour:
  34. return "h"
  35. }
  36. return d.String()
  37. }
  38. func PerSecond(rate int) Limit {
  39. return Limit{
  40. Rate: rate,
  41. Period: time.Second,
  42. Burst: rate,
  43. }
  44. }
  45. func PerMinute(rate int) Limit {
  46. return Limit{
  47. Rate: rate,
  48. Period: time.Minute,
  49. Burst: rate,
  50. }
  51. }
  52. func PerHour(rate int) Limit {
  53. return Limit{
  54. Rate: rate,
  55. Period: time.Hour,
  56. Burst: rate,
  57. }
  58. }
  59. //------------------------------------------------------------------------------
  60. // Limiter controls how frequently events are allowed to happen.
  61. type Limiter struct {
  62. rdb rediser
  63. }
  64. // NewLimiter returns a new Limiter.
  65. func NewLimiter(rdb rediser) *Limiter {
  66. return &Limiter{
  67. rdb: rdb,
  68. }
  69. }
  70. // Allow is a shortcut for AllowN(ctx, key, limit, 1).
  71. func (l Limiter) Allow(ctx context.Context, key string, limit Limit) (*Result, error) {
  72. return l.AllowN(ctx, key, limit, 1)
  73. }
  74. // AllowN reports whether n events may happen at time now.
  75. func (l Limiter) AllowN(
  76. ctx context.Context,
  77. key string,
  78. limit Limit,
  79. n int,
  80. ) (*Result, error) {
  81. values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
  82. v, err := allowN.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
  83. if err != nil {
  84. return nil, err
  85. }
  86. values = v.([]interface{})
  87. retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
  88. if err != nil {
  89. return nil, err
  90. }
  91. resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
  92. if err != nil {
  93. return nil, err
  94. }
  95. res := &Result{
  96. Limit: limit,
  97. Allowed: int(values[0].(int64)),
  98. Remaining: int(values[1].(int64)),
  99. RetryAfter: dur(retryAfter),
  100. ResetAfter: dur(resetAfter),
  101. }
  102. return res, nil
  103. }
  104. // AllowAtMost reports whether at most n events may happen at time now.
  105. // It returns number of allowed events that is less than or equal to n.
  106. func (l Limiter) AllowAtMost(
  107. ctx context.Context,
  108. key string,
  109. limit Limit,
  110. n int,
  111. ) (*Result, error) {
  112. values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n}
  113. v, err := allowAtMost.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result()
  114. if err != nil {
  115. return nil, err
  116. }
  117. values = v.([]interface{})
  118. retryAfter, err := strconv.ParseFloat(values[2].(string), 64)
  119. if err != nil {
  120. return nil, err
  121. }
  122. resetAfter, err := strconv.ParseFloat(values[3].(string), 64)
  123. if err != nil {
  124. return nil, err
  125. }
  126. res := &Result{
  127. Limit: limit,
  128. Allowed: int(values[0].(int64)),
  129. Remaining: int(values[1].(int64)),
  130. RetryAfter: dur(retryAfter),
  131. ResetAfter: dur(resetAfter),
  132. }
  133. return res, nil
  134. }
  135. func dur(f float64) time.Duration {
  136. if f == -1 {
  137. return -1
  138. }
  139. return time.Duration(f * float64(time.Second))
  140. }
  141. type Result struct {
  142. // Limit is the limit that was used to obtain this result.
  143. Limit Limit
  144. // Allowed is the number of events that may happen at time now.
  145. Allowed int
  146. // Remaining is the maximum number of requests that could be
  147. // permitted instantaneously for this key given the current
  148. // state. For example, if a rate limiter allows 10 requests per
  149. // second and has already received 6 requests for this key this
  150. // second, Remaining would be 4.
  151. Remaining int
  152. // RetryAfter is the time until the next request will be permitted.
  153. // It should be -1 unless the rate limit has been exceeded.
  154. RetryAfter time.Duration
  155. // ResetAfter is the time until the RateLimiter returns to its
  156. // initial state for a given key. For example, if a rate limiter
  157. // manages requests per second and received one request 200ms ago,
  158. // Reset would return 800ms. You can also think of this as the time
  159. // until Limit and Remaining will be equal.
  160. ResetAfter time.Duration
  161. }