b_chat_session.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. package internal
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/gogf/gf/v2/util/guid"
  6. "strings"
  7. "time"
  8. "yx-dataset-server/app/errors"
  9. "yx-dataset-server/app/model"
  10. "yx-dataset-server/app/schema"
  11. "yx-dataset-server/library/ragflow"
  12. "yx-dataset-server/library/redis"
  13. )
  14. // NewChatSession 创建ChatSession
  15. func NewChatSession(
  16. mChatSession model.IChatSession,
  17. mTrans model.ITrans,
  18. mAssistant model.IChatAssistant,
  19. mChatDataset model.IChatDataset,
  20. mDataset model.IDataset,
  21. mMessage model.IChatMessage,
  22. mUser model.IUser,
  23. ) *ChatSession {
  24. return &ChatSession{
  25. ChatSessionModel: mChatSession,
  26. transModel: mTrans,
  27. assistantModel: mAssistant,
  28. chatDatasetModel: mChatDataset,
  29. datasetModel: mDataset,
  30. messageModel: mMessage,
  31. userModel: mUser,
  32. }
  33. }
  34. // ChatSession 创建ChatSession对象
  35. type ChatSession struct {
  36. ChatSessionModel model.IChatSession
  37. transModel model.ITrans
  38. assistantModel model.IChatAssistant
  39. chatDatasetModel model.IChatDataset
  40. datasetModel model.IDataset
  41. messageModel model.IChatMessage
  42. userModel model.IUser
  43. }
  44. // Query 查询数据
  45. func (a *ChatSession) Query(ctx context.Context, params schema.ChatSessionQueryParam, opts ...schema.ChatSessionQueryOptions) (*schema.ChatSessionQueryResult, error) {
  46. result, err := a.ChatSessionModel.Query(ctx, params, opts...)
  47. if err != nil {
  48. return nil, err
  49. }
  50. user, err := a.userModel.Query(ctx, schema.UserQueryParam{})
  51. if err != nil {
  52. return nil, err
  53. }
  54. result.Data.FillCreator(user.Data)
  55. return result, nil
  56. }
  57. // Get 查询指定数据
  58. func (a *ChatSession) Get(ctx context.Context, recordID string, opts ...schema.ChatSessionQueryOptions) (*schema.ChatSession, error) {
  59. item, err := a.ChatSessionModel.Get(ctx, recordID, opts...)
  60. if err != nil {
  61. return nil, err
  62. } else if item == nil {
  63. return nil, errors.ErrNotFound
  64. }
  65. user, err := a.userModel.Get(ctx, item.CreatorId)
  66. if err != nil {
  67. return nil, err
  68. }
  69. item.CreatorName = user.RealName
  70. message, err := a.messageModel.Query(ctx, schema.ChatMessageQueryParam{SessionId: recordID})
  71. if err != nil {
  72. return nil, err
  73. }
  74. item.Messages = message.Data
  75. return item, nil
  76. }
  77. // Create 创建数据
  78. func (a *ChatSession) Create(ctx context.Context, item schema.ChatSession) (*schema.ChatSession, error) {
  79. item.RecordID = guid.S()
  80. item.CreatorId = GetUserID(ctx)
  81. item.Name = fmt.Sprintf("%s%s", time.Now().Format("2006-01-02 15:04:05 "), "会话")
  82. assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
  83. if err != nil {
  84. return nil, err
  85. }
  86. s, err := ragflow.GetHttpClient().CreateSession(ctx, assistant.RagChatId, &ragflow.CreateSessionReq{
  87. Name: item.Name,
  88. })
  89. if err != nil {
  90. return nil, err
  91. }
  92. item.RagSessionId = s.Data.Id
  93. err = a.ChatSessionModel.Create(ctx, item)
  94. if err != nil {
  95. return nil, err
  96. }
  97. if item.Source == 2 {
  98. key := fmt.Sprintf("chat:session:%s", item.RecordID)
  99. _ = redis.GetClient().Set(ctx, key, item, 7200)
  100. }
  101. return a.Get(ctx, item.RecordID)
  102. }
  103. func (a *ChatSession) H5Create(ctx context.Context, item schema.ChatSession) (*schema.ChatSession, error) {
  104. item.RecordID = guid.S()
  105. item.CreatorId = GetUserID(ctx)
  106. item.Name = fmt.Sprintf("%s%s", time.Now().Format("2006-01-02 15:04:05 "), "会话")
  107. item.Source = 1
  108. assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
  109. if err != nil {
  110. return nil, err
  111. }
  112. item.RagChatId = assistant.RagChatId
  113. chatDataset, err := a.chatDatasetModel.Query(ctx, schema.ChatDatasetQueryParam{ChatAssistantId: assistant.RecordID})
  114. if err != nil {
  115. return nil, err
  116. }
  117. dataset, err := a.datasetModel.Query(ctx, schema.DatasetQueryParam{RecordIds: chatDataset.Data.ToDatasetIds()})
  118. if err != nil {
  119. return nil, err
  120. }
  121. var datasetName []string
  122. for _, v := range dataset.Data {
  123. quotedName := fmt.Sprintf(`《%s》`, v.Name)
  124. datasetName = append(datasetName, quotedName)
  125. }
  126. err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  127. s, err := ragflow.GetHttpClient().CreateSession(ctx, assistant.RagChatId, &ragflow.CreateSessionReq{
  128. Name: item.Name,
  129. })
  130. if err != nil {
  131. return err
  132. }
  133. item.RagSessionId = s.Data.Id
  134. err = a.ChatSessionModel.Create(ctx, item)
  135. if err != nil {
  136. return err
  137. }
  138. //key := fmt.Sprintf("chat:session:%s", item.RecordID)
  139. //_ = redis.GetClient().Set(ctx, key, item, 8640000)
  140. answer := fmt.Sprintf("你好,欢迎使用知识库问答助手!关于%s的疑问,我都会尽力为你解答,快来提问吧!", strings.Join(datasetName, ","))
  141. message := schema.ChatMessage{
  142. RecordID: guid.S(),
  143. UserId: item.CreatorId,
  144. AssistantId: item.AssistantId,
  145. SessionId: item.RecordID,
  146. RagSessionId: s.Data.Id,
  147. Question: "",
  148. Answer: strings.ReplaceAll(answer, `\`, ``),
  149. CreatorId: item.CreatorId,
  150. }
  151. err = a.messageModel.Create(ctx, message)
  152. if err != nil {
  153. return err
  154. }
  155. item.Messages = append(item.Messages, &message)
  156. return nil
  157. })
  158. return a.Get(ctx, item.RecordID)
  159. }
  160. // Update 更新数据
  161. func (a *ChatSession) Update(ctx context.Context, recordID string, item schema.ChatSession) (*schema.ChatSession, error) {
  162. oldItem, err := a.ChatSessionModel.Get(ctx, recordID)
  163. if err != nil {
  164. return nil, err
  165. } else if oldItem == nil {
  166. return nil, errors.ErrNotFound
  167. }
  168. // update 只修改名称
  169. if item.Name == oldItem.Name {
  170. return nil, err
  171. }
  172. err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  173. oldItem.Name = item.Name
  174. err = a.ChatSessionModel.Update(ctx, recordID, *oldItem)
  175. if err != nil {
  176. return err
  177. }
  178. if oldItem.Name != item.Name {
  179. _, err = ragflow.GetHttpClient().UpdateSession(ctx, oldItem.RagChatId, oldItem.RagSessionId, &ragflow.UpdateSessionReq{
  180. Name: item.Name,
  181. })
  182. if err != nil {
  183. return err
  184. }
  185. }
  186. return nil
  187. })
  188. return a.Get(ctx, recordID)
  189. }
  190. // Delete 删除数据
  191. func (a *ChatSession) Delete(ctx context.Context, recordID string) error {
  192. oldItem, err := a.ChatSessionModel.Get(ctx, recordID)
  193. if err != nil {
  194. return err
  195. } else if oldItem == nil {
  196. return errors.ErrNotFound
  197. }
  198. err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  199. err = a.ChatSessionModel.Delete(ctx, recordID)
  200. if err != nil {
  201. return err
  202. }
  203. _, err = ragflow.GetHttpClient().DeleteSessions(ctx, oldItem.RagChatId, []string{oldItem.RagSessionId})
  204. return err
  205. })
  206. return err
  207. }
  208. // UpdateStatus 更新状态
  209. func (a *ChatSession) UpdateStatus(ctx context.Context, recordID string, status int) error {
  210. oldItem, err := a.ChatSessionModel.Get(ctx, recordID)
  211. if err != nil {
  212. return err
  213. } else if oldItem == nil {
  214. return errors.ErrNotFound
  215. }
  216. return a.ChatSessionModel.UpdateStatus(ctx, recordID, status)
  217. }