client.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. package client
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "mime/multipart"
  10. "net/http"
  11. "net/url"
  12. "os"
  13. "strconv"
  14. "strings"
  15. "golang.org/x/time/rate"
  16. )
  17. // A Client is an HTTP client. Initialize with the New package-level function.
  18. type Client struct {
  19. HTTPClient *http.Client
  20. // BaseURL prepends to all requests.
  21. BaseURL string
  22. // A list of persistent request options.
  23. PersistentRequestOptions []RequestOption
  24. // Optional rate limiter instance initialized by the RateLimit method.
  25. rateLimiter *rate.Limiter
  26. // Optional handlers that are being fired before and after each new request.
  27. requestHandlers []RequestHandler
  28. // store it here for future use.
  29. keepAlive bool
  30. }
  31. // New returns a new Iris HTTP Client.
  32. // Available options:
  33. // - BaseURL
  34. // - Timeout
  35. // - PersistentRequestOptions
  36. // - RateLimit
  37. //
  38. // Look the Client.Do/JSON/... methods to send requests and
  39. // ReadXXX methods to read responses.
  40. //
  41. // The default content type to send and receive data is JSON.
  42. func New(opts ...Option) *Client {
  43. c := &Client{
  44. HTTPClient: &http.Client{},
  45. PersistentRequestOptions: defaultRequestOptions,
  46. requestHandlers: defaultRequestHandlers,
  47. }
  48. for _, opt := range opts {
  49. opt(c)
  50. }
  51. if transport, ok := c.HTTPClient.Transport.(*http.Transport); ok {
  52. c.keepAlive = !transport.DisableKeepAlives
  53. }
  54. return c
  55. }
  56. // RegisterRequestHandler registers one or more request handlers
  57. // to be ran before and after of each new request.
  58. //
  59. // Request handler's BeginRequest method run after each request constructed
  60. // and right before sent to the server.
  61. //
  62. // Request handler's EndRequest method run after response each received
  63. // and right before methods return back to the caller.
  64. //
  65. // Any request handlers MUST be set right after the Client's initialization.
  66. func (c *Client) RegisterRequestHandler(reqHandlers ...RequestHandler) {
  67. reqHandlersToRegister := make([]RequestHandler, 0, len(reqHandlers))
  68. for _, h := range reqHandlers {
  69. if h == nil {
  70. continue
  71. }
  72. reqHandlersToRegister = append(reqHandlersToRegister, h)
  73. }
  74. c.requestHandlers = append(c.requestHandlers, reqHandlersToRegister...)
  75. }
  76. func (c *Client) emitBeginRequest(ctx context.Context, req *http.Request) error {
  77. if len(c.requestHandlers) == 0 {
  78. return nil
  79. }
  80. for _, h := range c.requestHandlers {
  81. if hErr := h.BeginRequest(ctx, req); hErr != nil {
  82. return hErr
  83. }
  84. }
  85. return nil
  86. }
  87. func (c *Client) emitEndRequest(ctx context.Context, resp *http.Response, err error) error {
  88. if len(c.requestHandlers) == 0 {
  89. return nil
  90. }
  91. for _, h := range c.requestHandlers {
  92. if hErr := h.EndRequest(ctx, resp, err); hErr != nil {
  93. return hErr
  94. }
  95. }
  96. return err
  97. }
  98. // RequestOption declares the type of option one can pass
  99. // to the Do methods(JSON, Form, ReadJSON...).
  100. // Request options run before request constructed.
  101. type RequestOption = func(*http.Request) error
  102. // We always add the following request headers, unless they're removed by custom ones.
  103. var defaultRequestOptions = []RequestOption{
  104. RequestHeader(false, acceptKey, contentTypeJSON),
  105. }
  106. // RequestHeader adds or sets (if overridePrev is true) a header to the request.
  107. func RequestHeader(overridePrev bool, key string, values ...string) RequestOption {
  108. key = http.CanonicalHeaderKey(key)
  109. return func(req *http.Request) error {
  110. if overridePrev { // upsert.
  111. req.Header[key] = values
  112. } else { // just insert.
  113. req.Header[key] = append(req.Header[key], values...)
  114. }
  115. return nil
  116. }
  117. }
  118. // RequestAuthorization sets an Authorization request header.
  119. // Note that we could do the same with a Transport RoundDrip too.
  120. func RequestAuthorization(value string) RequestOption {
  121. return RequestHeader(true, "Authorization", value)
  122. }
  123. // RequestAuthorizationBearer sets an Authorization: Bearer $token request header.
  124. func RequestAuthorizationBearer(accessToken string) RequestOption {
  125. headerValue := "Bearer " + accessToken
  126. return RequestAuthorization(headerValue)
  127. }
  128. // RequestQuery adds a set of URL query parameters to the request.
  129. func RequestQuery(query url.Values) RequestOption {
  130. return func(req *http.Request) error {
  131. q := req.URL.Query()
  132. for k, v := range query {
  133. q[k] = v
  134. }
  135. req.URL.RawQuery = q.Encode()
  136. return nil
  137. }
  138. }
  139. // RequestParam sets a single URL query parameter to the request.
  140. func RequestParam(key string, values ...string) RequestOption {
  141. return RequestQuery(url.Values{
  142. key: values,
  143. })
  144. }
  145. // Do sends an HTTP request and returns an HTTP response.
  146. //
  147. // The payload can be:
  148. // - io.Reader
  149. // - raw []byte
  150. // - JSON raw message
  151. // - string
  152. // - struct (JSON).
  153. //
  154. // If method is empty then it defaults to "GET".
  155. // The final variadic, optional input argument sets
  156. // the custom request options to use before the request.
  157. //
  158. // Any HTTP returned error will be of type APIError
  159. // or a timeout error if the given context was canceled.
  160. func (c *Client) Do(ctx context.Context, method, urlpath string, payload interface{}, opts ...RequestOption) (*http.Response, error) {
  161. if ctx == nil {
  162. ctx = context.Background()
  163. }
  164. if c.rateLimiter != nil {
  165. if err := c.rateLimiter.Wait(ctx); err != nil {
  166. return nil, err
  167. }
  168. }
  169. // Method defaults to GET.
  170. if method == "" {
  171. method = http.MethodGet
  172. }
  173. // Find the payload, if any.
  174. var body io.Reader
  175. if payload != nil {
  176. switch v := payload.(type) {
  177. case io.Reader:
  178. body = v
  179. case []byte:
  180. body = bytes.NewBuffer(v)
  181. case json.RawMessage:
  182. body = bytes.NewBuffer(v)
  183. case string:
  184. body = strings.NewReader(v)
  185. case url.Values:
  186. body = strings.NewReader(v.Encode())
  187. default:
  188. w := new(bytes.Buffer)
  189. // We assume it's a struct, we wont make use of reflection to find out though.
  190. err := json.NewEncoder(w).Encode(v)
  191. if err != nil {
  192. return nil, err
  193. }
  194. body = w
  195. }
  196. }
  197. if c.BaseURL != "" {
  198. urlpath = c.BaseURL + urlpath // note that we don't do any special checks here, the caller is responsible.
  199. }
  200. // Initialize the request.
  201. req, err := http.NewRequestWithContext(ctx, method, urlpath, body)
  202. if err != nil {
  203. return nil, err
  204. }
  205. // We separate the error for the default options for now.
  206. for i, opt := range c.PersistentRequestOptions {
  207. if opt == nil {
  208. continue
  209. }
  210. if err = opt(req); err != nil {
  211. return nil, fmt.Errorf("client.Do: default request option[%d]: %w", i, err)
  212. }
  213. }
  214. // Apply any custom request options (e.g. content type, accept headers, query...)
  215. for _, opt := range opts {
  216. if opt == nil {
  217. continue
  218. }
  219. if err = opt(req); err != nil {
  220. return nil, err
  221. }
  222. }
  223. if err = c.emitBeginRequest(ctx, req); err != nil {
  224. return nil, err
  225. }
  226. // Caller is responsible for closing the response body.
  227. // Also note that the gzip compression is handled automatically nowadays.
  228. resp, respErr := c.HTTPClient.Do(req)
  229. if err = c.emitEndRequest(ctx, resp, respErr); err != nil {
  230. return nil, err
  231. }
  232. return resp, respErr
  233. }
  234. // DrainResponseBody drains response body and close it, allowing the transport to reuse TCP connections.
  235. // It's automatically called on Client.ReadXXX methods on the end.
  236. func (c *Client) DrainResponseBody(resp *http.Response) {
  237. _, _ = io.Copy(io.Discard, resp.Body)
  238. resp.Body.Close()
  239. }
  240. const (
  241. acceptKey = "Accept"
  242. contentTypeKey = "Content-Type"
  243. contentLengthKey = "Content-Length"
  244. contentTypePlainText = "plain/text"
  245. contentTypeJSON = "application/json"
  246. contentTypeFormURLEncoded = "application/x-www-form-urlencoded"
  247. )
  248. // JSON writes data as JSON to the server.
  249. func (c *Client) JSON(ctx context.Context, method, urlpath string, payload interface{}, opts ...RequestOption) (*http.Response, error) {
  250. opts = append(opts, RequestHeader(true, contentTypeKey, contentTypeJSON))
  251. return c.Do(ctx, method, urlpath, payload, opts...)
  252. }
  253. // JSON writes form data to the server.
  254. func (c *Client) Form(ctx context.Context, method, urlpath string, formValues url.Values, opts ...RequestOption) (*http.Response, error) {
  255. payload := formValues.Encode()
  256. opts = append(opts,
  257. RequestHeader(true, contentTypeKey, contentTypeFormURLEncoded),
  258. RequestHeader(true, contentLengthKey, strconv.Itoa(len(payload))),
  259. )
  260. return c.Do(ctx, method, urlpath, payload, opts...)
  261. }
  262. // Uploader holds the necessary information for upload requests.
  263. //
  264. // Look the Client.NewUploader method.
  265. type Uploader struct {
  266. client *Client
  267. body *bytes.Buffer
  268. Writer *multipart.Writer
  269. }
  270. // AddFileSource adds a form field to the uploader with the given key.
  271. func (u *Uploader) AddField(key, value string) error {
  272. f, err := u.Writer.CreateFormField(key)
  273. if err != nil {
  274. return err
  275. }
  276. _, err = io.Copy(f, strings.NewReader(value))
  277. return err
  278. }
  279. // AddFileSource adds a form file to the uploader with the given key.
  280. func (u *Uploader) AddFileSource(key, filename string, source io.Reader) error {
  281. f, err := u.Writer.CreateFormFile(key, filename)
  282. if err != nil {
  283. return err
  284. }
  285. _, err = io.Copy(f, source)
  286. return err
  287. }
  288. // AddFile adds a local form file to the uploader with the given key.
  289. func (u *Uploader) AddFile(key, filename string) error {
  290. source, err := os.Open(filename)
  291. if err != nil {
  292. return err
  293. }
  294. return u.AddFileSource(key, filename, source)
  295. }
  296. // Uploads sends local data to the server.
  297. func (u *Uploader) Upload(ctx context.Context, method, urlpath string, opts ...RequestOption) (*http.Response, error) {
  298. err := u.Writer.Close()
  299. if err != nil {
  300. return nil, err
  301. }
  302. payload := bytes.NewReader(u.body.Bytes())
  303. opts = append(opts, RequestHeader(true, contentTypeKey, u.Writer.FormDataContentType()))
  304. return u.client.Do(ctx, method, urlpath, payload, opts...)
  305. }
  306. // NewUploader returns a structure which is responsible for sending
  307. // file and form data to the server.
  308. func (c *Client) NewUploader() *Uploader {
  309. body := new(bytes.Buffer)
  310. writer := multipart.NewWriter(body)
  311. return &Uploader{
  312. client: c,
  313. body: body,
  314. Writer: writer,
  315. }
  316. }
  317. // ReadJSON binds "dest" to the response's body.
  318. // After this call, the response body reader is closed.
  319. func (c *Client) ReadJSON(ctx context.Context, dest interface{}, method, urlpath string, payload interface{}, opts ...RequestOption) error {
  320. if payload != nil {
  321. opts = append(opts, RequestHeader(true, contentTypeKey, contentTypeJSON))
  322. }
  323. resp, err := c.Do(ctx, method, urlpath, payload, opts...)
  324. if err != nil {
  325. return err
  326. }
  327. defer c.DrainResponseBody(resp)
  328. if resp.StatusCode >= http.StatusBadRequest {
  329. return ExtractError(resp)
  330. }
  331. // DBUG
  332. // b, _ := io.ReadAll(resp.Body)
  333. // println(string(b))
  334. // return json.Unmarshal(b, &dest)
  335. return json.NewDecoder(resp.Body).Decode(&dest)
  336. }
  337. // ReadPlain like ReadJSON but it accepts a pointer to a string or byte slice or integer
  338. // and it reads the body as plain text.
  339. func (c *Client) ReadPlain(ctx context.Context, dest interface{}, method, urlpath string, payload interface{}, opts ...RequestOption) error {
  340. resp, err := c.Do(ctx, method, urlpath, payload, opts...)
  341. if err != nil {
  342. return err
  343. }
  344. defer c.DrainResponseBody(resp)
  345. if resp.StatusCode >= http.StatusBadRequest {
  346. return ExtractError(resp)
  347. }
  348. body, err := io.ReadAll(resp.Body)
  349. if err != nil {
  350. return err
  351. }
  352. switch ptr := dest.(type) {
  353. case *[]byte:
  354. *ptr = body
  355. return nil
  356. case *string:
  357. *ptr = string(body)
  358. return nil
  359. case *int:
  360. *ptr, err = strconv.Atoi(string(body))
  361. return err
  362. default:
  363. return fmt.Errorf("unsupported response body type: %T", ptr)
  364. }
  365. }
  366. // GetPlainUnquote reads the response body as raw text and tries to unquote it,
  367. // useful when the remote server sends a single key as a value but due to backend mistake
  368. // it sends it as JSON (quoted) instead of plain text.
  369. func (c *Client) GetPlainUnquote(ctx context.Context, method, urlpath string, payload interface{}, opts ...RequestOption) (string, error) {
  370. var bodyStr string
  371. if err := c.ReadPlain(ctx, &bodyStr, method, urlpath, payload, opts...); err != nil {
  372. return "", err
  373. }
  374. s, err := strconv.Unquote(bodyStr)
  375. if err == nil {
  376. bodyStr = s
  377. }
  378. return bodyStr, nil
  379. }
  380. // WriteTo reads the response and then copies its data to the "dest" writer.
  381. // If the "dest" is a type of HTTP response writer then it writes the
  382. // content-type and content-length of the original request.
  383. //
  384. // Returns the amount of bytes written to "dest".
  385. func (c *Client) WriteTo(ctx context.Context, dest io.Writer, method, urlpath string, payload interface{}, opts ...RequestOption) (int64, error) {
  386. if payload != nil {
  387. opts = append(opts, RequestHeader(true, contentTypeKey, contentTypeJSON))
  388. }
  389. resp, err := c.Do(ctx, method, urlpath, payload, opts...)
  390. if err != nil {
  391. return 0, err
  392. }
  393. defer resp.Body.Close()
  394. if w, ok := dest.(http.ResponseWriter); ok {
  395. // Copy the content type and content-length.
  396. w.Header().Set("Content-Type", resp.Header.Get("Content-Type"))
  397. if resp.ContentLength > 0 {
  398. w.Header().Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10))
  399. }
  400. }
  401. return io.Copy(dest, resp.Body)
  402. }
  403. // BindResponse consumes the response's body and binds the result to the "dest" pointer,
  404. // closing the response's body is up to the caller.
  405. //
  406. // The "dest" will be binded based on the response's content type header.
  407. // Note that this is strict in order to catch bad actioners fast,
  408. // e.g. it wont try to read plain text if not specified on
  409. // the response headers and the dest is a *string.
  410. func BindResponse(resp *http.Response, dest interface{}) (err error) {
  411. contentType := trimHeader(resp.Header.Get(contentTypeKey))
  412. switch contentType {
  413. case contentTypeJSON: // the most common scenario on successful responses.
  414. return json.NewDecoder(resp.Body).Decode(&dest)
  415. case contentTypePlainText:
  416. b, err := io.ReadAll(resp.Body)
  417. if err != nil {
  418. return err
  419. }
  420. switch v := dest.(type) {
  421. case *string:
  422. *v = string(b)
  423. case *[]byte:
  424. *v = b
  425. default:
  426. return fmt.Errorf("plain text response should accept a *string or a *[]byte")
  427. }
  428. default:
  429. acceptContentType := trimHeader(resp.Request.Header.Get(acceptKey))
  430. msg := ""
  431. if acceptContentType == contentType {
  432. // Here we make a special case, if the content type
  433. // was explicitly set by the request but we cannot handle it.
  434. msg = fmt.Sprintf("current implementation can not handle the received (and accepted) mime type: %s", contentType)
  435. } else {
  436. msg = fmt.Sprintf("unexpected mime type received: %s", contentType)
  437. }
  438. err = errors.New(msg)
  439. }
  440. return
  441. }
  442. func trimHeader(v string) string {
  443. for i, char := range v {
  444. if char == ' ' || char == ';' {
  445. return v[:i]
  446. }
  447. }
  448. return v
  449. }