Skip to content

Launch result #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Contributing

We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository. We use `ruff` for linting and formatting, `pyright` for static type checking, and `pytest` for testing. We build for `PyPI`. Our release pipeline is powered by Github Actions.
109 changes: 85 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,51 +6,112 @@
[![Docs](https://readthedocs.org/projects/torchrunx/badge/?version=stable)](https://torchrunx.readthedocs.io)
[![GitHub License](https://img.shields.io/github/license/apoorvkh/torchrunx)](https://github.com/apoorvkh/torchrunx/blob/main/LICENSE)

Automatically launch functions and initialize distributed PyTorch environments on multiple machines
By [Apoorv Khandelwal](http://apoorvkh.com) and [Peter Curtin](https://github.com/pmcurtin)

**Automatically distribute PyTorch functions onto multiple machines or GPUs**

## Installation

```bash
pip install torchrunx
```

Requirements:
- Operating System: Linux
- Python >= 3.8.1
- PyTorch >= 2.0
- Shared filesystem & passwordless SSH between hosts
Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0

Shared filesystem & SSH access if using multiple machines

## Usage
## Minimal example

Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each):

```python
# Simple example
def distributed_function():
pass
def train_model(model, dataset):
trained_model = train(model, dataset)

if int(os.environ["RANK"]) == 0:
torch.save(learned_model, 'model.pt')
return 'model.pt'

return None
```

```python
import torchrunx as trx

trx.launch(
func=distributed_function,
func_kwargs={},
hostnames=["node1", "node2"], # or just: ["localhost"]
model_path = trx.launch(
func=train_model,
func_kwargs={'model': my_model, 'training_dataset': mnist_train},
hostnames=["localhost", "other_node"],
workers_per_host=2
)
)["localhost"][0] # return from rank 0 (first worker on "localhost")
```

### In a SLURM allocation
## Why should I use this?

[`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) is a hammer. `torchrunx` is a chisel.

Whether you have 1 GPU, 8 GPUs, or 8 machines:

Convenience:

- If you don't want to set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself
- If you want to run `python myscript.py` instead of `torchrun myscript.py`
- If you don't want to manually SSH and run `torchrun --master-ip --master-port ...` on every machine (and if you don't want to babysit these machines for hanging failures)

Robustness:

- If you want to run a complex, _modular_ workflow in one script
- no worries about memory leaks or OS failures
- don't parallelize your entire script: just the functions you want

Features:

- Our launch utility is super _Pythonic_
- If you want to run distributed PyTorch functions from Python Notebooks.
- Automatic integration with SLURM

Why not?

- We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR.

## More complicated example

We could also launch multiple functions, with different GPUs:

```python
trx.launch(
# ...
hostnames=trx.slurm_hosts(),
workers_per_host=trx.slurm_workers()
)
def train_model(model, dataset):
trained_model = train(model, dataset)

if int(os.environ["RANK"]) == 0:
torch.save(learned_model, 'model.pt')
return 'model.pt'

return None

def test_model(model_path, test_dataset):
model = torch.load(model_path)
accuracy = inference(model, test_dataset)
return accuracy
```

## Compared to other tools
```python
import torchrunx as trx

model_path = trx.launch(
func=train_model,
func_kwargs={'model': my_model, 'training_dataset': mnist_train},
hostnames=["localhost", "other_node"],
workers_per_host=2
)["localhost"][0] # return from rank 0 (first worker on "localhost")

## Contributing

We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository. We use `ruff` for linting and formatting, `pyright` for static type checking, and `pytest` for testing. We build for `PyPI` and `conda-forge`. Our release pipeline is powered by Github Actions.

accuracy = trx.launch(
func=test_model,
func_kwargs={'model': learned_model, 'test_dataset': mnist_test},
hostnames=["localhost"],
workers_per_host=1
)["localhost"][0]

print(f'Accuracy: {accuracy}')
```
21 changes: 12 additions & 9 deletions docs/source/contributing.rst
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
Contributing
============

Development environment
-----------------------
.. include:: ../../CONTRIBUTING.md
:parser: myst_parser.sphinx_

Ensure you have the latest development environment installed. After cloning our repository, `install pixi <https://pixi.sh/latest/#installation>`_ and run ``pixi shell`` in the repo's root directory. Additionally, we use `ruff <https://github.com/astral-sh/ruff>`_ for linting and formatting, `pyright <https://github.com/microsoft/pyright>`_ for type checking, and ``pytest`` for testing.
.. Development environment
.. -----------------------

Testing
-------
.. Ensure you have the latest development environment installed. After cloning our repository, `install pixi <https://pixi.sh/latest/#installation>`_ and run ``pixi shell`` in the repo's root directory. Additionally, we use `ruff <https://github.com/astral-sh/ruff>`_ for linting and formatting, `pyright <https://github.com/microsoft/pyright>`_ for type checking, and ``pytest`` for testing.

``tests/`` contains ``pytest``-style tests for validating that code changes do not break the core functionality of **torchrunx**. At the moment, we have a few simple CI tests powered by Github action, which are limited to single-agent CPU-only tests due to Github's infrastructure.
.. Testing
.. -------

Contributing
------------
.. ``tests/`` contains ``pytest``-style tests for validating that code changes do not break the core functionality of **torchrunx**. At the moment, we have a few simple CI tests powered by Github action, which are limited to single-agent CPU-only tests due to Github's infrastructure.

.. Contributing
.. ------------

Make a pull request with your changes and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in **torchrunx**.
.. Make a pull request with your changes and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in **torchrunx**.
49 changes: 25 additions & 24 deletions src/torchrunx/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class WorkerArgs:
logger_port: int
main_agent_hostname: str
main_agent_port: int
backend: Literal["mpi", "gloo", "nccl", "ucc", None]
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None
rank: int
local_rank: int
local_world_size: int
Expand Down Expand Up @@ -67,29 +67,30 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce

redirect_stdio_to_logger(logger)

store = dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
host_name=worker_args.main_agent_hostname,
port=worker_args.main_agent_port,
world_size=worker_args.world_size,
is_master=(worker_args.rank == 0),
)

backend = worker_args.backend or ("nccl" if torch.cuda.is_available() else "gloo")

dist.init_process_group(
backend=backend,
world_size=worker_args.world_size,
rank=worker_args.rank,
store=store,
timeout=datetime.timedelta(seconds=worker_args.timeout),
)

os.environ["RANK"] = str(worker_args.rank)
os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)
if worker_args.backend is not None:
os.environ["RANK"] = str(worker_args.rank)
os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)

backend = worker_args.backend
if backend == "auto":
backend = "nccl" if torch.cuda.is_available() else "gloo"

dist.init_process_group(
backend=backend,
world_size=worker_args.world_size,
rank=worker_args.rank,
store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
host_name=worker_args.main_agent_hostname,
port=worker_args.main_agent_port,
world_size=worker_args.world_size,
is_master=(worker_args.rank == 0),
),
timeout=datetime.timedelta(seconds=worker_args.timeout),
)

try:
return worker_args.function()
Expand Down
Loading
Loading