123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675 |
- package httpexpect
- import (
- "encoding/json"
- "errors"
- "fmt"
- "time"
- "github.com/gorilla/websocket"
- )
- var (
- noDuration = time.Duration(0)
- infiniteTime = time.Time{}
- )
- // WebsocketConn is used by Websocket to communicate with actual WebSocket connection.
- type WebsocketConn interface {
- ReadMessage() (messageType int, p []byte, err error)
- WriteMessage(messageType int, data []byte) error
- Close() error
- SetReadDeadline(t time.Time) error
- SetWriteDeadline(t time.Time) error
- Subprotocol() string
- }
- // Websocket provides methods to read from, write into and close WebSocket
- // connection.
- type Websocket struct {
- noCopy noCopy
- config Config
- chain *chain
- conn WebsocketConn
- readTimeout time.Duration
- writeTimeout time.Duration
- isClosed bool
- }
- // Deprecated: use NewWebsocketC instead.
- func NewWebsocket(config Config, conn WebsocketConn) *Websocket {
- return NewWebsocketC(config, conn)
- }
- // NewWebsocketC returns a new Websocket instance.
- //
- // Requirements for config are same as for WithConfig function.
- func NewWebsocketC(config Config, conn WebsocketConn) *Websocket {
- config = config.withDefaults()
- return newWebsocket(
- newChainWithConfig("Websocket()", config),
- config,
- conn,
- )
- }
- func newWebsocket(parent *chain, config Config, conn WebsocketConn) *Websocket {
- config.validate()
- return &Websocket{
- config: config,
- chain: parent.clone(),
- conn: conn,
- }
- }
- // Conn returns underlying WebsocketConn object.
- // This is the value originally passed to NewConnection.
- func (ws *Websocket) Conn() WebsocketConn {
- return ws.conn
- }
- // Deprecated: use Conn instead.
- func (ws *Websocket) Raw() *websocket.Conn {
- if ws.conn == nil {
- return nil
- }
- conn, ok := ws.conn.(*websocket.Conn)
- if !ok {
- return nil
- }
- return conn
- }
- // Alias is similar to Value.Alias.
- func (ws *Websocket) Alias(name string) *Websocket {
- opChain := ws.chain.enter("Alias(%q)", name)
- defer opChain.leave()
- ws.chain.setAlias(name)
- return ws
- }
- // WithReadTimeout sets timeout duration for WebSocket connection reads.
- //
- // By default no timeout is used.
- func (ws *Websocket) WithReadTimeout(timeout time.Duration) *Websocket {
- opChain := ws.chain.enter("WithReadTimeout()")
- defer opChain.leave()
- if opChain.failed() {
- return ws
- }
- ws.readTimeout = timeout
- return ws
- }
- // WithoutReadTimeout removes timeout for WebSocket connection reads.
- func (ws *Websocket) WithoutReadTimeout() *Websocket {
- opChain := ws.chain.enter("WithoutReadTimeout()")
- defer opChain.leave()
- if opChain.failed() {
- return ws
- }
- ws.readTimeout = noDuration
- return ws
- }
- // WithWriteTimeout sets timeout duration for WebSocket connection writes.
- //
- // By default no timeout is used.
- func (ws *Websocket) WithWriteTimeout(timeout time.Duration) *Websocket {
- opChain := ws.chain.enter("WithWriteTimeout()")
- defer opChain.leave()
- if opChain.failed() {
- return ws
- }
- ws.writeTimeout = timeout
- return ws
- }
- // WithoutWriteTimeout removes timeout for WebSocket connection writes.
- //
- // If not used then DefaultWebsocketTimeout will be used.
- func (ws *Websocket) WithoutWriteTimeout() *Websocket {
- opChain := ws.chain.enter("WithoutWriteTimeout()")
- defer opChain.leave()
- if opChain.failed() {
- return ws
- }
- ws.writeTimeout = noDuration
- return ws
- }
- // Subprotocol returns a new String instance with negotiated protocol
- // for the connection.
- func (ws *Websocket) Subprotocol() *String {
- opChain := ws.chain.enter("Subprotocol()")
- defer opChain.leave()
- if opChain.failed() {
- return newString(opChain, "")
- }
- if ws.conn == nil {
- return newString(opChain, "")
- }
- return newString(opChain, ws.conn.Subprotocol())
- }
- // Expect reads next message from WebSocket connection and
- // returns a new WebsocketMessage instance.
- //
- // Example:
- //
- // msg := conn.Expect()
- // msg.JSON().Object().HasValue("message", "hi")
- func (ws *Websocket) Expect() *WebsocketMessage {
- opChain := ws.chain.enter("Expect()")
- defer opChain.leave()
- if ws.checkUnusable(opChain, "Expect()") {
- return newEmptyWebsocketMessage(opChain)
- }
- m := ws.readMessage(opChain)
- if m == nil {
- return newEmptyWebsocketMessage(opChain)
- }
- return m
- }
- // Disconnect closes the underlying WebSocket connection without sending or
- // waiting for a close message.
- //
- // It's okay to call this function multiple times.
- //
- // It's recommended to always call this function after connection usage is over
- // to ensure that no resource leaks will happen.
- //
- // Example:
- //
- // conn := resp.Connection()
- // defer conn.Disconnect()
- func (ws *Websocket) Disconnect() *Websocket {
- opChain := ws.chain.enter("Disconnect()")
- defer opChain.leave()
- if ws.conn == nil || ws.isClosed {
- return ws
- }
- ws.isClosed = true
- if err := ws.conn.Close(); err != nil {
- opChain.fail(AssertionFailure{
- Type: AssertOperation,
- Errors: []error{
- errors.New("got close error when disconnecting websocket"),
- err,
- },
- })
- }
- return ws
- }
- // Close cleanly closes the underlying WebSocket connection
- // by sending an empty close message and then waiting (with timeout)
- // for the server to close the connection.
- //
- // WebSocket close code may be optionally specified.
- // If not, then "1000 - Normal Closure" will be used.
- //
- // WebSocket close codes are defined in RFC 6455, section 11.7.
- // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
- //
- // It's okay to call this function multiple times.
- //
- // Example:
- //
- // conn := resp.Connection()
- // conn.Close(websocket.CloseUnsupportedData)
- func (ws *Websocket) Close(code ...int) *Websocket {
- opChain := ws.chain.enter("Close()")
- defer opChain.leave()
- switch {
- case ws.checkUnusable(opChain, "Close()"):
- return ws
- case len(code) > 1:
- opChain.fail(AssertionFailure{
- Type: AssertUsage,
- Errors: []error{
- errors.New("unexpected multiple code arguments"),
- },
- })
- return ws
- }
- ws.writeMessage(opChain, websocket.CloseMessage, nil, code...)
- return ws
- }
- // CloseWithBytes cleanly closes the underlying WebSocket connection
- // by sending given slice of bytes as a close message and then waiting
- // (with timeout) for the server to close the connection.
- //
- // WebSocket close code may be optionally specified.
- // If not, then "1000 - Normal Closure" will be used.
- //
- // WebSocket close codes are defined in RFC 6455, section 11.7.
- // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
- //
- // It's okay to call this function multiple times.
- //
- // Example:
- //
- // conn := resp.Connection()
- // conn.CloseWithBytes([]byte("bye!"), websocket.CloseGoingAway)
- func (ws *Websocket) CloseWithBytes(b []byte, code ...int) *Websocket {
- opChain := ws.chain.enter("CloseWithBytes()")
- defer opChain.leave()
- switch {
- case ws.checkUnusable(opChain, "CloseWithBytes()"):
- return ws
- case len(code) > 1:
- opChain.fail(AssertionFailure{
- Type: AssertUsage,
- Errors: []error{
- errors.New("unexpected multiple code arguments"),
- },
- })
- return ws
- }
- ws.writeMessage(opChain, websocket.CloseMessage, b, code...)
- return ws
- }
- // CloseWithJSON cleanly closes the underlying WebSocket connection
- // by sending given object (marshaled using json.Marshal()) as a close message
- // and then waiting (with timeout) for the server to close the connection.
- //
- // WebSocket close code may be optionally specified.
- // If not, then "1000 - Normal Closure" will be used.
- //
- // WebSocket close codes are defined in RFC 6455, section 11.7.
- // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
- //
- // It's okay to call this function multiple times.
- //
- // Example:
- //
- // type MyJSON struct {
- // Foo int `json:"foo"`
- // }
- //
- // conn := resp.Connection()
- // conn.CloseWithJSON(MyJSON{Foo: 123}, websocket.CloseUnsupportedData)
- func (ws *Websocket) CloseWithJSON(
- object interface{}, code ...int,
- ) *Websocket {
- opChain := ws.chain.enter("CloseWithJSON()")
- defer opChain.leave()
- switch {
- case ws.checkUnusable(opChain, "CloseWithJSON()"):
- return ws
- case len(code) > 1:
- opChain.fail(AssertionFailure{
- Type: AssertUsage,
- Errors: []error{
- errors.New("unexpected multiple code arguments"),
- },
- })
- return ws
- }
- b, err := json.Marshal(object)
- if err != nil {
- opChain.fail(AssertionFailure{
- Type: AssertValid,
- Actual: &AssertionValue{object},
- Errors: []error{
- errors.New("invalid json object"),
- err,
- },
- })
- return ws
- }
- ws.writeMessage(opChain, websocket.CloseMessage, b, code...)
- return ws
- }
- // CloseWithText cleanly closes the underlying WebSocket connection
- // by sending given text as a close message and then waiting (with timeout)
- // for the server to close the connection.
- //
- // WebSocket close code may be optionally specified.
- // If not, then "1000 - Normal Closure" will be used.
- //
- // WebSocket close codes are defined in RFC 6455, section 11.7.
- // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
- //
- // It's okay to call this function multiple times.
- //
- // Example:
- //
- // conn := resp.Connection()
- // conn.CloseWithText("bye!")
- func (ws *Websocket) CloseWithText(s string, code ...int) *Websocket {
- opChain := ws.chain.enter("CloseWithText()")
- defer opChain.leave()
- switch {
- case ws.checkUnusable(opChain, "CloseWithText()"):
- return ws
- case len(code) > 1:
- opChain.fail(AssertionFailure{
- Type: AssertUsage,
- Errors: []error{
- errors.New("unexpected multiple code arguments"),
- },
- })
- return ws
- }
- ws.writeMessage(opChain, websocket.CloseMessage, []byte(s), code...)
- return ws
- }
- // WriteMessage writes to the underlying WebSocket connection a message
- // of given type with given content.
- // Additionally, WebSocket close code may be specified for close messages.
- //
- // WebSocket message types are defined in RFC 6455, section 11.8.
- // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
- //
- // WebSocket close codes are defined in RFC 6455, section 11.7.
- // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
- //
- // Example:
- //
- // conn := resp.Connection()
- // conn.WriteMessage(websocket.CloseMessage, []byte("Namárië..."))
- func (ws *Websocket) WriteMessage(typ int, content []byte, closeCode ...int) *Websocket {
- opChain := ws.chain.enter("WriteMessage()")
- defer opChain.leave()
- if ws.checkUnusable(opChain, "WriteMessage()") {
- return ws
- }
- ws.writeMessage(opChain, typ, content, closeCode...)
- return ws
- }
- // WriteBytesBinary is a shorthand for c.WriteMessage(websocket.BinaryMessage, b).
- func (ws *Websocket) WriteBytesBinary(b []byte) *Websocket {
- opChain := ws.chain.enter("WriteBytesBinary()")
- defer opChain.leave()
- if ws.checkUnusable(opChain, "WriteBytesBinary()") {
- return ws
- }
- ws.writeMessage(opChain, websocket.BinaryMessage, b)
- return ws
- }
- // WriteBytesText is a shorthand for c.WriteMessage(websocket.TextMessage, b).
- func (ws *Websocket) WriteBytesText(b []byte) *Websocket {
- opChain := ws.chain.enter("WriteBytesText()")
- defer opChain.leave()
- if ws.checkUnusable(opChain, "WriteBytesText()") {
- return ws
- }
- ws.writeMessage(opChain, websocket.TextMessage, b)
- return ws
- }
- // WriteText is a shorthand for
- // c.WriteMessage(websocket.TextMessage, []byte(s)).
- func (ws *Websocket) WriteText(s string) *Websocket {
- opChain := ws.chain.enter("WriteText()")
- defer opChain.leave()
- if ws.checkUnusable(opChain, "WriteText()") {
- return ws
- }
- return ws.WriteMessage(websocket.TextMessage, []byte(s))
- }
- // WriteJSON writes to the underlying WebSocket connection given object,
- // marshaled using json.Marshal().
- func (ws *Websocket) WriteJSON(object interface{}) *Websocket {
- opChain := ws.chain.enter("WriteJSON()")
- defer opChain.leave()
- if ws.checkUnusable(opChain, "WriteJSON()") {
- return ws
- }
- b, err := json.Marshal(object)
- if err != nil {
- opChain.fail(AssertionFailure{
- Type: AssertValid,
- Actual: &AssertionValue{object},
- Errors: []error{
- errors.New("invalid json object"),
- err,
- },
- })
- return ws
- }
- ws.writeMessage(opChain, websocket.TextMessage, b)
- return ws
- }
- func (ws *Websocket) checkUnusable(opChain *chain, where string) bool {
- switch {
- case opChain.failed():
- return true
- case ws.conn == nil:
- opChain.fail(AssertionFailure{
- Type: AssertUsage,
- Errors: []error{
- fmt.Errorf("unexpected %s call for failed websocket connection", where),
- },
- })
- return true
- case ws.isClosed:
- opChain.fail(AssertionFailure{
- Type: AssertUsage,
- Errors: []error{
- fmt.Errorf("unexpected %s call for closed websocket connection", where),
- },
- })
- return true
- }
- return false
- }
- func (ws *Websocket) readMessage(opChain *chain) *WebsocketMessage {
- wm := newEmptyWebsocketMessage(opChain)
- if !ws.setReadDeadline(opChain) {
- return nil
- }
- var err error
- wm.typ, wm.content, err = ws.conn.ReadMessage()
- if err != nil {
- closeErr, ok := err.(*websocket.CloseError)
- if !ok {
- opChain.fail(AssertionFailure{
- Type: AssertOperation,
- Errors: []error{
- errors.New("failed to read from websocket"),
- err,
- },
- })
- return nil
- }
- wm.typ = websocket.CloseMessage
- wm.closeCode = closeErr.Code
- wm.content = []byte(closeErr.Text)
- }
- ws.printRead(wm.typ, wm.content, wm.closeCode)
- return wm
- }
- func (ws *Websocket) writeMessage(
- opChain *chain, typ int, content []byte, closeCode ...int,
- ) {
- switch typ {
- case websocket.TextMessage, websocket.BinaryMessage:
- ws.printWrite(typ, content, 0)
- case websocket.CloseMessage:
- if len(closeCode) > 1 {
- opChain.fail(AssertionFailure{
- Type: AssertUsage,
- Errors: []error{
- errors.New("unexpected multiple closeCode arguments"),
- },
- })
- return
- }
- code := websocket.CloseNormalClosure
- if len(closeCode) > 0 {
- code = closeCode[0]
- }
- ws.printWrite(typ, content, code)
- content = websocket.FormatCloseMessage(code, string(content))
- default:
- opChain.fail(AssertionFailure{
- Type: AssertUsage,
- Errors: []error{
- fmt.Errorf("unexpected websocket message type %s",
- wsMessageType(typ)),
- },
- })
- return
- }
- if !ws.setWriteDeadline(opChain) {
- return
- }
- if err := ws.conn.WriteMessage(typ, content); err != nil {
- opChain.fail(AssertionFailure{
- Type: AssertOperation,
- Errors: []error{
- errors.New("failed to write to websocket"),
- err,
- },
- })
- return
- }
- }
- func (ws *Websocket) setReadDeadline(opChain *chain) bool {
- deadline := infiniteTime
- if ws.readTimeout != noDuration {
- deadline = time.Now().Add(ws.readTimeout)
- }
- if err := ws.conn.SetReadDeadline(deadline); err != nil {
- opChain.fail(AssertionFailure{
- Type: AssertOperation,
- Errors: []error{
- errors.New("failed to set read deadline for websocket"),
- err,
- },
- })
- return false
- }
- return true
- }
- func (ws *Websocket) setWriteDeadline(opChain *chain) bool {
- deadline := infiniteTime
- if ws.writeTimeout != noDuration {
- deadline = time.Now().Add(ws.writeTimeout)
- }
- if err := ws.conn.SetWriteDeadline(deadline); err != nil {
- opChain.fail(AssertionFailure{
- Type: AssertOperation,
- Errors: []error{
- errors.New("failed to set write deadline for websocket"),
- err,
- },
- })
- return false
- }
- return true
- }
- func (ws *Websocket) printRead(typ int, content []byte, closeCode int) {
- for _, printer := range ws.config.Printers {
- if p, ok := printer.(WebsocketPrinter); ok {
- p.WebsocketRead(typ, content, closeCode)
- }
- }
- }
- func (ws *Websocket) printWrite(typ int, content []byte, closeCode int) {
- for _, printer := range ws.config.Printers {
- if p, ok := printer.(WebsocketPrinter); ok {
- p.WebsocketWrite(typ, content, closeCode)
- }
- }
- }
|