@@ -7,17 +7,23 @@ import (
77 "crypto/ed25519"
88 "crypto/rsa"
99 "crypto/x509"
10+ "encoding/base64"
1011 "encoding/json"
1112 "errors"
1213 "fmt"
1314 "io"
15+ "math/big"
1416 "net/http"
1517 "net/url"
16- "reflect"
1718 "slices"
1819 "time"
1920)
2021
22+ var (
23+ // ErrPadding indicates that there is invalid padding.
24+ ErrPadding = errors .New ("padding error" )
25+ )
26+
2127// JWK represents a JSON Web Key.
2228type JWK struct {
2329 key any
@@ -67,6 +73,8 @@ type JWKValidateOptions struct {
6773 SkipUse bool
6874 // SkipX5UScheme is used to skip checking if the X5U URI scheme is https.
6975 SkipX5UScheme bool
76+ // StrictPadding is used to indicate that the JWK should be validated with strict padding.
77+ StrictPadding bool
7078}
7179
7280// JWKMetadataOptions are direct passthroughs into the JWKMarshal.
@@ -221,6 +229,10 @@ func (j JWK) Validate() error {
221229 return fmt .Errorf ("%w: invalid or unsupported key type %q" , ErrJWKValidation , j .marshal .KTY )
222230 }
223231
232+ if ! j .options .Validate .SkipUse && ! j .marshal .USE .IANARegistered () {
233+ return fmt .Errorf ("%w: invalid or unsupported key use %q" , ErrJWKValidation , j .marshal .USE )
234+ }
235+
224236 if ! j .options .Validate .SkipKeyOps {
225237 for _ , o := range j .marshal .KEYOPS {
226238 if ! o .IANARegistered () {
@@ -229,10 +241,6 @@ func (j JWK) Validate() error {
229241 }
230242 }
231243
232- if ! j .options .Validate .SkipUse && ! j .marshal .USE .IANARegistered () {
233- return fmt .Errorf ("%w: invalid or unsupported key use %q" , ErrJWKValidation , j .marshal .USE )
234- }
235-
236244 if ! j .options .Validate .SkipMetadata {
237245 if j .marshal .ALG != j .options .Metadata .ALG {
238246 return fmt .Errorf ("%w: ALG in marshal does not match ALG in options" , errors .Join (ErrJWKValidation , ErrOptions ))
@@ -301,18 +309,100 @@ func (j JWK) Validate() error {
301309 return fmt .Errorf ("failed to marshal JSON Web Key: %w" , errors .Join (ErrJWKValidation , err ))
302310 }
303311
312+ // Remove automatically computed thumbprints if not set in given JWK.
304313 if j .marshal .X5T == "" {
305314 marshalled .X5T = ""
306315 }
307316 if j .marshal .X5TS256 == "" {
308317 marshalled .X5TS256 = ""
309318 }
310319
311- ok := reflect .DeepEqual (j .marshal , marshalled )
312- if ! ok {
313- return fmt .Errorf ("%w: marshaled JWK does not match original JWK" , ErrJWKValidation )
320+ if j .marshal .X5T != marshalled .X5T {
321+ return fmt .Errorf ("%w: X5T in marshal does not match X5T in marshalled" , ErrJWKValidation )
322+ }
323+ if j .marshal .X5TS256 != marshalled .X5TS256 {
324+ return fmt .Errorf ("%w: X5TS256 in marshal does not match X5TS256 in marshalled" , ErrJWKValidation )
325+ }
326+ if j .marshal .CRV != marshalled .CRV {
327+ return fmt .Errorf ("%w: CRV in marshal does not match CRV in marshalled" , ErrJWKValidation )
328+ }
329+ switch j .marshal .KTY {
330+ case KtyEC :
331+ err = cmpBase64Int (j .marshal .X , marshalled .X , j .options .Validate .StrictPadding )
332+ if err != nil {
333+ return fmt .Errorf ("%w: X in marshal does not match X in marshalled" , errors .Join (ErrJWKValidation , err ))
334+ }
335+ err = cmpBase64Int (j .marshal .Y , marshalled .Y , j .options .Validate .StrictPadding )
336+ if err != nil {
337+ return fmt .Errorf ("%w: Y in marshal does not match Y in marshalled" , errors .Join (ErrJWKValidation , err ))
338+ }
339+ err = cmpBase64Int (j .marshal .D , marshalled .D , j .options .Validate .StrictPadding )
340+ if err != nil {
341+ return fmt .Errorf ("%w: D in marshal does not match D in marshalled" , errors .Join (ErrJWKValidation , err ))
342+ }
343+ case KtyOKP :
344+ if j .marshal .X != marshalled .X {
345+ return fmt .Errorf ("%w: X in marshal does not match X in marshalled" , ErrJWKValidation )
346+ }
347+ if j .marshal .D != marshalled .D {
348+ return fmt .Errorf ("%w: D in marshal does not match D in marshalled" , ErrJWKValidation )
349+ }
350+ case KtyRSA :
351+ err = cmpBase64Int (j .marshal .D , marshalled .D , j .options .Validate .StrictPadding )
352+ if err != nil {
353+ return fmt .Errorf ("%w: D in marshal does not match D in marshalled" , errors .Join (ErrJWKValidation , err ))
354+ }
355+ err = cmpBase64Int (j .marshal .N , marshalled .N , j .options .Validate .StrictPadding )
356+ if err != nil {
357+ return fmt .Errorf ("%w: N in marshal does not match N in marshalled" , errors .Join (ErrJWKValidation , err ))
358+ }
359+ err = cmpBase64Int (j .marshal .E , marshalled .E , j .options .Validate .StrictPadding )
360+ if err != nil {
361+ return fmt .Errorf ("%w: E in marshal does not match E in marshalled" , errors .Join (ErrJWKValidation , err ))
362+ }
363+ err = cmpBase64Int (j .marshal .P , marshalled .P , j .options .Validate .StrictPadding )
364+ if err != nil {
365+ return fmt .Errorf ("%w: P in marshal does not match P in marshalled" , errors .Join (ErrJWKValidation , err ))
366+ }
367+ err = cmpBase64Int (j .marshal .Q , marshalled .Q , j .options .Validate .StrictPadding )
368+ if err != nil {
369+ return fmt .Errorf ("%w: Q in marshal does not match Q in marshalled" , errors .Join (ErrJWKValidation , err ))
370+ }
371+ err = cmpBase64Int (j .marshal .DP , marshalled .DP , j .options .Validate .StrictPadding )
372+ if err != nil {
373+ return fmt .Errorf ("%w: DP in marshal does not match DP in marshalled" , errors .Join (ErrJWKValidation , err ))
374+ }
375+ err = cmpBase64Int (j .marshal .DQ , marshalled .DQ , j .options .Validate .StrictPadding )
376+ if err != nil {
377+ return fmt .Errorf ("%w: DQ in marshal does not match DQ in marshalled" , errors .Join (ErrJWKValidation , err ))
378+ }
379+ if len (j .marshal .OTH ) != len (marshalled .OTH ) {
380+ return fmt .Errorf ("%w: OTH in marshal does not match OTH in marshalled" , ErrJWKValidation )
381+ }
382+ for i , o := range j .marshal .OTH {
383+ err = cmpBase64Int (o .R , marshalled .OTH [i ].R , j .options .Validate .StrictPadding )
384+ if err != nil {
385+ return fmt .Errorf ("%w: OTH index %d in marshal does not match OTH in marshalled" , errors .Join (ErrJWKValidation , err ), i )
386+ }
387+ err = cmpBase64Int (o .D , marshalled .OTH [i ].D , j .options .Validate .StrictPadding )
388+ if err != nil {
389+ return fmt .Errorf ("%w: OTH index %d in marshal does not match OTH in marshalled" , errors .Join (ErrJWKValidation , err ), i )
390+ }
391+ err = cmpBase64Int (o .T , marshalled .OTH [i ].T , j .options .Validate .StrictPadding )
392+ if err != nil {
393+ return fmt .Errorf ("%w: OTH index %d in marshal does not match OTH in marshalled" , errors .Join (ErrJWKValidation , err ), i )
394+ }
395+ }
396+ case KtyOct :
397+ err = cmpBase64Int (j .marshal .K , marshalled .K , j .options .Validate .StrictPadding )
398+ if err != nil {
399+ return fmt .Errorf ("%w: K in marshal does not match K in marshalled" , errors .Join (ErrJWKValidation , err ))
400+ }
401+ default :
402+ return fmt .Errorf ("%w: invalid or unsupported key type %q" , ErrJWKValidation , j .marshal .KTY )
314403 }
315404
405+ // Saved for last because it may involve a network request.
316406 if j .marshal .X5U != "" || j .options .X509 .X5U != "" {
317407 if j .marshal .X5U != j .options .X509 .X5U {
318408 return fmt .Errorf ("%w: X5U in marshal does not match X5U in options" , errors .Join (ErrJWKValidation , ErrOptions ))
@@ -376,3 +466,28 @@ func DefaultGetX5U(u *url.URL) ([]*x509.Certificate, error) {
376466 }
377467 return certs , nil
378468}
469+
470+ func cmpBase64Int (first , second string , strictPadding bool ) error {
471+ if first == second {
472+ return nil
473+ }
474+ b , err := base64 .RawURLEncoding .DecodeString (first )
475+ if err != nil {
476+ return fmt .Errorf ("failed to decode Base64 raw URL decode first string: %w" , err )
477+ }
478+ fLen := len (b )
479+ f := new (big.Int ).SetBytes (b )
480+ b , err = base64 .RawURLEncoding .DecodeString (second )
481+ if err != nil {
482+ return fmt .Errorf ("failed to decode Base64 raw URL decode second string: %w" , err )
483+ }
484+ sLen := len (b )
485+ s := new (big.Int ).SetBytes (b )
486+ if f .Cmp (s ) != 0 {
487+ return fmt .Errorf ("%w: the parsed integers do not match" , ErrJWKValidation )
488+ }
489+ if strictPadding && fLen != sLen {
490+ return fmt .Errorf ("%w: the Base64 raw URL inputs do not have matching padding" , errors .Join (ErrJWKValidation , ErrPadding ))
491+ }
492+ return nil
493+ }
0 commit comments