Skip to content

Commit 34b5123

Browse files
Send the client key exchange init in Connect (#1274)
* Send the client key exchange init in Connect * Add a test --------- Co-authored-by: Wojciech Nagórski <[email protected]>
1 parent 326ce14 commit 34b5123

12 files changed

+135
-106
lines changed

src/Renci.SshNet/Security/IKeyExchange.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ public interface IKeyExchange : IDisposable
3838
/// Starts the key exchange algorithm.
3939
/// </summary>
4040
/// <param name="session">The session.</param>
41-
/// <param name="message">Key exchange init message.</param>
42-
void Start(Session session, KeyExchangeInitMessage message);
41+
/// <param name="message">The key exchange init message received from the server.</param>
42+
/// <param name="sendClientInitMessage">Whether to send a key exchange init message in response.</param>
43+
void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage);
4344

4445
/// <summary>
4546
/// Finishes the key exchange algorithm.

src/Renci.SshNet/Security/KeyExchange.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,15 @@ public byte[] ExchangeHash
6161
/// </summary>
6262
public event EventHandler<HostKeyEventArgs> HostKeyReceived;
6363

64-
/// <summary>
65-
/// Starts key exchange algorithm.
66-
/// </summary>
67-
/// <param name="session">The session.</param>
68-
/// <param name="message">Key exchange init message.</param>
69-
public virtual void Start(Session session, KeyExchangeInitMessage message)
64+
/// <inheritdoc/>
65+
public virtual void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
7066
{
7167
Session = session;
7268

73-
SendMessage(session.ClientInitMessage);
69+
if (sendClientInitMessage)
70+
{
71+
SendMessage(session.ClientInitMessage);
72+
}
7473

7574
// Determine encryption algorithm
7675
var clientEncryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys

src/Renci.SshNet/Security/KeyExchangeDiffieHellman.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,10 @@ protected override bool ValidateExchangeHash()
7676
return ValidateExchangeHash(_hostKey, _signature);
7777
}
7878

79-
/// <summary>
80-
/// Starts key exchange algorithm.
81-
/// </summary>
82-
/// <param name="session">The session.</param>
83-
/// <param name="message">Key exchange init message.</param>
84-
public override void Start(Session session, KeyExchangeInitMessage message)
79+
/// <inheritdoc/>
80+
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
8581
{
86-
base.Start(session, message);
82+
base.Start(session, message, sendClientInitMessage);
8783

8884
_serverPayload = message.GetBytes();
8985
_clientPayload = Session.ClientInitMessage.GetBytes();

src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupExchangeShaBase.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,10 @@ protected override byte[] CalculateHash()
3939
return Hash(groupExchangeHashData.GetBytes());
4040
}
4141

