package internal import ( "context" "fmt" "github.com/gogf/gf/util/guid" "yx-dataset-server/app/errors" "yx-dataset-server/app/model" "yx-dataset-server/app/schema" "yx-dataset-server/library/ragflow" ) // NewDatasetFile 创建DatasetFile func NewDatasetFile( mDatasetFile model.IDatasetFile, mDataset model.IDataset, mTrans model.ITrans, mUser model.IUser, ) *DatasetFile { return &DatasetFile{ DatasetFileModel: mDatasetFile, datasetModel: mDataset, transModel: mTrans, userModel: mUser, } } // DatasetFile 创建DatasetFile对象 type DatasetFile struct { DatasetFileModel model.IDatasetFile datasetModel model.IDataset transModel model.ITrans userModel model.IUser } // Query 查询数据 func (a *DatasetFile) Query(ctx context.Context, params schema.DatasetFileQueryParam, opts ...schema.DatasetFileQueryOptions) (*schema.DatasetFileQueryResult, error) { result, err := a.DatasetFileModel.Query(ctx, params, opts...) if err != nil { return nil, err } if len(result.Data) > 0 { user, err := a.userModel.Query(ctx, schema.UserQueryParam{RoleCode: []string{"11", "12"}}) if err != nil { return nil, err } result.Data.FillCreator(user.Data) } return result, nil } // Get 查询指定数据 func (a *DatasetFile) Get(ctx context.Context, recordID string, opts ...schema.DatasetFileQueryOptions) (*schema.DatasetFile, error) { item, err := a.DatasetFileModel.Get(ctx, recordID, opts...) if err != nil { return nil, err } else if item == nil { return nil, errors.ErrNotFound } return item, nil } func (a *DatasetFile) getUpdate(ctx context.Context, recordID string) (*schema.DatasetFile, error) { return a.Get(ctx, recordID) } // Create 创建数据 func (a *DatasetFile) Create(ctx context.Context, item schema.DatasetFile) error { item.RecordID = guid.S() err := ExecTrans(ctx, a.transModel, func(ctx context.Context) error { dataset, err := a.datasetModel.Get(ctx, item.DatasetId) if err != nil { return err } err = a.DatasetFileModel.Create(ctx, item) if err != nil { return err } dataset.FileCount += 1 err = a.datasetModel.Update(ctx, dataset.RecordID, *dataset) if err != nil { return err } fileInfo := ragflow.FileInfo{ FileName: item.Name, Url: fmt.Sprintf("%s%s", "https://app.yongxulvjian.com", item.Url), } req := ragflow.UploadFileReq{ File: []*ragflow.FileInfo{&fileInfo}, } file, err := ragflow.GetHttpClient().UploadDocument(ctx, dataset.RagDataId, req) if err != nil { return errors.New(fmt.Sprintf("文件上传失败:%s", err.Error())) } //go func() { // _, err = ragflow.GetHttpClient().ParseDocuments(ctx, dataset.RagDataId, []string{file.Data[0].ID}) // if err != nil { // glog.Errorf(ctx, "文件解析失败:%s", err.Error()) // } //}() item.RagFileId = file.Data[0].ID return a.DatasetFileModel.Update(ctx, item.RecordID, item) }) return err } // BatchCreate 批量创建数据 func (a *DatasetFile) BatchCreate(ctx context.Context, files schema.DatasetFiles) error { if len(files) == 0 { return errors.New("文件不能为空") } for _, v := range files { err := a.Create(ctx, *v) if err != nil { return err } } return nil } // Update 更新数据 func (a *DatasetFile) Update(ctx context.Context, recordID string, item schema.DatasetFile) error { oldItem, err := a.DatasetFileModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } return a.DatasetFileModel.Update(ctx, recordID, item) } // Delete 删除数据 func (a *DatasetFile) Delete(ctx context.Context, recordID string) error { oldItem, err := a.DatasetFileModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } dataset, err := a.datasetModel.Get(ctx, oldItem.DatasetId) if err != nil { return err } if dataset == nil || dataset.RagDataId == "" { return errors.New("知识库不存在") } err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error { dataset, err := a.datasetModel.Get(ctx, oldItem.DatasetId) if err != nil { return err } dataset.FileCount -= 1 err = a.datasetModel.Update(ctx, dataset.RecordID, *dataset) if err != nil { return err } _, err = ragflow.GetHttpClient().DeleteDocuments(ctx, dataset.RagDataId, []string{oldItem.RagFileId}) if err != nil { return err } return a.DatasetFileModel.Delete(ctx, recordID) }) return err } // UpdateStatus 更新状态 func (a *DatasetFile) UpdateStatus(ctx context.Context, recordID string, status int) error { oldItem, err := a.DatasetFileModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } return a.DatasetFileModel.UpdateStatus(ctx, recordID, status) } // UpdateEnabled 更新启用状态 func (a *DatasetFile) UpdateEnabled(ctx context.Context, recordID string, status bool) error { oldItem, err := a.DatasetFileModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } return a.DatasetFileModel.UpdateEnabled(ctx, recordID, status) } // BatchDelete 批量删除 func (a *DatasetFile) BatchDelete(ctx context.Context, fileIDs []string) error { if len(fileIDs) == 0 { return errors.New("文件不能为空") } file, err := a.DatasetFileModel.Query(ctx, schema.DatasetFileQueryParam{RecordIDs: fileIDs}) if err != nil { return err } if len(file.Data) == 0 { return nil } dataset, err := a.datasetModel.Get(ctx, file.Data[0].DatasetId) if err != nil { return err } _, err = ragflow.GetHttpClient().DeleteDocuments(ctx, dataset.RagDataId, file.Data.ToRagFileIds()) if err != nil { return err } return a.DatasetFileModel.BatchDelete(ctx, fileIDs) }