liuxiulin 8 месяцев назад
Родитель
Сommit
a608801fd3

+ 20 - 1
pkg/models/device.go

@@ -1,6 +1,9 @@
 package models
 
-import "github.com/jinzhu/gorm"
+import (
+	"errors"
+	"github.com/jinzhu/gorm"
+)
 
 // Device device model
 // device is a product instance, which is managed by our platform
@@ -44,3 +47,19 @@ type DeviceChartData struct {
 	Dt    string
 	Count int
 }
+
+type UpgradeParams struct {
+	VendorID string `json:"vendor_id"`
+	DeviceID string `json:"device_id"`
+	File     []byte `json:"file"`
+	FileName string `json:"file_name"`
+	FileSize int64  `json:"file_size"`
+}
+
+// Validate 验证
+func (a *UpgradeParams) Validate() error {
+	if a.DeviceID == "" {
+		return errors.New("非法参数[Name, Label]")
+	}
+	return nil
+}

+ 49 - 4
services/knowoapi/controllers/device.go

@@ -3,6 +3,8 @@ package controllers
 import (
 	"errors"
 	"fmt"
+	"io"
+	"sparrow/pkg/models"
 	"sparrow/pkg/rpcs"
 	"sparrow/pkg/server"
 	"sparrow/services/knowoapi/services"
@@ -32,10 +34,8 @@ func (a *DeviceController) Get() {
 		badRequest(a.Ctx, err)
 		return
 	}
-	proid, err := a.Ctx.URLParamInt("proid")
-	if err != nil {
-		proid = 0
-	}
+	proid := a.Ctx.URLParam("proid")
+
 	deviceid := a.Ctx.URLParam("device_id")
 	vid := a.getVendorID(a.Ctx)
 	datas, total, err := a.Service.GetDevices(vid, proid, pi, ps, deviceid)
@@ -217,3 +217,48 @@ func (a *DeviceController) GetLivechart() {
 		"chart": datas,
 	})
 }
+
+// Upgrade ota升级
+// POST /devices/upgrade
+func (a *DeviceController) Upgrade() {
+	params := new(models.UpgradeParams)
+	if err := parseBody(a.Ctx, params); err != nil {
+		badRequest(a.Ctx, err)
+		return
+	}
+	params.VendorID = a.Token.getVendorID(a.Ctx)
+	file, header, err := a.Ctx.FormFile("file")
+	if err != nil {
+		responseError(a.Ctx, ErrNormal, err.Error())
+		return
+	}
+	fileBytes, err := io.ReadAll(file)
+	if err != nil {
+		responseError(a.Ctx, ErrNormal, err.Error())
+		return
+	}
+	params.FileSize = header.Size
+	params.File = fileBytes
+	params.FileName = header.Filename
+	err = a.Service.Upgrade(params)
+	if err != nil {
+		responseError(a.Ctx, ErrNormal, err.Error())
+		return
+	}
+	done(a.Ctx, params)
+}
+
+// OtaProgress ota升级
+// GET /devices/ota/progress?deviceId=
+func (a *DeviceController) OtaProgress() {
+	deviceId := a.Ctx.URLParam("deviceId")
+
+	data, err := a.Service.GetUpgradeProgress(deviceId)
+	if err != nil {
+		responseError(a.Ctx, ErrDatabase, err.Error())
+		return
+	}
+	done(a.Ctx, map[string]interface{}{
+		"progress": data.Progress,
+	})
+}

+ 1 - 1
services/knowoapi/controllers/rule_chain.go

@@ -13,7 +13,7 @@ type RuleChainController struct {
 	Token   Token
 }
 
