Skip to content

Commit c4bf656

Browse files
committed
Add flag to JWT and KeyAuth middleware to allow continuing execution next(c) when error handler decides to swallow the error (returns nil).
1 parent c0fdaa2 commit c4bf656

File tree

4 files changed

+191
-3
lines changed

4 files changed

+191
-3
lines changed

middleware/jwt.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ type (
2121
// BeforeFunc defines a function which is executed just before the middleware.
2222
BeforeFunc BeforeFunc
2323

24-
// SuccessHandler defines a function which is executed for a valid token.
24+
// SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next
25+
// middleware or handler.
2526
SuccessHandler JWTSuccessHandler
2627

2728
// ErrorHandler defines a function which is executed for an invalid token.
@@ -31,6 +32,13 @@ type (
3132
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
3233
ErrorHandlerWithContext JWTErrorHandlerWithContext
3334

35+
// NoErrorContinuesExecution allows next middleware/handler to be called when ErrorHandlerWithContext decides to
36+
// swallow the error (returns nil).
37+
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized
38+
// users. In that case you can use ErrorHandlerWithContext to set default public JWT token value to request and
39+
// continue with handler chain. Assuming logic downstream execution chain has to check that (public) token value.
40+
NoErrorContinuesExecution bool
41+
3442
// Signing key to validate token.
3543
// This is one of the three options to provide a token validation key.
3644
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
@@ -228,7 +236,11 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
228236
return config.ErrorHandler(err)
229237
}
230238
if config.ErrorHandlerWithContext != nil {
231-
return config.ErrorHandlerWithContext(err, c)
239+
tmpErr := config.ErrorHandlerWithContext(err, c)
240+
if config.NoErrorContinuesExecution && tmpErr == nil {
241+
return next(c)
242+
}
243+
return tmpErr
232244
}
233245

234246
// backwards compatible errors codes

middleware/jwt_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,3 +703,77 @@ func TestJWTConfig_SuccessHandler(t *testing.T) {
703703
})
704704
}
705705
}
706+
707+
func TestJWTConfig_NoErrorContinuesExecution(t *testing.T) {
708+
var testCases = []struct {
709+
name string
710+
whenNoErrorContinuesExecution bool
711+
givenToken string
712+
expectStatus int
713+
expectBody string
714+
}{
715+
{
716+
name: "no error handler is called",
717+
whenNoErrorContinuesExecution: true,
718+
givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ",
719+
expectStatus: http.StatusTeapot,
720+
expectBody: "",
721+
},
722+
{
723+
name: "NoErrorContinuesExecution is false and error handler is called for missing token",
724+
whenNoErrorContinuesExecution: false,
725+
givenToken: "",
726+
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
727+
expectStatus: http.StatusOK,
728+
expectBody: "",
729+
},
730+
{
731+
name: "error handler is called for missing token",
732+
whenNoErrorContinuesExecution: true,
733+
givenToken: "",
734+
expectStatus: http.StatusTeapot,
735+
expectBody: "public-token",
736+
},
737+
{
738+
name: "error handler is called for invalid token",
739+
whenNoErrorContinuesExecution: true,
740+
givenToken: "x.x.x",
741+
expectStatus: http.StatusUnauthorized,
742+
expectBody: "{\"message\":\"Unauthorized\"}\n",
743+
},
744+
}
745+
746+
for _, tc := range testCases {
747+
t.Run(tc.name, func(t *testing.T) {
748+
e := echo.New()
749+
750+
e.GET("/", func(c echo.Context) error {
751+
testValue, _ := c.Get("test").(string)
752+
return c.String(http.StatusTeapot, testValue)
753+
})
754+
755+
e.Use(JWTWithConfig(JWTConfig{
756+
NoErrorContinuesExecution: tc.whenNoErrorContinuesExecution,
757+
SigningKey: []byte("secret"),
758+
ErrorHandlerWithContext: func(err error, c echo.Context) error {
759+
if err == ErrJWTMissing {
760+
c.Set("test", "public-token")
761+
return nil
762+
}
763+
return echo.ErrUnauthorized
764+
},
765+
}))
766+
767+
req := httptest.NewRequest(http.MethodGet, "/", nil)
768+
if tc.givenToken != "" {
769+
req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken)
770+
}
771+
res := httptest.NewRecorder()
772+
773+
e.ServeHTTP(res, req)
774+
775+
assert.Equal(t, tc.expectStatus, res.Code)
776+
assert.Equal(t, tc.expectBody, res.Body.String())
777+
})
778+
}
779+
}

