package utils import ( "bytes" "context" "encoding/json" "fmt" "io" "io/ioutil" "net/http" "net/url" "regexp" "time" ) // 封装一个可以支持重试的http request client // admin server集群在某些机器宕机或者超时的情况下轮询重试 // var ( // 默认重试等待时间 defaultRetryWaitTime = 1 * time.Second // 默认重试次数 defaultRetryCount = 4 // 重定向次数过多的错误 redirectsErrorRe = regexp.MustCompile(`stopped after \d+ redirects\z`) // 不支持的协议类型错误 schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`) // 默认的客户端 DefaultClient = NewHttpClient() // 默认log // defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} // 请求超时时间 defaultRequestTimeout = 3 * time.Second ) // ErrorResult 错误结果 type ErrorResult struct { Code int `json:"code"` Message string `json:"message"` } func (r *ErrorResult) Error() string { return r.Message } type ReaderFunc func() (io.Reader, error) type Request struct { body ReaderFunc *http.Request } // Responser response interface type Responser interface { String() (string, error) Bytes() ([]byte, error) JSON(v interface{}) error Response() *http.Response Close() } func newResponse(resp *http.Response) *response { return &response{resp} } type response struct { resp *http.Response } func (r *response) Response() *http.Response { return r.resp } func (r *response) String() (string, error) { b, err := r.Bytes() if err != nil { return "", err } return string(b), nil } func (r *response) Bytes() ([]byte, error) { defer r.resp.Body.Close() buf, err := ioutil.ReadAll(r.resp.Body) if err != nil { return nil, err } return buf, nil } func (r *response) JSON(v interface{}) error { defer r.resp.Body.Close() return json.NewDecoder(r.resp.Body).Decode(v) } func (r *response) Close() { if !r.resp.Close { r.resp.Body.Close() } } // ParseResponseJSON 解析响应JSON func ParseResponseJSON(resp Responser, result interface{}) error { if resp.Response().StatusCode != 200 { buf, err := resp.Bytes() if err != nil { return err } errResult := &ErrorResult{} err = json.Unmarshal(buf, errResult) if err == nil && (errResult.Code != 0 || errResult.Message != "") { return errResult } return fmt.Errorf("%s", buf) } else if result == nil { resp.Close() return nil } return resp.JSON(result) } func NewRequest(method, url string, rawBody interface{}) (*Request, error) { bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody) if err != nil { return nil, err } httpReq, err := http.NewRequest(method, url, nil) if err != nil { return nil, err } httpReq.ContentLength = contentLength return &Request{bodyReader, httpReq}, nil } type LenReader interface { Len() int } func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, error) { var bodyReader ReaderFunc var contentLength int64 if rawBody != nil { switch body := rawBody.(type) { // 如果注册了ReaderFunc,直接调用 case ReaderFunc: bodyReader = body tmp, err := body() if err != nil { return nil, 0, err } if lr, ok := tmp.(LenReader); ok { contentLength = int64(lr.Len()) } if c, ok := tmp.(io.Closer); ok { _ = c.Close() } case func() (io.Reader, error): bodyReader = body tmp, err := body() if err != nil { return nil, 0, err } if lr, ok := tmp.(LenReader); ok { contentLength = int64(lr.Len()) } if c, ok := tmp.(io.Closer); ok { _ = c.Close() } case []byte: buf := body bodyReader = func() (io.Reader, error) { return bytes.NewReader(buf), nil } contentLength = int64(len(buf)) case *bytes.Buffer: buf := body bodyReader = func() (io.Reader, error) { return bytes.NewReader(buf.Bytes()), nil } contentLength = int64(buf.Len()) case *bytes.Reader: buf, err := ioutil.ReadAll(body) if err != nil { return nil, 0, err } bodyReader = func() (io.Reader, error) { return bytes.NewReader(buf), nil } contentLength = int64(len(buf)) case io.ReadSeeker: raw := body bodyReader = func() (io.Reader, error) { _, err := raw.Seek(0, 0) return ioutil.NopCloser(raw), err } if lr, ok := raw.(LenReader); ok { contentLength = int64(lr.Len()) } case io.Reader: buf, err := ioutil.ReadAll(body) if err != nil { return nil, 0, err } bodyReader = func() (io.Reader, error) { return bytes.NewReader(buf), nil } contentLength = int64(len(buf)) default: return nil, 0, fmt.Errorf("无法处理的的body类型 %T", rawBody) } } return bodyReader, contentLength, nil } func (r *Request) WithContext(ctx context.Context) *Request { r.Request = r.Request.WithContext(ctx) return r } func (r *Request) BodyBytes() ([]byte, error) { if r.body == nil { return nil, nil } body, err := r.body() if err != nil { return nil, err } buf := new(bytes.Buffer) _, err = buf.ReadFrom(body) if err != nil { return nil, err } return buf.Bytes(), nil } // 指定是否可以重试的策略,如果返回false,则客户端停止重试。 type CheckRetry func(ctx context.Context, resp *http.Response, err error) (bool, error) // DefaultCheckRetry 默认的重试策略 func DefaultCheckRetry(ctx context.Context, resp *http.Response, err error) (bool, error) { if ctx.Err() != nil { return false, ctx.Err() } if err != nil { if v, ok := err.(*url.Error); ok { if redirectsErrorRe.MatchString(v.Error()) { return false, nil } if schemeErrorRe.MatchString(v.Error()) { return false, nil } // 超时不重试 if v.Timeout() { return false, nil } } return true, nil } if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != 501) { return true, nil } return false, nil } type logger interface { Info(v ...interface{}) } // LogWriter log writer interface type LogWriter interface { Info(v ...interface{}) } // Logger default logger type Logger struct { LogWriter } // Print format & print log func (logger Logger) Print(values ...interface{}) { logger.Info(values...) } type HttpClient struct { // 默认的http client httpClient *http.Client // 重试等待时长 retryWaitTime time.Duration // 重试次数 retryCount int // 重试判定策略 canRetry CheckRetry // logger logger logger // request time out seconds timeOut time.Duration } // NewHttpClient new http client with retry func NewHttpClient() *HttpClient { return &HttpClient{ httpClient: &http.Client{ Timeout: defaultRequestTimeout, Transport: &http.Transport{ MaxIdleConns: 0, MaxIdleConnsPerHost: 0, MaxConnsPerHost: 0, }, }, retryWaitTime: defaultRetryWaitTime, retryCount: defaultRetryCount, canRetry: DefaultCheckRetry, //logger: defaultLogger, } } func NewHttpClientWithConfig(timeOut time.Duration, retry int, retryWaitTime time.Duration) *HttpClient { return &HttpClient{ httpClient: &http.Client{ Timeout: timeOut, }, retryWaitTime: retryWaitTime, retryCount: retry, canRetry: DefaultCheckRetry, } } // 设置外部注入logger,只要实现print方法 func (a *HttpClient) SetLogger(l logger) { a.logger = l } // SetLogger func SetLogger(l logger) { DefaultClient.SetLogger(l) } // Do do http method with retries func (a *HttpClient) Do(req *Request) (Responser, error) { if a.httpClient == nil { a.httpClient = http.DefaultClient } var resp *http.Response var err error for i := 0; ; i++ { var code int if req.body != nil { body, err := req.body() if err != nil { a.httpClient.CloseIdleConnections() return newResponse(resp), err } if c, ok := body.(io.ReadCloser); ok { req.Body = c } else { req.Body = ioutil.NopCloser(body) } } resp, err = a.httpClient.Do(req.Request) if resp != nil { code = resp.StatusCode } checkOk, checkErr := a.canRetry(req.Context(), resp, err) if err != nil { a.logger.Info(fmt.Sprintf("请求出错:[URL]:%s [Method]:%s, [错误]:%s", req.URL, req.Method, err.Error())) } if !checkOk { if checkErr != nil { err = checkErr } a.httpClient.CloseIdleConnections() return newResponse(resp), nil } remain := a.retryCount - i if remain <= 0 { break } //if err == nil &&resp !=nil { // //} desc := fmt.Sprintf("%s %s", req.Method, req.URL) if code > 0 { desc = fmt.Sprintf("%s [status:%d]", desc, code) } a.logger.Info(desc) select { case <-req.Context().Done(): a.httpClient.CloseIdleConnections() return nil, req.Context().Err() case <-time.After(a.retryWaitTime): } } if resp != nil { _ = resp.Body.Close() } a.httpClient.CloseIdleConnections() return nil, fmt.Errorf("%s %s giving up after %d attempts", req.Method, req.URL, a.retryCount+1) } func Get(url string) (Responser, error) { return DefaultClient.Get(url) } func (a *HttpClient) Get(url string) (Responser, error) { req, err := NewRequest("GET", url, nil) if err != nil { return nil, err } return a.Do(req) } func Post(url string, body interface{}) (Responser, error) { return DefaultClient.Post(url, "application/json", body) } func Put(url string, body interface{}) (Responser, error) { return DefaultClient.Put(url, "application/json", body) } func Delete(url string, body interface{}) (Responser, error) { return DefaultClient.Delete(url, "application/json", body) } func (a *HttpClient) Post(url, bodyType string, body interface{}) (Responser, error) { w := new(bytes.Buffer) if err := json.NewEncoder(w).Encode(body); err != nil { return nil, err } req, err := NewRequest("POST", url, w) if err != nil { return nil, err } req.Header.Set("Content-Type", bodyType) return a.Do(req) } func (a *HttpClient) Put(url, bodyType string, body interface{}) (Responser, error) { w := new(bytes.Buffer) if err := json.NewEncoder(w).Encode(body); err != nil { return nil, err } req, err := NewRequest(http.MethodPut, url, w) if err != nil { return nil, err } req.Header.Set("Content-Type", bodyType) return a.Do(req) } func (a *HttpClient) Delete(url, bodyType string, body interface{}) (Responser, error) { req, err := NewRequest(http.MethodDelete, url, nil) if err != nil { return nil, err } req.Header.Set("Content-Type", bodyType) return a.Do(req) } // GetForObject http get a json obj func GetForObject(url string, result interface{}) error { resp, err := Get(url) if err != nil { return err } return ParseResponseJSON(resp, result) }