Skip to content

Commit 9d40357

Browse files
authored
Merge pull request #533 from MarkOtzen/port-blocking
Implement OnValidateTarget callback for target validation hooking
2 parents cb15c3c + fa3e7fe commit 9d40357

File tree

4 files changed

+256
-5
lines changed

4 files changed

+256
-5
lines changed

fastdialer/dialer_private.go

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"log/slog"
1111
"net"
1212
"os"
13+
"strconv"
1314
"strings"
1415
"sync/atomic"
1516
"time"
@@ -127,15 +128,30 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er
127128

128129
filteredIPs := []string{}
129130

131+
portInt, _ := strconv.Atoi(port)
132+
130133
// filter valid/invalid ips
131134
for _, ip := range IPS {
132-
// check if we have allow/deny list
133135
if !d.networkpolicy.Validate(ip) {
134136
if d.options.OnInvalidTarget != nil {
135137
d.options.OnInvalidTarget(hostname, ip, port)
136138
}
137139
continue
138140
}
141+
if !d.validatePort(portInt) {
142+
if d.options.OnInvalidTarget != nil {
143+
d.options.OnInvalidTarget(hostname, ip, port)
144+
}
145+
continue
146+
}
147+
if d.options.OnValidateTarget != nil {
148+
if err := d.options.OnValidateTarget(hostname, ip, port); err != nil {
149+
if d.options.OnInvalidTarget != nil {
150+
d.options.OnInvalidTarget(hostname, ip, port)
151+
}
152+
continue
153+
}
154+
}
139155
if d.options.OnBeforeDial != nil {
140156
d.options.OnBeforeDial(hostname, ip, port)
141157
}
@@ -436,3 +452,20 @@ func closeAfterTimeout(d time.Duration, c ...io.Closer) context.CancelFunc {
436452

437453
return ctxDone
438454
}
455+
456+
func (d *Dialer) validatePort(port int) bool {
457+
for _, p := range d.options.DenyPortList {
458+
if p == port {
459+
return false
460+
}
461+
}
462+
if len(d.options.AllowPortList) > 0 {
463+
for _, p := range d.options.AllowPortList {
464+
if p == port {
465+
return true
466+
}
467+
}
468+
return false
469+
}
470+
return true
471+
}

fastdialer/dialer_private_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func TestDial(t *testing.T) {
9696
if err != nil {
9797
t.Fatalf("Failed to start listener: %v", err)
9898
}
99-
defer listener.Close()
99+
defer func() { _ = listener.Close() }()
100100

101101
serverAddr := listener.Addr().String()
102102

@@ -109,7 +109,7 @@ func TestDial(t *testing.T) {
109109

110110
// hold conn w/o completing handshake
111111
time.Sleep(5 * time.Second)
112-
conn.Close()
112+
_ = conn.Close()
113113
}
114114
}()
115115

@@ -138,7 +138,7 @@ func TestDial(t *testing.T) {
138138
if err != nil {
139139
t.Fatalf("Failed to start listener: %v", err)
140140
}
141-
defer listener.Close()
141+
defer func() { _ = listener.Close() }()
142142

143143
serverAddr := listener.Addr().String()
144144

@@ -149,7 +149,7 @@ func TestDial(t *testing.T) {
149149
return
150150
}
151151
time.Sleep(5 * time.Second)
152-
conn.Close()
152+
_ = conn.Close()
153153
}
154154
}()
155155

fastdialer/dialer_test.go

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package fastdialer
22

33
import (
44
"context"
5+
"errors"
6+
"sync"
57
"testing"
68
)
79

