Skip to content

Commit 7091ac8

Browse files
authored
Rework validation to be less strict (#24)
1 parent 0e87ee3 commit 7091ac8

File tree

3 files changed

+233
-9
lines changed

3 files changed

+233
-9
lines changed

jwk.go

Lines changed: 123 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,23 @@ import (
77
"crypto/ed25519"
88
"crypto/rsa"
99
"crypto/x509"
10+
"encoding/base64"
1011
"encoding/json"
1112
"errors"
1213
"fmt"
1314
"io"
15+
"math/big"
1416
"net/http"
1517
"net/url"
16-
"reflect"
1718
"slices"
1819
"time"
1920
)
2021

22+
var (
23+
// ErrPadding indicates that there is invalid padding.
24+
ErrPadding = errors.New("padding error")
25+
)
26+
2127
// JWK represents a JSON Web Key.
2228
type JWK struct {
2329
key any
@@ -67,6 +73,8 @@ type JWKValidateOptions struct {
6773
SkipUse bool
6874
// SkipX5UScheme is used to skip checking if the X5U URI scheme is https.
6975
SkipX5UScheme bool
76+
// StrictPadding is used to indicate that the JWK should be validated with strict padding.
77+
StrictPadding bool
7078
}
7179

7280
// JWKMetadataOptions are direct passthroughs into the JWKMarshal.
@@ -221,6 +229,10 @@ func (j JWK) Validate() error {
221229
return fmt.Errorf("%w: invalid or unsupported key type %q", ErrJWKValidation, j.marshal.KTY)
222230
}
223231

232+
if !j.options.Validate.SkipUse && !j.marshal.USE.IANARegistered() {
233+
return fmt.Errorf("%w: invalid or unsupported key use %q", ErrJWKValidation, j.marshal.USE)
234+
}
235+
224236
if !j.options.Validate.SkipKeyOps {
225237
for _, o := range j.marshal.KEYOPS {
226238
if !o.IANARegistered() {
@@ -229,10 +241,6 @@ func (j JWK) Validate() error {
229241
}
230242
}
231243

232-
if !j.options.Validate.SkipUse && !j.marshal.USE.IANARegistered() {
233-
return fmt.Errorf("%w: invalid or unsupported key use %q", ErrJWKValidation, j.marshal.USE)
234-
}
235-
236244
if !j.options.Validate.SkipMetadata {
237245
if j.marshal.ALG != j.options.Metadata.ALG {
238246
return fmt.Errorf("%w: ALG in marshal does not match ALG in options", errors.Join(ErrJWKValidation, ErrOptions))
@@ -301,18 +309,100 @@ func (j JWK) Validate() error {
301309
return fmt.Errorf("failed to marshal JSON Web Key: %w", errors.Join(ErrJWKValidation, err))
302310
}
303311

312+
// Remove automatically computed thumbprints if not set in given JWK.
304313
if j.marshal.X5T == "" {
305314
marshalled.X5T = ""
306315
}
307316
if j.marshal.X5TS256 == "" {
308317
marshalled.X5TS256 = ""
309318
}
310319

311-
ok := reflect.DeepEqual(j.marshal, marshalled)
312-
if !ok {
313-
return fmt.Errorf("%w: marshaled JWK does not match original JWK", ErrJWKValidation)
320+
if j.marshal.X5T != marshalled.X5T {
321+
return fmt.Errorf("%w: X5T in marshal does not match X5T in marshalled", ErrJWKValidation)
322+
}
323+
if j.marshal.X5TS256 != marshalled.X5TS256 {
324+
return fmt.Errorf("%w: X5TS256 in marshal does not match X5TS256 in marshalled", ErrJWKValidation)
325+
}
326+
if j.marshal.CRV != marshalled.CRV {
327+
return fmt.Errorf("%w: CRV in marshal does not match CRV in marshalled", ErrJWKValidation)
328+
}
329+
switch j.marshal.KTY {
330+
case KtyEC:
331+
err = cmpBase64Int(j.marshal.X, marshalled.X, j.options.Validate.StrictPadding)
332+
if err != nil {
333+
return fmt.Errorf("%w: X in marshal does not match X in marshalled", errors.Join(ErrJWKValidation, err))
334+
}
335+
err = cmpBase64Int(j.marshal.Y, marshalled.Y, j.options.Validate.StrictPadding)
336+
if err != nil {
337+
return fmt.Errorf("%w: Y in marshal does not match Y in marshalled", errors.Join(ErrJWKValidation, err))
338+
}
339+
err = cmpBase64Int(j.marshal.D, marshalled.D, j.options.Validate.StrictPadding)
340+
if err != nil {
341+
return fmt.Errorf("%w: D in marshal does not match D in marshalled", errors.Join(ErrJWKValidation, err))
342+
}
343+
case KtyOKP:
344+
if j.marshal.X != marshalled.X {
345+
return fmt.Errorf("%w: X in marshal does not match X in marshalled", ErrJWKValidation)
346+
}
347+
if j.marshal.D != marshalled.D {
348+
return fmt.Errorf("%w: D in marshal does not match D in marshalled", ErrJWKValidation)
349+
}
350+
case KtyRSA:
351+
err = cmpBase64Int(j.marshal.D, marshalled.D, j.options.Validate.StrictPadding)
352+
if err != nil {
353+
return fmt.Errorf("%w: D in marshal does not match D in marshalled", errors.Join(ErrJWKValidation, err))
354+
}
355+
err = cmpBase64Int(j.marshal.N, marshalled.N, j.options.Validate.StrictPadding)
356+
if err != nil {
357+
return fmt.Errorf("%w: N in marshal does not match N in marshalled", errors.Join(ErrJWKValidation, err))
358+
}
359+
err = cmpBase64Int(j.marshal.E, marshalled.E, j.options.Validate.StrictPadding)
360+
if err != nil {
361+
return fmt.Errorf("%w: E in marshal does not match E in marshalled", errors.Join(ErrJWKValidation, err))
362+
}
363+
err = cmpBase64Int(j.marshal.P, marshalled.P, j.options.Validate.StrictPadding)
364+
if err != nil {
365+
return fmt.Errorf("%w: P in marshal does not match P in marshalled", errors.Join(ErrJWKValidation, err))
366+
}
367+
err = cmpBase64Int(j.marshal.Q, marshalled.Q, j.options.Validate.StrictPadding)
368+
if err != nil {
369+
return fmt.Errorf("%w: Q in marshal does not match Q in marshalled", errors.Join(ErrJWKValidation, err))
370+
}
371+
err = cmpBase64Int(j.marshal.DP, marshalled.DP, j.options.Validate.StrictPadding)
372+
if err != nil {
373+
return fmt.Errorf("%w: DP in marshal does not match DP in marshalled", errors.Join(ErrJWKValidation, err))
374+
}
375+
err = cmpBase64Int(j.marshal.DQ, marshalled.DQ, j.options.Validate.StrictPadding)
376+
if err != nil {
377+
return fmt.Errorf("%w: DQ in marshal does not match DQ in marshalled", errors.Join(ErrJWKValidation, err))
378+
}
379+
if len(j.marshal.OTH) != len(marshalled.OTH) {
380+
return fmt.Errorf("%w: OTH in marshal does not match OTH in marshalled", ErrJWKValidation)
381+
}
382+
for i, o := range j.marshal.OTH {
383+
err = cmpBase64Int(o.R, marshalled.OTH[i].R, j.options.Validate.StrictPadding)
384+
if err != nil {
385+
return fmt.Errorf("%w: OTH index %d in marshal does not match OTH in marshalled", errors.Join(ErrJWKValidation, err), i)
386+
}
387+
err = cmpBase64Int(o.D, marshalled.OTH[i].D, j.options.Validate.StrictPadding)
388+
if err != nil {
389+
return fmt.Errorf("%w: OTH index %d in marshal does not match OTH in marshalled", errors.Join(ErrJWKValidation, err), i)
390+
}
391+
err = cmpBase64Int(o.T, marshalled.OTH[i].T, j.options.Validate.StrictPadding)
392+
if err != nil {
393+
return fmt.Errorf("%w: OTH index %d in marshal does not match OTH in marshalled", errors.Join(ErrJWKValidation, err), i)
394+
}
395+
}
396+
case KtyOct:
397+
err = cmpBase64Int(j.marshal.K, marshalled.K, j.options.Validate.StrictPadding)
398+
if err != nil {
399+
return fmt.Errorf("%w: K in marshal does not match K in marshalled", errors.Join(ErrJWKValidation, err))
400+
}
401+
default:
402+
return fmt.Errorf("%w: invalid or unsupported key type %q", ErrJWKValidation, j.marshal.KTY)
314403
}
315404

405+
// Saved for last because it may involve a network request.
316406
if j.marshal.X5U != "" || j.options.X509.X5U != "" {
317407
if j.marshal.X5U != j.options.X509.X5U {
318408
return fmt.Errorf("%w: X5U in marshal does not match X5U in options", errors.Join(ErrJWKValidation, ErrOptions))
@@ -376,3 +466,28 @@ func DefaultGetX5U(u *url.URL) ([]*x509.Certificate, error) {
376466
}
377467
return certs, nil
378468
}
469+
470+
func cmpBase64Int(first, second string, strictPadding bool) error {
471+
if first == second {
472+
return nil
473+
}
474+
b, err := base64.RawURLEncoding.DecodeString(first)
475+
if err != nil {
476+
return fmt.Errorf("failed to decode Base64 raw URL decode first string: %w", err)
477+
}
478+
fLen := len(b)
479+
f := new(big.Int).SetBytes(b)
480+
b, err = base64.RawURLEncoding.DecodeString(second)
481+
if err != nil {
482+
return fmt.Errorf("failed to decode Base64 raw URL decode second string: %w", err)
483+
}
484+
sLen := len(b)
485+
s := new(big.Int).SetBytes(b)
486+
if f.Cmp(s) != 0 {
487+
return fmt.Errorf("%w: the parsed integers do not match", ErrJWKValidation)
488+
}
489+
if strictPadding && fLen != sLen {
490+
return fmt.Errorf("%w: the Base64 raw URL inputs do not have matching padding", errors.Join(ErrJWKValidation, ErrPadding))
491+
}
492+
return nil
493+
}

jwk_test.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@ import (
88
"encoding/base64"
99
"encoding/json"
1010
"encoding/pem"
11+
"errors"
1112
"testing"
1213
"time"
1314
)
1415

16+
const (
17+
anyStr = "any"
18+
invalidStr = "invalid"
19+
)
20+
1521
func TestNewJWKFromRawJSON(t *testing.T) {
1622
marshalOptions := JWKMarshalOptions{
1723
Private: true,
@@ -97,6 +103,109 @@ func TestMissingThumbprint(t *testing.T) {
97103
}
98104
}
99105

106+
func TestJWK_Validate(t *testing.T) {
107+
jwk := JWK{}
108+
err := jwk.Validate()
109+
if err == nil {
110+
t.Fatalf("Expected to fail validation for empty JWK.")
111+
}
112+
113+
jwk.options.Validate.SkipAll = true
114+
err = jwk.Validate()
115+
if err != nil {
116+
t.Fatalf("Failed to skip validation. %s", err)
117+
}
118+
jwk.options.Validate.SkipAll = false
119+
120+
jwk.marshal.KTY = KtyOKP
121+
jwk.marshal.USE = invalidStr
122+
err = jwk.Validate()
123+
if err == nil {
124+
t.Fatalf("Expected to fail validation for invalid use.")
125+
}
126+
jwk.marshal.USE = ""
127+
128+
jwk.marshal.KEYOPS = []KEYOPS{invalidStr}
129+
err = jwk.Validate()
130+
if err == nil {
131+
t.Fatalf("Expected to fail validation for invalid key operations.")
132+
}
133+
jwk.marshal.KEYOPS = nil
134+
135+
jwk.options.Metadata.ALG = AlgEdDSA
136+
err = jwk.Validate()
137+
if err == nil {
138+
t.Fatalf("Expected to fail validation for options not matching algorithm.")
139+
}
140+
jwk.options.Metadata.ALG = ""
141+
142+
jwk.options.Metadata.KID = anyStr
143+
err = jwk.Validate()
144+
if err == nil {
145+
t.Fatalf("Expected to fail validation for options not matching key ID.")
146+
}
147+
jwk.options.Metadata.KID = ""
148+
149+
jwk.options.Metadata.KEYOPS = []KEYOPS{KeyOpsSign}
150+
err = jwk.Validate()
151+
if err == nil {
152+
t.Fatalf("Expected to fail validation for options not matching key operations.")
153+
}
154+
jwk.options.Metadata.KEYOPS = nil
155+
156+
jwk.options.Metadata.USE = UseSig
157+
err = jwk.Validate()
158+
if err == nil {
159+
t.Fatalf("Expected to fail validation for options not matching use.")
160+
}
161+
jwk.options.Metadata.USE = ""
162+
}
163+
164+
func TestJWK_Validate_Padding(t *testing.T) {
165+
const invalidRSAModulusPadding = `
166+
{
167+
"kty": "RSA",
168+
"n": "AOpF5dwoCpmW2Th5kBaKDZmygOlyQSJm3JqwGvPTTViHCs4ZitlLF9za9-DPxP3zoNaryEYlFfLhYOFVS7mUjMGtLNTkLafBSIIoF28sy_z1GruxJ2aFchazBimxI1B0MXTKdIw4V268klrOECO5FIcHar7EV9W0XqToFon3oVvHWw3qkPV4o-A7Gdrh3Yh7vRUE_T5XCLYD9jO41nAqYhWYRGN-Kxu51x6VMa595TXTrpzgYGDba1MLQzB9qcHRIvRskt7Gh8M0zgcyo6c6jvktaEzh0j2kdL2JCAFHhMXUZedRUOpeqkEehpxDDR0Deiz7UPlMe6l8Ots97Wm357bgajDcxnqaGGEF5GIkr7xHw15DrTfOWPY35f0sHjNTOn9AU2bPWTy6oHZPhoFjHdSNp3UOIunnf1eXRlTa7YZ5PLmbFFyjNNSnQdcOHgKx1lJExJqXCAJ2pBkp0dX65uiqCLz4WZBcmCHGToi4mvQ5wpFqgUJ_6N8HXpP5ZLZ-hQ",
169+
"e": "AQAB"
170+
}`
171+
jwk, err := NewJWKFromRawJSON([]byte(invalidRSAModulusPadding), JWKMarshalOptions{}, JWKValidateOptions{})
172+
if err != nil {
173+
t.Fatalf("Failed to create JWK from raw JSON. %s", err)
174+
}
175+
err = jwk.Validate()
176+
if err != nil {
177+
t.Fatalf("Failed to validate RSA JWK with acceptably invalid padding. %s", err)
178+
}
179+
jwk.options.Validate.StrictPadding = true
180+
err = jwk.Validate()
181+
if !errors.Is(err, ErrPadding) {
182+
t.Fatalf("Expected to fail validation for invalid RSA modulus padding.")
183+
}
184+
185+
const invalidECDSAPadding = `
186+
{
187+
"kty": "EC",
188+
"crv": "P-521",
189+
"x": "aQnZOuwyXH1APmjESTgHLVUH49Ry19Ay7hgHiOB4Nsv5m_JN18wW-ByFtGtHatVJ_OHL5TuLOTSsp8ctniKTn3E",
190+
"y": "TZAwFszO_oiyvncIviOJdi8MU8VDfZo8Y3q0Z-AxaPDUFQS8aRDCHUzukj6RCNZsRCWd0HGOayIhV_uQZrB_Xbc",
191+
"d": "AZHsd9nLaXHFWH4wjiW5XcCrIO9AWl4Y0aV64kagRFPnWjljC6VxCsFF5IM0vTzCWKdlwFLEIgJO0pfwWlQMXKef"
192+
}
193+
`
194+
jwk, err = NewJWKFromRawJSON([]byte(invalidECDSAPadding), JWKMarshalOptions{}, JWKValidateOptions{})
195+
if err != nil {
196+
t.Fatalf("Failed to create JWK from raw JSON. %s", err)
197+
}
198+
err = jwk.Validate()
199+
if err != nil {
200+
t.Fatalf("Failed to validate ECDSA JWK with acceptably invalid padding. %s", err)
201+
}
202+
jwk.options.Validate.StrictPadding = true
203+
err = jwk.Validate()
204+
if !errors.Is(err, ErrPadding) {
205+
t.Fatalf("Expected to fail validation for invalid ECDSA padding.")
206+
}
207+
}
208+
100209
func testJSON(ctx context.Context, t *testing.T, jwks Storage) {
101210
b, err := base64.RawURLEncoding.DecodeString(x25519PrivateKey)
102211
if err != nil {

marshal.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ var (
3737
// OtherPrimes is for RSA private keys that have more than 2 primes.
3838
// https://www.rfc-editor.org/rfc/rfc7518#section-6.3.2.7
3939
type OtherPrimes struct {
40-
D string `json:"d,omitempty"` // https://www.rfc-editor.org/rfc/rfc7518#section-6.3.2.7.2
4140
R string `json:"r,omitempty"` // https://www.rfc-editor.org/rfc/rfc7518#section-6.3.2.7.1
41+
D string `json:"d,omitempty"` // https://www.rfc-editor.org/rfc/rfc7518#section-6.3.2.7.2
4242
T string `json:"t,omitempty"` // https://www.rfc-editor.org/rfc/rfc7518#section-6.3.2.7.3
4343
}
4444

0 commit comments

Comments
 (0)