package internal import ( "bufio" "context" "encoding/json" "fmt" "io" "regexp" "strings" iContext "yx-dataset-server/app/context" "yx-dataset-server/app/errors" "yx-dataset-server/app/model" "yx-dataset-server/app/schema" "yx-dataset-server/library/ragflow" "yx-dataset-server/library/utils" ) // TransFunc 定义事务执行函数 type TransFunc func(context.Context) error // GetRootUser 获取root用户 func GetRootUser() *schema.User { return &schema.User{ RecordID: utils.GetConfig("root.user_name").String(), UserName: utils.GetConfig("root.user_name").String(), RealName: utils.GetConfig("root.real_name").String(), Photo: utils.GetConfig("root.photo").String(), Password: utils.MD5HashString(utils.GetConfig("root.password").String()), } } // GetUserID 获取用户ID func GetUserID(ctx context.Context) string { userID, _ := iContext.FromUserID(ctx) return userID } // CheckIsRootUser 检查是否是root用户 func CheckIsRootUser(ctx context.Context) bool { userId := GetUserID(ctx) return GetRootUser().RecordID == userId } // ExecTrans 执行事务 func ExecTrans(ctx context.Context, transModel model.ITrans, fn TransFunc) error { if _, ok := iContext.FromTrans(ctx); ok { return fn(ctx) } trans, err := transModel.Begin(ctx) if err != nil { return err } defer func() { if r := recover(); r != nil { _ = transModel.Rollback(ctx, trans) panic(r) } }() err = fn(iContext.NewTrans(ctx, trans)) if err != nil { _ = transModel.Rollback(ctx, trans) return err } return transModel.Commit(ctx, trans) } // ExecTransWithLock 执行事务(加锁) func ExecTransWithLock(ctx context.Context, transModel model.ITrans, fn TransFunc) error { if !iContext.FromTransLock(ctx) { ctx = iContext.NewTransLock(ctx) } return ExecTrans(ctx, transModel, fn) } func CollectStreamDataAndCreate(ctx context.Context, stream io.ReadCloser, item *schema.ChatMessage) (*schema.ChatMessage, error) { // 解决 token too long 报错:调大缓冲区到 10MB const maxBufferSize = 10 * 1024 * 1024 scanner := bufio.NewScanner(stream) buf := make([]byte, maxBufferSize) scanner.Buffer(buf, maxBufferSize) // 用于拼接答案 var fullAnswer strings.Builder // 用于存储最终完整结果 finalResp := &ragflow.ChatCompletionResp{} re := regexp.MustCompile(`\[ID:\d+\]`) // 逐行读取流式返回 for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" { continue } // 只处理 SSE 格式 data: ... if !strings.HasPrefix(line, "data:") { continue } jsonStr := strings.TrimPrefix(line, "data:") jsonStr = strings.TrimSpace(jsonStr) // 流结束标记:data: true if jsonStr == "true" { break } // 解析当前分片 var chunk ragflow.ChatCompletionResp if err := json.Unmarshal([]byte(jsonStr), &chunk); err != nil { continue } // 拼接回答内容 if chunk.Data.Answer != "" { cleanAnswer := re.ReplaceAllString(chunk.Data.Answer, "") fullAnswer.WriteString(cleanAnswer) } // 保存最后一个分片(包含 reference/session_id 等完整信息) *finalResp = chunk } if err := scanner.Err(); err != nil { return nil, errors.New400Response(fmt.Sprintf("读取流式数据失败: %v", err)) } cleanedFullAnswer := re.ReplaceAllString(fullAnswer.String(), "") item.Answer = cleanedFullAnswer return item, nil }