| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- package internal
- import (
- "bufio"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "regexp"
- "strings"
- iContext "yx-dataset-server/app/context"
- "yx-dataset-server/app/errors"
- "yx-dataset-server/app/model"
- "yx-dataset-server/app/schema"
- "yx-dataset-server/library/ragflow"
- "yx-dataset-server/library/utils"
- )
- // TransFunc 定义事务执行函数
- type TransFunc func(context.Context) error
- // GetRootUser 获取root用户
- func GetRootUser() *schema.User {
- return &schema.User{
- RecordID: utils.GetConfig("root.user_name").String(),
- UserName: utils.GetConfig("root.user_name").String(),
- RealName: utils.GetConfig("root.real_name").String(),
- Photo: utils.GetConfig("root.photo").String(),
- RoleCode: "11",
- Password: utils.MD5HashString(utils.GetConfig("root.password").String()),
- }
- }
- // GetUserID 获取用户ID
- func GetUserID(ctx context.Context) string {
- userID, _ := iContext.FromUserID(ctx)
- return userID
- }
- // CheckIsRootUser 检查是否是root用户
- func CheckIsRootUser(ctx context.Context) bool {
- userId := GetUserID(ctx)
- return GetRootUser().RecordID == userId
- }
- const (
- RoleCodeSystemAdmin = "11" // 系统管理员
- RoleCodeEnterpriseAdmin = "12" // 企业管理员
- RoleCodeEmployee = "99" // 企业员工
- )
- // GetUserRoleCode 当前登录用户角色编码(root 视为系统管理员)
- func GetUserRoleCode(ctx context.Context, userModel model.IUser, roleModel model.IRole) (string, error) {
- if CheckIsRootUser(ctx) {
- return RoleCodeSystemAdmin, nil
- }
- uid := GetUserID(ctx)
- if uid == "" {
- return "", errors.ErrUnauthorized
- }
- user, err := userModel.Get(ctx, uid)
- if err != nil {
- return "", err
- }
- if user == nil {
- return "", errors.ErrUserNotFound
- }
- role, err := roleModel.Get(ctx, user.RoleId)
- if err != nil {
- return "", err
- }
- if role == nil {
- return "", errors.ErrNotFound
- }
- return role.Code, nil
- }
- // IsSystemAdmin root 或 系统管理员角色
- func IsSystemAdmin(ctx context.Context, userModel model.IUser, roleModel model.IRole) (bool, error) {
- code, err := GetUserRoleCode(ctx, userModel, roleModel)
- if err != nil {
- return false, err
- }
- return code == RoleCodeSystemAdmin, nil
- }
- // IsEnterpriseAdmin 企业管理员
- func IsEnterpriseAdmin(ctx context.Context, userModel model.IUser, roleModel model.IRole) (bool, error) {
- code, err := GetUserRoleCode(ctx, userModel, roleModel)
- if err != nil {
- return false, err
- }
- return code == RoleCodeEnterpriseAdmin, nil
- }
- // ExecTrans 执行事务
- func ExecTrans(ctx context.Context, transModel model.ITrans, fn TransFunc) error {
- if _, ok := iContext.FromTrans(ctx); ok {
- return fn(ctx)
- }
- trans, err := transModel.Begin(ctx)
- if err != nil {
- return err
- }
- defer func() {
- if r := recover(); r != nil {
- _ = transModel.Rollback(ctx, trans)
- panic(r)
- }
- }()
- err = fn(iContext.NewTrans(ctx, trans))
- if err != nil {
- _ = transModel.Rollback(ctx, trans)
- return err
- }
- return transModel.Commit(ctx, trans)
- }
- // ExecTransWithLock 执行事务(加锁)
- func ExecTransWithLock(ctx context.Context, transModel model.ITrans, fn TransFunc) error {
- if !iContext.FromTransLock(ctx) {
- ctx = iContext.NewTransLock(ctx)
- }
- return ExecTrans(ctx, transModel, fn)
- }
- func CollectStreamDataAndCreate(ctx context.Context, stream io.ReadCloser, item *schema.ChatMessage) (*schema.ChatMessage, error) {
- // 解决 token too long 报错:调大缓冲区到 10MB
- const maxBufferSize = 10 * 1024 * 1024
- scanner := bufio.NewScanner(stream)
- buf := make([]byte, maxBufferSize)
- scanner.Buffer(buf, maxBufferSize)
- // 用于拼接答案
- var fullAnswer strings.Builder
- // 用于存储最终完整结果
- finalResp := &ragflow.ChatCompletionResp{}
- re := regexp.MustCompile(`\[ID:\d+\]`)
- // 逐行读取流式返回
- for scanner.Scan() {
- line := strings.TrimSpace(scanner.Text())
- if line == "" {
- continue
- }
- // 只处理 SSE 格式 data: ...
- if !strings.HasPrefix(line, "data:") {
- continue
- }
- jsonStr := strings.TrimPrefix(line, "data:")
- jsonStr = strings.TrimSpace(jsonStr)
- // 流结束标记:data: true
- if jsonStr == "true" {
- break
- }
- // 解析当前分片
- var chunk ragflow.ChatCompletionResp
- if err := json.Unmarshal([]byte(jsonStr), &chunk); err != nil {
- continue
- }
- // 拼接回答内容
- if chunk.Data.Answer != "" {
- cleanAnswer := re.ReplaceAllString(chunk.Data.Answer, "")
- fullAnswer.WriteString(cleanAnswer)
- }
- // 保存最后一个分片(包含 reference/session_id 等完整信息)
- *finalResp = chunk
- }
- if err := scanner.Err(); err != nil {
- return nil, errors.New400Response(fmt.Sprintf("读取流式数据失败: %v", err))
- }
- cleanedFullAnswer := re.ReplaceAllString(fullAnswer.String(), "")
- item.Answer = cleanedFullAnswer
- return item, nil
- }
|