| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385 |
- package internal
- import (
- "bufio"
- "context"
- "encoding/json"
- "fmt"
- "github.com/gogf/gf/util/guid"
- "github.com/gogf/gf/v2/net/ghttp"
- "strings"
- "yx-dataset-server/app/errors"
- "yx-dataset-server/app/model"
- "yx-dataset-server/app/schema"
- "yx-dataset-server/library/ragflow"
- )
- // NewChatMessage 创建ChatMessage
- func NewChatMessage(
- mSession model.IChatSession,
- mChatMessage model.IChatMessage,
- mAssistant model.IChatAssistant,
- ) *ChatMessage {
- return &ChatMessage{
- ChatMessageModel: mChatMessage,
- sessionModel: mSession,
- assistantModel: mAssistant,
- }
- }
- // ChatMessage 创建ChatMessage对象
- type ChatMessage struct {
- ChatMessageModel model.IChatMessage
- sessionModel model.IChatSession
- assistantModel model.IChatAssistant
- }
- // Query 查询数据
- func (a *ChatMessage) Query(ctx context.Context, params schema.ChatMessageQueryParam, opts ...schema.ChatMessageQueryOptions) (*schema.ChatMessageQueryResult, error) {
- return a.ChatMessageModel.Query(ctx, params, opts...)
- }
- // Get 查询指定数据
- func (a *ChatMessage) Get(ctx context.Context, recordID string, opts ...schema.ChatMessageQueryOptions) (*schema.ChatMessage, error) {
- item, err := a.ChatMessageModel.Get(ctx, recordID, opts...)
- if err != nil {
- return nil, err
- } else if item == nil {
- return nil, errors.ErrNotFound
- }
- return item, nil
- }
- func (a *ChatMessage) getUpdate(ctx context.Context, recordID string) (*schema.ChatMessage, error) {
- return a.Get(ctx, recordID)
- }
- // StreamChatMessage 流式消息
- func (a *ChatMessage) StreamChatMessage(ctx context.Context, r *ghttp.Request, item schema.ChatMessage) {
- item.RecordID = guid.S()
- item.CreatorId = GetUserID(ctx)
- //var sessionId string
- if item.SessionId == "" {
- r.Response.Write("data: {\"code\":500,\"message\":\"会话ID不能为空\"}\n\n")
- return
- }
- session, err := a.sessionModel.Get(ctx, item.SessionId)
- if err != nil {
- r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
- return
- }
- //sessionId = session.RagSessionId
- assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
- if err != nil {
- r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
- return
- }
- stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
- Question: item.Question,
- SessionID: session.RagSessionId,
- })
- if err != nil {
- r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
- return
- }
- defer stream.Close() // 关闭流
- // 用于拼接答案
- var fullAnswer strings.Builder
- //err = a.collectStreamDataAndCreate(ctx, stream, &item)
- //if err != nil {
- // r.Response.Write("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error())
- // return
- //}
- // ===================== 5. 设置SSE响应头(必须!) =====================
- r.Response.Header().Set("Content-Type", "text/event-stream")
- r.Response.Header().Set("Cache-Control", "no-cache")
- r.Response.Header().Set("Connection", "keep-alive")
- r.Response.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲
- r.Response.Flush()
- scanner := bufio.NewScanner(stream)
- buf := make([]byte, 1024*1024*10) // 增大缓冲区
- scanner.Buffer(buf, 1024*1024*10)
- for scanner.Scan() {
- line := scanner.Text()
- // 原封不动写给前端
- r.Response.Write(line + "\n")
- r.Response.Flush() // 实时刷新
- //2. 后端收集数据(拼接完整答案)
- if strings.HasPrefix(line, "data:") {
- jsonStr := strings.TrimPrefix(line, "data:")
- jsonStr = strings.TrimSpace(jsonStr)
- if jsonStr != "true" {
- var chunk ragflow.ChatCompletionResp
- if err := json.Unmarshal([]byte(jsonStr), &chunk); err == nil {
- fullAnswer.WriteString(chunk.Data.Answer)
- }
- }
- }
- }
- // 处理流读取错误
- if err := scanner.Err(); err != nil {
- r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"读取流失败:%s\"}\n\n", err.Error()))
- r.Response.Flush()
- return
- }
- // ===================== 【业务层:保存完整答案到数据库】 =====================
- item.Answer = fullAnswer.String()
- if err := a.ChatMessageModel.Create(ctx, item); err != nil {
- fmt.Printf("保存聊天记录失败:%v\n", err)
- }
- }
- func (a *ChatMessage) ChatMessage(ctx context.Context, r *ghttp.Request, item *schema.ChatMessage) (*schema.ChatMessage, error) {
- item.RecordID = guid.S()
- item.CreatorId = GetUserID(ctx)
- //var sessionId string
- if item.SessionId == "" {
- return nil, errors.New400Response("会话ID不能为空")
- }
- session, err := a.sessionModel.Get(ctx, item.SessionId)
- if err != nil {
- return nil, err
- }
- //sessionId = session.RagSessionId
- assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
- if err != nil {
- return nil, err
- }
- stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
- Question: item.Question,
- SessionID: session.RagSessionId,
- })
- if err != nil {
- return nil, err
- }
- defer stream.Close() // 关闭流
- item, err = CollectStreamDataAndCreate(ctx, stream, item)
- if err != nil {
- return nil, err
- }
- err = a.ChatMessageModel.Create(ctx, *item)
- if err != nil {
- return nil, err
- }
- return item, nil
- }
- //func (a *ChatMessage) collectStreamDataAndCreate(ctx context.Context, stream io.ReadCloser, item *schema.ChatMessage) error {
- // // 解决 token too long 报错:调大缓冲区到 10MB
- // const maxBufferSize = 10 * 1024 * 1024
- // scanner := bufio.NewScanner(stream)
- // buf := make([]byte, maxBufferSize)
- // scanner.Buffer(buf, maxBufferSize)
- //
- // // 用于拼接答案
- // var fullAnswer strings.Builder
- // // 用于存储最终完整结果
- // finalResp := &ragflow.ChatCompletionResp{}
- //
- // // 逐行读取流式返回
- // for scanner.Scan() {
- // line := strings.TrimSpace(scanner.Text())
- // if line == "" {
- // continue
- // }
- //
- // // 只处理 SSE 格式 data: ...
- // if !strings.HasPrefix(line, "data:") {
- // continue
- // }
- // jsonStr := strings.TrimPrefix(line, "data:")
- // jsonStr = strings.TrimSpace(jsonStr)
- //
- // // 流结束标记:data: true
- // if jsonStr == "true" {
- // break
- // }
- //
- // // 解析当前分片
- // var chunk ragflow.ChatCompletionResp
- // if err := json.Unmarshal([]byte(jsonStr), &chunk); err != nil {
- // continue
- // }
- // // 拼接回答内容
- // if chunk.Data.Answer != "" {
- // fullAnswer.WriteString(chunk.Data.Answer)
- // }
- // // 保存最后一个分片(包含 reference/session_id 等完整信息)
- // *finalResp = chunk
- // }
- //
- // if err := scanner.Err(); err != nil {
- // return errors.New400Response(fmt.Sprintf("读取流式数据失败: %v", err))
- // }
- //
- // item.Answer = fullAnswer.String()
- //
- // return a.ChatMessageModel.Create(ctx, *item)
- //
- //}
- func (a *ChatMessage) CreateV3(ctx context.Context, r *ghttp.Request, item schema.ChatMessage) {
- item.RecordID = guid.S()
- item.CreatorId = GetUserID(ctx)
- //var sessionId string
- if item.SessionId == "" {
- //return nil, errors.New400Response("会话ID不能为空")
- return
- }
- session, err := a.sessionModel.Get(ctx, item.SessionId)
- if err != nil {
- //return nil, err
- return
- }
- //sessionId = session.RagSessionId
- assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
- if err != nil {
- //return nil, err
- return
- }
- //resp, err := ragflow.GetHttpClient().ChatCompletions(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
- // Question: item.Question,
- // SessionID: sessionId,
- //})
- //if err != nil {
- // return nil, err
- //}
- //item.Answer = resp.Data.Answer
- err = a.ChatMessageModel.Create(ctx, item)
- if err != nil {
- return
- }
- stream, err := ragflow.GetHttpClient().ChatCompletionsStream(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
- Question: item.Question,
- SessionID: session.RagSessionId,
- })
- if err != nil {
- return
- }
- defer stream.Close() // 关闭流
- // ===================== 5. 设置SSE响应头(必须!) =====================
- r.Response.Header().Set("Content-Type", "text/event-stream")
- r.Response.Header().Set("Cache-Control", "no-cache")
- r.Response.Header().Set("Connection", "keep-alive")
- r.Response.Header().Set("X-Accel-Buffering", "no") // 禁用Nginx缓冲
- //r.Response.Flush()
- //
- //_, err = io.Copy(r.Response.Writer, stream)
- //if err != nil {
- // // 可自定义错误处理
- // r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()))
- //}
- // 实时刷新
- //===================== 6. 逐行转发RAG流给前端 =====================
- scanner := bufio.NewScanner(stream)
- buf := make([]byte, 1024*1024*10) // 增大缓冲区
- scanner.Buffer(buf, 1024*1024*10)
- for scanner.Scan() {
- line := scanner.Text()
- // 原封不动写给前端
- r.Response.Write(line + "\n")
- r.Response.Flush() // 实时刷新
- }
- //// 处理错误
- //if err = scanner.Err(); err != nil {
- // r.Response.Write(fmt.Sprintf("data: {\"code\":500,\"message\":\"%s\"}\n\n", err.Error()))
- // r.Response.Flush() // 实时刷新
- //}
- }
- func (a *ChatMessage) CreateV2(ctx context.Context, item schema.ChatMessage) (*schema.ChatMessage, error) {
- item.RecordID = guid.S()
- item.CreatorId = GetUserID(ctx)
- if item.SessionId == "" {
- return nil, errors.New400Response("会话ID不能为空")
- }
- session, err := a.sessionModel.Get(ctx, item.SessionId)
- if err != nil {
- return nil, err
- }
- assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
- if err != nil {
- return nil, err
- }
- // 调用非流式接口
- resp, err := ragflow.GetHttpClient().ChatCompletions(ctx, assistant.RagChatId, &ragflow.ChatCompletionReq{
- Question: item.Question,
- SessionID: session.RagSessionId,
- })
- if err != nil {
- return nil, err
- }
- item.Answer = resp.Data.Answer
- // 保存数据库
- err = a.ChatMessageModel.Create(ctx, item)
- if err != nil {
- return nil, err
- }
- return a.getUpdate(ctx, item.RecordID)
- }
- // Update 更新数据
- func (a *ChatMessage) Update(ctx context.Context, recordID string, item schema.ChatMessage) (*schema.ChatMessage, error) {
- oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
- if err != nil {
- return nil, err
- } else if oldItem == nil {
- return nil, errors.ErrNotFound
- }
- err = a.ChatMessageModel.Update(ctx, recordID, item)
- if err != nil {
- return nil, err
- }
- return a.getUpdate(ctx, recordID)
- }
- // Delete 删除数据
- func (a *ChatMessage) Delete(ctx context.Context, recordID string) error {
- oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
- if err != nil {
- return err
- } else if oldItem == nil {
- return errors.ErrNotFound
- }
- return a.ChatMessageModel.Delete(ctx, recordID)
- }
- // UpdateStatus 更新状态
- func (a *ChatMessage) UpdateStatus(ctx context.Context, recordID string, status int) error {
- oldItem, err := a.ChatMessageModel.Get(ctx, recordID)
- if err != nil {
- return err
- } else if oldItem == nil {
- return errors.ErrNotFound
- }
- return a.ChatMessageModel.UpdateStatus(ctx, recordID, status)
- }
|