Skip to content

Commit f09f2bd

Browse files
authored
Fix open redirect vulnerability with AddTrailingSlashWithConfig and RemoveTrailingSlashWithConfig (#1775,#1771)
* fix open redirect vulnerability with AddTrailingSlashWithConfig and RemoveTrailingSlashWithConfig (fix #1771) * rename trimMultipleSlashes to sanitizeURI
1 parent 932976d commit f09f2bd

File tree

2 files changed

+273
-82
lines changed

2 files changed

+273
-82
lines changed

middleware/slash.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
6060

6161
// Redirect
6262
if config.RedirectCode != 0 {
63-
return c.Redirect(config.RedirectCode, uri)
63+
return c.Redirect(config.RedirectCode, sanitizeURI(uri))
6464
}
6565

6666
// Forward
@@ -108,7 +108,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
108108

109109
// Redirect
110110
if config.RedirectCode != 0 {
111-
return c.Redirect(config.RedirectCode, uri)
111+
return c.Redirect(config.RedirectCode, sanitizeURI(uri))
112112
}
113113

114114
// Forward
@@ -119,3 +119,12 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
119119
}
120120
}
121121
}
122+
123+
func sanitizeURI(uri string) string {
124+
// double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
125+
// we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
126+
if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
127+
uri = "/" + strings.TrimLeft(uri, `/\`)
128+
}
129+
return uri
130+
}

middleware/slash_test.go

Lines changed: 262 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -9,88 +9,270 @@ import (
99
"github.com/stretchr/testify/assert"
1010
)
1111

12+
func TestAddTrailingSlashWithConfig(t *testing.T) {
13+
var testCases = []struct {
14+
whenURL string
15+
whenMethod string
16+
expectPath string
17+
expectLocation []string
18+
expectStatus int
19+
}{
20+
{
21+
whenURL: "/add-slash",
22+
whenMethod: http.MethodGet,
23+
expectPath: "/add-slash",
24+
expectLocation: []string{`/add-slash/`},
25+
},
26+
{
27+
whenURL: "/add-slash?key=value",
28+
whenMethod: http.MethodGet,
29+
expectPath: "/add-slash",
30+
expectLocation: []string{`/add-slash/?key=value`},
31+
},
32+
{
33+
whenURL: "/",
34+
whenMethod: http.MethodConnect,
35+
expectPath: "/",
36+
expectLocation: nil,
37+
expectStatus: http.StatusOK,
38+
},
39+
// cases for open redirect vulnerability
40+
{
41+
whenURL: "http://localhost:1323/%5Cexample.com",
42+
expectPath: `/\example.com`,
43+
expectLocation: []string{`/example.com/`},
44+
},
45+
{
46+
whenURL: `http://localhost:1323/\example.com`,
47+
expectPath: `/\example.com`,
48+
expectLocation: []string{`/example.com/`},
49+
},
50+
{
51+
whenURL: `http://localhost:1323/\\%5C////%5C\\\example.com`,
52+
expectPath: `/\\\////\\\\example.com`,
53+
expectLocation: []string{`/example.com/`},
54+
},
55+
{
56+
whenURL: "http://localhost:1323//example.com",
57+
expectPath: `//example.com`,
58+
expectLocation: []string{`/example.com/`},
59+
},
60+
{
61+
whenURL: "http://localhost:1323/%5C%5C",
62+
expectPath: `/\\`,
63+
expectLocation: []string{`/`},
64+
},
65+
}
66+
for _, tc := range testCases {
67+
t.Run(tc.whenURL, func(t *testing.T) {
68+
e := echo.New()
69+
70+
mw := AddTrailingSlashWithConfig(TrailingSlashConfig{
71+
RedirectCode: http.StatusMovedPermanently,
72+
})
73+
h := mw(func(c echo.Context) error {
74+
return nil
75+
})
76+
77+
rec := httptest.NewRecorder()
78+
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
79+
c := e.NewContext(req, rec)
80+
81+
err := h(c)
82+
assert.NoError(t, err)
83+
84+
assert.Equal(t, tc.expectPath, req.URL.Path)
85+
assert.Equal(t, tc.expectLocation, rec.Header()[echo.HeaderLocation])
86+
if tc.expectStatus == 0 {
87+
assert.Equal(t, http.StatusMovedPermanently, rec.Code)
88+
} else {
89+
assert.Equal(t, tc.expectStatus, rec.Code)
90+
}
91+
})
92+
}
93+
}
94+
1295
func TestAddTrailingSlash(t *testing.T) {
13-
is := assert.New(t)
14-
e := echo.New()
15-
req := httptest.NewRequest(http.MethodGet, "/add-slash", nil)
16-
rec := httptest.NewRecorder()
17-
c := e.NewContext(req, rec)
18-
h := AddTrailingSlash()(func(c echo.Context) error {
19-
return nil
20-
})
21-
is.NoError(h(c))
22-
is.Equal("/add-slash/", req.URL.Path)
23-
is.Equal("/add-slash/", req.RequestURI)
24-
25-
// Method Connect must not fail:
26-
req = httptest.NewRequest(http.MethodConnect, "", nil)
27-
rec = httptest.NewRecorder()
28-
c = e.NewContext(req, rec)
29-
h = AddTrailingSlash()(func(c echo.Context) error {
30-
return nil
31-
})
32-
is.NoError(h(c))
33-
is.Equal("/", req.URL.Path)
34-
is.Equal("/", req.RequestURI)
35-
36-
// With config
37-
req = httptest.NewRequest(http.MethodGet, "/add-slash?key=value", nil)
38-
rec = httptest.NewRecorder()
39-
c = e.NewContext(req, rec)
40-
h = AddTrailingSlashWithConfig(TrailingSlashConfig{
41-
RedirectCode: http.StatusMovedPermanently,
42-
})(func(c echo.Context) error {
43-
return nil
44-
})
45-
is.NoError(h(c))
46-
is.Equal(http.StatusMovedPermanently, rec.Code)
47-
is.Equal("/add-slash/?key=value", rec.Header().Get(echo.HeaderLocation))
96+
var testCases = []struct {
97+
whenURL string
98+
whenMethod string
99+
expectPath string
100+
expectLocation []string
101+
}{
102+
{
103+
whenURL: "/add-slash",
104+
whenMethod: http.MethodGet,
105+
expectPath: "/add-slash/",
106+
},
107+
{
108+
whenURL: "/add-slash?key=value",
109+
whenMethod: http.MethodGet,
110+
expectPath: "/add-slash/",
111+
},
112+
{
113+
whenURL: "/",
114+
whenMethod: http.MethodConnect,
115+
expectPath: "/",
116+
expectLocation: nil,
117+
},
118+
}
119+
for _, tc := range testCases {
120+
t.Run(tc.whenURL, func(t *testing.T) {
121+
e := echo.New()
122+
123+
h := AddTrailingSlash()(func(c echo.Context) error {
124+
return nil
125+
})
126+
127+
rec := httptest.NewRecorder()
128+
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
129+
c := e.NewContext(req, rec)
130+
131+
err := h(c)
132+
assert.NoError(t, err)
133+
134+
assert.Equal(t, tc.expectPath, req.URL.Path)
135+
assert.Equal(t, []string(nil), rec.Header()[echo.HeaderLocation])
136+
assert.Equal(t, http.StatusOK, rec.Code)
137+
})
138+
}
139+
}
140+
141+
func TestRemoveTrailingSlashWithConfig(t *testing.T) {
142+
var testCases = []struct {
143+
whenURL string
144+
whenMethod string
145+
expectPath string
146+
expectLocation []string
147+
expectStatus int
148+
}{
149+
{
150+
whenURL: "/remove-slash/",
151+
whenMethod: http.MethodGet,
152+
expectPath: "/remove-slash/",
153+
expectLocation: []string{`/remove-slash`},
154+
},
155+
{
156+
whenURL: "/remove-slash/?key=value",
157+
whenMethod: http.MethodGet,
158+
expectPath: "/remove-slash/",
159+
expectLocation: []string{`/remove-slash?key=value`},
160+
},
161+
{
162+
whenURL: "/",
163+
whenMethod: http.MethodConnect,
164+
expectPath: "/",
165+
expectLocation: nil,
166+
expectStatus: http.StatusOK,
167+
},
168+
{
169+
whenURL: "http://localhost",
170+
whenMethod: http.MethodGet,
171+
expectPath: "",
172+
expectLocation: nil,
173+
expectStatus: http.StatusOK,
174+
},
175+
// cases for open redirect vulnerability
176+
{
177+
whenURL: "http://localhost:1323/%5Cexample.com/",
178+
expectPath: `/\example.com/`,
179+
expectLocation: []string{`/example.com`},
180+
},
181+
{
182+
whenURL: `http://localhost:1323/\example.com/`,
183+
expectPath: `/\example.com/`,
184+
expectLocation: []string{`/example.com`},
185+
},
186+
{
187+
whenURL: `http://localhost:1323/\\%5C////%5C\\\example.com/`,
188+
expectPath: `/\\\////\\\\example.com/`,
189+
expectLocation: []string{`/example.com`},
190+
},
191+
{
192+
whenURL: "http://localhost:1323//example.com/",
193+
expectPath: `//example.com/`,
194+
expectLocation: []string{`/example.com`},
195+
},
196+
{
197+
whenURL: "http://localhost:1323/%5C%5C/",
198+
expectPath: `/\\/`,
199+
expectLocation: []string{`/`},
200+
},
201+
}
202+
for _, tc := range testCases {
203+
t.Run(tc.whenURL, func(t *testing.T) {
204+
e := echo.New()
205+
206+
mw := RemoveTrailingSlashWithConfig(TrailingSlashConfig{
207+
RedirectCode: http.StatusMovedPermanently,
208+
})
209+
h := mw(func(c echo.Context) error {
210+
return nil
211+
})
212+
213+
rec := httptest.NewRecorder()
214+
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
215+
c := e.NewContext(req, rec)
216+
217+
err := h(c)
218+
assert.NoError(t, err)
219+
220+
assert.Equal(t, tc.expectPath, req.URL.Path)
221+
assert.Equal(t, tc.expectLocation, rec.Header()[echo.HeaderLocation])
222+
if tc.expectStatus == 0 {
223+
assert.Equal(t, http.StatusMovedPermanently, rec.Code)
224+
} else {
225+
assert.Equal(t, tc.expectStatus, rec.Code)
226+
}
227+
})
228+
}
48229
}
49230

