mw_rate_limiter.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. package middleware
  2. import (
  3. "github.com/go-redis/redis/v8"
  4. "github.com/go-redis/redis_rate/v9"
  5. "github.com/gogf/gf/frame/g"
  6. "github.com/gogf/gf/net/ghttp"
  7. "gxt-file-server/app/errors"
  8. "gxt-file-server/pkg/gplus"
  9. "gxt-file-server/pkg/logger"
  10. "strconv"
  11. )
  12. // RateLimiterMiddleware 请求频率限制中间件
  13. func RateLimiterMiddleware(skippers ...SkipperFunc) ghttp.HandlerFunc {
  14. if !g.Cfg().GetBool("rate_limiter.enable") {
  15. return EmptyMiddleware
  16. }
  17. // check enable redis
  18. if !g.Cfg().GetBool("redis.enable") {
  19. return func(r *ghttp.Request) {
  20. logger.Warnf(gplus.NewContext(r), "限流中间件无法正常使用,请启用redis配置[redis.enable]")
  21. r.Middleware.Next()
  22. }
  23. }
  24. addr := g.Cfg().GetString("redis.addr")
  25. password := g.Cfg().GetString("redis.password")
  26. db := g.Cfg().GetInt("redis.db")
  27. ring := redis.NewRing(&redis.RingOptions{
  28. Addrs: map[string]string{
  29. "server1": addr,
  30. },
  31. Password: password,
  32. DB: db,
  33. })
  34. limiter := redis_rate.NewLimiter(ring)
  35. return func(r *ghttp.Request) {
  36. if SkipHandler(r, skippers...) {
  37. r.Middleware.Next()
  38. return
  39. }
  40. userID := gplus.GetUserID(r)
  41. if userID == "" {
  42. r.Middleware.Next()
  43. return
  44. }
  45. ctx := gplus.NewContext(r)
  46. limit := g.Cfg().GetInt("rate_limiter.count")
  47. result, err := limiter.Allow(ctx,
  48. userID, redis_rate.PerMinute(limit))
  49. if err != nil {
  50. gplus.ResError(r, errors.ErrInternalServer)
  51. }
  52. if result != nil {
  53. if result.Allowed == 0 {
  54. h := r.Response.Header()
  55. h.Set("X-RateLimit-Limit", strconv.FormatInt(int64(result.Limit.Burst), 10))
  56. h.Set("X-RateLimit-Remaining", strconv.FormatInt(int64(result.Remaining), 10))
  57. h.Set("X-RateLimit-Reset", strconv.FormatInt(int64(result.ResetAfter.Seconds()), 10))
  58. gplus.ResError(r, errors.ErrTooManyRequests)
  59. return
  60. }
  61. }
  62. r.Middleware.Next()
  63. }
  64. }