jwt.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package jwt
  2. import (
  3. "errors"
  4. "fmt"
  5. "strings"
  6. "time"
  7. "github.com/kataras/iris/v12"
  8. "github.com/kataras/iris/v12/context"
  9. "github.com/golang-jwt/jwt/v4"
  10. )
  11. func init() {
  12. context.SetHandlerName("github.com/iris-contrib/middleware/jwt.*", "iris-contrib.jwt")
  13. }
  14. type (
  15. // Token for JWT. Different fields will be used depending on whether you're
  16. // creating or parsing/verifying a token.
  17. //
  18. // A type alias for jwt.Token.
  19. Token = jwt.Token
  20. // MapClaims type that uses the map[string]interface{} for JSON decoding
  21. // This is the default claims type if you don't supply one
  22. //
  23. // A type alias for jwt.MapClaims.
  24. MapClaims = jwt.MapClaims
  25. // Claims must just have a Valid method that determines
  26. // if the token is invalid for any supported reason.
  27. //
  28. // A type alias for jwt.Claims.
  29. Claims = jwt.Claims
  30. )
  31. // Shortcuts to create a new Token.
  32. var (
  33. NewToken = jwt.New
  34. NewTokenWithClaims = jwt.NewWithClaims
  35. )
  36. // HS256 and company.
  37. var (
  38. SigningMethodHS256 = jwt.SigningMethodHS256
  39. SigningMethodHS384 = jwt.SigningMethodHS384
  40. SigningMethodHS512 = jwt.SigningMethodHS512
  41. )
  42. // ECDSA - EC256 and company.
  43. var (
  44. SigningMethodES256 = jwt.SigningMethodES256
  45. SigningMethodES384 = jwt.SigningMethodES384
  46. SigningMethodES512 = jwt.SigningMethodES512
  47. )
  48. // A function called whenever an error is encountered
  49. type errorHandler func(iris.Context, error)
  50. // TokenExtractor is a function that takes a context as input and returns
  51. // either a token or an error. An error should only be returned if an attempt
  52. // to specify a token was found, but the information was somehow incorrectly
  53. // formed. In the case where a token is simply not present, this should not
  54. // be treated as an error. An empty string should be returned in that case.
  55. type TokenExtractor func(iris.Context) (string, error)
  56. // Middleware the middleware for JSON Web tokens authentication method
  57. type Middleware struct {
  58. Config Config
  59. }
  60. // OnError is the default error handler.
  61. // Use it to change the behavior for each error.
  62. // See `Config.ErrorHandler`.
  63. func OnError(ctx iris.Context, err error) {
  64. if err == nil {
  65. return
  66. }
  67. ctx.StopExecution()
  68. ctx.StatusCode(iris.StatusUnauthorized)
  69. ctx.WriteString(err.Error())
  70. }
  71. // New constructs a new Secure instance with supplied options.
  72. func New(cfg ...Config) *Middleware {
  73. var c Config
  74. if len(cfg) == 0 {
  75. c = Config{}
  76. } else {
  77. c = cfg[0]
  78. }
  79. if c.ContextKey == "" {
  80. c.ContextKey = DefaultContextKey
  81. }
  82. if c.ErrorHandler == nil {
  83. c.ErrorHandler = OnError
  84. }
  85. if c.Extractor == nil {
  86. c.Extractor = FromAuthHeader
  87. }
  88. return &Middleware{Config: c}
  89. }
  90. func logf(ctx iris.Context, format string, args ...interface{}) {
  91. ctx.Application().Logger().Debugf(format, args...)
  92. }
  93. // Get returns the user (&token) information for this client/request
  94. func (m *Middleware) Get(ctx iris.Context) *jwt.Token {
  95. v := ctx.Values().Get(m.Config.ContextKey)
  96. if v == nil {
  97. return nil
  98. }
  99. return v.(*jwt.Token)
  100. }
  101. // Serve the middleware's action
  102. func (m *Middleware) Serve(ctx iris.Context) {
  103. if err := m.CheckJWT(ctx); err != nil {
  104. m.Config.ErrorHandler(ctx, err)
  105. return
  106. }
  107. // If everything ok then call next.
  108. ctx.Next()
  109. }
  110. // FromAuthHeader is a "TokenExtractor" that takes a give context and extracts
  111. // the JWT token from the Authorization header.
  112. func FromAuthHeader(ctx iris.Context) (string, error) {
  113. authHeader := ctx.GetHeader("Authorization")
  114. if authHeader == "" {
  115. return "", nil // No error, just no token
  116. }
  117. // TODO: Make this a bit more robust, parsing-wise
  118. authHeaderParts := strings.Split(authHeader, " ")
  119. if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
  120. return "", fmt.Errorf("authorization header format must be Bearer {token}")
  121. }
  122. return authHeaderParts[1], nil
  123. }
  124. // FromParameter returns a function that extracts the token from the specified
  125. // query string parameter
  126. func FromParameter(param string) TokenExtractor {
  127. return func(ctx iris.Context) (string, error) {
  128. return ctx.URLParam(param), nil
  129. }
  130. }
  131. // FromFirst returns a function that runs multiple token extractors and takes the
  132. // first token it finds
  133. func FromFirst(extractors ...TokenExtractor) TokenExtractor {
  134. return func(ctx iris.Context) (string, error) {
  135. for _, ex := range extractors {
  136. token, err := ex(ctx)
  137. if err != nil {
  138. return "", err
  139. }
  140. if token != "" {
  141. return token, nil
  142. }
  143. }
  144. return "", nil
  145. }
  146. }
  147. var (
  148. // ErrTokenMissing is the error value that it's returned when
  149. // a token is not found based on the token extractor.
  150. ErrTokenMissing = errors.New("required authorization token not found")
  151. // ErrTokenInvalid is the error value that it's returned when
  152. // a token is not valid.
  153. ErrTokenInvalid = errors.New("token is invalid")
  154. // ErrTokenExpired is the error value that it's returned when
  155. // a token value is found and it's valid but it's expired.
  156. ErrTokenExpired = errors.New("token is expired")
  157. )
  158. var jwtParser = new(jwt.Parser)
  159. // CheckJWT the main functionality, checks for token
  160. func (m *Middleware) CheckJWT(ctx iris.Context) error {
  161. if !m.Config.EnableAuthOnOptions {
  162. if ctx.Method() == iris.MethodOptions {
  163. return nil
  164. }
  165. }
  166. // Use the specified token extractor to extract a token from the request
  167. token, err := m.Config.Extractor(ctx)
  168. // If debugging is turned on, log the outcome
  169. if err != nil {
  170. logf(ctx, "Error extracting JWT: %v", err)
  171. return err
  172. }
  173. logf(ctx, "Token extracted: %s", token)
  174. // If the token is empty...
  175. if token == "" {
  176. // Check if it was required
  177. if m.Config.CredentialsOptional {
  178. logf(ctx, "No credentials found (CredentialsOptional=true)")
  179. // No error, just no token (and that is ok given that CredentialsOptional is true)
  180. return nil
  181. }
  182. // If we get here, the required token is missing
  183. logf(ctx, "Error: No credentials found (CredentialsOptional=false)")
  184. return ErrTokenMissing
  185. }
  186. // Now parse the token
  187. parsedToken, err := jwtParser.Parse(token, m.Config.ValidationKeyGetter)
  188. // Check if there was an error in parsing...
  189. if err != nil {
  190. logf(ctx, "Error parsing token: %v", err)
  191. return err
  192. }
  193. if m.Config.SigningMethod != nil && m.Config.SigningMethod.Alg() != parsedToken.Header["alg"] {
  194. err := fmt.Errorf("expected %s signing method but token specified %s",
  195. m.Config.SigningMethod.Alg(),
  196. parsedToken.Header["alg"])
  197. logf(ctx, "Error validating token algorithm: %v", err)
  198. return err
  199. }
  200. // Check if the parsed token is valid...
  201. if !parsedToken.Valid {
  202. logf(ctx, "Token is invalid")
  203. // m.Config.ErrorHandler(ctx, ErrTokenInvalid)
  204. return ErrTokenInvalid
  205. }
  206. if m.Config.Expiration {
  207. if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
  208. if expired := claims.VerifyExpiresAt(time.Now().Unix(), true); !expired {
  209. logf(ctx, "Token is expired")
  210. return ErrTokenExpired
  211. }
  212. }
  213. }
  214. logf(ctx, "JWT: %v", parsedToken)
  215. // If we get here, everything worked and we can set the
  216. // user property in context.
  217. ctx.Values().Set(m.Config.ContextKey, parsedToken)
  218. return nil
  219. }