Skip to content

Commit fc576f1

Browse files
authored
Merge pull request #166 from scouturier/feat/cache-oidc-id-token-silent-refresh
feat: cache OIDC id_token for silent credential refresh
2 parents 42939c0 + 74c7e4d commit fc576f1

File tree

4 files changed

+273
-24
lines changed

4 files changed

+273
-24
lines changed

source/credential_provider/__main__.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,8 @@ def get_monitoring_token(self):
584584
exp_time = token_data.get("expires", 0)
585585
now = int(datetime.now(timezone.utc).timestamp())
586586

587-
# Return token if it expires in more than 10 minutes
588-
if exp_time - now > 600:
587+
# Return token if it expires in more than 60 seconds
588+
if exp_time - now > 60:
589589
token = token_data["token"]
590590
# Set in environment for this session
591591
os.environ["CLAUDE_CODE_MONITORING_TOKEN"] = token
@@ -1759,6 +1759,30 @@ def _handle_quota_warning(self, quota_result: dict):
17591759
# End Quota Check Methods
17601760
# ===========================================
17611761

1762+
def _try_silent_refresh(self):
1763+
"""Attempt to refresh AWS credentials using a cached, still-valid OIDC id_token.
1764+
1765+
Returns:
1766+
Tuple of (credentials, id_token, token_claims) if successful, (None, None, None) otherwise.
1767+
"""
1768+
try:
1769+
id_token = self.get_monitoring_token()
1770+
if not id_token:
1771+
self._debug_print("No valid cached id_token for silent refresh")
1772+
return None, None, None
1773+
1774+
self._debug_print("Found valid cached id_token, attempting silent credential refresh...")
1775+
token_claims = jwt.decode(id_token, options={"verify_signature": False})
1776+
1777+
credentials = self.get_aws_credentials(id_token, token_claims)
1778+
self.save_credentials(credentials)
1779+
self.save_monitoring_token(id_token, token_claims)
1780+
self._debug_print("Silent credential refresh succeeded")
1781+
return credentials, id_token, token_claims
1782+
except Exception as e:
1783+
self._debug_print(f"Silent refresh failed, will require browser auth: {e}")
1784+
return None, None, None
1785+
17621786
def run(self):
17631787
"""Main execution flow"""
17641788
try:
@@ -1817,7 +1841,23 @@ def run(self):
18171841
print(json.dumps(cached)) # noqa: S105
18181842
return 0
18191843

1820-
# Authenticate with OIDC provider
1844+
# Try silent refresh using cached id_token before opening browser
1845+
silent_creds, id_token, token_claims = self._try_silent_refresh()
1846+
if silent_creds:
1847+
# Check quota if configured (reuse token/claims already fetched above)
1848+
if self._should_check_quota():
1849+
if id_token and token_claims:
1850+
quota_result = self._check_quota(token_claims, id_token)
1851+
self._save_quota_check_timestamp()
1852+
if not quota_result.get("allowed", True):
1853+
return self._handle_quota_blocked(quota_result)
1854+
else:
1855+
self._handle_quota_warning(quota_result)
1856+
1857+
print(json.dumps(silent_creds))
1858+
return 0
1859+
1860+
# Authenticate with OIDC provider (browser popup - only when id_token is also expired)
18211861
self._debug_print(f"Authenticating with {self.provider_config['name']} for profile '{self.profile}'...")
18221862
id_token, token_claims = self.authenticate_oidc()
18231863

source/otel_helper/__main__.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -202,20 +202,23 @@ def get_cache_path():
202202

203203

204204
def read_cached_headers():
205-
"""Read cached OTEL headers if they exist and the token hasn't expired."""
205+
"""Read cached OTEL headers if they exist.
206+
207+
User attributes (email, team, etc.) don't change between sessions,
208+
so cached headers are served regardless of token expiry. Headers are
209+
refreshed opportunistically when a valid token is available.
210+
"""
206211
try:
207212
cache_path = get_cache_path()
208213
if not cache_path.exists():
209214
return None
210215
with open(cache_path) as f:
211216
cached = json.load(f)
212-
# Check if token expires in more than 10 minutes
213-
now = int(time.time())
214-
if cached.get("token_exp", 0) - now > 600:
215-
logger.debug("Using cached OTEL headers (token still valid)")
216-
return cached["headers"]
217-
logger.debug("Cached OTEL headers expired or expiring soon")
218-
return None
217+
headers = cached.get("headers")
218+
if not headers:
219+
return None
220+
logger.debug("Using cached OTEL headers")
221+
return headers
219222
except Exception as e:
220223
logger.debug(f"Failed to read cached headers: {e}")
221224
return None
@@ -332,10 +335,6 @@ def main():
332335

