oauth.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. package providers
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "net/url"
  8. "time"
  9. httputil "github.com/aliyun/credentials-go/credentials/internal/http"
  10. "github.com/aliyun/credentials-go/credentials/internal/utils"
  11. )
  12. // OAuthTokenUpdateCallback 定义OAuth令牌更新回调函数类型
  13. type OAuthTokenUpdateCallback func(refreshToken, accessToken, accessKey, secret, securityToken string, accessTokenExpire, stsExpire int64) error
  14. type oauthCredentialResponse struct {
  15. AccessKeyId string `json:"accessKeyId"`
  16. AccessKeySecret string `json:"accessKeySecret"`
  17. SecurityToken string `json:"securityToken"`
  18. Expiration string `json:"expiration"`
  19. RequestId string `json:"requestId"`
  20. }
  21. type oauthRefreshTokenResponse struct {
  22. AccessToken string `json:"access_token"`
  23. RefreshToken string `json:"refresh_token"`
  24. ExpiresIn int64 `json:"expires_in"`
  25. TokenType string `json:"token_type"`
  26. }
  27. type OAuthCredentialsProvider struct {
  28. clientId string
  29. signInUrl string
  30. refreshToken string
  31. accessToken string
  32. accessTokenExpire int64
  33. lastUpdateTimestamp int64
  34. expirationTimestamp int64
  35. sessionCredentials *sessionCredentials
  36. // for http options
  37. httpOptions *HttpOptions
  38. // OAuth token call back
  39. tokenUpdateCallback OAuthTokenUpdateCallback
  40. }
  41. type OAuthCredentialsProviderBuilder struct {
  42. provider *OAuthCredentialsProvider
  43. }
  44. func NewOAuthCredentialsProviderBuilder() *OAuthCredentialsProviderBuilder {
  45. return &OAuthCredentialsProviderBuilder{
  46. provider: &OAuthCredentialsProvider{},
  47. }
  48. }
  49. func (b *OAuthCredentialsProviderBuilder) WithClientId(clientId string) *OAuthCredentialsProviderBuilder {
  50. b.provider.clientId = clientId
  51. return b
  52. }
  53. func (b *OAuthCredentialsProviderBuilder) WithSignInUrl(signInUrl string) *OAuthCredentialsProviderBuilder {
  54. b.provider.signInUrl = signInUrl
  55. return b
  56. }
  57. func (b *OAuthCredentialsProviderBuilder) WithRefreshToken(refreshToken string) *OAuthCredentialsProviderBuilder {
  58. b.provider.refreshToken = refreshToken
  59. return b
  60. }
  61. func (b *OAuthCredentialsProviderBuilder) WithAccessToken(accessToken string) *OAuthCredentialsProviderBuilder {
  62. b.provider.accessToken = accessToken
  63. return b
  64. }
  65. func (b *OAuthCredentialsProviderBuilder) WithAccessTokenExpire(accessTokenExpire int64) *OAuthCredentialsProviderBuilder {
  66. b.provider.accessTokenExpire = accessTokenExpire
  67. return b
  68. }
  69. func (b *OAuthCredentialsProviderBuilder) WithHttpOptions(httpOptions *HttpOptions) *OAuthCredentialsProviderBuilder {
  70. b.provider.httpOptions = httpOptions
  71. return b
  72. }
  73. func (b *OAuthCredentialsProviderBuilder) WithTokenUpdateCallback(callback OAuthTokenUpdateCallback) *OAuthCredentialsProviderBuilder {
  74. b.provider.tokenUpdateCallback = callback
  75. return b
  76. }
  77. func (b *OAuthCredentialsProviderBuilder) Build() (provider *OAuthCredentialsProvider, err error) {
  78. if b.provider.clientId == "" {
  79. err = errors.New("the ClientId is empty")
  80. return
  81. }
  82. if b.provider.signInUrl == "" {
  83. err = errors.New("the url for sign-in is empty")
  84. return
  85. }
  86. provider = b.provider
  87. return
  88. }
  89. func (provider *OAuthCredentialsProvider) getCredentials() (session *sessionCredentials, err error) {
  90. // 仅在 refreshToken 存在时尝试刷新 accessToken
  91. // 若 refreshToken 不存在,则直接使用当前 accessToken 去交换 accessKeyId,由服务端判断是否有效
  92. if provider.refreshToken != "" && (provider.accessToken == "" || provider.accessTokenExpire == 0 || provider.accessTokenExpire-time.Now().Unix() <= 1200) {
  93. err = provider.tryRefreshOauthToken()
  94. if err != nil {
  95. return nil, err
  96. }
  97. }
  98. url, err := url.Parse(provider.signInUrl)
  99. if err != nil {
  100. return nil, err
  101. }
  102. req := &httputil.Request{
  103. Method: "POST",
  104. Protocol: url.Scheme,
  105. Host: url.Host,
  106. Path: "/v1/exchange",
  107. Headers: map[string]string{},
  108. }
  109. connectTimeout := 5 * time.Second
  110. readTimeout := 10 * time.Second
  111. if provider.httpOptions != nil && provider.httpOptions.ConnectTimeout > 0 {
  112. connectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Millisecond
  113. }
  114. if provider.httpOptions != nil && provider.httpOptions.ReadTimeout > 0 {
  115. readTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Millisecond
  116. }
  117. if provider.httpOptions != nil && provider.httpOptions.Proxy != "" {
  118. req.Proxy = provider.httpOptions.Proxy
  119. }
  120. req.ConnectTimeout = connectTimeout
  121. req.ReadTimeout = readTimeout
  122. // set headers
  123. req.Headers["Content-Type"] = "application/json"
  124. req.Headers["Authorization"] = fmt.Sprintf("Bearer %s", provider.accessToken)
  125. res, err := httpDo(req)
  126. if err != nil {
  127. return
  128. }
  129. if res.StatusCode != http.StatusOK {
  130. message := "get session token from OAuth failed: "
  131. err = errors.New(message + string(res.Body))
  132. return
  133. }
  134. var data oauthCredentialResponse
  135. err = json.Unmarshal(res.Body, &data)
  136. if err != nil {
  137. err = fmt.Errorf("get session token from OAuth failed, json.Unmarshal fail: %s", err.Error())
  138. return
  139. }
  140. if data.AccessKeyId == "" || data.AccessKeySecret == "" || data.SecurityToken == "" {
  141. err = fmt.Errorf("refresh session token err, fail to get credentials from OAuth: " + string(res.Body))
  142. return
  143. }
  144. session = &sessionCredentials{
  145. AccessKeyId: data.AccessKeyId,
  146. AccessKeySecret: data.AccessKeySecret,
  147. SecurityToken: data.SecurityToken,
  148. Expiration: data.Expiration,
  149. }
  150. return
  151. }
  152. func (provider *OAuthCredentialsProvider) tryRefreshOauthToken() (err error) {
  153. refreshToken := provider.refreshToken
  154. clientId := provider.clientId
  155. url, err := url.Parse(provider.signInUrl)
  156. if err != nil {
  157. return
  158. }
  159. req := &httputil.Request{
  160. Method: "POST",
  161. Protocol: url.Scheme,
  162. Host: url.Host,
  163. Path: "/v1/token",
  164. Headers: map[string]string{},
  165. }
  166. connectTimeout := 5 * time.Second
  167. readTimeout := 10 * time.Second
  168. if provider.httpOptions != nil && provider.httpOptions.ConnectTimeout > 0 {
  169. connectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Millisecond
  170. }
  171. if provider.httpOptions != nil && provider.httpOptions.ReadTimeout > 0 {
  172. readTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Millisecond
  173. }
  174. if provider.httpOptions != nil && provider.httpOptions.Proxy != "" {
  175. req.Proxy = provider.httpOptions.Proxy
  176. }
  177. req.ConnectTimeout = connectTimeout
  178. req.ReadTimeout = readTimeout
  179. bodyForm := make(map[string]string)
  180. bodyForm["grant_type"] = "refresh_token"
  181. bodyForm["refresh_token"] = refreshToken
  182. bodyForm["client_id"] = clientId
  183. bodyForm["Timestamp"] = utils.GetTimeInFormatISO8601()
  184. req.Form = bodyForm
  185. req.Headers["Content-Type"] = "application/x-www-form-urlencoded"
  186. resp, err := httpDo(req)
  187. if err != nil {
  188. return
  189. }
  190. if resp.StatusCode != http.StatusOK {
  191. return fmt.Errorf("failed to refresh token, status code: %d", resp.StatusCode)
  192. }
  193. var tokenResp oauthRefreshTokenResponse
  194. err = json.Unmarshal(resp.Body, &tokenResp)
  195. if err != nil {
  196. err = fmt.Errorf("get refresh token from OAuth failed, json.Unmarshal fail: %s", err.Error())
  197. return
  198. }
  199. if tokenResp.RefreshToken == "" || tokenResp.AccessToken == "" {
  200. err = fmt.Errorf("failed to refresh token from OAuth: " + string(resp.Body))
  201. return
  202. }
  203. provider.accessToken = tokenResp.AccessToken
  204. provider.refreshToken = tokenResp.RefreshToken
  205. provider.accessTokenExpire = time.Now().Unix() + tokenResp.ExpiresIn
  206. return nil
  207. }
  208. func (provider *OAuthCredentialsProvider) needUpdateCredential() (result bool) {
  209. if provider.expirationTimestamp == 0 {
  210. return true
  211. }
  212. return provider.expirationTimestamp-time.Now().Unix() <= 180
  213. }
  214. func (provider *OAuthCredentialsProvider) GetCredentials() (cc *Credentials, err error) {
  215. if provider.sessionCredentials == nil || provider.needUpdateCredential() {
  216. sessionCredentials, err1 := provider.getCredentials()
  217. if err1 != nil {
  218. return nil, err1
  219. }
  220. provider.sessionCredentials = sessionCredentials
  221. expirationTime, err2 := time.Parse("2006-01-02T15:04:05Z", sessionCredentials.Expiration)
  222. if err2 != nil {
  223. return nil, err2
  224. }
  225. provider.lastUpdateTimestamp = time.Now().Unix()
  226. provider.expirationTimestamp = expirationTime.Unix()
  227. // 如果设置了回调函数,则调用回调函数写回配置文件
  228. if provider.tokenUpdateCallback != nil {
  229. err1 := provider.tokenUpdateCallback(provider.refreshToken, provider.accessToken, sessionCredentials.AccessKeyId, sessionCredentials.AccessKeySecret, sessionCredentials.SecurityToken, provider.accessTokenExpire, provider.expirationTimestamp)
  230. if err1 != nil {
  231. fmt.Printf("Warning: failed to update OAuth tokens in config file: %v\n", err)
  232. }
  233. }
  234. }
  235. cc = &Credentials{
  236. AccessKeyId: provider.sessionCredentials.AccessKeyId,
  237. AccessKeySecret: provider.sessionCredentials.AccessKeySecret,
  238. SecurityToken: provider.sessionCredentials.SecurityToken,
  239. ProviderName: provider.GetProviderName(),
  240. }
  241. return
  242. }
  243. func (provider *OAuthCredentialsProvider) GetProviderName() string {
  244. return "oauth"
  245. }