@@ -57,3 +59,216 @@ func testDialer(t *testing.T, options Options) {
5759
t.Error("no A results found")
5860
}
5961
}
62+
63+
func TestDialerPortPolicy(t *testing.T) {
64+
t.Run("DenyPortBlocks", func(t *testing.T) {
65+
options := DefaultOptions
66+
options.DenyPortList = []int{80}
67+
68+
fd, err := NewDialer(options)
69+
if err != nil {
70+
t.Fatalf("couldn't create fastdialer instance: %s", err)
71+
}
72+
defer fd.Close()
73+
74+
ctx := context.Background()
75+
conn, err := fd.Dial(ctx, "tcp", "www.projectdiscovery.io:80")
76+
if conn != nil {
77+
_ = conn.Close()
78+
}
79+
if err != NoAddressAllowedError {
80+
t.Fatalf("expected NoAddressAllowedError for denied port, got: %v", err)
81+
}
82+
})
83+
84+
t.Run("DenyPortAllowsOther", func(t *testing.T) {
85+
options := DefaultOptions
86+
options.DenyPortList = []int{8081}
87+
88+
fd, err := NewDialer(options)
89+
if err != nil {
90+
t.Fatalf("couldn't create fastdialer instance: %s", err)
91+
}
92+
defer fd.Close()
93+
94+
ctx := context.Background()
95+
conn, err := fd.Dial(ctx, "tcp", "www.projectdiscovery.io:80")
96+
if err != nil || conn == nil {
97+
t.Fatalf("expected connection to succeed on non-denied port, got: %v", err)
98+
}
99+
_ = conn.Close()
100+
})
101+
102+
t.Run("AllowPortPermits", func(t *testing.T) {
103+
options := DefaultOptions
104+
options.AllowPortList = []int{80, 443}
105+
106+
fd, err := NewDialer(options)
107+
if err != nil {
108+
t.Fatalf("couldn't create fastdialer instance: %s", err)
109+
}
110+
defer fd.Close()
111+
112+
ctx := context.Background()
113+
conn, err := fd.Dial(ctx, "tcp", "www.projectdiscovery.io:80")
114+
if err != nil || conn == nil {
115+
t.Fatalf("expected connection to succeed on allowed port, got: %v", err)
116+
}
117+
_ = conn.Close()
118+
})
119+
120+
t.Run("AllowPortBlocksOther", func(t *testing.T) {
121+
options := DefaultOptions
122+
options.AllowPortList = []int{443}
123+
124+
fd, err := NewDialer(options)
125+
if err != nil {
126+
t.Fatalf("couldn't create fastdialer instance: %s", err)
127+
}
128+
defer fd.Close()
129+
130+
ctx := context.Background()
131+
conn, err := fd.Dial(ctx, "tcp", "www.projectdiscovery.io:80")
132+
if conn != nil {
133+
_ = conn.Close()
134+
}
135+
if err != NoAddressAllowedError {
136+
t.Fatalf("expected NoAddressAllowedError for non-allowed port, got: %v", err)
137+
}
138+
})
139+
140+
t.Run("NoPortPolicyUnchanged", func(t *testing.T) {
141+
options := DefaultOptions
142+
143+
fd, err := NewDialer(options)
144+
if err != nil {
145+
t.Fatalf("couldn't create fastdialer instance: %s", err)
146+
}
147+
defer fd.Close()
148+
149+
ctx := context.Background()
150+
conn, err := fd.Dial(ctx, "tcp", "www.projectdiscovery.io:80")
151+
if err != nil || conn == nil {
152+
t.Fatalf("expected connection to succeed without port policy, got: %v", err)
153+
}
154+
_ = conn.Close()
155+
})
156+
157+
t.Run("DenyPortTriggersOnInvalidTarget", func(t *testing.T) {
158+
options := DefaultOptions
159+
options.DenyPortList = []int{80}
160+
161+
var invalidCalled bool
162+
var mu sync.Mutex
163+
options.OnInvalidTarget = func(hostname, ip, port string) {
164+
mu.Lock()
165+
invalidCalled = true
166+
mu.Unlock()
167+
}
168+
169+
fd, err := NewDialer(options)
170+
if err != nil {
171+
t.Fatalf("couldn't create fastdialer instance: %s", err)
172+
}
173+
defer fd.Close()
174+
175+
ctx := context.Background()
176+
conn, _ := fd.Dial(ctx, "tcp", "www.projectdiscovery.io:80")
177+
if conn != nil {
178+
_ = conn.Close()
179+
}
180+
181+
mu.Lock()
182+
called := invalidCalled
183+
mu.Unlock()
184+
if !called {
185+
t.Error("OnInvalidTarget was not called for denied port")
186+
}
187+
})
188+
}
189+
190+
func TestDialerTargetValidation(t *testing.T) {
191+
t.Run("ValidTarget", func(t *testing.T) {
192+
options := DefaultOptions
193+
194+
var validateCalled bool
195+
options.OnValidateTarget = func(hostname, ip, port string) error {
196+
validateCalled = true
197+
if hostname != "www.projectdiscovery.io" {
198+
return errors.New("invalid hostname")
199+
}
200+
return nil
201+
}
202+
203+
var invalidCalled bool
204+
options.OnInvalidTarget = func(hostname, ip, port string) {
205+
invalidCalled = true
206+
}
207+
208+
fd, err := NewDialer(options)
209+
if err != nil {
210+
t.Fatalf("couldn't create fastdialer instance: %s", err)
211+
}
212+
defer fd.Close()
213+
214+
ctx := context.Background()
215+
conn, err := fd.Dial(ctx, "tcp", "www.projectdiscovery.io:80")
216+
if err != nil || conn == nil {
217+
t.Fatalf("couldn't connect to target: %s", err)
218+
}
219+
defer func() {
220+
_ = conn.Close()
221+
}()
222+
223+
if !validateCalled {
224+
t.Error("OnValidateTarget was not called")
225+
}
226+
if invalidCalled {
227+
t.Error("OnInvalidTarget was called for a valid target")
228+
}
229+
})
230+
231+
t.Run("InvalidTarget", func(t *testing.T) {
232+
options := DefaultOptions
233+
234+
var validateCalled bool
235+
options.OnValidateTarget = func(hostname, ip, port string) error {
236+
validateCalled = true
237+
return errors.New("target rejected")
238+
}
239+
240+
var invalidCalled bool
241+
var mu sync.Mutex
242+
options.OnInvalidTarget = func(hostname, ip, port string) {
243+
mu.Lock()
244+
invalidCalled = true
245+
mu.Unlock()
246+
}
247+
248+
fd, err := NewDialer(options)
249+
if err != nil {
250+
t.Fatalf("couldn't create fastdialer instance: %s", err)
251+
}
252+
defer fd.Close()
253+
254+
ctx := context.Background()
255+
conn, err := fd.Dial(ctx, "tcp", "www.projectdiscovery.io:80")
256+
if err != NoAddressAllowedError {
257+
if conn != nil {
258+
_ = conn.Close()
259+
}
260+
t.Fatalf("expected NoAddressAllowedError, got: %v", err)
261+
}
262+
263+
if !validateCalled {
264+
t.Error("OnValidateTarget was not called")
265+
}
266+
267+
mu.Lock()
268+
called := invalidCalled
269+
mu.Unlock()
270+
if !called {
271+
t.Error("OnInvalidTarget was not called for an invalid target")
272+
}
273+
})
274+
}

fastdialer/options.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ type Options struct {
5757
WithZTLS bool
5858
SNIName string
5959
OnBeforeDial func(hostname, IP, port string)
60+
// OnValidateTarget is called after network policy validation and before dialing.
61+
// If it returns an error, the target is considered invalid.
62+
OnValidateTarget func(hostname, IP, port string) error
6063
OnInvalidTarget func(hostname, IP, port string)
6164
OnDialCallback func(hostname, IP string)
6265
DisableZtlsFallback bool

0 commit comments

Comments
 (0)