42-
/// <summary>
43-
/// Starts key exchange algorithm.
44-
/// </summary>
45-
/// <param name="session">The session.</param>
46-
/// <param name="message">Key exchange init message.</param>
47-
public override void Start(Session session, KeyExchangeInitMessage message)
42+
/// <inheritdoc/>
43+
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
4844
{
49-
base.Start(session, message);
45+
base.Start(session, message, sendClientInitMessage);
5046

5147
// Register SSH_MSG_KEX_DH_GEX_GROUP message
5248
Session.RegisterMessage("SSH_MSG_KEX_DH_GEX_GROUP");

src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupShaBase.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,10 @@ internal abstract class KeyExchangeDiffieHellmanGroupShaBase : KeyExchangeDiffie
1313
/// </value>
1414
public abstract BigInteger GroupPrime { get; }
1515

16-
/// <summary>
17-
/// Starts key exchange algorithm.
18-
/// </summary>
19-
/// <param name="session">The session.</param>
20-
/// <param name="message">Key exchange init message.</param>
21-
public override void Start(Session session, KeyExchangeInitMessage message)
16+
/// <inheritdoc/>
17+
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
2218
{
23-
base.Start(session, message);
19+
base.Start(session, message, sendClientInitMessage);
2420

2521
Session.RegisterMessage("SSH_MSG_KEXDH_REPLY");
2622

src/Renci.SshNet/Security/KeyExchangeEC.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,10 @@ protected override bool ValidateExchangeHash()
7878
return ValidateExchangeHash(_hostKey, _signature);
7979
}
8080

81-
/// <summary>
82-
/// Starts key exchange algorithm.
83-
/// </summary>
84-
/// <param name="session">The session.</param>
85-
/// <param name="message">Key exchange init message.</param>
86-
public override void Start(Session session, KeyExchangeInitMessage message)
81+
/// <inheritdoc/>
82+
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
8783
{
88-
base.Start(session, message);
84+
base.Start(session, message, sendClientInitMessage);
8985

9086
_serverPayload = message.GetBytes();
9187
_clientPayload = Session.ClientInitMessage.GetBytes();

src/Renci.SshNet/Security/KeyExchangeECCurve25519.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,10 @@ protected override int HashSize
2929
get { return 256; }
3030
}
3131

32-
/// <summary>
33-
/// Starts key exchange algorithm.
34-
/// </summary>
35-
/// <param name="session">The session.</param>
36-
/// <param name="message">Key exchange init message.</param>
37-
public override void Start(Session session, KeyExchangeInitMessage message)
32+
/// <inheritdoc/>
33+
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
3834
{
39-
base.Start(session, message);
35+
base.Start(session, message, sendClientInitMessage);
4036

4137
Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY");
4238

src/Renci.SshNet/Security/KeyExchangeECDH.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,10 @@ internal abstract class KeyExchangeECDH : KeyExchangeEC
2424
private ECDHCBasicAgreement _keyAgreement;
2525
private ECDomainParameters _domainParameters;
2626

27-
/// <summary>
28-
/// Starts key exchange algorithm.
29-
/// </summary>
30-
/// <param name="session">The session.</param>
31-
/// <param name="message">Key exchange init message.</param>
32-
public override void Start(Session session, KeyExchangeInitMessage message)
27+
/// <inheritdoc/>
28+
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
3329
{
34-
base.Start(session, message);
30+
base.Start(session, message, sendClientInitMessage);
3531

3632
Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY");
3733

src/Renci.SshNet/Session.cs

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,7 @@ public class Session : ISession
160160
/// <summary>
161161
/// WaitHandle to signal that key exchange was completed.
162162
/// </summary>
163-
private EventWaitHandle _keyExchangeCompletedWaitHandle = new ManualResetEvent(initialState: false);
164-
165-
/// <summary>
166-
/// WaitHandle to signal that key exchange is in progress.
167-
/// </summary>
168-
private bool _keyExchangeInProgress;
163+
private ManualResetEventSlim _keyExchangeCompletedWaitHandle = new ManualResetEventSlim(initialState: false);
169164

170165
/// <summary>
171166
/// Exception that need to be thrown by waiting thread.
@@ -643,6 +638,11 @@ public void Connect()
643638
// Some server implementations might sent this message first, prior to establishing encryption algorithm
644639
RegisterMessage("SSH_MSG_USERAUTH_BANNER");
645640

641+
// Send our key exchange init.
642+
// We need to do this before starting the message listener to avoid the case where we receive the server
643+
// key exchange init and we continue the key exchange before having sent our own init.
644+
SendMessage(ClientInitMessage);
645+
646646
// Mark the message listener threads as started
647647
_ = _messageListenerCompleted.Reset();
648648

@@ -651,7 +651,7 @@ public void Connect()
651651
_ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);
652652

653653
// Wait for key exchange to be completed
654-
WaitOnHandle(_keyExchangeCompletedWaitHandle);
654+
WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
655655

656656
// If sessionId is not set then its not connected
657657
if (SessionId is null)
@@ -757,6 +757,11 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
757757
// Some server implementations might sent this message first, prior to establishing encryption algorithm
758758
RegisterMessage("SSH_MSG_USERAUTH_BANNER");
759759

760+
// Send our key exchange init.
761+
// We need to do this before starting the message listener to avoid the case where we receive the server
762+
// key exchange init and we continue the key exchange before having sent our own init.
763+
SendMessage(ClientInitMessage);
764+
760765
// Mark the message listener threads as started
761766
_ = _messageListenerCompleted.Reset();
762767

@@ -765,7 +770,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
765770
_ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);
766771

767772
// Wait for key exchange to be completed
768-
WaitOnHandle(_keyExchangeCompletedWaitHandle);
773+
WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
769774

770775
// If sessionId is not set then its not connected
771776
if (SessionId is null)
@@ -1046,10 +1051,10 @@ internal void SendMessage(Message message)
10461051
throw new SshConnectionException("Client not connected.");
10471052
}
10481053

