Skip to content

Commit a71c329

Browse files
committed
GH-606: ML-KEM key exchange implementation using Bouncy Castle
Refactor the KEM-based KEX paths a little bit; provide the ML-KEMs, and add the DH factories combining the ML-KEMs with the base curves and hashes. KexTest tests that the new key exchanges do work between an Apache MINA sshd client and server. Add an integration test that verifies that the new ML-KEM kex works against an OpenSSH 9.9 server (it only has mlkem768x25519, not the other two variants using ECDH nistp256/384, so we can't test those).
1 parent 5fa5b64 commit a71c329

15 files changed

Lines changed: 519 additions & 41 deletions

File tree

sshd-core/src/main/java/org/apache/sshd/client/kex/DHGClient.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@
3333
import org.apache.sshd.common.config.keys.OpenSshCertificate;
3434
import org.apache.sshd.common.digest.Digest;
3535
import org.apache.sshd.common.kex.AbstractDH;
36+
import org.apache.sshd.common.kex.CurveSizeIndicator;
3637
import org.apache.sshd.common.kex.DHFactory;
3738
import org.apache.sshd.common.kex.KexProposalOption;
3839
import org.apache.sshd.common.kex.KeyEncapsulationMethod;
3940
import org.apache.sshd.common.kex.KeyExchange;
4041
import org.apache.sshd.common.kex.KeyExchangeFactory;
41-
import org.apache.sshd.common.kex.XDH;
4242
import org.apache.sshd.common.keyprovider.KeyPairProvider;
4343
import org.apache.sshd.common.session.Session;
4444
import org.apache.sshd.common.signature.Signature;
@@ -154,14 +154,15 @@ public boolean next(int cmd, Buffer buffer) throws Exception {
154154
} else {
155155
try {
156156
int l = kemClient.getEncapsulationLength();
157-
if (dh instanceof XDH) {
158-
if (f.length != l + ((XDH) dh).getKeySize()) {
157+
if (dh instanceof CurveSizeIndicator) {
158+
int expectedLength = l + ((CurveSizeIndicator) dh).getByteLength();
159+
if (f.length != expectedLength) {
159160
throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED,
160-
"Wrong F length (should be 1071 bytes): " + f.length);
161+
"Wrong F length (should be " + expectedLength + " bytes): " + f.length);
161162
}
162-
} else {
163+
} else if (f.length <= l) {
163164
throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED,
164-
"Key encapsulation only supported for XDH");
165+
"Strange F length: " + f.length + " <= " + l);
165166
}
166167
dh.setF(Arrays.copyOfRange(f, l, f.length));
167168
Digest keyHash = dh.getHash();
@@ -170,6 +171,7 @@ public boolean next(int cmd, Buffer buffer) throws Exception {
170171
keyHash.update(dh.getK());
171172
k = keyHash.digest();
172173
} catch (IllegalArgumentException ex) {
174+
log.error("Key encapsulation error", ex);
173175
throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED,
174176
"Key encapsulation error: " + ex.getMessage());
175177
}

