jwt.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. package main
  2. import (
  3. "errors"
  4. "fmt"
  5. "log"
  6. "net/http"
  7. "strings"
  8. "time"
  9. "github.com/go-martini/martini"
  10. jwt "github.com/dgrijalva/jwt-go"
  11. )
  12. const (
  13. //DefaultContextKey jwt
  14. DefaultContextKey = "jwt"
  15. )
  16. // errorHandler error callback
  17. type errorHandler func(http.ResponseWriter, string)
  18. // TokenExtractor TokenExtractor
  19. type TokenExtractor func(*http.Request) (string, error)
  20. // Config config
  21. type Config struct {
  22. ValidationKeyGetter jwt.Keyfunc
  23. ContextKey string
  24. ErrorHandler errorHandler
  25. CredentialsOptional bool
  26. Extractor TokenExtractor
  27. // Default: true
  28. Debug bool
  29. // Default: false
  30. EnableAuthOnOptions bool
  31. // Default: nil
  32. SigningMethod jwt.SigningMethod
  33. // Default: false
  34. Expiration bool
  35. }
  36. // Middleware the middleware for JSON Web tokens authentication method
  37. type Middleware struct {
  38. Config Config
  39. }
  40. // OnError default error handler
  41. func OnError(res http.ResponseWriter, err string) {
  42. http.Error(res, err, http.StatusUnauthorized)
  43. }
  44. // New constructs a new Secure instance with supplied options.
  45. func New(cfg ...Config) *Middleware {
  46. var c Config
  47. if len(cfg) == 0 {
  48. c = Config{}
  49. } else {
  50. c = cfg[0]
  51. }
  52. if c.ContextKey == "" {
  53. c.ContextKey = DefaultContextKey
  54. }
  55. if c.ErrorHandler == nil {
  56. c.ErrorHandler = OnError
  57. }
  58. if c.Extractor == nil {
  59. c.Extractor = FromAuthHeader
  60. }
  61. return &Middleware{Config: c}
  62. }
  63. func (m *Middleware) logf(format string, args ...interface{}) {
  64. if m.Config.Debug {
  65. log.Printf(format, args...)
  66. }
  67. }
  68. // Get returns the user (&token) information for this client/request
  69. // func (m *Middleware) Get(req *http.Request) *jwt.Token {
  70. // return req.Header.Get(m.Config.ContextKey).(*jwt.Token)
  71. // }
  72. // Serve the middleware's action
  73. func (m *Middleware) Serve(ctx martini.Context, res http.ResponseWriter, req *http.Request) {
  74. if err := m.CheckJWT(req, res, ctx); err != nil {
  75. return
  76. }
  77. // If everything ok then call next.
  78. ctx.Next()
  79. }
  80. // FromAuthHeader is a "TokenExtractor" that takes a give context and extracts
  81. // the JWT token from the Authorization header.
  82. func FromAuthHeader(req *http.Request) (string, error) {
  83. authHeader := req.Header.Get("Authorization")
  84. if authHeader == "" {
  85. return "", nil // No error, just no token
  86. }
  87. // TODO: Make this a bit more robust, parsing-wise
  88. authHeaderParts := strings.Split(authHeader, " ")
  89. if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
  90. return "", fmt.Errorf("Authorization header format must be Bearer {token}")
  91. }
  92. return authHeaderParts[1], nil
  93. }
  94. // CheckJWT the main functionality, checks for token
  95. func (m *Middleware) CheckJWT(req *http.Request, res http.ResponseWriter, ctx martini.Context) error {
  96. if !m.Config.EnableAuthOnOptions {
  97. if req.Method == http.MethodOptions {
  98. return nil
  99. }
  100. }
  101. // Use the specified token extractor to extract a token from the request
  102. token, err := m.Config.Extractor(req)
  103. // If debugging is turned on, log the outcome
  104. if err != nil {
  105. m.logf("Error extracting JWT: %v", err)
  106. } else {
  107. m.logf("Token extracted: %s", token)
  108. }
  109. // If an error occurs, call the error handler and return an error
  110. if err != nil {
  111. m.Config.ErrorHandler(res, err.Error())
  112. return fmt.Errorf("Error extracting token: %v", err)
  113. }
  114. // If the token is empty...
  115. if token == "" {
  116. // Check if it was required
  117. if m.Config.CredentialsOptional {
  118. m.logf(" No credentials found (CredentialsOptional=true)")
  119. // No error, just no token (and that is ok given that CredentialsOptional is true)
  120. return nil
  121. }
  122. // If we get here, the required token is missing
  123. errorMsg := "Required authorization token not found"
  124. m.Config.ErrorHandler(res, errorMsg)
  125. m.logf(" Error: No credentials found (CredentialsOptional=false)")
  126. return fmt.Errorf(errorMsg)
  127. }
  128. // Now parse the token
  129. parsedToken, err := jwt.Parse(token, m.Config.ValidationKeyGetter)
  130. // Check if there was an error in parsing...
  131. if err != nil {
  132. m.logf("Error parsing token: %v", err)
  133. m.Config.ErrorHandler(res, err.Error())
  134. return fmt.Errorf("Error parsing token: %v", err)
  135. }
  136. if m.Config.SigningMethod != nil && m.Config.SigningMethod.Alg() != parsedToken.Header["alg"] {
  137. message := fmt.Sprintf("Expected %s signing method but token specified %s",
  138. m.Config.SigningMethod.Alg(),
  139. parsedToken.Header["alg"])
  140. m.logf("Error validating token algorithm: %s", message)
  141. m.Config.ErrorHandler(res, errors.New(message).Error())
  142. return fmt.Errorf("Error validating token algorithm: %s", message)
  143. }
  144. // Check if the parsed token is valid...
  145. if !parsedToken.Valid {
  146. m.logf("Token is invalid")
  147. m.Config.ErrorHandler(res, "The token isn't valid")
  148. return fmt.Errorf("Token is invalid")
  149. }
  150. if m.Config.Expiration {
  151. if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
  152. if expired := claims.VerifyExpiresAt(time.Now().Unix(), true); !expired {
  153. return fmt.Errorf("Token is expired")
  154. }
  155. }
  156. }
  157. m.logf("JWT: %v", parsedToken)
  158. // If we get here, everything worked and we can set the
  159. // user property in context.
  160. ctx.Map(parseToken(parsedToken))
  161. return nil
  162. }
  163. // 把token解析成一个usertoken对象
  164. func parseToken(token *jwt.Token) *UserToken {
  165. claims := token.Claims.(jwt.MapClaims)
  166. ut := &UserToken{
  167. UserID: uint(claims["UserID"].(float64)),
  168. UserName: claims["UserName"].(string),
  169. UserType: claims["UserType"].(string),
  170. RoleCode: int(claims["RoleCode"].(float64)),
  171. VendorID: uint(claims["VendorID"].(float64)),
  172. }
  173. return ut
  174. }