Skip to content

Significantly improve performance of ShellStream's Expect methods #1207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/Renci.SshNet/IServiceFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ internal partial interface IServiceFactory
/// <param name="height">The terminal height in pixels.</param>
/// <param name="terminalModeValues">The terminal mode values.</param>
/// <param name="bufferSize">Size of the buffer.</param>
/// <param name="expectSize">Size of the expect buffer.</param>
/// <returns>
/// The created <see cref="ShellStream"/> instance.
/// </returns>
Expand All @@ -135,7 +136,8 @@ ShellStream CreateShellStream(ISession session,
uint width,
uint height,
IDictionary<TerminalModes, uint> terminalModeValues,
int bufferSize);
int bufferSize,
int expectSize);

/// <summary>
/// Creates an <see cref="IRemotePathTransformation"/> that encloses a path in double quotes, and escapes
Expand Down
7 changes: 4 additions & 3 deletions src/Renci.SshNet/ServiceFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,23 +187,24 @@ public ISftpResponseFactory CreateSftpResponseFactory()
/// <param name="height">The terminal height in pixels.</param>
/// <param name="terminalModeValues">The terminal mode values.</param>
/// <param name="bufferSize">The size of the buffer.</param>
/// <param name="expectSize">The size of the expect buffer.</param>
/// <returns>
/// The created <see cref="ShellStream"/> instance.
/// </returns>
/// <exception cref="SshConnectionException">Client is not connected.</exception>
/// <remarks>
/// <para>
/// The <c>TERM</c> environment variable contains an identifier for the text window's capabilities.
/// You can get a detailed list of these cababilities by using the ‘infocmp’ command.
/// You can get a detailed list of these capabilities by using the ‘infocmp’ command.
/// </para>
/// <para>
/// The column/row dimensions override the pixel dimensions(when non-zero). Pixel dimensions refer
/// to the drawable area of the window.
/// </para>
/// </remarks>
public ShellStream CreateShellStream(ISession session, string terminalName, uint columns, uint rows, uint width, uint height, IDictionary<TerminalModes, uint> terminalModeValues, int bufferSize)
public ShellStream CreateShellStream(ISession session, string terminalName, uint columns, uint rows, uint width, uint height, IDictionary<TerminalModes, uint> terminalModeValues, int bufferSize, int expectSize)
{
return new ShellStream(session, terminalName, columns, rows, width, height, terminalModeValues, bufferSize);
return new ShellStream(session, terminalName, columns, rows, width, height, terminalModeValues, bufferSize, expectSize);
}

