| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- package internal
- import (
- "context"
- "fmt"
- "github.com/gogf/gf/v2/util/guid"
- "strings"
- "time"
- "yx-dataset-server/app/errors"
- "yx-dataset-server/app/model"
- "yx-dataset-server/app/schema"
- "yx-dataset-server/library/ragflow"
- "yx-dataset-server/library/redis"
- )
- // NewChatSession 创建ChatSession
- func NewChatSession(
- mChatSession model.IChatSession,
- mTrans model.ITrans,
- mAssistant model.IChatAssistant,
- mChatDataset model.IChatDataset,
- mDataset model.IDataset,
- mMessage model.IChatMessage,
- mUser model.IUser,
- ) *ChatSession {
- return &ChatSession{
- ChatSessionModel: mChatSession,
- transModel: mTrans,
- assistantModel: mAssistant,
- chatDatasetModel: mChatDataset,
- datasetModel: mDataset,
- messageModel: mMessage,
- userModel: mUser,
- }
- }
- // ChatSession 创建ChatSession对象
- type ChatSession struct {
- ChatSessionModel model.IChatSession
- transModel model.ITrans
- assistantModel model.IChatAssistant
- chatDatasetModel model.IChatDataset
- datasetModel model.IDataset
- messageModel model.IChatMessage
- userModel model.IUser
- }
- // Query 查询数据
- func (a *ChatSession) Query(ctx context.Context, params schema.ChatSessionQueryParam, opts ...schema.ChatSessionQueryOptions) (*schema.ChatSessionQueryResult, error) {
- result, err := a.ChatSessionModel.Query(ctx, params, opts...)
- if err != nil {
- return nil, err
- }
- user, err := a.userModel.Query(ctx, schema.UserQueryParam{})
- if err != nil {
- return nil, err
- }
- result.Data.FillCreator(user.Data)
- return result, nil
- }
- // Get 查询指定数据
- func (a *ChatSession) Get(ctx context.Context, recordID string, opts ...schema.ChatSessionQueryOptions) (*schema.ChatSession, error) {
- item, err := a.ChatSessionModel.Get(ctx, recordID, opts...)
- if err != nil {
- return nil, err
- } else if item == nil {
- return nil, errors.ErrNotFound
- }
- user, err := a.userModel.Get(ctx, item.CreatorId)
- if err != nil {
- return nil, err
- }
- item.CreatorName = user.RealName
- message, err := a.messageModel.Query(ctx, schema.ChatMessageQueryParam{SessionId: recordID})
- if err != nil {
- return nil, err
- }
- item.Messages = message.Data
- return item, nil
- }
- // Create 创建数据
- func (a *ChatSession) Create(ctx context.Context, item schema.ChatSession) (*schema.ChatSession, error) {
- item.RecordID = guid.S()
- item.CreatorId = GetUserID(ctx)
- item.Name = fmt.Sprintf("%s%s", time.Now().Format("2006-01-02 15:04:05 "), "会话")
- assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
- if err != nil {
- return nil, err
- }
- s, err := ragflow.GetHttpClient().CreateSession(ctx, assistant.RagChatId, &ragflow.CreateSessionReq{
- Name: item.Name,
- })
- if err != nil {
- return nil, err
- }
- item.RagSessionId = s.Data.Id
- err = a.ChatSessionModel.Create(ctx, item)
- if err != nil {
- return nil, err
- }
- if item.Source == 2 {
- key := fmt.Sprintf("chat:session:%s", item.RecordID)
- _ = redis.GetClient().Set(ctx, key, item, 7200)
- }
- return a.Get(ctx, item.RecordID)
- }
- func (a *ChatSession) H5Create(ctx context.Context, item schema.ChatSession) (*schema.ChatSession, error) {
- item.RecordID = guid.S()
- item.CreatorId = GetUserID(ctx)
- item.Name = fmt.Sprintf("%s%s", time.Now().Format("2006-01-02 15:04:05 "), "会话")
- item.Source = 1
- assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
- if err != nil {
- return nil, err
- }
- item.RagChatId = assistant.RagChatId
- chatDataset, err := a.chatDatasetModel.Query(ctx, schema.ChatDatasetQueryParam{ChatAssistantId: assistant.RecordID})
- if err != nil {
- return nil, err
- }
- dataset, err := a.datasetModel.Query(ctx, schema.DatasetQueryParam{RecordIds: chatDataset.Data.ToDatasetIds()})
- if err != nil {
- return nil, err
- }
- var datasetName []string
- for _, v := range dataset.Data {
- quotedName := fmt.Sprintf(`《%s》`, v.Name)
- datasetName = append(datasetName, quotedName)
- }
- err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
- s, err := ragflow.GetHttpClient().CreateSession(ctx, assistant.RagChatId, &ragflow.CreateSessionReq{
- Name: item.Name,
- })
- if err != nil {
- return err
- }
- item.RagSessionId = s.Data.Id
- err = a.ChatSessionModel.Create(ctx, item)
- if err != nil {
- return err
- }
- //key := fmt.Sprintf("chat:session:%s", item.RecordID)
- //_ = redis.GetClient().Set(ctx, key, item, 8640000)
- answer := fmt.Sprintf("你好,欢迎使用知识库问答助手!关于%s的疑问,我都会尽力为你解答,快来提问吧!", strings.Join(datasetName, ","))
- message := schema.ChatMessage{
- RecordID: guid.S(),
- UserId: item.CreatorId,
- AssistantId: item.AssistantId,
- SessionId: item.RecordID,
- RagSessionId: s.Data.Id,
- Question: "",
- Answer: strings.ReplaceAll(answer, `\`, ``),
- CreatorId: item.CreatorId,
- }
- err = a.messageModel.Create(ctx, message)
- if err != nil {
- return err
- }
- item.Messages = append(item.Messages, &message)
- return nil
- })
- return a.Get(ctx, item.RecordID)
- }
- // Update 更新数据
- func (a *ChatSession) Update(ctx context.Context, recordID string, item schema.ChatSession) (*schema.ChatSession, error) {
- oldItem, err := a.ChatSessionModel.Get(ctx, recordID)
- if err != nil {
- return nil, err
- } else if oldItem == nil {
- return nil, errors.ErrNotFound
- }
- // update 只修改名称
- if item.Name == oldItem.Name {
- return nil, err
- }
- err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
- oldItem.Name = item.Name
- err = a.ChatSessionModel.Update(ctx, recordID, *oldItem)
- if err != nil {
- return err
- }
- if oldItem.Name != item.Name {
- _, err = ragflow.GetHttpClient().UpdateSession(ctx, oldItem.RagChatId, oldItem.RagSessionId, &ragflow.UpdateSessionReq{
- Name: item.Name,
- })
- if err != nil {
- return err
- }
- }
- return nil
- })
- return a.Get(ctx, recordID)
- }
- // Delete 删除数据
- func (a *ChatSession) Delete(ctx context.Context, recordID string) error {
- oldItem, err := a.ChatSessionModel.Get(ctx, recordID)
- if err != nil {
- return err
- } else if oldItem == nil {
- return errors.ErrNotFound
- }
- err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
- err = a.ChatSessionModel.Delete(ctx, recordID)
- if err != nil {
- return err
- }
- _, err = ragflow.GetHttpClient().DeleteSessions(ctx, oldItem.RagChatId, []string{oldItem.RagSessionId})
- return err
- })
- return err
- }
- // UpdateStatus 更新状态
- func (a *ChatSession) UpdateStatus(ctx context.Context, recordID string, status int) error {
- oldItem, err := a.ChatSessionModel.Get(ctx, recordID)
- if err != nil {
- return err
- } else if oldItem == nil {
- return errors.ErrNotFound
- }
- return a.ChatSessionModel.UpdateStatus(ctx, recordID, status)
- }
|