Skip to content

Commit 78a18c0

Browse files
authored
Implementing Is(err) bool to support Go 1.13 style error checking (#136)
1 parent 0fb40d3 commit 78a18c0

File tree

5 files changed

+111
-21
lines changed

5 files changed

+111
-21
lines changed

claims.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ func (c RegisteredClaims) Valid() error {
5656
// default value in Go, let's not fail the verification for them.
5757
if !c.VerifyExpiresAt(now, false) {
5858
delta := now.Sub(c.ExpiresAt.Time)
59-
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
59+
vErr.Inner = fmt.Errorf("%s by %v", delta, ErrTokenExpired)
6060
vErr.Errors |= ValidationErrorExpired
6161
}
6262

6363
if !c.VerifyIssuedAt(now, false) {
64-
vErr.Inner = fmt.Errorf("token used before issued")
64+
vErr.Inner = ErrTokenUsedBeforeIssued
6565
vErr.Errors |= ValidationErrorIssuedAt
6666
}
6767

6868
if !c.VerifyNotBefore(now, false) {
69-
vErr.Inner = fmt.Errorf("token is not valid yet")
69+
vErr.Inner = ErrTokenNotValidYet
7070
vErr.Errors |= ValidationErrorNotValidYet
7171
}
7272

@@ -149,17 +149,17 @@ func (c StandardClaims) Valid() error {
149149
// default value in Go, let's not fail the verification for them.
150150
if !c.VerifyExpiresAt(now, false) {
151151
delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0))
152-
vErr.Inner = fmt.Errorf("token is expired by %v", delta)
152+
vErr.Inner = fmt.Errorf("%s by %v", delta, ErrTokenExpired)
153153
vErr.Errors |= ValidationErrorExpired
154154
}
155155

156156
if !c.VerifyIssuedAt(now, false) {
157-
vErr.Inner = fmt.Errorf("token used before issued")
157+
vErr.Inner = ErrTokenUsedBeforeIssued
158158
vErr.Errors |= ValidationErrorIssuedAt
159159
}
160160

161161
if !c.VerifyNotBefore(now, false) {
162-
vErr.Inner = fmt.Errorf("token is not valid yet")
162+
vErr.Inner = ErrTokenNotValidYet
163163
vErr.Errors |= ValidationErrorNotValidYet
164164
}
165165

errors.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@ var (
99
ErrInvalidKey = errors.New("key is invalid")
1010
ErrInvalidKeyType = errors.New("key is of invalid type")
1111
ErrHashUnavailable = errors.New("the requested hash function is unavailable")
12+
13+
ErrTokenMalformed = errors.New("token is malformed")
14+
ErrTokenUnverifiable = errors.New("token is unverifiable")
15+
ErrTokenSignatureInvalid = errors.New("token signature is invalid")
16+
17+
ErrTokenInvalidAudience = errors.New("token has invalid audience")
18+
ErrTokenExpired = errors.New("token is expired")
19+
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
20+
ErrTokenInvalidIssuer = errors.New("token has invalid issuer")
21+
ErrTokenNotValidYet = errors.New("token is not valid yet")
22+
ErrTokenInvalidId = errors.New("token has invalid id")
23+
ErrTokenInvalidClaims = errors.New("token has invalid claims")
1224
)
1325

1426
// The errors that might occur when parsing and validating a token
@@ -62,3 +74,39 @@ func (e *ValidationError) Unwrap() error {
6274
func (e *ValidationError) valid() bool {
6375
return e.Errors == 0
6476
}
77+
78+
// Is checks if this ValidationError is of the supplied error. We are first checking for the exact error message
79+
// by comparing the inner error message. If that fails, we compare using the error flags. This way we can use
80+
// custom error messages (mainly for backwards compatability) and still leverage errors.Is using the global error variables.
81+
func (e *ValidationError) Is(err error) bool {
82+
// Check, if our inner error is a direct match
83+
if errors.Is(errors.Unwrap(e), err) {
84+
return true
85+
}
86+
87+
// Otherwise, we need to match using our error flags
88+
switch err {
89+
case ErrTokenMalformed:
90+
return e.Errors&ValidationErrorMalformed != 0
91+
case ErrTokenUnverifiable:
92+
return e.Errors&ValidationErrorUnverifiable != 0
93+
case ErrTokenSignatureInvalid:
94+
return e.Errors&ValidationErrorSignatureInvalid != 0
95+
case ErrTokenInvalidAudience:
96+
return e.Errors&ValidationErrorAudience != 0
97+
case ErrTokenExpired:
98+
return e.Errors&ValidationErrorExpired != 0
99+
case ErrTokenUsedBeforeIssued:
100+
return e.Errors&ValidationErrorIssuedAt != 0
101+
case ErrTokenInvalidIssuer:
102+
return e.Errors&ValidationErrorIssuer != 0
103+
case ErrTokenNotValidYet:
104+
return e.Errors&ValidationErrorNotValidYet != 0
105+
case ErrTokenInvalidId:
106+
return e.Errors&ValidationErrorId != 0
107+
case ErrTokenInvalidClaims:
108+
return e.Errors&ValidationErrorClaimsInvalid != 0
109+
}
110+
111+
return false
112+
}

example_test.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package jwt_test
22

33
import (
4+
"errors"
45
"fmt"
56
"time"
67

@@ -103,15 +104,11 @@ func ExampleParse_errorChecking() {
103104

104105
if token.Valid {
105106
fmt.Println("You look nice today")
106-
} else if ve, ok := err.(*jwt.ValidationError); ok {
107-
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
108-
fmt.Println("That's not even a token")
109-
} else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
110-
// Token is either expired or not active yet
111-
fmt.Println("Timing is everything")
112-
} else {
113-
fmt.Println("Couldn't handle this token:", err)
114-
}
107+
} else if errors.Is(err, jwt.ErrTokenMalformed) {
108+
fmt.Println("That's not even a token")
109+
} else if errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet) {
110+
// Token is either expired or not active yet
111+
fmt.Println("Timing is everything")
115112
} else {
116113
fmt.Println("Couldn't handle this token:", err)
117114
}

map_claims.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,19 @@ func (m MapClaims) Valid() error {
126126
now := TimeFunc().Unix()
127127

128128
if !m.VerifyExpiresAt(now, false) {
129+
// TODO(oxisto): this should be replaced with ErrTokenExpired
129130
vErr.Inner = errors.New("Token is expired")
130131
vErr.Errors |= ValidationErrorExpired
131132
}
132133

133134
if !m.VerifyIssuedAt(now, false) {
135+
// TODO(oxisto): this should be replaced with ErrTokenUsedBeforeIssued
134136
vErr.Inner = errors.New("Token used before issued")
135137
vErr.Errors |= ValidationErrorIssuedAt
136138
}
137139

138140
if !m.VerifyNotBefore(now, false) {
141+
// TODO(oxisto): this should be replaced with ErrTokenNotValidYet
139142
vErr.Inner = errors.New("Token is not valid yet")
140143
vErr.Errors |= ValidationErrorNotValidYet
141144
}

parser_test.go

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"crypto"
55
"crypto/rsa"
66
"encoding/json"
7+
"errors"
78
"fmt"
89
"reflect"
910
"testing"
@@ -51,6 +52,7 @@ var jwtTestData = []struct {
5152
claims jwt.Claims
5253
valid bool
5354
errors uint32
55+
err []error
5456
parser *jwt.Parser
5557
signingMethod jwt.SigningMethod // The method to sign the JWT token for test purpose
5658
}{
@@ -62,6 +64,7 @@ var jwtTestData = []struct {
6264
true,
6365
0,
6466
nil,
67+
nil,
6568
jwt.SigningMethodRS256,
6669
},
6770
{
@@ -71,6 +74,7 @@ var jwtTestData = []struct {
7174
jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)},
7275
false,
7376
jwt.ValidationErrorExpired,
77+
[]error{jwt.ErrTokenExpired},
7478
nil,
7579
jwt.SigningMethodRS256,
7680
},
@@ -81,6 +85,7 @@ var jwtTestData = []struct {
8185
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)},
8286
false,
8387
jwt.ValidationErrorNotValidYet,
88+
[]error{jwt.ErrTokenNotValidYet},
8489
nil,
8590
jwt.SigningMethodRS256,
8691
},
@@ -91,6 +96,7 @@ var jwtTestData = []struct {
9196
jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)},
9297
false,
9398
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
99+
[]error{jwt.ErrTokenNotValidYet},
94100
nil,
95101
jwt.SigningMethodRS256,
96102
},
@@ -101,6 +107,7 @@ var jwtTestData = []struct {
101107
jwt.MapClaims{"foo": "bar"},
102108
false,
103109
jwt.ValidationErrorSignatureInvalid,
110+
[]error{jwt.ErrTokenSignatureInvalid, rsa.ErrVerification},
104111
nil,
105112
jwt.SigningMethodRS256,
106113
},
@@ -111,6 +118,7 @@ var jwtTestData = []struct {
111118
jwt.MapClaims{"foo": "bar"},
112119
false,
113120
jwt.ValidationErrorUnverifiable,
121+
[]error{jwt.ErrTokenUnverifiable},
114122
nil,
115123
jwt.SigningMethodRS256,
116124
},
@@ -121,6 +129,7 @@ var jwtTestData = []struct {
121129
jwt.MapClaims{"foo": "bar"},
122130
false,
123131
jwt.ValidationErrorSignatureInvalid,
132+
[]error{jwt.ErrTokenSignatureInvalid},
124133
nil,
125134
jwt.SigningMethodRS256,
126135
},
@@ -131,6 +140,7 @@ var jwtTestData = []struct {
131140
jwt.MapClaims{"foo": "bar"},
132141
false,
133142
jwt.ValidationErrorUnverifiable,
143+
[]error{jwt.ErrTokenUnverifiable, errKeyFuncError},
134144
nil,
135145
jwt.SigningMethodRS256,
136146
},
@@ -141,6 +151,7 @@ var jwtTestData = []struct {
141151
jwt.MapClaims{"foo": "bar"},
142152
false,
143153
jwt.ValidationErrorSignatureInvalid,
154+
[]error{jwt.ErrTokenSignatureInvalid},
144155
&jwt.Parser{ValidMethods: []string{"HS256"}},
145156
jwt.SigningMethodRS256,
146157
},
@@ -151,6 +162,7 @@ var jwtTestData = []struct {
151162
jwt.MapClaims{"foo": "bar"},
152163
true,
153164
0,
165+
nil,
154166
&jwt.Parser{ValidMethods: []string{"RS256", "HS256"}},
155167
jwt.SigningMethodRS256,
156168
},
@@ -161,6 +173,7 @@ var jwtTestData = []struct {
161173
jwt.MapClaims{"foo": "bar"},
162174
false,
163175
jwt.ValidationErrorSignatureInvalid,
176+
[]error{jwt.ErrTokenSignatureInvalid},
164177
&jwt.Parser{ValidMethods: []string{"RS256", "HS256"}},
165178
jwt.SigningMethodES256,
166179
},
@@ -171,6 +184,7 @@ var jwtTestData = []struct {
171184
jwt.MapClaims{"foo": "bar"},
172185
true,
173186
0,
187+
nil,
174188
&jwt.Parser{ValidMethods: []string{"HS256", "ES256"}},
175189
jwt.SigningMethodES256,
176190
},
@@ -181,6 +195,7 @@ var jwtTestData = []struct {
181195
jwt.MapClaims{"foo": json.Number("123.4")},
182196
true,
183197
0,
198+
nil,
184199
&jwt.Parser{UseJSONNumber: true},
185200
jwt.SigningMethodRS256,
186201
},
@@ -193,6 +208,7 @@ var jwtTestData = []struct {
193208
},
194209
true,
195210
0,
211+
nil,
196212
&jwt.Parser{UseJSONNumber: true},
197213
jwt.SigningMethodRS256,
198214
},
@@ -203,6 +219,7 @@ var jwtTestData = []struct {
203219
jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
204220
false,
205221
jwt.ValidationErrorExpired,
222+
[]error{jwt.ErrTokenExpired},
206223
&jwt.Parser{UseJSONNumber: true},
207224
jwt.SigningMethodRS256,
208225
},
@@ -213,6 +230,7 @@ var jwtTestData = []struct {
213230
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
214231
false,
215232
jwt.ValidationErrorNotValidYet,
233+
[]error{jwt.ErrTokenNotValidYet},
216234
&jwt.Parser{UseJSONNumber: true},
217235
jwt.SigningMethodRS256,
218236
},
@@ -223,6 +241,7 @@ var jwtTestData = []struct {
223241
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))},
224242
false,
225243
jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired,
244+
[]error{jwt.ErrTokenNotValidYet},
226245
&jwt.Parser{UseJSONNumber: true},
227246
jwt.SigningMethodRS256,
228247
},
@@ -233,6 +252,7 @@ var jwtTestData = []struct {
233252
jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))},
234253
true,
235254
0,
255+
nil,
236256
&jwt.Parser{UseJSONNumber: true, SkipClaimsValidation: true},
237257
jwt.SigningMethodRS256,
238258
},
@@ -245,6 +265,7 @@ var jwtTestData = []struct {
245265
},
246266
true,
247267
0,
268+
nil,
248269
&jwt.Parser{UseJSONNumber: true},
249270
jwt.SigningMethodRS256,
250271
},
@@ -257,6 +278,7 @@ var jwtTestData = []struct {
257278
},
258279
true,
259280
0,
281+
nil,
260282
&jwt.Parser{UseJSONNumber: true},
261283
jwt.SigningMethodRS256,
262284
},
@@ -269,6 +291,7 @@ var jwtTestData = []struct {
269291
},
270292
true,
271293
0,
294+
nil,
272295
&jwt.Parser{UseJSONNumber: true},
273296
jwt.SigningMethodRS256,
274297
},
@@ -281,6 +304,7 @@ var jwtTestData = []struct {
281304
},
282305
false,
283306
jwt.ValidationErrorMalformed,
307+
[]error{jwt.ErrTokenMalformed},
284308
&jwt.Parser{UseJSONNumber: true},
285309
jwt.SigningMethodRS256,
286310
},
@@ -293,6 +317,7 @@ var jwtTestData = []struct {
293317
},
294318
false,
295319
jwt.ValidationErrorMalformed,
320+
[]error{jwt.ErrTokenMalformed},
296321
&jwt.Parser{UseJSONNumber: true},
297322
jwt.SigningMethodRS256,
298323
},
@@ -325,6 +350,7 @@ func TestParser_Parse(t *testing.T) {
325350

326351
// Parse the token
327352
var token *jwt.Token
353+
var ve *jwt.ValidationError
328354
var err error
329355
var parser = data.parser
330356
if parser == nil {
@@ -361,18 +387,34 @@ func TestParser_Parse(t *testing.T) {
361387
if err == nil {
362388
t.Errorf("[%v] Expecting error. Didn't get one.", data.name)
363389
} else {
390+
if errors.As(err, &ve) {
391+
// compare the bitfield part of the error
392+
if e := ve.Errors; e != data.errors {
393+
t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors)
394+
}
395+
396+
if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError {
397+
t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError)
398+
}
399+
}
400+
}
401+
}
364402

365-
ve := err.(*jwt.ValidationError)
366-
// compare the bitfield part of the error
367-
if e := ve.Errors; e != data.errors {
368-
t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors)
403+
if data.err != nil {
404+
if err == nil {
405+
t.Errorf("[%v] Expecting error(s). Didn't get one.", data.name)
406+
} else {
407+
var all = false
408+
for _, e := range data.err {
409+
all = errors.Is(err, e)
369410
}
370411

371-
if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError {
372-
t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError)
412+
if !all {
413+
t.Errorf("[%v] Errors don't match expectation. %v should contain all of %v", data.name, err, data.err)
373414
}
374415
}
375416
}
417+
376418
if data.valid {
377419
if token.Signature == "" {
378420
t.Errorf("[%v] Signature is left unpopulated after parsing", data.name)

0 commit comments

Comments
 (0)