|
@@ -0,0 +1,214 @@
|
|
|
|
+package main
|
|
|
|
+
|
|
|
|
+import (
|
|
|
|
+ "errors"
|
|
|
|
+ "fmt"
|
|
|
|
+ "log"
|
|
|
|
+ "net/http"
|
|
|
|
+ "strings"
|
|
|
|
+ "time"
|
|
|
|
+
|
|
|
|
+ "github.com/go-martini/martini"
|
|
|
|
+
|
|
|
|
+ jwt "github.com/dgrijalva/jwt-go"
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+const (
|
|
|
|
+ //DefaultContextKey jwt
|
|
|
|
+ DefaultContextKey = "jwt"
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+// errorHandler error callback
|
|
|
|
+type errorHandler func(http.ResponseWriter, string)
|
|
|
|
+
|
|
|
|
+// TokenExtractor TokenExtractor
|
|
|
|
+type TokenExtractor func(*http.Request) (string, error)
|
|
|
|
+
|
|
|
|
+// Config config
|
|
|
|
+type Config struct {
|
|
|
|
+ ValidationKeyGetter jwt.Keyfunc
|
|
|
|
+
|
|
|
|
+ ContextKey string
|
|
|
|
+
|
|
|
|
+ ErrorHandler errorHandler
|
|
|
|
+
|
|
|
|
+ CredentialsOptional bool
|
|
|
|
+
|
|
|
|
+ Extractor TokenExtractor
|
|
|
|
+ // Default: true
|
|
|
|
+ Debug bool
|
|
|
|
+
|
|
|
|
+ // Default: false
|
|
|
|
+ EnableAuthOnOptions bool
|
|
|
|
+
|
|
|
|
+ // Default: nil
|
|
|
|
+ SigningMethod jwt.SigningMethod
|
|
|
|
+
|
|
|
|
+ // Default: false
|
|
|
|
+ Expiration bool
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// Middleware the middleware for JSON Web tokens authentication method
|
|
|
|
+type Middleware struct {
|
|
|
|
+ Config Config
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// OnError default error handler
|
|
|
|
+func OnError(res http.ResponseWriter, err string) {
|
|
|
|
+ http.Error(res, err, http.StatusUnauthorized)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// New constructs a new Secure instance with supplied options.
|
|
|
|
+func New(cfg ...Config) *Middleware {
|
|
|
|
+
|
|
|
|
+ var c Config
|
|
|
|
+ if len(cfg) == 0 {
|
|
|
|
+ c = Config{}
|
|
|
|
+ } else {
|
|
|
|
+ c = cfg[0]
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if c.ContextKey == "" {
|
|
|
|
+ c.ContextKey = DefaultContextKey
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if c.ErrorHandler == nil {
|
|
|
|
+ c.ErrorHandler = OnError
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if c.Extractor == nil {
|
|
|
|
+ c.Extractor = FromAuthHeader
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return &Middleware{Config: c}
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (m *Middleware) logf(format string, args ...interface{}) {
|
|
|
|
+ if m.Config.Debug {
|
|
|
|
+ log.Printf(format, args...)
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// Get returns the user (&token) information for this client/request
|
|
|
|
+// func (m *Middleware) Get(req *http.Request) *jwt.Token {
|
|
|
|
+// return req.Header.Get(m.Config.ContextKey).(*jwt.Token)
|
|
|
|
+// }
|
|
|
|
+
|
|
|
|
+// Serve the middleware's action
|
|
|
|
+func (m *Middleware) Serve(ctx martini.Context, res http.ResponseWriter, req *http.Request) {
|
|
|
|
+ if err := m.CheckJWT(req, res, ctx); err != nil {
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ // If everything ok then call next.
|
|
|
|
+ ctx.Next()
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// FromAuthHeader is a "TokenExtractor" that takes a give context and extracts
|
|
|
|
+// the JWT token from the Authorization header.
|
|
|
|
+func FromAuthHeader(req *http.Request) (string, error) {
|
|
|
|
+ authHeader := req.Header.Get("Authorization")
|
|
|
|
+ if authHeader == "" {
|
|
|
|
+ return "", nil // No error, just no token
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // TODO: Make this a bit more robust, parsing-wise
|
|
|
|
+ authHeaderParts := strings.Split(authHeader, " ")
|
|
|
|
+ if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
|
|
|
|
+ return "", fmt.Errorf("Authorization header format must be Bearer {token}")
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return authHeaderParts[1], nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// CheckJWT the main functionality, checks for token
|
|
|
|
+func (m *Middleware) CheckJWT(req *http.Request, res http.ResponseWriter, ctx martini.Context) error {
|
|
|
|
+ if !m.Config.EnableAuthOnOptions {
|
|
|
|
+ if req.Method == http.MethodOptions {
|
|
|
|
+ return nil
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Use the specified token extractor to extract a token from the request
|
|
|
|
+ token, err := m.Config.Extractor(req)
|
|
|
|
+ // If debugging is turned on, log the outcome
|
|
|
|
+ if err != nil {
|
|
|
|
+ m.logf("Error extracting JWT: %v", err)
|
|
|
|
+ } else {
|
|
|
|
+ m.logf("Token extracted: %s", token)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // If an error occurs, call the error handler and return an error
|
|
|
|
+ if err != nil {
|
|
|
|
+ m.Config.ErrorHandler(res, err.Error())
|
|
|
|
+ return fmt.Errorf("Error extracting token: %v", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // If the token is empty...
|
|
|
|
+ if token == "" {
|
|
|
|
+ // Check if it was required
|
|
|
|
+ if m.Config.CredentialsOptional {
|
|
|
|
+ m.logf(" No credentials found (CredentialsOptional=true)")
|
|
|
|
+ // No error, just no token (and that is ok given that CredentialsOptional is true)
|
|
|
|
+ return nil
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // If we get here, the required token is missing
|
|
|
|
+ errorMsg := "Required authorization token not found"
|
|
|
|
+ m.Config.ErrorHandler(res, errorMsg)
|
|
|
|
+ m.logf(" Error: No credentials found (CredentialsOptional=false)")
|
|
|
|
+ return fmt.Errorf(errorMsg)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Now parse the token
|
|
|
|
+
|
|
|
|
+ parsedToken, err := jwt.Parse(token, m.Config.ValidationKeyGetter)
|
|
|
|
+ // Check if there was an error in parsing...
|
|
|
|
+ if err != nil {
|
|
|
|
+ m.logf("Error parsing token: %v", err)
|
|
|
|
+ m.Config.ErrorHandler(res, err.Error())
|
|
|
|
+ return fmt.Errorf("Error parsing token: %v", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if m.Config.SigningMethod != nil && m.Config.SigningMethod.Alg() != parsedToken.Header["alg"] {
|
|
|
|
+ message := fmt.Sprintf("Expected %s signing method but token specified %s",
|
|
|
|
+ m.Config.SigningMethod.Alg(),
|
|
|
|
+ parsedToken.Header["alg"])
|
|
|
|
+ m.logf("Error validating token algorithm: %s", message)
|
|
|
|
+ m.Config.ErrorHandler(res, errors.New(message).Error())
|
|
|
|
+ return fmt.Errorf("Error validating token algorithm: %s", message)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Check if the parsed token is valid...
|
|
|
|
+ if !parsedToken.Valid {
|
|
|
|
+ m.logf("Token is invalid")
|
|
|
|
+ m.Config.ErrorHandler(res, "The token isn't valid")
|
|
|
|
+ return fmt.Errorf("Token is invalid")
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if m.Config.Expiration {
|
|
|
|
+ if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
|
|
|
|
+ if expired := claims.VerifyExpiresAt(time.Now().Unix(), true); !expired {
|
|
|
|
+ return fmt.Errorf("Token is expired")
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ m.logf("JWT: %v", parsedToken)
|
|
|
|
+
|
|
|
|
+ // If we get here, everything worked and we can set the
|
|
|
|
+ // user property in context.
|
|
|
|
+ ctx.Map(parseToken(parsedToken))
|
|
|
|
+ return nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 把token解析成一个usertoken对象
|
|
|
|
+func parseToken(token *jwt.Token) *UserToken {
|
|
|
|
+ claims := token.Claims.(jwt.MapClaims)
|
|
|
|
+ ut := &UserToken{
|
|
|
|
+ UserID: uint(claims["UserID"].(float64)),
|
|
|
|
+ UserName: claims["UserName"].(string),
|
|
|
|
+ UserType: claims["UserType"].(string),
|
|
|
|
+ RoleCode: int(claims["RoleCode"].(float64)),
|
|
|
|
+ }
|
|
|
|
+ return ut
|
|
|
|
+}
|