1049-
if (_keyExchangeInProgress && message is not IKeyExchangedAllowed)
1054+
if (!_keyExchangeCompletedWaitHandle.IsSet && message is not IKeyExchangedAllowed)
10501055
{
10511056
// Wait for key exchange to be completed
1052-
WaitOnHandle(_keyExchangeCompletedWaitHandle);
1057+
WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
10531058
}
10541059

10551060
DiagnosticAbstraction.Log(string.Format("[{0}] Sending message '{1}' to server: '{2}'.", ToHex(SessionId), message.GetType().Name, message));
@@ -1394,9 +1399,15 @@ internal void OnKeyExchangeDhGroupExchangeReplyReceived(KeyExchangeDhGroupExchan
13941399
/// <param name="message"><see cref="KeyExchangeInitMessage"/> message.</param>
13951400
internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
13961401
{
1397-
_keyExchangeInProgress = true;
1402+
// If _keyExchangeCompletedWaitHandle is already set, then this is a key
1403+
// re-exchange initiated by the server, and we need to send our own init
1404+
// message.
1405+
// Otherwise, the wait handle is not set and this received init is part of the
1406+
// initial connection for which we have already sent our init, so we shouldn't
1407+
// send another one.
1408+
var sendClientInitMessage = _keyExchangeCompletedWaitHandle.IsSet;
13981409

1399-
_ = _keyExchangeCompletedWaitHandle.Reset();
1410+
_keyExchangeCompletedWaitHandle.Reset();
14001411

14011412
// Disable messages that are not key exchange related
14021413
_sshMessageFactory.DisableNonKeyExchangeMessages();
@@ -1411,7 +1422,7 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
14111422
_keyExchange.HostKeyReceived += KeyExchange_HostKeyReceived;
14121423

14131424
// Start the algorithm implementation
1414-
_keyExchange.Start(this, message);
1425+
_keyExchange.Start(this, message, sendClientInitMessage);
14151426

14161427
KeyExchangeInitReceived?.Invoke(this, new MessageEventArgs<KeyExchangeInitMessage>(message));
14171428
}
@@ -1477,9 +1488,7 @@ internal void OnNewKeysReceived(NewKeysMessage message)
14771488
NewKeysReceived?.Invoke(this, new MessageEventArgs<NewKeysMessage>(message));
14781489

14791490
// Signal that key exchange completed
1480-
_ = _keyExchangeCompletedWaitHandle.Set();
1481-
1482-
_keyExchangeInProgress = false;
1491+
_keyExchangeCompletedWaitHandle.Set();
14831492
}
14841493

14851494
/// <summary>
@@ -1967,15 +1976,14 @@ private void RaiseError(Exception exp)
19671976
private void Reset()
19681977
{
19691978
_ = _exceptionWaitHandle?.Reset();
1970-
_ = _keyExchangeCompletedWaitHandle?.Reset();
1979+
_keyExchangeCompletedWaitHandle?.Reset();
19711980
_ = _messageListenerCompleted?.Set();
19721981

19731982
SessionId = null;
19741983
_isDisconnectMessageSent = false;
19751984
_isDisconnecting = false;
19761985
_isAuthenticated = false;
19771986
_exception = null;
1978-
_keyExchangeInProgress = false;
19791987
}
19801988

19811989
private static SshConnectionException CreateConnectionAbortedByServerException()

0 commit comments

Comments
 (0)