package internal import ( "bufio" "context" "encoding/json" "fmt" "github.com/gogf/gf/util/guid" "github.com/gogf/gf/v2/net/ghttp" "strings" "yx-dataset-server/app/errors" "yx-dataset-server/app/model" "yx-dataset-server/app/schema" "yx-dataset-server/library/ragflow" ) // NewChatMessage 创建ChatMessage func NewChatMessage( mSession model.IChatSession, mChatMessage model.IChatMessage, mAssistant model.IChatAssistant, ) *ChatMessage { return &ChatMessage{ ChatMessageModel: mChatMessage, sessionModel: mSession, assistantModel: mAssistant, } } // ChatMessage 创建ChatMessage对象 type ChatMessage struct { ChatMessageModel model.IChatMessage sessionModel model.IChatSession assistantModel model.IChatAssistant } // Query 查询数据 func (a *ChatMessage) Query(ctx context.Context, params schema.ChatMessageQueryParam, opts ...schema.ChatMessageQueryOptions) (*schema.ChatMessageQueryResult, error) { return a.ChatMessageModel.Query(ctx, params, opts...) } // Get 查询指定数据 func (a *ChatMessage) Get(ctx context.Context, recordID string, opts ...schema.ChatMessageQueryOptions) (*schema.ChatMessage, error) { item, err := a.ChatMessageModel.Get(ctx, recordID, opts...) if err != nil { return nil, err } else if item == nil { return nil, errors.ErrNotFound } return item, nil } func (a *ChatMessage) getUpdate(ctx context.Context, recordID string) (*schema.ChatMessage, error) { return a.Get(ctx, recordID) } // StreamChatMessage 流式消息 func (a *ChatMessage) StreamChatMessage(ctx context.Context, r *ghttp.Request, item schema.ChatMessage) { item.RecordID = guid.S() item.CreatorId = GetUserID(ctx) //var sessionId string if item.SessionId == "" { r.Response.Write("data: {\"code\":500,\"message\":\"会话ID不能为空\"}\n\n") return } session, err := a.sessionModel.Get(ctx, item.SessionId) if err != nil { r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()) return } //sessionId = session.RagSessionId assistant, err := a.assistantModel.Get(ctx, item.AssistantId) if err != nil { r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()) return } stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{ Question: item.Question, SessionID: session.RagSessionId, }) if err != nil { r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()) return } defer stream.Close() // 关闭流 // 用于拼接答案 var fullAnswer strings.Builder //err = a.collectStreamDataAndCreate(ctx, stream, &item) //if err != nil { // r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()) // return //} // ===================== 5. 设置SSE响应头(必须!) ===================== r.Response.Header().Set("Content-Type", "text/event-stream") r.Response.Header().Set("Cache-Control", "no-cache") r.Response.Header().Set("Connection", "keep-alive") r.Response.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲 r.Response.Flush() scanner := bufio.NewScanner(stream) buf := make([]byte, 1024*1024*10) // 增大缓冲区 scanner.Buffer(buf, 1024*1024*10) for scanner.Scan() { line := scanner.Text() // 原封不动写给前端 r.Response.Write(line + "\n") r.Response.Flush() // 实时刷新 //2. 后端收集数据(拼接完整答案) if strings.HasPrefix(line, "data:") { jsonStr := strings.TrimPrefix(line, "data:") jsonStr = strings.TrimSpace(jsonStr) if jsonStr != "true" { var chunk ragflow.ChatCompletionResp if err := json.Unmarshal([]byte(jsonStr), &chunk); err == nil { fullAnswer.WriteString(chunk.Data.Answer) } } } } // 处理流读取错误 if err := scanner.Err(); err != nil { r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"读取流失败:%s\"}\n\n", err.Error())) r.Response.Flush() return } // ===================== 【业务层:保存完整答案到数据库】 ===================== item.Answer = fullAnswer.String() if err := a.ChatMessageModel.Create(ctx, item); err != nil { fmt.Printf("保存聊天记录失败:%v\n", err) } } func (a *ChatMessage) ChatMessage(ctx context.Context, r *ghttp.Request, item *schema.ChatMessage) (*schema.ChatMessage, error) { item.RecordID = guid.S() item.CreatorId = GetUserID(ctx) //var sessionId string if item.SessionId == "" { return nil, errors.New400Response("会话ID不能为空") } session, err := a.sessionModel.Get(ctx, item.SessionId) if err != nil { return nil, err } //sessionId = session.RagSessionId assistant, err := a.assistantModel.Get(ctx, item.AssistantId) if err != nil { return nil, err } stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{ Question: item.Question, SessionID: session.RagSessionId, }) if err != nil { return nil, err } defer stream.Close() // 关闭流 item, err = CollectStreamDataAndCreate(ctx, stream, item) if err != nil { return nil, err } err = a.ChatMessageModel.Create(ctx, *item) if err != nil { return nil, err } return item, nil } //func (a *ChatMessage) collectStreamDataAndCreate(ctx context.Context, stream io.ReadCloser, item *schema.ChatMessage) error { // // 解决 token too long 报错:调大缓冲区到 10MB // const maxBufferSize = 10 * 1024 * 1024 // scanner := bufio.NewScanner(stream) // buf := make([]byte, maxBufferSize) // scanner.Buffer(buf, maxBufferSize) // // // 用于拼接答案 // var fullAnswer strings.Builder // // 用于存储最终完整结果 // finalResp := &ragflow.ChatCompletionResp{} // // // 逐行读取流式返回 // for scanner.Scan() { // line := strings.TrimSpace(scanner.Text()) // if line == "" { // continue // } // // // 只处理 SSE 格式 data: ... // if !strings.HasPrefix(line, "data:") { // continue // } // jsonStr := strings.TrimPrefix(line, "data:") // jsonStr = strings.TrimSpace(jsonStr) // // // 流结束标记:data: true // if jsonStr == "true" { // break // } // // // 解析当前分片 // var chunk ragflow.ChatCompletionResp // if err := json.Unmarshal([]byte(jsonStr), &chunk); err != nil { // continue // } // // 拼接回答内容 // if chunk.Data.Answer != "" { // fullAnswer.WriteString(chunk.Data.Answer) // } // // 保存最后一个分片(包含 reference/session_id 等完整信息) // *finalResp = chunk // } // // if err := scanner.Err(); err != nil { // return errors.New400Response(fmt.Sprintf("读取流式数据失败: %v", err)) // } // // item.Answer = fullAnswer.String() // // return a.ChatMessageModel.Create(ctx, *item) // //} func (a *ChatMessage) CreateV3(ctx context.Context, r *ghttp.Request, item schema.ChatMessage) { item.RecordID = guid.S() item.CreatorId = GetUserID(ctx) //var sessionId string if item.SessionId == "" { //return nil, errors.New400Response("会话ID不能为空") return } session, err := a.sessionModel.Get(ctx, item.SessionId) if err != nil { //return nil, err return } //sessionId = session.RagSessionId assistant, err := a.assistantModel.Get(ctx, item.AssistantId) if err != nil { //return nil, err return } //resp, err := ragflow.GetHttpClient().ChatCompletions(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{ // Question: item.Question, // SessionID: sessionId, //}) //if err != nil { // return nil, err //} //item.Answer = resp.Data.Answer err = a.ChatMessageModel.Create(ctx, item) if err != nil { return } stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{ Question: item.Question, SessionID: session.RagSessionId, }) if err != nil { return } defer stream.Close() // 关闭流 // ===================== 5. 设置SSE响应头(必须!) ===================== r.Response.Header().Set("Content-Type", "text/event-stream") r.Response.Header().Set("Cache-Control", "no-cache") r.Response.Header().Set("Connection", "keep-alive") r.Response.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲 //r.Response.Flush() // //_, err = io.Copy(r.Response.Writer, stream) //if err != nil { // // 可自定义错误处理 // r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())) //} // 实时刷新 //===================== 6. 逐行转发RAG流给前端 ===================== scanner := bufio.NewScanner(stream) buf := make([]byte, 1024*1024*10) // 增大缓冲区 scanner.Buffer(buf, 1024*1024*10) for scanner.Scan() { line := scanner.Text() // 原封不动写给前端 r.Response.Write(line + "\n") r.Response.Flush() // 实时刷新 } //// 处理错误 //if err = scanner.Err(); err != nil { // r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())) // r.Response.Flush() // 实时刷新 //} } func (a *ChatMessage) CreateV2(ctx context.Context, item schema.ChatMessage) (*schema.ChatMessage, error) { item.RecordID = guid.S() item.CreatorId = GetUserID(ctx) if item.SessionId == "" { return nil, errors.New400Response("会话ID不能为空") } session, err := a.sessionModel.Get(ctx, item.SessionId) if err != nil { return nil, err } assistant, err := a.assistantModel.Get(ctx, item.AssistantId) if err != nil { return nil, err } // 调用非流式接口 resp, err := ragflow.GetHttpClient().ChatCompletions(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{ Question: item.Question, SessionID: session.RagSessionId, }) if err != nil { return nil, err } item.Answer = resp.Data.Answer // 保存数据库 err = a.ChatMessageModel.Create(ctx, item) if err != nil { return nil, err } return a.getUpdate(ctx, item.RecordID) } // Update 更新数据 func (a *ChatMessage) Update(ctx context.Context, recordID string, item schema.ChatMessage) (*schema.ChatMessage, error) { oldItem, err := a.ChatMessageModel.Get(ctx, recordID) if err != nil { return nil, err } else if oldItem == nil { return nil, errors.ErrNotFound } err = a.ChatMessageModel.Update(ctx, recordID, item) if err != nil { return nil, err } return a.getUpdate(ctx, recordID) } // Delete 删除数据 func (a *ChatMessage) Delete(ctx context.Context, recordID string) error { oldItem, err := a.ChatMessageModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } return a.ChatMessageModel.Delete(ctx, recordID) } // UpdateStatus 更新状态 func (a *ChatMessage) UpdateStatus(ctx context.Context, recordID string, status int) error { oldItem, err := a.ChatMessageModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } return a.ChatMessageModel.UpdateStatus(ctx, recordID, status) }