reader.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. package wsutil
  2. import (
  3. "errors"
  4. "io"
  5. "io/ioutil"
  6. "github.com/gobwas/ws"
  7. )
  8. // ErrNoFrameAdvance means that Reader's Read() method was called without
  9. // preceding NextFrame() call.
  10. var ErrNoFrameAdvance = errors.New("no frame advance")
  11. // ErrFrameTooLarge indicates that a message of length higher than
  12. // MaxFrameSize was being read.
  13. var ErrFrameTooLarge = errors.New("frame too large")
  14. // FrameHandlerFunc handles parsed frame header and its body represented by
  15. // io.Reader.
  16. //
  17. // Note that reader represents already unmasked body.
  18. type FrameHandlerFunc func(ws.Header, io.Reader) error
  19. // Reader is a wrapper around source io.Reader which represents WebSocket
  20. // connection. It contains options for reading messages from source.
  21. //
  22. // Reader implements io.Reader, which Read() method reads payload of incoming
  23. // WebSocket frames. It also takes care on fragmented frames and possibly
  24. // intermediate control frames between them.
  25. //
  26. // Note that Reader's methods are not goroutine safe.
  27. type Reader struct {
  28. Source io.Reader
  29. State ws.State
  30. // SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
  31. SkipHeaderCheck bool
  32. // CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
  33. // bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
  34. CheckUTF8 bool
  35. // Extensions is a list of negotiated extensions for reader Source.
  36. // It is used to meet the specs and clear appropriate bits in fragment
  37. // header RSV segment.
  38. Extensions []RecvExtension
  39. // MaxFrameSize controls the maximum frame size in bytes
  40. // that can be read. A message exceeding that size will return
  41. // a ErrFrameTooLarge to the application.
  42. //
  43. // Not setting this field means there is no limit.
  44. MaxFrameSize int64
  45. OnContinuation FrameHandlerFunc
  46. OnIntermediate FrameHandlerFunc
  47. opCode ws.OpCode // Used to store message op code on fragmentation.
  48. frame io.Reader // Used to as frame reader.
  49. raw io.LimitedReader // Used to discard frames without cipher.
  50. utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
  51. }
  52. // NewReader creates new frame reader that reads from r keeping given state to
  53. // make some protocol validity checks when it needed.
  54. func NewReader(r io.Reader, s ws.State) *Reader {
  55. return &Reader{
  56. Source: r,
  57. State: s,
  58. }
  59. }
  60. // NewClientSideReader is a helper function that calls NewReader with r and
  61. // ws.StateClientSide.
  62. func NewClientSideReader(r io.Reader) *Reader {
  63. return NewReader(r, ws.StateClientSide)
  64. }
  65. // NewServerSideReader is a helper function that calls NewReader with r and
  66. // ws.StateServerSide.
  67. func NewServerSideReader(r io.Reader) *Reader {
  68. return NewReader(r, ws.StateServerSide)
  69. }
  70. // Read implements io.Reader. It reads the next message payload into p.
  71. // It takes care on fragmented messages.
  72. //
  73. // The error is io.EOF only if all of message bytes were read.
  74. // If an io.EOF happens during reading some but not all the message bytes
  75. // Read() returns io.ErrUnexpectedEOF.
  76. //
  77. // The error is ErrNoFrameAdvance if no NextFrame() call was made before
  78. // reading next message bytes.
  79. func (r *Reader) Read(p []byte) (n int, err error) {
  80. if r.frame == nil {
  81. if !r.fragmented() {
  82. // Every new Read() must be preceded by NextFrame() call.
  83. return 0, ErrNoFrameAdvance
  84. }
  85. // Read next continuation or intermediate control frame.
  86. _, err := r.NextFrame()
  87. if err != nil {
  88. return 0, err
  89. }
  90. if r.frame == nil {
  91. // We handled intermediate control and now got nothing to read.
  92. return 0, nil
  93. }
  94. }
  95. n, err = r.frame.Read(p)
  96. if err != nil && err != io.EOF {
  97. return n, err
  98. }
  99. if err == nil && r.raw.N != 0 {
  100. return n, nil
  101. }
  102. // EOF condition (either err is io.EOF or r.raw.N is zero).
  103. switch {
  104. case r.raw.N != 0:
  105. err = io.ErrUnexpectedEOF
  106. case r.fragmented():
  107. err = nil
  108. r.resetFragment()
  109. case r.CheckUTF8 && !r.utf8.Valid():
  110. // NOTE: check utf8 only when full message received, since partial
  111. // reads may be invalid.
  112. n = r.utf8.Accepted()
  113. err = ErrInvalidUTF8
  114. default:
  115. r.reset()
  116. err = io.EOF
  117. }
  118. return n, err
  119. }
  120. // Discard discards current message unread bytes.
  121. // It discards all frames of fragmented message.
  122. func (r *Reader) Discard() (err error) {
  123. for {
  124. _, err = io.Copy(ioutil.Discard, &r.raw)
  125. if err != nil {
  126. break
  127. }
  128. if !r.fragmented() {
  129. break
  130. }
  131. if _, err = r.NextFrame(); err != nil {
  132. break
  133. }
  134. }
  135. r.reset()
  136. return err
  137. }
  138. // NextFrame prepares r to read next message. It returns received frame header
  139. // and non-nil error on failure.
  140. //
  141. // Note that next NextFrame() call must be done after receiving or discarding
  142. // all current message bytes.
  143. func (r *Reader) NextFrame() (hdr ws.Header, err error) {
  144. hdr, err = ws.ReadHeader(r.Source)
  145. if err == io.EOF && r.fragmented() {
  146. // If we are in fragmented state EOF means that is was totally
  147. // unexpected.
  148. //
  149. // NOTE: This is necessary to prevent callers such that
  150. // ioutil.ReadAll to receive some amount of bytes without an error.
  151. // ReadAll() ignores an io.EOF error, thus caller may think that
  152. // whole message fetched, but actually only part of it.
  153. err = io.ErrUnexpectedEOF
  154. }
  155. if err == nil && !r.SkipHeaderCheck {
  156. err = ws.CheckHeader(hdr, r.State)
  157. }
  158. if err != nil {
  159. return hdr, err
  160. }
  161. if n := r.MaxFrameSize; n > 0 && hdr.Length > n {
  162. return hdr, ErrFrameTooLarge
  163. }
  164. // Save raw reader to use it on discarding frame without ciphering and
  165. // other streaming checks.
  166. r.raw = io.LimitedReader{
  167. R: r.Source,
  168. N: hdr.Length,
  169. }
  170. frame := io.Reader(&r.raw)
  171. if hdr.Masked {
  172. frame = NewCipherReader(frame, hdr.Mask)
  173. }
  174. for _, x := range r.Extensions {
  175. hdr, err = x.UnsetBits(hdr)
  176. if err != nil {
  177. return hdr, err
  178. }
  179. }
  180. if r.fragmented() {
  181. if hdr.OpCode.IsControl() {
  182. if cb := r.OnIntermediate; cb != nil {
  183. err = cb(hdr, frame)
  184. }
  185. if err == nil {
  186. // Ensure that src is empty.
  187. _, err = io.Copy(ioutil.Discard, &r.raw)
  188. }
  189. return hdr, err
  190. }
  191. } else {
  192. r.opCode = hdr.OpCode
  193. }
  194. if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
  195. r.utf8.Source = frame
  196. frame = &r.utf8
  197. }
  198. // Save reader with ciphering and other streaming checks.
  199. r.frame = frame
  200. if hdr.OpCode == ws.OpContinuation {
  201. if cb := r.OnContinuation; cb != nil {
  202. err = cb(hdr, frame)
  203. }
  204. }
  205. if hdr.Fin {
  206. r.State = r.State.Clear(ws.StateFragmented)
  207. } else {
  208. r.State = r.State.Set(ws.StateFragmented)
  209. }
  210. return hdr, err
  211. }
  212. func (r *Reader) fragmented() bool {
  213. return r.State.Fragmented()
  214. }
  215. func (r *Reader) resetFragment() {
  216. r.raw = io.LimitedReader{}
  217. r.frame = nil
  218. // Reset source of the UTF8Reader, but not the state.
  219. r.utf8.Source = nil
  220. }
  221. func (r *Reader) reset() {
  222. r.raw = io.LimitedReader{}
  223. r.frame = nil
  224. r.utf8 = UTF8Reader{}
  225. r.opCode = 0
  226. }
  227. // NextReader prepares next message read from r. It returns header that
  228. // describes the message and io.Reader to read message's payload. It returns
  229. // non-nil error when it is not possible to read message's initial frame.
  230. //
  231. // Note that next NextReader() on the same r should be done after reading all
  232. // bytes from previously returned io.Reader. For more performant way to discard
  233. // message use Reader and its Discard() method.
  234. //
  235. // Note that it will not handle any "intermediate" frames, that possibly could
  236. // be received between text/binary continuation frames. That is, if peer sent
  237. // text/binary frame with fin flag "false", then it could send ping frame, and
  238. // eventually remaining part of text/binary frame with fin "true" – with
  239. // NextReader() the ping frame will be dropped without any notice. To handle
  240. // this rare, but possible situation (and if you do not know exactly which
  241. // frames peer could send), you could use Reader with OnIntermediate field set.
  242. func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
  243. rd := &Reader{
  244. Source: r,
  245. State: s,
  246. }
  247. header, err := rd.NextFrame()
  248. if err != nil {
  249. return header, nil, err
  250. }
  251. return header, rd, nil
  252. }