@@ -211,9 +211,14 @@ def execute(self, command: str, *, timeout: float) -> CommandExecutionResult:
211211 with self ._lock :
212212 self ._drain_queue ()
213213 payload = command if command .endswith ("\n " ) else f"{ command } \n "
214- self ._stdin .write (payload )
215- self ._stdin .write (f"printf '{ marker } %s\\ n' $?\n " )
216- self ._stdin .flush ()
214+ try :
215+ self ._stdin .write (payload )
216+ self ._stdin .write (f"printf '{ marker } %s\\ n' $?\n " )
217+ self ._stdin .flush ()
218+ except (BrokenPipeError , OSError ):
219+ # The shell exited before we could write the marker command.
220+ # This happens when commands like 'exit 1' terminate the shell.
221+ return self ._collect_output_after_exit (deadline )
217222
218223 return self ._collect_output (marker , deadline , timeout )
219224
@@ -304,6 +309,80 @@ def _collect_output(
304309 total_bytes = total_bytes ,
305310 )
306311
312+ def _collect_output_after_exit (self , deadline : float ) -> CommandExecutionResult :
313+ """Collect output after the shell exited unexpectedly.
314+
315+ Called when a `BrokenPipeError` occurs while writing to stdin, indicating the
316+ shell process terminated (e.g., due to an 'exit' command).
317+
318+ Args:
319+ deadline: Absolute time by which collection must complete.
320+
321+ Returns:
322+ `CommandExecutionResult` with collected output and the process exit code.
323+ """
324+ collected : list [str ] = []
325+ total_lines = 0
326+ total_bytes = 0
327+ truncated_by_lines = False
328+ truncated_by_bytes = False
329+
330+ # Give reader threads a brief moment to enqueue any remaining output.
331+ drain_timeout = 0.1
332+ drain_deadline = min (time .monotonic () + drain_timeout , deadline )
333+
334+ while True :
335+ remaining = drain_deadline - time .monotonic ()
336+ if remaining <= 0 :
337+ break
338+ try :
339+ source , data = self ._queue .get (timeout = remaining )
340+ except queue .Empty :
341+ break
342+
343+ if data is None :
344+ # EOF marker from a reader thread; continue draining.
345+ continue
346+
347+ total_lines += 1
348+ encoded = data .encode ("utf-8" , "replace" )
349+ total_bytes += len (encoded )
350+
351+ if total_lines > self ._policy .max_output_lines :
352+ truncated_by_lines = True
353+ continue
354+
355+ if (
356+ self ._policy .max_output_bytes is not None
357+ and total_bytes > self ._policy .max_output_bytes
358+ ):
359+ truncated_by_bytes = True
360+ continue
361+
362+ if source == "stderr" :
363+ stripped = data .rstrip ("\n " )
364+ collected .append (f"[stderr] { stripped } " )
365+ if data .endswith ("\n " ):
366+ collected .append ("\n " )
367+ else :
368+ collected .append (data )
369+
370+ # Get exit code from the terminated process.
371+ exit_code : int | None = None
372+ if self ._process :
373+ exit_code = self ._process .poll ()
374+
375+ output = "" .join (collected )
376+ return CommandExecutionResult (
377+ output = output ,
378+ exit_code = exit_code ,
379+ timed_out = False ,
380+ truncated_by_lines = truncated_by_lines ,
381+ truncated_by_bytes = truncated_by_bytes ,
382+ total_lines = total_lines ,
383+ total_bytes = total_bytes ,
384+ )
385+
307386 def _kill_process (self ) -> None :
308387 if not self ._process :
309388 return
0 commit comments