Skip to content

Commit 659d2ee

Browse files
author
K Raajeive
committed
Added commandTimeoutMillis in executeRemoteCommand
Added commandTimeoutMillis in executeRemoteCommand
1 parent 67768dc commit 659d2ee

2 files changed

Lines changed: 168 additions & 9 deletions

File tree

sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java

Lines changed: 100 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import java.util.Iterator;
3838
import java.util.Map;
3939
import java.util.Set;
40+
import java.util.concurrent.TimeUnit;
4041

4142
import org.apache.sshd.client.ClientAuthenticationManager;
4243
import org.apache.sshd.client.ClientFactoryManager;
@@ -95,6 +96,11 @@ enum ClientSessionEvent {
9596
AUTHED
9697
}
9798

99+
/**
100+
* Minimum timeout value. When used in {@link #executeRemoteCommand}, the command execution will wait indefinitely.
101+
*/
102+
long MIN_TIMEOUT = 0L;
103+
98104
Set<ClientChannelEvent> REMOTE_COMMAND_WAIT_EVENTS = Collections.unmodifiableSet(EnumSet.of(ClientChannelEvent.CLOSED));
99105

100106
/**
@@ -235,12 +241,31 @@ ChannelExec createExecChannel(byte[] command, PtyChannelConfigurationHolder ptyC
235241
* error or a non-zero exit status was received. If this happens, then a {@link RemoteException}
236242
* is thrown with a cause of {@link ServerException} containing the remote captured standard
237243
* error - including CR/LF(s)
238-
* @see #executeRemoteCommand(String, OutputStream, Charset)
244+
* @see #executeRemoteCommand(String, long)
239245
*/
240246
default String executeRemoteCommand(String command) throws IOException {
247+
return executeRemoteCommand(command, MIN_TIMEOUT);
248+
}
249+
250+
/**
251+
* Execute a command that requires no input and returns its output
252+
*
253+
* @param command The command to execute
254+
* @param timeoutMillis Timeout (in milliseconds) for the remote command execution. Applies to both channel opening
255+
* and result waiting. A zero or negative value means no timeout.
256+
* @return The command's standard output result
257+
* @return The command's standard output result (assumed to be in US-ASCII)
258+
* @throws IOException If failed to execute the command - including if <U>anything</U> was written to the standard
259+
* error or a non-zero exit status was received. If this happens, then a
260+
* {@link RemoteException} is thrown with a cause of {@link ServerException} containing the
261+
* remote captured standard error - including CR/LF(s)
262+
* @see #executeRemoteCommand(String, OutputStream, Charset)
263+
* @see #executeRemoteCommand(String, OutputStream, Charset, long)
264+
*/
265+
default String executeRemoteCommand(String command, long timeoutMillis) throws IOException {
241266
try (ByteArrayOutputStream stderr = new ByteArrayOutputStream()) {
242267
try {
243-
return executeRemoteCommand(command, stderr, StandardCharsets.US_ASCII);
268+
return executeRemoteCommand(command, stderr, StandardCharsets.US_ASCII, timeoutMillis);
244269
} finally {
245270
if (stderr.size() > 0) {
246271
String errorMessage = stderr.toString(StandardCharsets.US_ASCII.name());
@@ -264,15 +289,38 @@ default String executeRemoteCommand(String command) throws IOException {
264289
* was output to the standard error stream, but does check the reported exit status (if any) for
265290
* non-zero value. If non-zero exit status received then a {@link RemoteException} is thrown
266291
* with' a {@link ServerException} cause containing the exits value
267-
* @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset)
292+
* @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, long)
268293
*/
269294
default String executeRemoteCommand(String command, OutputStream stderr, Charset charset) throws IOException {
295+
return executeRemoteCommand(command, stderr, charset, MIN_TIMEOUT);
296+
}
297+
298+
/**
299+
* Execute a command that requires no input and returns its output
300+
*
301+
* @param command The command to execute - without a terminating LF
302+
* @param stderr Standard error output stream - if {@code null} then error stream data is ignored.
303+
* <B>Note:</B> if the stream is not {@code null} then it will be left <U>open</U> when this
304+
* method returns or exception is thrown
305+
* @param charset The command {@link Charset} for input/output/error - if {@code null} then US_ASCII is
306+
* assumed
307+
* @param timeoutMillis Timeout (in milliseconds) for the remote command execution. Applies to both channel opening
308+
* and result waiting. A zero or negative value means no timeout.
309+
* @return The command's standard output result
310+
* @throws IOException If failed to manage the command channel - <B>Note:</B> the code does not check if anything
311+
* was output to the standard error stream, but does check the reported exit status (if any)
312+
* for non-zero value. If non-zero exit status received then a {@link RemoteException} is
313+
* thrown with' a {@link ServerException} cause containing the exits value
314+
* @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, long)
315+
*/
316+
default String executeRemoteCommand(String command, OutputStream stderr, Charset charset, long timeoutMillis)
317+
throws IOException {
270318
if (charset == null) {
271319
charset = StandardCharsets.US_ASCII;
272320
}
273321

274322
try (ByteArrayOutputStream stdout = new ByteArrayOutputStream(Byte.MAX_VALUE)) {
275-
executeRemoteCommand(command, stdout, stderr, charset);
323+
executeRemoteCommand(command, stdout, stderr, charset, timeoutMillis);
276324
byte[] outBytes = stdout.toByteArray();
277325
return new String(outBytes, charset);
278326
}
@@ -290,26 +338,69 @@ default String executeRemoteCommand(String command, OutputStream stderr, Charset
290338
* thrown
291339
* @param charset The command {@link Charset} for output/error - if {@code null} then US_ASCII is assumed
292340
* @throws IOException If failed to execute the command or got a non-zero exit status
293-
* @see ClientChannel#validateCommandExitStatusCode(String, Integer) validateCommandExitStatusCode
341+
* @see #executeRemoteCommand(String, OutputStream, OutputStream, Charset, long)
294342
*/
295343
default void executeRemoteCommand(
296344
String command, OutputStream stdout, OutputStream stderr, Charset charset)
297345
throws IOException {
346+
executeRemoteCommand(command, stdout, stderr, charset, MIN_TIMEOUT);
347+
}
348+
349+
/**
350+
* Execute a command that requires no input and redirects its STDOUT/STDERR streams to the user-provided ones
351+
*
352+
* @param command The command to execute - without a terminating LF.
353+
* @param stdout Standard output stream - if {@code null} then stream data is ignored. <b>Note:</b> if the
354+
* stream is not {@code null}, it will be left <u>open</u> when this method returns or an
355+
* exception is thrown.
356+
* @param stderr Error output stream - if {@code null} then error stream data is ignored. <b>Note:</b> if
357+
* the stream is not {@code null}, it will be left <u>open</u> when this method returns or an
358+
* exception is thrown.
359+
* @param charset The charset to use for encoding the command and decoding the output/error streams. If
360+
* {@code null}, US-ASCII is assumed.
361+
* @param timeoutMillis Timeout (in milliseconds) for the remote command execution. Applies to both channel opening
362+
* and result waiting. A zero or negative value means no timeout.
363+
* @throws IOException If the command execution fails, times out, or returns a non-zero exit code. A
364+
* {@link RemoteException} may be thrown if the remote side reports an error.
365+
* @see ClientChannel#open()#verify(long, java.util.concurrent.TimeUnit)
366+
* @see ClientChannel#waitFor(Collection, long)
367+
* @see ClientChannel#validateCommandExitStatusCode(String, Integer) validateCommandExitStatusCode
368+
*/
369+
default void executeRemoteCommand(
370+
String command, OutputStream stdout, OutputStream stderr, Charset charset, long timeoutMillis)
371+
throws IOException {
372+
298373
if (charset == null) {
299374
charset = StandardCharsets.US_ASCII;
300375
}
301376

377+
if (timeoutMillis < 0) {
378+
throw new IllegalArgumentException("Timeout must be non-negative");
379+
}
380+
302381
try (OutputStream channelErr = (stderr == null) ? new NullOutputStream() : new NoCloseOutputStream(stderr);
303382
OutputStream channelOut = (stdout == null) ? new NullOutputStream() : new NoCloseOutputStream(stdout);
304383
ClientChannel channel = createExecChannel(command, charset, null, Collections.emptyMap())) {
384+
305385
channel.setOut(channelOut);
306386
channel.setErr(channelErr);
307-
channel.open().await(); // TODO use verify and a configurable timeout
308387

309-
// TODO use a configurable timeout
310-
Collection<ClientChannelEvent> waitMask = channel.waitFor(REMOTE_COMMAND_WAIT_EVENTS, 0L);
388+
long waitTimeout;
389+
if (timeoutMillis > 0) {
390+
long startTime = System.currentTimeMillis();
391+
channel.open().verify(timeoutMillis, TimeUnit.MILLISECONDS);
392+
393+
long elapsed = System.currentTimeMillis() - startTime;
394+
waitTimeout = Math.max(1, timeoutMillis - elapsed);
395+
} else {
396+
channel.open().verify(); // wait indefinitely
397+
waitTimeout = 0L; // waitFor will also wait indefinitely
398+
}
399+
400+
Collection<ClientChannelEvent> waitMask = channel.waitFor(REMOTE_COMMAND_WAIT_EVENTS, waitTimeout);
311401
if (waitMask.contains(ClientChannelEvent.TIMEOUT)) {
312-
throw new SocketTimeoutException("Failed to retrieve command result in time: " + command);
402+
throw new SocketTimeoutException(String.format(
403+
"Failed to retrieve command '%s' result within timeout of %d ms", command, timeoutMillis));
313404
}
314405

315406
Integer exitStatus = channel.getExitStatus();

sshd-core/src/test/java/org/apache/sshd/client/session/ClientSessionTest.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.io.IOException;
2424
import java.io.InputStream;
2525
import java.io.OutputStream;
26+
import java.net.SocketTimeoutException;
2627
import java.nio.charset.StandardCharsets;
2728
import java.rmi.RemoteException;
2829
import java.rmi.ServerException;
@@ -240,6 +241,73 @@ protected boolean handleCommandLine(String command) throws Exception {
240241
assertEquals(Integer.toString(expectedErrorCode), actualErrorMessage, "Mismatched captured error code");
241242
}
242243

244+
@Test
245+
void executeCommandMethodWithConfigurableTimeout() throws Exception {
246+
String expectedCommand = getCurrentTestName() + "-CMD";
247+
String expectedResponse = getCurrentTestName() + "-RSP";
248+
long timeoutMillis = 10000L;
249+
sshd.setCommandFactory((session, command) -> new CommandExecutionHelper(command) {
250+
private boolean cmdProcessed;
251+
252+
@Override
253+
protected boolean handleCommandLine(String command) throws Exception {
254+
assertEquals(expectedCommand, command, "Mismatched incoming command");
255+
assertFalse(cmdProcessed, "Duplicated command call");
256+
OutputStream stdout = getOutputStream();
257+
Thread.sleep(500L);
258+
stdout.write(expectedResponse.getBytes(StandardCharsets.US_ASCII));
259+
stdout.flush();
260+
cmdProcessed = true;
261+
return false;
262+
}
263+
});
264+
265+
try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
266+
.verify(CONNECT_TIMEOUT)
267+
.getSession()) {
268+
session.addPasswordIdentity(getCurrentTestName());
269+
session.auth().verify(AUTH_TIMEOUT);
270+
271+
// NOTE !!! The LF is only because we are using a buffered reader on the server end to read the command
272+
String actualResponse = session.executeRemoteCommand(expectedCommand + "\n", timeoutMillis);
273+
assertEquals(expectedResponse, actualResponse, "Mismatched command response");
274+
}
275+
}
276+
277+
278+
@Test
279+
void exceptionThrownOnExecuteCommandTimeout() throws Exception {
280+
String expectedCommand = getCurrentTestName() + "-CMD";
281+
long timeoutMillis = 500;
282+
283+
sshd.setCommandFactory((session, command) -> new CommandExecutionHelper(command) {
284+
private boolean cmdProcessed;
285+
286+
@Override
287+
protected boolean handleCommandLine(String command) throws Exception {
288+
assertEquals(expectedCommand, command, "Mismatched incoming command");
289+
assertFalse(cmdProcessed, "Duplicated command call");
290+
Thread.sleep(timeoutMillis + 200);
291+
OutputStream stdout = getOutputStream();
292+
stdout.write(command.getBytes(StandardCharsets.US_ASCII));
293+
stdout.flush();
294+
cmdProcessed = true;
295+
return false;
296+
}
297+
});
298+
299+
try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
300+
.verify(CONNECT_TIMEOUT)
301+
.getSession()) {
302+
session.addPasswordIdentity(getCurrentTestName());
303+
session.auth().verify(AUTH_TIMEOUT);
304+
305+
assertThrows(SocketTimeoutException.class, () -> {
306+
session.executeRemoteCommand(expectedCommand + "\n", timeoutMillis);
307+
});
308+
}
309+
}
310+
243311
// see SSHD-859
244312
@Test
245313
void connectionContextPropagation() throws Exception {

0 commit comments

Comments
 (0)