package mqtt import ( "encoding/hex" "errors" "net" "sparrow/pkg/models" "sparrow/pkg/rpcs" "sparrow/pkg/server" "sync" "time" ) // const def const ( SendChanLen = 16 defaultKeepAlive = 30 ) // ResponseType response type type ResponseType struct { SendTime uint8 PublishType uint8 DataType string } // Connection client connection type Connection struct { Mgr *Manager DeviceID string Conn net.Conn SendChan chan Message MessageID uint16 MessageWaitChan map[uint16]chan error KeepAlive uint16 LastHbTime int64 Token []byte VendorId string closeChan chan struct{} DeviceCode string mLock sync.Mutex } // NewConnection create a connection func NewConnection(conn net.Conn, mgr *Manager) *Connection { sendchan := make(chan Message, SendChanLen) c := &Connection{ Conn: conn, SendChan: sendchan, Mgr: mgr, KeepAlive: defaultKeepAlive, MessageWaitChan: make(map[uint16]chan error), closeChan: make(chan struct{}), } go c.SendMsgToClient() go c.RcvMsgFromClient() return c } // Submit submit a message to send chan func (c *Connection) Submit(msg Message) { if c.Conn != nil { c.SendChan <- msg } } // Publish will publish a message , and return a chan to wait for completion. func (c *Connection) Publish(msg Message, timeout time.Duration) error { message := msg.(*Publish) message.MessageID = c.MessageID c.MessageID++ c.Submit(message) server.Log.Debugf("publishing message Id : %v, timeout %v", message.MessageID, timeout) ch := make(chan error) // we don't wait for confirm. if timeout == 0 { return nil } c.MessageWaitChan[message.MessageID] = ch // wait for timeout and go func() { timer := time.NewTimer(timeout) <-timer.C waitCh, exist := c.MessageWaitChan[message.MessageID] if exist { waitCh <- errors.New("timeout publishing message") delete(c.MessageWaitChan, message.MessageID) close(waitCh) } }() err := <-ch return err } func (c *Connection) confirmPublish(MessageID uint16) { server.Log.Debugf("[confirmPublish]收到消息Id: %d", MessageID) waitCh, exist := c.MessageWaitChan[MessageID] if exist { waitCh <- nil delete(c.MessageWaitChan, MessageID) close(waitCh) } } // ValidateToken validate token func (c *Connection) ValidateToken(token []byte) error { err := c.Mgr.Provider.ValidateDeviceToken(c.DeviceID, token) if err != nil { return err } c.Token = token return nil } // Close connection func (c *Connection) Close() { c.mLock.Lock() defer c.mLock.Unlock() DeviceID := c.DeviceID c.MessageID = 0 server.Log.Infof("closing connection of device %v", DeviceID) if c.Conn != nil { _ = c.Conn.Close() c.Conn = nil _ = c.Mgr.Provider.OnDeviceOffline(c.DeviceCode, c.VendorId) } if c.SendChan != nil { close(c.SendChan) c.SendChan = nil close(c.closeChan) } } func (c *Connection) RcvMsgFromClient() { conn := c.Conn host := conn.RemoteAddr().String() server.Log.Infof("receive new connection from %s", host) for { msg, err := DecodeOneMessage(conn) if err != nil { server.Log.Errorf("read error: %s", err) c.Close() return } c.LastHbTime = time.Now().Unix() switch msg := msg.(type) { case *Connect: ret := RetCodeAccepted if msg.ProtocolVersion == 3 && msg.ProtocolName != "MQIsdp" { ret = RetCodeUnacceptableProtocolVersion } else if msg.ProtocolVersion == 4 && msg.ProtocolName != "MQTT" { ret = RetCodeUnacceptableProtocolVersion } else if msg.ProtocolVersion > 4 { ret = RetCodeUnacceptableProtocolVersion } if len(msg.ClientID) < 1 || len(msg.ClientID) > 23 { server.Log.Errorf("invalid ClientID length: %d", len(msg.ClientID)) ret = RetCodeIdentifierRejected c.Close() return } DeviceID, err := ClientIDToDeviceID(msg.ClientID) if err != nil { server.Log.Errorf("invalid Identify: %d", ret) c.Close() return } device := &models.Device{} err = server.RPCCallByName(nil, rpcs.RegistryServerName, "Registry.FindDeviceById", DeviceID, device) if err != nil { server.Log.Errorf("device not found %d", DeviceID) c.Close() return } c.DeviceID = device.RecordId c.VendorId = device.VendorID c.DeviceCode = device.DeviceIdentifier token, err := hex.DecodeString(msg.Password) if err != nil { server.Log.Errorf("token format error : %v", err) ret = RetCodeNotAuthorized c.Close() return } err = c.ValidateToken(token) if err != nil { server.Log.Errorf("validate device token error. deviceid : %v, token : %s, error: %v", c.DeviceCode, hex.EncodeToString(token), err) ret = RetCodeNotAuthorized c.Close() return } if ret != RetCodeAccepted { server.Log.Errorf("invalid CON: %d", ret) c.Close() return } args := rpcs.ArgsGetOnline{ Id: device.DeviceIdentifier, ClientIP: host, AccessRPCHost: server.GetRPCHost(), HeartbeatInterval: 300, } c.Mgr.AddConn(c.DeviceCode, c) connack := &ConnAck{ ReturnCode: ret, } c.Submit(connack) c.KeepAlive = msg.KeepAliveTimer err = c.Mgr.Provider.OnDeviceOnline(args, c.VendorId) if err != nil { server.Log.Errorf("device online error : %v", err) c.Close() return } case *Publish: err = c.Mgr.PublishMessage2Server(c.DeviceCode, c.VendorId, msg) if err != nil { server.Log.Errorf("PublishMessage2Server error:%s", err.Error()) } if msg.QosLevel.IsAtLeastOnce() { publishack := &PubAck{MessageID: msg.MessageID} c.Submit(publishack) } else if msg.QosLevel.IsExactlyOnce() { server.Log.Infof("publish Rec send now") publishRec := &PubRec{MessageID: msg.MessageID} c.Submit(publishRec) } err := c.Mgr.Provider.OnDeviceHeartBeat(c.DeviceCode) if err != nil { server.Log.Errorf("%s, heartbeat set error %s, close now...", host, err) c.Close() return } case *PubAck: c.confirmPublish(msg.MessageID) err := c.Mgr.Provider.OnDeviceHeartBeat(c.DeviceCode) if err != nil { server.Log.Errorf("%s, heartbeat set error %s, close now...", host, err) c.Close() return } case *PubRec: server.Log.Infof("%s, comes publish rec", host) publishRel := &PubRel{MessageID: msg.MessageID} c.Submit(publishRel) case *PubRel: server.Log.Infof("%s, comes publish rel", host) publishCom := &PubComp{MessageID: msg.MessageID} c.Submit(publishCom) case *PubComp: server.Log.Infof("%s, comes publish comp", host) c.confirmPublish(msg.MessageID) err := c.Mgr.Provider.OnDeviceHeartBeat(c.DeviceCode) if err != nil { server.Log.Errorf("%s, heartbeat set error %s, close now...", host, err) c.Close() return } case *PingReq: pingrsp := &PingResp{} err := c.Mgr.Provider.OnDeviceHeartBeat(c.DeviceCode) if err != nil { server.Log.Errorf("%s, heartbeat set error %s, close now...", host, err) c.Close() return } c.Submit(pingrsp) case *Subscribe: server.Log.Infof("%s, subscribe topic: %v", c.DeviceCode, msg.Topics) case *Unsubscribe: server.Log.Infof("%s, unsubscribe topic: %v", host, msg.Topics) case *Disconnect: server.Log.Infof("%s, disconnect now, exit...", c.DeviceCode) c.Close() return default: server.Log.Errorf("unknown msg type %T", msg) c.Close() return } } } func (c *Connection) SendMsgToClient() { if c.Conn == nil { return } host := c.Conn.RemoteAddr() for { select { case <-c.closeChan: return case msg, ok := <-c.SendChan: if c.Conn == nil { return } if !ok { server.Log.Errorf("%s is end now", host) return } err := msg.Encode(c.Conn) if err != nil { server.Log.Errorf("send msg err: %s=====\n%v\n=====", err, msg) continue } } } }