Skip to content

Commit 01db49a

Browse files
authored
HTTP client only overwrites and appends JWK to local cache during refresh (#41)
1 parent ce57ded commit 01db49a

File tree

6 files changed

+147
-33
lines changed

6 files changed

+147
-33
lines changed

examples/storage_operations/go.mod

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ go 1.21
55
replace github.com/MicahParks/jwkset => ../..
66

77
require (
8-
github.com/MicahParks/jwkset v0.5.6
9-
github.com/google/uuid v1.5.0
8+
github.com/MicahParks/jwkset v0.6.0
9+
github.com/google/uuid v1.6.0
1010
)
1111

12-
require golang.org/x/time v0.5.0 // indirect
12+
require golang.org/x/time v0.9.0 // indirect

examples/storage_operations/go.sum

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
2-
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
3-
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
4-
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
1+
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
2+
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
3+
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
4+
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=

go.mod

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@ module github.com/MicahParks/jwkset
22

33
go 1.21
44

5-
require golang.org/x/time v0.5.0
5+
require golang.org/x/time v0.9.0
6+
7+
retract [v0.5.0, v0.5.15] // HTTP client only overwrites and appends JWK to local cache during refresh: https://github.com/MicahParks/jwkset/issues/40

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
2-
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
1+
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
2+
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=

http_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66
"net/http"
77
"net/http/httptest"
8+
"net/url"
89
"strings"
910
"sync"
1011
"testing"
@@ -144,6 +145,119 @@ func TestClient(t *testing.T) {
144145
}
145146
}
146147

