stub.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. package radix
  2. import (
  3. "bufio"
  4. "bytes"
  5. "net"
  6. "sync"
  7. "time"
  8. "errors"
  9. "github.com/mediocregopher/radix/v3/resp"
  10. "github.com/mediocregopher/radix/v3/resp/resp2"
  11. )
  12. type bufferAddr struct {
  13. network, addr string
  14. }
  15. func (sa bufferAddr) Network() string {
  16. return sa.network
  17. }
  18. func (sa bufferAddr) String() string {
  19. return sa.addr
  20. }
  21. type buffer struct {
  22. net.Conn // always nil
  23. remoteAddr bufferAddr
  24. bufL *sync.Cond
  25. buf *bytes.Buffer
  26. bufbr *bufio.Reader
  27. closed bool
  28. readDeadline time.Time
  29. }
  30. func newBuffer(remoteNetwork, remoteAddr string) *buffer {
  31. buf := new(bytes.Buffer)
  32. return &buffer{
  33. remoteAddr: bufferAddr{network: remoteNetwork, addr: remoteAddr},
  34. bufL: sync.NewCond(new(sync.Mutex)),
  35. buf: buf,
  36. bufbr: bufio.NewReader(buf),
  37. }
  38. }
  39. func (b *buffer) Encode(m resp.Marshaler) error {
  40. b.bufL.L.Lock()
  41. var err error
  42. if b.closed {
  43. err = b.err("write", errClosed)
  44. } else {
  45. err = m.MarshalRESP(b.buf)
  46. }
  47. b.bufL.L.Unlock()
  48. if err != nil {
  49. return err
  50. }
  51. b.bufL.Broadcast()
  52. return nil
  53. }
  54. func (b *buffer) Decode(u resp.Unmarshaler) error {
  55. b.bufL.L.Lock()
  56. defer b.bufL.L.Unlock()
  57. var timeoutCh chan struct{}
  58. if b.readDeadline.IsZero() {
  59. // no readDeadline, timeoutCh will never be written to
  60. } else if now := time.Now(); b.readDeadline.Before(now) {
  61. return b.err("read", new(timeoutError))
  62. } else {
  63. timeoutCh = make(chan struct{}, 2)
  64. sleep := b.readDeadline.Sub(now)
  65. go func() {
  66. time.Sleep(sleep)
  67. timeoutCh <- struct{}{}
  68. b.bufL.Broadcast()
  69. }()
  70. }
  71. for b.buf.Len() == 0 && b.bufbr.Buffered() == 0 {
  72. if b.closed {
  73. return b.err("read", errClosed)
  74. }
  75. select {
  76. case <-timeoutCh:
  77. return b.err("read", new(timeoutError))
  78. default:
  79. }
  80. // we have to periodically wakeup to double-check the timeoutCh, if
  81. // there is one
  82. if timeoutCh != nil {
  83. go func() {
  84. time.Sleep(1 * time.Second)
  85. b.bufL.Broadcast()
  86. }()
  87. }
  88. b.bufL.Wait()
  89. }
  90. return u.UnmarshalRESP(b.bufbr)
  91. }
  92. func (b *buffer) Close() error {
  93. b.bufL.L.Lock()
  94. defer b.bufL.L.Unlock()
  95. if b.closed {
  96. return b.err("close", errClosed)
  97. }
  98. b.closed = true
  99. b.bufL.Broadcast()
  100. return nil
  101. }
  102. func (b *buffer) RemoteAddr() net.Addr {
  103. return b.remoteAddr
  104. }
  105. func (b *buffer) SetDeadline(t time.Time) error {
  106. return b.SetReadDeadline(t)
  107. }
  108. func (b *buffer) SetReadDeadline(t time.Time) error {
  109. b.bufL.L.Lock()
  110. defer b.bufL.L.Unlock()
  111. if b.closed {
  112. return b.err("set", errClosed)
  113. }
  114. b.readDeadline = t
  115. return nil
  116. }
  117. func (b *buffer) err(op string, err error) error {
  118. return &net.OpError{
  119. Op: op,
  120. Net: "tcp",
  121. Source: nil,
  122. Addr: b.remoteAddr,
  123. Err: err,
  124. }
  125. }
  126. var errClosed = errors.New("use of closed network connection")
  127. type timeoutError struct{}
  128. func (e *timeoutError) Error() string { return "i/o timeout" }
  129. func (e *timeoutError) Timeout() bool { return true }
  130. func (e *timeoutError) Temporary() bool { return true }
  131. ////////////////////////////////////////////////////////////////////////////////
  132. type stub struct {
  133. *buffer
  134. fn func([]string) interface{}
  135. }
  136. // Stub returns a (fake) Conn which pretends it is a Conn to a real redis
  137. // instance, but is instead using the given callback to service requests. It is
  138. // primarily useful for writing tests.
  139. //
  140. // When Encode is called the given value is marshalled into bytes then
  141. // unmarshalled into a []string, which is passed to the callback. The return
  142. // from the callback is then marshalled and buffered interanlly, and will be
  143. // unmarshalled in the next call to Decode.
  144. //
  145. // remoteNetwork and remoteAddr can be empty, but if given will be used as the
  146. // return from the RemoteAddr method.
  147. //
  148. // If the internal buffer is empty then Decode will block until Encode is called
  149. // in a separate go-routine. The SetDeadline and SetReadDeadline methods can be
  150. // used as usual to limit how long Decode blocks. All other inherited net.Conn
  151. // methods will panic.
  152. func Stub(remoteNetwork, remoteAddr string, fn func([]string) interface{}) Conn {
  153. return &stub{
  154. buffer: newBuffer(remoteNetwork, remoteAddr),
  155. fn: fn,
  156. }
  157. }
  158. func (s *stub) Do(a Action) error {
  159. return a.Run(s)
  160. }
  161. func (s *stub) Encode(m resp.Marshaler) error {
  162. // first marshal into a RawMessage
  163. buf := new(bytes.Buffer)
  164. if err := m.MarshalRESP(buf); err != nil {
  165. return err
  166. }
  167. br := bufio.NewReader(buf)
  168. var rm resp2.RawMessage
  169. for {
  170. if buf.Len() == 0 && br.Buffered() == 0 {
  171. break
  172. } else if err := rm.UnmarshalRESP(br); err != nil {
  173. return err
  174. }
  175. // unmarshal that into a string slice
  176. var ss []string
  177. if err := rm.UnmarshalInto(resp2.Any{I: &ss}); err != nil {
  178. return err
  179. }
  180. // get return from callback. Results implementing resp.Marshaler are
  181. // assumed to be wanting to be written in all cases, otherwise if the
  182. // result is an error it is assumed to want to be returned directly.
  183. ret := s.fn(ss)
  184. if m, ok := ret.(resp.Marshaler); ok {
  185. return s.buffer.Encode(m)
  186. } else if err, _ := ret.(error); err != nil {
  187. return err
  188. } else if err = s.buffer.Encode(resp2.Any{I: ret}); err != nil {
  189. return err
  190. }
  191. }
  192. return nil
  193. }
  194. func (s *stub) NetConn() net.Conn {
  195. return s.buffer
  196. }