333336
# Generate headers dictionary
334337
headers_dict = format_as_headers_dict(user_info)
335-
# Only include Bearer token when ALB JWT validation is enabled (set by installer)
336-
if os.environ.get("OTEL_JWT_AUTH", "").lower() in ("true", "1", "yes"):
337-
headers_dict["authorization"] = f"Bearer {token}"
338-
339338
# In test mode, print detailed output
340339
if TEST_MODE:
341340
print("===== TEST MODE OUTPUT =====\n")

source/otel_helper/otel-helper.sh

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,10 @@ CACHE_DIR="$HOME/.claude-code-session"
66
CACHE_FILE="$CACHE_DIR/${PROFILE}-otel-headers.json"
77
RAW_FILE="$CACHE_DIR/${PROFILE}-otel-headers.raw"
88

9-
if [ -f "$CACHE_FILE" ] && [ -f "$RAW_FILE" ]; then
10-
# Extract token_exp from metadata cache (date +%s is GNU/BSD, works on macOS and Linux)
11-
TOKEN_EXP=$(grep -o '"token_exp": *[0-9]*' "$CACHE_FILE" | grep -o '[0-9]*')
12-
NOW=$(date +%s)
13-
if [ -n "$TOKEN_EXP" ] && [ "$((TOKEN_EXP - NOW))" -gt 600 ]; then
14-
# Serve raw headers directly — no JSON parsing needed
15-
cat "$RAW_FILE"
16-
exit 0
17-
fi
9+
if [ -f "$RAW_FILE" ]; then
10+
# Serve raw headers directly — no JSON parsing needed
11+
cat "$RAW_FILE"
12+
exit 0
1813
fi
1914

