Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions tests/tools/test_code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,30 @@ def test_convenience_helpers_present(self):
self.assertIn("def json_parse(", src)
self.assertIn("def shell_quote(", src)
self.assertIn("def retry(", src)
self.assertIn("import json, os, socket, shlex, time", src)
self.assertIn("import json, os, socket, shlex, threading, time", src)

def test_file_transport_uses_tempfile_fallback_for_rpc_dir(self):
src = generate_hermes_tools_module(["terminal"], transport="file")
self.assertIn("import json, os, shlex, tempfile, time", src)
self.assertIn("import json, os, shlex, tempfile, threading, time", src)
self.assertIn("os.path.join(tempfile.gettempdir(), \"hermes_rpc\")", src)
self.assertNotIn('os.environ.get("HERMES_RPC_DIR", "/tmp/hermes_rpc")', src)

def test_uds_transport_serializes_concurrent_calls(self):
"""Regression: UDS _call() must hold a lock across send+recv so that
concurrent tool calls from multiple threads don't interleave on the
shared socket and receive each other's responses."""
src = generate_hermes_tools_module(["terminal"], transport="uds")
self.assertIn("_call_lock = threading.Lock()", src)
self.assertIn("with _call_lock:", src)

def test_file_transport_serializes_seq_allocation(self):
"""Regression: file transport _call() must allocate `_seq` under a
lock, otherwise concurrent threads can pick the same seq and clobber
each other's request files."""
src = generate_hermes_tools_module(["terminal"], transport="file")
self.assertIn("_seq_lock = threading.Lock()", src)
self.assertIn("with _seq_lock:", src)


class TestExecuteCodeRemoteTempDir(unittest.TestCase):
def test_execute_remote_uses_backend_temp_dir_for_sandbox(self):
Expand Down Expand Up @@ -226,6 +242,64 @@ def test_runtime_exception(self):
result = self._run("raise ValueError('test error')")
self.assertEqual(result["status"], "error")

def test_concurrent_tool_calls_match_responses(self):
"""Regression for the UDS RPC race: multiple threads inside the
sandbox calling terminal() concurrently must each receive their own
response, not another thread's.

Before the fix, `_sock` and the recv-loop were shared without a
lock, so responses (written FIFO by the single-threaded server)
got delivered to whichever client thread happened to win the
recv() race. That surfaced as each thread seeing another thread's
output.

The mock dispatcher sleeps briefly to guarantee the requests
overlap on the socket.
"""
code = '''
import threading
from concurrent.futures import ThreadPoolExecutor
from hermes_tools import terminal

N = 10

def call(i):
r = terminal(f"echo TAG-{i}")
return i, r.get("output", "")

with ThreadPoolExecutor(max_workers=N) as ex:
results = list(ex.map(call, range(N)))

mismatches = [(i, out) for i, out in results if f"TAG-{i}" not in out]
if mismatches:
print(f"MISMATCH {len(mismatches)}/{N}: {mismatches[:3]}")
else:
print(f"OK {N}/{N}")
'''

def slow_mock(function_name, function_args, task_id=None, user_task=None):
import time as _t
if function_name == "terminal":
_t.sleep(0.05) # ensure requests overlap on the socket
cmd = function_args.get("command", "")
# Echo semantics: strip leading "echo " and return the rest
out = cmd[5:] if cmd.startswith("echo ") else f"mock: {cmd}"
return json.dumps({"output": out, "exit_code": 0})
return _mock_handle_function_call(
function_name, function_args, task_id=task_id, user_task=user_task
)

with patch("model_tools.handle_function_call", side_effect=slow_mock):
raw = execute_code(
code=code,
task_id="test-concurrent",
enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
)
result = json.loads(raw)
self.assertEqual(result["status"], "success", msg=result)
self.assertIn("OK 10/10", result["output"],
msg=f"Concurrent tool calls mismatched: {result['output']!r}")

def test_excluded_tool_returns_error(self):
"""Script calling a tool not in the allow-list gets an error from RPC."""
code = """
Expand Down
42 changes: 27 additions & 15 deletions tools/code_execution_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,14 @@ def retry(fn, max_attempts=3, delay=2):

_UDS_TRANSPORT_HEADER = '''\
"""Auto-generated Hermes tools RPC stubs."""
import json, os, socket, shlex, time
import json, os, socket, shlex, threading, time

_sock = None
# The RPC server handles a single client connection serially and has no
# request-id in the protocol, so concurrent _call() invocations from multiple
# threads (e.g. ThreadPoolExecutor) would race on the shared socket and get
# each other's responses. Serialize the entire send+recv round-trip.
_call_lock = threading.Lock()
''' + _COMMON_HELPERS + '''\

def _connect():
Expand All @@ -239,17 +244,18 @@ def _connect():

def _call(tool_name, args):
"""Send a tool call to the parent process and return the parsed result."""
conn = _connect()
request = json.dumps({"tool": tool_name, "args": args}) + "\\n"
conn.sendall(request.encode())
buf = b""
while True:
chunk = conn.recv(65536)
if not chunk:
raise RuntimeError("Agent process disconnected")
buf += chunk
if buf.endswith(b"\\n"):
break
with _call_lock:
conn = _connect()
conn.sendall(request.encode())
buf = b""
while True:
chunk = conn.recv(65536)
if not chunk:
raise RuntimeError("Agent process disconnected")
buf += chunk
if buf.endswith(b"\\n"):
break
raw = buf.decode().strip()
result = json.loads(raw)
if isinstance(result, str):
Expand All @@ -265,24 +271,30 @@ def _call(tool_name, args):

_FILE_TRANSPORT_HEADER = '''\
"""Auto-generated Hermes tools RPC stubs (file-based transport)."""
import json, os, shlex, tempfile, time
import json, os, shlex, tempfile, threading, time

_RPC_DIR = os.environ.get("HERMES_RPC_DIR") or os.path.join(tempfile.gettempdir(), "hermes_rpc")
_seq = 0
# `_seq += 1` is not atomic (read-modify-write), so concurrent _call()
# invocations from multiple threads could allocate the same sequence number
# and clobber each other's request files. Guard seq allocation with a lock.
_seq_lock = threading.Lock()
''' + _COMMON_HELPERS + '''\

def _call(tool_name, args):
"""Send a tool call request via file-based RPC and wait for response."""
global _seq
_seq += 1
seq_str = f"{_seq:06d}"
with _seq_lock:
_seq += 1
seq = _seq
seq_str = f"{seq:06d}"
req_file = os.path.join(_RPC_DIR, f"req_{seq_str}")
res_file = os.path.join(_RPC_DIR, f"res_{seq_str}")

# Write request atomically (write to .tmp, then rename)
tmp = req_file + ".tmp"
with open(tmp, "w") as f:
json.dump({"tool": tool_name, "args": args, "seq": _seq}, f)
json.dump({"tool": tool_name, "args": args, "seq": seq}, f)
os.rename(tmp, req_file)

# Wait for response with adaptive polling
Expand Down
Loading