148+
func TestClientCacheReplacement(t *testing.T) {
149+
ctx, cancel := context.WithCancel(context.Background())
150+
defer cancel()
151+
152+
kid := "my-key-id"
153+
secret := []byte("my-hmac-secret")
154+
serverStore := NewMemoryStorage()
155+
marshalOptions := JWKMarshalOptions{
156+
Private: true,
157+
}
158+
metadata := JWKMetadataOptions{
159+
KID: kid,
160+
}
161+
options := JWKOptions{
162+
Marshal: marshalOptions,
163+
Metadata: metadata,
164+
}
165+
jwk, err := NewJWKFromKey(secret, options)
166+
if err != nil {
167+
t.Fatalf("Failed to create a JWK from the given HMAC secret.\nError: %s", err)
168+
}
169+
err = serverStore.KeyWrite(ctx, jwk)
170+
if err != nil {
171+
t.Fatalf("Failed to write the given JWK to the store.\nError: %s", err)
172+
}
173+
rawJWKS, err := serverStore.JSON(ctx)
174+
if err != nil {
175+
t.Fatalf("Failed to get the JSON.\nError: %s", err)
176+
}
177+
178+
rawJWKSMux := sync.RWMutex{}
179+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
180+
rawJWKSMux.RLock()
181+
defer rawJWKSMux.RUnlock()
182+
_, _ = w.Write(rawJWKS)
183+
}))
184+
185+
u, err := url.ParseRequestURI(server.URL)
186+
if err != nil {
187+
t.Fatalf("Failed to parse the URL.\nError: %s", err)
188+
}
189+
190+
refreshInterval := 50 * time.Millisecond
191+
httpOptions := HTTPClientStorageOptions{
192+
Ctx: ctx,
193+
RefreshInterval: refreshInterval,
194+
}
195+
clientStore, err := NewStorageFromHTTP(u, httpOptions)
196+
if err != nil {
197+
t.Fatalf("Failed to create a new HTTP client.\nError: %s", err)
198+
}
199+
200+
jwk, err = clientStore.KeyRead(ctx, kid)
201+
if err != nil {
202+
t.Fatalf("Failed to read the JWK.\nError: %s", err)
203+
}
204+
205+
if !bytes.Equal(jwk.Key().([]byte), secret) {
206+
t.Fatalf("The key read from the HTTP client did not match the original key.")
207+
}
208+
209+
jwks, err := clientStore.KeyReadAll(ctx)
210+
if err != nil {
211+
t.Fatalf("Failed to read all the JWKs.\nError: %s", err)
212+
}
213+
if len(jwks) != 1 {
214+
t.Fatalf("Expected to read 1 JWK, but got %d.", len(jwks))
215+
}
216+
if !bytes.Equal(jwks[0].Key().([]byte), secret) {
217+
t.Fatalf("The key read from the HTTP client did not match the original key.")
218+
}
219+
220+
otherKeyID := myKeyID + "2"
221+
options.Metadata.KID = otherKeyID
222+
otherSecret := []byte("my-other-hmac-secret")
223+
jwk, err = NewJWKFromKey(otherSecret, options)
224+
if err != nil {
225+
t.Fatalf("Failed to create a JWK from the given HMAC secret.\nError: %s", err)
226+
}
227+
err = serverStore.KeyWrite(ctx, jwk)
228+
if err != nil {
229+
t.Fatalf("Failed to write the given JWK to the store.\nError: %s", err)
230+
}
231+
ok, err := serverStore.KeyDelete(ctx, kid)
232+
if err != nil {
233+
t.Fatalf("Failed to delete the given JWK from the store.\nError: %s", err)
234+
}
235+
if !ok {
236+
t.Fatalf("Expected the key to be deleted.")
237+
}
238+
rawJWKSMux.Lock()
239+
rawJWKS, err = serverStore.JSON(ctx)
240+
rawJWKSMux.Unlock()
241+
if err != nil {
242+
t.Fatalf("Failed to get the JSON.\nError: %s", err)
243+
}
244+
time.Sleep(2 * refreshInterval)
245+
246+
jwks, err = clientStore.KeyReadAll(ctx)
247+
if err != nil {
248+
t.Fatalf("Failed to read the JWK.\nError: %s", err)
249+
}
250+
if len(jwks) != 1 {
251+
t.Fatalf("Expected to read 1 JWK, but got %d.", len(jwks))
252+
}
253+
if jwks[0].marshal.KID != otherKeyID {
254+
t.Fatalf("The key read from the HTTP client did not match the original key.")
255+
}
256+
if !bytes.Equal(jwks[0].Key().([]byte), otherSecret) {
257+
t.Fatalf("The key read from the HTTP client did not match the original key.")
258+
}
259+
}
260+
147261
func TestClientError(t *testing.T) {
148262
_, err := NewHTTPClient(HTTPClientOptions{})
149263
if err == nil {

storage.go

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,19 @@ type Storage interface {
4949
MarshalWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (JWKSMarshal, error)
5050
}
5151

52-
var _ Storage = &memoryJWKSet{}
52+
var _ Storage = &MemoryJWKSet{}
5353

54-
type memoryJWKSet struct {
54+
type MemoryJWKSet struct {
5555
set []JWK
5656
mux sync.RWMutex
5757
}
5858

5959
// NewMemoryStorage creates a new in-memory Storage implementation.
60-
func NewMemoryStorage() Storage {
61-
return &memoryJWKSet{}
60+
func NewMemoryStorage() *MemoryJWKSet {
61+
return &MemoryJWKSet{}
6262
}
6363

64-
func (m *memoryJWKSet) KeyDelete(_ context.Context, keyID string) (ok bool, err error) {
64+
func (m *MemoryJWKSet) KeyDelete(_ context.Context, keyID string) (ok bool, err error) {
6565
m.mux.Lock()
6666
defer m.mux.Unlock()
6767
for i, jwk := range m.set {
@@ -72,7 +72,12 @@ func (m *memoryJWKSet) KeyDelete(_ context.Context, keyID string) (ok bool, err
7272
}
7373
return ok, nil
7474
}
75-
func (m *memoryJWKSet) KeyRead(_ context.Context, keyID string) (JWK, error) {
75+
func (m *MemoryJWKSet) KeyDeleteAll() {
76+
m.mux.Lock()
77+
defer m.mux.Unlock()
78+
m.set = make([]JWK, 0)
79+
}
80+
func (m *MemoryJWKSet) KeyRead(_ context.Context, keyID string) (JWK, error) {
7681
m.mux.RLock()
7782
defer m.mux.RUnlock()
7883
for _, jwk := range m.set {
@@ -82,12 +87,12 @@ func (m *memoryJWKSet) KeyRead(_ context.Context, keyID string) (JWK, error) {
8287
}
8388
return JWK{}, fmt.Errorf("%w: kid %q", ErrKeyNotFound, keyID)
8489
}
85-
func (m *memoryJWKSet) KeyReadAll(_ context.Context) ([]JWK, error) {
90+
func (m *MemoryJWKSet) KeyReadAll(_ context.Context) ([]JWK, error) {
8691
m.mux.RLock()
8792
defer m.mux.RUnlock()
8893
return slices.Clone(m.set), nil
8994
}
90-
func (m *memoryJWKSet) KeyWrite(_ context.Context, jwk JWK) error {
95+
func (m *MemoryJWKSet) KeyWrite(_ context.Context, jwk JWK) error {
9196
m.mux.Lock()
9297
defer m.mux.Unlock()
9398
for i, j := range m.set {
@@ -100,30 +105,30 @@ func (m *memoryJWKSet) KeyWrite(_ context.Context, jwk JWK) error {
100105
return nil
101106
}
102107

103-
func (m *memoryJWKSet) JSON(ctx context.Context) (json.RawMessage, error) {
108+
func (m *MemoryJWKSet) JSON(ctx context.Context) (json.RawMessage, error) {
104109
jwks, err := m.Marshal(ctx)
105110
if err != nil {
106111
return nil, fmt.Errorf("failed to marshal JWK Set: %w", err)
107112
}
108113
return json.Marshal(jwks)
109114
}
110-
func (m *memoryJWKSet) JSONPublic(ctx context.Context) (json.RawMessage, error) {
115+
func (m *MemoryJWKSet) JSONPublic(ctx context.Context) (json.RawMessage, error) {
111116
return m.JSONWithOptions(ctx, JWKMarshalOptions{}, JWKValidateOptions{})
112117
}
113-
func (m *memoryJWKSet) JSONPrivate(ctx context.Context) (json.RawMessage, error) {
118+
func (m *MemoryJWKSet) JSONPrivate(ctx context.Context) (json.RawMessage, error) {
114119
marshalOptions := JWKMarshalOptions{
115120
Private: true,
116121
}
117122
return m.JSONWithOptions(ctx, marshalOptions, JWKValidateOptions{})
118123
}
119-
func (m *memoryJWKSet) JSONWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (json.RawMessage, error) {
124+
func (m *MemoryJWKSet) JSONWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (json.RawMessage, error) {
120125
jwks, err := m.MarshalWithOptions(ctx, marshalOptions, validationOptions)
121126
if err != nil {
122127
return nil, fmt.Errorf("failed to marshal JWK Set with options: %w", err)
123128
}
124129
return json.Marshal(jwks)
125130
}
126-
func (m *memoryJWKSet) Marshal(ctx context.Context) (JWKSMarshal, error) {
131+
func (m *MemoryJWKSet) Marshal(ctx context.Context) (JWKSMarshal, error) {
127132
keys, err := m.KeyReadAll(ctx)
128133
if err != nil {
129134
return JWKSMarshal{}, fmt.Errorf("failed to read snapshot of all keys from storage: %w", err)
@@ -134,7 +139,7 @@ func (m *memoryJWKSet) Marshal(ctx context.Context) (JWKSMarshal, error) {
134139
}
135140
return jwks, nil
136141
}
137-
func (m *memoryJWKSet) MarshalWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (JWKSMarshal, error) {
142+
func (m *MemoryJWKSet) MarshalWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (JWKSMarshal, error) {
138143
jwks := JWKSMarshal{}
139144

140145
keys, err := m.KeyReadAll(ctx)
@@ -203,11 +208,6 @@ type HTTPClientStorageOptions struct {
203208
// Provide the Ctx option to end the goroutine when it's no longer needed.
204209
RefreshInterval time.Duration
205210

206-
// Storage is the underlying storage implementation to use.
207-
//
208-
// This defaults to NewMemoryStorage().
209-
Storage Storage
210-
211211
// ValidateOptions are the options to use when validating the JWKs.
212212
ValidateOptions JWKValidateOptions
213213
}
@@ -238,10 +238,7 @@ func NewStorageFromHTTP(u *url.URL, options HTTPClientStorageOptions) (Storage,
238238
if options.HTTPMethod == "" {
239239
options.HTTPMethod = http.MethodGet
240240
}
241-
store := options.Storage
242-
if store == nil {
243-
store = NewMemoryStorage()
244-
}
241+
store := NewMemoryStorage()
245242

246243
refresh := func(ctx context.Context) error {
247244
req, err := http.NewRequestWithContext(ctx, options.HTTPMethod, u.String(), nil)
@@ -262,6 +259,7 @@ func NewStorageFromHTTP(u *url.URL, options HTTPClientStorageOptions) (Storage,
262259
if err != nil {
263260
return fmt.Errorf("failed to decode JWK Set response: %w", err)
264261
}
262+
store.KeyDeleteAll() // Clear local cache in case of key revocation.
265263
for _, marshal := range jwks.Keys {
266264
marshalOptions := JWKMarshalOptions{
267265
Private: true,

0 commit comments

Comments
 (0)