b_common.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package internal
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "regexp"
  9. "strings"
  10. iContext "yx-dataset-server/app/context"
  11. "yx-dataset-server/app/errors"
  12. "yx-dataset-server/app/model"
  13. "yx-dataset-server/app/schema"
  14. "yx-dataset-server/library/ragflow"
  15. "yx-dataset-server/library/utils"
  16. )
  17. // TransFunc 定义事务执行函数
  18. type TransFunc func(context.Context) error
  19. // GetRootUser 获取root用户
  20. func GetRootUser() *schema.User {
  21. return &schema.User{
  22. RecordID: utils.GetConfig("root.user_name").String(),
  23. UserName: utils.GetConfig("root.user_name").String(),
  24. RealName: utils.GetConfig("root.real_name").String(),
  25. Photo: utils.GetConfig("root.photo").String(),
  26. RoleCode: "11",
  27. Password: utils.MD5HashString(utils.GetConfig("root.password").String()),
  28. }
  29. }
  30. // GetUserID 获取用户ID
  31. func GetUserID(ctx context.Context) string {
  32. userID, _ := iContext.FromUserID(ctx)
  33. return userID
  34. }
  35. // CheckIsRootUser 检查是否是root用户
  36. func CheckIsRootUser(ctx context.Context) bool {
  37. userId := GetUserID(ctx)
  38. return GetRootUser().RecordID == userId
  39. }
  40. const (
  41. RoleCodeSystemAdmin = "11" // 系统管理员
  42. RoleCodeEnterpriseAdmin = "12" // 企业管理员
  43. RoleCodeEmployee = "99" // 企业员工
  44. )
  45. // GetUserRoleCode 当前登录用户角色编码(root 视为系统管理员)
  46. func GetUserRoleCode(ctx context.Context, userModel model.IUser, roleModel model.IRole) (string, error) {
  47. if CheckIsRootUser(ctx) {
  48. return RoleCodeSystemAdmin, nil
  49. }
  50. uid := GetUserID(ctx)
  51. if uid == "" {
  52. return "", errors.ErrUnauthorized
  53. }
  54. user, err := userModel.Get(ctx, uid)
  55. if err != nil {
  56. return "", err
  57. }
  58. if user == nil {
  59. return "", errors.ErrUserNotFound
  60. }
  61. role, err := roleModel.Get(ctx, user.RoleId)
  62. if err != nil {
  63. return "", err
  64. }
  65. if role == nil {
  66. return "", errors.ErrNotFound
  67. }
  68. return role.Code, nil
  69. }
  70. // IsSystemAdmin root 或 系统管理员角色
  71. func IsSystemAdmin(ctx context.Context, userModel model.IUser, roleModel model.IRole) (bool, error) {
  72. code, err := GetUserRoleCode(ctx, userModel, roleModel)
  73. if err != nil {
  74. return false, err
  75. }
  76. return code == RoleCodeSystemAdmin, nil
  77. }
  78. // IsEnterpriseAdmin 企业管理员
  79. func IsEnterpriseAdmin(ctx context.Context, userModel model.IUser, roleModel model.IRole) (bool, error) {
  80. code, err := GetUserRoleCode(ctx, userModel, roleModel)
  81. if err != nil {
  82. return false, err
  83. }
  84. return code == RoleCodeEnterpriseAdmin, nil
  85. }
  86. // ExecTrans 执行事务
  87. func ExecTrans(ctx context.Context, transModel model.ITrans, fn TransFunc) error {
  88. if _, ok := iContext.FromTrans(ctx); ok {
  89. return fn(ctx)
  90. }
  91. trans, err := transModel.Begin(ctx)
  92. if err != nil {
  93. return err
  94. }
  95. defer func() {
  96. if r := recover(); r != nil {
  97. _ = transModel.Rollback(ctx, trans)
  98. panic(r)
  99. }
  100. }()
  101. err = fn(iContext.NewTrans(ctx, trans))
  102. if err != nil {
  103. _ = transModel.Rollback(ctx, trans)
  104. return err
  105. }
  106. return transModel.Commit(ctx, trans)
  107. }
  108. // ExecTransWithLock 执行事务(加锁)
  109. func ExecTransWithLock(ctx context.Context, transModel model.ITrans, fn TransFunc) error {
  110. if !iContext.FromTransLock(ctx) {
  111. ctx = iContext.NewTransLock(ctx)
  112. }
  113. return ExecTrans(ctx, transModel, fn)
  114. }
  115. func CollectStreamDataAndCreate(ctx context.Context, stream io.ReadCloser, item *schema.ChatMessage) (*schema.ChatMessage, error) {
  116. // 解决 token too long 报错:调大缓冲区到 10MB
  117. const maxBufferSize = 10 * 1024 * 1024
  118. scanner := bufio.NewScanner(stream)
  119. buf := make([]byte, maxBufferSize)
  120. scanner.Buffer(buf, maxBufferSize)
  121. // 用于拼接答案
  122. var fullAnswer strings.Builder
  123. // 用于存储最终完整结果
  124. finalResp := &ragflow.ChatCompletionResp{}
  125. re := regexp.MustCompile(`\[ID:\d+\]`)
  126. // 逐行读取流式返回
  127. for scanner.Scan() {
  128. line := strings.TrimSpace(scanner.Text())
  129. if line == "" {
  130. continue
  131. }
  132. // 只处理 SSE 格式 data: ...
  133. if !strings.HasPrefix(line, "data:") {
  134. continue
  135. }
  136. jsonStr := strings.TrimPrefix(line, "data:")
  137. jsonStr = strings.TrimSpace(jsonStr)
  138. // 流结束标记:data: true
  139. if jsonStr == "true" {
  140. break
  141. }
  142. // 解析当前分片
  143. var chunk ragflow.ChatCompletionResp
  144. if err := json.Unmarshal([]byte(jsonStr), &chunk); err != nil {
  145. continue
  146. }
  147. // 拼接回答内容
  148. if chunk.Data.Answer != "" {
  149. cleanAnswer := re.ReplaceAllString(chunk.Data.Answer, "")
  150. fullAnswer.WriteString(cleanAnswer)
  151. }
  152. // 保存最后一个分片(包含 reference/session_id 等完整信息)
  153. *finalResp = chunk
  154. }
  155. if err := scanner.Err(); err != nil {
  156. return nil, errors.New400Response(fmt.Sprintf("读取流式数据失败: %v", err))
  157. }
  158. cleanedFullAnswer := re.ReplaceAllString(fullAnswer.String(), "")
  159. item.Answer = cleanedFullAnswer
  160. return item, nil
  161. }