2015
# Cache miss - fall back to full PyInstaller binary (which writes the cache)
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# ABOUTME: Tests for OIDC id_token caching and silent credential refresh (issue #153)
2+
# ABOUTME: Verifies that expired AWS credentials can be refreshed without browser popup
3+
"""Tests for silent credential refresh using cached OIDC id_token."""
4+
5+
import json
6+
import time
7+
from unittest.mock import MagicMock, patch
8+
9+
import jwt as pyjwt
10+
import pytest
11+
12+
13+
def _make_id_token(exp_offset=3600, email="test@example.com"):
14+
"""Create a minimal JWT id_token for testing.
15+
16+
Args:
17+
exp_offset: Seconds from now until expiration (positive = future).
18+
email: Email claim to embed.
19+
"""
20+
claims = {
21+
"sub": "user-123",
22+
"email": email,
23+
"iss": "https://test.okta.com",
24+
"aud": "test-client-id",
25+
"exp": int(time.time()) + exp_offset,
26+
"iat": int(time.time()),
27+
"nonce": "test-nonce",
28+
}
29+
# Encode without signing (matches how the provider decodes with verify_signature=False)
30+
return pyjwt.encode(claims, "secret", algorithm="HS256"), claims
31+
32+
33+
def _make_config():
34+
"""Return a minimal config dict for MultiProviderAuth."""
35+
return {
36+
"profiles": {
37+
"TestProfile": {
38+
"provider_domain": "test.okta.com",
39+
"client_id": "test-client-id",
40+
"identity_pool_id": "us-east-1:test-pool",
41+
"aws_region": "us-east-1",
42+
"credential_storage": "session",
43+
}
44+
}
45+
}
46+
47+
48+
def _make_aws_credentials(exp_offset=900):
49+
"""Return fake AWS credentials dict."""
50+
from datetime import datetime, timezone, timedelta
51+
52+
exp = datetime.now(timezone.utc) + timedelta(seconds=exp_offset)
53+
return {
54+
"Version": 1,
55+
"AccessKeyId": "AKIAIOSFODNN7EXAMPLE",
56+
"SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
57+
"SessionToken": "FwoGZXIvYXdzEBYaDH...",
58+
"Expiration": exp.isoformat(),
59+
}
60+
61+
62+
@pytest.fixture
63+
def auth_instance(tmp_path):
64+
"""Create a MultiProviderAuth instance with mocked config."""
65+
config_file = tmp_path / "config.json"
66+
config_file.write_text(json.dumps(_make_config()))
67+
68+
with patch("credential_provider.__main__.Path") as mock_path_cls:
69+
# Make _load_config find our temp config
70+
mock_home = MagicMock()
71+
mock_path_cls.home.return_value = mock_home
72+
mock_home.__truediv__ = lambda self, key: tmp_path / key if key == "claude-code-with-bedrock" else MagicMock()
73+
74+
# Also mock __file__ parent for binary dir config lookup
75+
mock_file_parent = MagicMock()
76+
mock_file_parent.__truediv__ = lambda self, key: MagicMock(exists=lambda: False)
77+
mock_path_cls.return_value = mock_file_parent
78+
79+
# Simpler approach: just patch _load_config and _init_credential_storage
80+
with patch("credential_provider.__main__.MultiProviderAuth._load_config") as mock_load, \
81+
patch("credential_provider.__main__.MultiProviderAuth._init_credential_storage"):
82+
mock_load.return_value = {
83+
"provider_domain": "test.okta.com",
84+
"client_id": "test-client-id",
85+
"identity_pool_id": "us-east-1:test-pool",
86+
"aws_region": "us-east-1",
87+
"credential_storage": "session",
88+
"provider_type": "okta",
89+
"federation_type": "cognito",
90+
"max_session_duration": 28800,
91+
}
92+
93+
from credential_provider.__main__ import MultiProviderAuth
94+
instance = MultiProviderAuth(profile="TestProfile")
95+
instance.cache_dir = tmp_path / "cache"
96+
instance.cache_dir.mkdir(parents=True, exist_ok=True)
97+
return instance
98+
99+
100+
class TestSilentRefresh:
101+
"""Tests for _try_silent_refresh method."""
102+
103+
def test_silent_refresh_succeeds_with_valid_id_token(self, auth_instance):
104+
"""When a valid id_token is cached, silent refresh should return new AWS creds."""
105+
id_token, claims = _make_id_token(exp_offset=3600)
106+
aws_creds = _make_aws_credentials()
107+
108+
with patch.object(auth_instance, "get_monitoring_token", return_value=id_token), \
109+
patch.object(auth_instance, "get_aws_credentials", return_value=aws_creds) as mock_get_creds, \
110+
patch.object(auth_instance, "save_credentials") as mock_save, \
111+
patch.object(auth_instance, "save_monitoring_token") as mock_save_token:
112+
113+
creds, returned_token, returned_claims = auth_instance._try_silent_refresh()
114+
115+
assert creds is not None
116+
assert creds["AccessKeyId"] == aws_creds["AccessKeyId"]
117+
assert returned_token == id_token
118+
assert returned_claims["sub"] == claims["sub"]
119+
mock_get_creds.assert_called_once()
120+
mock_save.assert_called_once_with(aws_creds)
121+
# Verify the id_token is re-persisted so the next refresh also works
122+
mock_save_token.assert_called_once_with(id_token, claims)
123+
124+
def test_silent_refresh_returns_none_when_id_token_expired(self, auth_instance):
125+
"""When cached id_token is within the 60-second expiry buffer, get_monitoring_token
126+
returns None and silent refresh must not attempt an STS exchange."""
127+
with patch.object(auth_instance, "get_monitoring_token", return_value=None) as mock_get_token, \
128+
patch.object(auth_instance, "get_aws_credentials") as mock_get_creds:
129+
130+
creds, id_token, token_claims = auth_instance._try_silent_refresh()
131+
132+
assert creds is None
133+
assert id_token is None
134+
assert token_claims is None
135+
mock_get_token.assert_called_once()
136+
# STS must never be called when the token is expired
137+
mock_get_creds.assert_not_called()
138+
139+
def test_silent_refresh_returns_none_when_no_cached_token(self, auth_instance):
140+
"""When no id_token is cached, silent refresh should return None."""
141+
with patch.object(auth_instance, "get_monitoring_token", return_value=None):
142+
creds, id_token, token_claims = auth_instance._try_silent_refresh()
143+
assert creds is None
144+
assert id_token is None
145+
assert token_claims is None
146+
147+
def test_silent_refresh_returns_none_when_sts_exchange_fails(self, auth_instance):
148+
"""When id_token is valid but STS exchange fails, should return None (fallback to browser)."""
149+
id_token, _ = _make_id_token(exp_offset=3600)
150+
151+
with patch.object(auth_instance, "get_monitoring_token", return_value=id_token), \
152+
patch.object(auth_instance, "get_aws_credentials", side_effect=Exception("STS error")):
153+
154+
creds, returned_token, returned_claims = auth_instance._try_silent_refresh()
155+
assert creds is None
156+
assert returned_token is None
157+
assert returned_claims is None
158+
159+
def test_silent_refresh_not_called_when_aws_creds_valid(self, auth_instance):
160+
"""When AWS credentials are still valid, silent refresh should not be attempted."""
161+
aws_creds = _make_aws_credentials(exp_offset=3600)
162+
163+
with patch.object(auth_instance, "get_cached_credentials", return_value=aws_creds), \
164+
patch.object(auth_instance, "_try_silent_refresh") as mock_silent, \
165+
patch.object(auth_instance, "_should_recheck_quota", return_value=False):
166+
167+
# Capture stdout
168+
with patch("builtins.print"):
169+
auth_instance.run()
170+
171+
mock_silent.assert_not_called()
172+
173+
def test_run_uses_silent_refresh_before_browser(self, auth_instance):
174+
"""When AWS creds expired but id_token valid, run() should use silent refresh."""
175+
aws_creds = _make_aws_credentials(exp_offset=3600)
176+
177+
with patch.object(auth_instance, "get_cached_credentials", return_value=None), \
178+
patch("socket.socket") as mock_socket_cls, \
179+
patch.object(auth_instance, "_try_silent_refresh", return_value=(aws_creds, None, None)), \
180+
patch.object(auth_instance, "_should_check_quota", return_value=False), \
181+
patch.object(auth_instance, "authenticate_oidc") as mock_browser, \
182+
patch("builtins.print"):
183+
184+
# Mock socket to simulate port available
185+
mock_socket = MagicMock()
186+
mock_socket_cls.return_value = mock_socket
187+
188+
result = auth_instance.run()
189+
190+
assert result == 0
191+
mock_browser.assert_not_called()
192+
193+
def test_run_falls_back_to_browser_when_silent_refresh_fails(self, auth_instance):
194+
"""When silent refresh fails, run() should fall back to browser auth."""
195+
id_token, claims = _make_id_token(exp_offset=3600)
196+
aws_creds = _make_aws_credentials(exp_offset=3600)
197+
198+
with patch.object(auth_instance, "get_cached_credentials", return_value=None), \
199+
patch("socket.socket") as mock_socket_cls, \
200+
patch.object(auth_instance, "_try_silent_refresh", return_value=(None, None, None)), \
201+
patch.object(auth_instance, "authenticate_oidc", return_value=(id_token, claims)) as mock_browser, \
202+
patch.object(auth_instance, "_should_check_quota", return_value=False), \
203+
patch.object(auth_instance, "get_aws_credentials", return_value=aws_creds), \
204+
patch.object(auth_instance, "save_credentials"), \
205+
patch.object(auth_instance, "save_monitoring_token"), \
206+
patch("builtins.print"):
207+
208+
mock_socket = MagicMock()
209+
mock_socket_cls.return_value = mock_socket
210+
211+
result = auth_instance.run()
212+
213+
assert result == 0
214+
mock_browser.assert_called_once()
215+

0 commit comments

Comments
 (0)