Skip to content

Commit e98ca17

Browse files
committed
Add transient storage lifecycle receiver framework
1 parent 8eef1bb commit e98ca17

File tree

4 files changed

+70
-17
lines changed

4 files changed

+70
-17
lines changed

common/environment/rootcap_impl.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ func (r *rootEnvImpl) DropProxyEnvironment(tag string) error {
7474
if err != nil {
7575
return err
7676
}
77-
return transientStorage.DropScope(r.ctx, tag)
77+
transientStorage.Clear(r.ctx)
78+
return r.transientStorage.DropScope(r.ctx, tag)
7879
}
7980

8081
type appEnvImpl struct {

common/environment/transientstorageimpl/storage.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,16 @@ func (s *scopedTransientStorageImpl) List(ctx context.Context, keyPrefix string)
5656
func (s *scopedTransientStorageImpl) Clear(ctx context.Context) {
5757
s.access.Lock()
5858
defer s.access.Unlock()
59+
for _, v := range s.values {
60+
if sw, ok := v.(storage.TransientStorageLifecycleReceiver); ok {
61+
_ = sw.Close()
62+
}
63+
}
5964
s.values = map[string]interface{}{}
65+
for _, v := range s.scopes {
66+
v.Clear(ctx)
67+
}
68+
s.scopes = map[string]storage.ScopedTransientStorage{}
6069
}
6170

6271
func (s *scopedTransientStorageImpl) NarrowScope(ctx context.Context, key string) (storage.ScopedTransientStorage, error) {
@@ -74,6 +83,9 @@ func (s *scopedTransientStorageImpl) NarrowScope(ctx context.Context, key string
7483
func (s *scopedTransientStorageImpl) DropScope(ctx context.Context, key string) error {
7584
s.access.Lock()
7685
defer s.access.Unlock()
86+
if v, ok := s.scopes[key]; ok {
87+
v.Clear(ctx)
88+
}
7789
delete(s.scopes, key)
7890
return nil
7991
}

features/extension/storage/storage.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package storage
33
import (
44
"context"
55

6+
"github.com/v2fly/v2ray-core/v5/common"
67
"github.com/v2fly/v2ray-core/v5/features"
78
)
89

@@ -32,3 +33,8 @@ type ScopedPersistentStorageService interface {
3233
}
3334

3435
var ScopedPersistentStorageServiceType = (*ScopedPersistentStorageService)(nil)
36+
37+
type TransientStorageLifecycleReceiver interface {
38+
IsTransientStorageLifecycleReceiver()
39+
common.Closable
40+
}

transport/internet/grpc/dial.go

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import (
1717

1818
core "github.com/v2fly/v2ray-core/v5"
1919
"github.com/v2fly/v2ray-core/v5/common"
20+
"github.com/v2fly/v2ray-core/v5/common/environment"
21+
"github.com/v2fly/v2ray-core/v5/common/environment/envctx"
2022
"github.com/v2fly/v2ray-core/v5/common/net"
2123
"github.com/v2fly/v2ray-core/v5/common/session"
2224
"github.com/v2fly/v2ray-core/v5/transport/internet"
@@ -38,12 +40,25 @@ func init() {
3840
common.Must(internet.RegisterTransportDialer(protocolName, Dial))
3941
}
4042

41-
type dialerCanceller func()
43+
type transportConnectionState struct {
44+
scopedDialerMap map[net.Destination]*grpc.ClientConn
45+
scopedDialerAccess sync.Mutex
46+
}
4247

43-
var (
44-
globalDialerMap map[net.Destination]*grpc.ClientConn
45-
globalDialerAccess sync.Mutex
46-
)
48+
func (t *transportConnectionState) IsTransientStorageLifecycleReceiver() {
49+
}
50+
51+
func (t *transportConnectionState) Close() error {
52+
t.scopedDialerAccess.Lock()
53+
defer t.scopedDialerAccess.Unlock()
54+
for _, conn := range t.scopedDialerMap {
55+
_ = conn.Close()
56+
}
57+
t.scopedDialerMap = nil
58+
return nil
59+
}
60+
61+
type dialerCanceller func()
4762

4863
func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
4964
grpcSettings := streamSettings.ProtocolSettings.(*Config)
@@ -70,25 +85,36 @@ func dialgRPC(ctx context.Context, dest net.Destination, streamSettings *interne
7085
}
7186

7287
func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.DialOption, streamSettings *internet.MemoryStreamConfig) (*grpc.ClientConn, dialerCanceller, error) {
73-
globalDialerAccess.Lock()
74-
defer globalDialerAccess.Unlock()
88+
transportEnvironment := envctx.EnvironmentFromContext(ctx).(environment.TransportEnvironment)
89+
state, err := transportEnvironment.TransientStorage().Get(ctx, "grpc-transport-connection-state")
90+
if err != nil {
91+
state = &transportConnectionState{}
92+
transportEnvironment.TransientStorage().Put(ctx, "grpc-transport-connection-state", state)
93+
state, err = transportEnvironment.TransientStorage().Get(ctx, "grpc-transport-connection-state")
94+
if err != nil {
95+
return nil, nil, newError("failed to get grpc transport connection state").Base(err)
96+
}
97+
}
98+
stateTyped := state.(*transportConnectionState)
99+
100+
stateTyped.scopedDialerAccess.Lock()
101+
defer stateTyped.scopedDialerAccess.Unlock()
75102

76-
if globalDialerMap == nil {
77-
globalDialerMap = make(map[net.Destination]*grpc.ClientConn)
103+
if stateTyped.scopedDialerMap == nil {
104+
stateTyped.scopedDialerMap = make(map[net.Destination]*grpc.ClientConn)
78105
}
79106

80107
canceller := func() {
81-
globalDialerAccess.Lock()
82-
defer globalDialerAccess.Unlock()
83-
delete(globalDialerMap, dest)
108+
stateTyped.scopedDialerAccess.Lock()
109+
defer stateTyped.scopedDialerAccess.Unlock()
110+
delete(stateTyped.scopedDialerMap, dest)
84111
}
85112

86-
// TODO Should support chain proxy to the same destination
87-
if client, found := globalDialerMap[dest]; found && client.GetState() != connectivity.Shutdown {
113+
if client, found := stateTyped.scopedDialerMap[dest]; found && client.GetState() != connectivity.Shutdown {
88114
return client, canceller, nil
89115
}
90116

91-
conn, err := grpc.Dial(
117+
conn, err := grpc.NewClient(
92118
dest.Address.String()+":"+dest.Port.String(),
93119
dialOption,
94120
grpc.WithConnectParams(grpc.ConnectParams{
@@ -117,6 +143,14 @@ func getGrpcClient(ctx context.Context, dest net.Destination, dialOption grpc.Di
117143
return internet.DialSystem(detachedContext, net.TCPDestination(address, port), streamSettings.SocketSettings)
118144
}),
119145
)
120-
globalDialerMap[dest] = conn
146+
canceller = func() {
147+
stateTyped.scopedDialerAccess.Lock()
148+
defer stateTyped.scopedDialerAccess.Unlock()
149+
delete(stateTyped.scopedDialerMap, dest)
150+
if err != nil {
151+
conn.Close()
152+
}
153+
}
154+
stateTyped.scopedDialerMap[dest] = conn
121155
return conn, canceller, err
122156
}

0 commit comments

Comments
 (0)