Skip to content

Commit 5b86227

Browse files
committed
fix: enforce ECDSA curve validation per RFC 7518 Section 3.4
ECAlgorithm now validates that a key's elliptic curve matches the algorithm being used (e.g. ES256 requires P-256, ES384 requires P-384). Previously, any EC key could be used with any ECDSA algorithm, allowing weaker curves like P-192 to be used for ES256 verification. - Add optional expected_curve parameter to ECAlgorithm.__init__ - Add _validate_curve method called during prepare_key - Update get_default_algorithms to pass correct curves - Fix test_encode_decode_ecdsa_related_algorithms to use correct keys
1 parent 04947d7 commit 5b86227

File tree

3 files changed

+268
-27
lines changed

3 files changed

+268
-27
lines changed

jwt/algorithms.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,12 @@ def get_default_algorithms() -> dict[str, Algorithm]:
171171
"RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
172172
"RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
173173
"RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
174-
"ES256": ECAlgorithm(ECAlgorithm.SHA256),
175-
"ES256K": ECAlgorithm(ECAlgorithm.SHA256),
176-
"ES384": ECAlgorithm(ECAlgorithm.SHA384),
177-
"ES521": ECAlgorithm(ECAlgorithm.SHA512),
174+
"ES256": ECAlgorithm(ECAlgorithm.SHA256, SECP256R1),
175+
"ES256K": ECAlgorithm(ECAlgorithm.SHA256, SECP256K1),
176+
"ES384": ECAlgorithm(ECAlgorithm.SHA384, SECP384R1),
177+
"ES521": ECAlgorithm(ECAlgorithm.SHA512, SECP521R1),
178178
"ES512": ECAlgorithm(
179-
ECAlgorithm.SHA512
179+
ECAlgorithm.SHA512, SECP521R1
180180
), # Backward compat for #219 fix
181181
"PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
182182
"PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
@@ -576,11 +576,28 @@ class ECAlgorithm(Algorithm):
576576

577577
_crypto_key_types = ALLOWED_EC_KEY_TYPES
578578

579-
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
579+
def __init__(
580+
self,
581+
hash_alg: type[hashes.HashAlgorithm],
582+
expected_curve: type[EllipticCurve] | None = None,
583+
) -> None:
580584
self.hash_alg = hash_alg
585+
self.expected_curve = expected_curve
586+
587+
def _validate_curve(self, key: AllowedECKeys) -> None:
588+
"""Validate that the key's curve matches the expected curve."""
589+
if self.expected_curve is None:
590+
return
591+
592+
if not isinstance(key.curve, self.expected_curve):
593+
raise InvalidKeyError(
594+
f"The key's curve '{key.curve.name}' does not match the expected "
595+
f"curve '{self.expected_curve.name}' for this algorithm"
596+
)
581597

582598
def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
583599
if isinstance(key, self._crypto_key_types):
600+
self._validate_curve(key)
584601
return key
585602

586603
if not isinstance(key, (bytes, str)):
@@ -599,11 +616,15 @@ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
599616

600617
# Explicit check the key to prevent confusing errors from cryptography
601618
self.check_crypto_key_type(public_key)
602-
return cast(EllipticCurvePublicKey, public_key)
619+
ec_public_key = cast(EllipticCurvePublicKey, public_key)
620+
self._validate_curve(ec_public_key)
621+
return ec_public_key
603622
except ValueError:
604623
private_key = load_pem_private_key(key_bytes, password=None)
605624
self.check_crypto_key_type(private_key)
606-
return cast(EllipticCurvePrivateKey, private_key)
625+
ec_private_key = cast(EllipticCurvePrivateKey, private_key)
626+
self._validate_curve(ec_private_key)
627+
return ec_private_key
607628

608629
def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
609630
der_sig = key.sign(msg, ECDSA(self.hash_alg()))

tests/test_algorithms.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,3 +1162,228 @@ def test_rsa_prepare_key_raises_invalid_key_error_on_invalid_pem(self):
11621162

