token.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package dingtalk
  2. import (
  3. "context"
  4. "sync"
  5. "time"
  6. )
  7. // 钉钉 access_token 缓存相关常量
  8. const (
  9. // tokenRefreshBuffer token 过期前提前刷新的时间余量
  10. tokenRefreshBuffer = 2 * time.Minute
  11. // tokenMaxTTL 兜底 TTL(钉钉 token 有效期一般 7200s)
  12. tokenMaxTTL = 2 * time.Hour
  13. )
  14. // TokenStorage 可插拔的外部 Token 存储(如 Redis),接口保持最小
  15. // 实现需保证并发安全。返回的 token 为空串视为"未命中/已过期"。
  16. type TokenStorage interface {
  17. GetToken(ctx context.Context, clientID string) (token string, err error)
  18. SetToken(ctx context.Context, clientID string, token string, ttl time.Duration) error
  19. }
  20. // tokenEntry 内存 L1 缓存单元
  21. type tokenEntry struct {
  22. token string
  23. expiresAt time.Time
  24. }
  25. // tokenManager Token 管理器:L1 内存 + L2 可插拔(如 Redis),并带 singleflight 去重
  26. type tokenManager struct {
  27. mu sync.RWMutex
  28. entries map[string]tokenEntry // key: clientID
  29. inflight map[string]*inflightCall
  30. inflightMu sync.Mutex
  31. storage TokenStorage // 可为 nil
  32. }
  33. type inflightCall struct {
  34. done chan struct{}
  35. token string
  36. err error
  37. }
  38. var (
  39. globalTokenManager = &tokenManager{
  40. entries: make(map[string]tokenEntry),
  41. inflight: make(map[string]*inflightCall),
  42. }
  43. tokenStorageMu sync.RWMutex
  44. )
  45. // SetTokenStorage 注入外部 Token 存储(如 Redis),通常在 boot 阶段调用一次
  46. func SetTokenStorage(storage TokenStorage) {
  47. tokenStorageMu.Lock()
  48. defer tokenStorageMu.Unlock()
  49. globalTokenManager.storage = storage
  50. }
  51. // getStorage 安全读取当前 TokenStorage
  52. func (m *tokenManager) getStorage() TokenStorage {
  53. tokenStorageMu.RLock()
  54. defer tokenStorageMu.RUnlock()
  55. return m.storage
  56. }
  57. // getFromMem 从 L1 读取有效 token
  58. func (m *tokenManager) getFromMem(clientID string) (string, bool) {
  59. m.mu.RLock()
  60. defer m.mu.RUnlock()
  61. e, ok := m.entries[clientID]
  62. if !ok {
  63. return "", false
  64. }
  65. if time.Now().After(e.expiresAt) {
  66. return "", false
  67. }
  68. return e.token, true
  69. }
  70. // setToMem 写入 L1
  71. func (m *tokenManager) setToMem(clientID, token string, ttl time.Duration) {
  72. m.mu.Lock()
  73. defer m.mu.Unlock()
  74. m.entries[clientID] = tokenEntry{
  75. token: token,
  76. expiresAt: time.Now().Add(ttl),
  77. }
  78. }
  79. // Fetch 获取 access_token:按 L1 → L2(storage) → refresh 顺序查找,带 singleflight
  80. // refreshFn 由调用方(dingtalk.Client)注入,返回 (token, ttl, err)
  81. func (m *tokenManager) Fetch(
  82. ctx context.Context,
  83. clientID string,
  84. refreshFn func(ctx context.Context) (string, time.Duration, error),
  85. ) (string, error) {
  86. // L1
  87. if tk, ok := m.getFromMem(clientID); ok {
  88. return tk, nil
  89. }
  90. // L2
  91. if storage := m.getStorage(); storage != nil {
  92. if tk, err := storage.GetToken(ctx, clientID); err == nil && tk != "" {
  93. // 回填 L1(保守给一个最大 TTL,由 L2 过期决定真实生命周期)
  94. m.setToMem(clientID, tk, tokenMaxTTL-tokenRefreshBuffer)
  95. return tk, nil
  96. }
  97. }
  98. // Refresh (singleflight)
  99. m.inflightMu.Lock()
  100. call, exists := m.inflight[clientID]
  101. if !exists {
  102. call = &inflightCall{done: make(chan struct{})}
  103. m.inflight[clientID] = call
  104. m.inflightMu.Unlock()
  105. go func() {
  106. defer func() {
  107. m.inflightMu.Lock()
  108. delete(m.inflight, clientID)
  109. m.inflightMu.Unlock()
  110. close(call.done)
  111. }()
  112. token, ttl, err := refreshFn(ctx)
  113. if err != nil {
  114. call.err = err
  115. return
  116. }
  117. effectiveTTL := ttl - tokenRefreshBuffer
  118. if effectiveTTL <= 0 {
  119. effectiveTTL = tokenMaxTTL - tokenRefreshBuffer
  120. }
  121. m.setToMem(clientID, token, effectiveTTL)
  122. if storage := m.getStorage(); storage != nil {
  123. _ = storage.SetToken(ctx, clientID, token, effectiveTTL)
  124. }
  125. call.token = token
  126. }()
  127. } else {
  128. m.inflightMu.Unlock()
  129. }
  130. select {
  131. case <-call.done:
  132. return call.token, call.err
  133. case <-ctx.Done():
  134. return "", ctx.Err()
  135. }
  136. }
  137. // Invalidate 主动失效某个 clientID 的缓存(例如收到 40014 InvalidAuthentication 时)
  138. func (m *tokenManager) Invalidate(ctx context.Context, clientID string) {
  139. m.mu.Lock()
  140. delete(m.entries, clientID)
  141. m.mu.Unlock()
  142. if storage := m.getStorage(); storage != nil {
  143. _ = storage.SetToken(ctx, clientID, "", time.Second)
  144. }
  145. }