b_chat_session.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. package internal
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/gogf/gf/util/guid"
  6. "yx-dataset-server/app/errors"
  7. "yx-dataset-server/app/model"
  8. "yx-dataset-server/app/schema"
  9. "yx-dataset-server/library/ragflow"
  10. "yx-dataset-server/library/redis"
  11. )
  12. // NewChatSession 创建ChatSession
  13. func NewChatSession(
  14. mChatSession model.IChatSession,
  15. mTrans model.ITrans,
  16. mAssistant model.IChatAssistant,
  17. ) *ChatSession {
  18. return &ChatSession{
  19. ChatSessionModel: mChatSession,
  20. transModel: mTrans,
  21. assistantModel: mAssistant,
  22. }
  23. }
  24. // ChatSession 创建ChatSession对象
  25. type ChatSession struct {
  26. ChatSessionModel model.IChatSession
  27. transModel model.ITrans
  28. assistantModel model.IChatAssistant
  29. }
  30. // Query 查询数据
  31. func (a *ChatSession) Query(ctx context.Context, params schema.ChatSessionQueryParam, opts ...schema.ChatSessionQueryOptions) (*schema.ChatSessionQueryResult, error) {
  32. return a.ChatSessionModel.Query(ctx, params, opts...)
  33. }
  34. // Get 查询指定数据
  35. func (a *ChatSession) Get(ctx context.Context, recordID string, opts ...schema.ChatSessionQueryOptions) (*schema.ChatSession, error) {
  36. item, err := a.ChatSessionModel.Get(ctx, recordID, opts...)
  37. if err != nil {
  38. return nil, err
  39. } else if item == nil {
  40. return nil, errors.ErrNotFound
  41. }
  42. return item, nil
  43. }
  44. func (a *ChatSession) getUpdate(ctx context.Context, recordID string) (*schema.ChatSession, error) {
  45. return a.Get(ctx, recordID)
  46. }
  47. // Create 创建数据
  48. func (a *ChatSession) Create(ctx context.Context, item schema.ChatSession) (*schema.ChatSession, error) {
  49. item.RecordID = guid.S()
  50. item.CreatorId = GetUserID(ctx)
  51. name := "new_session"
  52. if item.Name != "" {
  53. name = item.Name
  54. }
  55. assistant, err := a.assistantModel.Get(ctx, item.AssistantId)
  56. if err != nil {
  57. return nil, err
  58. }
  59. s, err := ragflow.GetHttpClient().CreateSession(ctx, assistant.RagChatId, &ragflow.CreateSessionReq{
  60. Name: name,
  61. })
  62. if err != nil {
  63. return nil, err
  64. }
  65. item.RagSessionId = s.Data.Id
  66. err = a.ChatSessionModel.Create(ctx, item)
  67. if err != nil {
  68. return nil, err
  69. }
  70. key := fmt.Sprintf("chat:session:%s", item.RecordID)
  71. _ = redis.GetClient().Set(ctx, key, item, 8640000)
  72. return a.getUpdate(ctx, item.RecordID)
  73. }
  74. // Update 更新数据
  75. func (a *ChatSession) Update(ctx context.Context, recordID string, item schema.ChatSession) (*schema.ChatSession, error) {
  76. oldItem, err := a.ChatSessionModel.Get(ctx, recordID)
  77. if err != nil {
  78. return nil, err
  79. } else if oldItem == nil {
  80. return nil, errors.ErrNotFound
  81. }
  82. err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  83. err = a.ChatSessionModel.Update(ctx, recordID, item)
  84. if err != nil {
  85. return err
  86. }
  87. if oldItem.Name != item.Name {
  88. _, err = ragflow.GetHttpClient().UpdateSession(ctx, oldItem.RagChatId, oldItem.RagSessionId, &ragflow.UpdateSessionReq{
  89. Name: item.Name,
  90. })
  91. if err != nil {
  92. return err
  93. }
  94. }
  95. return nil
  96. })
  97. return a.getUpdate(ctx, recordID)
  98. }
  99. // Delete 删除数据
  100. func (a *ChatSession) Delete(ctx context.Context, recordID string) error {
  101. oldItem, err := a.ChatSessionModel.Get(ctx, recordID)
  102. if err != nil {
  103. return err
  104. } else if oldItem == nil {
  105. return errors.ErrNotFound
  106. }
  107. err = ExecTrans(ctx, a.transModel, func(ctx context.Context) error {
  108. err = a.ChatSessionModel.Delete(ctx, recordID)
  109. if err != nil {
  110. return err
  111. }
  112. _, err = ragflow.GetHttpClient().DeleteSessions(ctx, oldItem.RagChatId, []string{oldItem.RagSessionId})
  113. return err
  114. })
  115. return err
  116. }
  117. // UpdateStatus 更新状态
  118. func (a *ChatSession) UpdateStatus(ctx context.Context, recordID string, status int) error {
  119. oldItem, err := a.ChatSessionModel.Get(ctx, recordID)
  120. if err != nil {
  121. return err
  122. } else if oldItem == nil {
  123. return errors.ErrNotFound
  124. }
  125. return a.ChatSessionModel.UpdateStatus(ctx, recordID, status)
  126. }