b_chat_assistant.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. package internal
  2. import (
  3. "context"
  4. "github.com/gogf/gf/util/guid"
  5. "yx-dataset-server/app/errors"
  6. "yx-dataset-server/app/model"
  7. "yx-dataset-server/app/schema"
  8. "yx-dataset-server/library/ragflow"
  9. )
  10. // NewChatAssistant 创建ChatAssistant
  11. func NewChatAssistant(
  12. mChatAssistant model.IChatAssistant,
  13. mChatDataset model.IChatDataset,
  14. mTrans model.ITrans,
  15. mDataset model.IDataset,
  16. mUser model.IUser,
  17. mSession model.IChatSession,
  18. ) *ChatAssistant {
  19. return &ChatAssistant{
  20. ChatAssistantModel: mChatAssistant,
  21. chatDatasetModel: mChatDataset,
  22. transModel: mTrans,
  23. datasetModel: mDataset,
  24. userModel: mUser,
  25. sessionModel: mSession,
  26. }
  27. }
  28. // ChatAssistant 创建ChatAssistant对象
  29. type ChatAssistant struct {
  30. ChatAssistantModel model.IChatAssistant
  31. chatDatasetModel model.IChatDataset
  32. transModel model.ITrans
  33. datasetModel model.IDataset
  34. userModel model.IUser
  35. sessionModel model.IChatSession
  36. }
  37. // Query 查询数据
  38. func (a *ChatAssistant) Query(ctx context.Context, params schema.ChatAssistantQueryParam, opts ...schema.ChatAssistantQueryOptions) (*schema.ChatAssistantQueryResult, error) {
  39. var userQueryParam schema.UserQueryParam
  40. userQueryParam.RoleCode = []string{"11", "12"}
  41. if !CheckIsRootUser(ctx) {
  42. user, err := a.userModel.Get(ctx, GetUserID(ctx))
  43. if err != nil {
  44. return nil, err
  45. }
  46. params.OrgId = user.OrgId
  47. userQueryParam.OrgId = user.OrgId
  48. }
  49. result, err := a.ChatAssistantModel.Query(ctx, params, opts...)
  50. if err != nil {
  51. return nil, err
  52. }
  53. chatDataset, err := a.chatDatasetModel.Query(ctx, schema.ChatDatasetQueryParam{})
  54. if err != nil {
  55. return nil, err
  56. }
  57. dataset, err := a.datasetModel.Query(ctx, schema.DatasetQueryParam{})
  58. if err != nil {
  59. return nil, err
  60. }
  61. users, err := a.userModel.Query(ctx, userQueryParam)
  62. if err != nil {
  63. return nil, err
  64. }
  65. result.Data.FillCreator(users.Data)
  66. result.Data.FillDatasetId(chatDataset.Data)
  67. result.Data.FillDataset(dataset.Data)
  68. return result, nil
  69. }
  70. // Get 查询指定数据
  71. func (a *ChatAssistant) Get(ctx context.Context, recordID string, opts ...schema.ChatAssistantQueryOptions) (*schema.ChatAssistant, error) {
  72. item, err := a.ChatAssistantModel.Get(ctx, recordID, opts...)
  73. if err != nil {
  74. return nil, err
  75. } else if item == nil {
  76. return nil, errors.ErrNotFound
  77. }
  78. user, err := a.userModel.Get(ctx, item.CreatorId)
  79. if err != nil {
  80. return nil, err
  81. }
  82. item.CreatorName = user.RealName
  83. chatDataset, err := a.chatDatasetModel.Query(ctx, schema.ChatDatasetQueryParam{})
  84. if err != nil {
  85. return nil, err
  86. }
  87. dataset, err := a.datasetModel.Query(ctx, schema.DatasetQueryParam{RecordIds: chatDataset.Data.ToDatasetIds()})
  88. if err != nil {
  89. return nil, err
  90. }
  91. users, err := a.userModel.Query(ctx, schema.UserQueryParam{RoleCode: []string{"11", "12"}})
  92. if err != nil {
  93. return nil, err
  94. }
  95. dataset.Data.FillCreator(users.Data)
  96. item.Datasets = dataset.Data
  97. sessions, err := a.sessionModel.Query(ctx, schema.ChatSessionQueryParam{AssistantId: item.RecordID})
  98. if err != nil {
  99. return nil, err
  100. }
  101. item.Sessions = sessions.Data
  102. return item, nil
  103. }
  104. // Create 创建数据
  105. func (a *ChatAssistant) Create(ctx context.Context, item schema.ChatAssistant) error {
  106. item.RecordID = guid.S()
  107. item.CreatorId = GetUserID(ctx)
  108. name := "new_chat"
  109. if item.Name != "" {
  110. name = item.Name
  111. }
  112. if len(item.Datasets) == 0 {
  113. return errors.New400Response("关联知识库不能为空")
  114. }
  115. if !CheckIsRootUser(ctx) {
  116. user, err := a.userModel.Get(ctx, item.CreatorId)
  117. if err != nil {
  118. return err
  119. }
  120. item.OrgId = user.OrgId
  121. }
  122. err := ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  123. for _, v := range item.Datasets {
  124. err := a.chatDatasetModel.Create(ctx, schema.ChatDataset{
  125. RecordID: guid.S(),
  126. ChatAssistantId: item.RecordID,
  127. DatasetId: v.RecordID,
  128. })
  129. if err != nil {
  130. return err
  131. }
  132. }
  133. resp, err := ragflow.GetHttpClient().CreateChat(ctx, &ragflow.CreateChatReq{
  134. Name: name,
  135. DatasetIDs: item.Datasets.ToRagDataIds(),
  136. })
  137. if err != nil {
  138. return err
  139. }
  140. item.RagChatId = resp.Data.ID
  141. return a.ChatAssistantModel.Create(ctx, item)
  142. })
  143. return err
  144. }
  145. // Update 更新数据
  146. func (a *ChatAssistant) Update(ctx context.Context, recordID string, item schema.ChatAssistant) error {
  147. oldItem, err := a.ChatAssistantModel.Get(ctx, recordID)
  148. if err != nil {
  149. return err
  150. } else if oldItem == nil {
  151. return errors.ErrNotFound
  152. }
  153. err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  154. err = a.ChatAssistantModel.Update(ctx, recordID, item)
  155. if err != nil {
  156. return err
  157. }
  158. if item.Name != oldItem.Name {
  159. _, err = ragflow.GetHttpClient().UpdateChat(ctx, item.RagChatId, &ragflow.UpdateChatReq{
  160. Name: item.Name,
  161. })
  162. if err != nil {
  163. return err
  164. }
  165. }
  166. return nil
  167. })
  168. return err
  169. }
  170. // Delete 删除数据
  171. func (a *ChatAssistant) Delete(ctx context.Context, recordID string) error {
  172. oldItem, err := a.ChatAssistantModel.Get(ctx, recordID)
  173. if err != nil {
  174. return err
  175. } else if oldItem == nil {
  176. return errors.ErrNotFound
  177. }
  178. err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  179. err = a.ChatAssistantModel.Delete(ctx, recordID)
  180. if err != nil {
  181. return err
  182. }
  183. _, err = ragflow.GetHttpClient().DeleteChats(ctx, []string{oldItem.RagChatId})
  184. if err != nil {
  185. return err
  186. }
  187. return nil
  188. })
  189. return err
  190. }
  191. // UpdateStatus 更新状态
  192. func (a *ChatAssistant) UpdateStatus(ctx context.Context, recordID string, status int) error {
  193. oldItem, err := a.ChatAssistantModel.Get(ctx, recordID)
  194. if err != nil {
  195. return err
  196. } else if oldItem == nil {
  197. return errors.ErrNotFound
  198. }
  199. return a.ChatAssistantModel.UpdateStatus(ctx, recordID, status)
  200. }