|
@@ -0,0 +1,385 @@
|
|
|
|
|
+package internal
|
|
|
|
|
+
|
|
|
|
|
+import (
|
|
|
|
|
+ "bufio"
|
|
|
|
|
+ "context"
|
|
|
|
|
+ "encoding/json"
|
|
|
|
|
+ "fmt"
|
|
|
|
|
+ "github.com/gogf/gf/util/guid"
|
|
|
|
|
+ "github.com/gogf/gf/v2/net/ghttp"
|
|
|
|
|
+ "strings"
|
|
|
|
|
+ "yx-dataset-server/app/errors"
|
|
|
|
|
+ "yx-dataset-server/app/model"
|
|
|
|
|
+ "yx-dataset-server/app/schema"
|
|
|
|
|
+ "yx-dataset-server/library/ragflow"
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+// NewChatMessage 创建ChatMessage
|
|
|
|
|
+func NewChatMessage(
|
|
|
|
|
+ mSession model.IChatSession,
|
|
|
|
|
+ mChatMessage model.IChatMessage,
|
|
|
|
|
+ mAssistant model.IChatAssistant,
|
|
|
|
|
+) *ChatMessage {
|
|
|
|
|
+ return &ChatMessage{
|
|
|
|
|
+ ChatMessageModel: mChatMessage,
|
|
|
|
|
+ sessionModel: mSession,
|
|
|
|
|
+ assistantModel: mAssistant,
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// ChatMessage 创建ChatMessage对象
|
|
|
|
|
+type ChatMessage struct {
|
|
|
|
|
+ ChatMessageModel model.IChatMessage
|
|
|
|
|
+ sessionModel model.IChatSession
|
|
|
|
|
+ assistantModel model.IChatAssistant
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Query 查询数据
|
|
|
|
|
+func (a *ChatMessage) Query(ctx context.Context, params schema.ChatMessageQueryParam, opts ...schema.ChatMessageQueryOptions) (*schema.ChatMessageQueryResult, error) {
|
|
|
|
|
+ return a.ChatMessageModel.Query(ctx, params, opts...)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Get 查询指定数据
|
|
|
|
|
+func (a *ChatMessage) Get(ctx context.Context, recordID string, opts ...schema.ChatMessageQueryOptions) (*schema.ChatMessage, error) {
|
|
|
|
|
+ item, err := a.ChatMessageModel.Get(ctx, recordID, opts...)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ } else if item == nil {
|
|
|
|
|
+ return nil, errors.ErrNotFound
|
|
|
|
|
+ }
|
|
|
|
|
+ return item, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (a *ChatMessage) getUpdate(ctx context.Context, recordID string) (*schema.ChatMessage, error) {
|
|
|
|
|
+ return a.Get(ctx, recordID)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// StreamChatMessage 流式消息
|
|
|
|
|
+func (a *ChatMessage) StreamChatMessage(ctx context.Context, r *ghttp.Request, item schema.ChatMessage) {
|
|
|
|
|
+ item.RecordID = guid.S()
|
|
|
|
|
+ item.CreatorId = GetUserID(ctx)
|
|
|
|
|
+ //var sessionId string
|
|
|
|
|
+ if item.SessionId == "" {
|
|
|
|
|
+ r.Response.Write("data: {\"code\":500,\"message\":\"会话ID不能为空\"}\n\n")
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ session, err := a.sessionModel.Get(ctx, item.SessionId)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ //sessionId = session.RagSessionId
|
|
|
|
|
+
|
|
|
|
|
+ assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
|
|
|
|
|
+ Question: item.Question,
|
|
|
|
|
+ SessionID: session.RagSessionId,
|
|
|
|
|
+ })
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ defer stream.Close() // 关闭流
|
|
|
|
|
+ // 用于拼接答案
|
|
|
|
|
+ var fullAnswer strings.Builder
|
|
|
|
|
+ //err = a.collectStreamDataAndCreate(ctx, stream, &item)
|
|
|
|
|
+ //if err != nil {
|
|
|
|
|
+ // r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
|
|
|
|
|
+ // return
|
|
|
|
|
+ //}
|
|
|
|
|
+ // ===================== 5. 设置SSE响应头(必须!) =====================
|
|
|
|
|
+ r.Response.Header().Set("Content-Type", "text/event-stream")
|
|
|
|
|
+ r.Response.Header().Set("Cache-Control", "no-cache")
|
|
|
|
|
+ r.Response.Header().Set("Connection", "keep-alive")
|
|
|
|
|
+ r.Response.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲
|
|
|
|
|
+ r.Response.Flush()
|
|
|
|
|
+
|
|
|
|
|
+ scanner := bufio.NewScanner(stream)
|
|
|
|
|
+ buf := make([]byte, 1024*1024*10) // 增大缓冲区
|
|
|
|
|
+ scanner.Buffer(buf, 1024*1024*10)
|
|
|
|
|
+
|
|
|
|
|
+ for scanner.Scan() {
|
|
|
|
|
+ line := scanner.Text()
|
|
|
|
|
+ // 原封不动写给前端
|
|
|
|
|
+ r.Response.Write(line + "\n")
|
|
|
|
|
+ r.Response.Flush() // 实时刷新
|
|
|
|
|
+
|
|
|
|
|
+ //2. 后端收集数据(拼接完整答案)
|
|
|
|
|
+ if strings.HasPrefix(line, "data:") {
|
|
|
|
|
+ jsonStr := strings.TrimPrefix(line, "data:")
|
|
|
|
|
+ jsonStr = strings.TrimSpace(jsonStr)
|
|
|
|
|
+
|
|
|
|
|
+ if jsonStr != "true" {
|
|
|
|
|
+ var chunk ragflow.ChatCompletionResp
|
|
|
|
|
+ if err := json.Unmarshal([]byte(jsonStr), &chunk); err == nil {
|
|
|
|
|
+ fullAnswer.WriteString(chunk.Data.Answer)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 处理流读取错误
|
|
|
|
|
+ if err := scanner.Err(); err != nil {
|
|
|
|
|
+ r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"读取流失败:%s\"}\n\n", err.Error()))
|
|
|
|
|
+ r.Response.Flush()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // ===================== 【业务层:保存完整答案到数据库】 =====================
|
|
|
|
|
+ item.Answer = fullAnswer.String()
|
|
|
|
|
+ if err := a.ChatMessageModel.Create(ctx, item); err != nil {
|
|
|
|
|
+ fmt.Printf("保存聊天记录失败:%v\n", err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (a *ChatMessage) ChatMessage(ctx context.Context, r *ghttp.Request, item *schema.ChatMessage) (*schema.ChatMessage, error) {
|
|
|
|
|
+ item.RecordID = guid.S()
|
|
|
|
|
+ item.CreatorId = GetUserID(ctx)
|
|
|
|
|
+ //var sessionId string
|
|
|
|
|
+ if item.SessionId == "" {
|
|
|
|
|
+ return nil, errors.New400Response("会话ID不能为空")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ session, err := a.sessionModel.Get(ctx, item.SessionId)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ //sessionId = session.RagSessionId
|
|
|
|
|
+
|
|
|
|
|
+ assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
|
|
|
|
|
+ Question: item.Question,
|
|
|
|
|
+ SessionID: session.RagSessionId,
|
|
|
|
|
+ })
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ defer stream.Close() // 关闭流
|
|
|
|
|
+
|
|
|
|
|
+ item, err = CollectStreamDataAndCreate(ctx, stream, item)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ err = a.ChatMessageModel.Create(ctx, *item)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return item, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+//func (a *ChatMessage) collectStreamDataAndCreate(ctx context.Context, stream io.ReadCloser, item *schema.ChatMessage) error {
|
|
|
|
|
+// // 解决 token too long 报错:调大缓冲区到 10MB
|
|
|
|
|
+// const maxBufferSize = 10 * 1024 * 1024
|
|
|
|
|
+// scanner := bufio.NewScanner(stream)
|
|
|
|
|
+// buf := make([]byte, maxBufferSize)
|
|
|
|
|
+// scanner.Buffer(buf, maxBufferSize)
|
|
|
|
|
+//
|
|
|
|
|
+// // 用于拼接答案
|
|
|
|
|
+// var fullAnswer strings.Builder
|
|
|
|
|
+// // 用于存储最终完整结果
|
|
|
|
|
+// finalResp := &ragflow.ChatCompletionResp{}
|
|
|
|
|
+//
|
|
|
|
|
+// // 逐行读取流式返回
|
|
|
|
|
+// for scanner.Scan() {
|
|
|
|
|
+// line := strings.TrimSpace(scanner.Text())
|
|
|
|
|
+// if line == "" {
|
|
|
|
|
+// continue
|
|
|
|
|
+// }
|
|
|
|
|
+//
|
|
|
|
|
+// // 只处理 SSE 格式 data: ...
|
|
|
|
|
+// if !strings.HasPrefix(line, "data:") {
|
|
|
|
|
+// continue
|
|
|
|
|
+// }
|
|
|
|
|
+// jsonStr := strings.TrimPrefix(line, "data:")
|
|
|
|
|
+// jsonStr = strings.TrimSpace(jsonStr)
|
|
|
|
|
+//
|
|
|
|
|
+// // 流结束标记:data: true
|
|
|
|
|
+// if jsonStr == "true" {
|
|
|
|
|
+// break
|
|
|
|
|
+// }
|
|
|
|
|
+//
|
|
|
|
|
+// // 解析当前分片
|
|
|
|
|
+// var chunk ragflow.ChatCompletionResp
|
|
|
|
|
+// if err := json.Unmarshal([]byte(jsonStr), &chunk); err != nil {
|
|
|
|
|
+// continue
|
|
|
|
|
+// }
|
|
|
|
|
+// // 拼接回答内容
|
|
|
|
|
+// if chunk.Data.Answer != "" {
|
|
|
|
|
+// fullAnswer.WriteString(chunk.Data.Answer)
|
|
|
|
|
+// }
|
|
|
|
|
+// // 保存最后一个分片(包含 reference/session_id 等完整信息)
|
|
|
|
|
+// *finalResp = chunk
|
|
|
|
|
+// }
|
|
|
|
|
+//
|
|
|
|
|
+// if err := scanner.Err(); err != nil {
|
|
|
|
|
+// return errors.New400Response(fmt.Sprintf("读取流式数据失败: %v", err))
|
|
|
|
|
+// }
|
|
|
|
|
+//
|
|
|
|
|
+// item.Answer = fullAnswer.String()
|
|
|
|
|
+//
|
|
|
|
|
+// return a.ChatMessageModel.Create(ctx, *item)
|
|
|
|
|
+//
|
|
|
|
|
+//}
|
|
|
|
|
+
|
|
|
|
|
+func (a *ChatMessage) CreateV3(ctx context.Context, r *ghttp.Request, item schema.ChatMessage) {
|
|
|
|
|
+ item.RecordID = guid.S()
|
|
|
|
|
+ item.CreatorId = GetUserID(ctx)
|
|
|
|
|
+ //var sessionId string
|
|
|
|
|
+ if item.SessionId == "" {
|
|
|
|
|
+ //return nil, errors.New400Response("会话ID不能为空")
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ session, err := a.sessionModel.Get(ctx, item.SessionId)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ //return nil, err
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ //sessionId = session.RagSessionId
|
|
|
|
|
+
|
|
|
|
|
+ assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ //return nil, err
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ //resp, err := ragflow.GetHttpClient().ChatCompletions(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
|
|
|
|
|
+ // Question: item.Question,
|
|
|
|
|
+ // SessionID: sessionId,
|
|
|
|
|
+ //})
|
|
|
|
|
+ //if err != nil {
|
|
|
|
|
+ // return nil, err
|
|
|
|
|
+ //}
|
|
|
|
|
+ //item.Answer = resp.Data.Answer
|
|
|
|
|
+ err = a.ChatMessageModel.Create(ctx, item)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
|
|
|
|
|
+ Question: item.Question,
|
|
|
|
|
+ SessionID: session.RagSessionId,
|
|
|
|
|
+ })
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ defer stream.Close() // 关闭流
|
|
|
|
|
+
|
|
|
|
|
+ // ===================== 5. 设置SSE响应头(必须!) =====================
|
|
|
|
|
+ r.Response.Header().Set("Content-Type", "text/event-stream")
|
|
|
|
|
+ r.Response.Header().Set("Cache-Control", "no-cache")
|
|
|
|
|
+ r.Response.Header().Set("Connection", "keep-alive")
|
|
|
|
|
+ r.Response.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲
|
|
|
|
|
+ //r.Response.Flush()
|
|
|
|
|
+ //
|
|
|
|
|
+ //_, err = io.Copy(r.Response.Writer, stream)
|
|
|
|
|
+ //if err != nil {
|
|
|
|
|
+ // // 可自定义错误处理
|
|
|
|
|
+ // r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()))
|
|
|
|
|
+ //}
|
|
|
|
|
+ // 实时刷新
|
|
|
|
|
+ //===================== 6. 逐行转发RAG流给前端 =====================
|
|
|
|
|
+ scanner := bufio.NewScanner(stream)
|
|
|
|
|
+ buf := make([]byte, 1024*1024*10) // 增大缓冲区
|
|
|
|
|
+ scanner.Buffer(buf, 1024*1024*10)
|
|
|
|
|
+
|
|
|
|
|
+ for scanner.Scan() {
|
|
|
|
|
+ line := scanner.Text()
|
|
|
|
|
+ // 原封不动写给前端
|
|
|
|
|
+ r.Response.Write(line + "\n")
|
|
|
|
|
+ r.Response.Flush() // 实时刷新
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ //// 处理错误
|
|
|
|
|
+ //if err = scanner.Err(); err != nil {
|
|
|
|
|
+ // r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()))
|
|
|
|
|
+ // r.Response.Flush() // 实时刷新
|
|
|
|
|
+ //}
|
|
|
|
|
+
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (a *ChatMessage) CreateV2(ctx context.Context, item schema.ChatMessage) (*schema.ChatMessage, error) {
|
|
|
|
|
+ item.RecordID = guid.S()
|
|
|
|
|
+ item.CreatorId = GetUserID(ctx)
|
|
|
|
|
+
|
|
|
|
|
+ if item.SessionId == "" {
|
|
|
|
|
+ return nil, errors.New400Response("会话ID不能为空")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ session, err := a.sessionModel.Get(ctx, item.SessionId)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 调用非流式接口
|
|
|
|
|
+ resp, err := ragflow.GetHttpClient().ChatCompletions(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
|
|
|
|
|
+ Question: item.Question,
|
|
|
|
|
+ SessionID: session.RagSessionId,
|
|
|
|
|
+ })
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ item.Answer = resp.Data.Answer
|
|
|
|
|
+ // 保存数据库
|
|
|
|
|
+ err = a.ChatMessageModel.Create(ctx, item)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return a.getUpdate(ctx, item.RecordID)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Update 更新数据
|
|
|
|
|
+func (a *ChatMessage) Update(ctx context.Context, recordID string, item schema.ChatMessage) (*schema.ChatMessage, error) {
|
|
|
|
|
+ oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ } else if oldItem == nil {
|
|
|
|
|
+ return nil, errors.ErrNotFound
|
|
|
|
|
+ }
|
|
|
|
|
+ err = a.ChatMessageModel.Update(ctx, recordID, item)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ return a.getUpdate(ctx, recordID)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Delete 删除数据
|
|
|
|
|
+func (a *ChatMessage) Delete(ctx context.Context, recordID string) error {
|
|
|
|
|
+ oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ } else if oldItem == nil {
|
|
|
|
|
+ return errors.ErrNotFound
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return a.ChatMessageModel.Delete(ctx, recordID)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// UpdateStatus 更新状态
|
|
|
|
|
+func (a *ChatMessage) UpdateStatus(ctx context.Context, recordID string, status int) error {
|
|
|
|
|
+ oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ } else if oldItem == nil {
|
|
|
|
|
+ return errors.ErrNotFound
|
|
|
|
|
+ }
|
|
|
|
|
+ return a.ChatMessageModel.UpdateStatus(ctx, recordID, status)
|
|
|
|
|
+}
|