websocket_dialer.go 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. package httpexpect
  2. import (
  3. "bufio"
  4. "net"
  5. "net/http"
  6. "net/http/httptest"
  7. "sync"
  8. "github.com/gorilla/websocket"
  9. )
  10. // NewWebsocketDialer produces new websocket.Dialer which dials to bound
  11. // http.Handler without creating a real net.Conn.
  12. func NewWebsocketDialer(handler http.Handler) *websocket.Dialer {
  13. return &websocket.Dialer{
  14. NetDial: func(network, addr string) (net.Conn, error) {
  15. hc := newHandlerConn()
  16. hc.runHandler(handler)
  17. return hc, nil
  18. },
  19. }
  20. }
  21. type handlerConn struct {
  22. net.Conn // returned from dialer
  23. backConn net.Conn // passed to the background goroutine
  24. wg sync.WaitGroup
  25. }
  26. func newHandlerConn() *handlerConn {
  27. dialConn, backConn := net.Pipe()
  28. return &handlerConn{
  29. Conn: dialConn,
  30. backConn: backConn,
  31. }
  32. }
  33. func (hc *handlerConn) Close() error {
  34. err := hc.Conn.Close()
  35. hc.wg.Wait() // wait the background goroutine
  36. return err
  37. }
  38. func (hc *handlerConn) runHandler(handler http.Handler) {
  39. hc.wg.Add(1)
  40. go func() {
  41. defer hc.wg.Done()
  42. recorder := &hijackRecorder{conn: hc.backConn}
  43. for {
  44. req, err := http.ReadRequest(bufio.NewReader(hc.backConn))
  45. if err != nil {
  46. return
  47. }
  48. handler.ServeHTTP(recorder, req)
  49. }
  50. }()
  51. }
  52. // hijackRecorder it similar to httptest.ResponseRecorder,
  53. // but with Hijack capabilities.
  54. //
  55. // Original idea is stolen from https://github.com/posener/wstest
  56. type hijackRecorder struct {
  57. httptest.ResponseRecorder
  58. conn net.Conn
  59. }
  60. // Hijack the connection for caller.
  61. //
  62. // Implements http.Hijacker interface.
  63. func (r *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  64. rw := bufio.NewReadWriter(bufio.NewReader(r.conn), bufio.NewWriter(r.conn))
  65. return r.conn, rw, nil
  66. }
  67. // WriteHeader write HTTP header to the client and closes the connection
  68. //
  69. // Implements http.ResponseWriter interface.
  70. func (r *hijackRecorder) WriteHeader(code int) {
  71. resp := http.Response{StatusCode: code, Header: r.Header()}
  72. _ = resp.Write(r.conn)
  73. }