mw_rate_limiter.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. package middleware
  2. import (
  3. "github.com/go-redis/redis/v8"
  4. "github.com/go-redis/redis_rate/v9"
  5. "github.com/gogf/gf/v2/net/ghttp"
  6. "gxt-api-frame/app/errors"
  7. "gxt-api-frame/library/gplus"
  8. "gxt-api-frame/library/logger"
  9. "gxt-api-frame/library/utils"
  10. "strconv"
  11. )
  12. // RateLimiterMiddleware 请求频率限制中间件
  13. func RateLimiterMiddleware(skippers ...SkipperFunc) ghttp.HandlerFunc {
  14. if !utils.GetConfig("rate_limiter.enable").Bool() {
  15. return EmptyMiddleware
  16. }
  17. // check enable redis
  18. if !utils.GetConfig("redis.enable").Bool() {
  19. return func(r *ghttp.Request) {
  20. logger.Warnf(gplus.NewContext(r), "限流中间件无法正常使用,请启用redis配置[redis.enable]")
  21. r.Middleware.Next()
  22. return
  23. }
  24. }
  25. addr := utils.GetConfig("redis.addr").String()
  26. password := utils.GetConfig("redis.password").String()
  27. db := utils.GetConfig("redis.db").Int()
  28. ring := redis.NewRing(&redis.RingOptions{
  29. Addrs: map[string]string{
  30. "server1": addr,
  31. },
  32. Password: password,
  33. DB: db,
  34. })
  35. limiter := redis_rate.NewLimiter(ring)
  36. return func(r *ghttp.Request) {
  37. if SkipHandler(r, skippers...) {
  38. r.Middleware.Next()
  39. return
  40. }
  41. userID := gplus.GetUserID(r)
  42. if userID == "" {
  43. r.Middleware.Next()
  44. return
  45. }
  46. ctx := gplus.NewContext(r)
  47. limit := utils.GetConfig("rate_limiter.count").Int()
  48. result, err := limiter.Allow(ctx,
  49. userID, redis_rate.PerMinute(limit))
  50. if err != nil {
  51. gplus.ResError(r, errors.ErrInternalServer)
  52. }
  53. if result != nil {
  54. if result.Allowed == 0 {
  55. h := r.Response.Header()
  56. h.Set("X-RateLimit-Limit", strconv.FormatInt(int64(result.Limit.Burst), 10))
  57. h.Set("X-RateLimit-Remaining", strconv.FormatInt(int64(result.Remaining), 10))
  58. h.Set("X-RateLimit-Reset", strconv.FormatInt(int64(result.ResetAfter.Seconds()), 10))
  59. gplus.ResError(r, errors.ErrTooManyRequests)
  60. return
  61. }
  62. }
  63. r.Middleware.Next()
  64. }
  65. }