11631163
# Check that the exception message is correct
11641164
assert "Could not parse the provided public key." in str(excinfo.value)
1165+
1166+
1167+
@crypto_required
1168+
class TestECCurveValidation:
1169+
"""Tests for ECDSA curve validation per RFC 7518 Section 3.4."""
1170+
1171+
def test_ec_curve_validation_rejects_wrong_curve_for_es256(self):
1172+
"""ES256 should reject keys that are not P-256."""
1173+
from cryptography.hazmat.primitives.asymmetric.ec import SECP256R1
1174+
1175+
algo = ECAlgorithm(ECAlgorithm.SHA256, SECP256R1)
1176+
1177+
# P-384 key should be rejected
1178+
with open(key_path("jwk_ec_key_P-384.json")) as keyfile:
1179+
p384_key = ECAlgorithm.from_jwk(keyfile.read())
1180+
1181+
with pytest.raises(InvalidKeyError) as excinfo:
1182+
algo.prepare_key(p384_key)
1183+
assert "secp384r1" in str(excinfo.value)
1184+
assert "secp256r1" in str(excinfo.value)
1185+
1186+
def test_ec_curve_validation_rejects_wrong_curve_for_es384(self):
1187+
"""ES384 should reject keys that are not P-384."""
1188+
from cryptography.hazmat.primitives.asymmetric.ec import SECP384R1
1189+
1190+
algo = ECAlgorithm(ECAlgorithm.SHA384, SECP384R1)
1191+
1192+
# P-256 key should be rejected
1193+
with open(key_path("jwk_ec_key_P-256.json")) as keyfile:
1194+
p256_key = ECAlgorithm.from_jwk(keyfile.read())
1195+
1196+
with pytest.raises(InvalidKeyError) as excinfo:
1197+
algo.prepare_key(p256_key)
1198+
assert "secp256r1" in str(excinfo.value)
1199+
assert "secp384r1" in str(excinfo.value)
1200+
1201+
def test_ec_curve_validation_rejects_wrong_curve_for_es512(self):
1202+
"""ES512 should reject keys that are not P-521."""
1203+
from cryptography.hazmat.primitives.asymmetric.ec import SECP521R1
1204+
1205+
algo = ECAlgorithm(ECAlgorithm.SHA512, SECP521R1)
1206+
1207+
# P-256 key should be rejected
1208+
with open(key_path("jwk_ec_key_P-256.json")) as keyfile:
1209+
p256_key = ECAlgorithm.from_jwk(keyfile.read())
1210+
1211+
with pytest.raises(InvalidKeyError) as excinfo:
1212+
algo.prepare_key(p256_key)
1213+
assert "secp256r1" in str(excinfo.value)
1214+
assert "secp521r1" in str(excinfo.value)
1215+
1216+
def test_ec_curve_validation_rejects_wrong_curve_for_es256k(self):
1217+
"""ES256K should reject keys that are not secp256k1."""
1218+
from cryptography.hazmat.primitives.asymmetric.ec import SECP256K1
1219+
1220+
algo = ECAlgorithm(ECAlgorithm.SHA256, SECP256K1)
1221+
1222+
# P-256 key should be rejected
1223+
with open(key_path("jwk_ec_key_P-256.json")) as keyfile:
1224+
p256_key = ECAlgorithm.from_jwk(keyfile.read())
1225+
1226+
with pytest.raises(InvalidKeyError) as excinfo:
1227+
algo.prepare_key(p256_key)
1228+
assert "secp256r1" in str(excinfo.value)
1229+
assert "secp256k1" in str(excinfo.value)
1230+
1231+
def test_ec_curve_validation_accepts_correct_curve_for_es256(self):
1232+
"""ES256 should accept P-256 keys."""
1233+
from cryptography.hazmat.primitives.asymmetric.ec import SECP256R1
1234+
1235+
algo = ECAlgorithm(ECAlgorithm.SHA256, SECP256R1)
1236+
1237+
with open(key_path("jwk_ec_key_P-256.json")) as keyfile:
1238+
key = algo.from_jwk(keyfile.read())
1239+
prepared = algo.prepare_key(key)
1240+
assert prepared is key
1241+
1242+
def test_ec_curve_validation_accepts_correct_curve_for_es384(self):
1243+
"""ES384 should accept P-384 keys."""
1244+
from cryptography.hazmat.primitives.asymmetric.ec import SECP384R1
1245+
1246+
algo = ECAlgorithm(ECAlgorithm.SHA384, SECP384R1)
1247+
1248+
with open(key_path("jwk_ec_key_P-384.json")) as keyfile:
1249+
key = algo.from_jwk(keyfile.read())
1250+
prepared = algo.prepare_key(key)
1251+
assert prepared is key
1252+
1253+
def test_ec_curve_validation_accepts_correct_curve_for_es512(self):
1254+
"""ES512 should accept P-521 keys."""
1255+
from cryptography.hazmat.primitives.asymmetric.ec import SECP521R1
1256+
1257+
algo = ECAlgorithm(ECAlgorithm.SHA512, SECP521R1)
1258+
1259+
with open(key_path("jwk_ec_key_P-521.json")) as keyfile:
1260+
key = algo.from_jwk(keyfile.read())
1261+
prepared = algo.prepare_key(key)
1262+
assert prepared is key
1263+
1264+
def test_ec_curve_validation_accepts_correct_curve_for_es256k(self):
1265+
"""ES256K should accept secp256k1 keys."""
1266+
from cryptography.hazmat.primitives.asymmetric.ec import SECP256K1
1267+
1268+
algo = ECAlgorithm(ECAlgorithm.SHA256, SECP256K1)
1269+
1270+
with open(key_path("jwk_ec_key_secp256k1.json")) as keyfile:
1271+
key = algo.from_jwk(keyfile.read())
1272+
prepared = algo.prepare_key(key)
1273+
assert prepared is key
1274+
1275+
def test_ec_curve_validation_rejects_p192_for_es256(self):
1276+
"""ES256 should reject P-192 keys (weaker than P-256)."""
1277+
from cryptography.hazmat.primitives.asymmetric.ec import SECP256R1
1278+
1279+
algo = ECAlgorithm(ECAlgorithm.SHA256, SECP256R1)
1280+
1281+
with open(key_path("testkey_ec_secp192r1.priv")) as keyfile:
1282+
with pytest.raises(InvalidKeyError) as excinfo:
1283+
algo.prepare_key(keyfile.read())
1284+
assert "secp192r1" in str(excinfo.value)
1285+
assert "secp256r1" in str(excinfo.value)
1286+
1287+
def test_ec_algorithm_without_expected_curve_accepts_any_curve(self):
1288+
"""ECAlgorithm without expected_curve should accept any curve (backwards compat)."""
1289+
algo = ECAlgorithm(ECAlgorithm.SHA256)
1290+
1291+
# Should accept P-256
1292+
with open(key_path("jwk_ec_key_P-256.json")) as keyfile:
1293+
p256_key = algo.from_jwk(keyfile.read())
1294+
algo.prepare_key(p256_key)
1295+
1296+
# Should accept P-384
1297+
with open(key_path("jwk_ec_key_P-384.json")) as keyfile:
1298+
p384_key = algo.from_jwk(keyfile.read())
1299+
algo.prepare_key(p384_key)
1300+
1301+
# Should accept P-521
1302+
with open(key_path("jwk_ec_key_P-521.json")) as keyfile:
1303+
p521_key = algo.from_jwk(keyfile.read())
1304+
algo.prepare_key(p521_key)
1305+
1306+
# Should accept secp256k1
1307+
with open(key_path("jwk_ec_key_secp256k1.json")) as keyfile:
1308+
secp256k1_key = algo.from_jwk(keyfile.read())
1309+
algo.prepare_key(secp256k1_key)
1310+
1311+
def test_default_algorithms_have_correct_expected_curve(self):
1312+
"""Default algorithms returned by get_default_algorithms should have expected_curve set."""
1313+
from cryptography.hazmat.primitives.asymmetric.ec import (
1314+
SECP256K1,
1315+
SECP256R1,
1316+
SECP384R1,
1317+
SECP521R1,
1318+
)
1319+
1320+
from jwt.algorithms import get_default_algorithms
1321+
1322+
algorithms = get_default_algorithms()
1323+
1324+
es256 = algorithms["ES256"]
1325+
assert isinstance(es256, ECAlgorithm)
1326+
assert es256.expected_curve == SECP256R1
1327+
1328+
es256k = algorithms["ES256K"]
1329+
assert isinstance(es256k, ECAlgorithm)
1330+
assert es256k.expected_curve == SECP256K1
1331+
1332+
es384 = algorithms["ES384"]
1333+
assert isinstance(es384, ECAlgorithm)
1334+
assert es384.expected_curve == SECP384R1
1335+
1336+
es521 = algorithms["ES521"]
1337+
assert isinstance(es521, ECAlgorithm)
1338+
assert es521.expected_curve == SECP521R1
1339+
1340+
es512 = algorithms["ES512"]
1341+
assert isinstance(es512, ECAlgorithm)
1342+
assert es512.expected_curve == SECP521R1
1343+
1344+
def test_ec_curve_validation_with_pem_key(self):
1345+
"""Curve validation should work with PEM-formatted keys."""
1346+
from cryptography.hazmat.primitives.asymmetric.ec import SECP256R1
1347+
1348+
algo = ECAlgorithm(ECAlgorithm.SHA256, SECP256R1)
1349+
1350+
# P-256 PEM key should be accepted
1351+
with open(key_path("testkey_ec.priv")) as keyfile:
1352+
algo.prepare_key(keyfile.read())
1353+
1354+
# P-192 PEM key should be rejected
1355+
with open(key_path("testkey_ec_secp192r1.priv")) as keyfile:
1356+
with pytest.raises(InvalidKeyError):
1357+
algo.prepare_key(keyfile.read())
1358+
1359+
def test_jwt_encode_decode_rejects_wrong_curve(self):
1360+
"""Integration test: jwt.encode/decode should reject wrong curve keys."""
1361+
import jwt
1362+
1363+
# Use P-384 key with ES256 algorithm (expects P-256)
1364+
with open(key_path("jwk_ec_key_P-384.json")) as keyfile:
1365+
p384_key = ECAlgorithm.from_jwk(keyfile.read())
1366+
1367+
# Encoding should fail
1368+
with pytest.raises(InvalidKeyError):
1369+
jwt.encode({"hello": "world"}, p384_key, algorithm="ES256")
1370+
1371+
# Create a valid token with P-256 key
1372+
with open(key_path("jwk_ec_key_P-256.json")) as keyfile:
1373+
p256_key = ECAlgorithm.from_jwk(keyfile.read())
1374+
1375+
token = jwt.encode({"hello": "world"}, p256_key, algorithm="ES256")
1376+
1377+
# Decoding with wrong curve key should fail
1378+
with open(key_path("jwk_ec_pub_P-384.json")) as keyfile:
1379+
p384_pub_key = ECAlgorithm.from_jwk(keyfile.read())
1380+
1381+
with pytest.raises(InvalidKeyError):
1382+
jwt.decode(token, p384_pub_key, algorithms=["ES256"])
1383+
1384+
# Decoding with correct curve key should succeed
1385+
with open(key_path("jwk_ec_pub_P-256.json")) as keyfile:
1386+
p256_pub_key = ECAlgorithm.from_jwk(keyfile.read())
1387+
1388+
decoded = jwt.decode(token, p256_pub_key, algorithms=["ES256"])
1389+
assert decoded == {"hello": "world"}

