Skip to content

Commit 3ac08a0

Browse files
committed
proxy: add Dial, DialTimeout, and DialContext
The existing API does not allow client code to take advantage of Dialer implementations that implement DialTimeout and DialContext receivers. These functions provide a familiar API, see Dial and DialTimeout in the net package. Signed-off-by: Jacob Blain Christen <[email protected]>
1 parent addf6b3 commit 3ac08a0

File tree

7 files changed

+374
-18
lines changed

7 files changed

+374
-18
lines changed

proxy/dial.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package proxy
2+
3+
import (
4+
"context"
5+
"net"
6+
"time"
7+
)
8+
9+
// A ContextDialer is a means to establish a connection, with context.
10+
type ContextDialer interface {
11+
DialContext(ctx context.Context, network, address string) (net.Conn, error)
12+
}
13+
14+
// Dial works like net.Dial but using a Dialer derived from the configured proxy environment.
15+
func Dial(network, address string) (net.Conn, error) {
16+
d := FromEnvironment()
17+
return d.Dial(network, address)
18+
}
19+
20+
// DialTimeout works like net.DialTimeout but using a Dialer derived from the configured proxy environment.
21+
// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer
22+
// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
23+
// Passing a timeout of zero or less is equivalent to calling Dial.
24+
func DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
25+
d := FromEnvironment()
26+
if timeout <= 0 {
27+
return d.Dial(network, address)
28+
}
29+
30+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
31+
defer cancel()
32+
33+
if xd, ok := d.(ContextDialer); ok {
34+
return xd.DialContext(ctx, network, address)
35+
}
36+
return dialContext(ctx, network, address, d)
37+
}
38+
39+
// DialContext works like DialContext on net.Dialer but using a dialer derived from the configured proxy environment.
40+
// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer
41+
// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
42+
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
43+
d := FromEnvironment()
44+
if xd, ok := d.(ContextDialer); ok {
45+
return xd.DialContext(ctx, network, address)
46+
}
47+
return dialContext(ctx, network, address, d)
48+
}
49+
50+
// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout
51+
// A Conn returned from a successfil Dial after the context has been cancelled will be immediately closed.
52+
func dialContext(ctx context.Context, network, address string, d Dialer) (net.Conn, error) {
53+
var (
54+
conn net.Conn
55+
done = make(chan struct{}, 1)
56+
err error
57+
)
58+
go func() {
59+
conn, err = d.Dial(network, address)
60+
close(done)
61+
if conn != nil && ctx.Err() != nil {
62+
conn.Close()
63+
}
64+
}()
65+
select {
66+
case <-ctx.Done():
67+
err = ctx.Err()
68+
case <-done:
69+
}
70+
return conn, err
71+
}

