jwt_auth.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. package auth
  2. import (
  3. "encoding/json"
  4. "github.com/dgrijalva/jwt-go"
  5. "time"
  6. "yx-dataset-server/app/errors"
  7. )
  8. // 定义错误
  9. var (
  10. ErrInvalidToken = errors.ErrInvalidToken
  11. )
  12. // TokenInfo 令牌信息
  13. //goland:noinspection ALL
  14. type TokenInfo interface {
  15. // 获取访问令牌
  16. GetAccessToken() string
  17. // 获取令牌类型
  18. GetTokenType() string
  19. // 获取令牌到期时间戳
  20. GetExpiresAt() int64
  21. // JSON编码
  22. EncodeToJSON() ([]byte, error)
  23. }
  24. // tokenInfo 令牌信息
  25. type tokenInfo struct {
  26. AccessToken string `json:"access_token"` // 访问令牌
  27. TokenType string `json:"token_type"` // 令牌类型
  28. ExpiresAt int64 `json:"expires_at"` // 令牌到期时间
  29. }
  30. func (t *tokenInfo) GetAccessToken() string {
  31. return t.AccessToken
  32. }
  33. func (t *tokenInfo) GetTokenType() string {
  34. return t.TokenType
  35. }
  36. func (t *tokenInfo) GetExpiresAt() int64 {
  37. return t.ExpiresAt
  38. }
  39. func (t *tokenInfo) EncodeToJSON() ([]byte, error) {
  40. return json.Marshal(t)
  41. }
  42. type options struct {
  43. signingMethod jwt.SigningMethod
  44. signingKey interface{}
  45. keyfunc jwt.Keyfunc
  46. expired int
  47. tokenType string
  48. }
  49. const defaultKey = "yx-dataset-server"
  50. var defaultOptions = options{
  51. tokenType: "Bearer",
  52. expired: 86400,
  53. signingMethod: jwt.SigningMethodHS512,
  54. signingKey: []byte(defaultKey),
  55. keyfunc: func(t *jwt.Token) (interface{}, error) {
  56. if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
  57. return nil, errors.ErrInvalidToken
  58. }
  59. return []byte(defaultKey), nil
  60. },
  61. }
  62. // SetSigningMethod 设定签名方式
  63. func SetSigningMethod(method jwt.SigningMethod) Option {
  64. return func(o *options) {
  65. o.signingMethod = method
  66. }
  67. }
  68. // SetSigningKey 设定签名key
  69. func SetSigningKey(key interface{}) Option {
  70. return func(o *options) {
  71. o.signingKey = key
  72. }
  73. }
  74. // SetKeyfunc 设定验证key的回调函数
  75. func SetKeyfunc(keyFunc jwt.Keyfunc) Option {
  76. return func(o *options) {
  77. o.keyfunc = keyFunc
  78. }
  79. }
  80. // SetExpired 设定令牌过期时长(单位秒,默认7200)
  81. func SetExpired(expired int) Option {
  82. return func(o *options) {
  83. o.expired = expired
  84. }
  85. }
  86. type Option func(*options)
  87. type JWTAuth struct {
  88. opts *options
  89. store Storer
  90. }
  91. func New(store Storer, opts ...Option) *JWTAuth {
  92. o := defaultOptions
  93. for _, opt := range opts {
  94. opt(&o)
  95. }
  96. return &JWTAuth{
  97. opts: &o,
  98. store: store,
  99. }
  100. }
  101. // GenerateToken 生成令牌
  102. func (a *JWTAuth) GenerateToken(userID string) (TokenInfo, error) {
  103. now := time.Now()
  104. expiresAt := now.Add(time.Duration(a.opts.expired) * time.Second).Unix()
  105. token := jwt.NewWithClaims(a.opts.signingMethod, &jwt.StandardClaims{
  106. IssuedAt: now.Unix(),
  107. ExpiresAt: expiresAt,
  108. NotBefore: now.Unix(),
  109. Subject: userID,
  110. })
  111. tokenString, err := token.SignedString(a.opts.signingKey)
  112. if err != nil {
  113. return nil, err
  114. }
  115. err = a.callStore(func(store Storer) error {
  116. return store.Set(userID, tokenString, time.Duration(a.opts.expired)*time.Second)
  117. })
  118. tokenInfo := &tokenInfo{
  119. ExpiresAt: expiresAt,
  120. TokenType: a.opts.tokenType,
  121. AccessToken: tokenString,
  122. }
  123. return tokenInfo, nil
  124. }
  125. // 解析令牌
  126. func (a *JWTAuth) parseToken(tokenString string) (*jwt.StandardClaims, error) {
  127. token, _ := jwt.ParseWithClaims(tokenString, &jwt.StandardClaims{}, a.opts.keyfunc)
  128. if !token.Valid {
  129. return nil, errors.ErrInvalidToken
  130. }
  131. return token.Claims.(*jwt.StandardClaims), nil
  132. }
  133. func (a *JWTAuth) callStore(fn func(store Storer) error) error {
  134. if store := a.store; store != nil {
  135. return fn(store)
  136. }
  137. return nil
  138. }
  139. // DestroyToken 销毁令牌
  140. func (a *JWTAuth) DestroyToken(tokenString string) error {
  141. claims, err := a.parseToken(tokenString)
  142. if err != nil {
  143. return err
  144. }
  145. // 如果设定了存储,则将未过期的令牌放入
  146. err = a.callStore(func(store Storer) error {
  147. return store.Del(claims.Subject)
  148. })
  149. return err
  150. }
  151. // ParseUserID 解析用户ID
  152. func (a *JWTAuth) ParseUserID(tokenString string) (string, error) {
  153. claims, err := a.parseToken(tokenString)
  154. if err != nil {
  155. return "", err
  156. }
  157. err = a.callStore(func(store Storer) error {
  158. exists, err := store.Check(claims.Subject)
  159. if err != nil {
  160. return err
  161. } else if !exists {
  162. return ErrInvalidToken
  163. }
  164. return nil
  165. })
  166. if err != nil {
  167. return "", err
  168. }
  169. return claims.Subject, nil
  170. }
  171. // Auther 认证接口
  172. type Auther interface {
  173. // 生成令牌
  174. GenerateToken(userID string) (TokenInfo, error)
  175. // 销毁令牌
  176. DestroyToken(accessToken string) error
  177. // 解析用户ID
  178. ParseUserID(accessToken string) (string, error)
  179. // 释放资源
  180. //Release() error
  181. //GenerateApiToken(userID string) (TokenInfo, error)
  182. }