|
@@ -4,10 +4,12 @@ import (
|
|
|
"fmt"
|
|
|
"math/rand"
|
|
|
"net/rpc"
|
|
|
+ "sync"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
type RPCClient struct {
|
|
|
+ mu sync.Mutex
|
|
|
clients map[string]*rpc.Client
|
|
|
random *rand.Rand
|
|
|
}
|
|
@@ -25,41 +27,43 @@ func NewRPCClient() (*RPCClient, error) {
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
-func rpcCallWithReconnect(client *rpc.Client, addr string, serverMethod string, args interface{}, reply interface{}) error {
|
|
|
+func (a *RPCClient) rpcCallWithReconnect(key string, client *rpc.Client, addr string, serverMethod string, args interface{}, reply interface{}) error {
|
|
|
err := client.Call(serverMethod, args, reply)
|
|
|
if err == rpc.ErrShutdown {
|
|
|
Log.Warnf("rpc %s connection shut down, trying to reconnect...", addr)
|
|
|
client, err = rpc.Dial("tcp", addr)
|
|
|
if err != nil {
|
|
|
- Log.Debugf("重新连接%s失败:%s", addr, err.Error())
|
|
|
return err
|
|
|
}
|
|
|
+ a.clients[key] = client
|
|
|
return client.Call(serverMethod, args, reply)
|
|
|
}
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
// Call RPC call with reconnect and retry.
|
|
|
-func (client *RPCClient) Call(severName string, serverMethod string, args interface{}, reply interface{}) error {
|
|
|
+func (a *RPCClient) Call(severName string, serverMethod string, args interface{}, reply interface{}) error {
|
|
|
+ a.mu.Lock()
|
|
|
+ defer a.mu.Unlock()
|
|
|
addrs, err := serverInstance.serverManager.GetServerHosts(severName, FlagRPCHost)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
// pick a random start index for round robin
|
|
|
total := len(addrs)
|
|
|
- start := client.random.Intn(total)
|
|
|
+ start := a.random.Intn(total)
|
|
|
|
|
|
for idx := 0; idx < total; idx++ {
|
|
|
addr := addrs[(start+idx)%total]
|
|
|
mapkey := fmt.Sprintf("%s[%s]", severName, addr)
|
|
|
- if client.clients[mapkey] == nil {
|
|
|
- client.clients[mapkey], err = rpc.Dial("tcp", addr)
|
|
|
+ if a.clients[mapkey] == nil {
|
|
|
+ a.clients[mapkey], err = rpc.Dial("tcp", addr)
|
|
|
if err != nil {
|
|
|
Log.Warnf("RPC dial error : %s", err)
|
|
|
continue
|
|
|
}
|
|
|
}
|
|
|
- err = rpcCallWithReconnect(client.clients[mapkey], addr, serverMethod, args, reply)
|
|
|
+ err = a.rpcCallWithReconnect(mapkey, a.clients[mapkey], addr, serverMethod, args, reply)
|
|
|
if err != nil {
|
|
|
continue
|
|
|
}
|
|
@@ -70,14 +74,14 @@ func (client *RPCClient) Call(severName string, serverMethod string, args interf
|
|
|
}
|
|
|
|
|
|
// CallHost RPC call by host
|
|
|
-func (client *RPCClient) CallHost(host string, serverMethod string, args interface{}, reply interface{}) error {
|
|
|
- if client.clients[host] == nil {
|
|
|
+func (a *RPCClient) CallHost(host string, serverMethod string, args interface{}, reply interface{}) error {
|
|
|
+ if a.clients[host] == nil {
|
|
|
var err error
|
|
|
- client.clients[host], err = rpc.Dial("tcp", host)
|
|
|
+ a.clients[host], err = rpc.Dial("tcp", host)
|
|
|
if err != nil {
|
|
|
Log.Errorf("RPC dial error : %s", err)
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
|
- return rpcCallWithReconnect(client.clients[host], host, serverMethod, args, reply)
|
|
|
+ return a.rpcCallWithReconnect(host, a.clients[host], host, serverMethod, args, reply)
|
|
|
}
|