b_chat_message.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. package internal
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/gogf/gf/util/guid"
  8. "github.com/gogf/gf/v2/net/ghttp"
  9. "strings"
  10. "yx-dataset-server/app/errors"
  11. "yx-dataset-server/app/model"
  12. "yx-dataset-server/app/schema"
  13. "yx-dataset-server/library/ragflow"
  14. )
  15. // NewChatMessage 创建ChatMessage
  16. func NewChatMessage(
  17. mSession model.IChatSession,
  18. mChatMessage model.IChatMessage,
  19. mAssistant model.IChatAssistant,
  20. ) *ChatMessage {
  21. return &ChatMessage{
  22. ChatMessageModel: mChatMessage,
  23. sessionModel: mSession,
  24. assistantModel: mAssistant,
  25. }
  26. }
  27. // ChatMessage 创建ChatMessage对象
  28. type ChatMessage struct {
  29. ChatMessageModel model.IChatMessage
  30. sessionModel model.IChatSession
  31. assistantModel model.IChatAssistant
  32. }
  33. // Query 查询数据
  34. func (a *ChatMessage) Query(ctx context.Context, params schema.ChatMessageQueryParam, opts ...schema.ChatMessageQueryOptions) (*schema.ChatMessageQueryResult, error) {
  35. return a.ChatMessageModel.Query(ctx, params, opts...)
  36. }
  37. // Get 查询指定数据
  38. func (a *ChatMessage) Get(ctx context.Context, recordID string, opts ...schema.ChatMessageQueryOptions) (*schema.ChatMessage, error) {
  39. item, err := a.ChatMessageModel.Get(ctx, recordID, opts...)
  40. if err != nil {
  41. return nil, err
  42. } else if item == nil {
  43. return nil, errors.ErrNotFound
  44. }
  45. return item, nil
  46. }
  47. func (a *ChatMessage) getUpdate(ctx context.Context, recordID string) (*schema.ChatMessage, error) {
  48. return a.Get(ctx, recordID)
  49. }
  50. // StreamChatMessage 流式消息
  51. func (a *ChatMessage) StreamChatMessage(ctx context.Context, r *ghttp.Request, item schema.ChatMessage) {
  52. item.RecordID = guid.S()
  53. item.CreatorId = GetUserID(ctx)
  54. //var sessionId string
  55. if item.SessionId == "" {
  56. r.Response.Write("data: {\"code\":500,\"message\":\"会话ID不能为空\"}\n\n")
  57. return
  58. }
  59. session, err := a.sessionModel.Get(ctx, item.SessionId)
  60. if err != nil {
  61. r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
  62. return
  63. }
  64. //sessionId = session.RagSessionId
  65. assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
  66. if err != nil {
  67. r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
  68. return
  69. }
  70. stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
  71. Question: item.Question,
  72. SessionID: session.RagSessionId,
  73. })
  74. if err != nil {
  75. r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
  76. return
  77. }
  78. defer stream.Close() // 关闭流
  79. // 用于拼接答案
  80. var fullAnswer strings.Builder
  81. //err = a.collectStreamDataAndCreate(ctx, stream, &item)
  82. //if err != nil {
  83. // r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
  84. // return
  85. //}
  86. // ===================== 5. 设置SSE响应头(必须!) =====================
  87. r.Response.Header().Set("Content-Type", "text/event-stream")
  88. r.Response.Header().Set("Cache-Control", "no-cache")
  89. r.Response.Header().Set("Connection", "keep-alive")
  90. r.Response.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲
  91. r.Response.Flush()
  92. scanner := bufio.NewScanner(stream)
  93. buf := make([]byte, 1024*1024*10) // 增大缓冲区
  94. scanner.Buffer(buf, 1024*1024*10)
  95. for scanner.Scan() {
  96. line := scanner.Text()
  97. // 原封不动写给前端
  98. r.Response.Write(line + "\n")
  99. r.Response.Flush() // 实时刷新
  100. //2. 后端收集数据(拼接完整答案)
  101. if strings.HasPrefix(line, "data:") {
  102. jsonStr := strings.TrimPrefix(line, "data:")
  103. jsonStr = strings.TrimSpace(jsonStr)
  104. if jsonStr != "true" {
  105. var chunk ragflow.ChatCompletionResp
  106. if err := json.Unmarshal([]byte(jsonStr), &chunk); err == nil {
  107. fullAnswer.WriteString(chunk.Data.Answer)
  108. }
  109. }
  110. }
  111. }
  112. // 处理流读取错误
  113. if err := scanner.Err(); err != nil {
  114. r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"读取流失败:%s\"}\n\n", err.Error()))
  115. r.Response.Flush()
  116. return
  117. }
  118. // ===================== 【业务层:保存完整答案到数据库】 =====================
  119. item.Answer = fullAnswer.String()
  120. if err := a.ChatMessageModel.Create(ctx, item); err != nil {
  121. fmt.Printf("保存聊天记录失败:%v\n", err)
  122. }
  123. }
  124. func (a *ChatMessage) ChatMessage(ctx context.Context, r *ghttp.Request, item *schema.ChatMessage) (*schema.ChatMessage, error) {
  125. item.RecordID = guid.S()
  126. item.CreatorId = GetUserID(ctx)
  127. //var sessionId string
  128. if item.SessionId == "" {
  129. return nil, errors.New400Response("会话ID不能为空")
  130. }
  131. session, err := a.sessionModel.Get(ctx, item.SessionId)
  132. if err != nil {
  133. return nil, err
  134. }
  135. //sessionId = session.RagSessionId
  136. assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
  137. if err != nil {
  138. return nil, err
  139. }
  140. stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
  141. Question: item.Question,
  142. SessionID: session.RagSessionId,
  143. })
  144. if err != nil {
  145. return nil, err
  146. }
  147. defer stream.Close() // 关闭流
  148. item, err = CollectStreamDataAndCreate(ctx, stream, item)
  149. if err != nil {
  150. return nil, err
  151. }
  152. err = a.ChatMessageModel.Create(ctx, *item)
  153. if err != nil {
  154. return nil, err
  155. }
  156. return item, nil
  157. }
  158. //func (a *ChatMessage) collectStreamDataAndCreate(ctx context.Context, stream io.ReadCloser, item *schema.ChatMessage) error {
  159. // // 解决 token too long 报错:调大缓冲区到 10MB
  160. // const maxBufferSize = 10 * 1024 * 1024
  161. // scanner := bufio.NewScanner(stream)
  162. // buf := make([]byte, maxBufferSize)
  163. // scanner.Buffer(buf, maxBufferSize)
  164. //
  165. // // 用于拼接答案
  166. // var fullAnswer strings.Builder
  167. // // 用于存储最终完整结果
  168. // finalResp := &ragflow.ChatCompletionResp{}
  169. //
  170. // // 逐行读取流式返回
  171. // for scanner.Scan() {
  172. // line := strings.TrimSpace(scanner.Text())
  173. // if line == "" {
  174. // continue
  175. // }
  176. //
  177. // // 只处理 SSE 格式 data: ...
  178. // if !strings.HasPrefix(line, "data:") {
  179. // continue
  180. // }
  181. // jsonStr := strings.TrimPrefix(line, "data:")
  182. // jsonStr = strings.TrimSpace(jsonStr)
  183. //
  184. // // 流结束标记:data: true
  185. // if jsonStr == "true" {
  186. // break
  187. // }
  188. //
  189. // // 解析当前分片
  190. // var chunk ragflow.ChatCompletionResp
  191. // if err := json.Unmarshal([]byte(jsonStr), &chunk); err != nil {
  192. // continue
  193. // }
  194. // // 拼接回答内容
  195. // if chunk.Data.Answer != "" {
  196. // fullAnswer.WriteString(chunk.Data.Answer)
  197. // }
  198. // // 保存最后一个分片(包含 reference/session_id 等完整信息)
  199. // *finalResp = chunk
  200. // }
  201. //
  202. // if err := scanner.Err(); err != nil {
  203. // return errors.New400Response(fmt.Sprintf("读取流式数据失败: %v", err))
  204. // }
  205. //
  206. // item.Answer = fullAnswer.String()
  207. //
  208. // return a.ChatMessageModel.Create(ctx, *item)
  209. //
  210. //}
  211. func (a *ChatMessage) CreateV3(ctx context.Context, r *ghttp.Request, item schema.ChatMessage) {
  212. item.RecordID = guid.S()
  213. item.CreatorId = GetUserID(ctx)
  214. //var sessionId string
  215. if item.SessionId == "" {
  216. //return nil, errors.New400Response("会话ID不能为空")
  217. return
  218. }
  219. session, err := a.sessionModel.Get(ctx, item.SessionId)
  220. if err != nil {
  221. //return nil, err
  222. return
  223. }
  224. //sessionId = session.RagSessionId
  225. assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
  226. if err != nil {
  227. //return nil, err
  228. return
  229. }
  230. //resp, err := ragflow.GetHttpClient().ChatCompletions(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
  231. // Question: item.Question,
  232. // SessionID: sessionId,
  233. //})
  234. //if err != nil {
  235. // return nil, err
  236. //}
  237. //item.Answer = resp.Data.Answer
  238. err = a.ChatMessageModel.Create(ctx, item)
  239. if err != nil {
  240. return
  241. }
  242. stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
  243. Question: item.Question,
  244. SessionID: session.RagSessionId,
  245. })
  246. if err != nil {
  247. return
  248. }
  249. defer stream.Close() // 关闭流
  250. // ===================== 5. 设置SSE响应头(必须!) =====================
  251. r.Response.Header().Set("Content-Type", "text/event-stream")
  252. r.Response.Header().Set("Cache-Control", "no-cache")
  253. r.Response.Header().Set("Connection", "keep-alive")
  254. r.Response.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲
  255. //r.Response.Flush()
  256. //
  257. //_, err = io.Copy(r.Response.Writer, stream)
  258. //if err != nil {
  259. // // 可自定义错误处理
  260. // r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()))
  261. //}
  262. // 实时刷新
  263. //===================== 6. 逐行转发RAG流给前端 =====================
  264. scanner := bufio.NewScanner(stream)
  265. buf := make([]byte, 1024*1024*10) // 增大缓冲区
  266. scanner.Buffer(buf, 1024*1024*10)
  267. for scanner.Scan() {
  268. line := scanner.Text()
  269. // 原封不动写给前端
  270. r.Response.Write(line + "\n")
  271. r.Response.Flush() // 实时刷新
  272. }
  273. //// 处理错误
  274. //if err = scanner.Err(); err != nil {
  275. // r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()))
  276. // r.Response.Flush() // 实时刷新
  277. //}
  278. }
  279. func (a *ChatMessage) CreateV2(ctx context.Context, item schema.ChatMessage) (*schema.ChatMessage, error) {
  280. item.RecordID = guid.S()
  281. item.CreatorId = GetUserID(ctx)
  282. if item.SessionId == "" {
  283. return nil, errors.New400Response("会话ID不能为空")
  284. }
  285. session, err := a.sessionModel.Get(ctx, item.SessionId)
  286. if err != nil {
  287. return nil, err
  288. }
  289. assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
  290. if err != nil {
  291. return nil, err
  292. }
  293. // 调用非流式接口
  294. resp, err := ragflow.GetHttpClient().ChatCompletions(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
  295. Question: item.Question,
  296. SessionID: session.RagSessionId,
  297. })
  298. if err != nil {
  299. return nil, err
  300. }
  301. item.Answer = resp.Data.Answer
  302. // 保存数据库
  303. err = a.ChatMessageModel.Create(ctx, item)
  304. if err != nil {
  305. return nil, err
  306. }
  307. return a.getUpdate(ctx, item.RecordID)
  308. }
  309. // Update 更新数据
  310. func (a *ChatMessage) Update(ctx context.Context, recordID string, item schema.ChatMessage) (*schema.ChatMessage, error) {
  311. oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
  312. if err != nil {
  313. return nil, err
  314. } else if oldItem == nil {
  315. return nil, errors.ErrNotFound
  316. }
  317. err = a.ChatMessageModel.Update(ctx, recordID, item)
  318. if err != nil {
  319. return nil, err
  320. }
  321. return a.getUpdate(ctx, recordID)
  322. }
  323. // Delete 删除数据
  324. func (a *ChatMessage) Delete(ctx context.Context, recordID string) error {
  325. oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
  326. if err != nil {
  327. return err
  328. } else if oldItem == nil {
  329. return errors.ErrNotFound
  330. }
  331. return a.ChatMessageModel.Delete(ctx, recordID)
  332. }
  333. // UpdateStatus 更新状态
  334. func (a *ChatMessage) UpdateStatus(ctx context.Context, recordID string, status int) error {
  335. oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
  336. if err != nil {
  337. return err
  338. } else if oldItem == nil {
  339. return errors.ErrNotFound
  340. }
  341. return a.ChatMessageModel.UpdateStatus(ctx, recordID, status)
  342. }