rpc_client.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. // RPCClient implements a rpc client tool with reconnect and load balance.
  2. package server
  3. import (
  4. "fmt"
  5. "math/rand"
  6. "net/rpc"
  7. "time"
  8. )
  9. type RPCClient struct {
  10. clients map[string]*rpc.Client
  11. random *rand.Rand
  12. }
  13. func NewRPCClient() (*RPCClient, error) {
  14. if serverInstance == nil {
  15. return nil, errorf(errServerNotInit)
  16. }
  17. if serverInstance.svrmgr == nil {
  18. return nil, errorf(errServerManagerNotInit)
  19. }
  20. return &RPCClient{
  21. clients: make(map[string]*rpc.Client),
  22. random: rand.New(rand.NewSource(time.Now().UnixNano())),
  23. }, nil
  24. }
  25. func rpcCallWithReconnect(client *rpc.Client, addr string, serverMethod string, args interface{}, reply interface{}) error {
  26. err := client.Call(serverMethod, args, reply)
  27. if err == rpc.ErrShutdown {
  28. Log.Warn("rpc connection shut down, trying to reconnect...")
  29. client, err = rpc.Dial("tcp", addr)
  30. if err != nil {
  31. return err
  32. }
  33. return client.Call(serverMethod, args, reply)
  34. }
  35. return err
  36. }
  37. // RPC call with reconnect and retry.
  38. func (client *RPCClient) Call(severName string, serverMethod string, args interface{}, reply interface{}) error {
  39. addrs, err := serverInstance.svrmgr.GetServerHosts(severName, FlagRPCHost)
  40. if err != nil {
  41. return err
  42. }
  43. // pick a random start index for round robin
  44. total := len(addrs)
  45. start := client.random.Intn(total)
  46. for idx := 0; idx < total; idx++ {
  47. addr := addrs[(start+idx)%total]
  48. mapkey := fmt.Sprintf("%s[%s]", severName, addr)
  49. if client.clients[mapkey] == nil {
  50. client.clients[mapkey], err = rpc.Dial("tcp", addr)
  51. if err != nil {
  52. Log.Warnf("RPC dial error : %s", err)
  53. continue
  54. }
  55. }
  56. err = rpcCallWithReconnect(client.clients[mapkey], addr, serverMethod, args, reply)
  57. if err != nil {
  58. Log.Warnf("RpcCallWithReconnect error : %s", err)
  59. continue
  60. }
  61. return nil
  62. }
  63. return errorf(err.Error())
  64. }
  65. // RPC call by host
  66. func (client *RPCClient) CallHost(host string, serverMethod string, args interface{}, reply interface{}) error {
  67. if client.clients[host] == nil {
  68. var err error
  69. client.clients[host], err = rpc.Dial("tcp", host)
  70. if err != nil {
  71. Log.Errorf("RPC dial error : %s", err)
  72. return err
  73. }
  74. }
  75. return rpcCallWithReconnect(client.clients[host], host, serverMethod, args, reply)
  76. }