binder.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package httpexpect
  2. import (
  3. "crypto/tls"
  4. "fmt"
  5. "io/ioutil"
  6. "net"
  7. "net/http"
  8. "net/http/httptest"
  9. )
  10. // Binder implements networkless http.RoundTripper attached directly to
  11. // http.Handler.
  12. //
  13. // Binder emulates network communication by invoking given http.Handler
  14. // directly. It passes httptest.ResponseRecorder as http.ResponseWriter
  15. // to the handler, and then constructs http.Response from recorded data.
  16. type Binder struct {
  17. // HTTP handler invoked for every request.
  18. Handler http.Handler
  19. // TLS connection state used for https:// requests.
  20. TLS *tls.ConnectionState
  21. }
  22. // NewBinder returns a new Binder given a http.Handler.
  23. //
  24. // Example:
  25. //
  26. // client := &http.Client{
  27. // Transport: NewBinder(handler),
  28. // }
  29. func NewBinder(handler http.Handler) Binder {
  30. return Binder{Handler: handler}
  31. }
  32. // RoundTrip implements http.RoundTripper.RoundTrip.
  33. func (binder Binder) RoundTrip(origReq *http.Request) (*http.Response, error) {
  34. req := *origReq
  35. if req.Proto == "" {
  36. req.Proto = fmt.Sprintf("HTTP/%d.%d", req.ProtoMajor, req.ProtoMinor)
  37. }
  38. if req.Body != nil && req.Body != http.NoBody {
  39. if req.ContentLength == -1 {
  40. req.TransferEncoding = []string{"chunked"}
  41. }
  42. } else {
  43. req.Body = http.NoBody
  44. }
  45. if req.URL != nil && req.URL.Scheme == "https" && binder.TLS != nil {
  46. req.TLS = binder.TLS
  47. }
  48. if req.RequestURI == "" {
  49. req.RequestURI = req.URL.RequestURI()
  50. }
  51. recorder := httptest.NewRecorder()
  52. binder.Handler.ServeHTTP(recorder, &req)
  53. resp := http.Response{
  54. Request: &req,
  55. StatusCode: recorder.Code,
  56. Status: http.StatusText(recorder.Code),
  57. Header: recorder.Result().Header,
  58. }
  59. if recorder.Flushed {
  60. resp.TransferEncoding = []string{"chunked"}
  61. }
  62. if recorder.Body != nil {
  63. resp.Body = ioutil.NopCloser(recorder.Body)
  64. }
  65. return &resp, nil
  66. }
  67. type connNonTLS struct {
  68. net.Conn
  69. }
  70. func (connNonTLS) RemoteAddr() net.Addr {
  71. return &net.TCPAddr{IP: net.IPv4zero}
  72. }
  73. func (connNonTLS) LocalAddr() net.Addr {
  74. return &net.TCPAddr{IP: net.IPv4zero}
  75. }
  76. type connTLS struct {
  77. connNonTLS
  78. state *tls.ConnectionState
  79. }
  80. func (c connTLS) Handshake() error {
  81. return nil
  82. }
  83. func (c connTLS) ConnectionState() tls.ConnectionState {
  84. return *c.state
  85. }