rpc_client.go 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. // RPCClient implements a rpc client tool with reconnect and load balance.
  2. package server
  3. import (
  4. "fmt"
  5. "math/rand"
  6. "net/rpc"
  7. "time"
  8. "github.com/opentracing/opentracing-go"
  9. )
  10. type RPCClient struct {
  11. clients map[string]*rpc.Client
  12. random *rand.Rand
  13. }
  14. func NewRPCClient() (*RPCClient, error) {
  15. if serverInstance == nil {
  16. return nil, errorf(errServerNotInit)
  17. }
  18. if serverInstance.svrmgr == nil {
  19. return nil, errorf(errServerManagerNotInit)
  20. }
  21. return &RPCClient{
  22. clients: make(map[string]*rpc.Client),
  23. random: rand.New(rand.NewSource(time.Now().UnixNano())),
  24. }, nil
  25. }
  26. func rpcCallWithReconnect(client *rpc.Client, addr string, serverMethod string, args interface{}, reply interface{}) error {
  27. err := client.Call(serverMethod, args, reply)
  28. if err == rpc.ErrShutdown {
  29. Log.Warn("rpc connection shut down, trying to reconnect...")
  30. client, err = rpc.Dial("tcp", addr)
  31. if err != nil {
  32. return err
  33. }
  34. return client.Call(serverMethod, args, reply)
  35. }
  36. return err
  37. }
  38. // RPC call with reconnect and retry.
  39. func (client *RPCClient) Call(span opentracing.Span, severName string, serverMethod string, args interface{}, reply interface{}) error {
  40. defer span.Finish()
  41. addrs, err := serverInstance.svrmgr.GetServerHosts(severName, FlagRPCHost)
  42. if err != nil {
  43. return err
  44. }
  45. // pick a random start index for round robin
  46. total := len(addrs)
  47. start := client.random.Intn(total)
  48. for idx := 0; idx < total; idx++ {
  49. addr := addrs[(start+idx)%total]
  50. mapkey := fmt.Sprintf("%s[%s]", severName, addr)
  51. if client.clients[mapkey] == nil {
  52. client.clients[mapkey], err = rpc.Dial("tcp", addr)
  53. if err != nil {
  54. Log.Warnf("RPC dial error : %s", err)
  55. continue
  56. }
  57. }
  58. span.SetTag("server.method", serverMethod)
  59. span.SetTag("server.addr", addr)
  60. err = rpcCallWithReconnect(client.clients[mapkey], addr, serverMethod, args, reply)
  61. if err != nil {
  62. Log.WithField("method", serverMethod).Warnf("RpcCallWithReconnect error : %s", err)
  63. continue
  64. }
  65. return nil
  66. }
  67. return errorf(err.Error())
  68. }
  69. // RPC call by host
  70. func (client *RPCClient) CallHost(host string, serverMethod string, args interface{}, reply interface{}) error {
  71. if client.clients[host] == nil {
  72. var err error
  73. client.clients[host], err = rpc.Dial("tcp", host)
  74. if err != nil {
  75. Log.Errorf("RPC dial error : %s", err)
  76. return err
  77. }
  78. }
  79. return rpcCallWithReconnect(client.clients[host], host, serverMethod, args, reply)
  80. }