proxy/dial_test.go

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
package proxy
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"os"
8+
"testing"
9+
"time"
10+
11+
"golang.org/x/net/internal/sockstest"
12+
)
13+
14+
func TestDial(t *testing.T) {
15+
ResetProxyEnv()
16+
t.Run("Direct", func(t *testing.T) {
17+
defer ResetProxyEnv()
18+
l, err := net.Listen("tcp", "127.0.0.1:0")
19+
if err != nil {
20+
t.Fatal(err)
21+
}
22+
defer l.Close()
23+
_, port, err := net.SplitHostPort(l.Addr().String())
24+
if err != nil {
25+
t.Fatal(err)
26+
}
27+
c, err := Dial(l.Addr().Network(), net.JoinHostPort("", port))
28+
if err != nil {
29+
t.Fatal(err)
30+
}
31+
c.Close()
32+
})
33+
t.Run("SOCKS5", func(t *testing.T) {
34+
defer ResetProxyEnv()
35+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
36+
if err != nil {
37+
t.Fatal(err)
38+
}
39+
defer s.Close()
40+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
41+
t.Fatal(err)
42+
}
43+
c, err := Dial(s.TargetAddr().Network(), s.TargetAddr().String())
44+
if err != nil {
45+
t.Fatal(err)
46+
}
47+
c.Close()
48+
})
49+
}
50+
51+
func TestDialContext(t *testing.T) {
52+
ResetProxyEnv()
53+
t.Run("DirectWithCancel", func(t *testing.T) {
54+
defer ResetProxyEnv()
55+
l, err := net.Listen("tcp", "127.0.0.1:0")
56+
if err != nil {
57+
t.Fatal(err)
58+
}
59+
defer l.Close()
60+
_, port, err := net.SplitHostPort(l.Addr().String())
61+
if err != nil {
62+
t.Fatal(err)
63+
}
64+
ctx, cancel := context.WithCancel(context.Background())
65+
defer cancel()
66+
c, err := DialContext(ctx, l.Addr().Network(), net.JoinHostPort("", port))
67+
if err != nil {
68+
t.Fatal(err)
69+
}
70+
c.Close()
71+
})
72+
t.Run("DirectWithTimeout", func(t *testing.T) {
73+
defer ResetProxyEnv()
74+
l, err := net.Listen("tcp", "127.0.0.1:0")
75+
if err != nil {
76+
t.Fatal(err)
77+
}
78+
defer l.Close()
79+
_, port, err := net.SplitHostPort(l.Addr().String())
80+
if err != nil {
81+
t.Fatal(err)
82+
}
83+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
84+
defer cancel()
85+
c, err := DialContext(ctx, l.Addr().Network(), net.JoinHostPort("", port))
86+
if err != nil {
87+
t.Fatal(err)
88+
}
89+
c.Close()
90+
})
91+
t.Run("DirectWithTimeoutExceeded", func(t *testing.T) {
92+
defer ResetProxyEnv()
93+
l, err := net.Listen("tcp", "127.0.0.1:0")
94+
if err != nil {
95+
t.Fatal(err)
96+
}
97+
defer l.Close()
98+
_, port, err := net.SplitHostPort(l.Addr().String())
99+
if err != nil {
100+
t.Fatal(err)
101+
}
102+
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
103+
time.Sleep(time.Millisecond)
104+
defer cancel()
105+
c, err := DialContext(ctx, l.Addr().Network(), net.JoinHostPort("", port))
106+
if err == nil {
107+
defer c.Close()
108+
t.Fatal("failed to timeout")
109+
}
110+
})
111+
t.Run("SOCKS5", func(t *testing.T) {
112+
defer ResetProxyEnv()
113+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
114+
if err != nil {
115+
t.Fatal(err)
116+
}
117+
defer s.Close()
118+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
119+
t.Fatal(err)
120+
}
121+
c, err := DialContext(context.Background(), s.TargetAddr().Network(), s.TargetAddr().String())
122+
if err != nil {
123+
t.Fatal(err)
124+
}
125+
c.Close()
126+
})
127+
t.Run("SOCKS5WithTimeout", func(t *testing.T) {
128+
defer ResetProxyEnv()
129+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
130+
if err != nil {
131+
t.Fatal(err)
132+
}
133+
defer s.Close()
134+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
135+
t.Fatal(err)
136+
}
137+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
138+
defer cancel()
139+
c, err := DialContext(ctx, s.TargetAddr().Network(), s.TargetAddr().String())
140+
if err != nil {
141+
t.Fatal(err)
142+
}
143+
c.Close()
144+
})
145+
t.Run("SOCKS5WithTimeoutExceeded", func(t *testing.T) {
146+
defer ResetProxyEnv()
147+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
148+
if err != nil {
149+
t.Fatal(err)
150+
}
151+
defer s.Close()
152+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
153+
t.Fatal(err)
154+
}
155+
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
156+
time.Sleep(time.Millisecond)
157+
defer cancel()
158+
c, err := DialContext(ctx, s.TargetAddr().Network(), s.TargetAddr().String())
159+
if err == nil {
160+
defer c.Close()
161+
t.Fatal("failed to timeout")
162+
}
163+
})
164+
}
165+
166+
func TestDialTimeout(t *testing.T) {
167+
ResetProxyEnv()
168+
t.Run("Direct", func(t *testing.T) {
169+
defer ResetProxyEnv()
170+
l, err := net.Listen("tcp", "127.0.0.1:0")
171+
if err != nil {
172+
t.Fatal(err)
173+
}
174+
defer l.Close()
175+
_, port, err := net.SplitHostPort(l.Addr().String())
176+
if err != nil {
177+
t.Fatal(err)
178+
}
179+
c, err := DialTimeout(l.Addr().Network(), net.JoinHostPort("", port), 5*time.Second)
180+
if err != nil {
181+
t.Fatal(err)
182+
}
183+
c.Close()
184+
})
185+
t.Run("DirectTooSlow", func(t *testing.T) {
186+
defer ResetProxyEnv()
187+
l, err := net.Listen("tcp", "127.0.0.1:0")
188+
if err != nil {
189+
t.Fatal(err)
190+
}
191+
defer l.Close()
192+
_, port, err := net.SplitHostPort(l.Addr().String())
193+
if err != nil {
194+
t.Fatal(err)
195+
}
196+
c, err := DialTimeout(l.Addr().Network(), net.JoinHostPort("", port), time.Nanosecond)
197+
if err == nil {
198+
defer c.Close()
199+
t.Fatal("failed to timeout")
200+
}
201+
})
202+
t.Run("SOCKS5", func(t *testing.T) {
203+
defer ResetProxyEnv()
204+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
205+
if err != nil {
206+
t.Fatal(err)
207+
}
208+
defer s.Close()
209+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
210+
t.Fatal(err)
211+
}
212+
c, err := DialTimeout(s.TargetAddr().Network(), s.TargetAddr().String(), 5*time.Second)
213+
if err != nil {
214+
t.Fatal(err)
215+
}
216+
c.Close()
217+
})
218+
t.Run("SOCKS5TooSlow", func(t *testing.T) {
219+
defer ResetProxyEnv()
220+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
221+
if err != nil {
222+
t.Fatal(err)
223+
}
224+
defer s.Close()
225+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
226+
t.Fatal(err)
227+
}
228+
c, err := DialTimeout(s.TargetAddr().Network(), s.TargetAddr().String(), time.Nanosecond)
229+
if err == nil {
230+
defer c.Close()
231+
t.Fatal("failed to timeout")
232+
}
233+
})
234+
}

