net.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. /*
  2. * Copyright (c) 2021 IBM Corp and others.
  3. *
  4. * All rights reserved. This program and the accompanying materials
  5. * are made available under the terms of the Eclipse Public License v2.0
  6. * and Eclipse Distribution License v1.0 which accompany this distribution.
  7. *
  8. * The Eclipse Public License is available at
  9. * https://www.eclipse.org/legal/epl-2.0/
  10. * and the Eclipse Distribution License is available at
  11. * http://www.eclipse.org/org/documents/edl-v10.php.
  12. *
  13. * Contributors:
  14. * Seth Hoenig
  15. * Allan Stockdill-Mander
  16. * Mike Robertson
  17. * Matt Brittan
  18. */
  19. package mqtt
  20. import (
  21. "errors"
  22. "io"
  23. "net"
  24. "reflect"
  25. "strings"
  26. "sync"
  27. "time"
  28. "github.com/eclipse/paho.mqtt.golang/packets"
  29. )
  30. const closedNetConnErrorText = "use of closed network connection" // error string for closed conn (https://golang.org/src/net/error_test.go)
  31. // ConnectMQTT takes a connected net.Conn and performs the initial MQTT handshake. Parameters are:
  32. // conn - Connected net.Conn
  33. // cm - Connect Packet with everything other than the protocol name/version populated (historical reasons)
  34. // protocolVersion - The protocol version to attempt to connect with
  35. //
  36. // Note that, for backward compatibility, ConnectMQTT() suppresses the actual connection error (compare to connectMQTT()).
  37. func ConnectMQTT(conn net.Conn, cm *packets.ConnectPacket, protocolVersion uint) (byte, bool) {
  38. rc, sessionPresent, _ := connectMQTT(conn, cm, protocolVersion)
  39. return rc, sessionPresent
  40. }
  41. func connectMQTT(conn io.ReadWriter, cm *packets.ConnectPacket, protocolVersion uint) (byte, bool, error) {
  42. switch protocolVersion {
  43. case 3:
  44. DEBUG.Println(CLI, "Using MQTT 3.1 protocol")
  45. cm.ProtocolName = "MQIsdp"
  46. cm.ProtocolVersion = 3
  47. case 0x83:
  48. DEBUG.Println(CLI, "Using MQTT 3.1b protocol")
  49. cm.ProtocolName = "MQIsdp"
  50. cm.ProtocolVersion = 0x83
  51. case 0x84:
  52. DEBUG.Println(CLI, "Using MQTT 3.1.1b protocol")
  53. cm.ProtocolName = "MQTT"
  54. cm.ProtocolVersion = 0x84
  55. default:
  56. DEBUG.Println(CLI, "Using MQTT 3.1.1 protocol")
  57. cm.ProtocolName = "MQTT"
  58. cm.ProtocolVersion = 4
  59. }
  60. if err := cm.Write(conn); err != nil {
  61. ERROR.Println(CLI, err)
  62. return packets.ErrNetworkError, false, err
  63. }
  64. rc, sessionPresent, err := verifyCONNACK(conn)
  65. return rc, sessionPresent, err
  66. }
  67. // This function is only used for receiving a connack
  68. // when the connection is first started.
  69. // This prevents receiving incoming data while resume
  70. // is in progress if clean session is false.
  71. func verifyCONNACK(conn io.Reader) (byte, bool, error) {
  72. DEBUG.Println(NET, "connect started")
  73. ca, err := packets.ReadPacket(conn)
  74. if err != nil {
  75. ERROR.Println(NET, "connect got error", err)
  76. return packets.ErrNetworkError, false, err
  77. }
  78. if ca == nil {
  79. ERROR.Println(NET, "received nil packet")
  80. return packets.ErrNetworkError, false, errors.New("nil CONNACK packet")
  81. }
  82. msg, ok := ca.(*packets.ConnackPacket)
  83. if !ok {
  84. ERROR.Println(NET, "received msg that was not CONNACK")
  85. return packets.ErrNetworkError, false, errors.New("non-CONNACK first packet received")
  86. }
  87. DEBUG.Println(NET, "received connack")
  88. return msg.ReturnCode, msg.SessionPresent, nil
  89. }
  90. // inbound encapsulates the output from startIncoming.
  91. // err - If != nil then an error has occurred
  92. // cp - A control packet received over the network link
  93. type inbound struct {
  94. err error
  95. cp packets.ControlPacket
  96. }
  97. // startIncoming initiates a goroutine that reads incoming messages off the wire and sends them to the channel (returned).
  98. // If there are any issues with the network connection then the returned channel will be closed and the goroutine will exit
  99. // (so closing the connection will terminate the goroutine)
  100. func startIncoming(conn io.Reader) <-chan inbound {
  101. var err error
  102. var cp packets.ControlPacket
  103. ibound := make(chan inbound)
  104. DEBUG.Println(NET, "incoming started")
  105. go func() {
  106. for {
  107. if cp, err = packets.ReadPacket(conn); err != nil {
  108. // We do not want to log the error if it is due to the network connection having been closed
  109. // elsewhere (i.e. after sending DisconnectPacket). Detecting this situation is the subject of
  110. // https://github.com/golang/go/issues/4373
  111. if !strings.Contains(err.Error(), closedNetConnErrorText) {
  112. ibound <- inbound{err: err}
  113. }
  114. close(ibound)
  115. DEBUG.Println(NET, "incoming complete")
  116. return
  117. }
  118. DEBUG.Println(NET, "startIncoming Received Message")
  119. ibound <- inbound{cp: cp}
  120. }
  121. }()
  122. return ibound
  123. }
  124. // incomingComms encapsulates the possible output of the incomingComms routine. If err != nil then an error has occurred and
  125. // the routine will have terminated; otherwise one of the other members should be non-nil
  126. type incomingComms struct {
  127. err error // If non-nil then there has been an error (ignore everything else)
  128. outbound *PacketAndToken // Packet (with token) than needs to be sent out (e.g. an acknowledgement)
  129. incomingPub *packets.PublishPacket // A new publish has been received; this will need to be passed on to our user
  130. }
  131. // startIncomingComms initiates incoming communications; this includes starting a goroutine to process incoming
  132. // messages.
  133. // Accepts a channel of inbound messages from the store (persisted messages); note this must be closed as soon as
  134. // everything in the store has been sent.
  135. // Returns a channel that will be passed any received packets; this will be closed on a network error (and inboundFromStore closed)
  136. func startIncomingComms(conn io.Reader,
  137. c commsFns,
  138. inboundFromStore <-chan packets.ControlPacket,
  139. ) <-chan incomingComms {
  140. ibound := startIncoming(conn) // Start goroutine that reads from network connection
  141. output := make(chan incomingComms)
  142. DEBUG.Println(NET, "startIncomingComms started")
  143. go func() {
  144. for {
  145. if inboundFromStore == nil && ibound == nil {
  146. close(output)
  147. DEBUG.Println(NET, "startIncomingComms goroutine complete")
  148. return // As soon as ibound is closed we can exit (should have already processed an error)
  149. }
  150. DEBUG.Println(NET, "logic waiting for msg on ibound")
  151. var msg packets.ControlPacket
  152. var ok bool
  153. select {
  154. case msg, ok = <-inboundFromStore:
  155. if !ok {
  156. DEBUG.Println(NET, "startIncomingComms: inboundFromStore complete")
  157. inboundFromStore = nil // should happen quickly as this is only for persisted messages
  158. continue
  159. }
  160. DEBUG.Println(NET, "startIncomingComms: got msg from store")
  161. case ibMsg, ok := <-ibound:
  162. if !ok {
  163. DEBUG.Println(NET, "startIncomingComms: ibound complete")
  164. ibound = nil
  165. continue
  166. }
  167. DEBUG.Println(NET, "startIncomingComms: got msg on ibound")
  168. // If the inbound comms routine encounters any issues it will send us an error.
  169. if ibMsg.err != nil {
  170. output <- incomingComms{err: ibMsg.err}
  171. continue // Usually the channel will be closed immediately after sending an error but safer that we do not assume this
  172. }
  173. msg = ibMsg.cp
  174. c.persistInbound(msg)
  175. c.UpdateLastReceived() // Notify keepalive logic that we recently received a packet
  176. }
  177. switch m := msg.(type) {
  178. case *packets.PingrespPacket:
  179. DEBUG.Println(NET, "startIncomingComms: received pingresp")
  180. c.pingRespReceived()
  181. case *packets.SubackPacket:
  182. DEBUG.Println(NET, "startIncomingComms: received suback, id:", m.MessageID)
  183. token := c.getToken(m.MessageID)
  184. if t, ok := token.(*SubscribeToken); ok {
  185. DEBUG.Println(NET, "startIncomingComms: granted qoss", m.ReturnCodes)
  186. for i, qos := range m.ReturnCodes {
  187. t.subResult[t.subs[i]] = qos
  188. }
  189. }
  190. token.flowComplete()
  191. c.freeID(m.MessageID)
  192. case *packets.UnsubackPacket:
  193. DEBUG.Println(NET, "startIncomingComms: received unsuback, id:", m.MessageID)
  194. c.getToken(m.MessageID).flowComplete()
  195. c.freeID(m.MessageID)
  196. case *packets.PublishPacket:
  197. DEBUG.Println(NET, "startIncomingComms: received publish, msgId:", m.MessageID)
  198. output <- incomingComms{incomingPub: m}
  199. case *packets.PubackPacket:
  200. DEBUG.Println(NET, "startIncomingComms: received puback, id:", m.MessageID)
  201. c.getToken(m.MessageID).flowComplete()
  202. c.freeID(m.MessageID)
  203. case *packets.PubrecPacket:
  204. DEBUG.Println(NET, "startIncomingComms: received pubrec, id:", m.MessageID)
  205. prel := packets.NewControlPacket(packets.Pubrel).(*packets.PubrelPacket)
  206. prel.MessageID = m.MessageID
  207. output <- incomingComms{outbound: &PacketAndToken{p: prel, t: nil}}
  208. case *packets.PubrelPacket:
  209. DEBUG.Println(NET, "startIncomingComms: received pubrel, id:", m.MessageID)
  210. pc := packets.NewControlPacket(packets.Pubcomp).(*packets.PubcompPacket)
  211. pc.MessageID = m.MessageID
  212. c.persistOutbound(pc)
  213. output <- incomingComms{outbound: &PacketAndToken{p: pc, t: nil}}
  214. case *packets.PubcompPacket:
  215. DEBUG.Println(NET, "startIncomingComms: received pubcomp, id:", m.MessageID)
  216. c.getToken(m.MessageID).flowComplete()
  217. c.freeID(m.MessageID)
  218. }
  219. }
  220. }()
  221. return output
  222. }
  223. // startOutgoingComms initiates a go routine to transmit outgoing packets.
  224. // Pass in an open network connection and channels for outbound messages (including those triggered
  225. // directly from incoming comms).
  226. // Returns a channel that will receive details of any errors (closed when the goroutine exits)
  227. // This function wil only terminate when all input channels are closed
  228. func startOutgoingComms(conn net.Conn,
  229. c commsFns,
  230. oboundp <-chan *PacketAndToken,
  231. obound <-chan *PacketAndToken,
  232. oboundFromIncoming <-chan *PacketAndToken,
  233. ) <-chan error {
  234. errChan := make(chan error)
  235. DEBUG.Println(NET, "outgoing started")
  236. go func() {
  237. for {
  238. DEBUG.Println(NET, "outgoing waiting for an outbound message")
  239. // This goroutine will only exits when all of the input channels we receive on have been closed. This approach is taken to avoid any
  240. // deadlocks (if the connection goes down there are limited options as to what we can do with anything waiting on us and
  241. // throwing away the packets seems the best option)
  242. if oboundp == nil && obound == nil && oboundFromIncoming == nil {
  243. DEBUG.Println(NET, "outgoing comms stopping")
  244. close(errChan)
  245. return
  246. }
  247. select {
  248. case pub, ok := <-obound:
  249. if !ok {
  250. obound = nil
  251. continue
  252. }
  253. msg := pub.p.(*packets.PublishPacket)
  254. DEBUG.Println(NET, "obound msg to write", msg.MessageID)
  255. writeTimeout := c.getWriteTimeOut()
  256. if writeTimeout > 0 {
  257. if err := conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
  258. ERROR.Println(NET, "SetWriteDeadline ", err)
  259. }
  260. }
  261. if err := msg.Write(conn); err != nil {
  262. ERROR.Println(NET, "outgoing obound reporting error ", err)
  263. pub.t.setError(err)
  264. // report error if it's not due to the connection being closed elsewhere
  265. if !strings.Contains(err.Error(), closedNetConnErrorText) {
  266. errChan <- err
  267. }
  268. continue
  269. }
  270. if writeTimeout > 0 {
  271. // If we successfully wrote, we don't want the timeout to happen during an idle period
  272. // so we reset it to infinite.
  273. if err := conn.SetWriteDeadline(time.Time{}); err != nil {
  274. ERROR.Println(NET, "SetWriteDeadline to 0 ", err)
  275. }
  276. }
  277. if msg.Qos == 0 {
  278. pub.t.flowComplete()
  279. }
  280. DEBUG.Println(NET, "obound wrote msg, id:", msg.MessageID)
  281. case msg, ok := <-oboundp:
  282. if !ok {
  283. oboundp = nil
  284. continue
  285. }
  286. DEBUG.Println(NET, "obound priority msg to write, type", reflect.TypeOf(msg.p))
  287. if err := msg.p.Write(conn); err != nil {
  288. ERROR.Println(NET, "outgoing oboundp reporting error ", err)
  289. if msg.t != nil {
  290. msg.t.setError(err)
  291. }
  292. errChan <- err
  293. continue
  294. }
  295. if _, ok := msg.p.(*packets.DisconnectPacket); ok {
  296. msg.t.(*DisconnectToken).flowComplete()
  297. DEBUG.Println(NET, "outbound wrote disconnect, closing connection")
  298. // As per the MQTT spec "After sending a DISCONNECT Packet the Client MUST close the Network Connection"
  299. // Closing the connection will cause the goroutines to end in sequence (starting with incoming comms)
  300. _ = conn.Close()
  301. }
  302. case msg, ok := <-oboundFromIncoming: // message triggered by an inbound message (PubrecPacket or PubrelPacket)
  303. if !ok {
  304. oboundFromIncoming = nil
  305. continue
  306. }
  307. DEBUG.Println(NET, "obound from incoming msg to write, type", reflect.TypeOf(msg.p), " ID ", msg.p.Details().MessageID)
  308. if err := msg.p.Write(conn); err != nil {
  309. ERROR.Println(NET, "outgoing oboundFromIncoming reporting error", err)
  310. if msg.t != nil {
  311. msg.t.setError(err)
  312. }
  313. errChan <- err
  314. continue
  315. }
  316. }
  317. c.UpdateLastSent() // Record that a packet has been received (for keepalive routine)
  318. }
  319. }()
  320. return errChan
  321. }
  322. // commsFns provide access to the client state (messageids, requesting disconnection and updating timing)
  323. type commsFns interface {
  324. getToken(id uint16) tokenCompletor // Retrieve the token for the specified messageid (if none then a dummy token must be returned)
  325. freeID(id uint16) // Release the specified messageid (clearing out of any persistent store)
  326. UpdateLastReceived() // Must be called whenever a packet is received
  327. UpdateLastSent() // Must be called whenever a packet is successfully sent
  328. getWriteTimeOut() time.Duration // Return the writetimeout (or 0 if none)
  329. persistOutbound(m packets.ControlPacket) // add the packet to the outbound store
  330. persistInbound(m packets.ControlPacket) // add the packet to the inbound store
  331. pingRespReceived() // Called when a ping response is received
  332. }
  333. // startComms initiates goroutines that handles communications over the network connection
  334. // Messages will be stored (via commsFns) and deleted from the store as necessary
  335. // It returns two channels:
  336. //
  337. // packets.PublishPacket - Will receive publish packets received over the network.
  338. // Closed when incoming comms routines exit (on shutdown or if network link closed)
  339. // error - Any errors will be sent on this channel. The channel is closed when all comms routines have shut down
  340. //
  341. // Note: The comms routines monitoring oboundp and obound will not shutdown until those channels are both closed. Any messages received between the
  342. // connection being closed and those channels being closed will generate errors (and nothing will be sent). That way the chance of a deadlock is
  343. // minimised.
  344. func startComms(conn net.Conn, // Network connection (must be active)
  345. c commsFns, // getters and setters to enable us to cleanly interact with client
  346. inboundFromStore <-chan packets.ControlPacket, // Inbound packets from the persistence store (should be closed relatively soon after startup)
  347. oboundp <-chan *PacketAndToken,
  348. obound <-chan *PacketAndToken) (
  349. <-chan *packets.PublishPacket, // Publishpackages received over the network
  350. <-chan error, // Any errors (should generally trigger a disconnect)
  351. ) {
  352. // Start inbound comms handler; this needs to be able to transmit messages so we start a go routine to add these to the priority outbound channel
  353. ibound := startIncomingComms(conn, c, inboundFromStore)
  354. outboundFromIncoming := make(chan *PacketAndToken) // Will accept outgoing messages triggered by startIncomingComms (e.g. acknowledgements)
  355. // Start the outgoing handler. It is important to note that output from startIncomingComms is fed into startOutgoingComms (for ACK's)
  356. oboundErr := startOutgoingComms(conn, c, oboundp, obound, outboundFromIncoming)
  357. DEBUG.Println(NET, "startComms started")
  358. // Run up go routines to handle the output from the above comms functions - these are handled in separate
  359. // go routines because they can interact (e.g. ibound triggers an ACK to obound which triggers an error)
  360. var wg sync.WaitGroup
  361. wg.Add(2)
  362. outPublish := make(chan *packets.PublishPacket)
  363. outError := make(chan error)
  364. // Any messages received get passed to the appropriate channel
  365. go func() {
  366. for ic := range ibound {
  367. if ic.err != nil {
  368. outError <- ic.err
  369. continue
  370. }
  371. if ic.outbound != nil {
  372. outboundFromIncoming <- ic.outbound
  373. continue
  374. }
  375. if ic.incomingPub != nil {
  376. outPublish <- ic.incomingPub
  377. continue
  378. }
  379. ERROR.Println(STR, "startComms received empty incomingComms msg")
  380. }
  381. // Close channels that will not be written to again (allowing other routines to exit)
  382. close(outboundFromIncoming)
  383. close(outPublish)
  384. wg.Done()
  385. }()
  386. // Any errors will be passed out to our caller
  387. go func() {
  388. for err := range oboundErr {
  389. outError <- err
  390. }
  391. wg.Done()
  392. }()
  393. // outError is used by both routines so can only be closed when they are both complete
  394. go func() {
  395. wg.Wait()
  396. close(outError)
  397. DEBUG.Println(NET, "startComms closing outError")
  398. }()
  399. return outPublish, outError
  400. }
  401. // ackFunc acknowledges a packet
  402. // WARNING the function returned must not be called if the comms routine is shutting down or not running
  403. // (it needs outgoing comms in order to send the acknowledgement). Currently this is only called from
  404. // matchAndDispatch which will be shutdown before the comms are
  405. func ackFunc(oboundP chan *PacketAndToken, persist Store, packet *packets.PublishPacket) func() {
  406. return func() {
  407. switch packet.Qos {
  408. case 2:
  409. pr := packets.NewControlPacket(packets.Pubrec).(*packets.PubrecPacket)
  410. pr.MessageID = packet.MessageID
  411. DEBUG.Println(NET, "putting pubrec msg on obound")
  412. oboundP <- &PacketAndToken{p: pr, t: nil}
  413. DEBUG.Println(NET, "done putting pubrec msg on obound")
  414. case 1:
  415. pa := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
  416. pa.MessageID = packet.MessageID
  417. DEBUG.Println(NET, "putting puback msg on obound")
  418. persistOutbound(persist, pa)
  419. oboundP <- &PacketAndToken{p: pa, t: nil}
  420. DEBUG.Println(NET, "done putting puback msg on obound")
  421. case 0:
  422. // do nothing, since there is no need to send an ack packet back
  423. }
  424. }
  425. }