b_common.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. package internal
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "regexp"
  9. "strings"
  10. iContext "yx-dataset-server/app/context"
  11. "yx-dataset-server/app/errors"
  12. "yx-dataset-server/app/model"
  13. "yx-dataset-server/app/schema"
  14. "yx-dataset-server/library/ragflow"
  15. "yx-dataset-server/library/utils"
  16. )
  17. // TransFunc 定义事务执行函数
  18. type TransFunc func(context.Context) error
  19. // GetRootUser 获取root用户
  20. func GetRootUser() *schema.User {
  21. return &schema.User{
  22. RecordID: utils.GetConfig("root.user_name").String(),
  23. UserName: utils.GetConfig("root.user_name").String(),
  24. RealName: utils.GetConfig("root.real_name").String(),
  25. Photo: utils.GetConfig("root.photo").String(),
  26. Password: utils.MD5HashString(utils.GetConfig("root.password").String()),
  27. }
  28. }
  29. // GetUserID 获取用户ID
  30. func GetUserID(ctx context.Context) string {
  31. userID, _ := iContext.FromUserID(ctx)
  32. return userID
  33. }
  34. // CheckIsRootUser 检查是否是root用户
  35. func CheckIsRootUser(ctx context.Context) bool {
  36. userId := GetUserID(ctx)
  37. return GetRootUser().RecordID == userId
  38. }
  39. // ExecTrans 执行事务
  40. func ExecTrans(ctx context.Context, transModel model.ITrans, fn TransFunc) error {
  41. if _, ok := iContext.FromTrans(ctx); ok {
  42. return fn(ctx)
  43. }
  44. trans, err := transModel.Begin(ctx)
  45. if err != nil {
  46. return err
  47. }
  48. defer func() {
  49. if r := recover(); r != nil {
  50. _ = transModel.Rollback(ctx, trans)
  51. panic(r)
  52. }
  53. }()
  54. err = fn(iContext.NewTrans(ctx, trans))
  55. if err != nil {
  56. _ = transModel.Rollback(ctx, trans)
  57. return err
  58. }
  59. return transModel.Commit(ctx, trans)
  60. }
  61. // ExecTransWithLock 执行事务(加锁)
  62. func ExecTransWithLock(ctx context.Context, transModel model.ITrans, fn TransFunc) error {
  63. if !iContext.FromTransLock(ctx) {
  64. ctx = iContext.NewTransLock(ctx)
  65. }
  66. return ExecTrans(ctx, transModel, fn)
  67. }
  68. func CollectStreamDataAndCreate(ctx context.Context, stream io.ReadCloser, item *schema.ChatMessage) (*schema.ChatMessage, error) {
  69. // 解决 token too long 报错:调大缓冲区到 10MB
  70. const maxBufferSize = 10 * 1024 * 1024
  71. scanner := bufio.NewScanner(stream)
  72. buf := make([]byte, maxBufferSize)
  73. scanner.Buffer(buf, maxBufferSize)
  74. // 用于拼接答案
  75. var fullAnswer strings.Builder
  76. // 用于存储最终完整结果
  77. finalResp := &ragflow.ChatCompletionResp{}
  78. re := regexp.MustCompile(`\[ID:\d+\]`)
  79. // 逐行读取流式返回
  80. for scanner.Scan() {
  81. line := strings.TrimSpace(scanner.Text())
  82. if line == "" {
  83. continue
  84. }
  85. // 只处理 SSE 格式 data: ...
  86. if !strings.HasPrefix(line, "data:") {
  87. continue
  88. }
  89. jsonStr := strings.TrimPrefix(line, "data:")
  90. jsonStr = strings.TrimSpace(jsonStr)
  91. // 流结束标记:data: true
  92. if jsonStr == "true" {
  93. break
  94. }
  95. // 解析当前分片
  96. var chunk ragflow.ChatCompletionResp
  97. if err := json.Unmarshal([]byte(jsonStr), &chunk); err != nil {
  98. continue
  99. }
  100. // 拼接回答内容
  101. if chunk.Data.Answer != "" {
  102. cleanAnswer := re.ReplaceAllString(chunk.Data.Answer, "")
  103. fullAnswer.WriteString(cleanAnswer)
  104. }
  105. // 保存最后一个分片(包含 reference/session_id 等完整信息)
  106. *finalResp = chunk
  107. }
  108. if err := scanner.Err(); err != nil {
  109. return nil, errors.New400Response(fmt.Sprintf("读取流式数据失败: %v", err))
  110. }
  111. cleanedFullAnswer := re.ReplaceAllString(fullAnswer.String(), "")
  112. item.Answer = cleanedFullAnswer
  113. return item, nil
  114. }