-// Post 
+// Post post
 // POST /admin/rule_chain
 func (a *RuleChainController) Post() {
 	ptl := new(models.RuleChain)

+ 4 - 4
services/knowoapi/model/device.go

@@ -13,7 +13,7 @@ type Device struct {
 }
 
 // Init device
-//TODO:增加产品ID查询,目前的查询都没有指定产品ID的查询
+// TODO:增加产品ID查询,目前的查询都没有指定产品ID的查询
 func (a *Device) Init(db *gorm.DB) *Device {
 	a.db = db
 	return a
@@ -131,9 +131,9 @@ func (a *Device) GetLivelyOfNumDays(vendorid string, days int) ([]map[string]int
 }
 
 // GetDevices 获取厂商已经激活的设备列表
-func (a *Device) GetDevices(vendorid string, proid, pi, ps int, deviceid string) (datas []models.Device, total int, err error) {
+func (a *Device) GetDevices(vendorid, proid string, pi, ps int, deviceid string) (datas []models.Device, total int, err error) {
 	tx := a.db.Where("vendor_id = ?", vendorid)
-	if proid != 0 {
+	if proid != "" {
 		tx = tx.Where("product_id = ?", proid)
 	}
 	if deviceid != "" {
@@ -144,7 +144,7 @@ func (a *Device) GetDevices(vendorid string, proid, pi, ps int, deviceid string)
 	return
 }
 
-//GetDevicesByVenderId 获取用户设备
+// GetDevicesByVenderId 获取用户设备
 func (a *Device) GetDevicesByVenderId(vendorid string) (datas []models.Device, err error) {
 	a.db.Where("vendor_id = ?", vendorid).Find(&datas)
 	return

+ 49 - 2
services/knowoapi/services/device.go

@@ -1,6 +1,7 @@
 package services
 
 import (
+	"github.com/gogf/gf/util/guid"
 	"sparrow/pkg/models"
 	"sparrow/pkg/rpcs"
 	"sparrow/pkg/server"
@@ -20,9 +21,13 @@ type DeviceService interface {
 	//获取近N日活跃设备数据
 	GetLivelyOfNumDays(string, int) ([]map[string]interface{}, error)
 	//获取已经激活的设备列表
-	GetDevices(vendorid string, proid, pi, ps int, deviceid string) ([]*models.Devices, int, error)
+	GetDevices(vendorid, proid string, pi, ps int, deviceid string) ([]*models.Devices, int, error)
 	//获取用户下所有设备的数量,在线设备的数量,离线设备的数量
 	GetDevicesCountByVenderId(vendorid string) (map[string]interface{}, error)
+	// 发起设备OTA升级
+	Upgrade(params *models.UpgradeParams) error
+	// GetUpgradeProgress 获取ota升级进度
+	GetUpgradeProgress(deviceId string) (rpcs.ReplyOtaProgress, error)
 }
 
 type deviceservice struct {
@@ -35,7 +40,7 @@ func NewDeviceService(models *model.All) DeviceService {
 		models: models,
 	}
 }
-func (a deviceservice) GetDevices(vendorid string, proid, pi, ps int, deviceid string) ([]*models.Devices, int, error) {
+func (a deviceservice) GetDevices(vendorid, proid string, pi, ps int, deviceid string) ([]*models.Devices, int, error) {
 
 	data, total, err := a.models.Device.GetDevices(vendorid, proid, pi, ps, deviceid)
 
@@ -123,3 +128,45 @@ func (a deviceservice) GetDevicesCountByVenderId(vendorid string) (map[string]in
 
 	return deviceCount, nil
 }
+
+func (a deviceservice) Upgrade(param *models.UpgradeParams) error {
+
+	var fileArgs rpcs.ArgsOtaFile
+	fileArgs.FileData = param.File
+	fileArgs.FileId = guid.S()
+	var reply rpcs.ReplyEmptyResult
+
+	err := server.RPCCallByName(nil, rpcs.DeviceManagerName, "DeviceManager.SavaFile", fileArgs, &reply)
+	if err != nil {
+		server.Log.Errorf("OTA升级文件保存失败:%v", err)
+		return err
+	}
+
+	var args rpcs.ArgsUpgrade4G
+	args.DeviceId = param.DeviceID
+	args.FileId = fileArgs.FileId
+	args.FileSize = param.FileSize
+
+	err = server.RPCCallByName(nil, rpcs.MQTTAccessName, "Access.UpgradeFor4G", args, &reply)
+	if err != nil {
+		server.Log.Errorf("4G模组OTA升级失败:%v", err)
+		return err
+	}
+
+	return nil
+}
+
+func (a deviceservice) GetUpgradeProgress(deviceId string) (rpcs.ReplyOtaProgress, error) {
+	var args rpcs.ArgsOtaProgress
+	args.DeviceId = deviceId
+
+	var reply rpcs.ReplyOtaProgress
+
+	err := server.RPCCallByName(nil, rpcs.DeviceManagerName, "DeviceManager.GetProgress", args, &reply)
+	if err != nil {
+		server.Log.Errorf("OTA升级进度获取失败:%v", err)
+		return reply, err
+	}
+
+	return reply, nil
+}