b_chat_message.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. package internal
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/gogf/gf/v2/net/ghttp"
  8. "github.com/gogf/gf/v2/os/glog"
  9. "github.com/gogf/gf/v2/util/guid"
  10. "regexp"
  11. "strings"
  12. "yx-dataset-server/app/errors"
  13. "yx-dataset-server/app/model"
  14. "yx-dataset-server/app/schema"
  15. "yx-dataset-server/library/ragflow"
  16. )
  17. // NewChatMessage 创建ChatMessage
  18. func NewChatMessage(
  19. mSession model.IChatSession,
  20. mChatMessage model.IChatMessage,
  21. mAssistant model.IChatAssistant,
  22. ) *ChatMessage {
  23. return &ChatMessage{
  24. ChatMessageModel: mChatMessage,
  25. sessionModel: mSession,
  26. assistantModel: mAssistant,
  27. }
  28. }
  29. // ChatMessage 创建ChatMessage对象
  30. type ChatMessage struct {
  31. ChatMessageModel model.IChatMessage
  32. sessionModel model.IChatSession
  33. assistantModel model.IChatAssistant
  34. }
  35. // Query 查询数据
  36. func (a *ChatMessage) Query(ctx context.Context, params schema.ChatMessageQueryParam, opts ...schema.ChatMessageQueryOptions) (*schema.ChatMessageQueryResult, error) {
  37. return a.ChatMessageModel.Query(ctx, params, opts...)
  38. }
  39. // Get 查询指定数据
  40. func (a *ChatMessage) Get(ctx context.Context, recordID string, opts ...schema.ChatMessageQueryOptions) (*schema.ChatMessage, error) {
  41. item, err := a.ChatMessageModel.Get(ctx, recordID, opts...)
  42. if err != nil {
  43. return nil, err
  44. } else if item == nil {
  45. return nil, errors.ErrNotFound
  46. }
  47. return item, nil
  48. }
  49. func (a *ChatMessage) getUpdate(ctx context.Context, recordID string) (*schema.ChatMessage, error) {
  50. return a.Get(ctx, recordID)
  51. }
  52. // StreamChatMessage 流式消息
  53. func (a *ChatMessage) StreamChatMessage(ctx context.Context, r *ghttp.Request, item schema.ChatMessage) {
  54. item.RecordID = guid.S()
  55. item.CreatorId = GetUserID(ctx)
  56. //var sessionId string
  57. if item.SessionId == "" {
  58. r.Response.Write("data: {\"code\":500,\"message\":\"会话ID不能为空\"}\n\n")
  59. return
  60. }
  61. session, err := a.sessionModel.Get(ctx, item.SessionId)
  62. if err != nil {
  63. r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
  64. return
  65. }
  66. messages, err := a.ChatMessageModel.Query(ctx, schema.ChatMessageQueryParam{SessionId: item.SessionId})
  67. if err != nil {
  68. r.Response.Write("data: {\"code\":500,\"message\":\"会话ID不能为空\"}\n\n")
  69. return
  70. }
  71. if len(messages.Data) == 0 {
  72. go a.updateSessionName(ctx, *session, item.Question)
  73. }
  74. assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
  75. if err != nil {
  76. r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
  77. return
  78. }
  79. re := regexp.MustCompile(`\[ID:\d+\]`)
  80. stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
  81. Question: item.Question,
  82. SessionID: session.RagSessionId,
  83. })
  84. if err != nil {
  85. r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
  86. return
  87. }
  88. defer stream.Close() // 关闭流
  89. // 用于拼接答案
  90. var fullAnswer strings.Builder
  91. //err = a.collectStreamDataAndCreate(ctx, stream, &item)
  92. //if err != nil {
  93. // r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
  94. // return
  95. //}
  96. // ===================== 5. 设置SSE响应头(必须!) =====================
  97. r.Response.Header().Set("Content-Type", "text/event-stream")
  98. r.Response.Header().Set("Cache-Control", "no-cache")
  99. r.Response.Header().Set("Connection", "keep-alive")
  100. r.Response.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲
  101. r.Response.Flush()
  102. scanner := bufio.NewScanner(stream)
  103. buf := make([]byte, 1024*1024*10) // 增大缓冲区
  104. scanner.Buffer(buf, 1024*1024*10)
  105. finalResp := &ragflow.ChatCompletionResp{}
  106. for scanner.Scan() {
  107. line := scanner.Text()
  108. // 原封不动写给前端
  109. r.Response.Write(line + "\n")
  110. r.Response.Flush() // 实时刷新
  111. var chunk ragflow.ChatCompletionResp
  112. //2. 后端收集数据(拼接完整答案)
  113. if strings.HasPrefix(line, "data:") {
  114. jsonStr := strings.TrimPrefix(line, "data:")
  115. jsonStr = strings.TrimSpace(jsonStr)
  116. if jsonStr != "true" {
  117. if err := json.Unmarshal([]byte(jsonStr), &chunk); err == nil {
  118. cleanAnswer := re.ReplaceAllString(chunk.Data.Answer, "")
  119. fullAnswer.WriteString(cleanAnswer)
  120. }
  121. }
  122. }
  123. *finalResp = chunk
  124. }
  125. // 处理流读取错误
  126. if err := scanner.Err(); err != nil {
  127. r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"读取流失败:%s\"}\n\n", err.Error()))
  128. r.Response.Flush()
  129. return
  130. }
  131. // ===================== 【业务层:保存完整答案到数据库】 =====================
  132. item.Answer = fullAnswer.String()
  133. if err := a.ChatMessageModel.Create(ctx, item); err != nil {
  134. fmt.Printf("保存聊天记录失败:%v\n", err)
  135. }
  136. }
  137. // updateSessionName 更新会话名称
  138. func (a *ChatMessage) updateSessionName(ctx context.Context, session schema.ChatSession, question string) {
  139. idx := strings.Index(question, ",")
  140. if idx == -1 {
  141. session.Name = question
  142. }
  143. // 截取从开头到逗号的内容(不包含逗号)
  144. session.Name = question[:idx]
  145. err := a.sessionModel.Update(ctx, session.RecordID, session)
  146. if err != nil {
  147. glog.Errorf(ctx, "更新会话名称失败:%v", err)
  148. }
  149. _, err = ragflow.GetHttpClient().UpdateSession(ctx, session.RagChatId, session.RagSessionId, &ragflow.UpdateSessionReq{Name: session.Name})
  150. if err != nil {
  151. glog.Errorf(ctx, "更新ragflow会话名称失败:%v", err)
  152. return
  153. }
  154. }
  155. func (a *ChatMessage) ChatMessage(ctx context.Context, r *ghttp.Request, item *schema.ChatMessage) (*schema.ChatMessage, error) {
  156. item.RecordID = guid.S()
  157. item.CreatorId = GetUserID(ctx)
  158. //var sessionId string
  159. if item.SessionId == "" {
  160. return nil, errors.New400Response("会话ID不能为空")
  161. }
  162. session, err := a.sessionModel.Get(ctx, item.SessionId)
  163. if err != nil {
  164. return nil, err
  165. }
  166. //sessionId = session.RagSessionId
  167. assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
  168. if err != nil {
  169. return nil, err
  170. }
  171. stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
  172. Question: item.Question,
  173. SessionID: session.RagSessionId,
  174. })
  175. if err != nil {
  176. return nil, err
  177. }
  178. defer stream.Close() // 关闭流
  179. item, err = CollectStreamDataAndCreate(ctx, stream, item)
  180. if err != nil {
  181. return nil, err
  182. }
  183. err = a.ChatMessageModel.Create(ctx, *item)
  184. if err != nil {
  185. return nil, err
  186. }
  187. return item, nil
  188. }
  189. //func (a *ChatMessage) collectStreamDataAndCreate(ctx context.Context, stream io.ReadCloser, item *schema.ChatMessage) error {
  190. // // 解决 token too long 报错:调大缓冲区到 10MB
  191. // const maxBufferSize = 10 * 1024 * 1024
  192. // scanner := bufio.NewScanner(stream)
  193. // buf := make([]byte, maxBufferSize)
  194. // scanner.Buffer(buf, maxBufferSize)
  195. //
  196. // // 用于拼接答案
  197. // var fullAnswer strings.Builder
  198. // // 用于存储最终完整结果
  199. // finalResp := &ragflow.ChatCompletionResp{}
  200. //
  201. // // 逐行读取流式返回
  202. // for scanner.Scan() {
  203. // line := strings.TrimSpace(scanner.Text())
  204. // if line == "" {
  205. // continue
  206. // }
  207. //
  208. // // 只处理 SSE 格式 data: ...
  209. // if !strings.HasPrefix(line, "data:") {
  210. // continue
  211. // }
  212. // jsonStr := strings.TrimPrefix(line, "data:")
  213. // jsonStr = strings.TrimSpace(jsonStr)
  214. //
  215. // // 流结束标记:data: true
  216. // if jsonStr == "true" {
  217. // break
  218. // }
  219. //
  220. // // 解析当前分片
  221. // var chunk ragflow.ChatCompletionResp
  222. // if err := json.Unmarshal([]byte(jsonStr), &chunk); err != nil {
  223. // continue
  224. // }
  225. // // 拼接回答内容
  226. // if chunk.Data.Answer != "" {
  227. // fullAnswer.WriteString(chunk.Data.Answer)
  228. // }
  229. // // 保存最后一个分片(包含 reference/session_id 等完整信息)
  230. // *finalResp = chunk
  231. // }
  232. //
  233. // if err := scanner.Err(); err != nil {
  234. // return errors.New400Response(fmt.Sprintf("读取流式数据失败: %v", err))
  235. // }
  236. //
  237. // item.Answer = fullAnswer.String()
  238. //
  239. // return a.ChatMessageModel.Create(ctx, *item)
  240. //
  241. //}
  242. func (a *ChatMessage) CreateV3(ctx context.Context, r *ghttp.Request, item schema.ChatMessage) {
  243. item.RecordID = guid.S()
  244. item.CreatorId = GetUserID(ctx)
  245. //var sessionId string
  246. if item.SessionId == "" {
  247. //return nil, errors.New400Response("会话ID不能为空")
  248. return
  249. }
  250. session, err := a.sessionModel.Get(ctx, item.SessionId)
  251. if err != nil {
  252. //return nil, err
  253. return
  254. }
  255. //sessionId = session.RagSessionId
  256. assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
  257. if err != nil {
  258. //return nil, err
  259. return
  260. }
  261. //resp, err := ragflow.GetHttpClient().ChatCompletions(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
  262. // Question: item.Question,
  263. // SessionID: sessionId,
  264. //})
  265. //if err != nil {
  266. // return nil, err
  267. //}
  268. //item.Answer = resp.Data.Answer
  269. err = a.ChatMessageModel.Create(ctx, item)
  270. if err != nil {
  271. return
  272. }
  273. stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
  274. Question: item.Question,
  275. SessionID: session.RagSessionId,
  276. })
  277. if err != nil {
  278. return
  279. }
  280. defer stream.Close() // 关闭流
  281. // ===================== 5. 设置SSE响应头(必须!) =====================
  282. r.Response.Header().Set("Content-Type", "text/event-stream")
  283. r.Response.Header().Set("Cache-Control", "no-cache")
  284. r.Response.Header().Set("Connection", "keep-alive")
  285. r.Response.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲
  286. //r.Response.Flush()
  287. //
  288. //_, err = io.Copy(r.Response.Writer, stream)
  289. //if err != nil {
  290. // // 可自定义错误处理
  291. // r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()))
  292. //}
  293. // 实时刷新
  294. //===================== 6. 逐行转发RAG流给前端 =====================
  295. scanner := bufio.NewScanner(stream)
  296. buf := make([]byte, 1024*1024*10) // 增大缓冲区
  297. scanner.Buffer(buf, 1024*1024*10)
  298. for scanner.Scan() {
  299. line := scanner.Text()
  300. // 原封不动写给前端
  301. r.Response.Write(line + "\n")
  302. r.Response.Flush() // 实时刷新
  303. }
  304. //// 处理错误
  305. //if err = scanner.Err(); err != nil {
  306. // r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()))
  307. // r.Response.Flush() // 实时刷新
  308. //}
  309. }
  310. func (a *ChatMessage) CreateV2(ctx context.Context, item schema.ChatMessage) (*schema.ChatMessage, error) {
  311. item.RecordID = guid.S()
  312. item.CreatorId = GetUserID(ctx)
  313. if item.SessionId == "" {
  314. return nil, errors.New400Response("会话ID不能为空")
  315. }
  316. session, err := a.sessionModel.Get(ctx, item.SessionId)
  317. if err != nil {
  318. return nil, err
  319. }
  320. assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
  321. if err != nil {
  322. return nil, err
  323. }
  324. // 调用非流式接口
  325. resp, err := ragflow.GetHttpClient().ChatCompletions(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
  326. Question: item.Question,
  327. SessionID: session.RagSessionId,
  328. })
  329. if err != nil {
  330. return nil, err
  331. }
  332. item.Answer = resp.Data.Answer
  333. // 保存数据库
  334. err = a.ChatMessageModel.Create(ctx, item)
  335. if err != nil {
  336. return nil, err
  337. }
  338. return a.getUpdate(ctx, item.RecordID)
  339. }
  340. // Update 更新数据
  341. func (a *ChatMessage) Update(ctx context.Context, recordID string, item schema.ChatMessage) (*schema.ChatMessage, error) {
  342. oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
  343. if err != nil {
  344. return nil, err
  345. } else if oldItem == nil {
  346. return nil, errors.ErrNotFound
  347. }
  348. err = a.ChatMessageModel.Update(ctx, recordID, item)
  349. if err != nil {
  350. return nil, err
  351. }
  352. return a.getUpdate(ctx, recordID)
  353. }
  354. // Delete 删除数据
  355. func (a *ChatMessage) Delete(ctx context.Context, recordID string) error {
  356. oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
  357. if err != nil {
  358. return err
  359. } else if oldItem == nil {
  360. return errors.ErrNotFound
  361. }
  362. return a.ChatMessageModel.Delete(ctx, recordID)
  363. }
  364. // UpdateStatus 更新状态
  365. func (a *ChatMessage) UpdateStatus(ctx context.Context, recordID string, status int) error {
  366. oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
  367. if err != nil {
  368. return err
  369. } else if oldItem == nil {
  370. return errors.ErrNotFound
  371. }
  372. return a.ChatMessageModel.UpdateStatus(ctx, recordID, status)
  373. }