Skip to content

Send the client key exchange init in Connect #1274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/Renci.SshNet/Security/IKeyExchange.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ public interface IKeyExchange : IDisposable
/// Starts the key exchange algorithm.
/// </summary>
/// <param name="session">The session.</param>
/// <param name="message">Key exchange init message.</param>
void Start(Session session, KeyExchangeInitMessage message);
/// <param name="message">The key exchange init message received from the server.</param>
/// <param name="sendClientInitMessage">Whether to send a key exchange init message in response.</param>
void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage);

/// <summary>
/// Finishes the key exchange algorithm.
Expand Down
13 changes: 6 additions & 7 deletions src/Renci.SshNet/Security/KeyExchange.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,15 @@ public byte[] ExchangeHash
/// </summary>
public event EventHandler<HostKeyEventArgs> HostKeyReceived;

/// <summary>
/// Starts key exchange algorithm.
/// </summary>
/// <param name="session">The session.</param>
/// <param name="message">Key exchange init message.</param>
public virtual void Start(Session session, KeyExchangeInitMessage message)
/// <inheritdoc/>
public virtual void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
Session = session;

SendMessage(session.ClientInitMessage);
if (sendClientInitMessage)
{
SendMessage(session.ClientInitMessage);
}

// Determine encryption algorithm
var clientEncryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys
Expand Down
10 changes: 3 additions & 7 deletions src/Renci.SshNet/Security/KeyExchangeDiffieHellman.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,10 @@ protected override bool ValidateExchangeHash()
return ValidateExchangeHash(_hostKey, _signature);
}

/// <summary>
/// Starts key exchange algorithm.
/// </summary>
/// <param name="session">The session.</param>
/// <param name="message">Key exchange init message.</param>
public override void Start(Session session, KeyExchangeInitMessage message)
/// <inheritdoc/>
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
base.Start(session, message);
base.Start(session, message, sendClientInitMessage);

_serverPayload = message.GetBytes();
_clientPayload = Session.ClientInitMessage.GetBytes();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,10 @@ protected override byte[] CalculateHash()
return Hash(groupExchangeHashData.GetBytes());
}

/// <summary>
/// Starts key exchange algorithm.
/// </summary>
/// <param name="session">The session.</param>
/// <param name="message">Key exchange init message.</param>
public override void Start(Session session, KeyExchangeInitMessage message)
/// <inheritdoc/>
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
base.Start(session, message);
base.Start(session, message, sendClientInitMessage);

// Register SSH_MSG_KEX_DH_GEX_GROUP message
Session.RegisterMessage("SSH_MSG_KEX_DH_GEX_GROUP");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@ internal abstract class KeyExchangeDiffieHellmanGroupShaBase : KeyExchangeDiffie
/// </value>
public abstract BigInteger GroupPrime { get; }

/// <summary>
/// Starts key exchange algorithm.
/// </summary>
/// <param name="session">The session.</param>
/// <param name="message">Key exchange init message.</param>
public override void Start(Session session, KeyExchangeInitMessage message)
/// <inheritdoc/>
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
base.Start(session, message);
base.Start(session, message, sendClientInitMessage);

Session.RegisterMessage("SSH_MSG_KEXDH_REPLY");

Expand Down
10 changes: 3 additions & 7 deletions src/Renci.SshNet/Security/KeyExchangeEC.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,10 @@ protected override bool ValidateExchangeHash()
return ValidateExchangeHash(_hostKey, _signature);
}

/// <summary>
/// Starts key exchange algorithm.
/// </summary>
/// <param name="session">The session.</param>
/// <param name="message">Key exchange init message.</param>
public override void Start(Session session, KeyExchangeInitMessage message)
/// <inheritdoc/>
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
base.Start(session, message);
base.Start(session, message, sendClientInitMessage);

_serverPayload = message.GetBytes();
_clientPayload = Session.ClientInitMessage.GetBytes();
Expand Down
10 changes: 3 additions & 7 deletions src/Renci.SshNet/Security/KeyExchangeECCurve25519.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,10 @@ protected override int HashSize
get { return 256; }
}

/// <summary>
/// Starts key exchange algorithm.
/// </summary>
/// <param name="session">The session.</param>
/// <param name="message">Key exchange init message.</param>
public override void Start(Session session, KeyExchangeInitMessage message)
/// <inheritdoc/>
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
base.Start(session, message);
base.Start(session, message, sendClientInitMessage);

Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY");

Expand Down
10 changes: 3 additions & 7 deletions src/Renci.SshNet/Security/KeyExchangeECDH.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@ internal abstract class KeyExchangeECDH : KeyExchangeEC
private ECDHCBasicAgreement _keyAgreement;
private ECDomainParameters _domainParameters;

/// <summary>
/// Starts key exchange algorithm.
/// </summary>
/// <param name="session">The session.</param>
/// <param name="message">Key exchange init message.</param>
public override void Start(Session session, KeyExchangeInitMessage message)
/// <inheritdoc/>
public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
base.Start(session, message);
base.Start(session, message, sendClientInitMessage);

Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY");

