pipeliner.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. package radix
  2. import (
  3. "bufio"
  4. "fmt"
  5. "strings"
  6. "sync"
  7. "time"
  8. "github.com/mediocregopher/radix/v3/resp"
  9. )
  10. var blockingCmds = map[string]bool{
  11. "WAIT": true,
  12. // taken from https://github.com/joomcode/redispipe#limitations
  13. "BLPOP": true,
  14. "BRPOP": true,
  15. "BRPOPLPUSH": true,
  16. "BZPOPMIN": true,
  17. "BZPOPMAX": true,
  18. "XREAD": true,
  19. "XREADGROUP": true,
  20. "SAVE": true,
  21. }
  22. type pipeliner struct {
  23. c Client
  24. limit int
  25. window time.Duration
  26. // reqsBufCh contains buffers for collecting commands and acts as a semaphore
  27. // to limit the number of concurrent flushes.
  28. reqsBufCh chan []CmdAction
  29. reqCh chan *pipelinerCmd
  30. reqWG sync.WaitGroup
  31. l sync.RWMutex
  32. closed bool
  33. }
  34. var _ Client = (*pipeliner)(nil)
  35. func newPipeliner(c Client, concurrency, limit int, window time.Duration) *pipeliner {
  36. if concurrency < 1 {
  37. concurrency = 1
  38. }
  39. p := &pipeliner{
  40. c: c,
  41. limit: limit,
  42. window: window,
  43. reqsBufCh: make(chan []CmdAction, concurrency),
  44. reqCh: make(chan *pipelinerCmd, 32), // https://xkcd.com/221/
  45. }
  46. p.reqWG.Add(1)
  47. go func() {
  48. defer p.reqWG.Done()
  49. p.reqLoop()
  50. }()
  51. for i := 0; i < cap(p.reqsBufCh); i++ {
  52. if p.limit > 0 {
  53. p.reqsBufCh <- make([]CmdAction, 0, limit)
  54. } else {
  55. p.reqsBufCh <- nil
  56. }
  57. }
  58. return p
  59. }
  60. // CanDo checks if the given Action can be executed / passed to p.Do.
  61. //
  62. // If CanDo returns false, the Action must not be given to Do.
  63. func (p *pipeliner) CanDo(a Action) bool {
  64. // there is currently no way to get the command for CmdAction implementations
  65. // from outside the radix package so we can not multiplex those commands. User
  66. // defined pipelines are not pipelined to let the user better control them.
  67. if cmdA, ok := a.(*cmdAction); ok {
  68. return !blockingCmds[strings.ToUpper(cmdA.cmd)]
  69. }
  70. return false
  71. }
  72. // Do executes the given Action as part of the pipeline.
  73. //
  74. // If a is not a CmdAction, Do panics.
  75. func (p *pipeliner) Do(a Action) error {
  76. req := getPipelinerCmd(a.(CmdAction)) // get this outside the lock to avoid
  77. p.l.RLock()
  78. if p.closed {
  79. p.l.RUnlock()
  80. return errClientClosed
  81. }
  82. p.reqCh <- req
  83. p.l.RUnlock()
  84. err := <-req.resCh
  85. poolPipelinerCmd(req)
  86. return err
  87. }
  88. // Close closes the pipeliner and makes sure that all background goroutines
  89. // are stopped before returning.
  90. //
  91. // Close does *not* close the underlying Client.
  92. func (p *pipeliner) Close() error {
  93. p.l.Lock()
  94. defer p.l.Unlock()
  95. if p.closed {
  96. return nil
  97. }
  98. close(p.reqCh)
  99. p.reqWG.Wait()
  100. for i := 0; i < cap(p.reqsBufCh); i++ {
  101. <-p.reqsBufCh
  102. }
  103. p.c, p.closed = nil, true
  104. return nil
  105. }
  106. func (p *pipeliner) reqLoop() {
  107. t := getTimer(time.Hour)
  108. defer putTimer(t)
  109. t.Stop()
  110. reqs := <-p.reqsBufCh
  111. defer func() {
  112. p.reqsBufCh <- reqs
  113. }()
  114. for {
  115. select {
  116. case req, ok := <-p.reqCh:
  117. if !ok {
  118. reqs = p.flush(reqs)
  119. return
  120. }
  121. reqs = append(reqs, req)
  122. if p.limit > 0 && len(reqs) == p.limit {
  123. // if we reached the pipeline limit, execute now to avoid unnecessary waiting
  124. t.Stop()
  125. reqs = p.flush(reqs)
  126. } else if len(reqs) == 1 {
  127. t.Reset(p.window)
  128. }
  129. case <-t.C:
  130. reqs = p.flush(reqs)
  131. }
  132. }
  133. }
  134. func (p *pipeliner) flush(reqs []CmdAction) []CmdAction {
  135. if len(reqs) == 0 {
  136. return reqs
  137. }
  138. go func() {
  139. defer func() {
  140. p.reqsBufCh <- reqs[:0]
  141. }()
  142. pp := &pipelinerPipeline{pipeline: pipeline(reqs)}
  143. defer pp.flush()
  144. if err := p.c.Do(pp); err != nil {
  145. pp.doErr = err
  146. }
  147. }()
  148. return <-p.reqsBufCh
  149. }
  150. type pipelinerCmd struct {
  151. CmdAction
  152. resCh chan error
  153. unmarshalCalled bool
  154. unmarshalErr error
  155. }
  156. var (
  157. _ resp.Unmarshaler = (*pipelinerCmd)(nil)
  158. )
  159. func (p *pipelinerCmd) sendRes(err error) {
  160. p.resCh <- err
  161. }
  162. func (p *pipelinerCmd) UnmarshalRESP(br *bufio.Reader) error {
  163. p.unmarshalErr = p.CmdAction.UnmarshalRESP(br)
  164. p.unmarshalCalled = true // important: we set this after unmarshalErr in case the call to UnmarshalRESP panics
  165. return p.unmarshalErr
  166. }
  167. var pipelinerCmdPool sync.Pool
  168. func getPipelinerCmd(cmd CmdAction) *pipelinerCmd {
  169. req, _ := pipelinerCmdPool.Get().(*pipelinerCmd)
  170. if req != nil {
  171. *req = pipelinerCmd{
  172. CmdAction: cmd,
  173. resCh: req.resCh,
  174. }
  175. return req
  176. }
  177. return &pipelinerCmd{
  178. CmdAction: cmd,
  179. // using a buffer of 1 is faster than no buffer in most cases
  180. resCh: make(chan error, 1),
  181. }
  182. }
  183. func poolPipelinerCmd(req *pipelinerCmd) {
  184. req.CmdAction = nil
  185. pipelinerCmdPool.Put(req)
  186. }
  187. type pipelinerPipeline struct {
  188. pipeline
  189. doErr error
  190. }
  191. func (p *pipelinerPipeline) flush() {
  192. for _, req := range p.pipeline {
  193. var err error
  194. cmd := req.(*pipelinerCmd)
  195. if cmd.unmarshalCalled {
  196. err = cmd.unmarshalErr
  197. } else {
  198. err = p.doErr
  199. }
  200. cmd.sendRes(err)
  201. }
  202. }
  203. func (p *pipelinerPipeline) Run(c Conn) (err error) {
  204. defer func() {
  205. if v := recover(); v != nil {
  206. err = fmt.Errorf("%s", v)
  207. }
  208. }()
  209. if err := c.Encode(p); err != nil {
  210. return err
  211. }
  212. errConn := ioErrConn{Conn: c}
  213. for _, req := range p.pipeline {
  214. if _ = errConn.Decode(req); errConn.lastIOErr != nil {
  215. return errConn.lastIOErr
  216. }
  217. }
  218. return nil
  219. }