123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696 |
- package mqtt
- import (
- "bytes"
- "errors"
- "io"
- )
- // OoS only support QoS 0
- const (
- QosAtMostOnce = TagQosLevel(iota)
- QosAtLeastOnce
- QosExactlyOnce
- QosInvalid
- )
- // Max Payload size
- const (
- MaxPayloadSize = (1 << (4 * 7)) - 1
- )
- type TagQosLevel uint8
- func (qos TagQosLevel) IsValid() bool {
- return qos < QosInvalid && qos >= QosAtMostOnce
- }
- func (qos TagQosLevel) HasId() bool {
- return qos == QosAtLeastOnce || qos == QosExactlyOnce
- }
- func (qos TagQosLevel) IsAtLeastOnce() bool {
- return qos == QosAtLeastOnce
- }
- func (qos TagQosLevel) IsExactlyOnce() bool {
- return qos == QosExactlyOnce
- }
- // Message Type
- const (
- MsgConnect = TagMessageType(iota + 1)
- MsgConnAck
- MsgPublish
- MsgPubAck
- MsgPubRec
- MsgPubRel
- MsgPubComp
- MsgSubscribe
- MsgSubAck
- MsgUnsubscribe
- MsgUnsubAck
- MsgPingReq
- MsgPingResp
- MsgDisconnect
- MsgInvalid
- )
- // retcode
- const (
- RetCodeAccepted = TagRetCode(iota)
- RetCodeUnacceptableProtocolVersion
- RetCodeIdentifierRejected
- RetCodeServerUnavailable
- RetCodeBadUsernameOrPassword
- RetCodeNotAuthorized
- RetCodeInvalid
- )
- type TagRetCode uint8
- func (rc TagRetCode) IsValid() bool {
- return rc >= RetCodeAccepted && rc < RetCodeInvalid
- }
- type TagMessageType uint8
- func (msg TagMessageType) IsValid() bool {
- return msg >= MsgConnect && msg < MsgInvalid
- }
- // message interface
- type Message interface {
- Encode(w io.Writer) error
- Decode(r io.Reader, hdr Header, packetRemaining int32) error
- }
- // message fix header
- type Header struct {
- DupFlag bool
- QosLevel TagQosLevel
- Retain bool
- }
- func (hdr *Header) Encode(w io.Writer, msgType TagMessageType, remainingLength int32) error {
- buf := new(bytes.Buffer)
- err := hdr.EncodeInto(buf, msgType, remainingLength)
- if err != nil {
- return err
- }
- _, err = w.Write(buf.Bytes())
- return err
- }
- func (hdr *Header) EncodeInto(buf *bytes.Buffer, msgType TagMessageType, remainingLength int32) error {
- if !hdr.QosLevel.IsValid() {
- return errors.New("Invalid Qos level")
- }
- if !msgType.IsValid() {
- return errors.New("Invalid MsgType")
- }
- val := byte(msgType) << 4
- val |= (boolToByte(hdr.DupFlag) << 3)
- val |= byte(hdr.QosLevel) << 1
- val |= boolToByte(hdr.Retain)
- buf.WriteByte(val)
- encodeLength(remainingLength, buf)
- return nil
- }
- func (hdr *Header) Decode(r io.Reader) (msgType TagMessageType, remainingLength int32, err error) {
- var buf [1]byte
- if _, err := io.ReadFull(r, buf[:]); err != nil {
- return 0, 0, err
- }
- byte1 := buf[0]
- msgType = TagMessageType(byte1 & 0xf0 >> 4)
- *hdr = Header{
- DupFlag: byte1&0x08 > 0,
- QosLevel: TagQosLevel(byte1 & 0x06 >> 1),
- Retain: byte1&0x01 > 0,
- }
- remainingLength, err = decodeLength(r)
- return msgType, remainingLength, err
- }
- func writeMessage(w io.Writer, msgType TagMessageType, hdr *Header, payloadBuf *bytes.Buffer, extraLength int32) error {
- totalPayloadLength := int64(len(payloadBuf.Bytes())) + int64(extraLength)
- if totalPayloadLength > MaxPayloadSize {
- return errors.New("message too long")
- }
- buf := new(bytes.Buffer)
- err := hdr.EncodeInto(buf, msgType, int32(totalPayloadLength))
- if err != nil {
- return err
- }
- buf.Write(payloadBuf.Bytes())
- _, err = w.Write(buf.Bytes())
- return err
- }
- // Connect represents an MQTT CONNECT message.
- type Connect struct {
- Header
- ProtocolName string
- ProtocolVersion uint8
- WillRetain bool
- WillFlag bool
- CleanSession bool
- WillQos TagQosLevel
- KeepAliveTimer uint16
- ClientID string
- WillTopic, WillMessage string
- UsernameFlag, PasswordFlag bool
- Username, Password string
- }
- func (msg *Connect) Encode(w io.Writer) (err error) {
- if msg.WillQos > QosInvalid {
- return errors.New("invalid Qos")
- }
- buf := new(bytes.Buffer)
- flags := boolToByte(msg.UsernameFlag) << 7
- flags |= boolToByte(msg.PasswordFlag) << 6
- flags |= boolToByte(msg.WillRetain) << 5
- flags |= byte(msg.WillQos) << 3
- flags |= boolToByte(msg.WillFlag) << 2
- flags |= boolToByte(msg.CleanSession) << 1
- setString(msg.ProtocolName, buf)
- setUint8(msg.ProtocolVersion, buf)
- buf.WriteByte(flags)
- setUint16(msg.KeepAliveTimer, buf)
- setString(msg.ClientID, buf)
- if msg.WillFlag {
- setString(msg.WillTopic, buf)
- setString(msg.WillMessage, buf)
- }
- if msg.UsernameFlag {
- setString(msg.Username, buf)
- }
- if msg.PasswordFlag {
- setString(msg.Password, buf)
- }
- return writeMessage(w, MsgConnect, &msg.Header, buf, 0)
- }
- func (msg *Connect) Decode(r io.Reader, hdr Header, packetRemaining int32) (err error) {
- protocolName, err := getString(r, &packetRemaining)
- if err != nil {
- return err
- }
- protocolVersion, err := getUint8(r, &packetRemaining)
- if err != nil {
- return err
- }
- flags, err := getUint8(r, &packetRemaining)
- if err != nil {
- return err
- }
- keepAliveTimer, err := getUint16(r, &packetRemaining)
- if err != nil {
- return err
- }
- ClientID, err := getString(r, &packetRemaining)
- if err != nil {
- return err
- }
- *msg = Connect{
- Header: hdr,
- ProtocolName: protocolName,
- ProtocolVersion: protocolVersion,
- UsernameFlag: flags&0x80 > 0,
- PasswordFlag: flags&0x40 > 0,
- WillRetain: flags&0x20 > 0,
- WillQos: TagQosLevel(flags & 0x18 >> 3),
- WillFlag: flags&0x04 > 0,
- CleanSession: flags&0x02 > 0,
- KeepAliveTimer: keepAliveTimer,
- ClientID: ClientID,
- }
- if msg.WillFlag {
- msg.WillTopic, err = getString(r, &packetRemaining)
- if err != nil {
- return err
- }
- msg.WillMessage, err = getString(r, &packetRemaining)
- if err != nil {
- return err
- }
- }
- if msg.UsernameFlag {
- msg.Username, err = getString(r, &packetRemaining)
- if err != nil {
- return err
- }
- }
- if msg.PasswordFlag {
- msg.Password, err = getString(r, &packetRemaining)
- if err != nil {
- return err
- }
- }
- if packetRemaining != 0 {
- return errors.New("message too long")
- }
- return nil
- }
- // ConnAck represents an MQTT CONNACK message.
- type ConnAck struct {
- Header
- ReturnCode TagRetCode
- }
- func (msg *ConnAck) Encode(w io.Writer) (err error) {
- buf := new(bytes.Buffer)
- buf.WriteByte(byte(0)) // Reserved byte.
- setUint8(uint8(msg.ReturnCode), buf)
- return writeMessage(w, MsgConnAck, &msg.Header, buf, 0)
- }
- func (msg *ConnAck) Decode(r io.Reader, hdr Header, packetRemaining int32) (err error) {
- msg.Header = hdr
- _, err = getUint8(r, &packetRemaining) // Skip reserved byte.
- if err != nil {
- return err
- }
- code, err := getUint8(r, &packetRemaining)
- if err != nil {
- return err
- }
- msg.ReturnCode = TagRetCode(code)
- if !msg.ReturnCode.IsValid() {
- return errors.New("invliad retcode")
- }
- if packetRemaining != 0 {
- return errors.New("message too long")
- }
- return nil
- }
- // Publish represents an MQTT PUBLISH message.
- type Publish struct {
- Header
- TopicName string
- MessageID uint16
- Payload Payload
- }
- func (msg *Publish) Encode(w io.Writer) (err error) {
- buf := new(bytes.Buffer)
- setString(msg.TopicName, buf)
- if msg.Header.QosLevel.HasId() {
- setUint16(msg.MessageID, buf)
- }
- if err = msg.Payload.WritePayload(buf); err != nil {
- return err
- }
- if err = writeMessage(w, MsgPublish, &msg.Header, buf, int32(0)); err != nil {
- return err
- }
- return nil
- }
- func (msg *Publish) Decode(r io.Reader, hdr Header, packetRemaining int32) (err error) {
- msg.Header = hdr
- msg.TopicName, err = getString(r, &packetRemaining)
- if err != nil {
- return err
- }
- if msg.Header.QosLevel.HasId() {
- msg.MessageID, err = getUint16(r, &packetRemaining)
- if err != nil {
- return err
- }
- }
- payloadReader := &io.LimitedReader{r, int64(packetRemaining)}
- msg.Payload = make(BytesPayload, int(packetRemaining))
- return msg.Payload.ReadPayload(payloadReader, int(packetRemaining))
- }
- // PubAck represents an MQTT PUBACK message.
- type PubAck struct {
- Header
- MessageID uint16
- }
- func (msg *PubAck) Encode(w io.Writer) error {
- return encodeAckCommon(w, &msg.Header, msg.MessageID, MsgPubAck)
- }
- func (msg *PubAck) Decode(r io.Reader, hdr Header, packetRemaining int32) (err error) {
- msg.Header = hdr
- return decodeAckCommon(r, packetRemaining, &msg.MessageID)
- }
- // PubRec represents an MQTT PUBREC message.
- type PubRec struct {
- Header
- MessageID uint16
- }
- func (msg *PubRec) Encode(w io.Writer) error {
- return encodeAckCommon(w, &msg.Header, msg.MessageID, MsgPubRec)
- }
- func (msg *PubRec) Decode(r io.Reader, hdr Header, packetRemaining int32) (err error) {
- msg.Header = hdr
- return decodeAckCommon(r, packetRemaining, &msg.MessageID)
- }
- // PubRel represents an MQTT PUBREL message.
- type PubRel struct {
- Header
- MessageID uint16
- }
- func (msg *PubRel) Encode(w io.Writer) error {
- return encodeAckCommon(w, &msg.Header, msg.MessageID, MsgPubRel)
- }
- func (msg *PubRel) Decode(r io.Reader, hdr Header, packetRemaining int32) (err error) {
- msg.Header = hdr
- return decodeAckCommon(r, packetRemaining, &msg.MessageID)
- }
- // PubComp represents an MQTT PUBCOMP message.
- type PubComp struct {
- Header
- MessageID uint16
- }
- func (msg *PubComp) Encode(w io.Writer) error {
- return encodeAckCommon(w, &msg.Header, msg.MessageID, MsgPubComp)
- }
- func (msg *PubComp) Decode(r io.Reader, hdr Header, packetRemaining int32) (err error) {
- msg.Header = hdr
- return decodeAckCommon(r, packetRemaining, &msg.MessageID)
- }
- // Subscribe represents an MQTT SUBSCRIBE message.
- type Subscribe struct {
- Header
- MessageID uint16
- Topics []TopicQos
- }
- type TopicQos struct {
- Topic string
- Qos TagQosLevel
- }
- func (msg *Subscribe) Encode(w io.Writer) (err error) {
- buf := new(bytes.Buffer)
- if msg.Header.QosLevel.HasId() {
- setUint16(msg.MessageID, buf)
- }
- for _, topicSub := range msg.Topics {
- setString(topicSub.Topic, buf)
- setUint8(uint8(topicSub.Qos), buf)
- }
- return writeMessage(w, MsgSubscribe, &msg.Header, buf, 0)
- }
- func (msg *Subscribe) Decode(r io.Reader, hdr Header, packetRemaining int32) (err error) {
- msg.Header = hdr
- if msg.Header.QosLevel.HasId() {
- msg.MessageID, err = getUint16(r, &packetRemaining)
- if err != nil {
- return err
- }
- }
- var topics []TopicQos
- for packetRemaining > 0 {
- topic, err := getString(r, &packetRemaining)
- if err != nil {
- return err
- }
- qos, err := getUint8(r, &packetRemaining)
- if err != nil {
- return err
- }
- topics = append(topics, TopicQos{
- Topic: topic,
- Qos: TagQosLevel(qos),
- })
- }
- msg.Topics = topics
- return nil
- }
- // SubAck represents an MQTT SUBACK message.
- type SubAck struct {
- Header
- MessageID uint16
- TopicsQos []TagQosLevel
- }
- func (msg *SubAck) Encode(w io.Writer) (err error) {
- buf := new(bytes.Buffer)
- setUint16(msg.MessageID, buf)
- for i := 0; i < len(msg.TopicsQos); i += 1 {
- setUint8(uint8(msg.TopicsQos[i]), buf)
- }
- return writeMessage(w, MsgSubAck, &msg.Header, buf, 0)
- }
- func (msg *SubAck) Decode(r io.Reader, hdr Header, packetRemaining int32) (err error) {
- msg.Header = hdr
- msg.MessageID, err = getUint16(r, &packetRemaining)
- if err != nil {
- return err
- }
- topicsQos := make([]TagQosLevel, 0)
- for packetRemaining > 0 {
- qos, err := getUint8(r, &packetRemaining)
- if err != nil {
- return err
- }
- grantedQos := TagQosLevel(qos & 0x03)
- topicsQos = append(topicsQos, grantedQos)
- }
- msg.TopicsQos = topicsQos
- return nil
- }
- // Unsubscribe represents an MQTT UNSUBSCRIBE message.
- type Unsubscribe struct {
- Header
- MessageID uint16
- Topics []string
- }
- func (msg *Unsubscribe) Encode(w io.Writer) (err error) {
- buf := new(bytes.Buffer)
- if msg.Header.QosLevel.HasId() {
- setUint16(msg.MessageID, buf)
- }
- for _, topic := range msg.Topics {
- setString(topic, buf)
- }
- return writeMessage(w, MsgUnsubscribe, &msg.Header, buf, 0)
- }
- func (msg *Unsubscribe) Decode(r io.Reader, hdr Header, packetRemaining int32) (err error) {
- msg.Header = hdr
- if msg.Header.QosLevel.HasId() {
- msg.MessageID, err = getUint16(r, &packetRemaining)
- if err != nil {
- return err
- }
- }
- topics := make([]string, 0)
- for packetRemaining > 0 {
- topic, err := getString(r, &packetRemaining)
- if err != nil {
- return err
- }
- topics = append(topics, topic)
- }
- msg.Topics = topics
- return nil
- }
- // UnsubAck represents an MQTT UNSUBACK message.
- type UnsubAck struct {
- Header
- MessageID uint16
- }
- func (msg *UnsubAck) Encode(w io.Writer) error {
- return encodeAckCommon(w, &msg.Header, msg.MessageID, MsgUnsubAck)
- }
- func (msg *UnsubAck) Decode(r io.Reader, hdr Header, packetRemaining int32) (err error) {
- msg.Header = hdr
- return decodeAckCommon(r, packetRemaining, &msg.MessageID)
- }
- // PingReq represents an MQTT PINGREQ message.
- type PingReq struct {
- Header
- }
- func (msg *PingReq) Encode(w io.Writer) error {
- return msg.Header.Encode(w, MsgPingReq, 0)
- }
- func (msg *PingReq) Decode(r io.Reader, hdr Header, packetRemaining int32) error {
- if packetRemaining != 0 {
- return errors.New("msg too long")
- }
- return nil
- }
- // PingResp represents an MQTT PINGRESP message.
- type PingResp struct {
- Header
- }
- func (msg *PingResp) Encode(w io.Writer) error {
- return msg.Header.Encode(w, MsgPingResp, 0)
- }
- func (msg *PingResp) Decode(r io.Reader, hdr Header, packetRemaining int32) error {
- if packetRemaining != 0 {
- return errors.New("msg too long")
- }
- return nil
- }
- // Disconnect represents an MQTT DISCONNECT message.
- type Disconnect struct {
- Header
- }
- func (msg *Disconnect) Encode(w io.Writer) error {
- return msg.Header.Encode(w, MsgDisconnect, 0)
- }
- func (msg *Disconnect) Decode(r io.Reader, hdr Header, packetRemaining int32) error {
- if packetRemaining != 0 {
- return errors.New("msg too long")
- }
- return nil
- }
- func encodeAckCommon(w io.Writer, hdr *Header, MessageID uint16, msgType TagMessageType) error {
- buf := new(bytes.Buffer)
- setUint16(MessageID, buf)
- return writeMessage(w, msgType, hdr, buf, 0)
- }
- func decodeAckCommon(r io.Reader, packetRemaining int32, MessageID *uint16) (err error) {
- *MessageID, err = getUint16(r, &packetRemaining)
- if err != nil {
- return err
- }
- if packetRemaining != 0 {
- return errors.New("msg too long")
- }
- return nil
- }
- // DecodeOneMessage decodes one message from r. config provides specifics on
- // how to decode messages, nil indicates that the DefaultDecoderConfig should
- // be used.
- func DecodeOneMessage(r io.Reader) (msg Message, err error) {
- var hdr Header
- var msgType TagMessageType
- var packetRemaining int32
- msgType, packetRemaining, err = hdr.Decode(r)
- if err != nil {
- return nil, err
- }
- msg, err = NewMessage(msgType)
- if err != nil {
- return nil, err
- }
- return msg, msg.Decode(r, hdr, packetRemaining)
- }
- func NewMessage(msgType TagMessageType) (msg Message, err error) {
- switch msgType {
- case MsgConnect:
- msg = new(Connect)
- case MsgConnAck:
- msg = new(ConnAck)
- case MsgPublish:
- msg = new(Publish)
- case MsgPubAck:
- msg = new(PubAck)
- case MsgPubRec:
- msg = new(PubRec)
- case MsgPubRel:
- msg = new(PubRel)
- case MsgPubComp:
- msg = new(PubComp)
- case MsgSubscribe:
- msg = new(Subscribe)
- case MsgUnsubAck:
- msg = new(UnsubAck)
- case MsgSubAck:
- msg = new(SubAck)
- case MsgUnsubscribe:
- msg = new(Unsubscribe)
- case MsgPingReq:
- msg = new(PingReq)
- case MsgPingResp:
- msg = new(PingResp)
- case MsgDisconnect:
- msg = new(Disconnect)
- default:
- return nil, errors.New("msgType error")
- }
- return msg, nil
- }
|