conn.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. package radix
  2. import (
  3. "bufio"
  4. "crypto/tls"
  5. "net"
  6. "net/url"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/mediocregopher/radix/v3/resp"
  11. )
  12. // Conn is a Client wrapping a single network connection which synchronously
  13. // reads/writes data using the redis resp protocol.
  14. //
  15. // A Conn can be used directly as a Client, but in general you probably want to
  16. // use a *Pool instead.
  17. type Conn interface {
  18. // The Do method of a Conn is _not_ expected to be thread-safe with the
  19. // other methods of Conn, and merely calls the Action's Run method with
  20. // itself as the argument.
  21. Client
  22. // Encode and Decode may be called at the same time by two different
  23. // go-routines, but each should only be called once at a time (i.e. two
  24. // routines shouldn't call Encode at the same time, same with Decode).
  25. //
  26. // Encode and Decode should _not_ be called at the same time as Do.
  27. //
  28. // If either Encode or Decode encounter a net.Error the Conn will be
  29. // automatically closed.
  30. //
  31. // Encode is expected to encode an entire resp message, not a partial one.
  32. // In other words, when sending commands to redis, Encode should only be
  33. // called once per command. Similarly, Decode is expected to decode an
  34. // entire resp response.
  35. Encode(resp.Marshaler) error
  36. Decode(resp.Unmarshaler) error
  37. // Returns the underlying network connection, as-is. Read, Write, and Close
  38. // should not be called on the returned Conn.
  39. NetConn() net.Conn
  40. }
  41. // ConnFunc is a function which returns an initialized, ready-to-be-used Conn.
  42. // Functions like NewPool or NewCluster take in a ConnFunc in order to allow for
  43. // things like calls to AUTH on each new connection, setting timeouts, custom
  44. // Conn implementations, etc... See the package docs for more details.
  45. type ConnFunc func(network, addr string) (Conn, error)
  46. // DefaultConnFunc is a ConnFunc which will return a Conn for a redis instance
  47. // using sane defaults.
  48. var DefaultConnFunc = func(network, addr string) (Conn, error) {
  49. return Dial(network, addr)
  50. }
  51. func wrapDefaultConnFunc(addr string) ConnFunc {
  52. _, opts := parseRedisURL(addr)
  53. return func(network, addr string) (Conn, error) {
  54. return Dial(network, addr, opts...)
  55. }
  56. }
  57. type connWrap struct {
  58. net.Conn
  59. brw *bufio.ReadWriter
  60. }
  61. // NewConn takes an existing net.Conn and wraps it to support the Conn interface
  62. // of this package. The Read and Write methods on the original net.Conn should
  63. // not be used after calling this method.
  64. func NewConn(conn net.Conn) Conn {
  65. return &connWrap{
  66. Conn: conn,
  67. brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
  68. }
  69. }
  70. func (cw *connWrap) Do(a Action) error {
  71. return a.Run(cw)
  72. }
  73. func (cw *connWrap) Encode(m resp.Marshaler) error {
  74. if err := m.MarshalRESP(cw.brw); err != nil {
  75. return err
  76. }
  77. return cw.brw.Flush()
  78. }
  79. func (cw *connWrap) Decode(u resp.Unmarshaler) error {
  80. return u.UnmarshalRESP(cw.brw.Reader)
  81. }
  82. func (cw *connWrap) NetConn() net.Conn {
  83. return cw.Conn
  84. }
  85. type dialOpts struct {
  86. connectTimeout, readTimeout, writeTimeout time.Duration
  87. authUser, authPass string
  88. selectDB string
  89. useTLSConfig bool
  90. tlsConfig *tls.Config
  91. }
  92. // DialOpt is an optional behavior which can be applied to the Dial function to
  93. // effect its behavior, or the behavior of the Conn it creates.
  94. type DialOpt func(*dialOpts)
  95. // DialConnectTimeout determines the timeout value to pass into net.DialTimeout
  96. // when creating the connection. If not set then net.Dial is called instead.
  97. func DialConnectTimeout(d time.Duration) DialOpt {
  98. return func(do *dialOpts) {
  99. do.connectTimeout = d
  100. }
  101. }
  102. // DialReadTimeout determines the deadline to set when reading from a dialed
  103. // connection. If not set then SetReadDeadline is never called.
  104. func DialReadTimeout(d time.Duration) DialOpt {
  105. return func(do *dialOpts) {
  106. do.readTimeout = d
  107. }
  108. }
  109. // DialWriteTimeout determines the deadline to set when writing to a dialed
  110. // connection. If not set then SetWriteDeadline is never called.
  111. func DialWriteTimeout(d time.Duration) DialOpt {
  112. return func(do *dialOpts) {
  113. do.writeTimeout = d
  114. }
  115. }
  116. // DialTimeout is the equivalent to using DialConnectTimeout, DialReadTimeout,
  117. // and DialWriteTimeout all with the same value.
  118. func DialTimeout(d time.Duration) DialOpt {
  119. return func(do *dialOpts) {
  120. DialConnectTimeout(d)(do)
  121. DialReadTimeout(d)(do)
  122. DialWriteTimeout(d)(do)
  123. }
  124. }
  125. const defaultAuthUser = "default"
  126. // DialAuthPass will cause Dial to perform an AUTH command once the connection
  127. // is created, using the given pass.
  128. //
  129. // If this is set and a redis URI is passed to Dial which also has a password
  130. // set, this takes precedence.
  131. //
  132. // Using DialAuthPass is equivalent to calling DialAuthUser with user "default"
  133. // and is kept for compatibility with older package versions.
  134. func DialAuthPass(pass string) DialOpt {
  135. return DialAuthUser(defaultAuthUser, pass)
  136. }
  137. // DialAuthUser will cause Dial to perform an AUTH command once the connection
  138. // is created, using the given user and pass.
  139. //
  140. // If this is set and a redis URI is passed to Dial which also has a username
  141. // and password set, this takes precedence.
  142. func DialAuthUser(user, pass string) DialOpt {
  143. return func(do *dialOpts) {
  144. do.authUser = user
  145. do.authPass = pass
  146. }
  147. }
  148. // DialSelectDB will cause Dial to perform a SELECT command once the connection
  149. // is created, using the given database index.
  150. //
  151. // If this is set and a redis URI is passed to Dial which also has a database
  152. // index set, this takes precedence.
  153. func DialSelectDB(db int) DialOpt {
  154. return func(do *dialOpts) {
  155. do.selectDB = strconv.Itoa(db)
  156. }
  157. }
  158. // DialUseTLS will cause Dial to perform a TLS handshake using the provided
  159. // config. If config is nil the config is interpreted as equivalent to the zero
  160. // configuration. See https://golang.org/pkg/crypto/tls/#Config
  161. func DialUseTLS(config *tls.Config) DialOpt {
  162. return func(do *dialOpts) {
  163. do.tlsConfig = config
  164. do.useTLSConfig = true
  165. }
  166. }
  167. type timeoutConn struct {
  168. net.Conn
  169. readTimeout, writeTimeout time.Duration
  170. }
  171. func (tc *timeoutConn) Read(b []byte) (int, error) {
  172. if tc.readTimeout > 0 {
  173. err := tc.Conn.SetReadDeadline(time.Now().Add(tc.readTimeout))
  174. if err != nil {
  175. return 0, err
  176. }
  177. }
  178. return tc.Conn.Read(b)
  179. }
  180. func (tc *timeoutConn) Write(b []byte) (int, error) {
  181. if tc.writeTimeout > 0 {
  182. err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout))
  183. if err != nil {
  184. return 0, err
  185. }
  186. }
  187. return tc.Conn.Write(b)
  188. }
  189. var defaultDialOpts = []DialOpt{
  190. DialTimeout(10 * time.Second),
  191. }
  192. func parseRedisURL(urlStr string) (string, []DialOpt) {
  193. // do a quick check before we bust out url.Parse, in case that is very
  194. // unperformant
  195. if !strings.HasPrefix(urlStr, "redis://") {
  196. return urlStr, nil
  197. }
  198. u, err := url.Parse(urlStr)
  199. if err != nil {
  200. return urlStr, nil
  201. }
  202. q := u.Query()
  203. username := defaultAuthUser
  204. if n := u.User.Username(); n != "" {
  205. username = n
  206. } else if n := q.Get("username"); n != "" {
  207. username = n
  208. }
  209. password := q.Get("password")
  210. if p, ok := u.User.Password(); ok {
  211. password = p
  212. }
  213. opts := []DialOpt{
  214. DialAuthUser(username, password),
  215. }
  216. dbStr := q.Get("db")
  217. if u.Path != "" && u.Path != "/" {
  218. dbStr = u.Path[1:]
  219. }
  220. if dbStr, err := strconv.Atoi(dbStr); err == nil {
  221. opts = append(opts, DialSelectDB(dbStr))
  222. }
  223. return u.Host, opts
  224. }
  225. // Dial is a ConnFunc which creates a Conn using net.Dial and NewConn. It takes
  226. // in a number of options which can overwrite its default behavior as well.
  227. //
  228. // In place of a host:port address, Dial also accepts a URI, as per:
  229. // https://www.iana.org/assignments/uri-schemes/prov/redis
  230. // If the URI has an AUTH password or db specified Dial will attempt to perform
  231. // the AUTH and/or SELECT as well.
  232. //
  233. // If either DialAuthPass or DialSelectDB is used it overwrites the associated
  234. // value passed in by the URI.
  235. //
  236. // The default options Dial uses are:
  237. //
  238. // DialTimeout(10 * time.Second)
  239. //
  240. func Dial(network, addr string, opts ...DialOpt) (Conn, error) {
  241. var do dialOpts
  242. for _, opt := range defaultDialOpts {
  243. opt(&do)
  244. }
  245. addr, addrOpts := parseRedisURL(addr)
  246. for _, opt := range addrOpts {
  247. opt(&do)
  248. }
  249. for _, opt := range opts {
  250. opt(&do)
  251. }
  252. var netConn net.Conn
  253. var err error
  254. dialer := net.Dialer{}
  255. if do.connectTimeout > 0 {
  256. dialer.Timeout = do.connectTimeout
  257. }
  258. if do.useTLSConfig {
  259. netConn, err = tls.DialWithDialer(&dialer, network, addr, do.tlsConfig)
  260. } else {
  261. netConn, err = dialer.Dial(network, addr)
  262. }
  263. if err != nil {
  264. return nil, err
  265. }
  266. // If the netConn is a net.TCPConn (or some wrapper for it) and so can have
  267. // keepalive enabled, do so with a sane (though slightly aggressive)
  268. // default.
  269. {
  270. type keepaliveConn interface {
  271. SetKeepAlive(bool) error
  272. SetKeepAlivePeriod(time.Duration) error
  273. }
  274. if kaConn, ok := netConn.(keepaliveConn); ok {
  275. if err = kaConn.SetKeepAlive(true); err != nil {
  276. netConn.Close()
  277. return nil, err
  278. } else if err = kaConn.SetKeepAlivePeriod(10 * time.Second); err != nil {
  279. netConn.Close()
  280. return nil, err
  281. }
  282. }
  283. }
  284. conn := NewConn(&timeoutConn{
  285. readTimeout: do.readTimeout,
  286. writeTimeout: do.writeTimeout,
  287. Conn: netConn,
  288. })
  289. if do.authUser != "" && do.authUser != defaultAuthUser {
  290. if err := conn.Do(Cmd(nil, "AUTH", do.authUser, do.authPass)); err != nil {
  291. conn.Close()
  292. return nil, err
  293. }
  294. } else if do.authPass != "" {
  295. if err := conn.Do(Cmd(nil, "AUTH", do.authPass)); err != nil {
  296. conn.Close()
  297. return nil, err
  298. }
  299. }
  300. if do.selectDB != "" {
  301. if err := conn.Do(Cmd(nil, "SELECT", do.selectDB)); err != nil {
  302. conn.Close()
  303. return nil, err
  304. }
  305. }
  306. return conn, nil
  307. }