b_dataset_file.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. package internal
  2. import (
  3. "context"
  4. "fmt"
  5. "path/filepath"
  6. "strings"
  7. "time"
  8. "github.com/gogf/gf/v2/os/glog"
  9. "github.com/gogf/gf/v2/util/guid"
  10. "yx-dataset-server/app/errors"
  11. "yx-dataset-server/app/model"
  12. "yx-dataset-server/app/schema"
  13. "yx-dataset-server/library/ragflow"
  14. )
  15. // NewDatasetFile 创建DatasetFile
  16. func NewDatasetFile(
  17. mDatasetFile model.IDatasetFile,
  18. mDataset model.IDataset,
  19. mTrans model.ITrans,
  20. mUser model.IUser,
  21. ) *DatasetFile {
  22. return &DatasetFile{
  23. DatasetFileModel: mDatasetFile,
  24. datasetModel: mDataset,
  25. transModel: mTrans,
  26. userModel: mUser,
  27. }
  28. }
  29. // DatasetFile 创建DatasetFile对象
  30. type DatasetFile struct {
  31. DatasetFileModel model.IDatasetFile
  32. datasetModel model.IDataset
  33. transModel model.ITrans
  34. userModel model.IUser
  35. }
  36. // Query 查询数据
  37. func (a *DatasetFile) Query(ctx context.Context, params schema.DatasetFileQueryParam, opts ...schema.DatasetFileQueryOptions) (*schema.DatasetFileQueryResult, error) {
  38. result, err := a.DatasetFileModel.Query(ctx, params, opts...)
  39. if err != nil {
  40. return nil, err
  41. }
  42. if len(result.Data) > 0 {
  43. user, err := a.userModel.Query(ctx, schema.UserQueryParam{RoleCode: []string{"11", "12"}})
  44. if err != nil {
  45. return nil, err
  46. }
  47. result.Data.FillCreator(user.Data)
  48. }
  49. return result, nil
  50. }
  51. // Get 查询指定数据
  52. func (a *DatasetFile) Get(ctx context.Context, recordID string, opts ...schema.DatasetFileQueryOptions) (*schema.DatasetFile, error) {
  53. item, err := a.DatasetFileModel.Get(ctx, recordID, opts...)
  54. if err != nil {
  55. return nil, err
  56. } else if item == nil {
  57. return nil, errors.ErrNotFound
  58. }
  59. return item, nil
  60. }
  61. func (a *DatasetFile) getUpdate(ctx context.Context, recordID string) (*schema.DatasetFile, error) {
  62. return a.Get(ctx, recordID)
  63. }
  64. // Create 创建数据
  65. // 流程:
  66. // 1. 事务内:写 DatasetFile、累加 Dataset.FileCount、调用 ragflow 上传与解析;
  67. // 2. 事务成功后,异步轮询 ragflow 获取解析状态;
  68. //
  69. // 注意:
  70. // - 上传与解析只发生在事务中一次;
  71. // - 轮询使用独立 background context,避免父请求结束后 ctx 取消导致的 ragflow 调用失败;
  72. // - 只有事务成功才会启动轮询。
  73. func (a *DatasetFile) Create(ctx context.Context, item schema.DatasetFile) error {
  74. item.RecordID = guid.S()
  75. dataset, err := a.datasetModel.Get(ctx, item.DatasetId)
  76. if err != nil {
  77. return err
  78. }
  79. if dataset == nil || dataset.RagDataId == "" {
  80. return errors.New("知识库不存在")
  81. }
  82. err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  83. fileInfo := ragflow.FileInfo{
  84. FileName: item.Name,
  85. Url: fmt.Sprintf("%s%s", "https://app.yongxulvjian.com", item.Url),
  86. }
  87. uploadResp, err := ragflow.GetHttpClient().UploadDocument(ctx, dataset.RagDataId, ragflow.UploadFileReq{
  88. File: []*ragflow.FileInfo{&fileInfo},
  89. })
  90. if err != nil {
  91. return errors.New(fmt.Sprintf("文件上传失败:%s", err.Error()))
  92. }
  93. if uploadResp == nil || len(uploadResp.Data) == 0 {
  94. return errors.New("文件上传失败:ragflow 返回为空")
  95. }
  96. item.RagFileId = uploadResp.Data[0].ID
  97. item.ParseStatus = 1
  98. if _, err := ragflow.GetHttpClient().ParseDocuments(ctx, dataset.RagDataId, []string{item.RagFileId}); err != nil {
  99. glog.Errorf(ctx, "文件解析触发失败:%s", err.Error())
  100. }
  101. if err := a.DatasetFileModel.Create(ctx, item); err != nil {
  102. return err
  103. }
  104. dataset.FileCount += 1
  105. return a.datasetModel.Update(ctx, dataset.RecordID, *dataset)
  106. })
  107. if err != nil {
  108. return err
  109. }
  110. // 异步轮询解析状态,使用 Background 避免请求 ctx 提前取消
  111. go a.waitParse(context.Background(), item.RecordID, item.RagFileId, dataset.RagDataId)
  112. return nil
  113. }
  114. // CreateV2 创建数据(V2版本:直接从内存上传到 RagFlow,无需经 Minio / 临时文件)
  115. // 流程:
  116. // 1. 参数校验 & 知识库校验;
  117. // 2. 事务外调用 RagFlow 上传(IO 密集,避免长事务占用 DB 连接);
  118. // 3. 事务内只做 DB 操作:写 DatasetFile + 累加 Dataset.FileCount;
  119. // 4. 若入库失败,尝试反向删除 RagFlow 已上传的文档,避免数据残留;
  120. // 5. 事务成功后,异步轮询 RagFlow 解析状态。
  121. func (a *DatasetFile) CreateV2(ctx context.Context, param schema.UpdateFileParam) error {
  122. if len(param.FileData) == 0 {
  123. return errors.New("文件内容不能为空")
  124. }
  125. if param.DatasetId == "" {
  126. return errors.New("知识库ID不能为空")
  127. }
  128. dataset, err := a.datasetModel.Get(ctx, param.DatasetId)
  129. if err != nil {
  130. return err
  131. }
  132. if dataset == nil || dataset.RagDataId == "" {
  133. return errors.New("知识库不存在")
  134. }
  135. ragClient := ragflow.GetHttpClient()
  136. ragResp, err := ragClient.UploadDocumentV2(ctx, dataset.RagDataId, &ragflow.UploadDocumentV2Req{
  137. FileName: param.FileName,
  138. FileData: param.FileData,
  139. })
  140. if err != nil {
  141. return fmt.Errorf("RagFlow上传失败: %s", err.Error())
  142. }
  143. if ragResp == nil || len(ragResp.Data) == 0 {
  144. return errors.New("RagFlow上传失败:返回为空")
  145. }
  146. ragFileId := ragResp.Data[0].ID
  147. // 解析失败不回滚上传,后续由 waitParse 反映真实状态
  148. if _, err := ragClient.ParseDocuments(ctx, dataset.RagDataId, []string{ragFileId}); err != nil {
  149. glog.Errorf(ctx, "文件解析触发失败:%s", err.Error())
  150. }
  151. fileSize := param.FileSize
  152. if fileSize <= 0 {
  153. fileSize = int64(len(param.FileData))
  154. }
  155. item := schema.DatasetFile{
  156. RecordID: guid.S(),
  157. Name: param.FileName,
  158. DatasetId: param.DatasetId,
  159. Size: fileSize,
  160. Type: a.getFileType(param.FileName),
  161. Enabled: true,
  162. CreatorId: param.CreatorId,
  163. RagFileId: ragFileId,
  164. ParseStatus: 1,
  165. }
  166. err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  167. if err := a.DatasetFileModel.Create(ctx, item); err != nil {
  168. return err
  169. }
  170. dataset.FileCount += 1
  171. return a.datasetModel.Update(ctx, dataset.RecordID, *dataset)
  172. })
  173. if err != nil {
  174. if _, delErr := ragClient.DeleteDocuments(ctx, dataset.RagDataId, []string{ragFileId}); delErr != nil {
  175. glog.Errorf(ctx, "入库失败后清理 RagFlow 文档失败:%s", delErr.Error())
  176. }
  177. return err
  178. }
  179. go a.waitParse(context.Background(), item.RecordID, ragFileId, dataset.RagDataId)
  180. return nil
  181. }
  182. // getFileType 根据文件名获取文件类型
  183. func (a *DatasetFile) getFileType(fileName string) string {
  184. ext := strings.ToLower(filepath.Ext(fileName))
  185. switch ext {
  186. case ".pdf":
  187. return "pdf"
  188. case ".doc", ".docx":
  189. return "word"
  190. case ".xls", ".xlsx":
  191. return "excel"
  192. case ".txt":
  193. return "txt"
  194. case ".md":
  195. return "markdown"
  196. case ".jpg", ".jpeg":
  197. return "jpg"
  198. case ".png":
  199. return "png"
  200. default:
  201. return "other"
  202. }
  203. }
  204. // waitParse 轮询 ragflow 获取指定文档的解析状态并更新本地 parse_status
  205. // 1 解析中 / 2 解析完成 / 3 解析失败
  206. func (a *DatasetFile) waitParse(ctx context.Context, recordId, fileId, ragDataId string) {
  207. if fileId == "" || ragDataId == "" {
  208. return
  209. }
  210. const maxAttempts = 600 // ~20 分钟上限,防止死循环
  211. for i := 0; i < maxAttempts; i++ {
  212. res, err := ragflow.GetHttpClient().ListDocuments(ctx, ragDataId, &ragflow.ListDocumentReq{DocumentID: fileId})
  213. if err != nil || res == nil || len(res.Data.Docs) == 0 {
  214. if err != nil {
  215. glog.Errorf(ctx, "查询文件解析状态失败:%s", err.Error())
  216. }
  217. time.Sleep(2 * time.Second)
  218. continue
  219. }
  220. doc := res.Data.Docs[0]
  221. switch doc.Run {
  222. case "RUNNING":
  223. time.Sleep(2 * time.Second)
  224. continue
  225. case "DONE":
  226. glog.Info(ctx, fmt.Sprintf("%s:【%s】文件解析完成", fileId, doc.Name))
  227. _ = a.DatasetFileModel.UpdateParseStatus(ctx, recordId, 2)
  228. return
  229. default:
  230. glog.Error(ctx, fmt.Sprintf("%s:【%s】文件解析失败,Run=%s", fileId, doc.Name, doc.Run))
  231. _ = a.DatasetFileModel.UpdateParseStatus(ctx, recordId, 3)
  232. return
  233. }
  234. }
  235. glog.Error(ctx, fmt.Sprintf("%s: 文件解析状态轮询超时", fileId))
  236. _ = a.DatasetFileModel.UpdateParseStatus(ctx, recordId, 3)
  237. }
  238. // BatchCreate 批量创建数据
  239. func (a *DatasetFile) BatchCreate(ctx context.Context, files schema.DatasetFiles) error {
  240. if len(files) == 0 {
  241. return errors.New("文件不能为空")
  242. }
  243. for _, v := range files {
  244. err := a.Create(ctx, *v)
  245. if err != nil {
  246. return err
  247. }
  248. }
  249. return nil
  250. }
  251. // Update 更新数据
  252. func (a *DatasetFile) Update(ctx context.Context, recordID string, item schema.DatasetFile) error {
  253. oldItem, err := a.DatasetFileModel.Get(ctx, recordID)
  254. if err != nil {
  255. return err
  256. } else if oldItem == nil {
  257. return errors.ErrNotFound
  258. }
  259. return a.DatasetFileModel.Update(ctx, recordID, item)
  260. }
  261. // Delete 删除数据
  262. func (a *DatasetFile) Delete(ctx context.Context, recordID string) error {
  263. oldItem, err := a.DatasetFileModel.Get(ctx, recordID)
  264. if err != nil {
  265. return err
  266. } else if oldItem == nil {
  267. return errors.ErrNotFound
  268. }
  269. dataset, err := a.datasetModel.Get(ctx, oldItem.DatasetId)
  270. if err != nil {
  271. return err
  272. }
  273. if dataset == nil || dataset.RagDataId == "" {
  274. return errors.New("知识库不存在")
  275. }
  276. err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  277. dataset, err := a.datasetModel.Get(ctx, oldItem.DatasetId)
  278. if err != nil {
  279. return err
  280. }
  281. dataset.FileCount -= 1
  282. err = a.datasetModel.Update(ctx, dataset.RecordID, *dataset)
  283. if err != nil {
  284. return err
  285. }
  286. _, err = ragflow.GetHttpClient().DeleteDocuments(ctx, dataset.RagDataId, []string{oldItem.RagFileId})
  287. if err != nil {
  288. return err
  289. }
  290. return a.DatasetFileModel.Delete(ctx, recordID)
  291. })
  292. return err
  293. }
  294. // UpdateStatus 更新状态
  295. func (a *DatasetFile) UpdateStatus(ctx context.Context, recordID string, status int) error {
  296. oldItem, err := a.DatasetFileModel.Get(ctx, recordID)
  297. if err != nil {
  298. return err
  299. } else if oldItem == nil {
  300. return errors.ErrNotFound
  301. }
  302. return a.DatasetFileModel.UpdateStatus(ctx, recordID, status)
  303. }
  304. // UpdateEnabled 更新启用状态
  305. func (a *DatasetFile) UpdateEnabled(ctx context.Context, recordID string, status bool) error {
  306. oldItem, err := a.DatasetFileModel.Get(ctx, recordID)
  307. if err != nil {
  308. return err
  309. } else if oldItem == nil {
  310. return errors.ErrNotFound
  311. }
  312. dataset, err := a.datasetModel.Get(ctx, oldItem.DatasetId)
  313. if err != nil {
  314. return err
  315. }
  316. if dataset == nil || dataset.RagDataId == "" {
  317. return errors.New("知识库不存在")
  318. }
  319. enabled := 0
  320. if status {
  321. enabled = 1
  322. }
  323. // 更新ragFlow
  324. _, err = ragflow.GetHttpClient().UpdateDocument(ctx, dataset.RagDataId, oldItem.RagFileId, &ragflow.UpdateDocumentReq{Enabled: enabled})
  325. if err != nil {
  326. return err
  327. }
  328. return a.DatasetFileModel.UpdateEnabled(ctx, recordID, status)
  329. }
  330. // BatchDelete 批量删除
  331. func (a *DatasetFile) BatchDelete(ctx context.Context, fileIDs []string) error {
  332. if len(fileIDs) == 0 {
  333. return errors.New("文件不能为空")
  334. }
  335. file, err := a.DatasetFileModel.Query(ctx, schema.DatasetFileQueryParam{RecordIDs: fileIDs})
  336. if err != nil {
  337. return err
  338. }
  339. if len(file.Data) == 0 {
  340. return nil
  341. }
  342. dataset, err := a.datasetModel.Get(ctx, file.Data[0].DatasetId)
  343. if err != nil {
  344. return err
  345. }
  346. _, err = ragflow.GetHttpClient().DeleteDocuments(ctx, dataset.RagDataId, file.Data.ToRagFileIds())
  347. if err != nil {
  348. return err
  349. }
  350. return a.DatasetFileModel.BatchDelete(ctx, fileIDs)
  351. }
  352. // BatchCreateV2 批量上传多个文件到指定知识库
  353. // 入参 files 通常由 controller 从 multipart 表单解析而来,已包含文件名与内容。
  354. // 单文件失败不会中断整批;最终成功数和失败明细统一返回给调用方。
  355. func (a *DatasetFile) BatchCreateV2(ctx context.Context, datasetId string, files []schema.UpdateFileParam) error {
  356. if datasetId == "" {
  357. return errors.New("dataset_id 不能为空")
  358. }
  359. if len(files) == 0 {
  360. return errors.New("文件列表不能为空")
  361. }
  362. dataset, err := a.datasetModel.Get(ctx, datasetId)
  363. if err != nil {
  364. return err
  365. }
  366. if dataset == nil || dataset.RagDataId == "" {
  367. return errors.New("知识库不存在")
  368. }
  369. var (
  370. failed []string
  371. success int
  372. )
  373. for _, f := range files {
  374. // 保证 datasetId 一致,避免调用方遗漏
  375. f.DatasetId = datasetId
  376. if err := a.CreateV2(ctx, f); err != nil {
  377. glog.Errorf(ctx, "批量导入【%s】失败:%s", f.FileName, err.Error())
  378. failed = append(failed, f.FileName)
  379. continue
  380. }
  381. success++
  382. }
  383. if len(failed) > 0 {
  384. return errors.New(fmt.Sprintf("部分文件导入失败(成功 %d 个):%v", success, failed))
  385. }
  386. return nil
  387. }