Expand Down
44 changes: 26 additions & 18 deletions src/Renci.SshNet/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,7 @@ public class Session : ISession
/// <summary>
/// WaitHandle to signal that key exchange was completed.
/// </summary>
private EventWaitHandle _keyExchangeCompletedWaitHandle = new ManualResetEvent(initialState: false);

/// <summary>
/// WaitHandle to signal that key exchange is in progress.
/// </summary>
private bool _keyExchangeInProgress;
private ManualResetEventSlim _keyExchangeCompletedWaitHandle = new ManualResetEventSlim(initialState: false);

/// <summary>
/// Exception that need to be thrown by waiting thread.
Expand Down Expand Up @@ -643,6 +638,11 @@ public void Connect()
// Some server implementations might sent this message first, prior to establishing encryption algorithm
RegisterMessage("SSH_MSG_USERAUTH_BANNER");

// Send our key exchange init.
// We need to do this before starting the message listener to avoid the case where we receive the server
// key exchange init and we continue the key exchange before having sent our own init.
SendMessage(ClientInitMessage);

// Mark the message listener threads as started
_ = _messageListenerCompleted.Reset();

Expand All @@ -651,7 +651,7 @@ public void Connect()
_ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);

// Wait for key exchange to be completed
WaitOnHandle(_keyExchangeCompletedWaitHandle);
WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);

// If sessionId is not set then its not connected
if (SessionId is null)
Expand Down Expand Up @@ -757,6 +757,11 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
// Some server implementations might sent this message first, prior to establishing encryption algorithm
RegisterMessage("SSH_MSG_USERAUTH_BANNER");

// Send our key exchange init.
// We need to do this before starting the message listener to avoid the case where we receive the server
// key exchange init and we continue the key exchange before having sent our own init.
SendMessage(ClientInitMessage);

// Mark the message listener threads as started
_ = _messageListenerCompleted.Reset();

Expand All @@ -765,7 +770,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
_ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);

// Wait for key exchange to be completed
WaitOnHandle(_keyExchangeCompletedWaitHandle);
WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);

// If sessionId is not set then its not connected
if (SessionId is null)
Expand Down Expand Up @@ -1046,10 +1051,10 @@ internal void SendMessage(Message message)
throw new SshConnectionException("Client not connected.");
}

if (_keyExchangeInProgress && message is not IKeyExchangedAllowed)
if (!_keyExchangeCompletedWaitHandle.IsSet && message is not IKeyExchangedAllowed)
{
// Wait for key exchange to be completed
WaitOnHandle(_keyExchangeCompletedWaitHandle);
WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
}

DiagnosticAbstraction.Log(string.Format("[{0}] Sending message '{1}' to server: '{2}'.", ToHex(SessionId), message.GetType().Name, message));
Expand Down Expand Up @@ -1394,9 +1399,15 @@ internal void OnKeyExchangeDhGroupExchangeReplyReceived(KeyExchangeDhGroupExchan
/// <param name="message"><see cref="KeyExchangeInitMessage"/> message.</param>
internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
{
_keyExchangeInProgress = true;
// If _keyExchangeCompletedWaitHandle is already set, then this is a key
// re-exchange initiated by the server, and we need to send our own init
// message.
// Otherwise, the wait handle is not set and this received init is part of the
// initial connection for which we have already sent our init, so we shouldn't
// send another one.
var sendClientInitMessage = _keyExchangeCompletedWaitHandle.IsSet;

_ = _keyExchangeCompletedWaitHandle.Reset();
_keyExchangeCompletedWaitHandle.Reset();

// Disable messages that are not key exchange related
_sshMessageFactory.DisableNonKeyExchangeMessages();
Expand All @@ -1411,7 +1422,7 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
_keyExchange.HostKeyReceived += KeyExchange_HostKeyReceived;

// Start the algorithm implementation
_keyExchange.Start(this, message);
_keyExchange.Start(this, message, sendClientInitMessage);

KeyExchangeInitReceived?.Invoke(this, new MessageEventArgs<KeyExchangeInitMessage>(message));
}
Expand Down Expand Up @@ -1477,9 +1488,7 @@ internal void OnNewKeysReceived(NewKeysMessage message)
NewKeysReceived?.Invoke(this, new MessageEventArgs<NewKeysMessage>(message));

// Signal that key exchange completed
_ = _keyExchangeCompletedWaitHandle.Set();

_keyExchangeInProgress = false;
_keyExchangeCompletedWaitHandle.Set();
}

/// <summary>
Expand Down Expand Up @@ -1967,15 +1976,14 @@ private void RaiseError(Exception exp)
private void Reset()
{
_ = _exceptionWaitHandle?.Reset();
_ = _keyExchangeCompletedWaitHandle?.Reset();
_keyExchangeCompletedWaitHandle?.Reset();
_ = _messageListenerCompleted?.Set();

SessionId = null;
_isDisconnectMessageSent = false;
_isDisconnecting = false;
_isAuthenticated = false;
_exception = null;
_keyExchangeInProgress = false;
}

private static SshConnectionException CreateConnectionAbortedByServerException()
Expand Down
Loading