b_dataset.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. package internal
  2. import (
  3. "context"
  4. "github.com/gogf/gf/util/guid"
  5. "yx-dataset-server/app/errors"
  6. "yx-dataset-server/app/model"
  7. "yx-dataset-server/app/schema"
  8. "yx-dataset-server/library/ragflow"
  9. )
  10. // NewDataset 创建Dataset
  11. func NewDataset(
  12. mDataset model.IDataset,
  13. mFile model.IDatasetFile,
  14. mTrans model.ITrans,
  15. mUser model.IUser,
  16. mRole model.IRole,
  17. mRelation model.IDatasetRelation,
  18. mOrg model.IOrganization,
  19. ) *Dataset {
  20. return &Dataset{
  21. DatasetModel: mDataset,
  22. fileModel: mFile,
  23. transModel: mTrans,
  24. userModel: mUser,
  25. roleModel: mRole,
  26. relationModel: mRelation,
  27. orgModel: mOrg,
  28. }
  29. }
  30. // Dataset 创建Dataset对象
  31. type Dataset struct {
  32. DatasetModel model.IDataset
  33. fileModel model.IDatasetFile
  34. transModel model.ITrans
  35. userModel model.IUser
  36. roleModel model.IRole
  37. relationModel model.IDatasetRelation
  38. orgModel model.IOrganization
  39. }
  40. // Query 查询数据
  41. func (a *Dataset) Query(ctx context.Context, params schema.DatasetQueryParam, opts ...schema.DatasetQueryOptions) (*schema.DatasetQueryResult, error) {
  42. if isSys, _ := IsSystemAdmin(ctx, a.userModel, a.roleModel); !isSys {
  43. // 非系统管理员:仅能看到自己有授权的知识库(user 维度 + 所属组织维度)
  44. userID := GetUserID(ctx)
  45. bizIds := []string{userID}
  46. if u, err := a.userModel.Get(ctx, userID); err == nil && u != nil && u.OrgId != "" {
  47. bizIds = append(bizIds, u.OrgId)
  48. }
  49. rel, err := a.relationModel.Query(ctx, schema.DatasetRelationQueryParam{BizIds: bizIds})
  50. if err != nil {
  51. return nil, err
  52. }
  53. if len(rel.Data) == 0 {
  54. return &schema.DatasetQueryResult{
  55. Data: make(schema.Datasets, 0),
  56. PageResult: &schema.PaginationResult{Total: 0},
  57. }, nil
  58. }
  59. params.RecordIds = rel.Data.ToDatasetIds()
  60. }
  61. result, err := a.DatasetModel.Query(ctx, params, opts...)
  62. if err != nil {
  63. return nil, err
  64. }
  65. users, err := a.userModel.Query(ctx, schema.UserQueryParam{RoleCode: []string{RoleCodeSystemAdmin, RoleCodeEnterpriseAdmin}})
  66. if err != nil {
  67. return nil, err
  68. }
  69. result.Data.FillCreator(users.Data)
  70. return result, nil
  71. }
  72. // Get 查询指定数据
  73. func (a *Dataset) Get(ctx context.Context, recordID string, opts ...schema.DatasetQueryOptions) (*schema.Dataset, error) {
  74. item, err := a.DatasetModel.Get(ctx, recordID, opts...)
  75. if err != nil {
  76. return nil, err
  77. } else if item == nil {
  78. return nil, errors.ErrNotFound
  79. }
  80. file, err := a.fileModel.Query(ctx, schema.DatasetFileQueryParam{DatasetId: recordID})
  81. if err != nil {
  82. return nil, err
  83. }
  84. if len(file.Data) > 0 {
  85. item.Files = file.Data
  86. }
  87. return item, nil
  88. }
  89. // Create 创建数据
  90. // 依据角色决定 type:
  91. // - 系统管理员 / root → 公共/共享知识库 (type=1)
  92. // - 企业管理员 → 企业知识库 (type=2),biz_id=org_id
  93. // - 员工 → 个人知识库 (type=3),biz_id=user_id
  94. func (a *Dataset) Create(ctx context.Context, item schema.Dataset) error {
  95. item.RecordID = guid.S()
  96. item.CreatorId = GetUserID(ctx)
  97. // 非 root:取创建者所属组织
  98. if !CheckIsRootUser(ctx) {
  99. user, err := a.userModel.Get(ctx, GetUserID(ctx))
  100. if err != nil {
  101. return err
  102. }
  103. item.OrgId = user.OrgId
  104. }
  105. data, err := ragflow.GetHttpClient().CreateDataset(ctx, &ragflow.CreateDatasetReq{
  106. Name: item.Name,
  107. ChunkMethod: "naive",
  108. EmbeddingModel: "text-embedding-v4@Tongyi-Qianwen",
  109. })
  110. if err != nil {
  111. return err
  112. }
  113. item.RagDataId = data.Data.ID
  114. return ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  115. if err := a.DatasetModel.Create(ctx, item); err != nil {
  116. return err
  117. }
  118. // 根据类型建立初始关系映射
  119. switch item.Type {
  120. case schema.DatasetTypePublic:
  121. // 共享知识库:由系统管理员建立,初始不绑定到任何业务实体;
  122. // 通过新建/编辑组织或分配给员工时再写入映射记录。
  123. return nil
  124. case schema.DatasetTypeOrg:
  125. // 企业知识库:自动给所属组织建立映射
  126. if err := a.relationModel.Create(ctx, schema.DatasetRelation{
  127. RecordID: guid.S(),
  128. DatasetId: item.RecordID,
  129. BizId: item.OrgId,
  130. Type: schema.DatasetTypeOrg,
  131. CreatorId: item.CreatorId,
  132. }); err != nil {
  133. return err
  134. }
  135. // 组织内的企业管理员自动获得该知识库访问权
  136. admins, err := a.userModel.Query(ctx, schema.UserQueryParam{OrgId: item.OrgId, RoleCode: []string{RoleCodeEnterpriseAdmin}})
  137. if err != nil {
  138. return err
  139. }
  140. for _, v := range admins.Data {
  141. if err := a.relationModel.Create(ctx, schema.DatasetRelation{
  142. RecordID: guid.S(),
  143. DatasetId: item.RecordID,
  144. BizId: v.RecordID,
  145. Type: schema.DatasetTypeOrg,
  146. CreatorId: item.CreatorId,
  147. }); err != nil {
  148. return err
  149. }
  150. }
  151. return nil
  152. case schema.DatasetTypePersonal:
  153. // 个人知识库:仅给创建者本人建立映射
  154. return a.relationModel.Create(ctx, schema.DatasetRelation{
  155. RecordID: guid.S(),
  156. DatasetId: item.RecordID,
  157. BizId: item.CreatorId,
  158. Type: schema.DatasetTypePersonal,
  159. CreatorId: item.CreatorId,
  160. })
  161. }
  162. return nil
  163. })
  164. }
  165. // Update 更新数据
  166. func (a *Dataset) Update(ctx context.Context, recordID string, item schema.Dataset) error {
  167. isSys, err := IsSystemAdmin(ctx, a.userModel, a.roleModel)
  168. if err != nil {
  169. return err
  170. }
  171. oldItem, err := a.DatasetModel.Get(ctx, recordID)
  172. if err != nil {
  173. return err
  174. } else if oldItem == nil {
  175. return errors.ErrNotFound
  176. }
  177. if !isSys {
  178. u, err := a.userModel.Get(ctx, GetUserID(ctx))
  179. if err != nil {
  180. return err
  181. }
  182. if u == nil {
  183. return errors.ErrUserNotFound
  184. }
  185. // 企业自建:限本组织;个人:限本人
  186. switch oldItem.Type {
  187. case schema.DatasetTypeOrg:
  188. if oldItem.OrgId != u.OrgId {
  189. return errors.New400Response("仅能修改本组织自建知识库")
  190. }
  191. case schema.DatasetTypePersonal:
  192. if oldItem.CreatorId != u.RecordID {
  193. return errors.New400Response("仅能修改本人个人知识库")
  194. }
  195. default:
  196. return errors.New400Response("仅系统管理员可修改共享知识库")
  197. }
  198. }
  199. // 保留 type/creator_id,不允许变更
  200. item.Type = oldItem.Type
  201. item.CreatorId = oldItem.CreatorId
  202. item.OrgId = oldItem.OrgId
  203. if item.Name != oldItem.Name {
  204. if _, err := ragflow.GetHttpClient().UpdateDataset(ctx, oldItem.RagDataId, &ragflow.UpdateDatasetReq{
  205. Name: item.Name,
  206. ChunkMethod: "naive",
  207. }); err != nil {
  208. return err
  209. }
  210. }
  211. return a.DatasetModel.Update(ctx, recordID, item)
  212. }
  213. // Delete 删除数据
  214. func (a *Dataset) Delete(ctx context.Context, recordID string) error {
  215. isSys, err := IsSystemAdmin(ctx, a.userModel, a.roleModel)
  216. if err != nil {
  217. return err
  218. }
  219. oldItem, err := a.DatasetModel.Get(ctx, recordID)
  220. if err != nil {
  221. return err
  222. } else if oldItem == nil {
  223. return errors.ErrNotFound
  224. }
  225. if !isSys {
  226. u, err := a.userModel.Get(ctx, GetUserID(ctx))
  227. if err != nil {
  228. return err
  229. }
  230. if u == nil {
  231. return errors.ErrUserNotFound
  232. }
  233. switch oldItem.Type {
  234. case schema.DatasetTypeOrg:
  235. if oldItem.OrgId != u.OrgId {
  236. return errors.New400Response("仅能删除本组织自建知识库")
  237. }
  238. case schema.DatasetTypePersonal:
  239. if oldItem.CreatorId != u.RecordID {
  240. return errors.New400Response("仅能删除本人个人知识库")
  241. }
  242. default:
  243. return errors.New400Response("仅系统管理员可删除共享知识库")
  244. }
  245. }
  246. return ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  247. if err := a.relationModel.DeleteByDatasetId(ctx, recordID); err != nil {
  248. return err
  249. }
  250. if err := a.fileModel.DeleteByDatasetIds(ctx, []string{recordID}); err != nil {
  251. return err
  252. }
  253. if err := a.DatasetModel.Delete(ctx, recordID); err != nil {
  254. return err
  255. }
  256. _, err = ragflow.GetHttpClient().DeleteDataset(ctx, []string{oldItem.RagDataId})
  257. return err
  258. })
  259. }
  260. // UpdateStatus 更新状态
  261. func (a *Dataset) UpdateStatus(ctx context.Context, recordID string, status int) error {
  262. oldItem, err := a.DatasetModel.Get(ctx, recordID)
  263. if err != nil {
  264. return err
  265. } else if oldItem == nil {
  266. return errors.ErrNotFound
  267. }
  268. return a.DatasetModel.UpdateStatus(ctx, recordID, status)
  269. }
  270. // GetPermissionDatasets 获取有权限的知识库(按组织分组)
  271. // - orgId 非空:返回该组织维度映射到的知识库(用于前端“组织可选知识库”列表)
  272. // - orgId 空:返回当前登录用户有权限的知识库
  273. func (a *Dataset) GetPermissionDatasets(ctx context.Context, orgId string) (schema.Organizations, error) {
  274. var datasetQueryParm schema.DatasetQueryParam
  275. if orgId != "" {
  276. rel, err := a.relationModel.Query(ctx, schema.DatasetRelationQueryParam{BizId: orgId})
  277. if err != nil {
  278. return nil, err
  279. }
  280. datasetQueryParm.RecordIds = rel.Data.ToDatasetIds()
  281. } else if !CheckIsRootUser(ctx) {
  282. userID := GetUserID(ctx)
  283. bizIds := []string{userID}
  284. if u, err := a.userModel.Get(ctx, userID); err == nil && u != nil && u.OrgId != "" {
  285. bizIds = append(bizIds, u.OrgId)
  286. }
  287. rel, err := a.relationModel.Query(ctx, schema.DatasetRelationQueryParam{BizIds: bizIds})
  288. if err != nil {
  289. return nil, err
  290. }
  291. datasetQueryParm.RecordIds = rel.Data.ToDatasetIds()
  292. }
  293. datasets, err := a.DatasetModel.Query(ctx, datasetQueryParm)
  294. if err != nil {
  295. return nil, err
  296. }
  297. orgs, err := a.orgModel.Query(ctx, schema.OrganizationQueryParam{RecordIds: datasets.Data.ToOrgIds()})
  298. if err != nil {
  299. return nil, err
  300. }
  301. orgs.Data.FillDataset(datasets.Data)
  302. return orgs.Data, nil
  303. }
  304. // GetAvailableDatasets 按类型查询可用知识库
  305. // - orgId 为空:返回全部公共/共享知识库(PublicDatasets)
  306. // - orgId 非空:返回该企业可访问的全部知识库(被分配的共享 PublicDatasets + 企业自建 OrgDatasets)
  307. func (a *Dataset) GetAvailableDatasets(ctx context.Context, orgId string) (*schema.AvailableDatasets, error) {
  308. result := &schema.AvailableDatasets{
  309. PublicDatasets: make(schema.Datasets, 0),
  310. OrgDatasets: make(schema.Datasets, 0),
  311. }
  312. if orgId == "" {
  313. // 新建企业场景:只返回全部公共/共享知识库
  314. shared, err := a.DatasetModel.Query(ctx, schema.DatasetQueryParam{Type: schema.DatasetTypePublic})
  315. if err != nil {
  316. return nil, err
  317. }
  318. result.PublicDatasets = shared.Data
  319. return result, nil
  320. }
  321. // 该企业被分配的共享知识库(通过组织级关系映射获得)
  322. rel, err := a.relationModel.Query(ctx, schema.DatasetRelationQueryParam{
  323. BizId: orgId,
  324. Type: schema.DatasetTypePublic,
  325. })
  326. if err != nil {
  327. return nil, err
  328. }
  329. if ids := rel.Data.ToDatasetIds(); len(ids) > 0 {
  330. shared, err := a.DatasetModel.Query(ctx, schema.DatasetQueryParam{RecordIds: ids})
  331. if err != nil {
  332. return nil, err
  333. }
  334. result.PublicDatasets = shared.Data
  335. }
  336. // 该企业自建的企业知识库
  337. orgDs, err := a.DatasetModel.Query(ctx, schema.DatasetQueryParam{
  338. OrgId: orgId,
  339. Type: schema.DatasetTypeOrg,
  340. })
  341. if err != nil {
  342. return nil, err
  343. }
  344. result.OrgDatasets = orgDs.Data
  345. return result, nil
  346. }