binder.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package httpexpect
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "fmt"
  6. "io/ioutil"
  7. "net"
  8. "net/http"
  9. "net/http/httptest"
  10. )
  11. // Binder implements networkless http.RoundTripper attached directly to
  12. // http.Handler.
  13. //
  14. // Binder emulates network communication by invoking given http.Handler
  15. // directly. It passes httptest.ResponseRecorder as http.ResponseWriter
  16. // to the handler, and then constructs http.Response from recorded data.
  17. type Binder struct {
  18. // HTTP handler invoked for every request.
  19. Handler http.Handler
  20. // TLS connection state used for https:// requests.
  21. TLS *tls.ConnectionState
  22. }
  23. // NewBinder returns a new Binder given a http.Handler.
  24. //
  25. // Example:
  26. // client := &http.Client{
  27. // Transport: NewBinder(handler),
  28. // }
  29. func NewBinder(handler http.Handler) Binder {
  30. return Binder{Handler: handler}
  31. }
  32. // RoundTrip implements http.RoundTripper.RoundTrip.
  33. func (binder Binder) RoundTrip(req *http.Request) (*http.Response, error) {
  34. if req.Proto == "" {
  35. req.Proto = fmt.Sprintf("HTTP/%d.%d", req.ProtoMajor, req.ProtoMinor)
  36. }
  37. if req.Body != nil {
  38. if req.ContentLength == -1 {
  39. req.TransferEncoding = []string{"chunked"}
  40. }
  41. } else {
  42. req.Body = ioutil.NopCloser(bytes.NewReader(nil))
  43. }
  44. if req.URL != nil && req.URL.Scheme == "https" && binder.TLS != nil {
  45. req.TLS = binder.TLS
  46. }
  47. if req.RequestURI == "" {
  48. req.RequestURI = req.URL.RequestURI()
  49. }
  50. recorder := httptest.NewRecorder()
  51. binder.Handler.ServeHTTP(recorder, req)
  52. resp := http.Response{
  53. Request: req,
  54. StatusCode: recorder.Code,
  55. Status: http.StatusText(recorder.Code),
  56. Header: recorder.HeaderMap,
  57. }
  58. if recorder.Flushed {
  59. resp.TransferEncoding = []string{"chunked"}
  60. }
  61. if recorder.Body != nil {
  62. resp.Body = ioutil.NopCloser(recorder.Body)
  63. }
  64. return &resp, nil
  65. }
  66. type connNonTLS struct {
  67. net.Conn
  68. }
  69. func (connNonTLS) RemoteAddr() net.Addr {
  70. return &net.TCPAddr{IP: net.IPv4zero}
  71. }
  72. func (connNonTLS) LocalAddr() net.Addr {
  73. return &net.TCPAddr{IP: net.IPv4zero}
  74. }
  75. type connTLS struct {
  76. connNonTLS
  77. state *tls.ConnectionState
  78. }
  79. func (c connTLS) ConnectionState() tls.ConnectionState {
  80. return *c.state
  81. }