ソースを参照

fix: 修复rpc重连BUG

lijian 2 年 前
コミット
324a06b6f3
1 ファイル変更15 行追加11 行削除
  1. 15 11
      pkg/server/rpc_client.go

+ 15 - 11
pkg/server/rpc_client.go

@@ -4,10 +4,12 @@ import (
 	"fmt"
 	"math/rand"
 	"net/rpc"
+	"sync"
 	"time"
 )
 
 type RPCClient struct {
+	mu      sync.Mutex
 	clients map[string]*rpc.Client
 	random  *rand.Rand
 }
@@ -25,41 +27,43 @@ func NewRPCClient() (*RPCClient, error) {
 	}, nil
 }
 
-func rpcCallWithReconnect(client *rpc.Client, addr string, serverMethod string, args interface{}, reply interface{}) error {
+func (a *RPCClient) rpcCallWithReconnect(key string, client *rpc.Client, addr string, serverMethod string, args interface{}, reply interface{}) error {
 	err := client.Call(serverMethod, args, reply)
 	if err == rpc.ErrShutdown {
 		Log.Warnf("rpc %s connection shut down, trying to reconnect...", addr)
 		client, err = rpc.Dial("tcp", addr)
 		if err != nil {
-			Log.Debugf("重新连接%s失败:%s", addr, err.Error())
 			return err
 		}
+		a.clients[key] = client
 		return client.Call(serverMethod, args, reply)
 	}
 	return err
 }
 
 // Call RPC call with reconnect and retry.
-func (client *RPCClient) Call(severName string, serverMethod string, args interface{}, reply interface{}) error {
+func (a *RPCClient) Call(severName string, serverMethod string, args interface{}, reply interface{}) error {
+	a.mu.Lock()
+	defer a.mu.Unlock()
 	addrs, err := serverInstance.serverManager.GetServerHosts(severName, FlagRPCHost)
 	if err != nil {
 		return err
 	}
 	// pick a random start index for round robin
 	total := len(addrs)
-	start := client.random.Intn(total)
+	start := a.random.Intn(total)
 
 	for idx := 0; idx < total; idx++ {
 		addr := addrs[(start+idx)%total]
 		mapkey := fmt.Sprintf("%s[%s]", severName, addr)
-		if client.clients[mapkey] == nil {
-			client.clients[mapkey], err = rpc.Dial("tcp", addr)
+		if a.clients[mapkey] == nil {
+			a.clients[mapkey], err = rpc.Dial("tcp", addr)
 			if err != nil {
 				Log.Warnf("RPC dial error : %s", err)
 				continue
 			}
 		}
-		err = rpcCallWithReconnect(client.clients[mapkey], addr, serverMethod, args, reply)
+		err = a.rpcCallWithReconnect(mapkey, a.clients[mapkey], addr, serverMethod, args, reply)
 		if err != nil {
 			continue
 		}
@@ -70,14 +74,14 @@ func (client *RPCClient) Call(severName string, serverMethod string, args interf
 }
 
 // CallHost RPC call by host
-func (client *RPCClient) CallHost(host string, serverMethod string, args interface{}, reply interface{}) error {
-	if client.clients[host] == nil {
+func (a *RPCClient) CallHost(host string, serverMethod string, args interface{}, reply interface{}) error {
+	if a.clients[host] == nil {
 		var err error
-		client.clients[host], err = rpc.Dial("tcp", host)
+		a.clients[host], err = rpc.Dial("tcp", host)
 		if err != nil {
 			Log.Errorf("RPC dial error : %s", err)
 			return err
 		}
 	}
-	return rpcCallWithReconnect(client.clients[host], host, serverMethod, args, reply)
+	return a.rpcCallWithReconnect(host, a.clients[host], host, serverMethod, args, reply)
 }