packets.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. /*
  2. * Copyright (c) 2021 IBM Corp and others.
  3. *
  4. * All rights reserved. This program and the accompanying materials
  5. * are made available under the terms of the Eclipse Public License v2.0
  6. * and Eclipse Distribution License v1.0 which accompany this distribution.
  7. *
  8. * The Eclipse Public License is available at
  9. * https://www.eclipse.org/legal/epl-2.0/
  10. * and the Eclipse Distribution License is available at
  11. * http://www.eclipse.org/org/documents/edl-v10.php.
  12. *
  13. * Contributors:
  14. * Allan Stockdill-Mander
  15. */
  16. package packets
  17. import (
  18. "bytes"
  19. "encoding/binary"
  20. "errors"
  21. "fmt"
  22. "io"
  23. )
  24. // ControlPacket defines the interface for structs intended to hold
  25. // decoded MQTT packets, either from being read or before being
  26. // written
  27. type ControlPacket interface {
  28. Write(io.Writer) error
  29. Unpack(io.Reader) error
  30. String() string
  31. Details() Details
  32. }
  33. // PacketNames maps the constants for each of the MQTT packet types
  34. // to a string representation of their name.
  35. var PacketNames = map[uint8]string{
  36. 1: "CONNECT",
  37. 2: "CONNACK",
  38. 3: "PUBLISH",
  39. 4: "PUBACK",
  40. 5: "PUBREC",
  41. 6: "PUBREL",
  42. 7: "PUBCOMP",
  43. 8: "SUBSCRIBE",
  44. 9: "SUBACK",
  45. 10: "UNSUBSCRIBE",
  46. 11: "UNSUBACK",
  47. 12: "PINGREQ",
  48. 13: "PINGRESP",
  49. 14: "DISCONNECT",
  50. }
  51. // Below are the constants assigned to each of the MQTT packet types
  52. const (
  53. Connect = 1
  54. Connack = 2
  55. Publish = 3
  56. Puback = 4
  57. Pubrec = 5
  58. Pubrel = 6
  59. Pubcomp = 7
  60. Subscribe = 8
  61. Suback = 9
  62. Unsubscribe = 10
  63. Unsuback = 11
  64. Pingreq = 12
  65. Pingresp = 13
  66. Disconnect = 14
  67. )
  68. // Below are the const definitions for error codes returned by
  69. // Connect()
  70. const (
  71. Accepted = 0x00
  72. ErrRefusedBadProtocolVersion = 0x01
  73. ErrRefusedIDRejected = 0x02
  74. ErrRefusedServerUnavailable = 0x03
  75. ErrRefusedBadUsernameOrPassword = 0x04
  76. ErrRefusedNotAuthorised = 0x05
  77. ErrNetworkError = 0xFE
  78. ErrProtocolViolation = 0xFF
  79. )
  80. // ConnackReturnCodes is a map of the error codes constants for Connect()
  81. // to a string representation of the error
  82. var ConnackReturnCodes = map[uint8]string{
  83. 0: "Connection Accepted",
  84. 1: "Connection Refused: Bad Protocol Version",
  85. 2: "Connection Refused: Client Identifier Rejected",
  86. 3: "Connection Refused: Server Unavailable",
  87. 4: "Connection Refused: Username or Password in unknown format",
  88. 5: "Connection Refused: Not Authorised",
  89. 254: "Connection Error",
  90. 255: "Connection Refused: Protocol Violation",
  91. }
  92. var (
  93. ErrorRefusedBadProtocolVersion = errors.New("unacceptable protocol version")
  94. ErrorRefusedIDRejected = errors.New("identifier rejected")
  95. ErrorRefusedServerUnavailable = errors.New("server Unavailable")
  96. ErrorRefusedBadUsernameOrPassword = errors.New("bad user name or password")
  97. ErrorRefusedNotAuthorised = errors.New("not Authorized")
  98. ErrorNetworkError = errors.New("network Error")
  99. ErrorProtocolViolation = errors.New("protocol Violation")
  100. )
  101. // ConnErrors is a map of the errors codes constants for Connect()
  102. // to a Go error
  103. var ConnErrors = map[byte]error{
  104. Accepted: nil,
  105. ErrRefusedBadProtocolVersion: ErrorRefusedBadProtocolVersion,
  106. ErrRefusedIDRejected: ErrorRefusedIDRejected,
  107. ErrRefusedServerUnavailable: ErrorRefusedServerUnavailable,
  108. ErrRefusedBadUsernameOrPassword: ErrorRefusedBadUsernameOrPassword,
  109. ErrRefusedNotAuthorised: ErrorRefusedNotAuthorised,
  110. ErrNetworkError: ErrorNetworkError,
  111. ErrProtocolViolation: ErrorProtocolViolation,
  112. }
  113. // ReadPacket takes an instance of an io.Reader (such as net.Conn) and attempts
  114. // to read an MQTT packet from the stream. It returns a ControlPacket
  115. // representing the decoded MQTT packet and an error. One of these returns will
  116. // always be nil, a nil ControlPacket indicating an error occurred.
  117. func ReadPacket(r io.Reader) (ControlPacket, error) {
  118. var fh FixedHeader
  119. b := make([]byte, 1)
  120. _, err := io.ReadFull(r, b)
  121. if err != nil {
  122. return nil, err
  123. }
  124. err = fh.unpack(b[0], r)
  125. if err != nil {
  126. return nil, err
  127. }
  128. cp, err := NewControlPacketWithHeader(fh)
  129. if err != nil {
  130. return nil, err
  131. }
  132. packetBytes := make([]byte, fh.RemainingLength)
  133. n, err := io.ReadFull(r, packetBytes)
  134. if err != nil {
  135. return nil, err
  136. }
  137. if n != fh.RemainingLength {
  138. return nil, errors.New("failed to read expected data")
  139. }
  140. err = cp.Unpack(bytes.NewBuffer(packetBytes))
  141. return cp, err
  142. }
  143. // NewControlPacket is used to create a new ControlPacket of the type specified
  144. // by packetType, this is usually done by reference to the packet type constants
  145. // defined in packets.go. The newly created ControlPacket is empty and a pointer
  146. // is returned.
  147. func NewControlPacket(packetType byte) ControlPacket {
  148. switch packetType {
  149. case Connect:
  150. return &ConnectPacket{FixedHeader: FixedHeader{MessageType: Connect}}
  151. case Connack:
  152. return &ConnackPacket{FixedHeader: FixedHeader{MessageType: Connack}}
  153. case Disconnect:
  154. return &DisconnectPacket{FixedHeader: FixedHeader{MessageType: Disconnect}}
  155. case Publish:
  156. return &PublishPacket{FixedHeader: FixedHeader{MessageType: Publish}}
  157. case Puback:
  158. return &PubackPacket{FixedHeader: FixedHeader{MessageType: Puback}}
  159. case Pubrec:
  160. return &PubrecPacket{FixedHeader: FixedHeader{MessageType: Pubrec}}
  161. case Pubrel:
  162. return &PubrelPacket{FixedHeader: FixedHeader{MessageType: Pubrel, Qos: 1}}
  163. case Pubcomp:
  164. return &PubcompPacket{FixedHeader: FixedHeader{MessageType: Pubcomp}}
  165. case Subscribe:
  166. return &SubscribePacket{FixedHeader: FixedHeader{MessageType: Subscribe, Qos: 1}}
  167. case Suback:
  168. return &SubackPacket{FixedHeader: FixedHeader{MessageType: Suback}}
  169. case Unsubscribe:
  170. return &UnsubscribePacket{FixedHeader: FixedHeader{MessageType: Unsubscribe, Qos: 1}}
  171. case Unsuback:
  172. return &UnsubackPacket{FixedHeader: FixedHeader{MessageType: Unsuback}}
  173. case Pingreq:
  174. return &PingreqPacket{FixedHeader: FixedHeader{MessageType: Pingreq}}
  175. case Pingresp:
  176. return &PingrespPacket{FixedHeader: FixedHeader{MessageType: Pingresp}}
  177. }
  178. return nil
  179. }
  180. // NewControlPacketWithHeader is used to create a new ControlPacket of the type
  181. // specified within the FixedHeader that is passed to the function.
  182. // The newly created ControlPacket is empty and a pointer is returned.
  183. func NewControlPacketWithHeader(fh FixedHeader) (ControlPacket, error) {
  184. switch fh.MessageType {
  185. case Connect:
  186. return &ConnectPacket{FixedHeader: fh}, nil
  187. case Connack:
  188. return &ConnackPacket{FixedHeader: fh}, nil
  189. case Disconnect:
  190. return &DisconnectPacket{FixedHeader: fh}, nil
  191. case Publish:
  192. return &PublishPacket{FixedHeader: fh}, nil
  193. case Puback:
  194. return &PubackPacket{FixedHeader: fh}, nil
  195. case Pubrec:
  196. return &PubrecPacket{FixedHeader: fh}, nil
  197. case Pubrel:
  198. return &PubrelPacket{FixedHeader: fh}, nil
  199. case Pubcomp:
  200. return &PubcompPacket{FixedHeader: fh}, nil
  201. case Subscribe:
  202. return &SubscribePacket{FixedHeader: fh}, nil
  203. case Suback:
  204. return &SubackPacket{FixedHeader: fh}, nil
  205. case Unsubscribe:
  206. return &UnsubscribePacket{FixedHeader: fh}, nil
  207. case Unsuback:
  208. return &UnsubackPacket{FixedHeader: fh}, nil
  209. case Pingreq:
  210. return &PingreqPacket{FixedHeader: fh}, nil
  211. case Pingresp:
  212. return &PingrespPacket{FixedHeader: fh}, nil
  213. }
  214. return nil, fmt.Errorf("unsupported packet type 0x%x", fh.MessageType)
  215. }
  216. // Details struct returned by the Details() function called on
  217. // ControlPackets to present details of the Qos and MessageID
  218. // of the ControlPacket
  219. type Details struct {
  220. Qos byte
  221. MessageID uint16
  222. }
  223. // FixedHeader is a struct to hold the decoded information from
  224. // the fixed header of an MQTT ControlPacket
  225. type FixedHeader struct {
  226. MessageType byte
  227. Dup bool
  228. Qos byte
  229. Retain bool
  230. RemainingLength int
  231. }
  232. func (fh FixedHeader) String() string {
  233. return fmt.Sprintf("%s: dup: %t qos: %d retain: %t rLength: %d", PacketNames[fh.MessageType], fh.Dup, fh.Qos, fh.Retain, fh.RemainingLength)
  234. }
  235. func boolToByte(b bool) byte {
  236. switch b {
  237. case true:
  238. return 1
  239. default:
  240. return 0
  241. }
  242. }
  243. func (fh *FixedHeader) pack() bytes.Buffer {
  244. var header bytes.Buffer
  245. header.WriteByte(fh.MessageType<<4 | boolToByte(fh.Dup)<<3 | fh.Qos<<1 | boolToByte(fh.Retain))
  246. header.Write(encodeLength(fh.RemainingLength))
  247. return header
  248. }
  249. func (fh *FixedHeader) unpack(typeAndFlags byte, r io.Reader) error {
  250. fh.MessageType = typeAndFlags >> 4
  251. fh.Dup = (typeAndFlags>>3)&0x01 > 0
  252. fh.Qos = (typeAndFlags >> 1) & 0x03
  253. fh.Retain = typeAndFlags&0x01 > 0
  254. var err error
  255. fh.RemainingLength, err = decodeLength(r)
  256. return err
  257. }
  258. func decodeByte(b io.Reader) (byte, error) {
  259. num := make([]byte, 1)
  260. _, err := b.Read(num)
  261. if err != nil {
  262. return 0, err
  263. }
  264. return num[0], nil
  265. }
  266. func decodeUint16(b io.Reader) (uint16, error) {
  267. num := make([]byte, 2)
  268. _, err := b.Read(num)
  269. if err != nil {
  270. return 0, err
  271. }
  272. return binary.BigEndian.Uint16(num), nil
  273. }
  274. func encodeUint16(num uint16) []byte {
  275. bytesResult := make([]byte, 2)
  276. binary.BigEndian.PutUint16(bytesResult, num)
  277. return bytesResult
  278. }
  279. func encodeString(field string) []byte {
  280. return encodeBytes([]byte(field))
  281. }
  282. func decodeString(b io.Reader) (string, error) {
  283. buf, err := decodeBytes(b)
  284. return string(buf), err
  285. }
  286. func decodeBytes(b io.Reader) ([]byte, error) {
  287. fieldLength, err := decodeUint16(b)
  288. if err != nil {
  289. return nil, err
  290. }
  291. field := make([]byte, fieldLength)
  292. _, err = b.Read(field)
  293. if err != nil {
  294. return nil, err
  295. }
  296. return field, nil
  297. }
  298. func encodeBytes(field []byte) []byte {
  299. fieldLength := make([]byte, 2)
  300. binary.BigEndian.PutUint16(fieldLength, uint16(len(field)))
  301. return append(fieldLength, field...)
  302. }
  303. func encodeLength(length int) []byte {
  304. var encLength []byte
  305. for {
  306. digit := byte(length % 128)
  307. length /= 128
  308. if length > 0 {
  309. digit |= 0x80
  310. }
  311. encLength = append(encLength, digit)
  312. if length == 0 {
  313. break
  314. }
  315. }
  316. return encLength
  317. }
  318. func decodeLength(r io.Reader) (int, error) {
  319. var rLength uint32
  320. var multiplier uint32
  321. b := make([]byte, 1)
  322. for multiplier < 27 { // fix: Infinite '(digit & 128) == 1' will cause the dead loop
  323. _, err := io.ReadFull(r, b)
  324. if err != nil {
  325. return 0, err
  326. }
  327. digit := b[0]
  328. rLength |= uint32(digit&127) << multiplier
  329. if (digit & 128) == 0 {
  330. break
  331. }
  332. multiplier += 7
  333. }
  334. return int(rLength), nil
  335. }