pubsub_stub.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package radix
  2. import (
  3. "fmt"
  4. "io"
  5. "strings"
  6. "sync"
  7. "errors"
  8. "github.com/mediocregopher/radix/v3/resp"
  9. "github.com/mediocregopher/radix/v3/resp/resp2"
  10. )
  11. var errPubSubMode = resp2.Error{
  12. E: errors.New("ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context"),
  13. }
  14. type multiMarshal []resp.Marshaler
  15. func (mm multiMarshal) MarshalRESP(w io.Writer) error {
  16. for _, m := range mm {
  17. if err := m.MarshalRESP(w); err != nil {
  18. return err
  19. }
  20. }
  21. return nil
  22. }
  23. type pubSubStub struct {
  24. Conn
  25. fn func([]string) interface{}
  26. inCh <-chan PubSubMessage
  27. closeOnce sync.Once
  28. closeCh chan struct{}
  29. closeErr error
  30. l sync.Mutex
  31. pubsubMode bool
  32. subbed, psubbed map[string]bool
  33. // this is only used for tests
  34. mDoneCh chan struct{}
  35. }
  36. // PubSubStub returns a (fake) Conn, much like Stub does, which pretends it is a
  37. // Conn to a real redis instance, but is instead using the given callback to
  38. // service requests. It is primarily useful for writing tests.
  39. //
  40. // PubSubStub differes from Stub in that Encode calls for (P)SUBSCRIBE,
  41. // (P)UNSUBSCRIBE, MESSAGE, and PING will be intercepted and handled as per
  42. // redis' expected pubsub functionality. A PubSubMessage may be written to the
  43. // returned channel at any time, and if the PubSubStub has had (P)SUBSCRIBE
  44. // called matching that PubSubMessage it will be written to the PubSubStub's
  45. // internal buffer as expected.
  46. //
  47. // This is intended to be used so that it can mock services which can perform
  48. // both normal redis commands and pubsub (e.g. a real redis instance, redis
  49. // sentinel). Once created this stub can be passed into PubSub and treated like
  50. // a real connection.
  51. func PubSubStub(remoteNetwork, remoteAddr string, fn func([]string) interface{}) (Conn, chan<- PubSubMessage) {
  52. ch := make(chan PubSubMessage)
  53. s := &pubSubStub{
  54. fn: fn,
  55. inCh: ch,
  56. closeCh: make(chan struct{}),
  57. subbed: map[string]bool{},
  58. psubbed: map[string]bool{},
  59. mDoneCh: make(chan struct{}, 1),
  60. }
  61. s.Conn = Stub(remoteNetwork, remoteAddr, s.innerFn)
  62. go s.spin()
  63. return s, ch
  64. }
  65. func (s *pubSubStub) innerFn(ss []string) interface{} {
  66. s.l.Lock()
  67. defer s.l.Unlock()
  68. writeRes := func(mm multiMarshal, cmd, subj string) multiMarshal {
  69. c := len(s.subbed) + len(s.psubbed)
  70. s.pubsubMode = c > 0
  71. return append(mm, resp2.Any{I: []interface{}{cmd, subj, c}})
  72. }
  73. switch strings.ToUpper(ss[0]) {
  74. case "PING":
  75. if !s.pubsubMode {
  76. return s.fn(ss)
  77. }
  78. return []string{"pong", ""}
  79. case "SUBSCRIBE":
  80. var mm multiMarshal
  81. for _, channel := range ss[1:] {
  82. s.subbed[channel] = true
  83. mm = writeRes(mm, "subscribe", channel)
  84. }
  85. return mm
  86. case "UNSUBSCRIBE":
  87. var mm multiMarshal
  88. for _, channel := range ss[1:] {
  89. delete(s.subbed, channel)
  90. mm = writeRes(mm, "unsubscribe", channel)
  91. }
  92. return mm
  93. case "PSUBSCRIBE":
  94. var mm multiMarshal
  95. for _, pattern := range ss[1:] {
  96. s.psubbed[pattern] = true
  97. mm = writeRes(mm, "psubscribe", pattern)
  98. }
  99. return mm
  100. case "PUNSUBSCRIBE":
  101. var mm multiMarshal
  102. for _, pattern := range ss[1:] {
  103. delete(s.psubbed, pattern)
  104. mm = writeRes(mm, "punsubscribe", pattern)
  105. }
  106. return mm
  107. case "MESSAGE":
  108. m := PubSubMessage{
  109. Type: "message",
  110. Channel: ss[1],
  111. Message: []byte(ss[2]),
  112. }
  113. var mm multiMarshal
  114. if s.subbed[m.Channel] {
  115. mm = append(mm, m)
  116. }
  117. return mm
  118. case "PMESSAGE":
  119. m := PubSubMessage{
  120. Type: "pmessage",
  121. Pattern: ss[1],
  122. Channel: ss[2],
  123. Message: []byte(ss[3]),
  124. }
  125. var mm multiMarshal
  126. if s.psubbed[m.Pattern] {
  127. mm = append(mm, m)
  128. }
  129. return mm
  130. default:
  131. if s.pubsubMode {
  132. return errPubSubMode
  133. }
  134. return s.fn(ss)
  135. }
  136. }
  137. func (s *pubSubStub) Close() error {
  138. s.closeOnce.Do(func() {
  139. close(s.closeCh)
  140. s.closeErr = s.Conn.Close()
  141. })
  142. return s.closeErr
  143. }
  144. func (s *pubSubStub) spin() {
  145. for {
  146. select {
  147. case m, ok := <-s.inCh:
  148. if !ok {
  149. panic("PubSubStub message channel was closed")
  150. }
  151. if m.Type == "" {
  152. if m.Pattern == "" {
  153. m.Type = "message"
  154. } else {
  155. m.Type = "pmessage"
  156. }
  157. }
  158. if err := s.Conn.Encode(m); err != nil {
  159. panic(fmt.Sprintf("error encoding message in PubSubStub: %s", err))
  160. }
  161. select {
  162. case s.mDoneCh <- struct{}{}:
  163. default:
  164. }
  165. case <-s.closeCh:
  166. return
  167. }
  168. }
  169. }