/// <summary>
Expand Down
109 changes: 79 additions & 30 deletions src/Renci.SshNet/ShellStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public class ShellStream : Stream
private readonly Encoding _encoding;
private readonly int _bufferSize;
private readonly Queue<byte> _incoming;
private readonly int _expectSize;
private readonly Queue<byte> _expect;
private readonly Queue<byte> _outgoing;
private IChannelSession _channel;
private AutoResetEvent _dataReceived = new AutoResetEvent(initialState: false);
Expand Down Expand Up @@ -76,15 +78,28 @@ internal int BufferSize
/// <param name="height">The terminal height in pixels.</param>
/// <param name="terminalModeValues">The terminal mode values.</param>
/// <param name="bufferSize">The size of the buffer.</param>
/// <param name="expectSize">The size of the expect buffer.</param>
/// <exception cref="SshException">The channel could not be opened.</exception>
/// <exception cref="SshException">The pseudo-terminal request was not accepted by the server.</exception>
/// <exception cref="SshException">The request to start a shell was not accepted by the server.</exception>
internal ShellStream(ISession session, string terminalName, uint columns, uint rows, uint width, uint height, IDictionary<TerminalModes, uint> terminalModeValues, int bufferSize)
internal ShellStream(ISession session, string terminalName, uint columns, uint rows, uint width, uint height, IDictionary<TerminalModes, uint> terminalModeValues, int bufferSize, int expectSize)
{
if (bufferSize <= 0)
{
throw new ArgumentException($"{nameof(bufferSize)} must be between 1 and {int.MaxValue}.");
}

if (expectSize <= 0)
{
throw new ArgumentException($"{nameof(expectSize)} must be between 1 and {int.MaxValue}.");
}

_encoding = session.ConnectionInfo.Encoding;
_session = session;
_bufferSize = bufferSize;
_incoming = new Queue<byte>();
_expectSize = expectSize;
_expect = new Queue<byte>(_expectSize);
_outgoing = new Queue<byte>();

_channel = _session.CreateChannelSession();
Expand Down Expand Up @@ -248,35 +263,40 @@ public void Expect(params ExpectAction[] expectActions)
public void Expect(TimeSpan timeout, params ExpectAction[] expectActions)
{
var expectedFound = false;
var text = string.Empty;
var matchText = string.Empty;

do
{
lock (_incoming)
{
if (_incoming.Count > 0)
if (_expect.Count > 0)
{
text = _encoding.GetString(_incoming.ToArray(), 0, _incoming.Count);
matchText = _encoding.GetString(_expect.ToArray(), 0, _expect.Count);
}

if (text.Length > 0)
if (matchText.Length > 0)
{
foreach (var expectAction in expectActions)
{
var match = expectAction.Expect.Match(text);
var match = expectAction.Expect.Match(matchText);

if (match.Success)
{
var result = text.Substring(0, match.Index + match.Length);
var charCount = _encoding.GetByteCount(result);
var returnText = matchText.Substring(0, match.Index + match.Length);
var returnLength = _encoding.GetByteCount(returnText);

for (var i = 0; i < charCount && _incoming.Count > 0; i++)
// Remove processed items from the queue
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually going to remove what it should from both _incoming and _expect?

If _incoming looks like: aaaaabbbbbcccccZ
and _expect (and so matchText) looks like: cccZ
and we are expecting Z
then returnText == "cccZ" and returnLength == 4

Then we are going to dequeue aaaa from _incoming and nothing from _expect (because _incoming.Count > _expect.Count + 4). So we end up with:

_incoming looks like: abbbbbcccccZ
and _expect still looks like: cccZ

Have I got that right? Is that expected? (it feels wrong)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to add a test to try to replicate this and make sure its accounted for. I have a feeling that the dequeueing may be slightly off since for things to happen as you've mentioned, data would have to accumulate and be read in different ways within the same workflow.

{
// Remove processed items from the queue
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}

expectAction.Action(result);
expectAction.Action(returnText);
expectedFound = true;
}
}
Expand Down Expand Up @@ -349,27 +369,33 @@ public string Expect(Regex regex)
/// </returns>
public string Expect(Regex regex, TimeSpan timeout)
{
var result = string.Empty;
var matchText = string.Empty;
string returnText;

while (true)
{
lock (_incoming)
{
if (_incoming.Count > 0)
if (_expect.Count > 0)
{
result = _encoding.GetString(_incoming.ToArray(), 0, _incoming.Count);
matchText = _encoding.GetString(_expect.ToArray(), 0, _expect.Count);
}

var match = regex.Match(result);
var match = regex.Match(matchText);

if (match.Success)
{
result = result.Substring(0, match.Index + match.Length);
var charCount = _encoding.GetByteCount(result);
returnText = matchText.Substring(0, match.Index + match.Length);
var returnLength = _encoding.GetByteCount(returnText);

// Remove processed items from the queue
for (var i = 0; i < charCount && _incoming.Count > 0; i++)
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
{
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}

Expand All @@ -390,7 +416,7 @@ public string Expect(Regex regex, TimeSpan timeout)
}
}

return result;
return returnText;
}

/// <summary>
Expand Down Expand Up @@ -446,7 +472,8 @@ public IAsyncResult BeginExpect(AsyncCallback callback, object state, params Exp
public IAsyncResult BeginExpect(TimeSpan timeout, AsyncCallback callback, object state, params ExpectAction[] expectActions)
#pragma warning restore CA1859 // Use concrete types when possible for improved performance
{
var text = string.Empty;
var matchText = string.Empty;
string returnText;

// Create new AsyncResult object
var asyncResult = new ExpectAsyncResult(callback, state);
Expand All @@ -461,31 +488,36 @@ public IAsyncResult BeginExpect(TimeSpan timeout, AsyncCallback callback, object
{
lock (_incoming)
{
if (_incoming.Count > 0)
if (_expect.Count > 0)
{
text = _encoding.GetString(_incoming.ToArray(), 0, _incoming.Count);
matchText = _encoding.GetString(_expect.ToArray(), 0, _expect.Count);
}

if (text.Length > 0)
if (matchText.Length > 0)
{
foreach (var expectAction in expectActions)
{
var match = expectAction.Expect.Match(text);
var match = expectAction.Expect.Match(matchText);

if (match.Success)
{
var result = text.Substring(0, match.Index + match.Length);
var charCount = _encoding.GetByteCount(result);
returnText = matchText.Substring(0, match.Index + match.Length);
var returnLength = _encoding.GetByteCount(returnText);

for (var i = 0; i < match.Index + match.Length && _incoming.Count > 0; i++)
// Remove processed items from the queue
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
{
// Remove processed items from the queue
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}

expectAction.Action(result);
expectAction.Action(returnText);
callback?.Invoke(asyncResult);
expectActionResult = result;
expectActionResult = returnText;
}
}
}
Expand Down Expand Up @@ -584,6 +616,11 @@ public string ReadLine(TimeSpan timeout)
// remove processed bytes from the queue
for (var i = 0; i < bytesProcessed; i++)
{
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}

Expand Down Expand Up @@ -620,6 +657,7 @@ public string Read()
lock (_incoming)
{
text = _encoding.GetString(_incoming.ToArray(), 0, _incoming.Count);
_expect.Clear();
_incoming.Clear();
}

Expand Down Expand Up @@ -649,6 +687,11 @@ public override int Read(byte[] buffer, int offset, int count)
{
for (; i < count && _incoming.Count > 0; i++)
{
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

buffer[offset + i] = _incoming.Dequeue();
}
}
Expand Down Expand Up @@ -800,6 +843,12 @@ private void Channel_DataReceived(object sender, ChannelDataEventArgs e)
foreach (var b in e.Data)
{
_incoming.Enqueue(b);
if (_expect.Count == _expectSize)
{
_ = _expect.Dequeue();
}

_expect.Enqueue(b);
}
}

Expand Down
71 changes: 67 additions & 4 deletions src/Renci.SshNet/SshClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ public Shell CreateShell(Encoding encoding, string input, Stream output, Stream
/// <remarks>
/// <para>
/// The <c>TERM</c> environment variable contains an identifier for the text window's capabilities.
/// You can get a detailed list of these cababilities by using the ‘infocmp’ command.
/// You can get a detailed list of these capabilities by using the ‘infocmp’ command.
/// </para>
/// <para>
/// The column/row dimensions override the pixel dimensions(when nonzero). Pixel dimensions refer
Expand All @@ -416,7 +416,38 @@ public Shell CreateShell(Encoding encoding, string input, Stream output, Stream
/// </remarks>
public ShellStream CreateShellStream(string terminalName, uint columns, uint rows, uint width, uint height, int bufferSize)
{
return CreateShellStream(terminalName, columns, rows, width, height, bufferSize, terminalModeValues: null);
return CreateShellStream(terminalName, columns, rows, width, height, bufferSize, bufferSize * 2, terminalModeValues: null);
}

/// <summary>
/// Creates the shell stream.
/// </summary>
/// <param name="terminalName">The <c>TERM</c> environment variable.</param>
/// <param name="columns">The terminal width in columns.</param>
/// <param name="rows">The terminal width in rows.</param>
/// <param name="width">The terminal width in pixels.</param>
/// <param name="height">The terminal height in pixels.</param>
/// <param name="bufferSize">The size of the buffer.</param>
/// <param name="expectSize">The size of the expect buffer.</param>
/// <returns>
/// The created <see cref="ShellStream"/> instance.
/// </returns>
/// <exception cref="SshConnectionException">Client is not connected.</exception>
/// <remarks>
/// <para>
/// The <c>TERM</c> environment variable contains an identifier for the text window's capabilities.
/// You can get a detailed list of these capabilities by using the ‘infocmp’ command.
/// </para>
/// <para>
/// The column/row dimensions override the pixel dimensions(when non-zero). Pixel dimensions refer
/// to the drawable area of the window.
/// </para>
/// </remarks>
public ShellStream CreateShellStream(string terminalName, uint columns, uint rows, uint width, uint height, int bufferSize, int expectSize)
{
EnsureSessionIsOpen();

return CreateShellStream(terminalName, columns, rows, width, height, bufferSize, expectSize, terminalModeValues: null);
}

/// <summary>
Expand All @@ -436,7 +467,7 @@ public ShellStream CreateShellStream(string terminalName, uint columns, uint row
/// <remarks>
/// <para>
/// The <c>TERM</c> environment variable contains an identifier for the text window's capabilities.
/// You can get a detailed list of these cababilities by using the ‘infocmp’ command.
/// You can get a detailed list of these capabilities by using the ‘infocmp’ command.
/// </para>
/// <para>
/// The column/row dimensions override the pixel dimensions(when non-zero). Pixel dimensions refer
Expand All @@ -447,7 +478,39 @@ public ShellStream CreateShellStream(string terminalName, uint columns, uint row
{
EnsureSessionIsOpen();

return ServiceFactory.CreateShellStream(Session, terminalName, columns, rows, width, height, terminalModeValues, bufferSize);
return CreateShellStream(terminalName, columns, rows, width, height, bufferSize, bufferSize * 2, terminalModeValues);
}

/// <summary>
/// Creates the shell stream.
/// </summary>
/// <param name="terminalName">The <c>TERM</c> environment variable.</param>
/// <param name="columns">The terminal width in columns.</param>
/// <param name="rows">The terminal width in rows.</param>
/// <param name="width">The terminal width in pixels.</param>
/// <param name="height">The terminal height in pixels.</param>
/// <param name="bufferSize">The size of the buffer.</param>
/// <param name="expectSize">The size of the expect buffer.</param>
/// <param name="terminalModeValues">The terminal mode values.</param>
/// <returns>
/// The created <see cref="ShellStream"/> instance.
/// </returns>
/// <exception cref="SshConnectionException">Client is not connected.</exception>
/// <remarks>
/// <para>
/// The <c>TERM</c> environment variable contains an identifier for the text window's capabilities.
/// You can get a detailed list of these capabilities by using the ‘infocmp’ command.
/// </para>
/// <para>
/// The column/row dimensions override the pixel dimensions(when non-zero). Pixel dimensions refer
/// to the drawable area of the window.
/// </para>
/// </remarks>
public ShellStream CreateShellStream(string terminalName, uint columns, uint rows, uint width, uint height, int bufferSize, int expectSize, IDictionary<TerminalModes, uint> terminalModeValues)
{
EnsureSessionIsOpen();

return ServiceFactory.CreateShellStream(Session, terminalName, columns, rows, width, height, terminalModeValues, bufferSize, expectSize);
}

/// <summary>
Expand Down
Loading