client.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. package client
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "github.com/open-dingtalk/dingtalk-stream-sdk-go/card"
  9. "io"
  10. "net/http"
  11. "net/url"
  12. "sync"
  13. "time"
  14. "github.com/gorilla/websocket"
  15. "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
  16. "github.com/open-dingtalk/dingtalk-stream-sdk-go/handler"
  17. "github.com/open-dingtalk/dingtalk-stream-sdk-go/logger"
  18. "github.com/open-dingtalk/dingtalk-stream-sdk-go/payload"
  19. "github.com/open-dingtalk/dingtalk-stream-sdk-go/plugin"
  20. "github.com/open-dingtalk/dingtalk-stream-sdk-go/utils"
  21. )
  22. /**
  23. * @Author linya.jj
  24. * @Date 2023/3/22 14:23
  25. */
  26. type StreamClient struct {
  27. AppCredential *AppCredentialConfig
  28. UserAgent *UserAgentConfig
  29. AutoReconnect bool
  30. subscriptions map[string]map[string]handler.IFrameHandler
  31. conn *websocket.Conn
  32. sessionId string
  33. mutex sync.Mutex
  34. extras map[string]string
  35. openApiHost string
  36. proxy string
  37. }
  38. func NewStreamClient(options ...ClientOption) *StreamClient {
  39. cli := &StreamClient{}
  40. defaultOptions := []ClientOption{
  41. WithSubscription(utils.SubscriptionTypeKSystem, "disconnect", cli.OnDisconnect),
  42. WithSubscription(utils.SubscriptionTypeKSystem, "ping", cli.OnPing),
  43. WithUserAgent(NewDingtalkGoSDKUserAgent()),
  44. WithAutoReconnect(true),
  45. }
  46. for _, option := range defaultOptions {
  47. option(cli)
  48. }
  49. for _, option := range options {
  50. if option == nil {
  51. continue
  52. }
  53. option(cli)
  54. }
  55. return cli
  56. }
  57. func (cli *StreamClient) Start(ctx context.Context) error {
  58. if cli.conn != nil {
  59. return nil
  60. }
  61. cli.mutex.Lock()
  62. defer cli.mutex.Unlock()
  63. if cli.conn != nil {
  64. return nil
  65. }
  66. endpoint, err := cli.GetConnectionEndpoint(ctx)
  67. if err != nil {
  68. return err
  69. }
  70. wssUrl := fmt.Sprintf("%s?ticket=%s", endpoint.Endpoint, endpoint.Ticket)
  71. header := make(http.Header)
  72. var dialer *websocket.Dialer
  73. if len(cli.proxy) == 0 {
  74. dialer = websocket.DefaultDialer
  75. } else {
  76. proxyURL, err := url.Parse(cli.proxy)
  77. if err != nil {
  78. return err
  79. }
  80. dialer = &websocket.Dialer{
  81. Proxy: http.ProxyURL(proxyURL),
  82. }
  83. }
  84. conn, resp, err := dialer.Dial(wssUrl, header)
  85. if err != nil {
  86. return err
  87. }
  88. // 建连失败
  89. if resp.StatusCode >= http.StatusBadRequest {
  90. return utils.ErrorFromHttpResponseBody(resp)
  91. }
  92. cli.conn = conn
  93. cli.sessionId = endpoint.Ticket
  94. logger.GetLogger().Infof("connect success, sessionId=[%s]", cli.sessionId)
  95. go cli.processLoop()
  96. return nil
  97. }
  98. func (cli *StreamClient) processLoop() {
  99. defer func() {
  100. if err := recover(); err != nil {
  101. logger.GetLogger().Errorf("connection process panic due to unknown reason, error=[%s]", err)
  102. }
  103. if cli.AutoReconnect {
  104. go cli.reconnect()
  105. }
  106. }()
  107. for {
  108. if cli.conn == nil {
  109. logger.GetLogger().Errorf("connection process connect nil, maybe disconnected.")
  110. return
  111. }
  112. messageType, message, err := cli.conn.ReadMessage()
  113. if err != nil {
  114. logger.GetLogger().Errorf("connection process read message error: messageType=[%d] message=[%s] error=[%s]", messageType, string(message), err)
  115. return
  116. }
  117. logger.GetLogger().Debugf("[wire] [websocket] remote => local: \n%s", string(message))
  118. go cli.processDataFrame(message)
  119. }
  120. }
  121. func (cli *StreamClient) processDataFrame(rawData []byte) {
  122. defer func() {
  123. if err := recover(); err != nil {
  124. logger.GetLogger().Errorf("connection processDataFrame panic, error=[%s]", err)
  125. }
  126. }()
  127. dataFrame, err := payload.DecodeDataFrame(rawData)
  128. if err != nil {
  129. logger.GetLogger().Errorf("connection process decode data frame error: length=[%d] error=[%s]", len(rawData), err)
  130. return
  131. }
  132. if dataFrame == nil || dataFrame.Headers == nil {
  133. logger.GetLogger().Errorf("connection processDataFrame dataFrame nil.")
  134. return
  135. }
  136. var dataAck *payload.DataFrameResponse
  137. frameHandler, err := cli.GetHandler(dataFrame.Type, dataFrame.GetTopic())
  138. if err != nil || frameHandler == nil {
  139. // 没有注册handler,返回404
  140. dataAck = payload.NewDataFrameResponse(payload.DataFrameResponseStatusCodeKHandlerNotFound)
  141. } else {
  142. dataAck, err = frameHandler(context.Background(), dataFrame)
  143. if err != nil && dataAck == nil {
  144. dataAck = payload.NewErrorDataFrameResponse(err)
  145. }
  146. }
  147. if dataAck == nil {
  148. dataAck = payload.NewSuccessDataFrameResponse()
  149. }
  150. if dataAck.GetHeader(payload.DataFrameHeaderKMessageId) == "" {
  151. dataAck.SetHeader(payload.DataFrameHeaderKMessageId, dataFrame.GetMessageId())
  152. }
  153. if dataAck.GetHeader(payload.DataFrameHeaderKContentType) == "" {
  154. dataAck.SetHeader(payload.DataFrameHeaderKContentType, payload.DataFrameContentTypeKJson)
  155. }
  156. errSend := cli.SendDataFrameResponse(context.Background(), dataAck)
  157. sentBytes, _ := json.Marshal(dataAck)
  158. logger.GetLogger().Debugf("[wire] [websocket] local => remote:\n%s", string(sentBytes))
  159. if errSend != nil {
  160. logger.GetLogger().Errorf("connection processDataFrame send response error: error=[%s]", errSend)
  161. }
  162. }
  163. func (cli *StreamClient) Close() {
  164. if cli.conn == nil {
  165. return
  166. }
  167. cli.mutex.Lock()
  168. defer cli.mutex.Unlock()
  169. if cli.conn == nil {
  170. return
  171. }
  172. if err := cli.conn.Close(); err != nil {
  173. logger.GetLogger().Errorf("StreamClient close. error=[%s]", err)
  174. }
  175. cli.conn = nil
  176. cli.sessionId = ""
  177. }
  178. func (cli *StreamClient) reconnect() {
  179. defer func() {
  180. if err := recover(); err != nil {
  181. logger.GetLogger().Errorf("reconect panic due to unknown reason. error=[%s]", err)
  182. }
  183. }()
  184. cli.Close()
  185. for {
  186. err := cli.Start(context.Background())
  187. if err != nil {
  188. logger.GetLogger().Errorf("StreamClient reconnect error. error=[%s]", err)
  189. time.Sleep(time.Second * 3)
  190. } else {
  191. logger.GetLogger().Infof("StreamClient reconnect success")
  192. return
  193. }
  194. }
  195. }
  196. func (cli *StreamClient) GetHandler(stype, stopic string) (handler.IFrameHandler, error) {
  197. subs := cli.subscriptions[stype]
  198. if subs == nil || subs[stopic] == nil {
  199. return nil, errors.New("HandlerNotRegistedForTypeTopic_" + stype + "_" + stopic)
  200. }
  201. return subs[stopic], nil
  202. }
  203. func (cli *StreamClient) CheckConfigValid() error {
  204. if err := cli.AppCredential.Valid(); err != nil {
  205. return err
  206. }
  207. if err := cli.UserAgent.Valid(); err != nil {
  208. return err
  209. }
  210. if cli.subscriptions == nil {
  211. return errors.New("subscriptionsNil")
  212. }
  213. for ttype, subs := range cli.subscriptions {
  214. if _, ok := utils.SubscriptionTypeSet[ttype]; !ok {
  215. return errors.New("UnKnownSubscriptionType_" + ttype)
  216. }
  217. if len(subs) <= 0 {
  218. return errors.New("NoHandlersRegistedForType_" + ttype)
  219. }
  220. for ttopic, h := range subs {
  221. if h == nil {
  222. return errors.New("HandlerNilForTypeTopic_" + ttype + "_" + ttopic)
  223. }
  224. }
  225. }
  226. return nil
  227. }
  228. func (cli *StreamClient) GetConnectionEndpoint(ctx context.Context) (*payload.ConnectionEndpointResponse, error) {
  229. if err := cli.CheckConfigValid(); err != nil {
  230. return nil, err
  231. }
  232. requestModel := payload.ConnectionEndpointRequest{
  233. ClientId: cli.AppCredential.ClientId,
  234. ClientSecret: cli.AppCredential.ClientSecret,
  235. UserAgent: cli.UserAgent.UserAgent,
  236. Subscriptions: make([]*payload.SubscriptionModel, 0),
  237. Extras: cli.extras,
  238. }
  239. if localIp, err := utils.GetFirstLanIP(); err == nil {
  240. requestModel.LocalIP = localIp
  241. }
  242. for ttype, subs := range cli.subscriptions {
  243. for ttopic := range subs {
  244. requestModel.Subscriptions = append(requestModel.Subscriptions, &payload.SubscriptionModel{
  245. Type: ttype,
  246. Topic: ttopic,
  247. })
  248. }
  249. }
  250. requestJsonBody, _ := json.Marshal(requestModel)
  251. var targetHost string
  252. if len(cli.openApiHost) == 0 {
  253. targetHost = utils.DefaultOpenApiHost
  254. } else {
  255. targetHost = cli.openApiHost
  256. }
  257. req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetHost+utils.GetConnectionEndpointAPIUrl, bytes.NewReader(requestJsonBody))
  258. if err != nil {
  259. return nil, err
  260. }
  261. req.Header.Set("Content-Type", "application/json")
  262. req.Header.Set("Accept", "application/json")
  263. var transport http.RoundTripper
  264. if len(cli.proxy) == 0 {
  265. transport = http.DefaultTransport
  266. } else {
  267. proxyURL, err := url.Parse(cli.proxy)
  268. if err != nil {
  269. return nil, err
  270. }
  271. transport = &http.Transport{
  272. Proxy: http.ProxyURL(proxyURL),
  273. }
  274. }
  275. httpClient := &http.Client{
  276. Transport: transport,
  277. Timeout: 5 * time.Second, //设置超时,包含connection时间、任意重定向时间、读取response body时间
  278. }
  279. logger.GetLogger().Debugf("[wire] [http] local => remote:\n%s %s %s\nHost: %s\n%s\n\n%s",
  280. req.Method, req.URL.RequestURI(), req.Proto, req.Host,
  281. utils.DumpHeaders(req.Header), requestJsonBody)
  282. resp, err := httpClient.Do(req)
  283. if err != nil {
  284. return nil, err
  285. }
  286. if resp.StatusCode != http.StatusOK {
  287. return nil, utils.ErrorFromHttpResponseBody(resp)
  288. }
  289. defer resp.Body.Close()
  290. responseJsonBody, err := io.ReadAll(resp.Body)
  291. if err != nil {
  292. return nil, err
  293. }
  294. logger.GetLogger().Debugf("[wire] [http] remote => localhost:\n%s %s\n%s\n\n%s",
  295. resp.Proto, resp.Status,
  296. utils.DumpHeaders(resp.Header), responseJsonBody)
  297. endpoint := &payload.ConnectionEndpointResponse{}
  298. if err := json.Unmarshal(responseJsonBody, endpoint); err != nil {
  299. return nil, err
  300. }
  301. if err := endpoint.Valid(); err != nil {
  302. return nil, err
  303. }
  304. return endpoint, nil
  305. }
  306. func (cli *StreamClient) OnDisconnect(ctx context.Context, df *payload.DataFrame) (*payload.DataFrameResponse, error) {
  307. logger.GetLogger().Debugf("StreamClient.OnDisconnect")
  308. cli.Close()
  309. return nil, nil
  310. }
  311. func (cli *StreamClient) OnPing(ctx context.Context, df *payload.DataFrame) (*payload.DataFrameResponse, error) {
  312. dfPong := payload.NewDataFrameAckPong(df.GetMessageId())
  313. dfPong.Data = df.Data
  314. return dfPong, nil
  315. }
  316. // 返回正常数据包
  317. func (cli *StreamClient) SendDataFrameResponse(ctx context.Context, resp *payload.DataFrameResponse) error {
  318. if resp == nil {
  319. return errors.New("SendDataFrameResponseError_ResponseNil")
  320. }
  321. if cli.conn == nil {
  322. logger.GetLogger().Errorf("SendDataFrameResponse error, conn nil, maybe disconnected.")
  323. return errors.New("disconnected")
  324. }
  325. return cli.conn.WriteJSON(resp)
  326. }
  327. // 通用注册函数
  328. func (cli *StreamClient) RegisterRouter(stype, stopic string, frameHandler handler.IFrameHandler) {
  329. if cli.subscriptions == nil {
  330. cli.subscriptions = make(map[string]map[string]handler.IFrameHandler)
  331. }
  332. if _, ok := cli.subscriptions[stype]; !ok {
  333. cli.subscriptions[stype] = make(map[string]handler.IFrameHandler)
  334. }
  335. cli.subscriptions[stype][stopic] = frameHandler
  336. }
  337. // callback类型注册函数
  338. func (cli *StreamClient) RegisterCallbackRouter(topic string, frameHandler handler.IFrameHandler) {
  339. cli.RegisterRouter(utils.SubscriptionTypeKCallback, topic, frameHandler)
  340. }
  341. // 聊天机器人的注册函数
  342. func (cli *StreamClient) RegisterChatBotCallbackRouter(messageHandler chatbot.IChatBotMessageHandler) {
  343. cli.RegisterRouter(utils.SubscriptionTypeKCallback, payload.BotMessageCallbackTopic, chatbot.NewDefaultChatBotFrameHandler(messageHandler).OnEventReceived)
  344. }
  345. // AI插件的注册函数
  346. func (cli *StreamClient) RegisterPluginCallbackRouter(messageHandler plugin.IPluginMessageHandler) {
  347. cli.RegisterRouter(utils.SubscriptionTypeKCallback, payload.PluginMessageCallbackTopic, plugin.NewDefaultPluginFrameHandler(messageHandler).OnEventReceived)
  348. }
  349. // 互动卡片的注册函数
  350. func (cli *StreamClient) RegisterCardCallbackRouter(messageHandler card.ICardCallbackHandler) {
  351. cli.RegisterRouter(utils.SubscriptionTypeKCallback, payload.CardInstanceCallbackTopic, card.NewDefaultPluginFrameHandler(messageHandler).OnEventReceived)
  352. }
  353. // 事件类型的注册函数
  354. func (cli *StreamClient) RegisterEventRouter(topic string, frameHandler handler.IFrameHandler) {
  355. cli.RegisterRouter(utils.SubscriptionTypeKEvent, topic, frameHandler)
  356. }
  357. // 所有事件的注册函数
  358. func (cli *StreamClient) RegisterAllEventRouter(frameHandler handler.IFrameHandler) {
  359. cli.RegisterRouter(utils.SubscriptionTypeKEvent, "*", frameHandler)
  360. }