middleware/key_auth.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ type (
3535
// ErrorHandler defines a function which is executed for an invalid key.
3636
// It may be used to define a custom error.
3737
ErrorHandler KeyAuthErrorHandler
38+
39+
// NoErrorContinuesExecution allows next middleware/handler to be called when ErrorHandler decides to swallow
40+
// the error (returns nil).
41+
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for valid
42+
// requests. In that case you can use ErrorHandler to set default public auth values to request and continue with
43+
// handler chain. Assuming logic downstream execution chain has to check that (public) auth value.
44+
NoErrorContinuesExecution bool
3845
}
3946

4047
// KeyAuthValidator defines a function to validate KeyAuth credentials.
@@ -53,6 +60,21 @@ var (
5360
}
5461
)
5562

63+
// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups
64+
type ErrKeyAuthMissing struct {
65+
Err error
66+
}
67+
68+
// Error returns errors text
69+
func (e *ErrKeyAuthMissing) Error() string {
70+
return e.Err.Error()
71+
}
72+
73+
// Unwrap unwraps error
74+
func (e *ErrKeyAuthMissing) Unwrap() error {
75+
return e.Err
76+
}
77+
5678
// KeyAuth returns an KeyAuth middleware.
5779
//
5880
// For valid key it calls the next handler.
@@ -131,10 +153,15 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
131153
} else {
132154
err = lastExtractorErr
133155
}
156+
err = &ErrKeyAuthMissing{Err: err}
134157
}
135158

136159
if config.ErrorHandler != nil {
137-
return config.ErrorHandler(err, c)
160+
tmpErr := config.ErrorHandler(err, c)
161+
if config.NoErrorContinuesExecution && tmpErr == nil {
162+
return next(c)
163+
}
164+
return tmpErr
138165
}
139166
if lastValidatorErr != nil { // prioritize validator errors over extracting errors
140167
return &echo.HTTPError{

middleware/key_auth_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,78 @@ func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) {
288288
},
289289
)
290290
}
291+
292+
func TestKeyAuthWithConfig_NoErrorContinuesExecution(t *testing.T) {
293+
var testCases = []struct {
294+
name string
295+
whenNoErrorContinuesExecution bool
296+
givenKey string
297+
expectStatus int
298+
expectBody string
299+
}{
300+
{
301+
name: "no error handler is called",
302+
whenNoErrorContinuesExecution: true,
303+
givenKey: "valid-key",
304+
expectStatus: http.StatusTeapot,
305+
expectBody: "",
306+
},
307+
{
308+
name: "NoErrorContinuesExecution is false and error handler is called for missing token",
309+
whenNoErrorContinuesExecution: false,
310+
givenKey: "",
311+
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
312+
expectStatus: http.StatusOK,
313+
expectBody: "",
314+
},
315+
{
316+
name: "error handler is called for missing token",
317+
whenNoErrorContinuesExecution: true,
318+
givenKey: "",
319+
expectStatus: http.StatusTeapot,
320+
expectBody: "public-auth",
321+
},
322+
{
323+
name: "error handler is called for invalid token",
324+
whenNoErrorContinuesExecution: true,
325+
givenKey: "x.x.x",
326+
expectStatus: http.StatusUnauthorized,
327+
expectBody: "{\"message\":\"Unauthorized\"}\n",
328+
},
329+
}
330+
331+
for _, tc := range testCases {
332+
t.Run(tc.name, func(t *testing.T) {
333+
e := echo.New()
334+
335+
e.GET("/", func(c echo.Context) error {
336+
testValue, _ := c.Get("test").(string)
337+
return c.String(http.StatusTeapot, testValue)
338+
})
339+
340+
e.Use(KeyAuthWithConfig(KeyAuthConfig{
341+
Validator: testKeyValidator,
342+
ErrorHandler: func(err error, c echo.Context) error {
343+
if _, ok := err.(*ErrKeyAuthMissing); ok {
344+
c.Set("test", "public-auth")
345+
return nil
346+
}
347+
return echo.ErrUnauthorized
348+
},
349+
KeyLookup: "header:X-API-Key",
350+
NoErrorContinuesExecution: tc.whenNoErrorContinuesExecution,
351+
}))
352+
353+
req := httptest.NewRequest(http.MethodGet, "/", nil)
354+
if tc.givenKey != "" {
355+
req.Header.Set("X-API-Key", tc.givenKey)
356+
}
357+
res := httptest.NewRecorder()
358+
359+
e.ServeHTTP(res, req)
360+
361+
assert.Equal(t, tc.expectStatus, res.Code)
362+
assert.Equal(t, tc.expectBody, res.Body.String())
363+
})
364+
}
365+
}

0 commit comments

Comments
 (0)