50231
func TestRemoveTrailingSlash(t *testing.T) {
51-
is := assert.New(t)
52-
e := echo.New()
53-
req := httptest.NewRequest(http.MethodGet, "/remove-slash/", nil)
54-
rec := httptest.NewRecorder()
55-
c := e.NewContext(req, rec)
56-
h := RemoveTrailingSlash()(func(c echo.Context) error {
57-
return nil
58-
})
59-
is.NoError(h(c))
60-
is.Equal("/remove-slash", req.URL.Path)
61-
is.Equal("/remove-slash", req.RequestURI)
62-
63-
// Method Connect must not fail:
64-
req = httptest.NewRequest(http.MethodConnect, "", nil)
65-
rec = httptest.NewRecorder()
66-
c = e.NewContext(req, rec)
67-
h = RemoveTrailingSlash()(func(c echo.Context) error {
68-
return nil
69-
})
70-
is.NoError(h(c))
71-
is.Equal("", req.URL.Path)
72-
is.Equal("", req.RequestURI)
73-
74-
// With config
75-
req = httptest.NewRequest(http.MethodGet, "/remove-slash/?key=value", nil)
76-
rec = httptest.NewRecorder()
77-
c = e.NewContext(req, rec)
78-
h = RemoveTrailingSlashWithConfig(TrailingSlashConfig{
79-
RedirectCode: http.StatusMovedPermanently,
80-
})(func(c echo.Context) error {
81-
return nil
82-
})
83-
is.NoError(h(c))
84-
is.Equal(http.StatusMovedPermanently, rec.Code)
85-
is.Equal("/remove-slash?key=value", rec.Header().Get(echo.HeaderLocation))
86-
87-
// With bare URL
88-
req = httptest.NewRequest(http.MethodGet, "http://localhost", nil)
89-
rec = httptest.NewRecorder()
90-
c = e.NewContext(req, rec)
91-
h = RemoveTrailingSlash()(func(c echo.Context) error {
92-
return nil
93-
})
94-
is.NoError(h(c))
95-
is.Equal("", req.URL.Path)
232+
var testCases = []struct {
233+
whenURL string
234+
whenMethod string
235+
expectPath string
236+
}{
237+
{
238+
whenURL: "/remove-slash/",
239+
whenMethod: http.MethodGet,
240+
expectPath: "/remove-slash",
241+
},
242+
{
243+
whenURL: "/remove-slash/?key=value",
244+
whenMethod: http.MethodGet,
245+
expectPath: "/remove-slash",
246+
},
247+
{
248+
whenURL: "/",
249+
whenMethod: http.MethodConnect,
250+
expectPath: "/",
251+
},
252+
{
253+
whenURL: "http://localhost",
254+
whenMethod: http.MethodGet,
255+
expectPath: "",
256+
},
257+
}
258+
for _, tc := range testCases {
259+
t.Run(tc.whenURL, func(t *testing.T) {
260+
e := echo.New()
261+
262+
h := RemoveTrailingSlash()(func(c echo.Context) error {
263+
return nil
264+
})
265+
266+
rec := httptest.NewRecorder()
267+
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
268+
c := e.NewContext(req, rec)
269+
270+
err := h(c)
271+
assert.NoError(t, err)
272+
273+
assert.Equal(t, tc.expectPath, req.URL.Path)
274+
assert.Equal(t, []string(nil), rec.Header()[echo.HeaderLocation])
275+
assert.Equal(t, http.StatusOK, rec.Code)
276+
})
277+
}
96278
}

0 commit comments

Comments
 (0)