package internal import ( "context" "fmt" "github.com/gogf/gf/v2/util/guid" "strings" "time" "yx-dataset-server/app/errors" "yx-dataset-server/app/model" "yx-dataset-server/app/schema" "yx-dataset-server/library/ragflow" "yx-dataset-server/library/redis" ) // NewChatSession 创建ChatSession func NewChatSession( mChatSession model.IChatSession, mTrans model.ITrans, mAssistant model.IChatAssistant, mChatDataset model.IChatDataset, mDataset model.IDataset, mMessage model.IChatMessage, mUser model.IUser, ) *ChatSession { return &ChatSession{ ChatSessionModel: mChatSession, transModel: mTrans, assistantModel: mAssistant, chatDatasetModel: mChatDataset, datasetModel: mDataset, messageModel: mMessage, userModel: mUser, } } // ChatSession 创建ChatSession对象 type ChatSession struct { ChatSessionModel model.IChatSession transModel model.ITrans assistantModel model.IChatAssistant chatDatasetModel model.IChatDataset datasetModel model.IDataset messageModel model.IChatMessage userModel model.IUser } // Query 查询数据 func (a *ChatSession) Query(ctx context.Context, params schema.ChatSessionQueryParam, opts ...schema.ChatSessionQueryOptions) (*schema.ChatSessionQueryResult, error) { result, err := a.ChatSessionModel.Query(ctx, params, opts...) if err != nil { return nil, err } user, err := a.userModel.Query(ctx, schema.UserQueryParam{}) if err != nil { return nil, err } result.Data.FillCreator(user.Data) return result, nil } // Get 查询指定数据 func (a *ChatSession) Get(ctx context.Context, recordID string, opts ...schema.ChatSessionQueryOptions) (*schema.ChatSession, error) { item, err := a.ChatSessionModel.Get(ctx, recordID, opts...) if err != nil { return nil, err } else if item == nil { return nil, errors.ErrNotFound } user, err := a.userModel.Get(ctx, item.CreatorId) if err != nil { return nil, err } item.CreatorName = user.RealName message, err := a.messageModel.Query(ctx, schema.ChatMessageQueryParam{SessionId: recordID}) if err != nil { return nil, err } item.Messages = message.Data return item, nil } // Create 创建数据 func (a *ChatSession) Create(ctx context.Context, item schema.ChatSession) (*schema.ChatSession, error) { item.RecordID = guid.S() item.CreatorId = GetUserID(ctx) item.Name = fmt.Sprintf("%s%s", time.Now().Format("2006-01-02 15:04:05 "), "会话") assistant, err := a.assistantModel.Get(ctx, item.AssistantId) if err != nil { return nil, err } s, err := ragflow.GetHttpClient().CreateSession(ctx, assistant.RagChatId, &ragflow.CreateSessionReq{ Name: item.Name, }) if err != nil { return nil, err } item.RagSessionId = s.Data.Id err = a.ChatSessionModel.Create(ctx, item) if err != nil { return nil, err } if item.Source == 2 { key := fmt.Sprintf("chat:session:%s", item.RecordID) _ = redis.GetClient().Set(ctx, key, item, 7200) } return a.Get(ctx, item.RecordID) } func (a *ChatSession) H5Create(ctx context.Context, item schema.ChatSession) (*schema.ChatSession, error) { item.RecordID = guid.S() item.CreatorId = GetUserID(ctx) item.Name = fmt.Sprintf("%s%s", time.Now().Format("2006-01-02 15:04:05 "), "会话") item.Source = 1 assistant, err := a.assistantModel.Get(ctx, item.AssistantId) if err != nil { return nil, err } item.RagChatId = assistant.RagChatId chatDataset, err := a.chatDatasetModel.Query(ctx, schema.ChatDatasetQueryParam{ChatAssistantId: assistant.RecordID}) if err != nil { return nil, err } dataset, err := a.datasetModel.Query(ctx, schema.DatasetQueryParam{RecordIds: chatDataset.Data.ToDatasetIds()}) if err != nil { return nil, err } var datasetName []string for _, v := range dataset.Data { quotedName := fmt.Sprintf(`《%s》`, v.Name) datasetName = append(datasetName, quotedName) } err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error { s, err := ragflow.GetHttpClient().CreateSession(ctx, assistant.RagChatId, &ragflow.CreateSessionReq{ Name: item.Name, }) if err != nil { return err } item.RagSessionId = s.Data.Id err = a.ChatSessionModel.Create(ctx, item) if err != nil { return err } //key := fmt.Sprintf("chat:session:%s", item.RecordID) //_ = redis.GetClient().Set(ctx, key, item, 8640000) answer := fmt.Sprintf("你好,欢迎使用知识库问答助手!关于%s的疑问,我都会尽力为你解答,快来提问吧!", strings.Join(datasetName, ",")) message := schema.ChatMessage{ RecordID: guid.S(), UserId: item.CreatorId, AssistantId: item.AssistantId, SessionId: item.RecordID, RagSessionId: s.Data.Id, Question: "", Answer: strings.ReplaceAll(answer, `\`, ``), CreatorId: item.CreatorId, } err = a.messageModel.Create(ctx, message) if err != nil { return err } item.Messages = append(item.Messages, &message) return nil }) return a.Get(ctx, item.RecordID) } // Update 更新数据 func (a *ChatSession) Update(ctx context.Context, recordID string, item schema.ChatSession) (*schema.ChatSession, error) { oldItem, err := a.ChatSessionModel.Get(ctx, recordID) if err != nil { return nil, err } else if oldItem == nil { return nil, errors.ErrNotFound } // update 只修改名称 if item.Name == oldItem.Name { return nil, err } err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error { oldItem.Name = item.Name err = a.ChatSessionModel.Update(ctx, recordID, *oldItem) if err != nil { return err } if oldItem.Name != item.Name { _, err = ragflow.GetHttpClient().UpdateSession(ctx, oldItem.RagChatId, oldItem.RagSessionId, &ragflow.UpdateSessionReq{ Name: item.Name, }) if err != nil { return err } } return nil }) return a.Get(ctx, recordID) } // Delete 删除数据 func (a *ChatSession) Delete(ctx context.Context, recordID string) error { oldItem, err := a.ChatSessionModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error { err = a.ChatSessionModel.Delete(ctx, recordID) if err != nil { return err } _, err = ragflow.GetHttpClient().DeleteSessions(ctx, oldItem.RagChatId, []string{oldItem.RagSessionId}) return err }) return err } // UpdateStatus 更新状态 func (a *ChatSession) UpdateStatus(ctx context.Context, recordID string, status int) error { oldItem, err := a.ChatSessionModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } return a.ChatSessionModel.UpdateStatus(ctx, recordID, status) }