@@ -17,6 +17,8 @@ import (
1717
1818 core "github.com/v2fly/v2ray-core/v5"
1919 "github.com/v2fly/v2ray-core/v5/common"
20+ "github.com/v2fly/v2ray-core/v5/common/environment"
21+ "github.com/v2fly/v2ray-core/v5/common/environment/envctx"
2022 "github.com/v2fly/v2ray-core/v5/common/net"
2123 "github.com/v2fly/v2ray-core/v5/common/session"
2224 "github.com/v2fly/v2ray-core/v5/transport/internet"
@@ -38,12 +40,25 @@ func init() {
3840 common .Must (internet .RegisterTransportDialer (protocolName , Dial ))
3941}
4042
41- type dialerCanceller func ()
43+ type transportConnectionState struct {
44+ scopedDialerMap map [net.Destination ]* grpc.ClientConn
45+ scopedDialerAccess sync.Mutex
46+ }
4247
43- var (
44- globalDialerMap map [net.Destination ]* grpc.ClientConn
45- globalDialerAccess sync.Mutex
46- )
48+ func (t * transportConnectionState ) IsTransientStorageLifecycleReceiver () {
49+ }
50+
51+ func (t * transportConnectionState ) Close () error {
52+ t .scopedDialerAccess .Lock ()
53+ defer t .scopedDialerAccess .Unlock ()
54+ for _ , conn := range t .scopedDialerMap {
55+ _ = conn .Close ()
56+ }
57+ t .scopedDialerMap = nil
58+ return nil
59+ }
60+
61+ type dialerCanceller func ()
4762
4863func dialgRPC (ctx context.Context , dest net.Destination , streamSettings * internet.MemoryStreamConfig ) (net.Conn , error ) {
4964 grpcSettings := streamSettings .ProtocolSettings .(* Config )
@@ -70,25 +85,36 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne
7085}
7186
7287func getGrpcClient (ctx context.Context , dest net.Destination , dialOption grpc.DialOption , streamSettings * internet.MemoryStreamConfig ) (* grpc.ClientConn , dialerCanceller , error ) {
73- globalDialerAccess .Lock ()
74- defer globalDialerAccess .Unlock ()
88+ transportEnvironment := envctx .EnvironmentFromContext (ctx ).(environment.TransportEnvironment )
89+ state , err := transportEnvironment .TransientStorage ().Get (ctx , "grpc-transport-connection-state" )
90+ if err != nil {
91+ state = & transportConnectionState {}
92+ transportEnvironment .TransientStorage ().Put (ctx , "grpc-transport-connection-state" , state )
93+ state , err = transportEnvironment .TransientStorage ().Get (ctx , "grpc-transport-connection-state" )
94+ if err != nil {
95+ return nil , nil , newError ("failed to get grpc transport connection state" ).Base (err )
96+ }
97+ }
98+ stateTyped := state .(* transportConnectionState )
99+
100+ stateTyped .scopedDialerAccess .Lock ()
101+ defer stateTyped .scopedDialerAccess .Unlock ()
75102
76- if globalDialerMap == nil {
77- globalDialerMap = make (map [net.Destination ]* grpc.ClientConn )
103+ if stateTyped . scopedDialerMap == nil {
104+ stateTyped . scopedDialerMap = make (map [net.Destination ]* grpc.ClientConn )
78105 }
79106
80107 canceller := func () {
81- globalDialerAccess .Lock ()
82- defer globalDialerAccess .Unlock ()
83- delete (globalDialerMap , dest )
108+ stateTyped . scopedDialerAccess .Lock ()
109+ defer stateTyped . scopedDialerAccess .Unlock ()
110+ delete (stateTyped . scopedDialerMap , dest )
84111 }
85112
86- // TODO Should support chain proxy to the same destination
87- if client , found := globalDialerMap [dest ]; found && client .GetState () != connectivity .Shutdown {
113+ if client , found := stateTyped .scopedDialerMap [dest ]; found && client .GetState () != connectivity .Shutdown {
88114 return client , canceller , nil
89115 }
90116
91- conn , err := grpc .Dial (
117+ conn , err := grpc .NewClient (
92118 dest .Address .String ()+ ":" + dest .Port .String (),
93119 dialOption ,
94120 grpc .WithConnectParams (grpc.ConnectParams {
@@ -117,6 +143,14 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di
117143 return internet .DialSystem (detachedContext , net .TCPDestination (address , port ), streamSettings .SocketSettings )
118144 }),
119145 )
120- globalDialerMap [dest ] = conn
146+ canceller = func () {
147+ stateTyped .scopedDialerAccess .Lock ()
148+ defer stateTyped .scopedDialerAccess .Unlock ()
149+ delete (stateTyped .scopedDialerMap , dest )
150+ if err != nil {
151+ conn .Close ()
152+ }
153+ }
154+ stateTyped .scopedDialerMap [dest ] = conn
121155 return conn , canceller , err
122156}
0 commit comments