sshd-core/src/main/java/org/apache/sshd/common/BaseBuilder.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ public class BaseBuilder<T extends AbstractFactoryManager, S extends BaseBuilder
8888
public static final List<BuiltinDHFactories> DEFAULT_KEX_PREFERENCE = Collections.unmodifiableList(
8989
Arrays.asList(
9090
BuiltinDHFactories.sntrup761x25519,
91+
BuiltinDHFactories.sntrup761x25519_openssh,
92+
BuiltinDHFactories.mlkem768x25519,
93+
BuiltinDHFactories.mlkem1024nistp384,
94+
BuiltinDHFactories.mlkem768nistp256,
9195
BuiltinDHFactories.curve25519,
9296
BuiltinDHFactories.curve25519_libssh,
9397
BuiltinDHFactories.curve448,

sshd-core/src/main/java/org/apache/sshd/common/kex/BuiltinDHFactories.java

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,86 @@ public boolean isSupported() {
302302
return MontgomeryCurve.x448.isSupported() && BuiltinDigests.sha512.isSupported();
303303
}
304304
},
305+
/**
306+
* @see <a href= "https://datatracker.ietf.org/doc/html/draft-kampanakis-curdle-ssh-pq-ke-04">PQ/T Hybrid Key
307+
* Exchange in SSH</a>
308+
*/
309+
mlkem768x25519(Constants.MLKEM768_25519_SHA256) {
310+
@Override
311+
public XDH create(Object... params) throws Exception {
312+
if (!GenericUtils.isEmpty(params)) {
313+
throw new IllegalArgumentException("No accepted parameters for " + getName());
314+
}
315+
return new XDH(MontgomeryCurve.x25519, true) {
316+
317+
@Override
318+
public KeyEncapsulationMethod getKeyEncapsulation() {
319+
return BuiltinKEM.mlkem768;
320+
}
321+
322+
@Override
323+
public Digest getHash() throws Exception {
324+
return BuiltinDigests.sha256.create();
325+
}
326+
};
327+
}
328+
329+
@Override
330+
public boolean isSupported() {
331+
return MontgomeryCurve.x25519.isSupported() && BuiltinDigests.sha256.isSupported()
332+
&& BuiltinKEM.mlkem768.isSupported();
333+
}
334+
},
335+
/**
336+
* @see <a href= "https://datatracker.ietf.org/doc/html/draft-kampanakis-curdle-ssh-pq-ke-04">PQ/T Hybrid Key
337+
* Exchange in SSH</a>
338+
*/
339+
mlkem768nistp256(Constants.MLKEM768_NISTP256_SHA256) {
340+
@Override
341+
public ECDH create(Object... params) throws Exception {
342+
if (!GenericUtils.isEmpty(params)) {
343+
throw new IllegalArgumentException("No accepted parameters for " + getName());
344+
}
345+
return new ECDH(ECCurves.nistp256, true) {
346+
347+
@Override
348+
public KeyEncapsulationMethod getKeyEncapsulation() {
349+
return BuiltinKEM.mlkem768;
350+
}
351+
352+
};
353+
}
354+
355+
@Override
356+
public boolean isSupported() {
357+
return ECCurves.nistp256.isSupported() && BuiltinKEM.mlkem768.isSupported();
358+
}
359+
},
360+
/**
361+
* @see <a href= "https://datatracker.ietf.org/doc/html/draft-kampanakis-curdle-ssh-pq-ke-04">PQ/T Hybrid Key
362+
* Exchange in SSH</a>
363+
*/
364+
mlkem1024nistp384(Constants.MLKEM1024_NISTP384_SHA384) {
365+
@Override
366+
public ECDH create(Object... params) throws Exception {
367+
if (!GenericUtils.isEmpty(params)) {
368+
throw new IllegalArgumentException("No accepted parameters for " + getName());
369+
}
370+
return new ECDH(ECCurves.nistp384, true) {
371+
372+
@Override
373+
public KeyEncapsulationMethod getKeyEncapsulation() {
374+
return BuiltinKEM.mlkem1024;
375+
}
376+
377+
};
378+
}
379+
380+
@Override
381+
public boolean isSupported() {
382+
return ECCurves.nistp384.isSupported() && BuiltinKEM.mlkem1024.isSupported();
383+
}
384+
},
305385
/**
306386
* @see <a href=
307387
* "https://www.ietf.org/archive/id/draft-josefsson-ntruprime-ssh-02.html">draft-josefsson-ntruprime-ssh-02.html</a>
@@ -524,6 +604,9 @@ public static final class Constants {
524604
public static final String CURVE25519_SHA256 = "curve25519-sha256";
525605
public static final String CURVE25519_SHA256_LIBSSH = CURVE25519_SHA256 + "@libssh.org";
526606
public static final String CURVE448_SHA512 = "curve448-sha512";
607+
public static final String MLKEM768_25519_SHA256 = "mlkem768x25519-sha256";
608+
public static final String MLKEM768_NISTP256_SHA256 = "mlkem768nistp256-sha256";
609+
public static final String MLKEM1024_NISTP384_SHA384 = "mlkem1024nistp384-sha384";
527610
public static final String SNTRUP761_25519_SHA512 = "sntrup761x25519-sha512";
528611
public static final String SNTRUP761_25519_SHA512_OPENSSH = SNTRUP761_25519_SHA512 + "@openssh.com";
529612

sshd-core/src/main/java/org/apache/sshd/common/kex/BuiltinKEM.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,44 @@
2626
*/
2727
public enum BuiltinKEM implements KeyEncapsulationMethod, NamedResource, OptionalFeature {
2828

29+
mlkem768("mlkem768") {
30+
31+
@Override
32+
public Client getClient() {
33+
return MLKEM.getClient(MLKEM.Parameters.mlkem768);
34+
}
35+
36+
@Override
37+
public Server getServer() {
38+
return MLKEM.getServer(MLKEM.Parameters.mlkem768);
39+
}
40+
41+
@Override
42+
public boolean isSupported() {
43+
return MLKEM.Parameters.mlkem768.isSupported();
44+
}
45+
46+
},
47+
48+
mlkem1024("mlkem1024") {
49+
50+
@Override
51+
public Client getClient() {
52+
return MLKEM.getClient(MLKEM.Parameters.mlkem1024);
53+
}
54+
55+
@Override
56+
public Server getServer() {
57+
return MLKEM.getServer(MLKEM.Parameters.mlkem1024);
58+
}
59+
60+
@Override
61+
public boolean isSupported() {
62+
return MLKEM.Parameters.mlkem1024.isSupported();
63+
}
64+
65+
},
66+
2967
sntrup761("sntrup761") {
3068

3169
@Override
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.sshd.common.kex;
20+
21+
/**
22+
* @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
23+
*/
24+
public interface CurveSizeIndicator {
25+
26+
/**
27+
* Retrieves the length of a point coordinate in bytes.
28+
*
29+
* @return the length
30+
*/
31+
int getByteLength();
32+
}

sshd-core/src/main/java/org/apache/sshd/common/kex/ECDH.java

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,31 +43,41 @@
4343
public class ECDH extends AbstractDH {
4444
public static final String KEX_TYPE = "ECDH";
4545

46+
private final boolean raw;
47+
4648
private ECCurves curve;
4749
private ECParameterSpec params;
4850
private ECPoint f;
4951

50-
public ECDH() throws Exception {
51-
this((ECParameterSpec) null);
52-
}
53-
5452
public ECDH(String curveName) throws Exception {
55-
this(ValidateUtils.checkNotNull(ECCurves.fromCurveName(curveName), "Unknown curve name: %s", curveName));
53+
this(curveName, false);
5654
}
5755

5856
public ECDH(ECCurves curve) throws Exception {
59-
this(Objects.requireNonNull(curve, "No known curve instance provided").getParameters());
60-
this.curve = curve;
57+
this(curve, false);
6158
}
6259

6360
public ECDH(ECParameterSpec paramSpec) throws Exception {
61+
this(paramSpec, false);
62+
}
63+
64+
public ECDH(String curveName, boolean raw) throws Exception {
65+
this(ValidateUtils.checkNotNull(ECCurves.fromCurveName(curveName), "Unknown curve name: %s", curveName), raw);
66+
}
67+
68+
public ECDH(ECCurves curve, boolean raw) throws Exception {
69+
this(Objects.requireNonNull(curve, "No known curve instance provided").getParameters(), raw);
70+
this.curve = curve;
71+
}
72+
73+
public ECDH(ECParameterSpec paramSpec, boolean raw) throws Exception {
6474
myKeyAgree = SecurityUtils.getKeyAgreement(KEX_TYPE);
65-
params = paramSpec; // do not check for null-ity since in some cases it can be
75+
params = Objects.requireNonNull(paramSpec, "No EC curve parameters provided");
76+
this.raw = raw;
6677
}
6778

6879
@Override
6980
protected byte[] calculateE() throws Exception {
70-
Objects.requireNonNull(params, "No ECParameterSpec(s)");
7181
KeyPairGenerator myKpairGen = SecurityUtils.getKeyPairGenerator(KeyUtils.EC_ALGORITHM);
7282
myKpairGen.initialize(params);
7383

@@ -81,22 +91,17 @@ protected byte[] calculateE() throws Exception {
8191

8292
@Override
8393
protected byte[] calculateK() throws Exception {
84-
Objects.requireNonNull(params, "No ECParameterSpec(s)");
8594
Objects.requireNonNull(f, "Missing 'f' value");
8695
ECPublicKeySpec keySpec = new ECPublicKeySpec(f, params);
8796
KeyFactory myKeyFac = SecurityUtils.getKeyFactory(KeyUtils.EC_ALGORITHM);
8897
PublicKey yourPubKey = myKeyFac.generatePublic(keySpec);
8998
myKeyAgree.doPhase(yourPubKey, true);
90-
return stripLeadingZeroes(myKeyAgree.generateSecret());
91-
}
92-
93-
public void setCurveParameters(ECParameterSpec paramSpec) {
94-
params = paramSpec;
99+
byte[] secret = myKeyAgree.generateSecret();
100+
return raw ? secret : stripLeadingZeroes(secret);
95101
}
96102

97103
@Override
98104
public void setF(byte[] f) {
99-
Objects.requireNonNull(params, "No ECParameterSpec(s)");
100105
Objects.requireNonNull(f, "No 'f' value specified");
101106
this.f = ECCurves.octetStringToEcPoint(f);
102107
}
@@ -117,12 +122,14 @@ public void putF(Buffer buffer, byte[] f) {
117122

118123
@Override
119124
public Digest getHash() throws Exception {
125+
return findCurve().getDigestForParams();
126+
}
127+
128+
private ECCurves findCurve() {
120129
if (curve == null) {
121-
Objects.requireNonNull(params, "No ECParameterSpec(s)");
122130
curve = Objects.requireNonNull(ECCurves.fromCurveParameters(params), "Unknown curve parameters");
123131
}
124-
125-
return curve.getDigestForParams();
132+
return curve;
126133
}
127134

128135
@Override

sshd-core/src/main/java/org/apache/sshd/common/kex/KeyEncapsulationMethod.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ interface Client {
6363
*/
6464
interface Server {
6565

66+
/**
67+
* Retrieves the required length of the KEM public key, in bytes.
68+
*
69+
* @return the length of the key
70+
*/
71+
int getPublicKeyLength();
72+
6673
/**
6774
* Initializes the KEM with a public key received from a client and prepares an encapsulated secret.
6875
*

0 commit comments

Comments
 (0)