@@ -49,19 +49,19 @@ type Storage interface {
4949 MarshalWithOptions (ctx context.Context , marshalOptions JWKMarshalOptions , validationOptions JWKValidateOptions ) (JWKSMarshal , error )
5050}
5151
52- var _ Storage = & memoryJWKSet {}
52+ var _ Storage = & MemoryJWKSet {}
5353
54- type memoryJWKSet struct {
54+ type MemoryJWKSet struct {
5555 set []JWK
5656 mux sync.RWMutex
5757}
5858
5959// NewMemoryStorage creates a new in-memory Storage implementation.
60- func NewMemoryStorage () Storage {
61- return & memoryJWKSet {}
60+ func NewMemoryStorage () * MemoryJWKSet {
61+ return & MemoryJWKSet {}
6262}
6363
64- func (m * memoryJWKSet ) KeyDelete (_ context.Context , keyID string ) (ok bool , err error ) {
64+ func (m * MemoryJWKSet ) KeyDelete (_ context.Context , keyID string ) (ok bool , err error ) {
6565 m .mux .Lock ()
6666 defer m .mux .Unlock ()
6767 for i , jwk := range m .set {
@@ -72,7 +72,12 @@ func (m *memoryJWKSet) KeyDelete(_ context.Context, keyID string) (ok bool, err
7272 }
7373 return ok , nil
7474}
75- func (m * memoryJWKSet ) KeyRead (_ context.Context , keyID string ) (JWK , error ) {
75+ func (m * MemoryJWKSet ) KeyDeleteAll () {
76+ m .mux .Lock ()
77+ defer m .mux .Unlock ()
78+ m .set = make ([]JWK , 0 )
79+ }
80+ func (m * MemoryJWKSet ) KeyRead (_ context.Context , keyID string ) (JWK , error ) {
7681 m .mux .RLock ()
7782 defer m .mux .RUnlock ()
7883 for _ , jwk := range m .set {
@@ -82,12 +87,12 @@ func (m *memoryJWKSet) KeyRead(_ context.Context, keyID string) (JWK, error) {
8287 }
8388 return JWK {}, fmt .Errorf ("%w: kid %q" , ErrKeyNotFound , keyID )
8489}
85- func (m * memoryJWKSet ) KeyReadAll (_ context.Context ) ([]JWK , error ) {
90+ func (m * MemoryJWKSet ) KeyReadAll (_ context.Context ) ([]JWK , error ) {
8691 m .mux .RLock ()
8792 defer m .mux .RUnlock ()
8893 return slices .Clone (m .set ), nil
8994}
90- func (m * memoryJWKSet ) KeyWrite (_ context.Context , jwk JWK ) error {
95+ func (m * MemoryJWKSet ) KeyWrite (_ context.Context , jwk JWK ) error {
9196 m .mux .Lock ()
9297 defer m .mux .Unlock ()
9398 for i , j := range m .set {
@@ -100,30 +105,30 @@ func (m *memoryJWKSet) KeyWrite(_ context.Context, jwk JWK) error {
100105 return nil
101106}
102107
103- func (m * memoryJWKSet ) JSON (ctx context.Context ) (json.RawMessage , error ) {
108+ func (m * MemoryJWKSet ) JSON (ctx context.Context ) (json.RawMessage , error ) {
104109 jwks , err := m .Marshal (ctx )
105110 if err != nil {
106111 return nil , fmt .Errorf ("failed to marshal JWK Set: %w" , err )
107112 }
108113 return json .Marshal (jwks )
109114}
110- func (m * memoryJWKSet ) JSONPublic (ctx context.Context ) (json.RawMessage , error ) {
115+ func (m * MemoryJWKSet ) JSONPublic (ctx context.Context ) (json.RawMessage , error ) {
111116 return m .JSONWithOptions (ctx , JWKMarshalOptions {}, JWKValidateOptions {})
112117}
113- func (m * memoryJWKSet ) JSONPrivate (ctx context.Context ) (json.RawMessage , error ) {
118+ func (m * MemoryJWKSet ) JSONPrivate (ctx context.Context ) (json.RawMessage , error ) {
114119 marshalOptions := JWKMarshalOptions {
115120 Private : true ,
116121 }
117122 return m .JSONWithOptions (ctx , marshalOptions , JWKValidateOptions {})
118123}
119- func (m * memoryJWKSet ) JSONWithOptions (ctx context.Context , marshalOptions JWKMarshalOptions , validationOptions JWKValidateOptions ) (json.RawMessage , error ) {
124+ func (m * MemoryJWKSet ) JSONWithOptions (ctx context.Context , marshalOptions JWKMarshalOptions , validationOptions JWKValidateOptions ) (json.RawMessage , error ) {
120125 jwks , err := m .MarshalWithOptions (ctx , marshalOptions , validationOptions )
121126 if err != nil {
122127 return nil , fmt .Errorf ("failed to marshal JWK Set with options: %w" , err )
123128 }
124129 return json .Marshal (jwks )
125130}
126- func (m * memoryJWKSet ) Marshal (ctx context.Context ) (JWKSMarshal , error ) {
131+ func (m * MemoryJWKSet ) Marshal (ctx context.Context ) (JWKSMarshal , error ) {
127132 keys , err := m .KeyReadAll (ctx )
128133 if err != nil {
129134 return JWKSMarshal {}, fmt .Errorf ("failed to read snapshot of all keys from storage: %w" , err )
@@ -134,7 +139,7 @@ func (m *memoryJWKSet) Marshal(ctx context.Context) (JWKSMarshal, error) {
134139 }
135140 return jwks , nil
136141}
137- func (m * memoryJWKSet ) MarshalWithOptions (ctx context.Context , marshalOptions JWKMarshalOptions , validationOptions JWKValidateOptions ) (JWKSMarshal , error ) {
142+ func (m * MemoryJWKSet ) MarshalWithOptions (ctx context.Context , marshalOptions JWKMarshalOptions , validationOptions JWKValidateOptions ) (JWKSMarshal , error ) {
138143 jwks := JWKSMarshal {}
139144
140145 keys , err := m .KeyReadAll (ctx )
@@ -203,11 +208,6 @@ type HTTPClientStorageOptions struct {
203208 // Provide the Ctx option to end the goroutine when it's no longer needed.
204209 RefreshInterval time.Duration
205210
206- // Storage is the underlying storage implementation to use.
207- //
208- // This defaults to NewMemoryStorage().
209- Storage Storage
210-
211211 // ValidateOptions are the options to use when validating the JWKs.
212212 ValidateOptions JWKValidateOptions
213213}
@@ -238,10 +238,7 @@ func NewStorageFromHTTP(u *url.URL, options HTTPClientStorageOptions) (Storage,
238238 if options .HTTPMethod == "" {
239239 options .HTTPMethod = http .MethodGet
240240 }
241- store := options .Storage
242- if store == nil {
243- store = NewMemoryStorage ()
244- }
241+ store := NewMemoryStorage ()
245242
246243 refresh := func (ctx context.Context ) error {
247244 req , err := http .NewRequestWithContext (ctx , options .HTTPMethod , u .String (), nil )
@@ -262,6 +259,7 @@ func NewStorageFromHTTP(u *url.URL, options HTTPClientStorageOptions) (Storage,
262259 if err != nil {
263260 return fmt .Errorf ("failed to decode JWK Set response: %w" , err )
264261 }
262+ store .KeyDeleteAll () // Clear local cache in case of key revocation.
265263 for _ , marshal := range jwks .Keys {
266264 marshalOptions := JWKMarshalOptions {
267265 Private : true ,
0 commit comments