api.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. package ragflow
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/gogf/gf/v2/frame/g"
  8. "github.com/gogf/gf/v2/net/gclient"
  9. "github.com/gogf/gf/v2/os/gctx"
  10. "github.com/gogf/gf/v2/os/glog"
  11. "github.com/gogf/gf/v2/util/gconv"
  12. "io"
  13. "mime/multipart"
  14. "net/http"
  15. "regexp"
  16. "sync"
  17. )
  18. var (
  19. once sync.Once
  20. internalClient *Client // 内部私有客户端
  21. )
  22. // 32位小写十六进制正则校验
  23. var pipelineIDRegex = regexp.MustCompile(`^[0-9a-f]{32}$`)
  24. // Client 包装 gclient.Client
  25. type Client struct {
  26. cli *gclient.Client // 原生GoFrame HTTP客户端
  27. address string // RAGFlow服务地址
  28. apiKey string // 认证密钥
  29. }
  30. // Init 初始化客户端
  31. func Init(ctx context.Context, addr, apiKey string) *Client {
  32. once.Do(func() {
  33. internalClient = newClient(ctx, addr, apiKey)
  34. })
  35. return internalClient
  36. }
  37. // GetHttpClient 获取原生 gclient.Client
  38. func GetHttpClient() *Client {
  39. return internalClient
  40. }
  41. // newClient 创建客户端实例
  42. func newClient(ctx context.Context, addr, apiKey string) *Client {
  43. // 初始化GoFrame HTTP客户端
  44. cli := g.Client()
  45. // 设置全局默认请求头
  46. cli.SetHeader("Content-Type", "application/json")
  47. cli.SetHeader("Authorization", "Bearer "+apiKey)
  48. // 初始化校验
  49. if addr == "" || apiKey == "" {
  50. panic("ragflow client init failed: address or apiKey is empty")
  51. }
  52. return &Client{
  53. cli: cli,
  54. address: addr,
  55. apiKey: apiKey,
  56. }
  57. }
  58. // -------------------------- 工具函数 --------------------------
  59. func Log(ctx context.Context, v ...interface{}) {
  60. glog.Info(ctx, v...)
  61. }
  62. func NewCtx() context.Context {
  63. return gctx.New()
  64. }
  65. // -------------------------- 公共结构体 --------------------------
  66. type CommonResp struct {
  67. Code int `json:"code"`
  68. Msg string `json:"message,omitempty"`
  69. }
  70. // -------------------------- 1. 数据集接口(原有) --------------------------
  71. type CreateDatasetReq struct {
  72. Name string `json:"name"`
  73. ChunkMethod string `json:"chunk_method,omitempty"`
  74. Description string `json:"description,omitempty"`
  75. EmbeddingModel string `json:"embedding_model,omitempty"`
  76. //Avatar string `json:"avatar,omitempty"`
  77. //Permission string `json:"permission,omitempty"`
  78. //ParserConfig interface{} `json:"parser_config,omitempty"`
  79. //ParseType int `json:"parse_type,omitempty"`
  80. //PipelineID string `json:"pipeline_id,omitempty"`
  81. }
  82. type CreateDatasetResp struct {
  83. Code int `json:"code"`
  84. Data DatasetData `json:"data"`
  85. }
  86. type DatasetData struct {
  87. ID string `json:"id"`
  88. Name string `json:"name"`
  89. ChunkMethod string `json:"chunk_method"`
  90. EmbeddingModel string `json:"embedding_model"`
  91. Permission string `json:"permission"`
  92. ParserConfig interface{} `json:"parser_config"`
  93. CreateTime int64 `json:"create_time"`
  94. DocumentCount int `json:"document_count"`
  95. ChunkCount int `json:"chunk_count"`
  96. }
  97. type DeleteDatasetReq struct {
  98. IDs []string `json:"ids"`
  99. }
  100. type UpdateDatasetReq struct {
  101. Name string `json:"name,omitempty"`
  102. Language string `json:"language"`
  103. Avatar string `json:"avatar,omitempty"`
  104. Description string `json:"description,omitempty"`
  105. EmbeddingModel string `json:"embedding_model,omitempty"`
  106. Permission string `json:"permission,omitempty"`
  107. ChunkMethod string `json:"chunk_method,omitempty"`
  108. ParserConfig interface{} `json:"parser_config,omitempty"`
  109. Pagerank int `json:"pagerank,omitempty"`
  110. }
  111. // 创建知识库
  112. func (r *Client) CreateDataset(ctx context.Context, req *CreateDatasetReq) (*CreateDatasetResp, error) {
  113. if req.Name == "" {
  114. return nil, fmt.Errorf("name 是必填参数")
  115. }
  116. url := fmt.Sprintf("%s/api/v1/datasets", r.address)
  117. var resp CreateDatasetResp
  118. err := r.cli.PostVar(ctx, url, req).Scan(&resp)
  119. if err != nil {
  120. return nil, fmt.Errorf("请求失败: %v", err)
  121. }
  122. if resp.Code != 0 {
  123. return nil, fmt.Errorf("接口错误 code=%d", resp.Code)
  124. }
  125. return &resp, nil
  126. }
  127. // DeleteDataset 删除知识库
  128. func (r *Client) DeleteDataset(ctx context.Context, ids []string) (*CommonResp, error) {
  129. url := fmt.Sprintf("%s/api/v1/datasets", r.address)
  130. req := DeleteDatasetReq{IDs: ids}
  131. var resp CommonResp
  132. err := r.cli.DeleteVar(ctx, url, req).Scan(&resp)
  133. if err != nil {
  134. return nil, fmt.Errorf("请求失败: %v", err)
  135. }
  136. return &resp, nil
  137. }
  138. // UpdateDataset 更新知识库
  139. func (r *Client) UpdateDataset(ctx context.Context, datasetID string, req *UpdateDatasetReq) (*CommonResp, error) {
  140. if datasetID == "" {
  141. return nil, fmt.Errorf("datasetID 不能为空")
  142. }
  143. url := fmt.Sprintf("%s/api/v1/datasets/%s", r.address, datasetID)
  144. var resp CommonResp
  145. err := r.cli.PutVar(ctx, url, req).Scan(&resp)
  146. if err != nil {
  147. return nil, fmt.Errorf("请求失败: %v", err)
  148. }
  149. return &resp, nil
  150. }
  151. // -------------------------- 2. 文档接口(原有) --------------------------
  152. type UploadFileReq struct {
  153. File []*FileInfo `json:"file"`
  154. }
  155. type FileInfo struct {
  156. FileName string `json:"file_name"` // 文件名称
  157. Url string `json:"url"` // 文件地址
  158. }
  159. type UploadDocumentResp struct {
  160. Code int `json:"code"`
  161. Data []DocumentDetail `json:"data"`
  162. Message string `json:"message"` // 失败时返回
  163. }
  164. type DocumentDetail struct {
  165. ID string `json:"id"`
  166. Name string `json:"name"`
  167. DatasetID string `json:"dataset_id"`
  168. Size int64 `json:"size"`
  169. Run string `json:"run"`
  170. ChunkMethod string `json:"chunk_method"`
  171. ParserConfig interface{} `json:"parser_config"`
  172. Status string `json:"status"`
  173. }
  174. type UpdateDocumentReq struct {
  175. Name string `json:"name,omitempty"`
  176. MetaFields interface{} `json:"meta_fields,omitempty"`
  177. ChunkMethod string `json:"chunk_method,omitempty"`
  178. ParserConfig interface{} `json:"parser_config,omitempty"`
  179. Enabled int `json:"enabled,omitempty"`
  180. }
  181. type UpdateDocumentResp struct {
  182. Code int `json:"code"`
  183. Message string `json:"message"` // 失败时返回
  184. Data DocumentDetail `json:"data"`
  185. }
  186. type ListDocumentReq struct {
  187. Page int `json:"page,omitempty"`
  188. PageSize int `json:"page_size,omitempty"`
  189. OrderBy string `json:"orderby,omitempty"`
  190. Desc bool `json:"desc,omitempty"`
  191. Keywords string `json:"keywords,omitempty"`
  192. DocumentID string `json:"id,omitempty"`
  193. DocumentName string `json:"name,omitempty"`
  194. CreateTimeFrom int64 `json:"create_time_from,omitempty"`
  195. CreateTimeTo int64 `json:"create_time_to,omitempty"`
  196. Suffix string `json:"suffix,omitempty"`
  197. Run string `json:"run,omitempty"`
  198. MetadataCondition string `json:"metadata_condition,omitempty"`
  199. }
  200. type ListDocumentResp struct {
  201. Code int `json:"code"`
  202. Data struct {
  203. Docs []DocumentDetail `json:"docs"`
  204. Total int `json:"total_datasets"`
  205. } `json:"data"`
  206. }
  207. type DeleteDocumentsReq struct {
  208. IDs []string `json:"ids"`
  209. }
  210. type ParseDocumentReq struct {
  211. DocumentIDs []string `json:"document_ids"`
  212. }
  213. //func (r *Client) UploadDocument(ctx context.Context, datasetID, fileName string, url string) error {
  214. // // 参数校验
  215. // remoteFileUrl := url
  216. // apiUrl := fmt.Sprintf("%s/api/v1/datasets/%s/documents", r.address, datasetID)
  217. // fieldName := "file"
  218. //
  219. // resp, err := http.Get(remoteFileUrl)
  220. // if err != nil {
  221. // panic("下载远程文件失败:" + err.Error())
  222. // }
  223. // defer resp.Body.Close()
  224. //
  225. // body := &bytes.Buffer{}
  226. // writer := multipart.NewWriter(body)
  227. //
  228. // formFile, _ := writer.CreateFormFile(fieldName, "xinfengjinghua.pdf")
  229. // io.Copy(formFile, resp.Body)
  230. // _ = writer.Close()
  231. //
  232. // req, _ := http.NewRequest("POST", apiUrl, body)
  233. // req.Header.Set("Content-Type", writer.FormDataContentType())
  234. // req.Header.Set("Authorization", "Bearer "+"ragflow-nkGJCVuFSxgt7X6AkzuOXjx_9Q3hN3zAUP9LNpmJOFk")
  235. //
  236. // client := &http.Client{}
  237. // uploadResp, err := client.Do(req)
  238. // if err != nil {
  239. // panic("上传失败:" + err.Error())
  240. // }
  241. // defer uploadResp.Body.Close()
  242. //
  243. // result, _ := io.ReadAll(uploadResp.Body)
  244. // println("上传状态:", uploadResp.Status)
  245. // println("接口返回:", string(result))
  246. //
  247. // return nil
  248. //}
  249. func (r *Client) UploadDocument(ctx context.Context, datasetID string, req UploadFileReq) (*UploadDocumentResp, error) {
  250. // 参数校验
  251. if datasetID == "" {
  252. return nil, fmt.Errorf("datasetID 不能为空")
  253. }
  254. body := &bytes.Buffer{}
  255. writer := multipart.NewWriter(body)
  256. for _, v := range req.File {
  257. //下载远程文件
  258. f, err := r.cli.Get(ctx, v.Url)
  259. if err != nil {
  260. return nil, fmt.Errorf("下载文件失败: %v", err)
  261. }
  262. part, err := writer.CreateFormFile("file", v.FileName)
  263. if err != nil {
  264. return nil, fmt.Errorf("创建表单失败: %v", err)
  265. }
  266. _, err = io.Copy(part, f.Response.Body)
  267. if err != nil {
  268. return nil, fmt.Errorf("写入文件流失败: %v", err)
  269. }
  270. _ = f.Response.Body.Close()
  271. }
  272. err := writer.Close()
  273. if err != nil {
  274. return nil, fmt.Errorf("关闭表单失败: %v", err)
  275. }
  276. r.cli.SetHeader("Content-Type", writer.FormDataContentType())
  277. // 拼接上传接口地址
  278. apiUrl := fmt.Sprintf("%s/api/v1/datasets/%s/documents", r.address, datasetID)
  279. var resp UploadDocumentResp
  280. if err = r.cli.PostVar(ctx, apiUrl, body).Scan(&resp); err != nil {
  281. return nil, fmt.Errorf("解析响应失败: %v", err)
  282. }
  283. if resp.Code != 0 {
  284. return nil, fmt.Errorf("上传失败[%d]: %s", resp.Code, resp.Message)
  285. }
  286. r.cli.SetHeader("Content-Type", "application/json")
  287. return &resp, nil
  288. }
  289. // UpdateDocument 更新文档
  290. func (r *Client) UpdateDocument(ctx context.Context, datasetID, documentID string, req *UpdateDocumentReq) (*UpdateDocumentResp, error) {
  291. url := fmt.Sprintf("%s/api/v1/datasets/%s/documents/%s", r.address, datasetID, documentID)
  292. var resp UpdateDocumentResp
  293. err := r.cli.PutVar(ctx, url, req).Scan(&resp)
  294. if err != nil {
  295. return nil, fmt.Errorf("更新失败: %v", err)
  296. }
  297. return &resp, nil
  298. }
  299. // ListDocuments 列出文档
  300. func (r *Client) ListDocuments(ctx context.Context, datasetID string, req *ListDocumentReq) (*ListDocumentResp, error) {
  301. url := fmt.Sprintf("%s/api/v1/datasets/%s/documents", r.address, datasetID)
  302. params := gconv.Map(req)
  303. var resp ListDocumentResp
  304. err := r.cli.GetVar(ctx, url, params).Scan(&resp)
  305. if err != nil {
  306. return nil, err
  307. }
  308. return &resp, nil
  309. }
  310. // DeleteDocuments 删除文档
  311. func (r *Client) DeleteDocuments(ctx context.Context, datasetID string, ids []string) (*CommonResp, error) {
  312. url := fmt.Sprintf("%s/api/v1/datasets/%s/documents", r.address, datasetID)
  313. req := DeleteDocumentsReq{IDs: ids}
  314. var resp CommonResp
  315. err := r.cli.DeleteVar(ctx, url, req).Scan(&resp)
  316. if err != nil {
  317. return nil, err
  318. }
  319. return &resp, nil
  320. }
  321. // ParseDocuments 解析文档
  322. func (r *Client) ParseDocuments(ctx context.Context, datasetID string, ids []string) (*CommonResp, error) {
  323. if len(ids) == 0 {
  324. return nil, fmt.Errorf("document_ids 不能为空")
  325. }
  326. url := fmt.Sprintf("%s/api/v1/datasets/%s/chunks", r.address, datasetID)
  327. req := ParseDocumentReq{DocumentIDs: ids}
  328. var resp CommonResp
  329. err := r.cli.PostVar(ctx, url, req).Scan(&resp)
  330. if err != nil {
  331. return nil, err
  332. }
  333. return &resp, nil
  334. }
  335. // -------------------------- 3. 对话助手接口(新增) --------------------------
  336. // LLM配置
  337. type ChatLLM struct {
  338. ModelName string `json:"model_name,omitempty"`
  339. ModelType string `json:"model_type,omitempty"` // chat/image2text
  340. Temperature float64 `json:"temperature,omitempty"`
  341. TopP float64 `json:"top_p,omitempty"`
  342. PresencePenalty float64 `json:"presence_penalty,omitempty"`
  343. FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
  344. }
  345. // Prompt配置
  346. type ChatPrompt struct {
  347. SimilarityThreshold float64 `json:"similarity_threshold,omitempty"`
  348. KeywordsSimilarityWeight float64 `json:"keywords_similarity_weight,omitempty"`
  349. TopN int `json:"top_n,omitempty"`
  350. Variables []interface{} `json:"variables,omitempty"`
  351. RerankModel string `json:"rerank_model,omitempty"`
  352. EmptyResponse string `json:"empty_response,omitempty"`
  353. Opener string `json:"opener,omitempty"`
  354. ShowQuote bool `json:"show_quote,omitempty"`
  355. Prompt string `json:"prompt,omitempty"`
  356. }
  357. // 创建对话助手
  358. type CreateChatReq struct {
  359. Name string `json:"name"` // 必填
  360. Avatar string `json:"avatar,omitempty"`
  361. DatasetIDs []string `json:"dataset_ids,omitempty"`
  362. LLM *ChatLLM `json:"llm,omitempty"`
  363. Prompt *ChatPrompt `json:"prompt,omitempty"`
  364. }
  365. type CreateChatResp struct {
  366. Code int `json:"code"`
  367. Data ChatInfo `json:"data"`
  368. }
  369. type ChatInfo struct {
  370. ID string `json:"id"`
  371. Name string `json:"name"`
  372. }
  373. // 更新对话助手
  374. type UpdateChatReq struct {
  375. Name string `json:"name,omitempty"`
  376. Avatar string `json:"avatar,omitempty"`
  377. DatasetIDs []string `json:"dataset_ids,omitempty"`
  378. LLM *ChatLLM `json:"llm,omitempty"`
  379. Prompt *ChatPrompt `json:"prompt,omitempty"`
  380. }
  381. // 列表助手
  382. type ListChatReq struct {
  383. Page int `json:"page,omitempty"`
  384. PageSize int `json:"page_size,omitempty"`
  385. OrderBy string `json:"orderby,omitempty"`
  386. Desc bool `json:"desc,omitempty"`
  387. ID string `json:"id,omitempty"`
  388. Name string `json:"name,omitempty"`
  389. }
  390. type ListChatResp struct {
  391. Code int `json:"code"`
  392. Data []interface{} `json:"data"`
  393. }
  394. // 删除助手
  395. type DeleteChatReq struct {
  396. IDs []string `json:"ids"`
  397. }
  398. // CreateChat 创建对话助手
  399. func (r *Client) CreateChat(ctx context.Context, req *CreateChatReq) (*CreateChatResp, error) {
  400. if req.Name == "" {
  401. return nil, fmt.Errorf("name 是必填参数")
  402. }
  403. url := fmt.Sprintf("%s/api/v1/chats", r.address)
  404. r.cli.SetHeader("Content-Type", "application/json")
  405. var resp CreateChatResp
  406. err := r.cli.PostVar(ctx, url, req).Scan(&resp)
  407. if err != nil {
  408. return nil, err
  409. }
  410. return &resp, nil
  411. }
  412. // UpdateChat 更新对话助手
  413. func (r *Client) UpdateChat(ctx context.Context, chatID string, req *UpdateChatReq) (*CommonResp, error) {
  414. if chatID == "" {
  415. return nil, fmt.Errorf("chatID 不能为空")
  416. }
  417. url := fmt.Sprintf("%s/api/v1/chats/%s", r.address, chatID)
  418. r.cli.SetHeader("Content-Type", "application/json")
  419. var resp CommonResp
  420. err := r.cli.PutVar(ctx, url, req).Scan(&resp)
  421. if err != nil {
  422. return nil, err
  423. }
  424. return &resp, nil
  425. }
  426. // ListChats 列出助手
  427. func (r *Client) ListChats(ctx context.Context, req *ListChatReq) (*ListChatResp, error) {
  428. url := fmt.Sprintf("%s/api/v1/chats", r.address)
  429. params := gconv.Map(req)
  430. var resp ListChatResp
  431. err := r.cli.GetVar(ctx, url, params).Scan(&resp)
  432. if err != nil {
  433. return nil, err
  434. }
  435. return &resp, nil
  436. }
  437. // DeleteChats 删除助手
  438. func (r *Client) DeleteChats(ctx context.Context, ids []string) (*CommonResp, error) {
  439. url := fmt.Sprintf("%s/api/v1/chats", r.address)
  440. r.cli.SetHeader("Content-Type", "application/json")
  441. req := DeleteChatReq{IDs: ids}
  442. var resp CommonResp
  443. err := r.cli.DeleteVar(ctx, url, req).Scan(&resp)
  444. if err != nil {
  445. return nil, err
  446. }
  447. return &resp, nil
  448. }
  449. // -------------------------- 4. 会话管理接口(新增) --------------------------
  450. type CreateSessionReq struct {
  451. Name string `json:"name"` // 必填
  452. UserID string `json:"user_id,omitempty"`
  453. }
  454. type CreateSessionResp struct {
  455. Code int `json:"code"`
  456. Data SessionInfo `json:"data"`
  457. }
  458. type SessionInfo struct {
  459. Id string `json:"id"`
  460. Message struct {
  461. Content string `json:"content"`
  462. Role string `json:"role"`
  463. } `json:"message"`
  464. Name string `json:"name"`
  465. }
  466. type UpdateSessionReq struct {
  467. Name string `json:"name,omitempty"`
  468. UserID string `json:"user_id,omitempty"`
  469. }
  470. type ListSessionReq struct {
  471. Page int `json:"page,omitempty"`
  472. PageSize int `json:"page_size,omitempty"`
  473. OrderBy string `json:"orderby,omitempty"`
  474. Desc bool `json:"desc,omitempty"`
  475. ID string `json:"id,omitempty"`
  476. Name string `json:"name,omitempty"`
  477. UserID string `json:"user_id,omitempty"`
  478. }
  479. type DeleteSessionReq struct {
  480. IDs []string `json:"ids"`
  481. }
  482. // CreateSession 创建会话
  483. func (r *Client) CreateSession(ctx context.Context, chatID string, req *CreateSessionReq) (*CreateSessionResp, error) {
  484. if chatID == "" || req.Name == "" {
  485. return nil, fmt.Errorf("参数不能为空")
  486. }
  487. url := fmt.Sprintf("%s/api/v1/chats/%s/sessions", r.address, chatID)
  488. var resp CreateSessionResp
  489. err := r.cli.PostVar(ctx, url, req).Scan(&resp)
  490. if err != nil {
  491. return nil, err
  492. }
  493. return &resp, nil
  494. }
  495. // UpdateSession 更新会话
  496. func (r *Client) UpdateSession(ctx context.Context, chatID, sessionID string, req *UpdateSessionReq) (*CommonResp, error) {
  497. url := fmt.Sprintf("%s/api/v1/chats/%s/sessions/%s", r.address, chatID, sessionID)
  498. var resp CommonResp
  499. err := r.cli.PutVar(ctx, url, req).Scan(&resp)
  500. if err != nil {
  501. return nil, err
  502. }
  503. return &resp, nil
  504. }
  505. // ListSessions 列出会话
  506. func (r *Client) ListSessions(ctx context.Context, chatID string, req *ListSessionReq) (*CommonResp, error) {
  507. url := fmt.Sprintf("%s/api/v1/chats/%s/sessions", r.address, chatID)
  508. params := gconv.Map(req)
  509. var resp CommonResp
  510. err := r.cli.GetVar(ctx, url, params).Scan(&resp)
  511. if err != nil {
  512. return nil, err
  513. }
  514. return &resp, nil
  515. }
  516. // DeleteSessions 删除会话
  517. func (r *Client) DeleteSessions(ctx context.Context, chatID string, ids []string) (*CommonResp, error) {
  518. url := fmt.Sprintf("%s/api/v1/chats/%s/sessions", r.address, chatID)
  519. req := DeleteSessionReq{IDs: ids}
  520. var resp CommonResp
  521. err := r.cli.DeleteVar(ctx, url, req).Scan(&resp)
  522. if err != nil {
  523. return nil, err
  524. }
  525. return &resp, nil
  526. }
  527. // -------------------------- 5. AI对话接口(流式/非流式 新增) --------------------------
  528. // 元数据过滤条件
  529. type MetaCondition struct {
  530. Logic string `json:"logic,omitempty"` // and/or
  531. Conditions []Condition `json:"conditions,omitempty"`
  532. }
  533. type Condition struct {
  534. Name string `json:"name"`
  535. ComparisonOperator string `json:"comparison_operator"`
  536. Value interface{} `json:"value,omitempty"`
  537. }
  538. // 对话请求
  539. type ChatCompletionReq struct {
  540. Question string `json:"question"` // 必填
  541. Stream bool `json:"stream"`
  542. SessionID string `json:"session_id,omitempty"`
  543. UserID string `json:"user_id,omitempty"`
  544. MetadataCondition *MetaCondition `json:"metadata_condition,omitempty"`
  545. }
  546. type Completion struct {
  547. Answer string `json:"answer"`
  548. Reference interface{} `json:"reference"`
  549. AudioBinary interface{} `json:"audio_binary"`
  550. Id string `json:"id"`
  551. SessionId string `json:"session_id"`
  552. }
  553. // 对话请求
  554. type ChatCompletionResp struct {
  555. Code int `json:"code"`
  556. Message string `json:"message"`
  557. Data Completion `json:"data"`
  558. }
  559. // ChatCompletions 对话
  560. func (r *Client) ChatCompletions(ctx context.Context, chatID string, req *ChatCompletionReq) (*ChatCompletionResp, error) {
  561. if chatID == "" || req.Question == "" {
  562. return nil, fmt.Errorf("chatID 和 question 不能为空")
  563. }
  564. url := fmt.Sprintf("%s/api/v1/chats/%s/completions", r.address, chatID)
  565. req.Stream = false
  566. var resp ChatCompletionResp
  567. d, err := r.cli.Post(ctx, url, req)
  568. if err != nil {
  569. return nil, err
  570. }
  571. err = json.Unmarshal(d.ReadAll(), &resp)
  572. return &resp, nil
  573. }
  574. func (r *Client) ChatCompletionsStream(ctx context.Context, chatID string, req *ChatCompletionReq) (io.ReadCloser, error) {
  575. if chatID == "" || req.Question == "" {
  576. return nil, fmt.Errorf("chatID 和 question 不能为空")
  577. }
  578. // 强制开启流式
  579. req.Stream = true
  580. url := fmt.Sprintf("%s/api/v1/chats/%s/completions", r.address, chatID)
  581. r.cli.Timeout(0)
  582. r.cli.Transport = &http.Transport{
  583. DisableCompression: true,
  584. }
  585. resp, err := r.cli.Post(ctx, url, req)
  586. if err != nil {
  587. return nil, err
  588. }
  589. return resp.Body, nil
  590. }