pubsub.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. package radix
  2. import (
  3. "bufio"
  4. "bytes"
  5. "io"
  6. "net"
  7. "sync"
  8. "time"
  9. "errors"
  10. "github.com/mediocregopher/radix/v3/resp"
  11. "github.com/mediocregopher/radix/v3/resp/resp2"
  12. )
  13. // PubSubMessage describes a message being published to a subscribed channel.
  14. type PubSubMessage struct {
  15. Type string // "message" or "pmessage"
  16. Pattern string // will be set if Type is "pmessage"
  17. Channel string
  18. Message []byte
  19. }
  20. // MarshalRESP implements the Marshaler interface.
  21. func (m PubSubMessage) MarshalRESP(w io.Writer) error {
  22. var err error
  23. marshal := func(m resp.Marshaler) {
  24. if err == nil {
  25. err = m.MarshalRESP(w)
  26. }
  27. }
  28. if m.Type == "message" {
  29. marshal(resp2.ArrayHeader{N: 3})
  30. marshal(resp2.BulkString{S: m.Type})
  31. } else if m.Type == "pmessage" {
  32. marshal(resp2.ArrayHeader{N: 4})
  33. marshal(resp2.BulkString{S: m.Type})
  34. marshal(resp2.BulkString{S: m.Pattern})
  35. } else {
  36. return errors.New("unknown message Type")
  37. }
  38. marshal(resp2.BulkString{S: m.Channel})
  39. marshal(resp2.BulkStringBytes{B: m.Message})
  40. return err
  41. }
  42. var errNotPubSubMessage = errors.New("message is not a PubSubMessage")
  43. // UnmarshalRESP implements the Unmarshaler interface.
  44. func (m *PubSubMessage) UnmarshalRESP(br *bufio.Reader) error {
  45. // This method will fully consume the message on the wire, regardless of if
  46. // it is a PubSubMessage or not. If it is not then errNotPubSubMessage is
  47. // returned.
  48. // When in subscribe mode redis only allows (P)(UN)SUBSCRIBE commands, which
  49. // all return arrays, and PING, which returns an array when in subscribe
  50. // mode. HOWEVER, when all channels have been unsubscribed from then the
  51. // connection will be taken _out_ of subscribe mode. This is theoretically
  52. // fine, since the driver will still only allow the 5 commands, except PING
  53. // will return a simple string when in the non-subscribed state. So this
  54. // needs to check for that.
  55. if prefix, err := br.Peek(1); err != nil {
  56. return err
  57. } else if bytes.Equal(prefix, resp2.SimpleStringPrefix) {
  58. // if it's a simple string, discard it (it's probably PONG) and error
  59. if err := (resp2.Any{}).UnmarshalRESP(br); err != nil {
  60. return err
  61. }
  62. return resp.ErrDiscarded{Err: errNotPubSubMessage}
  63. }
  64. var ah resp2.ArrayHeader
  65. if err := ah.UnmarshalRESP(br); err != nil {
  66. return err
  67. } else if ah.N < 2 {
  68. return errors.New("message has too few elements")
  69. }
  70. var msgType resp2.BulkStringBytes
  71. if err := msgType.UnmarshalRESP(br); err != nil {
  72. return err
  73. }
  74. switch string(msgType.B) {
  75. case "message":
  76. m.Type = "message"
  77. if ah.N != 3 {
  78. return errors.New("message has wrong number of elements")
  79. }
  80. case "pmessage":
  81. m.Type = "pmessage"
  82. if ah.N != 4 {
  83. return errors.New("message has wrong number of elements")
  84. }
  85. var pattern resp2.BulkString
  86. if err := pattern.UnmarshalRESP(br); err != nil {
  87. return err
  88. }
  89. m.Pattern = pattern.S
  90. default:
  91. // if it's not a PubSubMessage then discard the rest of the array
  92. for i := 1; i < ah.N; i++ {
  93. if err := (resp2.Any{}).UnmarshalRESP(br); err != nil {
  94. return err
  95. }
  96. }
  97. return errNotPubSubMessage
  98. }
  99. var channel resp2.BulkString
  100. if err := channel.UnmarshalRESP(br); err != nil {
  101. return err
  102. }
  103. m.Channel = channel.S
  104. var msg resp2.BulkStringBytes
  105. if err := msg.UnmarshalRESP(br); err != nil {
  106. return err
  107. }
  108. m.Message = msg.B
  109. return nil
  110. }
  111. ////////////////////////////////////////////////////////////////////////////////
  112. type chanSet map[string]map[chan<- PubSubMessage]bool
  113. func (cs chanSet) add(s string, ch chan<- PubSubMessage) {
  114. m, ok := cs[s]
  115. if !ok {
  116. m = map[chan<- PubSubMessage]bool{}
  117. cs[s] = m
  118. }
  119. m[ch] = true
  120. }
  121. func (cs chanSet) del(s string, ch chan<- PubSubMessage) bool {
  122. m, ok := cs[s]
  123. if !ok {
  124. return true
  125. }
  126. delete(m, ch)
  127. if len(m) == 0 {
  128. delete(cs, s)
  129. return true
  130. }
  131. return false
  132. }
  133. func (cs chanSet) missing(ss []string) []string {
  134. out := make([]string, 0, len(ss))
  135. for _, s := range ss {
  136. if _, ok := cs[s]; !ok {
  137. out = append(out, s)
  138. }
  139. }
  140. return out
  141. }
  142. func (cs chanSet) inverse() map[chan<- PubSubMessage][]string {
  143. inv := map[chan<- PubSubMessage][]string{}
  144. for s, m := range cs {
  145. for ch := range m {
  146. inv[ch] = append(inv[ch], s)
  147. }
  148. }
  149. return inv
  150. }
  151. ////////////////////////////////////////////////////////////////////////////////
  152. // PubSubConn wraps an existing Conn to support redis' pubsub system.
  153. // User-created channels can be subscribed to redis channels to receive
  154. // PubSubMessages which have been published.
  155. //
  156. // If any methods return an error it means the PubSubConn has been Close'd and
  157. // subscribed msgCh's will no longer receive PubSubMessages from it. All methods
  158. // are threadsafe, but should be called in a different go-routine than that
  159. // which is reading from the PubSubMessage channels.
  160. //
  161. // NOTE the PubSubMessage channels should never block. If any channels block
  162. // when being written to they will block all other channels from receiving a
  163. // publish and block methods from returning.
  164. type PubSubConn interface {
  165. // Subscribe subscribes the PubSubConn to the given set of channels. msgCh
  166. // will receieve a PubSubMessage for every publish written to any of the
  167. // channels. This may be called multiple times for the same channels and
  168. // different msgCh's, each msgCh will receieve a copy of the PubSubMessage
  169. // for each publish.
  170. Subscribe(msgCh chan<- PubSubMessage, channels ...string) error
  171. // Unsubscribe unsubscribes the msgCh from the given set of channels, if it
  172. // was subscribed at all.
  173. //
  174. // NOTE even if msgCh is not subscribed to any other redis channels, it
  175. // should still be considered "active", and therefore still be having
  176. // messages read from it, until Unsubscribe has returned
  177. Unsubscribe(msgCh chan<- PubSubMessage, channels ...string) error
  178. // PSubscribe is like Subscribe, but it subscribes msgCh to a set of
  179. // patterns and not individual channels.
  180. PSubscribe(msgCh chan<- PubSubMessage, patterns ...string) error
  181. // PUnsubscribe is like Unsubscribe, but it unsubscribes msgCh from a set of
  182. // patterns and not individual channels.
  183. //
  184. // NOTE even if msgCh is not subscribed to any other redis channels, it
  185. // should still be considered "active", and therefore still be having
  186. // messages read from it, until PUnsubscribe has returned
  187. PUnsubscribe(msgCh chan<- PubSubMessage, patterns ...string) error
  188. // Ping performs a simple Ping command on the PubSubConn, returning an error
  189. // if it failed for some reason
  190. Ping() error
  191. // Close closes the PubSubConn so it can't be used anymore. All subscribed
  192. // channels will stop receiving PubSubMessages from this Conn (but will not
  193. // themselves be closed).
  194. //
  195. // NOTE all msgChs should be considered "active", and therefore still be
  196. // having messages read from them, until Close has returned.
  197. Close() error
  198. }
  199. type pubSubConn struct {
  200. conn Conn
  201. csL sync.RWMutex
  202. subs chanSet
  203. psubs chanSet
  204. // These are used for writing commands and waiting for their response (e.g.
  205. // SUBSCRIBE, PING). See the do method for how that works.
  206. cmdL sync.Mutex
  207. cmdResCh chan error
  208. close sync.Once
  209. closeErr error
  210. // This one is optional, and kind of cheating. We use it in persistent to
  211. // get on-the-fly updates of when the connection fails. Maybe one day this
  212. // could be exposed if there's a clean way of doing so, or another way
  213. // accomplishing the same thing could be done instead.
  214. closeErrCh chan error
  215. // only used during testing
  216. testEventCh chan string
  217. }
  218. // PubSub wraps the given Conn so that it becomes a PubSubConn. The passed in
  219. // Conn should not be used after this call.
  220. func PubSub(rc Conn) PubSubConn {
  221. return newPubSub(rc, nil)
  222. }
  223. func newPubSub(rc Conn, closeErrCh chan error) PubSubConn {
  224. c := &pubSubConn{
  225. conn: rc,
  226. subs: chanSet{},
  227. psubs: chanSet{},
  228. cmdResCh: make(chan error, 1),
  229. closeErrCh: closeErrCh,
  230. }
  231. go c.spin()
  232. // Periodically call Ping so the connection has a keepalive on the
  233. // application level. If the Conn is closed Ping will return an error and
  234. // this will clean itself up.
  235. go func() {
  236. t := time.NewTicker(5 * time.Second)
  237. defer t.Stop()
  238. for range t.C {
  239. if err := c.Ping(); err != nil {
  240. return
  241. }
  242. }
  243. }()
  244. return c
  245. }
  246. func (c *pubSubConn) testEvent(str string) {
  247. if c.testEventCh != nil {
  248. c.testEventCh <- str
  249. }
  250. }
  251. func (c *pubSubConn) publish(m PubSubMessage) {
  252. c.csL.RLock()
  253. defer c.csL.RUnlock()
  254. var subs map[chan<- PubSubMessage]bool
  255. if m.Type == "pmessage" {
  256. subs = c.psubs[m.Pattern]
  257. } else {
  258. subs = c.subs[m.Channel]
  259. }
  260. for ch := range subs {
  261. ch <- m
  262. }
  263. }
  264. func (c *pubSubConn) spin() {
  265. for {
  266. var m PubSubMessage
  267. err := c.conn.Decode(&m)
  268. if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() {
  269. c.testEvent("timeout")
  270. continue
  271. } else if errors.Is(err, errNotPubSubMessage) {
  272. c.cmdResCh <- nil
  273. continue
  274. } else if err != nil {
  275. // closeInner returns the error from closing the Conn, which doesn't
  276. // really matter here.
  277. _ = c.closeInner(err)
  278. return
  279. }
  280. c.publish(m)
  281. }
  282. }
  283. // NOTE cmdL _must_ be held to use do.
  284. func (c *pubSubConn) do(exp int, cmd string, args ...string) error {
  285. rcmd := Cmd(nil, cmd, args...)
  286. if err := c.conn.Encode(rcmd); err != nil {
  287. return err
  288. }
  289. for i := 0; i < exp; i++ {
  290. err, ok := <-c.cmdResCh
  291. if err != nil {
  292. return err
  293. } else if !ok {
  294. return errors.New("connection closed")
  295. }
  296. }
  297. return nil
  298. }
  299. func (c *pubSubConn) closeInner(cmdResErr error) error {
  300. c.close.Do(func() {
  301. c.csL.Lock()
  302. defer c.csL.Unlock()
  303. c.closeErr = c.conn.Close()
  304. c.subs = nil
  305. c.psubs = nil
  306. if cmdResErr != nil {
  307. select {
  308. case c.cmdResCh <- cmdResErr:
  309. default:
  310. }
  311. }
  312. if c.closeErrCh != nil {
  313. c.closeErrCh <- cmdResErr
  314. close(c.closeErrCh)
  315. }
  316. close(c.cmdResCh)
  317. })
  318. return c.closeErr
  319. }
  320. func (c *pubSubConn) Close() error {
  321. return c.closeInner(nil)
  322. }
  323. func (c *pubSubConn) Subscribe(msgCh chan<- PubSubMessage, channels ...string) error {
  324. c.cmdL.Lock()
  325. defer c.cmdL.Unlock()
  326. c.csL.RLock()
  327. missing := c.subs.missing(channels)
  328. c.csL.RUnlock()
  329. if len(missing) > 0 {
  330. if err := c.do(len(missing), "SUBSCRIBE", missing...); err != nil {
  331. return err
  332. }
  333. }
  334. c.csL.Lock()
  335. for _, channel := range channels {
  336. c.subs.add(channel, msgCh)
  337. }
  338. c.csL.Unlock()
  339. return nil
  340. }
  341. func (c *pubSubConn) Unsubscribe(msgCh chan<- PubSubMessage, channels ...string) error {
  342. c.cmdL.Lock()
  343. defer c.cmdL.Unlock()
  344. c.csL.Lock()
  345. emptyChannels := make([]string, 0, len(channels))
  346. for _, channel := range channels {
  347. if empty := c.subs.del(channel, msgCh); empty {
  348. emptyChannels = append(emptyChannels, channel)
  349. }
  350. }
  351. c.csL.Unlock()
  352. if len(emptyChannels) == 0 {
  353. return nil
  354. }
  355. return c.do(len(emptyChannels), "UNSUBSCRIBE", emptyChannels...)
  356. }
  357. func (c *pubSubConn) PSubscribe(msgCh chan<- PubSubMessage, patterns ...string) error {
  358. c.cmdL.Lock()
  359. defer c.cmdL.Unlock()
  360. c.csL.RLock()
  361. missing := c.psubs.missing(patterns)
  362. c.csL.RUnlock()
  363. if len(missing) > 0 {
  364. if err := c.do(len(missing), "PSUBSCRIBE", missing...); err != nil {
  365. return err
  366. }
  367. }
  368. c.csL.Lock()
  369. for _, pattern := range patterns {
  370. c.psubs.add(pattern, msgCh)
  371. }
  372. c.csL.Unlock()
  373. return nil
  374. }
  375. func (c *pubSubConn) PUnsubscribe(msgCh chan<- PubSubMessage, patterns ...string) error {
  376. c.cmdL.Lock()
  377. defer c.cmdL.Unlock()
  378. c.csL.Lock()
  379. emptyPatterns := make([]string, 0, len(patterns))
  380. for _, pattern := range patterns {
  381. if empty := c.psubs.del(pattern, msgCh); empty {
  382. emptyPatterns = append(emptyPatterns, pattern)
  383. }
  384. }
  385. c.csL.Unlock()
  386. if len(emptyPatterns) == 0 {
  387. return nil
  388. }
  389. return c.do(len(emptyPatterns), "PUNSUBSCRIBE", emptyPatterns...)
  390. }
  391. func (c *pubSubConn) Ping() error {
  392. c.cmdL.Lock()
  393. defer c.cmdL.Unlock()
  394. return c.do(1, "PING")
  395. }