websocket.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675
  1. package httpexpect
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "time"
  7. "github.com/gorilla/websocket"
  8. )
  9. var (
  10. noDuration = time.Duration(0)
  11. infiniteTime = time.Time{}
  12. )
  13. // WebsocketConn is used by Websocket to communicate with actual WebSocket connection.
  14. type WebsocketConn interface {
  15. ReadMessage() (messageType int, p []byte, err error)
  16. WriteMessage(messageType int, data []byte) error
  17. Close() error
  18. SetReadDeadline(t time.Time) error
  19. SetWriteDeadline(t time.Time) error
  20. Subprotocol() string
  21. }
  22. // Websocket provides methods to read from, write into and close WebSocket
  23. // connection.
  24. type Websocket struct {
  25. noCopy noCopy
  26. config Config
  27. chain *chain
  28. conn WebsocketConn
  29. readTimeout time.Duration
  30. writeTimeout time.Duration
  31. isClosed bool
  32. }
  33. // Deprecated: use NewWebsocketC instead.
  34. func NewWebsocket(config Config, conn WebsocketConn) *Websocket {
  35. return NewWebsocketC(config, conn)
  36. }
  37. // NewWebsocketC returns a new Websocket instance.
  38. //
  39. // Requirements for config are same as for WithConfig function.
  40. func NewWebsocketC(config Config, conn WebsocketConn) *Websocket {
  41. config = config.withDefaults()
  42. return newWebsocket(
  43. newChainWithConfig("Websocket()", config),
  44. config,
  45. conn,
  46. )
  47. }
  48. func newWebsocket(parent *chain, config Config, conn WebsocketConn) *Websocket {
  49. config.validate()
  50. return &Websocket{
  51. config: config,
  52. chain: parent.clone(),
  53. conn: conn,
  54. }
  55. }
  56. // Conn returns underlying WebsocketConn object.
  57. // This is the value originally passed to NewConnection.
  58. func (ws *Websocket) Conn() WebsocketConn {
  59. return ws.conn
  60. }
  61. // Deprecated: use Conn instead.
  62. func (ws *Websocket) Raw() *websocket.Conn {
  63. if ws.conn == nil {
  64. return nil
  65. }
  66. conn, ok := ws.conn.(*websocket.Conn)
  67. if !ok {
  68. return nil
  69. }
  70. return conn
  71. }
  72. // Alias is similar to Value.Alias.
  73. func (ws *Websocket) Alias(name string) *Websocket {
  74. opChain := ws.chain.enter("Alias(%q)", name)
  75. defer opChain.leave()
  76. ws.chain.setAlias(name)
  77. return ws
  78. }
  79. // WithReadTimeout sets timeout duration for WebSocket connection reads.
  80. //
  81. // By default no timeout is used.
  82. func (ws *Websocket) WithReadTimeout(timeout time.Duration) *Websocket {
  83. opChain := ws.chain.enter("WithReadTimeout()")
  84. defer opChain.leave()
  85. if opChain.failed() {
  86. return ws
  87. }
  88. ws.readTimeout = timeout
  89. return ws
  90. }
  91. // WithoutReadTimeout removes timeout for WebSocket connection reads.
  92. func (ws *Websocket) WithoutReadTimeout() *Websocket {
  93. opChain := ws.chain.enter("WithoutReadTimeout()")
  94. defer opChain.leave()
  95. if opChain.failed() {
  96. return ws
  97. }
  98. ws.readTimeout = noDuration
  99. return ws
  100. }
  101. // WithWriteTimeout sets timeout duration for WebSocket connection writes.
  102. //
  103. // By default no timeout is used.
  104. func (ws *Websocket) WithWriteTimeout(timeout time.Duration) *Websocket {
  105. opChain := ws.chain.enter("WithWriteTimeout()")
  106. defer opChain.leave()
  107. if opChain.failed() {
  108. return ws
  109. }
  110. ws.writeTimeout = timeout
  111. return ws
  112. }
  113. // WithoutWriteTimeout removes timeout for WebSocket connection writes.
  114. //
  115. // If not used then DefaultWebsocketTimeout will be used.
  116. func (ws *Websocket) WithoutWriteTimeout() *Websocket {
  117. opChain := ws.chain.enter("WithoutWriteTimeout()")
  118. defer opChain.leave()
  119. if opChain.failed() {
  120. return ws
  121. }
  122. ws.writeTimeout = noDuration
  123. return ws
  124. }
  125. // Subprotocol returns a new String instance with negotiated protocol
  126. // for the connection.
  127. func (ws *Websocket) Subprotocol() *String {
  128. opChain := ws.chain.enter("Subprotocol()")
  129. defer opChain.leave()
  130. if opChain.failed() {
  131. return newString(opChain, "")
  132. }
  133. if ws.conn == nil {
  134. return newString(opChain, "")
  135. }
  136. return newString(opChain, ws.conn.Subprotocol())
  137. }
  138. // Expect reads next message from WebSocket connection and
  139. // returns a new WebsocketMessage instance.
  140. //
  141. // Example:
  142. //
  143. // msg := conn.Expect()
  144. // msg.JSON().Object().HasValue("message", "hi")
  145. func (ws *Websocket) Expect() *WebsocketMessage {
  146. opChain := ws.chain.enter("Expect()")
  147. defer opChain.leave()
  148. if ws.checkUnusable(opChain, "Expect()") {
  149. return newEmptyWebsocketMessage(opChain)
  150. }
  151. m := ws.readMessage(opChain)
  152. if m == nil {
  153. return newEmptyWebsocketMessage(opChain)
  154. }
  155. return m
  156. }
  157. // Disconnect closes the underlying WebSocket connection without sending or
  158. // waiting for a close message.
  159. //
  160. // It's okay to call this function multiple times.
  161. //
  162. // It's recommended to always call this function after connection usage is over
  163. // to ensure that no resource leaks will happen.
  164. //
  165. // Example:
  166. //
  167. // conn := resp.Connection()
  168. // defer conn.Disconnect()
  169. func (ws *Websocket) Disconnect() *Websocket {
  170. opChain := ws.chain.enter("Disconnect()")
  171. defer opChain.leave()
  172. if ws.conn == nil || ws.isClosed {
  173. return ws
  174. }
  175. ws.isClosed = true
  176. if err := ws.conn.Close(); err != nil {
  177. opChain.fail(AssertionFailure{
  178. Type: AssertOperation,
  179. Errors: []error{
  180. errors.New("got close error when disconnecting websocket"),
  181. err,
  182. },
  183. })
  184. }
  185. return ws
  186. }
  187. // Close cleanly closes the underlying WebSocket connection
  188. // by sending an empty close message and then waiting (with timeout)
  189. // for the server to close the connection.
  190. //
  191. // WebSocket close code may be optionally specified.
  192. // If not, then "1000 - Normal Closure" will be used.
  193. //
  194. // WebSocket close codes are defined in RFC 6455, section 11.7.
  195. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
  196. //
  197. // It's okay to call this function multiple times.
  198. //
  199. // Example:
  200. //
  201. // conn := resp.Connection()
  202. // conn.Close(websocket.CloseUnsupportedData)
  203. func (ws *Websocket) Close(code ...int) *Websocket {
  204. opChain := ws.chain.enter("Close()")
  205. defer opChain.leave()
  206. switch {
  207. case ws.checkUnusable(opChain, "Close()"):
  208. return ws
  209. case len(code) > 1:
  210. opChain.fail(AssertionFailure{
  211. Type: AssertUsage,
  212. Errors: []error{
  213. errors.New("unexpected multiple code arguments"),
  214. },
  215. })
  216. return ws
  217. }
  218. ws.writeMessage(opChain, websocket.CloseMessage, nil, code...)
  219. return ws
  220. }
  221. // CloseWithBytes cleanly closes the underlying WebSocket connection
  222. // by sending given slice of bytes as a close message and then waiting
  223. // (with timeout) for the server to close the connection.
  224. //
  225. // WebSocket close code may be optionally specified.
  226. // If not, then "1000 - Normal Closure" will be used.
  227. //
  228. // WebSocket close codes are defined in RFC 6455, section 11.7.
  229. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
  230. //
  231. // It's okay to call this function multiple times.
  232. //
  233. // Example:
  234. //
  235. // conn := resp.Connection()
  236. // conn.CloseWithBytes([]byte("bye!"), websocket.CloseGoingAway)
  237. func (ws *Websocket) CloseWithBytes(b []byte, code ...int) *Websocket {
  238. opChain := ws.chain.enter("CloseWithBytes()")
  239. defer opChain.leave()
  240. switch {
  241. case ws.checkUnusable(opChain, "CloseWithBytes()"):
  242. return ws
  243. case len(code) > 1:
  244. opChain.fail(AssertionFailure{
  245. Type: AssertUsage,
  246. Errors: []error{
  247. errors.New("unexpected multiple code arguments"),
  248. },
  249. })
  250. return ws
  251. }
  252. ws.writeMessage(opChain, websocket.CloseMessage, b, code...)
  253. return ws
  254. }
  255. // CloseWithJSON cleanly closes the underlying WebSocket connection
  256. // by sending given object (marshaled using json.Marshal()) as a close message
  257. // and then waiting (with timeout) for the server to close the connection.
  258. //
  259. // WebSocket close code may be optionally specified.
  260. // If not, then "1000 - Normal Closure" will be used.
  261. //
  262. // WebSocket close codes are defined in RFC 6455, section 11.7.
  263. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
  264. //
  265. // It's okay to call this function multiple times.
  266. //
  267. // Example:
  268. //
  269. // type MyJSON struct {
  270. // Foo int `json:"foo"`
  271. // }
  272. //
  273. // conn := resp.Connection()
  274. // conn.CloseWithJSON(MyJSON{Foo: 123}, websocket.CloseUnsupportedData)
  275. func (ws *Websocket) CloseWithJSON(
  276. object interface{}, code ...int,
  277. ) *Websocket {
  278. opChain := ws.chain.enter("CloseWithJSON()")
  279. defer opChain.leave()
  280. switch {
  281. case ws.checkUnusable(opChain, "CloseWithJSON()"):
  282. return ws
  283. case len(code) > 1:
  284. opChain.fail(AssertionFailure{
  285. Type: AssertUsage,
  286. Errors: []error{
  287. errors.New("unexpected multiple code arguments"),
  288. },
  289. })
  290. return ws
  291. }
  292. b, err := json.Marshal(object)
  293. if err != nil {
  294. opChain.fail(AssertionFailure{
  295. Type: AssertValid,
  296. Actual: &AssertionValue{object},
  297. Errors: []error{
  298. errors.New("invalid json object"),
  299. err,
  300. },
  301. })
  302. return ws
  303. }
  304. ws.writeMessage(opChain, websocket.CloseMessage, b, code...)
  305. return ws
  306. }
  307. // CloseWithText cleanly closes the underlying WebSocket connection
  308. // by sending given text as a close message and then waiting (with timeout)
  309. // for the server to close the connection.
  310. //
  311. // WebSocket close code may be optionally specified.
  312. // If not, then "1000 - Normal Closure" will be used.
  313. //
  314. // WebSocket close codes are defined in RFC 6455, section 11.7.
  315. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
  316. //
  317. // It's okay to call this function multiple times.
  318. //
  319. // Example:
  320. //
  321. // conn := resp.Connection()
  322. // conn.CloseWithText("bye!")
  323. func (ws *Websocket) CloseWithText(s string, code ...int) *Websocket {
  324. opChain := ws.chain.enter("CloseWithText()")
  325. defer opChain.leave()
  326. switch {
  327. case ws.checkUnusable(opChain, "CloseWithText()"):
  328. return ws
  329. case len(code) > 1:
  330. opChain.fail(AssertionFailure{
  331. Type: AssertUsage,
  332. Errors: []error{
  333. errors.New("unexpected multiple code arguments"),
  334. },
  335. })
  336. return ws
  337. }
  338. ws.writeMessage(opChain, websocket.CloseMessage, []byte(s), code...)
  339. return ws
  340. }
  341. // WriteMessage writes to the underlying WebSocket connection a message
  342. // of given type with given content.
  343. // Additionally, WebSocket close code may be specified for close messages.
  344. //
  345. // WebSocket message types are defined in RFC 6455, section 11.8.
  346. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
  347. //
  348. // WebSocket close codes are defined in RFC 6455, section 11.7.
  349. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
  350. //
  351. // Example:
  352. //
  353. // conn := resp.Connection()
  354. // conn.WriteMessage(websocket.CloseMessage, []byte("Namárië..."))
  355. func (ws *Websocket) WriteMessage(typ int, content []byte, closeCode ...int) *Websocket {
  356. opChain := ws.chain.enter("WriteMessage()")
  357. defer opChain.leave()
  358. if ws.checkUnusable(opChain, "WriteMessage()") {
  359. return ws
  360. }
  361. ws.writeMessage(opChain, typ, content, closeCode...)
  362. return ws
  363. }
  364. // WriteBytesBinary is a shorthand for c.WriteMessage(websocket.BinaryMessage, b).
  365. func (ws *Websocket) WriteBytesBinary(b []byte) *Websocket {
  366. opChain := ws.chain.enter("WriteBytesBinary()")
  367. defer opChain.leave()
  368. if ws.checkUnusable(opChain, "WriteBytesBinary()") {
  369. return ws
  370. }
  371. ws.writeMessage(opChain, websocket.BinaryMessage, b)
  372. return ws
  373. }
  374. // WriteBytesText is a shorthand for c.WriteMessage(websocket.TextMessage, b).
  375. func (ws *Websocket) WriteBytesText(b []byte) *Websocket {
  376. opChain := ws.chain.enter("WriteBytesText()")
  377. defer opChain.leave()
  378. if ws.checkUnusable(opChain, "WriteBytesText()") {
  379. return ws
  380. }
  381. ws.writeMessage(opChain, websocket.TextMessage, b)
  382. return ws
  383. }
  384. // WriteText is a shorthand for
  385. // c.WriteMessage(websocket.TextMessage, []byte(s)).
  386. func (ws *Websocket) WriteText(s string) *Websocket {
  387. opChain := ws.chain.enter("WriteText()")
  388. defer opChain.leave()
  389. if ws.checkUnusable(opChain, "WriteText()") {
  390. return ws
  391. }
  392. return ws.WriteMessage(websocket.TextMessage, []byte(s))
  393. }
  394. // WriteJSON writes to the underlying WebSocket connection given object,
  395. // marshaled using json.Marshal().
  396. func (ws *Websocket) WriteJSON(object interface{}) *Websocket {
  397. opChain := ws.chain.enter("WriteJSON()")
  398. defer opChain.leave()
  399. if ws.checkUnusable(opChain, "WriteJSON()") {
  400. return ws
  401. }
  402. b, err := json.Marshal(object)
  403. if err != nil {
  404. opChain.fail(AssertionFailure{
  405. Type: AssertValid,
  406. Actual: &AssertionValue{object},
  407. Errors: []error{
  408. errors.New("invalid json object"),
  409. err,
  410. },
  411. })
  412. return ws
  413. }
  414. ws.writeMessage(opChain, websocket.TextMessage, b)
  415. return ws
  416. }
  417. func (ws *Websocket) checkUnusable(opChain *chain, where string) bool {
  418. switch {
  419. case opChain.failed():
  420. return true
  421. case ws.conn == nil:
  422. opChain.fail(AssertionFailure{
  423. Type: AssertUsage,
  424. Errors: []error{
  425. fmt.Errorf("unexpected %s call for failed websocket connection", where),
  426. },
  427. })
  428. return true
  429. case ws.isClosed:
  430. opChain.fail(AssertionFailure{
  431. Type: AssertUsage,
  432. Errors: []error{
  433. fmt.Errorf("unexpected %s call for closed websocket connection", where),
  434. },
  435. })
  436. return true
  437. }
  438. return false
  439. }
  440. func (ws *Websocket) readMessage(opChain *chain) *WebsocketMessage {
  441. wm := newEmptyWebsocketMessage(opChain)
  442. if !ws.setReadDeadline(opChain) {
  443. return nil
  444. }
  445. var err error
  446. wm.typ, wm.content, err = ws.conn.ReadMessage()
  447. if err != nil {
  448. closeErr, ok := err.(*websocket.CloseError)
  449. if !ok {
  450. opChain.fail(AssertionFailure{
  451. Type: AssertOperation,
  452. Errors: []error{
  453. errors.New("failed to read from websocket"),
  454. err,
  455. },
  456. })
  457. return nil
  458. }
  459. wm.typ = websocket.CloseMessage
  460. wm.closeCode = closeErr.Code
  461. wm.content = []byte(closeErr.Text)
  462. }
  463. ws.printRead(wm.typ, wm.content, wm.closeCode)
  464. return wm
  465. }
  466. func (ws *Websocket) writeMessage(
  467. opChain *chain, typ int, content []byte, closeCode ...int,
  468. ) {
  469. switch typ {
  470. case websocket.TextMessage, websocket.BinaryMessage:
  471. ws.printWrite(typ, content, 0)
  472. case websocket.CloseMessage:
  473. if len(closeCode) > 1 {
  474. opChain.fail(AssertionFailure{
  475. Type: AssertUsage,
  476. Errors: []error{
  477. errors.New("unexpected multiple closeCode arguments"),
  478. },
  479. })
  480. return
  481. }
  482. code := websocket.CloseNormalClosure
  483. if len(closeCode) > 0 {
  484. code = closeCode[0]
  485. }
  486. ws.printWrite(typ, content, code)
  487. content = websocket.FormatCloseMessage(code, string(content))
  488. default:
  489. opChain.fail(AssertionFailure{
  490. Type: AssertUsage,
  491. Errors: []error{
  492. fmt.Errorf("unexpected websocket message type %s",
  493. wsMessageType(typ)),
  494. },
  495. })
  496. return
  497. }
  498. if !ws.setWriteDeadline(opChain) {
  499. return
  500. }
  501. if err := ws.conn.WriteMessage(typ, content); err != nil {
  502. opChain.fail(AssertionFailure{
  503. Type: AssertOperation,
  504. Errors: []error{
  505. errors.New("failed to write to websocket"),
  506. err,
  507. },
  508. })
  509. return
  510. }
  511. }
  512. func (ws *Websocket) setReadDeadline(opChain *chain) bool {
  513. deadline := infiniteTime
  514. if ws.readTimeout != noDuration {
  515. deadline = time.Now().Add(ws.readTimeout)
  516. }
  517. if err := ws.conn.SetReadDeadline(deadline); err != nil {
  518. opChain.fail(AssertionFailure{
  519. Type: AssertOperation,
  520. Errors: []error{
  521. errors.New("failed to set read deadline for websocket"),
  522. err,
  523. },
  524. })
  525. return false
  526. }
  527. return true
  528. }
  529. func (ws *Websocket) setWriteDeadline(opChain *chain) bool {
  530. deadline := infiniteTime
  531. if ws.writeTimeout != noDuration {
  532. deadline = time.Now().Add(ws.writeTimeout)
  533. }
  534. if err := ws.conn.SetWriteDeadline(deadline); err != nil {
  535. opChain.fail(AssertionFailure{
  536. Type: AssertOperation,
  537. Errors: []error{
  538. errors.New("failed to set write deadline for websocket"),
  539. err,
  540. },
  541. })
  542. return false
  543. }
  544. return true
  545. }
  546. func (ws *Websocket) printRead(typ int, content []byte, closeCode int) {
  547. for _, printer := range ws.config.Printers {
  548. if p, ok := printer.(WebsocketPrinter); ok {
  549. p.WebsocketRead(typ, content, closeCode)
  550. }
  551. }
  552. }
  553. func (ws *Websocket) printWrite(typ int, content []byte, closeCode int) {
  554. for _, printer := range ws.config.Printers {
  555. if p, ok := printer.(WebsocketPrinter); ok {
  556. p.WebsocketWrite(typ, content, closeCode)
  557. }
  558. }
  559. }