Skip to content

Commit 721bf15

Browse files
authored
fix(langchain): resolve race condition in ShellSession.execute() (#34535)
Addresses a flaky test When executing `exit 1` as a startup command, the shell process terminates immediately. The code then tries to write a marker command (`printf '...'`) to stdin, but the pipe is already broken because the shell has exited, causing `BrokenPipeError`.
1 parent dcfd9c0 commit 721bf15

File tree

1 file changed

+82
-3
lines changed

1 file changed

+82
-3
lines changed

libs/langchain_v1/langchain/agents/middleware/shell_tool.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,14 @@ def execute(self, command: str, *, timeout: float) -> CommandExecutionResult:
211211
with self._lock:
212212
self._drain_queue()
213213
payload = command if command.endswith("\n") else f"{command}\n"
214-
self._stdin.write(payload)
215-
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
216-
self._stdin.flush()
214+
try:
215+
self._stdin.write(payload)
216+
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
217+
self._stdin.flush()
218+
except (BrokenPipeError, OSError):
219+
# The shell exited before we could write the marker command.
220+
# This happens when commands like 'exit 1' terminate the shell.
221+
return self._collect_output_after_exit(deadline)
217222

218223
return self._collect_output(marker, deadline, timeout)
219224

@@ -304,6 +309,80 @@ def _collect_output(
304309
total_bytes=total_bytes,
305310
)
306311

312+
def _collect_output_after_exit(self, deadline: float) -> CommandExecutionResult:
313+
"""Collect output after the shell exited unexpectedly.
314+
315+
Called when a `BrokenPipeError` occurs while writing to stdin, indicating the
316+
shell process terminated (e.g., due to an 'exit' command).
317+
318+
Args:
319+
deadline: Absolute time by which collection must complete.
320+
321+
Returns:
322+
`CommandExecutionResult` with collected output and the process exit code.
323+
"""
324+
collected: list[str] = []
325+
total_lines = 0
326+
total_bytes = 0
327+
truncated_by_lines = False
328+
truncated_by_bytes = False
329+
330+
# Give reader threads a brief moment to enqueue any remaining output.
331+
drain_timeout = 0.1
332+
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
333+
334+
while True:
335+
remaining = drain_deadline - time.monotonic()
336+
if remaining <= 0:
337+
break
338+
try:
339+
source, data = self._queue.get(timeout=remaining)
340+
except queue.Empty:
341+
break
342+
343+
if data is None:
344+
# EOF marker from a reader thread; continue draining.
345+
continue
346+
347+
total_lines += 1
348+
encoded = data.encode("utf-8", "replace")
349+
total_bytes += len(encoded)
350+
351+
if total_lines > self._policy.max_output_lines:
352+
truncated_by_lines = True
353+
continue
354+
355+
if (
356+
self._policy.max_output_bytes is not None
357+
and total_bytes > self._policy.max_output_bytes
358+
):
359+
truncated_by_bytes = True
360+
continue
361+
362+
if source == "stderr":
363+
stripped = data.rstrip("\n")
364+
collected.append(f"[stderr] {stripped}")
365+
if data.endswith("\n"):
366+
collected.append("\n")
367+
else:
368+
collected.append(data)
369+
370+
# Get exit code from the terminated process.
371+
exit_code: int | None = None
372+
if self._process:
373+
exit_code = self._process.poll()
374+
375+
output = "".join(collected)
376+
return CommandExecutionResult(
377+
output=output,
378+
exit_code=exit_code,
379+
timed_out=False,
380+
truncated_by_lines=truncated_by_lines,
381+
truncated_by_bytes=truncated_by_bytes,
382+
total_lines=total_lines,
383+
total_bytes=total_bytes,
384+
)
385+
307386
def _kill_process(self) -> None:
308387
if not self._process:
309388
return

0 commit comments

Comments
 (0)