b_dataset.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package internal
  2. import (
  3. "context"
  4. "github.com/gogf/gf/util/guid"
  5. "yx-dataset-server/app/errors"
  6. "yx-dataset-server/app/model"
  7. "yx-dataset-server/app/schema"
  8. "yx-dataset-server/library/ragflow"
  9. )
  10. // NewDataset 创建Dataset
  11. func NewDataset(
  12. mDataset model.IDataset,
  13. mFile model.IDatasetFile,
  14. mTrans model.ITrans,
  15. mUser model.IUser,
  16. ) *Dataset {
  17. return &Dataset{
  18. DatasetModel: mDataset,
  19. fileModel: mFile,
  20. transModel: mTrans,
  21. userModel: mUser,
  22. }
  23. }
  24. // Dataset 创建Dataset对象
  25. type Dataset struct {
  26. DatasetModel model.IDataset
  27. fileModel model.IDatasetFile
  28. transModel model.ITrans
  29. userModel model.IUser
  30. }
  31. // Query 查询数据
  32. func (a *Dataset) Query(ctx context.Context, params schema.DatasetQueryParam, opts ...schema.DatasetQueryOptions) (*schema.DatasetQueryResult, error) {
  33. result, err := a.DatasetModel.Query(ctx, params, opts...)
  34. if err != nil {
  35. return nil, err
  36. }
  37. user, err := a.userModel.Query(ctx, schema.UserQueryParam{RoleCode: []string{"11", "12"}})
  38. if err != nil {
  39. return nil, err
  40. }
  41. result.Data.FillCreator(user.Data)
  42. return result, nil
  43. }
  44. // Get 查询指定数据
  45. func (a *Dataset) Get(ctx context.Context, recordID string, opts ...schema.DatasetQueryOptions) (*schema.Dataset, error) {
  46. item, err := a.DatasetModel.Get(ctx, recordID, opts...)
  47. if err != nil {
  48. return nil, err
  49. } else if item == nil {
  50. return nil, errors.ErrNotFound
  51. }
  52. file, err := a.fileModel.Query(ctx, schema.DatasetFileQueryParam{DatasetId: recordID})
  53. if len(file.Data) > 0 {
  54. item.Files = file.Data
  55. }
  56. return item, nil
  57. }
  58. func (a *Dataset) getUpdate(ctx context.Context, recordID string) (*schema.Dataset, error) {
  59. return a.Get(ctx, recordID)
  60. }
  61. // Create 创建数据
  62. func (a *Dataset) Create(ctx context.Context, item schema.Dataset) error {
  63. item.RecordID = guid.S()
  64. //item.Sequence = 0
  65. //sequence, err := a.DatasetModel.GetMaxSequence(ctx)
  66. //if err != nil {
  67. // return err
  68. //}
  69. //if sequence > 0 {
  70. // item.Sequence = sequence + 1
  71. //}
  72. item.CreatorId = GetUserID(ctx)
  73. if !CheckIsRootUser(ctx) {
  74. user, err := a.userModel.Get(ctx, GetUserID(ctx))
  75. if err != nil {
  76. return err
  77. }
  78. item.OrgId = user.OrgId
  79. }
  80. data, err := ragflow.GetHttpClient().CreateDataset(ctx, &ragflow.CreateDatasetReq{
  81. Name: item.Name,
  82. ChunkMethod: "naive",
  83. EmbeddingModel: "text-embedding-v4@Tongyi-Qianwen",
  84. })
  85. if err != nil {
  86. return err
  87. }
  88. item.RagDataId = data.Data.ID
  89. _, err = ragflow.GetHttpClient().UpdateDataset(ctx, item.RagDataId, &ragflow.UpdateDatasetReq{
  90. Language: "Chinese",
  91. })
  92. if err != nil {
  93. return err
  94. }
  95. return a.DatasetModel.Create(ctx, item)
  96. }
  97. // Update 更新数据
  98. func (a *Dataset) Update(ctx context.Context, recordID string, item schema.Dataset) error {
  99. oldItem, err := a.DatasetModel.Get(ctx, recordID)
  100. if err != nil {
  101. return err
  102. } else if oldItem == nil {
  103. return errors.ErrNotFound
  104. }
  105. if item.Name != oldItem.Name {
  106. _, err = ragflow.GetHttpClient().UpdateDataset(ctx, item.RagDataId, &ragflow.UpdateDatasetReq{
  107. Name: item.Name,
  108. ChunkMethod: "naive",
  109. })
  110. if err != nil {
  111. return err
  112. }
  113. }
  114. err = a.DatasetModel.Update(ctx, recordID, item)
  115. return err
  116. }
  117. // Delete 删除数据
  118. func (a *Dataset) Delete(ctx context.Context, recordID string) error {
  119. oldItem, err := a.DatasetModel.Get(ctx, recordID)
  120. if err != nil {
  121. return err
  122. } else if oldItem == nil {
  123. return errors.ErrNotFound
  124. }
  125. err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  126. err = a.fileModel.DeleteByDatasetIds(ctx, []string{recordID})
  127. if err != nil {
  128. return err
  129. }
  130. err = a.DatasetModel.Delete(ctx, recordID)
  131. if err != nil {
  132. return err
  133. }
  134. _, err = ragflow.GetHttpClient().DeleteDataset(ctx, []string{oldItem.RagDataId})
  135. return err
  136. })
  137. return err
  138. }
  139. // UpdateStatus 更新状态
  140. func (a *Dataset) UpdateStatus(ctx context.Context, recordID string, status int) error {
  141. oldItem, err := a.DatasetModel.Get(ctx, recordID)
  142. if err != nil {
  143. return err
  144. } else if oldItem == nil {
  145. return errors.ErrNotFound
  146. }
  147. return a.DatasetModel.UpdateStatus(ctx, recordID, status)
  148. }
  149. // UpdateSequence 更新排序
  150. //func (a *Dataset) UpdateSequence(ctx context.Context, recordID string, sequence int) error {
  151. // item, err := a.DatasetModel.Get(ctx, recordID)
  152. // if err != nil {
  153. // return err
  154. // } else if item == nil {
  155. // return errors.ErrNotFound
  156. // }
  157. // err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  158. // err = a.DatasetModel.UpdateBatchSequenceMinus(ctx, item.Sequence, sequence)
  159. // if err != nil {
  160. // return err
  161. // }
  162. // return a.DatasetModel.UpdateSequence(ctx, recordID, sequence)
  163. // })
  164. // return err
  165. //}