api.go 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. package ragflow
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/gogf/gf/v2/frame/g"
  8. "github.com/gogf/gf/v2/net/gclient"
  9. "github.com/gogf/gf/v2/os/gctx"
  10. "github.com/gogf/gf/v2/os/glog"
  11. "github.com/gogf/gf/v2/util/gconv"
  12. "io"
  13. "mime/multipart"
  14. "net/http"
  15. "regexp"
  16. "strings"
  17. "sync"
  18. )
  19. var (
  20. once sync.Once
  21. internalClient *Client // 内部私有客户端
  22. )
  23. // 32位小写十六进制正则校验
  24. var pipelineIDRegex = regexp.MustCompile(`^[0-9a-f]{32}$`)
  25. // Client 包装 gclient.Client
  26. type Client struct {
  27. cli *gclient.Client // 原生GoFrame HTTP客户端
  28. address string // RAGFlow服务地址
  29. apiKey string // 认证密钥
  30. }
  31. // Init 初始化客户端
  32. func Init(ctx context.Context, addr, apiKey string) *Client {
  33. once.Do(func() {
  34. internalClient = newClient(ctx, addr, apiKey)
  35. })
  36. return internalClient
  37. }
  38. // GetHttpClient 获取原生 gclient.Client
  39. func GetHttpClient() *Client {
  40. return internalClient
  41. }
  42. // newClient 创建客户端实例
  43. func newClient(ctx context.Context, addr, apiKey string) *Client {
  44. // 初始化GoFrame HTTP客户端
  45. cli := g.Client()
  46. // 设置全局默认请求头
  47. //cli.SetHeader("Content-Permission", "application/json")
  48. cli.SetHeader("Content-Type", "application/json")
  49. cli.SetHeader("Authorization", "Bearer "+apiKey)
  50. // 初始化校验
  51. if addr == "" || apiKey == "" {
  52. panic("ragflow client init failed: address or apiKey is empty")
  53. }
  54. return &Client{
  55. cli: cli,
  56. address: addr,
  57. apiKey: apiKey,
  58. }
  59. }
  60. // -------------------------- 工具函数 --------------------------
  61. func Log(ctx context.Context, v ...interface{}) {
  62. glog.Info(ctx, v...)
  63. }
  64. func NewCtx() context.Context {
  65. return gctx.New()
  66. }
  67. // -------------------------- 公共结构体 --------------------------
  68. type CommonResp struct {
  69. Code int `json:"code"`
  70. Msg string `json:"message,omitempty"`
  71. }
  72. // -------------------------- 1. 数据集接口(原有) --------------------------
  73. type CreateDatasetReq struct {
  74. Name string `json:"name"`
  75. ChunkMethod string `json:"chunk_method,omitempty"`
  76. Description string `json:"description,omitempty"`
  77. EmbeddingModel string `json:"embedding_model,omitempty"`
  78. //Avatar string `json:"avatar,omitempty"`
  79. //Permission string `json:"permission,omitempty"`
  80. //ParserConfig interface{} `json:"parser_config,omitempty"`
  81. //ParseType int `json:"parse_type,omitempty"`
  82. //PipelineID string `json:"pipeline_id,omitempty"`
  83. }
  84. type CreateDatasetResp struct {
  85. Code int `json:"code"`
  86. Data DatasetData `json:"data"`
  87. }
  88. type DatasetData struct {
  89. ID string `json:"id"`
  90. Name string `json:"name"`
  91. ChunkMethod string `json:"chunk_method"`
  92. EmbeddingModel string `json:"embedding_model"`
  93. Permission string `json:"permission"`
  94. ParserConfig interface{} `json:"parser_config"`
  95. CreateTime int64 `json:"create_time"`
  96. DocumentCount int `json:"document_count"`
  97. ChunkCount int `json:"chunk_count"`
  98. }
  99. type DeleteDatasetReq struct {
  100. IDs []string `json:"ids"`
  101. }
  102. type UpdateDatasetReq struct {
  103. Name string `json:"name,omitempty"`
  104. Language string `json:"language"`
  105. Avatar string `json:"avatar,omitempty"`
  106. Description string `json:"description,omitempty"`
  107. EmbeddingModel string `json:"embedding_model,omitempty"`
  108. Permission string `json:"permission,omitempty"`
  109. ChunkMethod string `json:"chunk_method,omitempty"`
  110. ParserConfig interface{} `json:"parser_config,omitempty"`
  111. Pagerank int `json:"pagerank,omitempty"`
  112. }
  113. // 创建知识库
  114. func (r *Client) CreateDataset(ctx context.Context, req *CreateDatasetReq) (*CreateDatasetResp, error) {
  115. if req.Name == "" {
  116. return nil, fmt.Errorf("name 是必填参数")
  117. }
  118. url := fmt.Sprintf("%s/api/v1/datasets", r.address)
  119. var resp CreateDatasetResp
  120. err := r.cli.PostVar(ctx, url, req).Scan(&resp)
  121. if err != nil {
  122. return nil, fmt.Errorf("请求失败: %v", err)
  123. }
  124. if resp.Code != 0 {
  125. return nil, fmt.Errorf("接口错误 code=%d", resp.Code)
  126. }
  127. return &resp, nil
  128. }
  129. // DeleteDataset 删除知识库
  130. func (r *Client) DeleteDataset(ctx context.Context, ids []string) (*CommonResp, error) {
  131. url := fmt.Sprintf("%s/api/v1/datasets", r.address)
  132. req := DeleteDatasetReq{IDs: ids}
  133. var resp CommonResp
  134. err := r.cli.DeleteVar(ctx, url, req).Scan(&resp)
  135. if err != nil {
  136. return nil, fmt.Errorf("请求失败: %v", err)
  137. }
  138. return &resp, nil
  139. }
  140. // UpdateDataset 更新知识库
  141. func (r *Client) UpdateDataset(ctx context.Context, datasetID string, req *UpdateDatasetReq) (*CommonResp, error) {
  142. if datasetID == "" {
  143. return nil, fmt.Errorf("datasetID 不能为空")
  144. }
  145. url := fmt.Sprintf("%s/api/v1/datasets/%s", r.address, datasetID)
  146. var resp CommonResp
  147. err := r.cli.PutVar(ctx, url, req).Scan(&resp)
  148. if err != nil {
  149. return nil, fmt.Errorf("请求失败: %v", err)
  150. }
  151. return &resp, nil
  152. }
  153. // -------------------------- 2. 文档接口(原有) --------------------------
  154. type UploadFileReq struct {
  155. File []*FileInfo `json:"file"`
  156. }
  157. type FileInfo struct {
  158. FileName string `json:"file_name"` // 文件名称
  159. Url string `json:"url"` // 文件地址
  160. }
  161. type UploadDocumentResp struct {
  162. Code int `json:"code"`
  163. Data []DocumentDetail `json:"data"`
  164. Message string `json:"message"` // 失败时返回
  165. }
  166. // UploadDocumentV2Req V2版本上传文档请求(直接接收文件数据)
  167. type UploadDocumentV2Req struct {
  168. FileName string // 文件名
  169. FileData []byte // 文件数据
  170. ChunkMethod string `json:"chunk_method,omitempty"` // 分块方式
  171. }
  172. // UploadDocumentV2Resp V2版本上传文档响应
  173. type UploadDocumentV2Resp struct {
  174. Code int `json:"code"`
  175. Data []DocumentDetail `json:"data"`
  176. Message string `json:"message"`
  177. }
  178. type DocumentDetail struct {
  179. ID string `json:"id"`
  180. Name string `json:"name"`
  181. DatasetID string `json:"dataset_id"`
  182. Size int64 `json:"size"`
  183. Run string `json:"run"`
  184. ChunkMethod string `json:"chunk_method"`
  185. ParserConfig interface{} `json:"parser_config"`
  186. Status string `json:"status"`
  187. }
  188. type UpdateDocumentReq struct {
  189. Name string `json:"name,omitempty"`
  190. MetaFields interface{} `json:"meta_fields,omitempty"`
  191. ChunkMethod string `json:"chunk_method,omitempty"`
  192. ParserConfig interface{} `json:"parser_config,omitempty"`
  193. Enabled int `json:"enabled,omitempty"`
  194. }
  195. type UpdateDocumentResp struct {
  196. Code int `json:"code"`
  197. Message string `json:"message"` // 失败时返回
  198. Data DocumentDetail `json:"data"`
  199. }
  200. type ListDocumentReq struct {
  201. Page int `json:"page,omitempty"`
  202. PageSize int `json:"page_size,omitempty"`
  203. OrderBy string `json:"orderby,omitempty"`
  204. Desc bool `json:"desc,omitempty"`
  205. Keywords string `json:"keywords,omitempty"`
  206. DocumentID string `json:"id,omitempty"`
  207. DocumentName string `json:"name,omitempty"`
  208. CreateTimeFrom int64 `json:"create_time_from,omitempty"`
  209. CreateTimeTo int64 `json:"create_time_to,omitempty"`
  210. Suffix string `json:"suffix,omitempty"`
  211. Run string `json:"run,omitempty"`
  212. MetadataCondition string `json:"metadata_condition,omitempty"`
  213. }
  214. type ListDocumentResp struct {
  215. Code int `json:"code"`
  216. Data struct {
  217. Docs []DocumentDetail `json:"docs"`
  218. Total int `json:"total_datasets"`
  219. } `json:"data"`
  220. }
  221. type DeleteDocumentsReq struct {
  222. IDs []string `json:"ids"`
  223. }
  224. type ParseDocumentReq struct {
  225. DocumentIDs []string `json:"document_ids"`
  226. }
  227. //func (r *Client) UploadDocument(ctx context.Context, datasetID, fileName string, url string) error {
  228. // // 参数校验
  229. // remoteFileUrl := url
  230. // apiUrl := fmt.Sprintf("%s/api/v1/datasets/%s/documents", r.address, datasetID)
  231. // fieldName := "file"
  232. //
  233. // resp, err := http.Get(remoteFileUrl)
  234. // if err != nil {
  235. // panic("下载远程文件失败:" + err.Error())
  236. // }
  237. // defer resp.Body.Close()
  238. //
  239. // body := &bytes.Buffer{}
  240. // writer := multipart.NewWriter(body)
  241. //
  242. // formFile, _ := writer.CreateFormFile(fieldName, "xinfengjinghua.pdf")
  243. // io.Copy(formFile, resp.Body)
  244. // _ = writer.Close()
  245. //
  246. // req, _ := http.NewRequest("POST", apiUrl, body)
  247. // req.Header.Set("Content-Permission", writer.FormDataContentType())
  248. // req.Header.Set("Authorization", "Bearer "+"ragflow-nkGJCVuFSxgt7X6AkzuOXjx_9Q3hN3zAUP9LNpmJOFk")
  249. //
  250. // client := &http.Client{}
  251. // uploadResp, err := client.Do(req)
  252. // if err != nil {
  253. // panic("上传失败:" + err.Error())
  254. // }
  255. // defer uploadResp.Body.Close()
  256. //
  257. // result, _ := io.ReadAll(uploadResp.Body)
  258. // println("上传状态:", uploadResp.Status)
  259. // println("接口返回:", string(result))
  260. //
  261. // return nil
  262. //}
  263. func (r *Client) UploadDocument(ctx context.Context, datasetID string, req UploadFileReq) (*UploadDocumentResp, error) {
  264. // 参数校验
  265. if datasetID == "" {
  266. return nil, fmt.Errorf("datasetID 不能为空")
  267. }
  268. body := &bytes.Buffer{}
  269. writer := multipart.NewWriter(body)
  270. for _, v := range req.File {
  271. //下载远程文件
  272. f, err := r.cli.Get(ctx, v.Url)
  273. if err != nil {
  274. return nil, fmt.Errorf("下载文件失败: %v", err)
  275. }
  276. part, err := writer.CreateFormFile("file", v.FileName)
  277. if err != nil {
  278. return nil, fmt.Errorf("创建表单失败: %v", err)
  279. }
  280. _, err = io.Copy(part, f.Response.Body)
  281. if err != nil {
  282. return nil, fmt.Errorf("写入文件流失败: %v", err)
  283. }
  284. _ = f.Response.Body.Close()
  285. }
  286. err := writer.Close()
  287. if err != nil {
  288. return nil, fmt.Errorf("关闭表单失败: %v", err)
  289. }
  290. r.cli.SetHeader("Content-Type", writer.FormDataContentType())
  291. // 拼接上传接口地址
  292. apiUrl := fmt.Sprintf("%s/api/v1/datasets/%s/documents", r.address, datasetID)
  293. var resp UploadDocumentResp
  294. if err = r.cli.PostVar(ctx, apiUrl, body).Scan(&resp); err != nil {
  295. return nil, fmt.Errorf("解析响应失败: %v", err)
  296. }
  297. if resp.Code != 0 {
  298. return nil, fmt.Errorf("上传失败[%d]: %s", resp.Code, resp.Message)
  299. }
  300. r.cli.SetHeader("Content-Type", "application/json")
  301. return &resp, nil
  302. }
  303. // UploadDocumentV2 直接接收文件数据上传到RagFlow
  304. // 适用于本地文件直接上传场景,无需先上传到Minio
  305. func (r *Client) UploadDocumentV2(ctx context.Context, datasetID string, req *UploadDocumentV2Req) (*UploadDocumentV2Resp, error) {
  306. // 参数校验
  307. if datasetID == "" {
  308. return nil, fmt.Errorf("datasetID 不能为空")
  309. }
  310. if req.FileName == "" {
  311. return nil, fmt.Errorf("fileName 不能为空")
  312. }
  313. if len(req.FileData) == 0 {
  314. return nil, fmt.Errorf("fileData 不能为空")
  315. }
  316. Log(ctx, fmt.Sprintf("[RagFlow V2] 开始上传文件: %s, 大小: %d bytes", req.FileName, len(req.FileData)))
  317. body := &bytes.Buffer{}
  318. writer := multipart.NewWriter(body)
  319. // 直接使用传入的文件数据创建表单文件
  320. part, err := writer.CreateFormFile("file", req.FileName)
  321. if err != nil {
  322. return nil, fmt.Errorf("创建表单文件失败: %v", err)
  323. }
  324. _, err = part.Write(req.FileData)
  325. if err != nil {
  326. return nil, fmt.Errorf("写入文件数据失败: %v", err)
  327. }
  328. err = writer.Close()
  329. if err != nil {
  330. return nil, fmt.Errorf("关闭表单失败: %v", err)
  331. }
  332. r.cli.SetHeader("Content-Type", writer.FormDataContentType())
  333. // 拼接上传接口地址
  334. apiUrl := fmt.Sprintf("%s/api/v1/datasets/%s/documents", r.address, datasetID)
  335. Log(ctx, fmt.Sprintf("[RagFlow V2] 请求地址: %s", apiUrl))
  336. var resp UploadDocumentV2Resp
  337. if err = r.cli.PostVar(ctx, apiUrl, body).Scan(&resp); err != nil {
  338. return nil, fmt.Errorf("上传请求失败: %v", err)
  339. }
  340. if resp.Code != 0 {
  341. Log(ctx, fmt.Sprintf("[RagFlow V2] 上传失败[%d]: %s", resp.Code, resp.Message))
  342. return nil, fmt.Errorf("上传失败[%d]: %s", resp.Code, resp.Message)
  343. }
  344. Log(ctx, fmt.Sprintf("[RagFlow V2] 上传成功,文档ID: %s", resp.Data[0].ID))
  345. r.cli.SetHeader("Content-Type", "application/json")
  346. return &resp, nil
  347. }
  348. // UpdateDocument 更新文档
  349. func (r *Client) UpdateDocument(ctx context.Context, datasetID, documentID string, req *UpdateDocumentReq) (*UpdateDocumentResp, error) {
  350. url := fmt.Sprintf("%s/api/v1/datasets/%s/documents/%s", r.address, datasetID, documentID)
  351. var resp UpdateDocumentResp
  352. err := r.cli.PutVar(ctx, url, req).Scan(&resp)
  353. if err != nil {
  354. return nil, fmt.Errorf("更新失败: %v", err)
  355. }
  356. return &resp, nil
  357. }
  358. // ListDocuments 列出文档
  359. func (r *Client) ListDocuments(ctx context.Context, datasetID string, req *ListDocumentReq) (*ListDocumentResp, error) {
  360. url := fmt.Sprintf("%s/api/v1/datasets/%s/documents", r.address, datasetID)
  361. params := gconv.Map(req)
  362. var resp ListDocumentResp
  363. err := r.cli.GetVar(ctx, url, params).Scan(&resp)
  364. if err != nil {
  365. return nil, err
  366. }
  367. return &resp, nil
  368. }
  369. // DeleteDocuments 删除文档
  370. func (r *Client) DeleteDocuments(ctx context.Context, datasetID string, ids []string) (*CommonResp, error) {
  371. url := fmt.Sprintf("%s/api/v1/datasets/%s/documents", r.address, datasetID)
  372. req := DeleteDocumentsReq{IDs: ids}
  373. var resp CommonResp
  374. err := r.cli.DeleteVar(ctx, url, req).Scan(&resp)
  375. if err != nil {
  376. return nil, err
  377. }
  378. return &resp, nil
  379. }
  380. // ParseDocuments 解析文档
  381. func (r *Client) ParseDocuments(ctx context.Context, datasetID string, ids []string) (*CommonResp, error) {
  382. if len(ids) == 0 {
  383. return nil, fmt.Errorf("document_ids 不能为空")
  384. }
  385. url := fmt.Sprintf("%s/api/v1/datasets/%s/chunks", r.address, datasetID)
  386. req := ParseDocumentReq{DocumentIDs: ids}
  387. var resp CommonResp
  388. err := r.cli.PostVar(ctx, url, req).Scan(&resp)
  389. if err != nil {
  390. return nil, err
  391. }
  392. return &resp, nil
  393. }
  394. // -------------------------- 3. 对话助手接口(新增) --------------------------
  395. // LLM配置
  396. type ChatLLM struct {
  397. ModelName string `json:"model_name,omitempty"`
  398. ModelType string `json:"model_type,omitempty"` // chat/image2text
  399. Temperature float64 `json:"temperature,omitempty"`
  400. TopP float64 `json:"top_p,omitempty"`
  401. PresencePenalty float64 `json:"presence_penalty,omitempty"`
  402. FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
  403. }
  404. // Prompt配置
  405. type ChatPrompt struct {
  406. SimilarityThreshold float64 `json:"similarity_threshold,omitempty"`
  407. KeywordsSimilarityWeight float64 `json:"keywords_similarity_weight,omitempty"`
  408. TopN int `json:"top_n,omitempty"`
  409. Variables []interface{} `json:"variables,omitempty"`
  410. RerankModel string `json:"rerank_model,omitempty"`
  411. EmptyResponse string `json:"empty_response,omitempty"`
  412. Opener string `json:"opener,omitempty"`
  413. ShowQuote bool `json:"show_quote,omitempty"`
  414. Prompt string `json:"prompt,omitempty"`
  415. }
  416. // 创建对话助手
  417. type CreateChatReq struct {
  418. Name string `json:"name"` // 必填
  419. Avatar string `json:"avatar,omitempty"`
  420. DatasetIDs []string `json:"dataset_ids,omitempty"`
  421. LLM *ChatLLM `json:"llm,omitempty"`
  422. Prompt *ChatPrompt `json:"prompt,omitempty"`
  423. }
  424. type CreateChatResp struct {
  425. Code int `json:"code"`
  426. Data ChatInfo `json:"data"`
  427. }
  428. type ChatInfo struct {
  429. ID string `json:"id"`
  430. Name string `json:"name"`
  431. }
  432. // 更新对话助手
  433. type UpdateChatReq struct {
  434. Name string `json:"name,omitempty"`
  435. Avatar string `json:"avatar,omitempty"`
  436. DatasetIDs []string `json:"dataset_ids,omitempty"`
  437. LLM *ChatLLM `json:"llm,omitempty"`
  438. Prompt *ChatPrompt `json:"prompt,omitempty"`
  439. }
  440. // 列表助手
  441. type ListChatReq struct {
  442. Page int `json:"page,omitempty"`
  443. PageSize int `json:"page_size,omitempty"`
  444. OrderBy string `json:"orderby,omitempty"`
  445. Desc bool `json:"desc,omitempty"`
  446. ID string `json:"id,omitempty"`
  447. Name string `json:"name,omitempty"`
  448. }
  449. type ListChatResp struct {
  450. Code int `json:"code"`
  451. Data []interface{} `json:"data"`
  452. }
  453. // 删除助手
  454. type DeleteChatReq struct {
  455. IDs []string `json:"ids"`
  456. }
  457. // CreateChat 创建对话助手
  458. func (r *Client) CreateChat(ctx context.Context, req *CreateChatReq) (*CreateChatResp, error) {
  459. if req.Name == "" {
  460. return nil, fmt.Errorf("name 是必填参数")
  461. }
  462. url := fmt.Sprintf("%s/api/v1/chats", r.address)
  463. r.cli.SetHeader("Content-Type", "application/json")
  464. var resp CreateChatResp
  465. err := r.cli.PostVar(ctx, url, req).Scan(&resp)
  466. if err != nil {
  467. return nil, err
  468. }
  469. return &resp, nil
  470. }
  471. // UpdateChat 更新对话助手
  472. func (r *Client) UpdateChat(ctx context.Context, chatID string, req *UpdateChatReq) (*CommonResp, error) {
  473. if chatID == "" {
  474. return nil, fmt.Errorf("chatID 不能为空")
  475. }
  476. url := fmt.Sprintf("%s/api/v1/chats/%s", r.address, chatID)
  477. r.cli.SetHeader("Content-Type", "application/json")
  478. var resp CommonResp
  479. err := r.cli.PutVar(ctx, url, req).Scan(&resp)
  480. if err != nil {
  481. return nil, err
  482. }
  483. return &resp, nil
  484. }
  485. // ListChats 列出助手
  486. func (r *Client) ListChats(ctx context.Context, req *ListChatReq) (*ListChatResp, error) {
  487. url := fmt.Sprintf("%s/api/v1/chats", r.address)
  488. params := gconv.Map(req)
  489. var resp ListChatResp
  490. err := r.cli.GetVar(ctx, url, params).Scan(&resp)
  491. if err != nil {
  492. return nil, err
  493. }
  494. return &resp, nil
  495. }
  496. // DeleteChats 删除助手
  497. func (r *Client) DeleteChats(ctx context.Context, ids []string) (*CommonResp, error) {
  498. url := fmt.Sprintf("%s/api/v1/chats", r.address)
  499. r.cli.SetHeader("Content-Type", "application/json")
  500. req := DeleteChatReq{IDs: ids}
  501. var resp CommonResp
  502. err := r.cli.DeleteVar(ctx, url, req).Scan(&resp)
  503. if err != nil {
  504. return nil, err
  505. }
  506. return &resp, nil
  507. }
  508. // -------------------------- 4. 会话管理接口(新增) --------------------------
  509. type CreateSessionReq struct {
  510. Name string `json:"name"` // 必填
  511. UserID string `json:"user_id,omitempty"`
  512. }
  513. type CreateSessionResp struct {
  514. Code int `json:"code"`
  515. Data SessionInfo `json:"data"`
  516. }
  517. type SessionInfo struct {
  518. Id string `json:"id"`
  519. Message struct {
  520. Content string `json:"content"`
  521. Role string `json:"role"`
  522. } `json:"message"`
  523. Name string `json:"name"`
  524. }
  525. type UpdateSessionReq struct {
  526. Name string `json:"name,omitempty"`
  527. UserID string `json:"user_id,omitempty"`
  528. }
  529. type ListSessionReq struct {
  530. Page int `json:"page,omitempty"`
  531. PageSize int `json:"page_size,omitempty"`
  532. OrderBy string `json:"orderby,omitempty"`
  533. Desc bool `json:"desc,omitempty"`
  534. ID string `json:"id,omitempty"`
  535. Name string `json:"name,omitempty"`
  536. UserID string `json:"user_id,omitempty"`
  537. }
  538. type DeleteSessionReq struct {
  539. IDs []string `json:"ids"`
  540. }
  541. // CreateSession 创建会话
  542. func (r *Client) CreateSession(ctx context.Context, chatID string, req *CreateSessionReq) (*CreateSessionResp, error) {
  543. if chatID == "" || req.Name == "" {
  544. return nil, fmt.Errorf("参数不能为空")
  545. }
  546. url := fmt.Sprintf("%s/api/v1/chats/%s/sessions", r.address, chatID)
  547. var resp CreateSessionResp
  548. err := r.cli.PostVar(ctx, url, req).Scan(&resp)
  549. if err != nil {
  550. return nil, err
  551. }
  552. return &resp, nil
  553. }
  554. // UpdateSession 更新会话
  555. func (r *Client) UpdateSession(ctx context.Context, chatID, sessionID string, req *UpdateSessionReq) (*CommonResp, error) {
  556. url := fmt.Sprintf("%s/api/v1/chats/%s/sessions/%s", r.address, chatID, sessionID)
  557. var resp CommonResp
  558. err := r.cli.PutVar(ctx, url, req).Scan(&resp)
  559. if err != nil {
  560. return nil, err
  561. }
  562. return &resp, nil
  563. }
  564. // ListSessions 列出会话
  565. func (r *Client) ListSessions(ctx context.Context, chatID string, req *ListSessionReq) (*CommonResp, error) {
  566. url := fmt.Sprintf("%s/api/v1/chats/%s/sessions", r.address, chatID)
  567. params := gconv.Map(req)
  568. var resp CommonResp
  569. err := r.cli.GetVar(ctx, url, params).Scan(&resp)
  570. if err != nil {
  571. return nil, err
  572. }
  573. return &resp, nil
  574. }
  575. // DeleteSessions 删除会话
  576. func (r *Client) DeleteSessions(ctx context.Context, chatID string, ids []string) (*CommonResp, error) {
  577. url := fmt.Sprintf("%s/api/v1/chats/%s/sessions", r.address, chatID)
  578. req := DeleteSessionReq{IDs: ids}
  579. var resp CommonResp
  580. err := r.cli.DeleteVar(ctx, url, req).Scan(&resp)
  581. if err != nil {
  582. return nil, err
  583. }
  584. return &resp, nil
  585. }
  586. // -------------------------- 5. AI对话接口(流式/非流式 新增) --------------------------
  587. // 元数据过滤条件
  588. type MetaCondition struct {
  589. Logic string `json:"logic,omitempty"` // and/or
  590. Conditions []Condition `json:"conditions,omitempty"`
  591. }
  592. type Condition struct {
  593. Name string `json:"name"`
  594. ComparisonOperator string `json:"comparison_operator"`
  595. Value interface{} `json:"value,omitempty"`
  596. }
  597. // 对话请求
  598. type ChatCompletionReq struct {
  599. Question string `json:"question"` // 必填
  600. Stream bool `json:"stream"`
  601. SessionID string `json:"session_id,omitempty"`
  602. UserID string `json:"user_id,omitempty"`
  603. MetadataCondition *MetaCondition `json:"metadata_condition,omitempty"`
  604. }
  605. type Completion struct {
  606. Answer string `json:"answer"`
  607. Reference interface{} `json:"reference"`
  608. AudioBinary interface{} `json:"audio_binary"`
  609. Id string `json:"id"`
  610. SessionId string `json:"session_id"`
  611. }
  612. // 对话请求
  613. type ChatCompletionResp struct {
  614. Code int `json:"code"`
  615. Message string `json:"message"`
  616. Data Completion `json:"data"`
  617. }
  618. // ChatCompletions 对话
  619. func (r *Client) ChatCompletions(ctx context.Context, chatID string, req *ChatCompletionReq) (*ChatCompletionResp, error) {
  620. if chatID == "" || req.Question == "" {
  621. return nil, fmt.Errorf("chatID 和 question 不能为空")
  622. }
  623. url := fmt.Sprintf("%s/api/v1/chats/%s/completions", r.address, chatID)
  624. req.Stream = false
  625. var resp ChatCompletionResp
  626. d, err := r.cli.Post(ctx, url, req)
  627. if err != nil {
  628. return nil, err
  629. }
  630. raw := d.ReadAll()
  631. if err = json.Unmarshal(raw, &resp); err != nil {
  632. return nil, fmt.Errorf("解析 RAGFlow completions 响应失败: %w body=%.500s", err, string(raw))
  633. }
  634. return &resp, nil
  635. }
  636. type StreamData struct {
  637. Answer string `json:"answer"` // 流式文本片段
  638. Final bool `json:"final"` // 是否为最后一条
  639. StartToThink bool `json:"start_to_think"` // 思考状态(可忽略)
  640. ID string `json:"id"`
  641. SessionID string `json:"session_id"`
  642. }
  643. // StreamResponse 对应完整的流式响应结构
  644. type StreamResponse struct {
  645. Code int `json:"code"` // 状态码,0为正常
  646. Data StreamData `json:"data"` // 业务数据
  647. }
  648. // ParseChatStreamEnvelope 解析 RAGFlow POST /api/v1/chats/{chat_id}/completions 流式中的单条 JSON。
  649. //
  650. // 官方文档约定:流式最后一帧为 {"code":0,"data":true}(data 为布尔 true),表示流结束;
  651. // 若仍用 Data struct 反序列化,json.Unmarshal 会失败,客户端若忽略该帧且连接长期不关闭,会一直卡在「处理中」。
  652. func ParseChatStreamEnvelope(raw []byte) (answer string, streamDone bool, code int, msg string, err error) {
  653. var envelope struct {
  654. Code int `json:"code"`
  655. Message string `json:"message"`
  656. Data json.RawMessage `json:"data"`
  657. }
  658. if err = json.Unmarshal(raw, &envelope); err != nil {
  659. return "", false, 0, "", err
  660. }
  661. code, msg = envelope.Code, envelope.Message
  662. if len(envelope.Data) == 0 {
  663. return "", false, code, msg, nil
  664. }
  665. var asBool bool
  666. if json.Unmarshal(envelope.Data, &asBool) == nil {
  667. if asBool {
  668. return "", true, code, msg, nil
  669. }
  670. return "", false, code, msg, nil
  671. }
  672. var inner struct {
  673. Answer string `json:"answer"`
  674. Final bool `json:"final"`
  675. }
  676. if err = json.Unmarshal(envelope.Data, &inner); err != nil {
  677. return "", false, code, msg, fmt.Errorf("解析 data 对象: %w", err)
  678. }
  679. if inner.Final {
  680. return inner.Answer, true, code, msg, nil
  681. }
  682. return inner.Answer, false, code, msg, nil
  683. }
  684. // SplitRAGFlowSSEDataPayloads 从一行 SSE 文本中拆出 0..n 条 JSON。
  685. // 支持:标准前缀 "data:{...}"、同一行多个 "data:{...}data:{...}"、以及无前缀的裸 JSON(部分代理场景)。
  686. func SplitRAGFlowSSEDataPayloads(line string) [][]byte {
  687. line = strings.TrimSpace(line)
  688. if line == "" || strings.HasPrefix(line, ":") {
  689. return nil
  690. }
  691. var blobs [][]byte
  692. parts := strings.Split(line, "data:")
  693. for _, part := range parts {
  694. part = strings.TrimSpace(part)
  695. if part == "" {
  696. continue
  697. }
  698. dec := json.NewDecoder(strings.NewReader(part))
  699. for {
  700. var raw json.RawMessage
  701. if err := dec.Decode(&raw); err != nil {
  702. break
  703. }
  704. if len(raw) > 0 {
  705. blobs = append(blobs, append([]byte(nil), raw...))
  706. }
  707. }
  708. }
  709. return blobs
  710. }
  711. // ChatCompletionsStream 获取流式消息
  712. func (r *Client) ChatCompletionsStream(ctx context.Context, chatID string, req *ChatCompletionReq) (io.ReadCloser, error) {
  713. if chatID == "" || req.Question == "" {
  714. return nil, fmt.Errorf("chatID 和 question 不能为空")
  715. }
  716. // 强制开启流式
  717. req.Stream = true
  718. url := fmt.Sprintf("%s/api/v1/chats/%s/completions", r.address, chatID)
  719. bodyBytes, err := json.Marshal(req)
  720. if err != nil {
  721. return nil, err
  722. }
  723. httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
  724. if err != nil {
  725. return nil, err
  726. }
  727. httpReq.Header.Set("Content-Type", "application/json")
  728. httpReq.Header.Set("Authorization", "Bearer "+r.apiKey)
  729. // 独立 http.Client,避免与全局 gclient 并发写 Timeout/Transport 导致互相干扰
  730. httpClient := &http.Client{
  731. Timeout: 0,
  732. Transport: &http.Transport{
  733. DisableCompression: true,
  734. },
  735. }
  736. resp, err := httpClient.Do(httpReq)
  737. if err != nil {
  738. return nil, err
  739. }
  740. if resp.StatusCode != http.StatusOK {
  741. b, _ := io.ReadAll(resp.Body)
  742. _ = resp.Body.Close()
  743. return nil, fmt.Errorf("ragflow 流式接口 HTTP %d: %s", resp.StatusCode, string(b))
  744. }
  745. return resp.Body, nil
  746. }