conn.go 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230
  1. // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bufio"
  7. "encoding/binary"
  8. "errors"
  9. "io"
  10. "io/ioutil"
  11. "math/rand"
  12. "net"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "time"
  17. "unicode/utf8"
  18. )
  19. const (
  20. // Frame header byte 0 bits from Section 5.2 of RFC 6455
  21. finalBit = 1 << 7
  22. rsv1Bit = 1 << 6
  23. rsv2Bit = 1 << 5
  24. rsv3Bit = 1 << 4
  25. // Frame header byte 1 bits from Section 5.2 of RFC 6455
  26. maskBit = 1 << 7
  27. maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
  28. maxControlFramePayloadSize = 125
  29. writeWait = time.Second
  30. defaultReadBufferSize = 4096
  31. defaultWriteBufferSize = 4096
  32. continuationFrame = 0
  33. noFrame = -1
  34. )
  35. // Close codes defined in RFC 6455, section 11.7.
  36. const (
  37. CloseNormalClosure = 1000
  38. CloseGoingAway = 1001
  39. CloseProtocolError = 1002
  40. CloseUnsupportedData = 1003
  41. CloseNoStatusReceived = 1005
  42. CloseAbnormalClosure = 1006
  43. CloseInvalidFramePayloadData = 1007
  44. ClosePolicyViolation = 1008
  45. CloseMessageTooBig = 1009
  46. CloseMandatoryExtension = 1010
  47. CloseInternalServerErr = 1011
  48. CloseServiceRestart = 1012
  49. CloseTryAgainLater = 1013
  50. CloseTLSHandshake = 1015
  51. )
  52. // The message types are defined in RFC 6455, section 11.8.
  53. const (
  54. // TextMessage denotes a text data message. The text message payload is
  55. // interpreted as UTF-8 encoded text data.
  56. TextMessage = 1
  57. // BinaryMessage denotes a binary data message.
  58. BinaryMessage = 2
  59. // CloseMessage denotes a close control message. The optional message
  60. // payload contains a numeric code and text. Use the FormatCloseMessage
  61. // function to format a close message payload.
  62. CloseMessage = 8
  63. // PingMessage denotes a ping control message. The optional message payload
  64. // is UTF-8 encoded text.
  65. PingMessage = 9
  66. // PongMessage denotes a pong control message. The optional message payload
  67. // is UTF-8 encoded text.
  68. PongMessage = 10
  69. )
  70. // ErrCloseSent is returned when the application writes a message to the
  71. // connection after sending a close message.
  72. var ErrCloseSent = errors.New("websocket: close sent")
  73. // ErrReadLimit is returned when reading a message that is larger than the
  74. // read limit set for the connection.
  75. var ErrReadLimit = errors.New("websocket: read limit exceeded")
  76. // netError satisfies the net Error interface.
  77. type netError struct {
  78. msg string
  79. temporary bool
  80. timeout bool
  81. }
  82. func (e *netError) Error() string { return e.msg }
  83. func (e *netError) Temporary() bool { return e.temporary }
  84. func (e *netError) Timeout() bool { return e.timeout }
  85. // CloseError represents a close message.
  86. type CloseError struct {
  87. // Code is defined in RFC 6455, section 11.7.
  88. Code int
  89. // Text is the optional text payload.
  90. Text string
  91. }
  92. func (e *CloseError) Error() string {
  93. s := []byte("websocket: close ")
  94. s = strconv.AppendInt(s, int64(e.Code), 10)
  95. switch e.Code {
  96. case CloseNormalClosure:
  97. s = append(s, " (normal)"...)
  98. case CloseGoingAway:
  99. s = append(s, " (going away)"...)
  100. case CloseProtocolError:
  101. s = append(s, " (protocol error)"...)
  102. case CloseUnsupportedData:
  103. s = append(s, " (unsupported data)"...)
  104. case CloseNoStatusReceived:
  105. s = append(s, " (no status)"...)
  106. case CloseAbnormalClosure:
  107. s = append(s, " (abnormal closure)"...)
  108. case CloseInvalidFramePayloadData:
  109. s = append(s, " (invalid payload data)"...)
  110. case ClosePolicyViolation:
  111. s = append(s, " (policy violation)"...)
  112. case CloseMessageTooBig:
  113. s = append(s, " (message too big)"...)
  114. case CloseMandatoryExtension:
  115. s = append(s, " (mandatory extension missing)"...)
  116. case CloseInternalServerErr:
  117. s = append(s, " (internal server error)"...)
  118. case CloseTLSHandshake:
  119. s = append(s, " (TLS handshake error)"...)
  120. }
  121. if e.Text != "" {
  122. s = append(s, ": "...)
  123. s = append(s, e.Text...)
  124. }
  125. return string(s)
  126. }
  127. // IsCloseError returns boolean indicating whether the error is a *CloseError
  128. // with one of the specified codes.
  129. func IsCloseError(err error, codes ...int) bool {
  130. if e, ok := err.(*CloseError); ok {
  131. for _, code := range codes {
  132. if e.Code == code {
  133. return true
  134. }
  135. }
  136. }
  137. return false
  138. }
  139. // IsUnexpectedCloseError returns boolean indicating whether the error is a
  140. // *CloseError with a code not in the list of expected codes.
  141. func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
  142. if e, ok := err.(*CloseError); ok {
  143. for _, code := range expectedCodes {
  144. if e.Code == code {
  145. return false
  146. }
  147. }
  148. return true
  149. }
  150. return false
  151. }
  152. var (
  153. errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true}
  154. errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()}
  155. errBadWriteOpCode = errors.New("websocket: bad write message type")
  156. errWriteClosed = errors.New("websocket: write closed")
  157. errInvalidControlFrame = errors.New("websocket: invalid control frame")
  158. )
  159. func newMaskKey() [4]byte {
  160. n := rand.Uint32()
  161. return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
  162. }
  163. func hideTempErr(err error) error {
  164. if e, ok := err.(net.Error); ok && e.Temporary() {
  165. err = &netError{msg: e.Error(), timeout: e.Timeout()}
  166. }
  167. return err
  168. }
  169. func isControl(frameType int) bool {
  170. return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage
  171. }
  172. func isData(frameType int) bool {
  173. return frameType == TextMessage || frameType == BinaryMessage
  174. }
  175. var validReceivedCloseCodes = map[int]bool{
  176. // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
  177. CloseNormalClosure: true,
  178. CloseGoingAway: true,
  179. CloseProtocolError: true,
  180. CloseUnsupportedData: true,
  181. CloseNoStatusReceived: false,
  182. CloseAbnormalClosure: false,
  183. CloseInvalidFramePayloadData: true,
  184. ClosePolicyViolation: true,
  185. CloseMessageTooBig: true,
  186. CloseMandatoryExtension: true,
  187. CloseInternalServerErr: true,
  188. CloseServiceRestart: true,
  189. CloseTryAgainLater: true,
  190. CloseTLSHandshake: false,
  191. }
  192. func isValidReceivedCloseCode(code int) bool {
  193. return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
  194. }
  195. // BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
  196. // interface. The type of the value stored in a pool is not specified.
  197. type BufferPool interface {
  198. // Get gets a value from the pool or returns nil if the pool is empty.
  199. Get() interface{}
  200. // Put adds a value to the pool.
  201. Put(interface{})
  202. }
  203. // writePoolData is the type added to the write buffer pool. This wrapper is
  204. // used to prevent applications from peeking at and depending on the values
  205. // added to the pool.
  206. type writePoolData struct{ buf []byte }
  207. // The Conn type represents a WebSocket connection.
  208. type Conn struct {
  209. conn net.Conn
  210. isServer bool
  211. subprotocol string
  212. // Write fields
  213. mu chan struct{} // used as mutex to protect write to conn
  214. writeBuf []byte // frame is constructed in this buffer.
  215. writePool BufferPool
  216. writeBufSize int
  217. writeDeadline time.Time
  218. writer io.WriteCloser // the current writer returned to the application
  219. isWriting bool // for best-effort concurrent write detection
  220. writeErrMu sync.Mutex
  221. writeErr error
  222. enableWriteCompression bool
  223. compressionLevel int
  224. newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
  225. // Read fields
  226. reader io.ReadCloser // the current reader returned to the application
  227. readErr error
  228. br *bufio.Reader
  229. // bytes remaining in current frame.
  230. // set setReadRemaining to safely update this value and prevent overflow
  231. readRemaining int64
  232. readFinal bool // true the current message has more frames.
  233. readLength int64 // Message size.
  234. readLimit int64 // Maximum message size.
  235. readMaskPos int
  236. readMaskKey [4]byte
  237. handlePong func(string) error
  238. handlePing func(string) error
  239. handleClose func(int, string) error
  240. readErrCount int
  241. messageReader *messageReader // the current low-level reader
  242. readDecompress bool // whether last read frame had RSV1 set
  243. newDecompressionReader func(io.Reader) io.ReadCloser
  244. }
  245. func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn {
  246. if br == nil {
  247. if readBufferSize == 0 {
  248. readBufferSize = defaultReadBufferSize
  249. } else if readBufferSize < maxControlFramePayloadSize {
  250. // must be large enough for control frame
  251. readBufferSize = maxControlFramePayloadSize
  252. }
  253. br = bufio.NewReaderSize(conn, readBufferSize)
  254. }
  255. if writeBufferSize <= 0 {
  256. writeBufferSize = defaultWriteBufferSize
  257. }
  258. writeBufferSize += maxFrameHeaderSize
  259. if writeBuf == nil && writeBufferPool == nil {
  260. writeBuf = make([]byte, writeBufferSize)
  261. }
  262. mu := make(chan struct{}, 1)
  263. mu <- struct{}{}
  264. c := &Conn{
  265. isServer: isServer,
  266. br: br,
  267. conn: conn,
  268. mu: mu,
  269. readFinal: true,
  270. writeBuf: writeBuf,
  271. writePool: writeBufferPool,
  272. writeBufSize: writeBufferSize,
  273. enableWriteCompression: true,
  274. compressionLevel: defaultCompressionLevel,
  275. }
  276. c.SetCloseHandler(nil)
  277. c.SetPingHandler(nil)
  278. c.SetPongHandler(nil)
  279. return c
  280. }
  281. // setReadRemaining tracks the number of bytes remaining on the connection. If n
  282. // overflows, an ErrReadLimit is returned.
  283. func (c *Conn) setReadRemaining(n int64) error {
  284. if n < 0 {
  285. return ErrReadLimit
  286. }
  287. c.readRemaining = n
  288. return nil
  289. }
  290. // Subprotocol returns the negotiated protocol for the connection.
  291. func (c *Conn) Subprotocol() string {
  292. return c.subprotocol
  293. }
  294. // Close closes the underlying network connection without sending or waiting
  295. // for a close message.
  296. func (c *Conn) Close() error {
  297. return c.conn.Close()
  298. }
  299. // LocalAddr returns the local network address.
  300. func (c *Conn) LocalAddr() net.Addr {
  301. return c.conn.LocalAddr()
  302. }
  303. // RemoteAddr returns the remote network address.
  304. func (c *Conn) RemoteAddr() net.Addr {
  305. return c.conn.RemoteAddr()
  306. }
  307. // Write methods
  308. func (c *Conn) writeFatal(err error) error {
  309. err = hideTempErr(err)
  310. c.writeErrMu.Lock()
  311. if c.writeErr == nil {
  312. c.writeErr = err
  313. }
  314. c.writeErrMu.Unlock()
  315. return err
  316. }
  317. func (c *Conn) read(n int) ([]byte, error) {
  318. p, err := c.br.Peek(n)
  319. if err == io.EOF {
  320. err = errUnexpectedEOF
  321. }
  322. c.br.Discard(len(p))
  323. return p, err
  324. }
  325. func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
  326. <-c.mu
  327. defer func() { c.mu <- struct{}{} }()
  328. c.writeErrMu.Lock()
  329. err := c.writeErr
  330. c.writeErrMu.Unlock()
  331. if err != nil {
  332. return err
  333. }
  334. c.conn.SetWriteDeadline(deadline)
  335. if len(buf1) == 0 {
  336. _, err = c.conn.Write(buf0)
  337. } else {
  338. err = c.writeBufs(buf0, buf1)
  339. }
  340. if err != nil {
  341. return c.writeFatal(err)
  342. }
  343. if frameType == CloseMessage {
  344. c.writeFatal(ErrCloseSent)
  345. }
  346. return nil
  347. }
  348. func (c *Conn) writeBufs(bufs ...[]byte) error {
  349. b := net.Buffers(bufs)
  350. _, err := b.WriteTo(c.conn)
  351. return err
  352. }
  353. // WriteControl writes a control message with the given deadline. The allowed
  354. // message types are CloseMessage, PingMessage and PongMessage.
  355. func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
  356. if !isControl(messageType) {
  357. return errBadWriteOpCode
  358. }
  359. if len(data) > maxControlFramePayloadSize {
  360. return errInvalidControlFrame
  361. }
  362. b0 := byte(messageType) | finalBit
  363. b1 := byte(len(data))
  364. if !c.isServer {
  365. b1 |= maskBit
  366. }
  367. buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize)
  368. buf = append(buf, b0, b1)
  369. if c.isServer {
  370. buf = append(buf, data...)
  371. } else {
  372. key := newMaskKey()
  373. buf = append(buf, key[:]...)
  374. buf = append(buf, data...)
  375. maskBytes(key, 0, buf[6:])
  376. }
  377. d := 1000 * time.Hour
  378. if !deadline.IsZero() {
  379. d = deadline.Sub(time.Now())
  380. if d < 0 {
  381. return errWriteTimeout
  382. }
  383. }
  384. timer := time.NewTimer(d)
  385. select {
  386. case <-c.mu:
  387. timer.Stop()
  388. case <-timer.C:
  389. return errWriteTimeout
  390. }
  391. defer func() { c.mu <- struct{}{} }()
  392. c.writeErrMu.Lock()
  393. err := c.writeErr
  394. c.writeErrMu.Unlock()
  395. if err != nil {
  396. return err
  397. }
  398. c.conn.SetWriteDeadline(deadline)
  399. _, err = c.conn.Write(buf)
  400. if err != nil {
  401. return c.writeFatal(err)
  402. }
  403. if messageType == CloseMessage {
  404. c.writeFatal(ErrCloseSent)
  405. }
  406. return err
  407. }
  408. // beginMessage prepares a connection and message writer for a new message.
  409. func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
  410. // Close previous writer if not already closed by the application. It's
  411. // probably better to return an error in this situation, but we cannot
  412. // change this without breaking existing applications.
  413. if c.writer != nil {
  414. c.writer.Close()
  415. c.writer = nil
  416. }
  417. if !isControl(messageType) && !isData(messageType) {
  418. return errBadWriteOpCode
  419. }
  420. c.writeErrMu.Lock()
  421. err := c.writeErr
  422. c.writeErrMu.Unlock()
  423. if err != nil {
  424. return err
  425. }
  426. mw.c = c
  427. mw.frameType = messageType
  428. mw.pos = maxFrameHeaderSize
  429. if c.writeBuf == nil {
  430. wpd, ok := c.writePool.Get().(writePoolData)
  431. if ok {
  432. c.writeBuf = wpd.buf
  433. } else {
  434. c.writeBuf = make([]byte, c.writeBufSize)
  435. }
  436. }
  437. return nil
  438. }
  439. // NextWriter returns a writer for the next message to send. The writer's Close
  440. // method flushes the complete message to the network.
  441. //
  442. // There can be at most one open writer on a connection. NextWriter closes the
  443. // previous writer if the application has not already done so.
  444. //
  445. // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
  446. // PongMessage) are supported.
  447. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
  448. var mw messageWriter
  449. if err := c.beginMessage(&mw, messageType); err != nil {
  450. return nil, err
  451. }
  452. c.writer = &mw
  453. if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
  454. w := c.newCompressionWriter(c.writer, c.compressionLevel)
  455. mw.compress = true
  456. c.writer = w
  457. }
  458. return c.writer, nil
  459. }
  460. type messageWriter struct {
  461. c *Conn
  462. compress bool // whether next call to flushFrame should set RSV1
  463. pos int // end of data in writeBuf.
  464. frameType int // type of the current frame.
  465. err error
  466. }
  467. func (w *messageWriter) endMessage(err error) error {
  468. if w.err != nil {
  469. return err
  470. }
  471. c := w.c
  472. w.err = err
  473. c.writer = nil
  474. if c.writePool != nil {
  475. c.writePool.Put(writePoolData{buf: c.writeBuf})
  476. c.writeBuf = nil
  477. }
  478. return err
  479. }
  480. // flushFrame writes buffered data and extra as a frame to the network. The
  481. // final argument indicates that this is the last frame in the message.
  482. func (w *messageWriter) flushFrame(final bool, extra []byte) error {
  483. c := w.c
  484. length := w.pos - maxFrameHeaderSize + len(extra)
  485. // Check for invalid control frames.
  486. if isControl(w.frameType) &&
  487. (!final || length > maxControlFramePayloadSize) {
  488. return w.endMessage(errInvalidControlFrame)
  489. }
  490. b0 := byte(w.frameType)
  491. if final {
  492. b0 |= finalBit
  493. }
  494. if w.compress {
  495. b0 |= rsv1Bit
  496. }
  497. w.compress = false
  498. b1 := byte(0)
  499. if !c.isServer {
  500. b1 |= maskBit
  501. }
  502. // Assume that the frame starts at beginning of c.writeBuf.
  503. framePos := 0
  504. if c.isServer {
  505. // Adjust up if mask not included in the header.
  506. framePos = 4
  507. }
  508. switch {
  509. case length >= 65536:
  510. c.writeBuf[framePos] = b0
  511. c.writeBuf[framePos+1] = b1 | 127
  512. binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length))
  513. case length > 125:
  514. framePos += 6
  515. c.writeBuf[framePos] = b0
  516. c.writeBuf[framePos+1] = b1 | 126
  517. binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length))
  518. default:
  519. framePos += 8
  520. c.writeBuf[framePos] = b0
  521. c.writeBuf[framePos+1] = b1 | byte(length)
  522. }
  523. if !c.isServer {
  524. key := newMaskKey()
  525. copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
  526. maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
  527. if len(extra) > 0 {
  528. return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
  529. }
  530. }
  531. // Write the buffers to the connection with best-effort detection of
  532. // concurrent writes. See the concurrency section in the package
  533. // documentation for more info.
  534. if c.isWriting {
  535. panic("concurrent write to websocket connection")
  536. }
  537. c.isWriting = true
  538. err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
  539. if !c.isWriting {
  540. panic("concurrent write to websocket connection")
  541. }
  542. c.isWriting = false
  543. if err != nil {
  544. return w.endMessage(err)
  545. }
  546. if final {
  547. w.endMessage(errWriteClosed)
  548. return nil
  549. }
  550. // Setup for next frame.
  551. w.pos = maxFrameHeaderSize
  552. w.frameType = continuationFrame
  553. return nil
  554. }
  555. func (w *messageWriter) ncopy(max int) (int, error) {
  556. n := len(w.c.writeBuf) - w.pos
  557. if n <= 0 {
  558. if err := w.flushFrame(false, nil); err != nil {
  559. return 0, err
  560. }
  561. n = len(w.c.writeBuf) - w.pos
  562. }
  563. if n > max {
  564. n = max
  565. }
  566. return n, nil
  567. }
  568. func (w *messageWriter) Write(p []byte) (int, error) {
  569. if w.err != nil {
  570. return 0, w.err
  571. }
  572. if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
  573. // Don't buffer large messages.
  574. err := w.flushFrame(false, p)
  575. if err != nil {
  576. return 0, err
  577. }
  578. return len(p), nil
  579. }
  580. nn := len(p)
  581. for len(p) > 0 {
  582. n, err := w.ncopy(len(p))
  583. if err != nil {
  584. return 0, err
  585. }
  586. copy(w.c.writeBuf[w.pos:], p[:n])
  587. w.pos += n
  588. p = p[n:]
  589. }
  590. return nn, nil
  591. }
  592. func (w *messageWriter) WriteString(p string) (int, error) {
  593. if w.err != nil {
  594. return 0, w.err
  595. }
  596. nn := len(p)
  597. for len(p) > 0 {
  598. n, err := w.ncopy(len(p))
  599. if err != nil {
  600. return 0, err
  601. }
  602. copy(w.c.writeBuf[w.pos:], p[:n])
  603. w.pos += n
  604. p = p[n:]
  605. }
  606. return nn, nil
  607. }
  608. func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
  609. if w.err != nil {
  610. return 0, w.err
  611. }
  612. for {
  613. if w.pos == len(w.c.writeBuf) {
  614. err = w.flushFrame(false, nil)
  615. if err != nil {
  616. break
  617. }
  618. }
  619. var n int
  620. n, err = r.Read(w.c.writeBuf[w.pos:])
  621. w.pos += n
  622. nn += int64(n)
  623. if err != nil {
  624. if err == io.EOF {
  625. err = nil
  626. }
  627. break
  628. }
  629. }
  630. return nn, err
  631. }
  632. func (w *messageWriter) Close() error {
  633. if w.err != nil {
  634. return w.err
  635. }
  636. return w.flushFrame(true, nil)
  637. }
  638. // WritePreparedMessage writes prepared message into connection.
  639. func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
  640. frameType, frameData, err := pm.frame(prepareKey{
  641. isServer: c.isServer,
  642. compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
  643. compressionLevel: c.compressionLevel,
  644. })
  645. if err != nil {
  646. return err
  647. }
  648. if c.isWriting {
  649. panic("concurrent write to websocket connection")
  650. }
  651. c.isWriting = true
  652. err = c.write(frameType, c.writeDeadline, frameData, nil)
  653. if !c.isWriting {
  654. panic("concurrent write to websocket connection")
  655. }
  656. c.isWriting = false
  657. return err
  658. }
  659. // WriteMessage is a helper method for getting a writer using NextWriter,
  660. // writing the message and closing the writer.
  661. func (c *Conn) WriteMessage(messageType int, data []byte) error {
  662. if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
  663. // Fast path with no allocations and single frame.
  664. var mw messageWriter
  665. if err := c.beginMessage(&mw, messageType); err != nil {
  666. return err
  667. }
  668. n := copy(c.writeBuf[mw.pos:], data)
  669. mw.pos += n
  670. data = data[n:]
  671. return mw.flushFrame(true, data)
  672. }
  673. w, err := c.NextWriter(messageType)
  674. if err != nil {
  675. return err
  676. }
  677. if _, err = w.Write(data); err != nil {
  678. return err
  679. }
  680. return w.Close()
  681. }
  682. // SetWriteDeadline sets the write deadline on the underlying network
  683. // connection. After a write has timed out, the websocket state is corrupt and
  684. // all future writes will return an error. A zero value for t means writes will
  685. // not time out.
  686. func (c *Conn) SetWriteDeadline(t time.Time) error {
  687. c.writeDeadline = t
  688. return nil
  689. }
  690. // Read methods
  691. func (c *Conn) advanceFrame() (int, error) {
  692. // 1. Skip remainder of previous frame.
  693. if c.readRemaining > 0 {
  694. if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
  695. return noFrame, err
  696. }
  697. }
  698. // 2. Read and parse first two bytes of frame header.
  699. // To aid debugging, collect and report all errors in the first two bytes
  700. // of the header.
  701. var errors []string
  702. p, err := c.read(2)
  703. if err != nil {
  704. return noFrame, err
  705. }
  706. frameType := int(p[0] & 0xf)
  707. final := p[0]&finalBit != 0
  708. rsv1 := p[0]&rsv1Bit != 0
  709. rsv2 := p[0]&rsv2Bit != 0
  710. rsv3 := p[0]&rsv3Bit != 0
  711. mask := p[1]&maskBit != 0
  712. c.setReadRemaining(int64(p[1] & 0x7f))
  713. c.readDecompress = false
  714. if rsv1 {
  715. if c.newDecompressionReader != nil {
  716. c.readDecompress = true
  717. } else {
  718. errors = append(errors, "RSV1 set")
  719. }
  720. }
  721. if rsv2 {
  722. errors = append(errors, "RSV2 set")
  723. }
  724. if rsv3 {
  725. errors = append(errors, "RSV3 set")
  726. }
  727. switch frameType {
  728. case CloseMessage, PingMessage, PongMessage:
  729. if c.readRemaining > maxControlFramePayloadSize {
  730. errors = append(errors, "len > 125 for control")
  731. }
  732. if !final {
  733. errors = append(errors, "FIN not set on control")
  734. }
  735. case TextMessage, BinaryMessage:
  736. if !c.readFinal {
  737. errors = append(errors, "data before FIN")
  738. }
  739. c.readFinal = final
  740. case continuationFrame:
  741. if c.readFinal {
  742. errors = append(errors, "continuation after FIN")
  743. }
  744. c.readFinal = final
  745. default:
  746. errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
  747. }
  748. if mask != c.isServer {
  749. errors = append(errors, "bad MASK")
  750. }
  751. if len(errors) > 0 {
  752. return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
  753. }
  754. // 3. Read and parse frame length as per
  755. // https://tools.ietf.org/html/rfc6455#section-5.2
  756. //
  757. // The length of the "Payload data", in bytes: if 0-125, that is the payload
  758. // length.
  759. // - If 126, the following 2 bytes interpreted as a 16-bit unsigned
  760. // integer are the payload length.
  761. // - If 127, the following 8 bytes interpreted as
  762. // a 64-bit unsigned integer (the most significant bit MUST be 0) are the
  763. // payload length. Multibyte length quantities are expressed in network byte
  764. // order.
  765. switch c.readRemaining {
  766. case 126:
  767. p, err := c.read(2)
  768. if err != nil {
  769. return noFrame, err
  770. }
  771. if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
  772. return noFrame, err
  773. }
  774. case 127:
  775. p, err := c.read(8)
  776. if err != nil {
  777. return noFrame, err
  778. }
  779. if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
  780. return noFrame, err
  781. }
  782. }
  783. // 4. Handle frame masking.
  784. if mask {
  785. c.readMaskPos = 0
  786. p, err := c.read(len(c.readMaskKey))
  787. if err != nil {
  788. return noFrame, err
  789. }
  790. copy(c.readMaskKey[:], p)
  791. }
  792. // 5. For text and binary messages, enforce read limit and return.
  793. if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
  794. c.readLength += c.readRemaining
  795. // Don't allow readLength to overflow in the presence of a large readRemaining
  796. // counter.
  797. if c.readLength < 0 {
  798. return noFrame, ErrReadLimit
  799. }
  800. if c.readLimit > 0 && c.readLength > c.readLimit {
  801. c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
  802. return noFrame, ErrReadLimit
  803. }
  804. return frameType, nil
  805. }
  806. // 6. Read control frame payload.
  807. var payload []byte
  808. if c.readRemaining > 0 {
  809. payload, err = c.read(int(c.readRemaining))
  810. c.setReadRemaining(0)
  811. if err != nil {
  812. return noFrame, err
  813. }
  814. if c.isServer {
  815. maskBytes(c.readMaskKey, 0, payload)
  816. }
  817. }
  818. // 7. Process control frame payload.
  819. switch frameType {
  820. case PongMessage:
  821. if err := c.handlePong(string(payload)); err != nil {
  822. return noFrame, err
  823. }
  824. case PingMessage:
  825. if err := c.handlePing(string(payload)); err != nil {
  826. return noFrame, err
  827. }
  828. case CloseMessage:
  829. closeCode := CloseNoStatusReceived
  830. closeText := ""
  831. if len(payload) >= 2 {
  832. closeCode = int(binary.BigEndian.Uint16(payload))
  833. if !isValidReceivedCloseCode(closeCode) {
  834. return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
  835. }
  836. closeText = string(payload[2:])
  837. if !utf8.ValidString(closeText) {
  838. return noFrame, c.handleProtocolError("invalid utf8 payload in close frame")
  839. }
  840. }
  841. if err := c.handleClose(closeCode, closeText); err != nil {
  842. return noFrame, err
  843. }
  844. return noFrame, &CloseError{Code: closeCode, Text: closeText}
  845. }
  846. return frameType, nil
  847. }
  848. func (c *Conn) handleProtocolError(message string) error {
  849. data := FormatCloseMessage(CloseProtocolError, message)
  850. if len(data) > maxControlFramePayloadSize {
  851. data = data[:maxControlFramePayloadSize]
  852. }
  853. c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
  854. return errors.New("websocket: " + message)
  855. }
  856. // NextReader returns the next data message received from the peer. The
  857. // returned messageType is either TextMessage or BinaryMessage.
  858. //
  859. // There can be at most one open reader on a connection. NextReader discards
  860. // the previous message if the application has not already consumed it.
  861. //
  862. // Applications must break out of the application's read loop when this method
  863. // returns a non-nil error value. Errors returned from this method are
  864. // permanent. Once this method returns a non-nil error, all subsequent calls to
  865. // this method return the same error.
  866. func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
  867. // Close previous reader, only relevant for decompression.
  868. if c.reader != nil {
  869. c.reader.Close()
  870. c.reader = nil
  871. }
  872. c.messageReader = nil
  873. c.readLength = 0
  874. for c.readErr == nil {
  875. frameType, err := c.advanceFrame()
  876. if err != nil {
  877. c.readErr = hideTempErr(err)
  878. break
  879. }
  880. if frameType == TextMessage || frameType == BinaryMessage {
  881. c.messageReader = &messageReader{c}
  882. c.reader = c.messageReader
  883. if c.readDecompress {
  884. c.reader = c.newDecompressionReader(c.reader)
  885. }
  886. return frameType, c.reader, nil
  887. }
  888. }
  889. // Applications that do handle the error returned from this method spin in
  890. // tight loop on connection failure. To help application developers detect
  891. // this error, panic on repeated reads to the failed connection.
  892. c.readErrCount++
  893. if c.readErrCount >= 1000 {
  894. panic("repeated read on failed websocket connection")
  895. }
  896. return noFrame, nil, c.readErr
  897. }
  898. type messageReader struct{ c *Conn }
  899. func (r *messageReader) Read(b []byte) (int, error) {
  900. c := r.c
  901. if c.messageReader != r {
  902. return 0, io.EOF
  903. }
  904. for c.readErr == nil {
  905. if c.readRemaining > 0 {
  906. if int64(len(b)) > c.readRemaining {
  907. b = b[:c.readRemaining]
  908. }
  909. n, err := c.br.Read(b)
  910. c.readErr = hideTempErr(err)
  911. if c.isServer {
  912. c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
  913. }
  914. rem := c.readRemaining
  915. rem -= int64(n)
  916. c.setReadRemaining(rem)
  917. if c.readRemaining > 0 && c.readErr == io.EOF {
  918. c.readErr = errUnexpectedEOF
  919. }
  920. return n, c.readErr
  921. }
  922. if c.readFinal {
  923. c.messageReader = nil
  924. return 0, io.EOF
  925. }
  926. frameType, err := c.advanceFrame()
  927. switch {
  928. case err != nil:
  929. c.readErr = hideTempErr(err)
  930. case frameType == TextMessage || frameType == BinaryMessage:
  931. c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
  932. }
  933. }
  934. err := c.readErr
  935. if err == io.EOF && c.messageReader == r {
  936. err = errUnexpectedEOF
  937. }
  938. return 0, err
  939. }
  940. func (r *messageReader) Close() error {
  941. return nil
  942. }
  943. // ReadMessage is a helper method for getting a reader using NextReader and
  944. // reading from that reader to a buffer.
  945. func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
  946. var r io.Reader
  947. messageType, r, err = c.NextReader()
  948. if err != nil {
  949. return messageType, nil, err
  950. }
  951. p, err = ioutil.ReadAll(r)
  952. return messageType, p, err
  953. }
  954. // SetReadDeadline sets the read deadline on the underlying network connection.
  955. // After a read has timed out, the websocket connection state is corrupt and
  956. // all future reads will return an error. A zero value for t means reads will
  957. // not time out.
  958. func (c *Conn) SetReadDeadline(t time.Time) error {
  959. return c.conn.SetReadDeadline(t)
  960. }
  961. // SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
  962. // message exceeds the limit, the connection sends a close message to the peer
  963. // and returns ErrReadLimit to the application.
  964. func (c *Conn) SetReadLimit(limit int64) {
  965. c.readLimit = limit
  966. }
  967. // CloseHandler returns the current close handler
  968. func (c *Conn) CloseHandler() func(code int, text string) error {
  969. return c.handleClose
  970. }
  971. // SetCloseHandler sets the handler for close messages received from the peer.
  972. // The code argument to h is the received close code or CloseNoStatusReceived
  973. // if the close message is empty. The default close handler sends a close
  974. // message back to the peer.
  975. //
  976. // The handler function is called from the NextReader, ReadMessage and message
  977. // reader Read methods. The application must read the connection to process
  978. // close messages as described in the section on Control Messages above.
  979. //
  980. // The connection read methods return a CloseError when a close message is
  981. // received. Most applications should handle close messages as part of their
  982. // normal error handling. Applications should only set a close handler when the
  983. // application must perform some action before sending a close message back to
  984. // the peer.
  985. func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
  986. if h == nil {
  987. h = func(code int, text string) error {
  988. message := FormatCloseMessage(code, "")
  989. c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
  990. return nil
  991. }
  992. }
  993. c.handleClose = h
  994. }
  995. // PingHandler returns the current ping handler
  996. func (c *Conn) PingHandler() func(appData string) error {
  997. return c.handlePing
  998. }
  999. // SetPingHandler sets the handler for ping messages received from the peer.
  1000. // The appData argument to h is the PING message application data. The default
  1001. // ping handler sends a pong to the peer.
  1002. //
  1003. // The handler function is called from the NextReader, ReadMessage and message
  1004. // reader Read methods. The application must read the connection to process
  1005. // ping messages as described in the section on Control Messages above.
  1006. func (c *Conn) SetPingHandler(h func(appData string) error) {
  1007. if h == nil {
  1008. h = func(message string) error {
  1009. err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
  1010. if err == ErrCloseSent {
  1011. return nil
  1012. } else if e, ok := err.(net.Error); ok && e.Temporary() {
  1013. return nil
  1014. }
  1015. return err
  1016. }
  1017. }
  1018. c.handlePing = h
  1019. }
  1020. // PongHandler returns the current pong handler
  1021. func (c *Conn) PongHandler() func(appData string) error {
  1022. return c.handlePong
  1023. }
  1024. // SetPongHandler sets the handler for pong messages received from the peer.
  1025. // The appData argument to h is the PONG message application data. The default
  1026. // pong handler does nothing.
  1027. //
  1028. // The handler function is called from the NextReader, ReadMessage and message
  1029. // reader Read methods. The application must read the connection to process
  1030. // pong messages as described in the section on Control Messages above.
  1031. func (c *Conn) SetPongHandler(h func(appData string) error) {
  1032. if h == nil {
  1033. h = func(string) error { return nil }
  1034. }
  1035. c.handlePong = h
  1036. }
  1037. // UnderlyingConn returns the internal net.Conn. This can be used to further
  1038. // modifications to connection specific flags.
  1039. func (c *Conn) UnderlyingConn() net.Conn {
  1040. return c.conn
  1041. }
  1042. // EnableWriteCompression enables and disables write compression of
  1043. // subsequent text and binary messages. This function is a noop if
  1044. // compression was not negotiated with the peer.
  1045. func (c *Conn) EnableWriteCompression(enable bool) {
  1046. c.enableWriteCompression = enable
  1047. }
  1048. // SetCompressionLevel sets the flate compression level for subsequent text and
  1049. // binary messages. This function is a noop if compression was not negotiated
  1050. // with the peer. See the compress/flate package for a description of
  1051. // compression levels.
  1052. func (c *Conn) SetCompressionLevel(level int) error {
  1053. if !isValidCompressionLevel(level) {
  1054. return errors.New("websocket: invalid compression level")
  1055. }
  1056. c.compressionLevel = level
  1057. return nil
  1058. }
  1059. // FormatCloseMessage formats closeCode and text as a WebSocket close message.
  1060. // An empty message is returned for code CloseNoStatusReceived.
  1061. func FormatCloseMessage(closeCode int, text string) []byte {
  1062. if closeCode == CloseNoStatusReceived {
  1063. // Return empty message because it's illegal to send
  1064. // CloseNoStatusReceived. Return non-nil value in case application
  1065. // checks for nil.
  1066. return []byte{}
  1067. }
  1068. buf := make([]byte, 2+len(text))
  1069. binary.BigEndian.PutUint16(buf, uint16(closeCode))
  1070. copy(buf[2:], text)
  1071. return buf
  1072. }