Skip to content

Added argument for agent timeout #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/torchrunx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
parser.add_argument("--world-size", type=int)
parser.add_argument("--rank", type=int)
parser.add_argument("--hostname", type=str)
parser.add_argument("--agent-timeout", type=int, default=30)
args = parser.parse_args()

main(
Expand All @@ -22,4 +23,5 @@
logger_hostname=args.launcher_hostname,
logger_port=args.logger_port,
hostname=args.hostname,
agent_timeout=args.agent_timeout,
)
5 changes: 4 additions & 1 deletion src/torchrunx/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def main(
logger_hostname: str,
logger_port: int,
hostname: str,
agent_timeout: int = 30,
) -> None:
"""Main function for agent processes (started on each node).

Expand All @@ -46,6 +47,7 @@ def main(
logger_hostname: Hostname of the logging server.
logger_port: Port for the logging server.
hostname: Hostname of this agent.
agent_timeout: Agent communication timeout (seconds).
"""
# Setup logging & stream logs to server

Expand All @@ -63,6 +65,7 @@ def main(
launcher_port=launcher_port,
world_size=world_size,
rank=rank,
agent_timeout=agent_timeout,
)

agent_rank = launcher_agent_group.rank - 1
Expand Down Expand Up @@ -102,7 +105,7 @@ def main(
local_world_size=num_workers,
world_size=worker_world_size,
hostname=launcher_payload.hostnames[agent_rank],
timeout=launcher_payload.timeout,
timeout=launcher_payload.worker_timeout,
).serialize(),
)
for i in range(num_workers)
Expand Down
15 changes: 12 additions & 3 deletions src/torchrunx/integrations/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,19 @@ def add_torchrunx_argument_group(parser: ArgumentParser) -> None:
)

group.add_argument(
"--timeout",
"--worker-timeout",
type=int,
default=600,
help="Worker process group timeout in seconds. Default: 600.",
)

group.add_argument(
"--agent-timeout",
type=int,
default=180,
help="Agent communication timeout in seconds. Default: 180.",
)

group.add_argument(
"--copy-env-vars",
type=str,
Expand Down Expand Up @@ -105,7 +112,8 @@ def launcher_from_args(args: Namespace) -> Launcher:
else:
backend = _backend # pyright: ignore [reportAssignmentType]

timeout: int = args.timeout
worker_timeout: int = args.worker_timeout
agent_timeout: int = args.agent_timeout

copy_env_vars: tuple[str, ...] = tuple(args.copy_env_vars)

Expand All @@ -123,7 +131,8 @@ def launcher_from_args(args: Namespace) -> Launcher:
workers_per_host=workers_per_host,
ssh_config_file=ssh_config_file,
backend=backend,
timeout=timeout,
worker_timeout=worker_timeout,
agent_timeout=agent_timeout,
copy_env_vars=copy_env_vars,
extra_env_vars=extra_env_vars,
env_file=env_file,
Expand Down
11 changes: 8 additions & 3 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ class Launcher:
"""`Backend <https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend>`_
for worker process group. By default, NCCL (GPU backend).
Use GLOO for CPU backend. ``None`` for no process group."""
timeout: int = 600
worker_timeout: int = 600
"""Worker process group timeout (seconds)."""
agent_timeout: int = 180
"""Agent communication timeout (seconds)."""
copy_env_vars: tuple[str, ...] = DEFAULT_ENV_VARS_FOR_COPY
"""Environment variables to copy from the launcher process to workers.
Supports Unix pattern matching syntax."""
Expand Down Expand Up @@ -117,7 +119,8 @@ def run( # noqa: C901, PLR0912, PLR0915
)
ssh_config_file = self.ssh_config_file
backend = self.backend
timeout = self.timeout
worker_timeout = self.worker_timeout
agent_timeout = self.agent_timeout

env_vars = {
k: v
Expand Down Expand Up @@ -159,7 +162,7 @@ def handler_factory() -> list[logging.Handler]:
worker_global_ranks=worker_global_ranks,
worker_world_size=sum(workers_per_host),
backend=backend,
timeout=timeout,
worker_timeout=worker_timeout,
)
agent_payloads = None

Expand Down Expand Up @@ -199,6 +202,7 @@ def handler_factory() -> list[logging.Handler]:
env_vars=env_vars,
env_file=env_file,
hostname=hostname,
agent_timeout=agent_timeout,
),
hostname=hostname,
ssh_config_file=ssh_config_file,
Expand All @@ -214,6 +218,7 @@ def handler_factory() -> list[logging.Handler]:
launcher_port=launcher_port,
world_size=world_size,
rank=0,
agent_timeout=agent_timeout,
)

# Sync initial payloads between launcher and agents
Expand Down
5 changes: 3 additions & 2 deletions src/torchrunx/utils/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class LauncherAgentGroup(Generic[FunctionR]):
launcher_port: int
world_size: int
rank: int
agent_timeout: int = 30

def __post_init__(self) -> None:
"""Initialize process group.
Expand All @@ -63,7 +64,7 @@ def __post_init__(self) -> None:
world_size=self.world_size,
is_master=(self.rank == 0),
),
timeout=datetime.timedelta(seconds=30),
timeout=datetime.timedelta(seconds=self.agent_timeout),
)

def _all_gather(self, obj: ObjectT) -> list[ObjectT]:
Expand Down Expand Up @@ -120,7 +121,7 @@ class LauncherPayload:
worker_global_ranks: list[list[int]]
worker_world_size: int
backend: Literal["nccl", "gloo", "mpi", "ucc"] | None
timeout: int
worker_timeout: int


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion src/torchrunx/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def build_launch_command(
env_vars: dict[str, str],
env_file: str | os.PathLike | None,
hostname: str,
agent_timeout: int,
) -> str:
"""Generator for command to launch torchrunx on an agent."""
# shlex.quote prevents shell injection here (resolves S602 in execute_command)
Expand All @@ -147,7 +148,8 @@ def build_launch_command(
f"--logger-port {logger_port} "
f"--world-size {world_size} "
f"--rank {rank} "
f"--hostname {hostname}",
f"--hostname {hostname} "
f"--agent-timeout {agent_timeout}",
)

return " && ".join(commands)
Expand Down