diff --git a/errors.go b/errors.go index 1c93024a..10ef72cf 100644 --- a/errors.go +++ b/errors.go @@ -27,7 +27,7 @@ const ( ValidationErrorClaimsInvalid // Generic claims validation error ) -// Helper for constructing a ValidationError with a string error message +// NewValidationError constructs a ValidationError with a string error message func NewValidationError(errorText string, errorFlags uint32) *ValidationError { return &ValidationError{ text: errorText, @@ -35,14 +35,14 @@ func NewValidationError(errorText string, errorFlags uint32) *ValidationError { } } -// The error from Parse if token is not valid +// ValidationError is returned from Parse if the token is not valid. type ValidationError struct { Inner error // stores the error returned by external dependencies, i.e.: KeyFunc Errors uint32 // bitfield. see ValidationError... constants text string // errors that do not have a valid error just have text } -// Validation error is an error type +// Error implements the builtin error interface. func (e ValidationError) Error() string { if e.Inner != nil { return e.Inner.Error() @@ -53,7 +53,46 @@ func (e ValidationError) Error() string { } } -// No errors -func (e *ValidationError) valid() bool { +// IncludesAll tells whether an error includes all the bits provided. +// For instance, to check whether an error matches one condition: +// +// valErr.IncludesAll(ValidationErrorAudience) +// // will return true if ValidationErrorAudience is present in the Errors field +// // and false otherwise +// +// or to check if it matches many conditions: +// +// valErr.IncludesAll(ValidationErrorIssuer, ValidationErrorAudience) +// // will return true only if BOTH ValidationErrorIssuer AND ValidationErrorAudience +// // are present on the Errors field and false otherwise. +func (e ValidationError) IncludesAll(flags ...uint32) bool { + bits := uint32(0) + for _, flag := range flags { + bits |= flag + } + return (e.Errors & bits) == bits +} + +// IncludesAny tells whether an error includes any of the bits provided. +// Checking for matching of one condition is exactly as in IncludesAll. +// To check if an error matches any of several conditions: +// +// valErr.IncludesAny(ValidationErrorNotValidYet, ValidationErrorExpired) +// // will return true if: +// // - ValidationErrorNotValidYet is present +// // - ValidationErrorExpired is present +// // - ValidationErrorNotValidYet and ValidationErrorExpired +// // are somehow both present +// // and will return false only if NEITHER NotValidYet NOR Expired are present. +func (e ValidationError) IncludesAny(flags ...uint32) bool { + bits := uint32(0) + for _, flag := range flags { + bits |= flag + } + return (e.Errors & bits) != 0 +} + +// valid returns true if there are no errors. +func (e ValidationError) valid() bool { return e.Errors == 0 } diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 00000000..9c5dca1c --- /dev/null +++ b/errors_test.go @@ -0,0 +1,118 @@ +package jwt + +import ( + "fmt" + "testing" +) + +func TestValidationErrorIncludes(t *testing.T) { + type checks struct { + params []uint32 // the params to pass to .IncludesAll and .IncludesAny + wantAll bool // the desired result of .IncludesAll + wantAny bool // the desired result of .IncludesAny + } + cases := []struct { + name string // the name of the test case + errors uint32 // the errors to put into the ValidationError + wantValid bool // true if the error should be .valid() + checks []checks // the checks to perform against the ValidationError + }{ + { + name: "valid", + errors: 0, + wantValid: true, + checks: []checks{ + { + params: []uint32{}, + wantAll: true, + wantAny: false, + }, + { + params: []uint32{ValidationErrorMalformed}, + wantAll: false, + wantAny: false, + }, + }, + }, + { + name: "one error", + errors: ValidationErrorExpired, + checks: []checks{ + { + params: []uint32{}, + wantAll: true, + wantAny: false, + }, + { + params: []uint32{ValidationErrorExpired}, + wantAll: true, + wantAny: true, + }, + { + params: []uint32{ValidationErrorExpired, ValidationErrorAudience}, + wantAll: false, + wantAny: true, + }, + { + params: []uint32{ValidationErrorAudience}, + wantAll: false, + wantAny: false, + }, + }, + }, + { + name: "many errors", + errors: ValidationErrorAudience | ValidationErrorIssuer, + checks: []checks{ + { + params: []uint32{}, + wantAll: true, + wantAny: false, + }, + { + params: []uint32{ValidationErrorAudience}, + wantAll: true, + wantAny: true, + }, + { + params: []uint32{ValidationErrorAudience, ValidationErrorId}, + wantAll: false, + wantAny: true, + }, + { + params: []uint32{ValidationErrorAudience, ValidationErrorIssuer}, + wantAll: true, + wantAny: true, + }, + { + params: []uint32{ValidationErrorAudience, ValidationErrorIssuer, ValidationErrorNotValidYet}, + wantAll: false, + wantAny: true, + }, + { + params: []uint32{ValidationErrorExpired, ValidationErrorSignatureInvalid}, + wantAll: false, + wantAny: false, + }, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ve := NewValidationError(tc.name, tc.errors) + if got := ve.valid(); got != tc.wantValid { + t.Errorf("ve.valid() = %v, want %v", got, tc.wantValid) + } + for _, ch := range tc.checks { + t.Run(fmt.Sprint(ch.params), func(t *testing.T) { + if got := ve.IncludesAll(ch.params...); got != ch.wantAll { + t.Errorf("ve.IncludesAll(%v) = %v; want %v", ch.params, got, ch.wantAll) + } + if got := ve.IncludesAny(ch.params...); got != ch.wantAny { + t.Errorf("ve.IncludesAny(%v) = %v; want %v", ch.params, got, ch.wantAny) + } + }) + } + }) + } +}