Skip to content

Commit 826bb3f

Browse files
committed
Make JWKS client cache threadsafe
1 parent ceecaee commit 826bb3f

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

tests/test_session.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import AsyncMock, Mock, patch
33
import jwt
44
from datetime import datetime, timezone
5+
import concurrent.futures
56

67
from tests.conftest import with_jwks_mock
78
from workos.session import AsyncSession, Session, _get_jwks_client, _jwks_cache
@@ -22,9 +23,9 @@
2223
class SessionFixtures:
2324
@pytest.fixture(autouse=True)
2425
def clear_jwks_cache(self):
25-
_jwks_cache._clients.clear()
26+
_jwks_cache.clear()
2627
yield
27-
_jwks_cache._clients.clear()
28+
_jwks_cache.clear()
2829

2930
@pytest.fixture
3031
def session_constants(self):
@@ -520,3 +521,20 @@ def test_jwks_client_caching_different_urls(self):
520521
# Should be different instances
521522
assert client1 is not client2
522523
assert id(client1) != id(client2)
524+
525+
def test_jwks_cache_thread_safety(self):
526+
url = "https://api.workos.com/sso/jwks/thread_test"
527+
clients = []
528+
529+
def get_client():
530+
return _get_jwks_client(url)
531+
532+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
533+
futures = [executor.submit(get_client) for _ in range(10)]
534+
clients = [future.result() for future in futures]
535+
536+
first_client = clients[0]
537+
for client in clients[1:]:
538+
assert (
539+
client is first_client
540+
), "All concurrent calls should return the same instance"

workos/session.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import TYPE_CHECKING, List, Protocol
33

44
import json
5+
import threading
56
from typing import Any, Dict, Optional, Union, cast
67
import jwt
78
from jwt import PyJWKClient
@@ -24,11 +25,28 @@
2425
class _JWKSClientCache:
2526
def __init__(self) -> None:
2627
self._clients: Dict[str, PyJWKClient] = {}
28+
self._lock = threading.Lock()
2729

2830
def get_client(self, jwks_url: str) -> PyJWKClient:
29-
if jwks_url not in self._clients:
30-
self._clients[jwks_url] = PyJWKClient(jwks_url)
31-
return self._clients[jwks_url]
31+
if jwks_url in self._clients:
32+
return self._clients[jwks_url]
33+
34+
with self._lock:
35+
if jwks_url in self._clients:
36+
return self._clients[jwks_url]
37+
38+
client = PyJWKClient(jwks_url)
39+
self._clients[jwks_url] = client
40+
return client
41+
42+
def clear(self) -> None:
43+
"""Intended primarily for test cleanup and manual cache invalidation.
44+
45+
Warning: If called concurrently with get_client(), some newly created
46+
clients might be lost due to lock acquisition ordering.
47+
"""
48+
with self._lock:
49+
self._clients.clear()
3250

3351

3452
# Module-level cache instance

0 commit comments

Comments
 (0)