proxy/direct.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package proxy
66

77
import (
8+
"context"
89
"net"
910
)
1011

@@ -13,6 +14,13 @@ type direct struct{}
1314
// Direct is a direct proxy: one that makes network connections directly.
1415
var Direct = direct{}
1516

17+
// Dial directly invokes net.Dial with the supplied parameters.
1618
func (direct) Dial(network, addr string) (net.Conn, error) {
1719
return net.Dial(network, addr)
1820
}
21+
22+
// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters.
23+
func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
24+
var d net.Dialer
25+
return d.DialContext(ctx, network, addr)
26+
}

proxy/per_host.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package proxy
66

77
import (
8+
"context"
89
"net"
910
"strings"
1011
)
@@ -41,6 +42,20 @@ func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
4142
return p.dialerForRequest(host).Dial(network, addr)
4243
}
4344

45+
// DialContext connects to the address addr on the given network through either
46+
// defaultDialer or bypass.
47+
func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) {
48+
host, _, err := net.SplitHostPort(addr)
49+
if err != nil {
50+
return nil, err
51+
}
52+
d := p.dialerForRequest(host)
53+
if x, ok := d.(ContextDialer); ok {
54+
return x.DialContext(ctx, network, addr)
55+
}
56+
return dialContext(ctx, network, addr, d)
57+
}
58+
4459
func (p *PerHost) dialerForRequest(host string) Dialer {
4560
if ip := net.ParseIP(host); ip != nil {
4661
for _, net := range p.bypassNetworks {

0 commit comments

Comments
 (0)