tests/test_api_jws.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -646,32 +646,27 @@ def test_rsa_related_algorithms(self, jws):
646646
assert "PS512" not in jws_algorithms
647647

648648
@pytest.mark.parametrize(
649-
"algo",
649+
"algo,priv_key_file,pub_key_file",
650650
[
651-
"ES256",
652-
"ES256K",
653-
"ES384",
654-
"ES512",
651+
("ES256", "jwk_ec_key_P-256.json", "jwk_ec_pub_P-256.json"),
652+
("ES256K", "jwk_ec_key_secp256k1.json", "jwk_ec_pub_secp256k1.json"),
653+
("ES384", "jwk_ec_key_P-384.json", "jwk_ec_pub_P-384.json"),
654+
("ES512", "jwk_ec_key_P-521.json", "jwk_ec_pub_P-521.json"),
655655
],
656656
)
657657
@crypto_required
658-
def test_encode_decode_ecdsa_related_algorithms(self, jws, payload, algo):
659-
# PEM-formatted EC key
660-
with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file:
661-
priv_eckey = load_pem_private_key(ec_priv_file.read(), password=None)
662-
jws_message = jws.encode(payload, priv_eckey, algorithm=algo)
663-
664-
with open(key_path("testkey_ec.pub"), "rb") as ec_pub_file:
665-
pub_eckey = load_pem_public_key(ec_pub_file.read())
666-
jws.decode(jws_message, pub_eckey, algorithms=[algo])
658+
def test_encode_decode_ecdsa_related_algorithms(
659+
self, jws, payload, algo, priv_key_file, pub_key_file
660+
):
661+
from jwt.algorithms import ECAlgorithm
667662

668-
# string-formatted key
669-
with open(key_path("testkey_ec.priv")) as ec_priv_file:
670-
priv_eckey = ec_priv_file.read() # type: ignore[assignment]
663+
# Load keys from JWK files (each algorithm requires its specific curve)
664+
with open(key_path(priv_key_file)) as priv_file:
665+
priv_eckey = ECAlgorithm.from_jwk(priv_file.read())
671666
jws_message = jws.encode(payload, priv_eckey, algorithm=algo)
672667

673-
with open(key_path("testkey_ec.pub")) as ec_pub_file:
674-
pub_eckey = ec_pub_file.read() # type: ignore[assignment]
668+
with open(key_path(pub_key_file)) as pub_file:
669+
pub_eckey = ECAlgorithm.from_jwk(pub_file.read())
675670
jws.decode(jws_message, pub_eckey, algorithms=[algo])
676671

677672
def test_ecdsa_related_algorithms(self, jws):

0 commit comments

Comments
 (0)