package internal import ( "context" "fmt" "github.com/gogf/gf/util/guid" "yx-dataset-server/app/errors" "yx-dataset-server/app/model" "yx-dataset-server/app/schema" "yx-dataset-server/library/ragflow" "yx-dataset-server/library/redis" ) // NewChatSession 创建ChatSession func NewChatSession( mChatSession model.IChatSession, mTrans model.ITrans, mAssistant model.IChatAssistant, ) *ChatSession { return &ChatSession{ ChatSessionModel: mChatSession, transModel: mTrans, assistantModel: mAssistant, } } // ChatSession 创建ChatSession对象 type ChatSession struct { ChatSessionModel model.IChatSession transModel model.ITrans assistantModel model.IChatAssistant } // Query 查询数据 func (a *ChatSession) Query(ctx context.Context, params schema.ChatSessionQueryParam, opts ...schema.ChatSessionQueryOptions) (*schema.ChatSessionQueryResult, error) { return a.ChatSessionModel.Query(ctx, params, opts...) } // Get 查询指定数据 func (a *ChatSession) Get(ctx context.Context, recordID string, opts ...schema.ChatSessionQueryOptions) (*schema.ChatSession, error) { item, err := a.ChatSessionModel.Get(ctx, recordID, opts...) if err != nil { return nil, err } else if item == nil { return nil, errors.ErrNotFound } return item, nil } func (a *ChatSession) getUpdate(ctx context.Context, recordID string) (*schema.ChatSession, error) { return a.Get(ctx, recordID) } // Create 创建数据 func (a *ChatSession) Create(ctx context.Context, item schema.ChatSession) (*schema.ChatSession, error) { item.RecordID = guid.S() item.CreatorId = GetUserID(ctx) name := "new_session" if item.Name != "" { name = item.Name } assistant, err := a.assistantModel.Get(ctx, item.AssistantId) if err != nil { return nil, err } s, err := ragflow.GetHttpClient().CreateSession(ctx, assistant.RagChatId, &ragflow.CreateSessionReq{ Name: name, }) if err != nil { return nil, err } item.RagSessionId = s.Data.Id err = a.ChatSessionModel.Create(ctx, item) if err != nil { return nil, err } key := fmt.Sprintf("chat:session:%s", item.RecordID) _ = redis.GetClient().Set(ctx, key, item, 8640000) return a.getUpdate(ctx, item.RecordID) } // Update 更新数据 func (a *ChatSession) Update(ctx context.Context, recordID string, item schema.ChatSession) (*schema.ChatSession, error) { oldItem, err := a.ChatSessionModel.Get(ctx, recordID) if err != nil { return nil, err } else if oldItem == nil { return nil, errors.ErrNotFound } err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error { err = a.ChatSessionModel.Update(ctx, recordID, item) if err != nil { return err } if oldItem.Name != item.Name { _, err = ragflow.GetHttpClient().UpdateSession(ctx, oldItem.RagChatId, oldItem.RagSessionId, &ragflow.UpdateSessionReq{ Name: item.Name, }) if err != nil { return err } } return nil }) return a.getUpdate(ctx, recordID) } // Delete 删除数据 func (a *ChatSession) Delete(ctx context.Context, recordID string) error { oldItem, err := a.ChatSessionModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error { err = a.ChatSessionModel.Delete(ctx, recordID) if err != nil { return err } _, err = ragflow.GetHttpClient().DeleteSessions(ctx, oldItem.RagChatId, []string{oldItem.RagSessionId}) return err }) return err } // UpdateStatus 更新状态 func (a *ChatSession) UpdateStatus(ctx context.Context, recordID string, status int) error { oldItem, err := a.ChatSessionModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } return a.ChatSessionModel.UpdateStatus(ctx, recordID, status) }