liuxiulin 7 mēneši atpakaļ
vecāks
revīzija
0fdbb0f02b
2 mainītis faili ar 65 papildinājumiem un 5 dzēšanām
  1. 62 3
      services/emqx-agent/agent.go
  2. 3 2
      services/emqx-agent/sub_dev.go

+ 62 - 3
services/emqx-agent/agent.go

@@ -13,17 +13,29 @@ import (
 	"sparrow/pkg/protocol"
 	"sparrow/pkg/rpcs"
 	"sparrow/pkg/server"
+	"sync"
 	"time"
 )
 
 type Access struct {
-	client SubDev
+	client        SubDev
+	lockedDevices map[string]*Device
 }
 
 func NewAgent(client SubDev) *Access {
-	return &Access{
-		client: client,
+	a := &Access{
+		client:        client,
+		lockedDevices: make(map[string]*Device),
 	}
+	go a.UnlockDevice()
+	return a
+}
+
+type Device struct {
+	Id       string
+	Locked   bool
+	LastSeen time.Time
+	Mutex    sync.Mutex
 }
 
 // Message 收到设备上报消息处理
@@ -200,6 +212,15 @@ func (a *Access) processDeviceUpgrade(deviceId string, message *gjson.Json) erro
 			server.Log.Errorf("OTA升级进度保存失败:%v", err)
 			return err
 		}
+
+	case "finish":
+		device := a.GetLockDevice(deviceId)
+		device.Mutex.Lock()
+		if device != nil {
+			device.Locked = false
+		}
+		device.Mutex.Unlock()
+		server.Log.Infof("OTA升级完成;%s", deviceId)
 	}
 
 	return nil
@@ -260,6 +281,14 @@ func (a *Access) Disconnected(status *protocol.DevConnectStatus) error {
 // SendCommand rpc 发送设备命令
 func (a *Access) SendCommand(args rpcs.ArgsSendCommand, reply *rpcs.ReplySendCommand) error {
 	// 查询设备信息
+	lockDevice := a.GetLockDevice(args.DeviceId)
+	lockDevice.Mutex.Lock()
+
+	if lockDevice.Locked {
+		return errors.New("设备正在进行OTA升级,请稍后重试")
+	}
+	lockDevice.Mutex.Unlock()
+
 	device := &models.Device{}
 	err := server.RPCCallByName(nil, rpcs.RegistryServerName, "Registry.FindDeviceByIdentifier", args.DeviceId, device)
 	if err != nil {
@@ -305,6 +334,13 @@ func (a *Access) GetStatus(args rpcs.ArgsGetStatus, reply *rpcs.ReplyGetStatus)
 }
 
 func (a *Access) chunkUpgrade(params rpcs.ChunkUpgrade) error {
+	lockDevice := a.GetLockDevice(params.DeviceId)
+	lockDevice.Mutex.Lock()
+
+	lockDevice.Locked = true
+	lockDevice.LastSeen = time.Now()
+	lockDevice.Mutex.Unlock()
+
 	server.Log.Infof("4G模组OTA升级:%s", params.DeviceId)
 
 	buf := bytes.NewBuffer(gbinary.BeEncodeUint16(gconv.Uint16(params.Offset)))
@@ -367,3 +403,26 @@ func (a *Access) SendByteData(args rpcs.ArgsSendByteData, reply *rpcs.ReplySendC
 
 	return a.client.PublishToMsgToDev(protocol.GetCommandTopic(args.DeviceId, product.ProductKey), args.Data)
 }
+
+func (a *Access) GetLockDevice(id string) *Device {
+	if d, exists := a.lockedDevices[id]; exists {
+		return d
+	}
+	device := &Device{Id: id, Locked: false}
+	a.lockedDevices[id] = device
+	return device
+}
+
+func (a *Access) UnlockDevice() {
+	for {
+		time.Sleep(5 * time.Second) // 每5秒检查一次
+		for _, device := range a.lockedDevices {
+			device.Mutex.Lock()
+			if device.Locked && time.Since(device.LastSeen) > 1*time.Minute {
+				device.Locked = false
+				server.Log.Infof("Device %s unlocked\n", device.Id)
+			}
+			device.Mutex.Unlock()
+		}
+	}
+}

+ 3 - 2
services/emqx-agent/sub_dev.go

@@ -44,8 +44,9 @@ type DevSubHandle interface {
 }
 
 type MqttClient struct {
-	client     *client.MqttClient
-	handlePool *grpool.Pool
+	client      *client.MqttClient
+	handlePool  *grpool.Pool
+	lockDevices map[string]*Device
 }
 
 func (d *MqttClient) PublishToMsgToDev(topic string, payload []byte) error {