Skip to content

Handle unknown channel messages correctly #1363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/Renci.SshNet/Channels/Channel.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Globalization;
using System.Net.Sockets;
using System.Threading;

Expand Down Expand Up @@ -715,8 +714,14 @@ private void OnChannelRequest(object sender, MessageEventArgs<ChannelRequestMess
}
else
{
// TODO: we should also send a SSH_MSG_CHANNEL_FAILURE message
throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Request '{0}' is not supported.", e.Message.RequestName));
var unknownRequestInfo = new UnknownRequestInfo(e.Message.RequestName);
unknownRequestInfo.Load(e.Message.RequestData);

if (unknownRequestInfo.WantReply)
{
var reply = new ChannelFailureMessage(RemoteChannelNumber);
SendMessage(reply);
}
}
}
catch (Exception ex)
Expand Down
5 changes: 5 additions & 0 deletions src/Renci.SshNet/Channels/IChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ internal interface IChannel : IDisposable
/// </remarks>
uint LocalPacketSize { get; }

/// <summary>
/// Gets the remote channel number.
/// </summary>
uint RemoteChannelNumber { get; }

/// <summary>
/// Gets the maximum size of a data packet that can be sent using the channel.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
namespace Renci.SshNet.Messages.Connection
{
/// <summary>
/// Represents an unknown request information that we can't handle.
/// </summary>
internal sealed class UnknownRequestInfo : RequestInfo
{
/// <summary>
/// Gets the name of the request.
/// </summary>
public override string RequestName { get; }

/// <summary>
/// Initializes a new instance of the <see cref="UnknownRequestInfo"/> class.
/// <paramref name="requestName">The name of the unknown request.</paramref>
/// </summary>
internal UnknownRequestInfo(string requestName)
{
RequestName = requestName;
}
}
}
4 changes: 2 additions & 2 deletions src/Renci.SshNet/SshCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -464,15 +464,15 @@ private void Channel_RequestReceived(object sender, ChannelRequestEventArgs e)

if (exitStatusInfo.WantReply)
{
var replyMessage = new ChannelSuccessMessage(_channel.LocalChannelNumber);
var replyMessage = new ChannelSuccessMessage(_channel.RemoteChannelNumber);
_session.SendMessage(replyMessage);
}
}
else
{
if (e.Info.WantReply)
{
var replyMessage = new ChannelFailureMessage(_channel.LocalChannelNumber);
var replyMessage = new ChannelFailureMessage(_channel.RemoteChannelNumber);
_session.SendMessage(replyMessage);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
using System;
using System.Collections.Generic;

using Microsoft.VisualStudio.TestTools.UnitTesting;

using Moq;

using Renci.SshNet.Common;
using Renci.SshNet.Messages;
using Renci.SshNet.Messages.Connection;

namespace Renci.SshNet.Tests.Classes.Channels
{
[TestClass]
public class ChannelTest_OnSessionChannelRequestReceived_HandleUnknownMessage : ChannelTestBase
{
private uint _localWindowSize;
private uint _localPacketSize;
private uint _localChannelNumber;
private uint _remoteChannelNumber;
private uint _remoteWindowSize;
private uint _remotePacketSize;
private ChannelStub _channel;
private IList<ExceptionEventArgs> _channelExceptionRegister;
private UnknownRequestInfoWithWantReply _requestInfo;

protected override void SetupData()
{
var random = new Random();

_localWindowSize = (uint) random.Next(1000, int.MaxValue);
_localPacketSize = _localWindowSize - 1;
_localChannelNumber = (uint) random.Next(0, int.MaxValue);
_remoteChannelNumber = (uint) random.Next(0, int.MaxValue);
_remoteWindowSize = (uint) random.Next(0, int.MaxValue);
_remotePacketSize = (uint) random.Next(0, int.MaxValue);
_channelExceptionRegister = new List<ExceptionEventArgs>();
_requestInfo = new UnknownRequestInfoWithWantReply();
}

protected override void SetupMocks()
{
_ = SessionMock.Setup(p => p.ConnectionInfo)
.Returns(new ConnectionInfo("host", "user", new PasswordAuthenticationMethod("user", "password")));
_ = SessionMock.Setup(p => p.SendMessage(It.IsAny<Message>()));
}

protected override void Arrange()
{
base.Arrange();

_channel = new ChannelStub(SessionMock.Object, _localChannelNumber, _localWindowSize, _localPacketSize);
_channel.InitializeRemoteChannelInfo(_remoteChannelNumber, _remoteWindowSize, _remotePacketSize);
_channel.SetIsOpen(true);
_channel.Exception += (sender, args) => _channelExceptionRegister.Add(args);
}

protected override void Act()
{
SessionMock.Raise(s => s.ChannelRequestReceived += null,
new MessageEventArgs<ChannelRequestMessage>(new ChannelRequestMessage(_localChannelNumber, _requestInfo)));
}

[TestMethod]
public void FailureMessageWasSent()
{
SessionMock.Verify(p => p.SendMessage(It.Is<ChannelFailureMessage>(m => m.LocalChannelNumber == _channel.RemoteChannelNumber)), Times.Once);
}

[TestMethod]
public void NoExceptionShouldHaveFired()
{
Assert.AreEqual(0, _channelExceptionRegister.Count);
}
}

internal class UnknownRequestInfoWithWantReply : RequestInfo
{
public override string RequestName
{
get
{
return nameof(UnknownRequestInfoWithWantReply);
}
}

internal UnknownRequestInfoWithWantReply()
{
WantReply = true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ public void Open()

public uint LocalPacketSize => throw new NotImplementedException();

public uint RemoteChannelNumber => throw new NotImplementedException();

public uint RemotePacketSize => throw new NotImplementedException();

public bool IsOpen => throw new NotImplementedException();
Expand Down