diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..fa436244 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,3 @@ +# Contributing + +We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository. We use `ruff` for linting and formatting, `pyright` for static type checking, and `pytest` for testing. We build for `PyPI`. Our release pipeline is powered by Github Actions. diff --git a/README.md b/README.md index e65a3336..51f1f772 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,9 @@ [![Docs](https://readthedocs.org/projects/torchrunx/badge/?version=stable)](https://torchrunx.readthedocs.io) [![GitHub License](https://img.shields.io/github/license/apoorvkh/torchrunx)](https://github.com/apoorvkh/torchrunx/blob/main/LICENSE) -Automatically launch functions and initialize distributed PyTorch environments on multiple machines +By [Apoorv Khandelwal](http://apoorvkh.com) and [Peter Curtin](https://github.com/pmcurtin) + +**Automatically distribute PyTorch functions onto multiple machines or GPUs** ## Installation @@ -14,43 +16,102 @@ Automatically launch functions and initialize distributed PyTorch environments o pip install torchrunx ``` -Requirements: -- Operating System: Linux -- Python >= 3.8.1 -- PyTorch >= 2.0 -- Shared filesystem & passwordless SSH between hosts +Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0 + +Shared filesystem & SSH access if using multiple machines -## Usage +## Minimal example + +Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each): ```python -# Simple example -def distributed_function(): - pass +def train_model(model, dataset): + trained_model = train(model, dataset) + + if int(os.environ["RANK"]) == 0: + torch.save(learned_model, 'model.pt') + return 'model.pt' + + return None ``` ```python import torchrunx as trx -trx.launch( - func=distributed_function, - func_kwargs={}, - hostnames=["node1", "node2"], # or just: ["localhost"] +model_path = trx.launch( + func=train_model, + func_kwargs={'model': my_model, 'training_dataset': mnist_train}, + hostnames=["localhost", "other_node"], workers_per_host=2 -) +)["localhost"][0] # return from rank 0 (first worker on "localhost") ``` -### In a SLURM allocation +## Why should I use this? + +[`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) is a hammer. `torchrunx` is a chisel. + +Whether you have 1 GPU, 8 GPUs, or 8 machines: + +Convenience: + +- If you don't want to set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself +- If you want to run `python myscript.py` instead of `torchrun myscript.py` +- If you don't want to manually SSH and run `torchrun --master-ip --master-port ...` on every machine (and if you don't want to babysit these machines for hanging failures) + +Robustness: + +- If you want to run a complex, _modular_ workflow in one script + - no worries about memory leaks or OS failures + - don't parallelize your entire script: just the functions you want + +Features: + +- Our launch utility is super _Pythonic_ +- If you want to run distributed PyTorch functions from Python Notebooks. +- Automatic integration with SLURM + +Why not? + +- We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR. + +## More complicated example + +We could also launch multiple functions, with different GPUs: ```python -trx.launch( - # ... - hostnames=trx.slurm_hosts(), - workers_per_host=trx.slurm_workers() -) +def train_model(model, dataset): + trained_model = train(model, dataset) + + if int(os.environ["RANK"]) == 0: + torch.save(learned_model, 'model.pt') + return 'model.pt' + + return None + +def test_model(model_path, test_dataset): + model = torch.load(model_path) + accuracy = inference(model, test_dataset) + return accuracy ``` -## Compared to other tools +```python +import torchrunx as trx + +model_path = trx.launch( + func=train_model, + func_kwargs={'model': my_model, 'training_dataset': mnist_train}, + hostnames=["localhost", "other_node"], + workers_per_host=2 +)["localhost"][0] # return from rank 0 (first worker on "localhost") -## Contributing -We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository. We use `ruff` for linting and formatting, `pyright` for static type checking, and `pytest` for testing. We build for `PyPI` and `conda-forge`. Our release pipeline is powered by Github Actions. + +accuracy = trx.launch( + func=test_model, + func_kwargs={'model': learned_model, 'test_dataset': mnist_test}, + hostnames=["localhost"], + workers_per_host=1 +)["localhost"][0] + +print(f'Accuracy: {accuracy}') +``` diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index fb9554e3..707e376c 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -1,17 +1,20 @@ Contributing ============ -Development environment ------------------------ +.. include:: ../../CONTRIBUTING.md + :parser: myst_parser.sphinx_ -Ensure you have the latest development environment installed. After cloning our repository, `install pixi `_ and run ``pixi shell`` in the repo's root directory. Additionally, we use `ruff `_ for linting and formatting, `pyright `_ for type checking, and ``pytest`` for testing. +.. Development environment +.. ----------------------- -Testing -------- +.. Ensure you have the latest development environment installed. After cloning our repository, `install pixi `_ and run ``pixi shell`` in the repo's root directory. Additionally, we use `ruff `_ for linting and formatting, `pyright `_ for type checking, and ``pytest`` for testing. -``tests/`` contains ``pytest``-style tests for validating that code changes do not break the core functionality of **torchrunx**. At the moment, we have a few simple CI tests powered by Github action, which are limited to single-agent CPU-only tests due to Github's infrastructure. +.. Testing +.. ------- -Contributing ------------- +.. ``tests/`` contains ``pytest``-style tests for validating that code changes do not break the core functionality of **torchrunx**. At the moment, we have a few simple CI tests powered by Github action, which are limited to single-agent CPU-only tests due to Github's infrastructure. + +.. Contributing +.. ------------ -Make a pull request with your changes and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in **torchrunx**. \ No newline at end of file +.. Make a pull request with your changes and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in **torchrunx**. \ No newline at end of file diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 04d1ec92..0f43506b 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -32,7 +32,7 @@ class WorkerArgs: logger_port: int main_agent_hostname: str main_agent_port: int - backend: Literal["mpi", "gloo", "nccl", "ucc", None] + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None rank: int local_rank: int local_world_size: int @@ -67,29 +67,30 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce redirect_stdio_to_logger(logger) - store = dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] - host_name=worker_args.main_agent_hostname, - port=worker_args.main_agent_port, - world_size=worker_args.world_size, - is_master=(worker_args.rank == 0), - ) - - backend = worker_args.backend or ("nccl" if torch.cuda.is_available() else "gloo") - - dist.init_process_group( - backend=backend, - world_size=worker_args.world_size, - rank=worker_args.rank, - store=store, - timeout=datetime.timedelta(seconds=worker_args.timeout), - ) - - os.environ["RANK"] = str(worker_args.rank) - os.environ["LOCAL_RANK"] = str(worker_args.local_rank) - os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size) - os.environ["WORLD_SIZE"] = str(worker_args.world_size) - os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname - os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) + if worker_args.backend is not None: + os.environ["RANK"] = str(worker_args.rank) + os.environ["LOCAL_RANK"] = str(worker_args.local_rank) + os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size) + os.environ["WORLD_SIZE"] = str(worker_args.world_size) + os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname + os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) + + backend = worker_args.backend + if backend == "auto": + backend = "nccl" if torch.cuda.is_available() else "gloo" + + dist.init_process_group( + backend=backend, + world_size=worker_args.world_size, + rank=worker_args.rank, + store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] + host_name=worker_args.main_agent_hostname, + port=worker_args.main_agent_port, + world_size=worker_args.world_size, + is_master=(worker_args.rank == 0), + ), + timeout=datetime.timedelta(seconds=worker_args.timeout), + ) try: return worker_args.function() diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index fa73ae04..cd4a6098 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -10,23 +10,19 @@ import subprocess import sys from dataclasses import dataclass -from functools import partial +from functools import partial, reduce from logging import Handler from multiprocessing import Process +from operator import add from pathlib import Path -from typing import Any, Callable, Literal, Sequence +from typing import Any, Callable, Literal, overload import fabric import torch.distributed as dist from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers from .logging_utils import LogRecordSocketReceiver, default_handlers -from .utils import ( - LauncherAgentGroup, - LauncherPayload, - WorkerException, - get_open_port, -) +from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: @@ -80,13 +76,13 @@ def build_logging_server( ) -def build_command( +def build_launch_command( launcher_hostname: str, launcher_port: int, logger_port: int, world_size: int, rank: int, - env_vars: Sequence[str], + env_vars: list[str] | tuple[str], env_file: str | os.PathLike | None, ) -> str: # shlex.quote prevents shell injection here (resolves S602 in execute_command) @@ -122,33 +118,35 @@ def build_command( return " && ".join(commands) -def is_localhost(hostname_or_ip: str) -> bool: - # check if host is "loopback" address (i.e. designated to send to self) - try: - ip = ipaddress.ip_address(hostname_or_ip) - except ValueError: - ip = ipaddress.ip_address(socket.gethostbyname(hostname_or_ip)) - if ip.is_loopback: - return True - # else compare local interface addresses between host and localhost - host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(ip), None)] - localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] - return len(set(host_addrs) & set(localhost_addrs)) > 0 - - def execute_command( command: str, hostname: str, ssh_config_file: str | os.PathLike | None = None, ) -> None: - if is_localhost(hostname): + is_localhost = True + _hostname_or_ip = hostname + try: + _ip = ipaddress.ip_address(_hostname_or_ip) + except ValueError: + _ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip)) + if not _ip.is_loopback: + # compare local interface addresses between host and localhost + _host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)] + _localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] + is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0 + + if is_localhost: # S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.8/library/subprocess.html#security-considerations) # Made sure to shlex.quote arguments in build_command to prevent shell injection subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602 else: + runtime_ssh_path = ssh_config_file + if isinstance(ssh_config_file, os.PathLike): + runtime_ssh_path = str(ssh_config_file) + with fabric.Connection( host=hostname, - config=fabric.Config(runtime_ssh_path=ssh_config_file), + config=fabric.Config(runtime_ssh_path=runtime_ssh_path), ) as conn: conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True) @@ -158,9 +156,9 @@ class Launcher: hostnames: list[str] | Literal["auto", "slurm"] = "auto" workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto" ssh_config_file: str | os.PathLike | None = None - backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto" log_handlers: list[Handler] | Literal["auto"] | None = "auto" - env_vars: Sequence[str] = ( + env_vars: tuple[str] = ( # pyright: ignore [reportAssignmentType] "PATH", "LD_LIBRARY", "LIBRARY_PATH", @@ -178,7 +176,7 @@ def run( # noqa: C901, PLR0912 func: Callable, func_args: tuple[Any] | None = None, func_kwargs: dict[str, Any] | None = None, - ) -> dict[str, dict[int, Any]]: + ) -> LaunchResult: """ Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch` @@ -231,7 +229,7 @@ def run( # noqa: C901, PLR0912 for i, hostname in enumerate(hostnames): execute_command( - command=build_command( + command=build_launch_command( launcher_hostname=launcher_hostname, launcher_port=launcher_port, logger_port=log_receiver.port, @@ -282,7 +280,7 @@ def run( # noqa: C901, PLR0912 # raises specific exception if any agent fails for s in agent_statuses: - for value in s.return_values.values(): + for value in s.return_values: if isinstance(value, WorkerException): raise value.exception @@ -307,10 +305,7 @@ def run( # noqa: C901, PLR0912 ssh_config_file=self.ssh_config_file, ) - return { - hostname: agent_status.return_values - for hostname, agent_status in zip(hostnames, agent_statuses) - } + return LaunchResult(hostnames=hostnames, agent_statuses=agent_statuses) def launch( @@ -320,9 +315,9 @@ def launch( hostnames: list[str] | Literal["auto", "slurm"] = "auto", workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto", ssh_config_file: str | os.PathLike | None = None, - backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None, - log_handlers: list[Handler] | Literal["auto"] = "auto", - env_vars: Sequence[str] = ( + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto", + log_handlers: list[Handler] | Literal["auto"] | None = "auto", + env_vars: tuple[str] = ( # pyright: ignore [reportArgumentType] "PATH", "LD_LIBRARY", "LIBRARY_PATH", @@ -334,7 +329,7 @@ def launch( ), env_file: str | os.PathLike | None = None, timeout: int = 600, -) -> dict[str, dict[int, Any]]: +) -> LaunchResult: """ Launch a distributed PyTorch function on the specified nodes. @@ -376,3 +371,48 @@ def launch( env_file=env_file, timeout=timeout, ).run(func=func, func_args=func_args, func_kwargs=func_kwargs) + + +class LaunchResult: + def __init__(self, hostnames: list[str], agent_statuses: list[AgentStatus]) -> None: + self.hostnames: list[str] = hostnames + self.return_values: list[list[Any]] = [s.return_values for s in agent_statuses] + + @overload + def all(self) -> dict[str, list[Any]]: + pass + + @overload + def all(self, by: Literal["hostname"]) -> dict[str, list[Any]]: + pass + + @overload + def all(self, by: Literal["rank"]) -> list[Any]: + pass + + def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]: + if by == "hostname": + return dict(zip(self.hostnames, self.return_values)) + elif by == "rank": # noqa: RET505 + return reduce(add, self.return_values) + + msg = "Invalid argument: expected by=('hostname' | 'rank')" + raise TypeError(msg) + + def values(self, hostname: str) -> list[Any]: + host_idx = self.hostnames.index(hostname) + return self.return_values[host_idx] + + def value(self, rank: int) -> Any: + if rank < 0: + msg = f"Rank {rank} must be larger than 0" + raise ValueError(msg) + + for values in self.return_values: + if rank >= len(values): + rank -= len(values) + else: + return values[rank] + + msg = f"Rank {rank} larger than world_size" + raise ValueError(msg) diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 0fafec9d..3770e93d 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -20,51 +20,6 @@ def get_open_port() -> int: return s.getsockname()[1] -@dataclass -class WorkerException: - exception: Exception - - -@dataclass -class LauncherPayload: - fn: Callable - hostnames: list[str] - worker_global_ranks: list[list[int]] - worker_world_size: int - backend: Literal["mpi", "gloo", "nccl", "ucc", None] - timeout: int - - -@dataclass -class AgentPayload: - hostname: str - port: int - process_id: int - - -@dataclass -class AgentStatus: - state: Literal["running", "failed", "done"] - return_values: dict[int, Any | WorkerException] = field(default_factory=dict) - - @classmethod - def from_result(cls, result: RunProcsResult | None) -> Self: - if result is None: - return cls(state="running") - - return_values = result.return_values - - if any(isinstance(v, WorkerException) for v in return_values.values()): - state = "failed" - else: - state = "done" - - return cls( - state=state, - return_values=return_values, - ) - - @dataclass class LauncherAgentGroup: launcher_hostname: str @@ -115,3 +70,48 @@ def sync_agent_statuses(self, status: AgentStatus | None) -> list[AgentStatus]: def shutdown(self) -> None: dist.destroy_process_group(group=self.group) + + +@dataclass +class LauncherPayload: + fn: Callable + hostnames: list[str] + worker_global_ranks: list[list[int]] + worker_world_size: int + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None + timeout: int + + +@dataclass +class AgentPayload: + hostname: str + port: int + process_id: int + + +@dataclass +class WorkerException: + exception: Exception + + +@dataclass +class AgentStatus: + state: Literal["running", "failed", "done"] + return_values: list[Any | WorkerException] = field( + default_factory=list + ) # indexed by local rank + + @classmethod + def from_result(cls, result: RunProcsResult | None) -> Self: + if result is None: + return cls(state="running") + + return_values = list(result.return_values.values()) + + failed = any(isinstance(v, WorkerException) for v in return_values) + state = "failed" if failed else "done" + + return cls( + state=state, + return_values=return_values, + ) diff --git a/tests/test_ci.py b/tests/test_ci.py index f72f3ef4..64cd1e93 100644 --- a/tests/test_ci.py +++ b/tests/test_ci.py @@ -37,8 +37,7 @@ def dist_func() -> torch.Tensor: backend="gloo", # log_dir="./test_logs" ) - results = next(iter(r.values())) - assert torch.all(results[0] == results[1]) + assert torch.all(r.value(0) == r.value(1)) def test_logging() -> None: diff --git a/tests/test_func.py b/tests/test_func.py index 8fb264bf..e8033b4e 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -13,7 +13,7 @@ def test_launch() -> None: workers_per_host="slurm", ) - result_values = [v for host_results in result.values() for v in host_results.values()] + result_values = result.all(by="rank") t = True for i in range(len(result_values)):