diff --git a/pyproject.toml b/pyproject.toml index be8c0d6d..6f029d8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"] line-length = 100 src = ["src", "tests"] [tool.ruff.lint] -extend-select = ["I"] +select = ["E", "F", "B", "UP", "I"] [tool.pyright] include = ["src", "tests"] diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 5a7be1c8..f4dfab33 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -7,7 +7,7 @@ import sys import tempfile from dataclasses import dataclass -from typing import Callable, Literal +from typing import Any, Callable, Literal import cloudpickle import torch @@ -20,7 +20,7 @@ AgentPayload, AgentStatus, LauncherAgentGroup, - LauncherPayload, + WorkerException, get_open_port, ) @@ -48,7 +48,7 @@ def from_bytes(cls, serialized: bytes) -> Self: return cloudpickle.loads(serialized) -def entrypoint(serialized_worker_args: bytes): +def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException: worker_args = WorkerArgs.from_bytes(serialized_worker_args) logger = logging.getLogger() @@ -93,13 +93,14 @@ def entrypoint(serialized_worker_args: bytes): logger.debug(f"executing function: {worker_args.function}") - r = worker_args.function() - - # flush streams - sys.stdout.flush() - sys.stderr.flush() - - return r + try: + return worker_args.function() + except Exception as e: + logger.error(e) + return WorkerException(exception=e) + finally: + sys.stdout.flush() + sys.stderr.flush() def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int): @@ -111,9 +112,8 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ process_id=os.getpid(), ) - all_payloads = launcher_agent_group.sync_payloads(payload=payload) - launcher_payload: LauncherPayload = all_payloads[0] # pyright: ignore[reportAssignmentType] - main_agent_payload: AgentPayload = all_payloads[1] # pyright: ignore[reportAssignmentType] + launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload) + main_agent_payload = agent_payloads[0] hostname = launcher_payload.hostnames[agent_rank] worker_world_size = launcher_payload.worker_world_size @@ -169,20 +169,19 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ logger.info("starting processes") try: - status = AgentStatus() + status = None while True: - if status.is_running(): + if status is None or status.state == "running": status = AgentStatus.from_result( result=ctx.wait(5), worker_global_ranks=worker_global_ranks ) agent_statuses = launcher_agent_group.sync_agent_statuses(status=status) - if all(s.is_done() for s in agent_statuses): + if all(s.state == "done" for s in agent_statuses): + break + elif any(s.state == "failed" for s in agent_statuses): break - - if any(s.is_failed() for s in agent_statuses): - raise RuntimeError() except: raise finally: diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 1f9db00c..89cde697 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -9,11 +9,11 @@ import subprocess import sys from collections import ChainMap -from dataclasses import dataclass, field +from dataclasses import dataclass from functools import partial from logging import Handler from multiprocessing import Process -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, Sequence import fabric import torch.distributed as dist @@ -21,10 +21,9 @@ from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers from .logging_utils import LogRecordSocketReceiver, default_handlers from .utils import ( - AgentPayload, - AgentStatus, LauncherAgentGroup, LauncherPayload, + WorkerException, get_open_port, ) @@ -59,22 +58,20 @@ def execute_command( @dataclass class Launcher: - hostnames: list[str] | Literal["auto", "slurm"] = field(default_factory=lambda: ["localhost"]) - workers_per_host: int | list[int] | Literal["auto", "slurm"] = 1 + hostnames: list[str] | Literal["auto", "slurm"] = "auto" + workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto" ssh_config_file: str | os.PathLike | None = None backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None log_handlers: list[Handler] | Literal["auto"] | None = "auto" - env_vars: list[str] = field( - default_factory=lambda: [ - "PATH", - "LD_LIBRARY", - "LIBRARY_PATH", - "PYTHON*", - "CUDA*", - "TORCH*", - "PYTORCH*", - "NCCL*", - ] + env_vars: Sequence[str] = ( + "PATH", + "LD_LIBRARY", + "LIBRARY_PATH", + "PYTHON*", + "CUDA*", + "TORCH*", + "PYTORCH*", + "NCCL*", ) env_file: str | os.PathLike | None = None timeout: int = 600 @@ -82,8 +79,8 @@ class Launcher: def run( self, func: Callable, - func_args: tuple[Any] = tuple(), - func_kwargs: dict[str, Any] = {}, + func_args: tuple[Any] | None = None, + func_kwargs: dict[str, Any] | None = None, ) -> dict[int, Any]: """ Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch` @@ -205,6 +202,11 @@ def run( host_ranks = range(_cumulative_workers[n], _cumulative_workers[n + 1]) worker_global_ranks.append(list(host_ranks)) + if func_args is None: + func_args = tuple() + if func_kwargs is None: + func_kwargs = dict() + payload = LauncherPayload( fn=partial(func, *func_args, **func_kwargs), hostnames=self.hostnames, @@ -214,30 +216,23 @@ def run( timeout=self.timeout, ) - agent_payloads: list[AgentPayload] = launcher_agent_group.sync_payloads(payload=payload)[1:] # pyright: ignore[reportAssignmentType] + launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload) agent_pids = [p.process_id for p in agent_payloads] # loop to monitor agent statuses (until failed or done) try: while True: - agent_statuses = launcher_agent_group.sync_agent_statuses(status=AgentStatus()) + agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) + + for s in agent_statuses: + if s.state == "failed": + for value in s.return_values.values(): + if isinstance(value, WorkerException): + raise value.exception - if all(s.is_done() for s in agent_statuses): + if all(s.state == "done" for s in agent_statuses): break - if any(s.is_failed() for s in agent_statuses): - # TODO: cleaner way to print these? - e = "" - for i, s in enumerate(agent_statuses): - if s is not None and s.is_failed(): - for k, v in s.failures.items(): - e += f"Node {i}, local worker {k} exited with error: " - if isinstance(v.message, str): - e += f"{v.message}\n" - else: - e += f"{v.message['message']}\n" - e += f"{v.message['extraInfo']['py_callstack']}\n\n" - raise RuntimeError(e) except: # cleanup: SIGTERM all agents for agent_pid, agent_hostname in zip(agent_pids, self.hostnames): @@ -259,14 +254,14 @@ def run( def launch( func: Callable, - func_args: tuple[Any] = tuple(), - func_kwargs: dict[str, Any] = {}, - hostnames: list[str] | Literal["auto", "slurm"] = ["localhost"], - workers_per_host: int | list[int] | Literal["auto", "slurm"] = 1, + func_args: tuple[Any] | None = None, + func_kwargs: dict[str, Any] | None = None, + hostnames: list[str] | Literal["auto", "slurm"] = "auto", + workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto", ssh_config_file: str | os.PathLike | None = None, backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None, log_handlers: list[Handler] | Literal["auto"] = "auto", - env_vars: list[str] = [ + env_vars: Sequence[str] = ( "PATH", "LD_LIBRARY", "LIBRARY_PATH", @@ -275,7 +270,7 @@ def launch( "TORCH*", "PYTORCH*", "NCCL*", - ], + ), env_file: str | os.PathLike | None = None, timeout: int = 600, ) -> dict[int, Any]: diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 43f040c6..3a14d342 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -9,7 +9,6 @@ import cloudpickle import torch.distributed as dist from torch.distributed.elastic.multiprocessing.api import RunProcsResult -from torch.distributed.elastic.multiprocessing.errors import ProcessFailure from typing_extensions import Self @@ -20,6 +19,11 @@ def get_open_port() -> int: return port +@dataclass +class WorkerException: + exception: Exception + + @dataclass class LauncherPayload: fn: Callable @@ -39,33 +43,25 @@ class AgentPayload: @dataclass class AgentStatus: - running: bool = True - failed: bool = False - return_values: dict[int, Any] = field(default_factory=dict) - failures: dict[int, ProcessFailure] = field(default_factory=dict) - stdouts: dict[int, str] = field(default_factory=dict) - stderrs: dict[int, str] = field(default_factory=dict) + state: Literal["running", "failed", "done"] + return_values: dict[int, Any | WorkerException] = field(default_factory=dict) @classmethod def from_result(cls, result: RunProcsResult | None, worker_global_ranks: list[int]) -> Self: if result is None: - return cls() + return cls(state="running") - return cls( - running=False, - failed=result.is_failed(), - return_values={worker_global_ranks[k]: v for k, v in result.return_values.items()}, - failures={worker_global_ranks[k]: v for k, v in result.failures.items()}, - ) + return_values = result.return_values - def is_running(self) -> bool: - return self.running + if any(isinstance(v, WorkerException) for v in return_values.values()): + state = "failed" + else: + state = "done" - def is_failed(self) -> bool: - return self.failed - - def is_done(self) -> bool: - return not self.running and not self.failed + return cls( + state=state, + return_values={worker_global_ranks[k]: v for k, v in return_values.items()}, + ) @dataclass @@ -98,15 +94,18 @@ def _deserialize(self, serialized: bytes) -> Any: def _all_gather(self, object: Any) -> list: """gather object from every rank to list on every rank""" object_bytes = self._serialize(object) - object_list = [bytes()] * self.world_size + object_list = [b""] * self.world_size dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group) object_list = [self._deserialize(o) for o in object_list] return object_list def sync_payloads( self, payload: LauncherPayload | AgentPayload - ) -> list[LauncherPayload | AgentPayload]: - return self._all_gather(object=payload) - - def sync_agent_statuses(self, status: AgentStatus) -> list[AgentStatus]: - return self._all_gather(object=status)[1:] + ) -> tuple[LauncherPayload, list[AgentPayload]]: + payloads = self._all_gather(object=payload) + launcher_payload = payloads[0] + agent_payloads = payloads[1:] + return launcher_payload, agent_payloads + + def sync_agent_statuses(self, status: AgentStatus | None) -> list[AgentStatus]: + return self._all_gather(object=status)[1:] # [0] is launcher (status=None) diff --git a/tests/test_CI.py b/tests/test_CI.py index f507798e..b86cad64 100644 --- a/tests/test_CI.py +++ b/tests/test_CI.py @@ -61,7 +61,7 @@ def dist_func(): assert len(log_files) == 3 for file in log_files: - with open(f"{tmp}/{file}", "r") as f: + with open(f"{tmp}/{file}") as f: contents = f.read() print(contents) if file.endswith("[0].log"): @@ -79,7 +79,7 @@ def error_func(): tmp = tempfile.mkdtemp() os.environ["TORCHRUNX_DIR"] = tmp - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(ValueError) as excinfo: trx.launch( func=error_func, func_kwargs={},