dialer.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. package ws
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/tls"
  7. "fmt"
  8. "io"
  9. "net"
  10. "net/http"
  11. "net/url"
  12. "strconv"
  13. "strings"
  14. "time"
  15. "github.com/gobwas/httphead"
  16. "github.com/gobwas/pool/pbufio"
  17. )
  18. // Constants used by Dialer.
  19. const (
  20. DefaultClientReadBufferSize = 4096
  21. DefaultClientWriteBufferSize = 4096
  22. )
  23. // Handshake represents handshake result.
  24. type Handshake struct {
  25. // Protocol is the subprotocol selected during handshake.
  26. Protocol string
  27. // Extensions is the list of negotiated extensions.
  28. Extensions []httphead.Option
  29. }
  30. // Errors used by the websocket client.
  31. var (
  32. ErrHandshakeBadStatus = fmt.Errorf("unexpected http status")
  33. ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol)
  34. ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol)
  35. )
  36. // DefaultDialer is dialer that holds no options and is used by Dial function.
  37. var DefaultDialer Dialer
  38. // Dial is like Dialer{}.Dial().
  39. func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) {
  40. return DefaultDialer.Dial(ctx, urlstr)
  41. }
  42. // Dialer contains options for establishing websocket connection to an url.
  43. type Dialer struct {
  44. // ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
  45. // They used to read and write http data while upgrading to WebSocket.
  46. // Allocated buffers are pooled with sync.Pool to avoid extra allocations.
  47. //
  48. // If a size is zero then default value is used.
  49. ReadBufferSize, WriteBufferSize int
  50. // Timeout is the maximum amount of time a Dial() will wait for a connect
  51. // and an handshake to complete.
  52. //
  53. // The default is no timeout.
  54. Timeout time.Duration
  55. // Protocols is the list of subprotocols that the client wants to speak,
  56. // ordered by preference.
  57. //
  58. // See https://tools.ietf.org/html/rfc6455#section-4.1
  59. Protocols []string
  60. // Extensions is the list of extensions that client wants to speak.
  61. //
  62. // Note that if server decides to use some of this extensions, Dial() will
  63. // return Handshake struct containing a slice of items, which are the
  64. // shallow copies of the items from this list. That is, internals of
  65. // Extensions items are shared during Dial().
  66. //
  67. // See https://tools.ietf.org/html/rfc6455#section-4.1
  68. // See https://tools.ietf.org/html/rfc6455#section-9.1
  69. Extensions []httphead.Option
  70. // Header is an optional HandshakeHeader instance that could be used to
  71. // write additional headers to the handshake request.
  72. //
  73. // It used instead of any key-value mappings to avoid allocations in user
  74. // land.
  75. Header HandshakeHeader
  76. // OnStatusError is the callback that will be called after receiving non
  77. // "101 Continue" HTTP response status. It receives an io.Reader object
  78. // representing server response bytes. That is, it gives ability to parse
  79. // HTTP response somehow (probably with http.ReadResponse call) and make a
  80. // decision of further logic.
  81. //
  82. // The arguments are only valid until the callback returns.
  83. OnStatusError func(status int, reason []byte, resp io.Reader)
  84. // OnHeader is the callback that will be called after successful parsing of
  85. // header, that is not used during WebSocket handshake procedure. That is,
  86. // it will be called with non-websocket headers, which could be relevant
  87. // for application-level logic.
  88. //
  89. // The arguments are only valid until the callback returns.
  90. //
  91. // Returned value could be used to prevent processing response.
  92. OnHeader func(key, value []byte) (err error)
  93. // NetDial is the function that is used to get plain tcp connection.
  94. // If it is not nil, then it is used instead of net.Dialer.
  95. NetDial func(ctx context.Context, network, addr string) (net.Conn, error)
  96. // TLSClient is the callback that will be called after successful dial with
  97. // received connection and its remote host name. If it is nil, then the
  98. // default tls.Client() will be used.
  99. // If it is not nil, then TLSConfig field is ignored.
  100. TLSClient func(conn net.Conn, hostname string) net.Conn
  101. // TLSConfig is passed to tls.Client() to start TLS over established
  102. // connection. If TLSClient is not nil, then it is ignored. If TLSConfig is
  103. // non-nil and its ServerName is empty, then for every Dial() it will be
  104. // cloned and appropriate ServerName will be set.
  105. TLSConfig *tls.Config
  106. // WrapConn is the optional callback that will be called when connection is
  107. // ready for an i/o. That is, it will be called after successful dial and
  108. // TLS initialization (for "wss" schemes). It may be helpful for different
  109. // user land purposes such as end to end encryption.
  110. //
  111. // Note that for debugging purposes of an http handshake (e.g. sent request
  112. // and received response), there is an wsutil.DebugDialer struct.
  113. WrapConn func(conn net.Conn) net.Conn
  114. }
  115. // Dial connects to the url host and upgrades connection to WebSocket.
  116. //
  117. // If server has sent frames right after successful handshake then returned
  118. // buffer will be non-nil. In other cases buffer is always nil. For better
  119. // memory efficiency received non-nil bufio.Reader should be returned to the
  120. // inner pool with PutReader() function after use.
  121. //
  122. // Note that Dialer does not implement IDNA (RFC5895) logic as net/http does.
  123. // If you want to dial non-ascii host name, take care of its name serialization
  124. // avoiding bad request issues. For more info see net/http Request.Write()
  125. // implementation, especially cleanHost() function.
  126. func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) {
  127. u, err := url.ParseRequestURI(urlstr)
  128. if err != nil {
  129. return nil, nil, hs, err
  130. }
  131. // Prepare context to dial with. Initially it is the same as original, but
  132. // if d.Timeout is non-zero and points to time that is before ctx.Deadline,
  133. // we use more shorter context for dial.
  134. dialctx := ctx
  135. var deadline time.Time
  136. if t := d.Timeout; t != 0 {
  137. deadline = time.Now().Add(t)
  138. if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
  139. var cancel context.CancelFunc
  140. dialctx, cancel = context.WithDeadline(ctx, deadline)
  141. defer cancel()
  142. }
  143. }
  144. if conn, err = d.dial(dialctx, u); err != nil {
  145. return conn, nil, hs, err
  146. }
  147. defer func() {
  148. if err != nil {
  149. conn.Close()
  150. }
  151. }()
  152. if ctx == context.Background() {
  153. // No need to start I/O interrupter goroutine which is not zero-cost.
  154. conn.SetDeadline(deadline)
  155. defer conn.SetDeadline(noDeadline)
  156. } else {
  157. // Context could be canceled or its deadline could be exceeded.
  158. // Start the interrupter goroutine to handle context cancelation.
  159. done := setupContextDeadliner(ctx, conn)
  160. defer func() {
  161. // Map Upgrade() error to a possible context expiration error. That
  162. // is, even if Upgrade() err is nil, context could be already
  163. // expired and connection be "poisoned" by SetDeadline() call.
  164. // In that case we must not return ctx.Err() error.
  165. done(&err)
  166. }()
  167. }
  168. br, hs, err = d.Upgrade(conn, u)
  169. return conn, br, hs, err
  170. }
  171. var (
  172. // netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if
  173. // Dialer.NetDial is not provided.
  174. netEmptyDialer net.Dialer
  175. // tlsEmptyConfig is an empty tls.Config used as default one.
  176. tlsEmptyConfig tls.Config
  177. )
  178. func tlsDefaultConfig() *tls.Config {
  179. return &tlsEmptyConfig
  180. }
  181. func hostport(host, defaultPort string) (hostname, addr string) {
  182. var (
  183. colon = strings.LastIndexByte(host, ':')
  184. bracket = strings.IndexByte(host, ']')
  185. )
  186. if colon > bracket {
  187. return host[:colon], host
  188. }
  189. return host, host + defaultPort
  190. }
  191. func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) {
  192. dial := d.NetDial
  193. if dial == nil {
  194. dial = netEmptyDialer.DialContext
  195. }
  196. switch u.Scheme {
  197. case "ws":
  198. _, addr := hostport(u.Host, ":80")
  199. conn, err = dial(ctx, "tcp", addr)
  200. case "wss":
  201. hostname, addr := hostport(u.Host, ":443")
  202. conn, err = dial(ctx, "tcp", addr)
  203. if err != nil {
  204. return nil, err
  205. }
  206. tlsClient := d.TLSClient
  207. if tlsClient == nil {
  208. tlsClient = d.tlsClient
  209. }
  210. conn = tlsClient(conn, hostname)
  211. default:
  212. return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme)
  213. }
  214. if wrap := d.WrapConn; wrap != nil {
  215. conn = wrap(conn)
  216. }
  217. return conn, err
  218. }
  219. func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn {
  220. config := d.TLSConfig
  221. if config == nil {
  222. config = tlsDefaultConfig()
  223. }
  224. if config.ServerName == "" {
  225. config = tlsCloneConfig(config)
  226. config.ServerName = hostname
  227. }
  228. // Do not make conn.Handshake() here because downstairs we will prepare
  229. // i/o on this conn with proper context's timeout handling.
  230. return tls.Client(conn, config)
  231. }
  232. var (
  233. // This variables are set like in net/net.go.
  234. // noDeadline is just zero value for readability.
  235. noDeadline = time.Time{}
  236. // aLongTimeAgo is a non-zero time, far in the past, used for immediate
  237. // cancelation of dials.
  238. aLongTimeAgo = time.Unix(42, 0)
  239. )
  240. // Upgrade writes an upgrade request to the given io.ReadWriter conn at given
  241. // url u and reads a response from it.
  242. //
  243. // It is a caller responsibility to manage I/O deadlines on conn.
  244. //
  245. // It returns handshake info and some bytes which could be written by the peer
  246. // right after response and be caught by us during buffered read.
  247. func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) {
  248. // headerSeen constants helps to report whether or not some header was seen
  249. // during reading request bytes.
  250. const (
  251. headerSeenUpgrade = 1 << iota
  252. headerSeenConnection
  253. headerSeenSecAccept
  254. // headerSeenAll is the value that we expect to receive at the end of
  255. // headers read/parse loop.
  256. headerSeenAll = 0 |
  257. headerSeenUpgrade |
  258. headerSeenConnection |
  259. headerSeenSecAccept
  260. )
  261. br = pbufio.GetReader(conn,
  262. nonZero(d.ReadBufferSize, DefaultClientReadBufferSize),
  263. )
  264. bw := pbufio.GetWriter(conn,
  265. nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize),
  266. )
  267. defer func() {
  268. pbufio.PutWriter(bw)
  269. if br.Buffered() == 0 || err != nil {
  270. // Server does not wrote additional bytes to the connection or
  271. // error occurred. That is, no reason to return buffer.
  272. pbufio.PutReader(br)
  273. br = nil
  274. }
  275. }()
  276. nonce := make([]byte, nonceSize)
  277. initNonce(nonce)
  278. httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header)
  279. if err := bw.Flush(); err != nil {
  280. return br, hs, err
  281. }
  282. // Read HTTP status line like "HTTP/1.1 101 Switching Protocols".
  283. sl, err := readLine(br)
  284. if err != nil {
  285. return br, hs, err
  286. }
  287. // Begin validation of the response.
  288. // See https://tools.ietf.org/html/rfc6455#section-4.2.2
  289. // Parse request line data like HTTP version, uri and method.
  290. resp, err := httpParseResponseLine(sl)
  291. if err != nil {
  292. return br, hs, err
  293. }
  294. // Even if RFC says "1.1 or higher" without mentioning the part of the
  295. // version, we apply it only to minor part.
  296. if resp.major != 1 || resp.minor < 1 {
  297. err = ErrHandshakeBadProtocol
  298. return br, hs, err
  299. }
  300. if resp.status != http.StatusSwitchingProtocols {
  301. err = StatusError(resp.status)
  302. if onStatusError := d.OnStatusError; onStatusError != nil {
  303. // Invoke callback with multireader of status-line bytes br.
  304. onStatusError(resp.status, resp.reason,
  305. io.MultiReader(
  306. bytes.NewReader(sl),
  307. strings.NewReader(crlf),
  308. br,
  309. ),
  310. )
  311. }
  312. return br, hs, err
  313. }
  314. // If response status is 101 then we expect all technical headers to be
  315. // valid. If not, then we stop processing response without giving user
  316. // ability to read non-technical headers. That is, we do not distinguish
  317. // technical errors (such as parsing error) and protocol errors.
  318. var headerSeen byte
  319. for {
  320. line, e := readLine(br)
  321. if e != nil {
  322. err = e
  323. return br, hs, err
  324. }
  325. if len(line) == 0 {
  326. // Blank line, no more lines to read.
  327. break
  328. }
  329. k, v, ok := httpParseHeaderLine(line)
  330. if !ok {
  331. err = ErrMalformedResponse
  332. return br, hs, err
  333. }
  334. switch btsToString(k) {
  335. case headerUpgradeCanonical:
  336. headerSeen |= headerSeenUpgrade
  337. if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
  338. err = ErrHandshakeBadUpgrade
  339. return br, hs, err
  340. }
  341. case headerConnectionCanonical:
  342. headerSeen |= headerSeenConnection
  343. // Note that as RFC6455 says:
  344. // > A |Connection| header field with value "Upgrade".
  345. // That is, in server side, "Connection" header could contain
  346. // multiple token. But in response it must contains exactly one.
  347. if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) {
  348. err = ErrHandshakeBadConnection
  349. return br, hs, err
  350. }
  351. case headerSecAcceptCanonical:
  352. headerSeen |= headerSeenSecAccept
  353. if !checkAcceptFromNonce(v, nonce) {
  354. err = ErrHandshakeBadSecAccept
  355. return br, hs, err
  356. }
  357. case headerSecProtocolCanonical:
  358. // RFC6455 1.3:
  359. // "The server selects one or none of the acceptable protocols
  360. // and echoes that value in its handshake to indicate that it has
  361. // selected that protocol."
  362. for _, want := range d.Protocols {
  363. if string(v) == want {
  364. hs.Protocol = want
  365. break
  366. }
  367. }
  368. if hs.Protocol == "" {
  369. // Server echoed subprotocol that is not present in client
  370. // requested protocols.
  371. err = ErrHandshakeBadSubProtocol
  372. return br, hs, err
  373. }
  374. case headerSecExtensionsCanonical:
  375. hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions)
  376. if err != nil {
  377. return br, hs, err
  378. }
  379. default:
  380. if onHeader := d.OnHeader; onHeader != nil {
  381. if e := onHeader(k, v); e != nil {
  382. err = e
  383. return br, hs, err
  384. }
  385. }
  386. }
  387. }
  388. if err == nil && headerSeen != headerSeenAll {
  389. switch {
  390. case headerSeen&headerSeenUpgrade == 0:
  391. err = ErrHandshakeBadUpgrade
  392. case headerSeen&headerSeenConnection == 0:
  393. err = ErrHandshakeBadConnection
  394. case headerSeen&headerSeenSecAccept == 0:
  395. err = ErrHandshakeBadSecAccept
  396. default:
  397. panic("unknown headers state")
  398. }
  399. }
  400. return br, hs, err
  401. }
  402. // PutReader returns bufio.Reader instance to the inner reuse pool.
  403. // It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which
  404. // contains unprocessed buffered data, that was sent by the server quickly
  405. // right after handshake.
  406. func PutReader(br *bufio.Reader) {
  407. pbufio.PutReader(br)
  408. }
  409. // StatusError contains an unexpected status-line code from the server.
  410. type StatusError int
  411. func (s StatusError) Error() string {
  412. return "unexpected HTTP response status: " + strconv.Itoa(int(s))
  413. }
  414. func isTimeoutError(err error) bool {
  415. t, ok := err.(net.Error)
  416. return ok && t.Timeout()
  417. }
  418. func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) {
  419. if len(selected) == 0 {
  420. return received, nil
  421. }
  422. var (
  423. index int
  424. option httphead.Option
  425. err error
  426. )
  427. index = -1
  428. match := func() (ok bool) {
  429. for _, want := range wanted {
  430. // A server accepts one or more extensions by including a
  431. // |Sec-WebSocket-Extensions| header field containing one or more
  432. // extensions that were requested by the client.
  433. //
  434. // The interpretation of any extension parameters, and what
  435. // constitutes a valid response by a server to a requested set of
  436. // parameters by a client, will be defined by each such extension.
  437. if bytes.Equal(option.Name, want.Name) {
  438. // Check parsed extension to be present in client
  439. // requested extensions. We move matched extension
  440. // from client list to avoid allocation of httphead.Option.Name,
  441. // httphead.Option.Parameters have to be copied from the header
  442. want.Parameters, _ = option.Parameters.Copy(make([]byte, option.Parameters.Size()))
  443. received = append(received, want)
  444. return true
  445. }
  446. }
  447. return false
  448. }
  449. ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control {
  450. if i != index {
  451. // Met next option.
  452. index = i
  453. if i != 0 && !match() {
  454. // Server returned non-requested extension.
  455. err = ErrHandshakeBadExtensions
  456. return httphead.ControlBreak
  457. }
  458. option = httphead.Option{Name: name}
  459. }
  460. if attr != nil {
  461. option.Parameters.Set(attr, val)
  462. }
  463. return httphead.ControlContinue
  464. })
  465. if !ok {
  466. err = ErrMalformedResponse
  467. return received, err
  468. }
  469. if !match() {
  470. return received, ErrHandshakeBadExtensions
  471. }
  472. return received, err
  473. }
  474. // setupContextDeadliner is a helper function that starts connection I/O
  475. // interrupter goroutine.
  476. //
  477. // Started goroutine calls SetDeadline() with long time ago value when context
  478. // become expired to make any I/O operations failed. It returns done function
  479. // that stops started goroutine and maps error received from conn I/O methods
  480. // to possible context expiration error.
  481. //
  482. // In concern with possible SetDeadline() call inside interrupter goroutine,
  483. // caller passes pointer to its I/O error (even if it is nil) to done(&err).
  484. // That is, even if I/O error is nil, context could be already expired and
  485. // connection "poisoned" by SetDeadline() call. In that case done(&err) will
  486. // store at *err ctx.Err() result. If err is caused not by timeout, it will
  487. // leaved untouched.
  488. func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) {
  489. var (
  490. quit = make(chan struct{})
  491. interrupt = make(chan error, 1)
  492. )
  493. go func() {
  494. select {
  495. case <-quit:
  496. interrupt <- nil
  497. case <-ctx.Done():
  498. // Cancel i/o immediately.
  499. conn.SetDeadline(aLongTimeAgo)
  500. interrupt <- ctx.Err()
  501. }
  502. }()
  503. return func(err *error) {
  504. close(quit)
  505. // If ctx.Err() is non-nil and the original err is net.Error with
  506. // Timeout() == true, then it means that I/O was canceled by us by
  507. // SetDeadline(aLongTimeAgo) call, or by somebody else previously
  508. // by conn.SetDeadline(x).
  509. //
  510. // Even on race condition when both deadlines are expired
  511. // (SetDeadline() made not by us and context's), we prefer ctx.Err() to
  512. // be returned.
  513. if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) {
  514. *err = ctxErr
  515. }
  516. }
  517. }