@@ -160,12 +160,7 @@ public class Session : ISession
160
160
/// <summary>
161
161
/// WaitHandle to signal that key exchange was completed.
162
162
/// </summary>
163
- private EventWaitHandle _keyExchangeCompletedWaitHandle = new ManualResetEvent ( initialState : false ) ;
164
-
165
- /// <summary>
166
- /// WaitHandle to signal that key exchange is in progress.
167
- /// </summary>
168
- private bool _keyExchangeInProgress ;
163
+ private ManualResetEventSlim _keyExchangeCompletedWaitHandle = new ManualResetEventSlim ( initialState : false ) ;
169
164
170
165
/// <summary>
171
166
/// Exception that need to be thrown by waiting thread.
@@ -643,6 +638,11 @@ public void Connect()
643
638
// Some server implementations might sent this message first, prior to establishing encryption algorithm
644
639
RegisterMessage ( "SSH_MSG_USERAUTH_BANNER" ) ;
645
640
641
+ // Send our key exchange init.
642
+ // We need to do this before starting the message listener to avoid the case where we receive the server
643
+ // key exchange init and we continue the key exchange before having sent our own init.
644
+ SendMessage ( ClientInitMessage ) ;
645
+
646
646
// Mark the message listener threads as started
647
647
_ = _messageListenerCompleted . Reset ( ) ;
648
648
@@ -651,7 +651,7 @@ public void Connect()
651
651
_ = ThreadAbstraction . ExecuteThreadLongRunning ( MessageListener ) ;
652
652
653
653
// Wait for key exchange to be completed
654
- WaitOnHandle ( _keyExchangeCompletedWaitHandle ) ;
654
+ WaitOnHandle ( _keyExchangeCompletedWaitHandle . WaitHandle ) ;
655
655
656
656
// If sessionId is not set then its not connected
657
657
if ( SessionId is null )
@@ -757,6 +757,11 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
757
757
// Some server implementations might sent this message first, prior to establishing encryption algorithm
758
758
RegisterMessage ( "SSH_MSG_USERAUTH_BANNER" ) ;
759
759
760
+ // Send our key exchange init.
761
+ // We need to do this before starting the message listener to avoid the case where we receive the server
762
+ // key exchange init and we continue the key exchange before having sent our own init.
763
+ SendMessage ( ClientInitMessage ) ;
764
+
760
765
// Mark the message listener threads as started
761
766
_ = _messageListenerCompleted . Reset ( ) ;
762
767
@@ -765,7 +770,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
765
770
_ = ThreadAbstraction . ExecuteThreadLongRunning ( MessageListener ) ;
766
771
767
772
// Wait for key exchange to be completed
768
- WaitOnHandle ( _keyExchangeCompletedWaitHandle ) ;
773
+ WaitOnHandle ( _keyExchangeCompletedWaitHandle . WaitHandle ) ;
769
774
770
775
// If sessionId is not set then its not connected
771
776
if ( SessionId is null )
@@ -1046,10 +1051,10 @@ internal void SendMessage(Message message)
1046
1051
throw new SshConnectionException ( "Client not connected." ) ;
1047
1052
}
1048
1053
1049
- if ( _keyExchangeInProgress && message is not IKeyExchangedAllowed )
1054
+ if ( ! _keyExchangeCompletedWaitHandle . IsSet && message is not IKeyExchangedAllowed )
1050
1055
{
1051
1056
// Wait for key exchange to be completed
1052
- WaitOnHandle ( _keyExchangeCompletedWaitHandle ) ;
1057
+ WaitOnHandle ( _keyExchangeCompletedWaitHandle . WaitHandle ) ;
1053
1058
}
1054
1059
1055
1060
DiagnosticAbstraction . Log ( string . Format ( "[{0}] Sending message '{1}' to server: '{2}'." , ToHex ( SessionId ) , message . GetType ( ) . Name , message ) ) ;
@@ -1394,9 +1399,15 @@ internal void OnKeyExchangeDhGroupExchangeReplyReceived(KeyExchangeDhGroupExchan
1394
1399
/// <param name="message"><see cref="KeyExchangeInitMessage"/> message.</param>
1395
1400
internal void OnKeyExchangeInitReceived ( KeyExchangeInitMessage message )
1396
1401
{
1397
- _keyExchangeInProgress = true ;
1402
+ // If _keyExchangeCompletedWaitHandle is already set, then this is a key
1403
+ // re-exchange initiated by the server, and we need to send our own init
1404
+ // message.
1405
+ // Otherwise, the wait handle is not set and this received init is part of the
1406
+ // initial connection for which we have already sent our init, so we shouldn't
1407
+ // send another one.
1408
+ var sendClientInitMessage = _keyExchangeCompletedWaitHandle . IsSet ;
1398
1409
1399
- _ = _keyExchangeCompletedWaitHandle . Reset ( ) ;
1410
+ _keyExchangeCompletedWaitHandle . Reset ( ) ;
1400
1411
1401
1412
// Disable messages that are not key exchange related
1402
1413
_sshMessageFactory . DisableNonKeyExchangeMessages ( ) ;
@@ -1411,7 +1422,7 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
1411
1422
_keyExchange . HostKeyReceived += KeyExchange_HostKeyReceived ;
1412
1423
1413
1424
// Start the algorithm implementation
1414
- _keyExchange . Start ( this , message ) ;
1425
+ _keyExchange . Start ( this , message , sendClientInitMessage ) ;
1415
1426
1416
1427
KeyExchangeInitReceived ? . Invoke ( this , new MessageEventArgs < KeyExchangeInitMessage > ( message ) ) ;
1417
1428
}
@@ -1477,9 +1488,7 @@ internal void OnNewKeysReceived(NewKeysMessage message)
1477
1488
NewKeysReceived ? . Invoke ( this , new MessageEventArgs < NewKeysMessage > ( message ) ) ;
1478
1489
1479
1490
// Signal that key exchange completed
1480
- _ = _keyExchangeCompletedWaitHandle . Set ( ) ;
1481
-
1482
- _keyExchangeInProgress = false ;
1491
+ _keyExchangeCompletedWaitHandle . Set ( ) ;
1483
1492
}
1484
1493
1485
1494
/// <summary>
@@ -1967,15 +1976,14 @@ private void RaiseError(Exception exp)
1967
1976
private void Reset ( )
1968
1977
{
1969
1978
_ = _exceptionWaitHandle ? . Reset ( ) ;
1970
- _ = _keyExchangeCompletedWaitHandle ? . Reset ( ) ;
1979
+ _keyExchangeCompletedWaitHandle ? . Reset ( ) ;
1971
1980
_ = _messageListenerCompleted ? . Set ( ) ;
1972
1981
1973
1982
SessionId = null ;
1974
1983
_isDisconnectMessageSent = false ;
1975
1984
_isDisconnecting = false ;
1976
1985
_isAuthenticated = false ;
1977
1986
_exception = null ;
1978
- _keyExchangeInProgress = false ;
1979
1987
}
1980
1988
1981
1989
private static SshConnectionException CreateConnectionAbortedByServerException ( )
0 commit comments