Skip to content

add pg_timeout flag #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/torchrunx/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
import os
import socket
import sys
Expand Down Expand Up @@ -33,6 +34,7 @@ class WorkerArgs:
local_world_size: int
world_size: int
log_file: os.PathLike
timeout: int

def to_bytes(self) -> bytes:
return cloudpickle.dumps(self)
Expand Down Expand Up @@ -81,7 +83,11 @@ def entrypoint(serialized_worker_args: bytes):
if backend is None:
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(
backend=backend, world_size=worker_args.world_size, rank=worker_args.rank, store=store
backend=backend,
world_size=worker_args.world_size,
rank=worker_args.rank,
store=store,
timeout=datetime.timedelta(seconds=worker_args.timeout),
)

os.environ["RANK"] = str(worker_args.rank)
Expand Down Expand Up @@ -130,6 +136,7 @@ def main(launcher_agent_group: LauncherAgentGroup):
local_world_size=num_workers,
world_size=worker_world_size,
log_file=worker_log_files[i],
timeout=launcher_payload.timeout,
).to_bytes(),
)
for i in range(num_workers)
Expand Down
6 changes: 6 additions & 0 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class Launcher:
]
)
env_file: str | os.PathLike | None = None
timeout: int = 600

def run(
self,
Expand Down Expand Up @@ -209,6 +210,7 @@ def run(
worker_global_ranks=worker_global_ranks,
worker_log_files=worker_log_files,
backend=self.backend,
timeout=self.timeout,
)

agent_payloads: list[AgentPayload] = launcher_agent_group.sync_payloads(payload=payload)[1:] # pyright: ignore[reportAssignmentType]
Expand Down Expand Up @@ -270,6 +272,7 @@ def launch(
"NCCL*",
],
env_file: str | os.PathLike | None = None,
timeout: int = 600,
) -> dict[int, Any]:
"""
Launch a distributed PyTorch function on the specified nodes.
Expand All @@ -292,6 +295,8 @@ def launch(
:type env_vars: list[str], optional
:param env_file: An additional environment file that will be sourced prior to executing ``func``, defaults to None
:type env_file: str | os.PathLike | None, optional
:param timeout: Worker process group timeout, defaults to 600
:type timeout: int, optional
:raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func``
:return: A dictionary mapping worker ranks to their output
:rtype: dict[int, Any]
Expand All @@ -304,4 +309,5 @@ def launch(
log_dir=log_dir,
env_vars=env_vars,
env_file=env_file,
timeout=timeout,
).run(func=func, func_kwargs=func_kwargs)
1 change: 1 addition & 0 deletions src/torchrunx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class LauncherPayload:
worker_global_ranks: list[list[int]]
worker_log_files: list[list[Path]]
backend: Literal["mpi", "gloo", "nccl", "ucc", None]
timeout: int


@dataclass
Expand Down