dialer.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. package kafka
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net"
  9. "strconv"
  10. "strings"
  11. "time"
  12. "github.com/segmentio/kafka-go/sasl"
  13. )
  14. // The Dialer type mirrors the net.Dialer API but is designed to open kafka
  15. // connections instead of raw network connections.
  16. type Dialer struct {
  17. // Unique identifier for client connections established by this Dialer.
  18. ClientID string
  19. // Optionally specifies the function that the dialer uses to establish
  20. // network connections. If nil, net.(*Dialer).DialContext is used instead.
  21. //
  22. // When DialFunc is set, LocalAddr, DualStack, FallbackDelay, and KeepAlive
  23. // are ignored.
  24. DialFunc func(ctx context.Context, network string, address string) (net.Conn, error)
  25. // Timeout is the maximum amount of time a dial will wait for a connect to
  26. // complete. If Deadline is also set, it may fail earlier.
  27. //
  28. // The default is no timeout.
  29. //
  30. // When dialing a name with multiple IP addresses, the timeout may be
  31. // divided between them.
  32. //
  33. // With or without a timeout, the operating system may impose its own
  34. // earlier timeout. For instance, TCP timeouts are often around 3 minutes.
  35. Timeout time.Duration
  36. // Deadline is the absolute point in time after which dials will fail.
  37. // If Timeout is set, it may fail earlier.
  38. // Zero means no deadline, or dependent on the operating system as with the
  39. // Timeout option.
  40. Deadline time.Time
  41. // LocalAddr is the local address to use when dialing an address.
  42. // The address must be of a compatible type for the network being dialed.
  43. // If nil, a local address is automatically chosen.
  44. LocalAddr net.Addr
  45. // DualStack enables RFC 6555-compliant "Happy Eyeballs" dialing when the
  46. // network is "tcp" and the destination is a host name with both IPv4 and
  47. // IPv6 addresses. This allows a client to tolerate networks where one
  48. // address family is silently broken.
  49. DualStack bool
  50. // FallbackDelay specifies the length of time to wait before spawning a
  51. // fallback connection, when DualStack is enabled.
  52. // If zero, a default delay of 300ms is used.
  53. FallbackDelay time.Duration
  54. // KeepAlive specifies the keep-alive period for an active network
  55. // connection.
  56. // If zero, keep-alives are not enabled. Network protocols that do not
  57. // support keep-alives ignore this field.
  58. KeepAlive time.Duration
  59. // Resolver optionally gives a hook to convert the broker address into an
  60. // alternate host or IP address which is useful for custom service discovery.
  61. // If a custom resolver returns any possible hosts, the first one will be
  62. // used and the original discarded. If a port number is included with the
  63. // resolved host, it will only be used if a port number was not previously
  64. // specified. If no port is specified or resolved, the default of 9092 will be
  65. // used.
  66. Resolver Resolver
  67. // TLS enables Dialer to open secure connections. If nil, standard net.Conn
  68. // will be used.
  69. TLS *tls.Config
  70. // SASLMechanism configures the Dialer to use SASL authentication. If nil,
  71. // no authentication will be performed.
  72. SASLMechanism sasl.Mechanism
  73. // The transactional id to use for transactional delivery. Idempotent
  74. // deliver should be enabled if transactional id is configured.
  75. // For more details look at transactional.id description here: http://kafka.apache.org/documentation.html#producerconfigs
  76. // Empty string means that the connection will be non-transactional.
  77. TransactionalID string
  78. }
  79. // Dial connects to the address on the named network.
  80. func (d *Dialer) Dial(network string, address string) (*Conn, error) {
  81. return d.DialContext(context.Background(), network, address)
  82. }
  83. // DialContext connects to the address on the named network using the provided
  84. // context.
  85. //
  86. // The provided Context must be non-nil. If the context expires before the
  87. // connection is complete, an error is returned. Once successfully connected,
  88. // any expiration of the context will not affect the connection.
  89. //
  90. // When using TCP, and the host in the address parameter resolves to multiple
  91. // network addresses, any dial timeout (from d.Timeout or ctx) is spread over
  92. // each consecutive dial, such that each is given an appropriate fraction of the
  93. // time to connect. For example, if a host has 4 IP addresses and the timeout is
  94. // 1 minute, the connect to each single address will be given 15 seconds to
  95. // complete before trying the next one.
  96. func (d *Dialer) DialContext(ctx context.Context, network string, address string) (*Conn, error) {
  97. return d.connect(
  98. ctx,
  99. network,
  100. address,
  101. ConnConfig{
  102. ClientID: d.ClientID,
  103. TransactionalID: d.TransactionalID,
  104. },
  105. )
  106. }
  107. // DialLeader opens a connection to the leader of the partition for a given
  108. // topic.
  109. //
  110. // The address given to the DialContext method may not be the one that the
  111. // connection will end up being established to, because the dialer will lookup
  112. // the partition leader for the topic and return a connection to that server.
  113. // The original address is only used as a mechanism to discover the
  114. // configuration of the kafka cluster that we're connecting to.
  115. func (d *Dialer) DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) {
  116. p, err := d.LookupPartition(ctx, network, address, topic, partition)
  117. if err != nil {
  118. return nil, err
  119. }
  120. return d.DialPartition(ctx, network, address, p)
  121. }
  122. // DialPartition opens a connection to the leader of the partition specified by partition
  123. // descriptor. It's strongly advised to use descriptor of the partition that comes out of
  124. // functions LookupPartition or LookupPartitions.
  125. func (d *Dialer) DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) {
  126. return d.connect(ctx, network, net.JoinHostPort(partition.Leader.Host, strconv.Itoa(partition.Leader.Port)), ConnConfig{
  127. ClientID: d.ClientID,
  128. Topic: partition.Topic,
  129. Partition: partition.ID,
  130. Broker: partition.Leader.ID,
  131. Rack: partition.Leader.Rack,
  132. TransactionalID: d.TransactionalID,
  133. })
  134. }
  135. // LookupLeader searches for the kafka broker that is the leader of the
  136. // partition for a given topic, returning a Broker value representing it.
  137. func (d *Dialer) LookupLeader(ctx context.Context, network string, address string, topic string, partition int) (Broker, error) {
  138. p, err := d.LookupPartition(ctx, network, address, topic, partition)
  139. return p.Leader, err
  140. }
  141. // LookupPartition searches for the description of specified partition id.
  142. func (d *Dialer) LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) {
  143. c, err := d.DialContext(ctx, network, address)
  144. if err != nil {
  145. return Partition{}, err
  146. }
  147. defer c.Close()
  148. brkch := make(chan Partition, 1)
  149. errch := make(chan error, 1)
  150. go func() {
  151. for attempt := 0; true; attempt++ {
  152. if attempt != 0 {
  153. if !sleep(ctx, backoff(attempt, 100*time.Millisecond, 10*time.Second)) {
  154. errch <- ctx.Err()
  155. return
  156. }
  157. }
  158. partitions, err := c.ReadPartitions(topic)
  159. if err != nil {
  160. if isTemporary(err) {
  161. continue
  162. }
  163. errch <- err
  164. return
  165. }
  166. for _, p := range partitions {
  167. if p.ID == partition {
  168. brkch <- p
  169. return
  170. }
  171. }
  172. }
  173. errch <- UnknownTopicOrPartition
  174. }()
  175. var prt Partition
  176. select {
  177. case prt = <-brkch:
  178. case err = <-errch:
  179. case <-ctx.Done():
  180. err = ctx.Err()
  181. }
  182. return prt, err
  183. }
  184. // LookupPartitions returns the list of partitions that exist for the given topic.
  185. func (d *Dialer) LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) {
  186. conn, err := d.DialContext(ctx, network, address)
  187. if err != nil {
  188. return nil, err
  189. }
  190. defer conn.Close()
  191. prtch := make(chan []Partition, 1)
  192. errch := make(chan error, 1)
  193. go func() {
  194. if prt, err := conn.ReadPartitions(topic); err != nil {
  195. errch <- err
  196. } else {
  197. prtch <- prt
  198. }
  199. }()
  200. var prt []Partition
  201. select {
  202. case prt = <-prtch:
  203. case err = <-errch:
  204. case <-ctx.Done():
  205. err = ctx.Err()
  206. }
  207. return prt, err
  208. }
  209. // connectTLS returns a tls.Conn that has already completed the Handshake.
  210. func (d *Dialer) connectTLS(ctx context.Context, conn net.Conn, config *tls.Config) (tlsConn *tls.Conn, err error) {
  211. tlsConn = tls.Client(conn, config)
  212. errch := make(chan error)
  213. go func() {
  214. defer close(errch)
  215. errch <- tlsConn.Handshake()
  216. }()
  217. select {
  218. case <-ctx.Done():
  219. conn.Close()
  220. tlsConn.Close()
  221. <-errch // ignore possible error from Handshake
  222. err = ctx.Err()
  223. case err = <-errch:
  224. }
  225. return
  226. }
  227. // connect opens a socket connection to the broker, wraps it to create a
  228. // kafka connection, and performs SASL authentication if configured to do so.
  229. func (d *Dialer) connect(ctx context.Context, network, address string, connCfg ConnConfig) (*Conn, error) {
  230. if d.Timeout != 0 {
  231. var cancel context.CancelFunc
  232. ctx, cancel = context.WithTimeout(ctx, d.Timeout)
  233. defer cancel()
  234. }
  235. if !d.Deadline.IsZero() {
  236. var cancel context.CancelFunc
  237. ctx, cancel = context.WithDeadline(ctx, d.Deadline)
  238. defer cancel()
  239. }
  240. c, err := d.dialContext(ctx, network, address)
  241. if err != nil {
  242. return nil, fmt.Errorf("failed to dial: %w", err)
  243. }
  244. conn := NewConnWith(c, connCfg)
  245. if d.SASLMechanism != nil {
  246. host, port, err := splitHostPortNumber(address)
  247. if err != nil {
  248. return nil, fmt.Errorf("could not determine host/port for SASL authentication: %w", err)
  249. }
  250. metadata := &sasl.Metadata{
  251. Host: host,
  252. Port: port,
  253. }
  254. if err := d.authenticateSASL(sasl.WithMetadata(ctx, metadata), conn); err != nil {
  255. _ = conn.Close()
  256. return nil, fmt.Errorf("could not successfully authenticate to %s:%d with SASL: %w", host, port, err)
  257. }
  258. }
  259. return conn, nil
  260. }
  261. // authenticateSASL performs all of the required requests to authenticate this
  262. // connection. If any step fails, this function returns with an error. A nil
  263. // error indicates successful authentication.
  264. //
  265. // In case of error, this function *does not* close the connection. That is the
  266. // responsibility of the caller.
  267. func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
  268. if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil {
  269. return fmt.Errorf("SASL handshake failed: %w", err)
  270. }
  271. sess, state, err := d.SASLMechanism.Start(ctx)
  272. if err != nil {
  273. return fmt.Errorf("SASL authentication process could not be started: %w", err)
  274. }
  275. for completed := false; !completed; {
  276. challenge, err := conn.saslAuthenticate(state)
  277. switch {
  278. case err == nil:
  279. case errors.Is(err, io.EOF):
  280. // the broker may communicate a failed exchange by closing the
  281. // connection (esp. in the case where we're passing opaque sasl
  282. // data over the wire since there's no protocol info).
  283. return SASLAuthenticationFailed
  284. default:
  285. return err
  286. }
  287. completed, state, err = sess.Next(ctx, challenge)
  288. if err != nil {
  289. return fmt.Errorf("SASL authentication process has failed: %w", err)
  290. }
  291. }
  292. return nil
  293. }
  294. func (d *Dialer) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
  295. address, err := lookupHost(ctx, addr, d.Resolver)
  296. if err != nil {
  297. return nil, fmt.Errorf("failed to resolve host: %w", err)
  298. }
  299. dial := d.DialFunc
  300. if dial == nil {
  301. dial = (&net.Dialer{
  302. LocalAddr: d.LocalAddr,
  303. DualStack: d.DualStack,
  304. FallbackDelay: d.FallbackDelay,
  305. KeepAlive: d.KeepAlive,
  306. }).DialContext
  307. }
  308. conn, err := dial(ctx, network, address)
  309. if err != nil {
  310. return nil, fmt.Errorf("failed to open connection to %s: %w", address, err)
  311. }
  312. if d.TLS != nil {
  313. c := d.TLS
  314. // If no ServerName is set, infer the ServerName
  315. // from the hostname we're connecting to.
  316. if c.ServerName == "" {
  317. c = d.TLS.Clone()
  318. // Copied from tls.go in the standard library.
  319. colonPos := strings.LastIndex(address, ":")
  320. if colonPos == -1 {
  321. colonPos = len(address)
  322. }
  323. hostname := address[:colonPos]
  324. c.ServerName = hostname
  325. }
  326. return d.connectTLS(ctx, conn, c)
  327. }
  328. return conn, nil
  329. }
  330. // DefaultDialer is the default dialer used when none is specified.
  331. var DefaultDialer = &Dialer{
  332. Timeout: 10 * time.Second,
  333. DualStack: true,
  334. }
  335. // Dial is a convenience wrapper for DefaultDialer.Dial.
  336. func Dial(network string, address string) (*Conn, error) {
  337. return DefaultDialer.Dial(network, address)
  338. }
  339. // DialContext is a convenience wrapper for DefaultDialer.DialContext.
  340. func DialContext(ctx context.Context, network string, address string) (*Conn, error) {
  341. return DefaultDialer.DialContext(ctx, network, address)
  342. }
  343. // DialLeader is a convenience wrapper for DefaultDialer.DialLeader.
  344. func DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) {
  345. return DefaultDialer.DialLeader(ctx, network, address, topic, partition)
  346. }
  347. // DialPartition is a convenience wrapper for DefaultDialer.DialPartition.
  348. func DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) {
  349. return DefaultDialer.DialPartition(ctx, network, address, partition)
  350. }
  351. // LookupPartition is a convenience wrapper for DefaultDialer.LookupPartition.
  352. func LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) {
  353. return DefaultDialer.LookupPartition(ctx, network, address, topic, partition)
  354. }
  355. // LookupPartitions is a convenience wrapper for DefaultDialer.LookupPartitions.
  356. func LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) {
  357. return DefaultDialer.LookupPartitions(ctx, network, address, topic)
  358. }
  359. func sleep(ctx context.Context, duration time.Duration) bool {
  360. if duration == 0 {
  361. select {
  362. default:
  363. return true
  364. case <-ctx.Done():
  365. return false
  366. }
  367. }
  368. timer := time.NewTimer(duration)
  369. defer timer.Stop()
  370. select {
  371. case <-timer.C:
  372. return true
  373. case <-ctx.Done():
  374. return false
  375. }
  376. }
  377. func backoff(attempt int, min time.Duration, max time.Duration) time.Duration {
  378. d := time.Duration(attempt*attempt) * min
  379. if d > max {
  380. d = max
  381. }
  382. return d
  383. }
  384. func canonicalAddress(s string) string {
  385. return net.JoinHostPort(splitHostPort(s))
  386. }
  387. func splitHostPort(s string) (host string, port string) {
  388. host, port, _ = net.SplitHostPort(s)
  389. if len(host) == 0 && len(port) == 0 {
  390. host = s
  391. port = "9092"
  392. }
  393. return
  394. }
  395. func splitHostPortNumber(s string) (host string, portNumber int, err error) {
  396. host, port := splitHostPort(s)
  397. portNumber, err = strconv.Atoi(port)
  398. if err != nil {
  399. return host, 0, fmt.Errorf("%s: %w", s, err)
  400. }
  401. return host, portNumber, nil
  402. }
  403. func lookupHost(ctx context.Context, address string, resolver Resolver) (string, error) {
  404. host, port := splitHostPort(address)
  405. if resolver != nil {
  406. resolved, err := resolver.LookupHost(ctx, host)
  407. if err != nil {
  408. return "", fmt.Errorf("failed to resolve host %s: %w", host, err)
  409. }
  410. // if the resolver doesn't return anything, we'll fall back on the provided
  411. // address instead
  412. if len(resolved) > 0 {
  413. resolvedHost, resolvedPort := splitHostPort(resolved[0])
  414. // we'll always prefer the resolved host
  415. host = resolvedHost
  416. // in the case of port though, the provided address takes priority, and we
  417. // only use the resolved address to set the port when not specified
  418. if port == "" {
  419. port = resolvedPort
  420. }
  421. }
  422. }
  423. return net.JoinHostPort(host, port), nil
  424. }