| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- package internal
- import (
- "bufio"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "regexp"
- "strings"
- iContext "yx-dataset-server/app/context"
- "yx-dataset-server/app/errors"
- "yx-dataset-server/app/model"
- "yx-dataset-server/app/schema"
- "yx-dataset-server/library/ragflow"
- "yx-dataset-server/library/utils"
- )
- // TransFunc 定义事务执行函数
- type TransFunc func(context.Context) error
- // GetRootUser 获取root用户
- func GetRootUser() *schema.User {
- return &schema.User{
- RecordID: utils.GetConfig("root.user_name").String(),
- UserName: utils.GetConfig("root.user_name").String(),
- RealName: utils.GetConfig("root.real_name").String(),
- Photo: utils.GetConfig("root.photo").String(),
- Password: utils.MD5HashString(utils.GetConfig("root.password").String()),
- }
- }
- // GetUserID 获取用户ID
- func GetUserID(ctx context.Context) string {
- userID, _ := iContext.FromUserID(ctx)
- return userID
- }
- // CheckIsRootUser 检查是否是root用户
- func CheckIsRootUser(ctx context.Context) bool {
- userId := GetUserID(ctx)
- return GetRootUser().RecordID == userId
- }
- // ExecTrans 执行事务
- func ExecTrans(ctx context.Context, transModel model.ITrans, fn TransFunc) error {
- if _, ok := iContext.FromTrans(ctx); ok {
- return fn(ctx)
- }
- trans, err := transModel.Begin(ctx)
- if err != nil {
- return err
- }
- defer func() {
- if r := recover(); r != nil {
- _ = transModel.Rollback(ctx, trans)
- panic(r)
- }
- }()
- err = fn(iContext.NewTrans(ctx, trans))
- if err != nil {
- _ = transModel.Rollback(ctx, trans)
- return err
- }
- return transModel.Commit(ctx, trans)
- }
- // ExecTransWithLock 执行事务(加锁)
- func ExecTransWithLock(ctx context.Context, transModel model.ITrans, fn TransFunc) error {
- if !iContext.FromTransLock(ctx) {
- ctx = iContext.NewTransLock(ctx)
- }
- return ExecTrans(ctx, transModel, fn)
- }
- func CollectStreamDataAndCreate(ctx context.Context, stream io.ReadCloser, item *schema.ChatMessage) (*schema.ChatMessage, error) {
- // 解决 token too long 报错:调大缓冲区到 10MB
- const maxBufferSize = 10 * 1024 * 1024
- scanner := bufio.NewScanner(stream)
- buf := make([]byte, maxBufferSize)
- scanner.Buffer(buf, maxBufferSize)
- // 用于拼接答案
- var fullAnswer strings.Builder
- // 用于存储最终完整结果
- finalResp := &ragflow.ChatCompletionResp{}
- re := regexp.MustCompile(`\[ID:\d+\]`)
- // 逐行读取流式返回
- for scanner.Scan() {
- line := strings.TrimSpace(scanner.Text())
- if line == "" {
- continue
- }
- // 只处理 SSE 格式 data: ...
- if !strings.HasPrefix(line, "data:") {
- continue
- }
- jsonStr := strings.TrimPrefix(line, "data:")
- jsonStr = strings.TrimSpace(jsonStr)
- // 流结束标记:data: true
- if jsonStr == "true" {
- break
- }
- // 解析当前分片
- var chunk ragflow.ChatCompletionResp
- if err := json.Unmarshal([]byte(jsonStr), &chunk); err != nil {
- continue
- }
- // 拼接回答内容
- if chunk.Data.Answer != "" {
- cleanAnswer := re.ReplaceAllString(chunk.Data.Answer, "")
- fullAnswer.WriteString(cleanAnswer)
- }
- // 保存最后一个分片(包含 reference/session_id 等完整信息)
- *finalResp = chunk
- }
- if err := scanner.Err(); err != nil {
- return nil, errors.New400Response(fmt.Sprintf("读取流式数据失败: %v", err))
- }
- cleanedFullAnswer := re.ReplaceAllString(fullAnswer.String(), "")
- item.Answer = cleanedFullAnswer
- return item, nil
- }
|