package internal import ( "context" "github.com/gogf/gf/util/guid" "yx-dataset-server/app/errors" "yx-dataset-server/app/model" "yx-dataset-server/app/schema" "yx-dataset-server/library/ragflow" ) // NewChatAssistant 创建ChatAssistant func NewChatAssistant( mChatAssistant model.IChatAssistant, mChatDataset model.IChatDataset, mTrans model.ITrans, mDataset model.IDataset, mUser model.IUser, mSession model.IChatSession, ) *ChatAssistant { return &ChatAssistant{ ChatAssistantModel: mChatAssistant, chatDatasetModel: mChatDataset, transModel: mTrans, datasetModel: mDataset, userModel: mUser, sessionModel: mSession, } } // ChatAssistant 创建ChatAssistant对象 type ChatAssistant struct { ChatAssistantModel model.IChatAssistant chatDatasetModel model.IChatDataset transModel model.ITrans datasetModel model.IDataset userModel model.IUser sessionModel model.IChatSession } // Query 查询数据 func (a *ChatAssistant) Query(ctx context.Context, params schema.ChatAssistantQueryParam, opts ...schema.ChatAssistantQueryOptions) (*schema.ChatAssistantQueryResult, error) { var userQueryParam schema.UserQueryParam userQueryParam.RoleCode = []string{"11", "12"} if !CheckIsRootUser(ctx) { user, err := a.userModel.Get(ctx, GetUserID(ctx)) if err != nil { return nil, err } params.OrgId = user.OrgId userQueryParam.OrgId = user.OrgId } result, err := a.ChatAssistantModel.Query(ctx, params, opts...) if err != nil { return nil, err } chatDataset, err := a.chatDatasetModel.Query(ctx, schema.ChatDatasetQueryParam{}) if err != nil { return nil, err } dataset, err := a.datasetModel.Query(ctx, schema.DatasetQueryParam{}) if err != nil { return nil, err } users, err := a.userModel.Query(ctx, userQueryParam) if err != nil { return nil, err } result.Data.FillCreator(users.Data) result.Data.FillDatasetId(chatDataset.Data) result.Data.FillDataset(dataset.Data) return result, nil } // Get 查询指定数据 func (a *ChatAssistant) Get(ctx context.Context, recordID string, opts ...schema.ChatAssistantQueryOptions) (*schema.ChatAssistant, error) { item, err := a.ChatAssistantModel.Get(ctx, recordID, opts...) if err != nil { return nil, err } else if item == nil { return nil, errors.ErrNotFound } user, err := a.userModel.Get(ctx, item.CreatorId) if err != nil { return nil, err } item.CreatorName = user.RealName chatDataset, err := a.chatDatasetModel.Query(ctx, schema.ChatDatasetQueryParam{}) if err != nil { return nil, err } dataset, err := a.datasetModel.Query(ctx, schema.DatasetQueryParam{RecordIds: chatDataset.Data.ToDatasetIds()}) if err != nil { return nil, err } users, err := a.userModel.Query(ctx, schema.UserQueryParam{RoleCode: []string{"11", "12"}}) if err != nil { return nil, err } dataset.Data.FillCreator(users.Data) item.Datasets = dataset.Data sessions, err := a.sessionModel.Query(ctx, schema.ChatSessionQueryParam{AssistantId: item.RecordID}) if err != nil { return nil, err } item.Sessions = sessions.Data return item, nil } // Create 创建数据 func (a *ChatAssistant) Create(ctx context.Context, item schema.ChatAssistant) error { item.RecordID = guid.S() item.CreatorId = GetUserID(ctx) name := "new_chat" if item.Name != "" { name = item.Name } if len(item.Datasets) == 0 { return errors.New400Response("关联知识库不能为空") } if !CheckIsRootUser(ctx) { user, err := a.userModel.Get(ctx, item.CreatorId) if err != nil { return err } item.OrgId = user.OrgId } err := ExecTrans(ctx, a.transModel, func(ctx context.Context) error { for _, v := range item.Datasets { err := a.chatDatasetModel.Create(ctx, schema.ChatDataset{ RecordID: guid.S(), ChatAssistantId: item.RecordID, DatasetId: v.RecordID, }) if err != nil { return err } } resp, err := ragflow.GetHttpClient().CreateChat(ctx, &ragflow.CreateChatReq{ Name: name, DatasetIDs: item.Datasets.ToRagDataIds(), }) if err != nil { return err } item.RagChatId = resp.Data.ID return a.ChatAssistantModel.Create(ctx, item) }) return err } // Update 更新数据 func (a *ChatAssistant) Update(ctx context.Context, recordID string, item schema.ChatAssistant) error { oldItem, err := a.ChatAssistantModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error { err = a.ChatAssistantModel.Update(ctx, recordID, item) if err != nil { return err } if item.Name != oldItem.Name { _, err = ragflow.GetHttpClient().UpdateChat(ctx, item.RagChatId, &ragflow.UpdateChatReq{ Name: item.Name, }) if err != nil { return err } } return nil }) return err } // Delete 删除数据 func (a *ChatAssistant) Delete(ctx context.Context, recordID string) error { oldItem, err := a.ChatAssistantModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error { err = a.ChatAssistantModel.Delete(ctx, recordID) if err != nil { return err } _, err = ragflow.GetHttpClient().DeleteChats(ctx, []string{oldItem.RagChatId}) if err != nil { return err } return nil }) return err } // UpdateStatus 更新状态 func (a *ChatAssistant) UpdateStatus(ctx context.Context, recordID string, status int) error { oldItem, err := a.ChatAssistantModel.Get(ctx, recordID) if err != nil { return err } else if oldItem == nil { return errors.ErrNotFound } return a.ChatAssistantModel.UpdateStatus(ctx, recordID, status) }