diff --git a/middleware/slash.go b/middleware/slash.go index 61d6e30b3..0492b334b 100644 --- a/middleware/slash.go +++ b/middleware/slash.go @@ -1,6 +1,8 @@ package middleware import ( + "strings" + "github.com/labstack/echo/v4" ) @@ -49,7 +51,7 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc url := req.URL path := url.Path qs := c.QueryString() - if path != "/" && path[len(path)-1] != '/' { + if !strings.HasSuffix(path, "/") { path += "/" uri := path if qs != "" { @@ -97,7 +99,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu path := url.Path qs := c.QueryString() l := len(path) - 1 - if l >= 0 && path != "/" && path[l] == '/' { + if l > 0 && strings.HasSuffix(path, "/") { path = path[:l] uri := path if qs != "" { diff --git a/middleware/slash_test.go b/middleware/slash_test.go index e60be740d..2a8e9eeaa 100644 --- a/middleware/slash_test.go +++ b/middleware/slash_test.go @@ -10,6 +10,7 @@ import ( ) func TestAddTrailingSlash(t *testing.T) { + is := assert.New(t) e := echo.New() req := httptest.NewRequest(http.MethodGet, "/add-slash", nil) rec := httptest.NewRecorder() @@ -17,11 +18,20 @@ func TestAddTrailingSlash(t *testing.T) { h := AddTrailingSlash()(func(c echo.Context) error { return nil }) - h(c) + is.NoError(h(c)) + is.Equal("/add-slash/", req.URL.Path) + is.Equal("/add-slash/", req.RequestURI) - assert := assert.New(t) - assert.Equal("/add-slash/", req.URL.Path) - assert.Equal("/add-slash/", req.RequestURI) + // Method Connect must not fail: + req = httptest.NewRequest(http.MethodConnect, "", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h = AddTrailingSlash()(func(c echo.Context) error { + return nil + }) + is.NoError(h(c)) + is.Equal("/", req.URL.Path) + is.Equal("/", req.RequestURI) // With config req = httptest.NewRequest(http.MethodGet, "/add-slash?key=value", nil) @@ -32,12 +42,13 @@ func TestAddTrailingSlash(t *testing.T) { })(func(c echo.Context) error { return nil }) - h(c) - assert.Equal(http.StatusMovedPermanently, rec.Code) - assert.Equal("/add-slash/?key=value", rec.Header().Get(echo.HeaderLocation)) + is.NoError(h(c)) + is.Equal(http.StatusMovedPermanently, rec.Code) + is.Equal("/add-slash/?key=value", rec.Header().Get(echo.HeaderLocation)) } func TestRemoveTrailingSlash(t *testing.T) { + is := assert.New(t) e := echo.New() req := httptest.NewRequest(http.MethodGet, "/remove-slash/", nil) rec := httptest.NewRecorder() @@ -45,12 +56,20 @@ func TestRemoveTrailingSlash(t *testing.T) { h := RemoveTrailingSlash()(func(c echo.Context) error { return nil }) - h(c) - - assert := assert.New(t) + is.NoError(h(c)) + is.Equal("/remove-slash", req.URL.Path) + is.Equal("/remove-slash", req.RequestURI) - assert.Equal("/remove-slash", req.URL.Path) - assert.Equal("/remove-slash", req.RequestURI) + // Method Connect must not fail: + req = httptest.NewRequest(http.MethodConnect, "", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h = RemoveTrailingSlash()(func(c echo.Context) error { + return nil + }) + is.NoError(h(c)) + is.Equal("", req.URL.Path) + is.Equal("", req.RequestURI) // With config req = httptest.NewRequest(http.MethodGet, "/remove-slash/?key=value", nil) @@ -61,9 +80,9 @@ func TestRemoveTrailingSlash(t *testing.T) { })(func(c echo.Context) error { return nil }) - h(c) - assert.Equal(http.StatusMovedPermanently, rec.Code) - assert.Equal("/remove-slash?key=value", rec.Header().Get(echo.HeaderLocation)) + is.NoError(h(c)) + is.Equal(http.StatusMovedPermanently, rec.Code) + is.Equal("/remove-slash?key=value", rec.Header().Get(echo.HeaderLocation)) // With bare URL req = httptest.NewRequest(http.MethodGet, "http://localhost", nil) @@ -72,6 +91,6 @@ func TestRemoveTrailingSlash(t *testing.T) { h = RemoveTrailingSlash()(func(c echo.Context) error { return nil }) - h(c) - assert.Equal("", req.URL.Path) + is.NoError(h(c)) + is.Equal("", req.URL.Path) }