server.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. package neffos
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "strconv"
  8. "strings"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. uuid "github.com/iris-contrib/go.uuid"
  13. )
  14. // Upgrader is the definition type of a protocol upgrader, gorilla or gobwas or custom.
  15. // It is the first parameter of the `New` function which constructs a neffos server.
  16. type Upgrader func(w http.ResponseWriter, r *http.Request) (Socket, error)
  17. // IDGenerator is the type of function that it is used
  18. // to generate unique identifiers for new connections.
  19. //
  20. // See `Server.IDGenerator`.
  21. type IDGenerator func(w http.ResponseWriter, r *http.Request) string
  22. // DefaultIDGenerator returns a universal unique identifier for a new connection.
  23. // It's the default `IDGenerator` for `Server`.
  24. var DefaultIDGenerator IDGenerator = func(http.ResponseWriter, *http.Request) string {
  25. id, err := uuid.NewV4()
  26. if err != nil {
  27. return strconv.FormatInt(time.Now().Unix(), 10)
  28. }
  29. return id.String()
  30. }
  31. // Server is the neffos server.
  32. // Keeps the `IDGenerator` which can be customized, by default it's the `DefaultIDGenerator` which
  33. // generates connections unique identifiers using the uuid/v4.
  34. //
  35. // Callers can optionally register callbacks for connection, disconnection and errored.
  36. // Its most important methods are `ServeHTTP` which is used to register the server on a specific endpoint
  37. // and `Broadcast` and `Close`.
  38. // Use the `New` function to create a new server, server starts automatically, no further action is required.
  39. type Server struct {
  40. uuid string
  41. upgrader Upgrader
  42. IDGenerator IDGenerator
  43. StackExchange StackExchange
  44. // If `StackExchange` is set then this field is ignored.
  45. //
  46. // It overrides the default behavior(when no StackExchange is not used)
  47. // which publishes a message independently.
  48. // In short the default behavior doesn't wait for a message to be published to all clients
  49. // before any next broadcast call.
  50. //
  51. // Therefore, if set to true,
  52. // each broadcast call will publish its own message(s) by order.
  53. SyncBroadcaster bool
  54. // FireDisconnectAlways will allow firing the `OnDisconnect` server's
  55. // event even if the connection wasimmediately closed from the `OnConnect` server's event
  56. // through `Close()` or non-nil error.
  57. // See https://github.com/kataras/neffos/issues/41
  58. //
  59. // Defaults to false.
  60. FireDisconnectAlways bool
  61. mu sync.RWMutex
  62. namespaces Namespaces
  63. // connection read/write timeouts.
  64. readTimeout time.Duration
  65. writeTimeout time.Duration
  66. count uint64
  67. connections map[*Conn]struct{}
  68. connect chan *Conn
  69. disconnect chan *Conn
  70. actions chan action
  71. broadcastMessages chan []Message
  72. broadcaster *broadcaster
  73. // messages that this server must waits
  74. // for a reply from one of its own connections(see `waitMessages`).
  75. waitingMessages map[string]chan Message
  76. waitingMessagesMutex sync.RWMutex
  77. closed uint32
  78. // OnUpgradeError can be optionally registered to catch upgrade errors.
  79. OnUpgradeError func(err error)
  80. // OnConnect can be optionally registered to be notified for any new neffos client connection,
  81. // it can be used to force-connect a client to a specific namespace(s) or to send data immediately or
  82. // even to cancel a client connection and dissalow its connection when its return error value is not nil.
  83. // Don't confuse it with the `OnNamespaceConnect`, this callback is for the entire client side connection.
  84. OnConnect func(c *Conn) error
  85. // OnDisconnect can be optionally registered to notify about a connection's disconnect.
  86. // Don't confuse it with the `OnNamespaceDisconnect`, this callback is for the entire client side connection.
  87. OnDisconnect func(c *Conn)
  88. }
  89. // New constructs and returns a new neffos server.
  90. // Listens to incoming connections automatically, no further action is required from the caller.
  91. // The second parameter is the "connHandler", it can be
  92. // filled as `Namespaces`, `Events` or `WithTimeout`, same namespaces and events can be used on the client-side as well,
  93. // Use the `Conn#IsClient` on any event callback to determinate if it's a client-side connection or a server-side one.
  94. //
  95. // See examples for more.
  96. func New(upgrader Upgrader, connHandler ConnHandler) *Server {
  97. readTimeout, writeTimeout := getTimeouts(connHandler)
  98. namespaces := connHandler.GetNamespaces()
  99. s := &Server{
  100. uuid: uuid.Must(uuid.NewV4()).String(),
  101. upgrader: upgrader,
  102. namespaces: namespaces,
  103. readTimeout: readTimeout,
  104. writeTimeout: writeTimeout,
  105. connections: make(map[*Conn]struct{}),
  106. connect: make(chan *Conn, 1),
  107. disconnect: make(chan *Conn),
  108. actions: make(chan action),
  109. broadcastMessages: make(chan []Message),
  110. broadcaster: newBroadcaster(),
  111. waitingMessages: make(map[string]chan Message),
  112. IDGenerator: DefaultIDGenerator,
  113. }
  114. go s.start()
  115. return s
  116. }
  117. // UseStackExchange can be used to add one or more StackExchange
  118. // to the server.
  119. // Returns a non-nil error when "exc"
  120. // completes the `StackExchangeInitializer` interface and its `Init` failed.
  121. //
  122. // Read more at the `StackExchange` type's docs.
  123. func (s *Server) UseStackExchange(exc StackExchange) error {
  124. if exc == nil {
  125. return nil
  126. }
  127. if err := stackExchangeInit(exc, s.namespaces); err != nil {
  128. return err
  129. }
  130. if s.usesStackExchange() {
  131. s.StackExchange = wrapStackExchanges(s.StackExchange, exc)
  132. } else {
  133. s.StackExchange = exc
  134. }
  135. return nil
  136. }
  137. // usesStackExchange reports whether this server
  138. // uses one or more `StackExchange`s.
  139. func (s *Server) usesStackExchange() bool {
  140. return s.StackExchange != nil
  141. }
  142. func (s *Server) start() {
  143. atomic.StoreUint32(&s.closed, 0)
  144. for {
  145. select {
  146. case c := <-s.connect:
  147. s.connections[c] = struct{}{}
  148. atomic.AddUint64(&s.count, 1)
  149. case c := <-s.disconnect:
  150. if _, ok := s.connections[c]; ok {
  151. // close(c.out)
  152. delete(s.connections, c)
  153. atomic.AddUint64(&s.count, ^uint64(0))
  154. // println("disconnect...")
  155. if s.OnDisconnect != nil {
  156. // don't fire disconnect if was immediately closed on the `OnConnect` server event.
  157. if !s.FireDisconnectAlways && (!c.readiness.isReady() || (c.readiness.err != nil)) {
  158. continue
  159. }
  160. s.OnDisconnect(c)
  161. }
  162. if s.usesStackExchange() {
  163. s.StackExchange.OnDisconnect(c)
  164. }
  165. }
  166. case msgs := <-s.broadcastMessages:
  167. for c := range s.connections {
  168. publishMessages(c, msgs)
  169. }
  170. case act := <-s.actions:
  171. for c := range s.connections {
  172. act.call(c)
  173. }
  174. if act.done != nil {
  175. act.done <- struct{}{}
  176. }
  177. }
  178. }
  179. }
  180. // Close terminates the server and all of its connections, client connections are getting notified.
  181. func (s *Server) Close() {
  182. if atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
  183. s.Do(func(c *Conn) {
  184. c.Close()
  185. }, false)
  186. }
  187. }
  188. var (
  189. errServerClosed = errors.New("server closed")
  190. errInvalidMethod = errors.New("no valid request method")
  191. )
  192. // URLParamAsHeaderPrefix is the prefix that server parses the url parameters as request headers.
  193. // The client's `URLParamAsHeaderPrefix` must match.
  194. // Note that this is mostly useful for javascript browser-side clients, nodejs and go client support custom headers by default.
  195. // No action required from end-developer, exported only for chance to a custom parsing.
  196. const URLParamAsHeaderPrefix = "X-Websocket-Header-"
  197. func tryParseURLParamsToHeaders(r *http.Request) {
  198. q := r.URL.Query()
  199. for k, values := range q {
  200. if len(k) <= len(URLParamAsHeaderPrefix) {
  201. continue
  202. }
  203. k = http.CanonicalHeaderKey(k) // canonical, so no X-WebSocket thing.
  204. idx := strings.Index(k, URLParamAsHeaderPrefix)
  205. if idx != 0 { // must be prefix.
  206. continue
  207. }
  208. if r.Header == nil {
  209. r.Header = make(http.Header)
  210. }
  211. k = k[len(URLParamAsHeaderPrefix):]
  212. for _, v := range values {
  213. r.Header.Add(k, v)
  214. }
  215. }
  216. }
  217. var errUpgradeOnRetry = errors.New("check status")
  218. // IsTryingToReconnect reports whether the returning "err" from the `Server#Upgrade`
  219. // is from a client that was trying to reconnect to the websocket server.
  220. //
  221. // Look the `Conn#WasReconnected` and `Conn#ReconnectTries` too.
  222. func IsTryingToReconnect(err error) (ok bool) {
  223. return err != nil && err == errUpgradeOnRetry
  224. }
  225. // This header key should match with that browser-client's `whenResourceOnline->re-dial` uses.
  226. const websocketReconectHeaderKey = "X-Websocket-Reconnect"
  227. func isServerConnID(s string) bool {
  228. return strings.HasPrefix(s, "neffos(0x")
  229. }
  230. func genServerConnID(s *Server, c *Conn) string {
  231. return fmt.Sprintf("neffos(0x%s(%s%p))", s.uuid, c.id, c)
  232. }
  233. // Upgrade handles the connection, same as `ServeHTTP` but it can accept
  234. // a socket wrapper and a "customIDGen" that overrides the server's IDGenerator
  235. // and it does return the connection or any errors.
  236. func (s *Server) Upgrade(
  237. w http.ResponseWriter,
  238. r *http.Request,
  239. socketWrapper func(Socket) Socket,
  240. customIDGen IDGenerator,
  241. ) (*Conn, error) {
  242. if atomic.LoadUint32(&s.closed) > 0 {
  243. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  244. return nil, errServerClosed
  245. }
  246. if r.Method == http.MethodHead {
  247. w.WriteHeader(http.StatusFound)
  248. return nil, errUpgradeOnRetry
  249. }
  250. if r.Method != http.MethodGet {
  251. // RCF rfc2616 https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html
  252. // The response MUST include an Allow header containing a list of valid methods for the requested resource.
  253. //
  254. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Allow#Examples
  255. w.Header().Set("Allow", http.MethodGet)
  256. w.Header().Set("X-Content-Type-Options", "nosniff")
  257. w.WriteHeader(http.StatusMethodNotAllowed)
  258. fmt.Fprintln(w, http.StatusText(http.StatusMethodNotAllowed))
  259. return nil, errInvalidMethod
  260. }
  261. tryParseURLParamsToHeaders(r)
  262. socket, err := s.upgrader(w, r)
  263. if err != nil {
  264. if s.OnUpgradeError != nil {
  265. s.OnUpgradeError(err)
  266. }
  267. return nil, err
  268. }
  269. if socketWrapper != nil {
  270. socket = socketWrapper(socket)
  271. }
  272. c := newConn(socket, s.namespaces)
  273. if customIDGen != nil {
  274. c.id = customIDGen(w, r)
  275. } else {
  276. c.id = s.IDGenerator(w, r)
  277. }
  278. c.serverConnID = genServerConnID(s, c)
  279. c.readTimeout = s.readTimeout
  280. c.writeTimeout = s.writeTimeout
  281. c.server = s
  282. retriesHeaderValue := r.Header.Get(websocketReconectHeaderKey)
  283. if retriesHeaderValue != "" {
  284. c.ReconnectTries, _ = strconv.Atoi(retriesHeaderValue)
  285. }
  286. if !s.usesStackExchange() && !s.SyncBroadcaster {
  287. go func(c *Conn) {
  288. for s.waitMessages(c) {
  289. }
  290. }(c)
  291. }
  292. s.connect <- c
  293. go c.startReader()
  294. // Before `OnConnect` in order to be able
  295. // to Broadcast inside the `OnConnect` custom func.
  296. if s.usesStackExchange() {
  297. if err := s.StackExchange.OnConnect(c); err != nil {
  298. c.readiness.unwait(err)
  299. return nil, err
  300. }
  301. }
  302. // Start the reader before `OnConnect`, remember clients may remotely connect to namespace before `Server#OnConnect`
  303. // therefore any `Server:NSConn#OnNamespaceConnected` can write immediately to the client too.
  304. // Note also that the `Server#OnConnect` itself can do that as well but if the written Message's Namespace is not locally connected
  305. // it, correctly, can't pass the write checks. Also, and most important, the `OnConnect` is ready to connect a client to a namespace (locally and remotely).
  306. //
  307. // This has a downside:
  308. // We need a way to check if the `OnConnect` returns an non-nil error which means that the connection should terminate before namespace connect or anything.
  309. // The solution is to still accept reading messages but add them to the queue(like we already do for any case messages came before ack),
  310. // the problem to that is that the queue handler is fired when ack is done but `OnConnect` may not even return yet, so we introduce a `mark ready` atomic scope
  311. // and a channel which will wait for that `mark ready` if handle queue is called before ready.
  312. // Also make the same check before emit the connection's disconnect event (if defined),
  313. // which will be always ready to be called because we added the connections via the connect channel;
  314. // we still need the connection to be available for any broadcasting on connected events.
  315. // ^ All these only when server-side connection in order to correctly handle the end-developer's `OnConnect`.
  316. //
  317. // Look `Conn.serverReadyWaiter#startReader##handleQueue.serverReadyWaiter.unwait`(to hold the events until no error returned or)
  318. // `#Write:serverReadyWaiter.unwait` (for things like server connect).
  319. // All cases tested & worked perfectly.
  320. if s.OnConnect != nil {
  321. if err = s.OnConnect(c); err != nil {
  322. // TODO: Do something with that error.
  323. // The most suitable thing we can do is to somehow send this to the client's `Dial` return statement.
  324. // This can be done if client waits for "OK" signal or a failure with an error before return the websocket connection,
  325. // as for today we have the ack process which does NOT block and end-developer can send messages and server will handle them when both sides are ready.
  326. // So, maybe it's a better solution to transform that process into a blocking state which can handle any `Server#OnConnect` error and return it at client's `Dial`.
  327. // Think more later today.
  328. // Done but with a lot of code.... will try to cleanup some things.
  329. //println("OnConnect error: " + err.Error())
  330. c.readiness.unwait(err)
  331. // No need to disconnect here, connection's .Close will be called on readiness ch errored.
  332. // c.Close()
  333. return nil, err
  334. }
  335. }
  336. //println("OnConnect does not exist or no error, fire unwait")
  337. c.readiness.unwait(nil)
  338. return c, nil
  339. }
  340. // ServeHTTP completes the `http.Handler` interface, it should be passed on a http server's router
  341. // to serve this neffos server on a specific endpoint.
  342. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  343. s.Upgrade(w, r, nil, nil)
  344. }
  345. // GetTotalConnections returns the total amount of the connected connections to the server, it's fast
  346. // and can be used as frequently as needed.
  347. func (s *Server) GetTotalConnections() uint64 {
  348. return atomic.LoadUint64(&s.count)
  349. }
  350. type action struct {
  351. call func(*Conn)
  352. done chan struct{}
  353. }
  354. // Do loops through all connected connections and fires the "fn", with this method
  355. // callers can do whatever they want on a connection outside of a event's callback,
  356. // but make sure that these operations are not taking long time to complete because it delays the
  357. // new incoming connections.
  358. // If "async" is true then this method does not block the flow of the program.
  359. func (s *Server) Do(fn func(*Conn), async bool) {
  360. act := action{call: fn}
  361. if !async {
  362. act.done = make(chan struct{})
  363. // go func() { s.actions <- act }()
  364. // <-act.done
  365. }
  366. s.actions <- act
  367. if !async {
  368. <-act.done
  369. }
  370. }
  371. func publishMessages(c *Conn, msgs []Message) bool {
  372. for _, msg := range msgs {
  373. if msg.from == c.ID() {
  374. // if the message is not supposed to return back to any connection with this ID.
  375. return true
  376. }
  377. // if "To" field is given then send to a specific connection.
  378. if msg.To != "" && msg.To != c.ID() {
  379. return true
  380. }
  381. // c.Write may fail if the message is not supposed to end to this client
  382. // but the connection should be still open in order to continue.
  383. if !c.Write(msg) && c.IsClosed() {
  384. return false
  385. }
  386. }
  387. return true
  388. }
  389. func (s *Server) waitMessages(c *Conn) bool {
  390. s.broadcaster.mu.Lock()
  391. defer s.broadcaster.mu.Unlock()
  392. msgs, ok := s.broadcaster.waitUntilClosed(c.closeCh)
  393. if !ok {
  394. return false
  395. }
  396. return publishMessages(c, msgs)
  397. }
  398. type stringerValue struct{ v string }
  399. func (s stringerValue) String() string { return s.v }
  400. // Exclude can be passed on `Server#Broadcast` when
  401. // caller does not have access to the `Conn`, `NSConn` or a `Room` value but
  402. // has access to a string variable which is a connection's ID instead.
  403. //
  404. // Example Code:
  405. // nsConn.Conn.Server().Broadcast(
  406. //
  407. // neffos.Exclude("connection_id_here"),
  408. // neffos.Message{Namespace: "default", Room: "roomName or empty", Event: "chat", Body: [...]})
  409. func Exclude(connID string) fmt.Stringer { return stringerValue{connID} }
  410. // Broadcast method is fast and does not block any new incoming connection by-default,
  411. // it can be used as frequently as needed. Use the "msg"'s Namespace, or/and Event or/and Room to broadcast
  412. // to a specific type of connection collectives.
  413. //
  414. // If first "exceptSender" parameter is not nil then the message "msg" will be
  415. // broadcasted to all connected clients except the given connection's ID,
  416. // any value that completes the `fmt.Stringer` interface is valid. Keep note that
  417. // `Conn`, `NSConn`, `Room` and `Exclude(connID) global function` are valid values.
  418. //
  419. // Example Code:
  420. // nsConn.Conn.Server().Broadcast(
  421. //
  422. // nsConn OR nil,
  423. // neffos.Message{Namespace: "default", Room: "roomName or empty", Event: "chat", Body: [...]})
  424. //
  425. // Note that it if `StackExchange` is nil then its default behavior
  426. // doesn't wait for a publish to complete to all clients before any
  427. // next broadcast call. To change that behavior set the `Server.SyncBroadcaster` to true
  428. // before server start.
  429. func (s *Server) Broadcast(exceptSender fmt.Stringer, msgs ...Message) {
  430. if exceptSender != nil {
  431. var fromExplicit, from string
  432. switch c := exceptSender.(type) {
  433. case *Conn:
  434. fromExplicit = c.serverConnID
  435. case *NSConn:
  436. fromExplicit = c.Conn.serverConnID
  437. default:
  438. from = exceptSender.String()
  439. }
  440. for i := range msgs {
  441. if from != "" {
  442. msgs[i].from = from
  443. } else {
  444. msgs[i].FromExplicit = fromExplicit
  445. }
  446. }
  447. }
  448. if s.usesStackExchange() {
  449. s.StackExchange.Publish(msgs)
  450. return
  451. }
  452. if s.SyncBroadcaster {
  453. s.broadcastMessages <- msgs
  454. return
  455. }
  456. s.broadcaster.broadcast(msgs)
  457. }
  458. // Ask is like `Broadcast` but it blocks until a response
  459. // from a specific connection if "msg.To" is filled otherwise
  460. // from the first connection which will reply to this "msg".
  461. //
  462. // Accepts a context for deadline as its first input argument.
  463. // The second argument is the request message
  464. // which should be sent to a specific namespace:event
  465. // like the `Conn.Ask`.
  466. func (s *Server) Ask(ctx context.Context, msg Message) (Message, error) {
  467. if ctx == nil {
  468. ctx = context.TODO()
  469. }
  470. msg.wait = genWait(false)
  471. if s.usesStackExchange() {
  472. msg.wait = genWaitStackExchange(msg.wait)
  473. return s.StackExchange.Ask(ctx, msg, msg.wait)
  474. }
  475. ch := make(chan Message)
  476. s.waitingMessagesMutex.Lock()
  477. s.waitingMessages[msg.wait] = ch
  478. s.waitingMessagesMutex.Unlock()
  479. s.Broadcast(nil, msg)
  480. select {
  481. case <-ctx.Done():
  482. return Message{}, ctx.Err()
  483. case receive := <-ch:
  484. s.waitingMessagesMutex.Lock()
  485. delete(s.waitingMessages, msg.wait)
  486. s.waitingMessagesMutex.Unlock()
  487. return receive, receive.Err
  488. }
  489. }
  490. // GetConnectionsByNamespace can be used as an alternative way to retrieve
  491. // all connected connections to a specific "namespace" on a specific time point.
  492. // Do not use this function frequently, it is not designed to be fast or cheap, use it for debugging or logging every 'x' time.
  493. // Users should work with the event's callbacks alone, the usability is enough for all type of operations. See `Do` too.
  494. //
  495. // Not thread safe.
  496. func (s *Server) GetConnectionsByNamespace(namespace string) map[string]*NSConn {
  497. conns := make(map[string]*NSConn)
  498. s.mu.RLock()
  499. for c := range s.connections {
  500. if ns := c.Namespace(namespace); ns != nil {
  501. conns[ns.Conn.ID()] = ns
  502. }
  503. }
  504. s.mu.RUnlock()
  505. return conns
  506. }
  507. // GetConnections can be used as an alternative way to retrieve
  508. // all connected connections to the server on a specific time point.
  509. // Do not use this function frequently, it is not designed to be fast or cheap, use it for debugging or logging every 'x' time.
  510. //
  511. // Not thread safe.
  512. func (s *Server) GetConnections() map[string]*Conn {
  513. conns := make(map[string]*Conn)
  514. s.mu.RLock()
  515. for c := range s.connections {
  516. conns[c.ID()] = c
  517. }
  518. s.mu.RUnlock()
  519. return conns
  520. }
  521. var (
  522. // ErrBadNamespace may return from a `Conn#Connect` method when the remote side does not declare the given namespace.
  523. ErrBadNamespace = errors.New("bad namespace")
  524. // ErrBadRoom may return from a `Room#Leave` method when trying to leave from a not joined room.
  525. ErrBadRoom = errors.New("bad room")
  526. // ErrWrite may return from any connection's method when the underline connection is closed (unexpectedly).
  527. ErrWrite = errors.New("write closed")
  528. )