Skip to content

Commit c4c75fc

Browse files
authored
Merge pull request #69 from apoorvkh/launch-result
Launch result
2 parents f46486d + 7850aeb commit c4c75fc

File tree

8 files changed

+251
-144
lines changed

8 files changed

+251
-144
lines changed

CONTRIBUTING.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Contributing
2+
3+
We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository. We use `ruff` for linting and formatting, `pyright` for static type checking, and `pytest` for testing. We build for `PyPI`. Our release pipeline is powered by Github Actions.

README.md

Lines changed: 85 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,51 +6,112 @@
66
[![Docs](https://readthedocs.org/projects/torchrunx/badge/?version=stable)](https://torchrunx.readthedocs.io)
77
[![GitHub License](https://img.shields.io/github/license/apoorvkh/torchrunx)](https://github.com/apoorvkh/torchrunx/blob/main/LICENSE)
88

9-
Automatically launch functions and initialize distributed PyTorch environments on multiple machines
9+
By [Apoorv Khandelwal](http://apoorvkh.com) and [Peter Curtin](https://github.com/pmcurtin)
10+
11+
**Automatically distribute PyTorch functions onto multiple machines or GPUs**
1012

1113
## Installation
1214

1315
```bash
1416
pip install torchrunx
1517
```
1618

17-
Requirements:
18-
- Operating System: Linux
19-
- Python >= 3.8.1
20-
- PyTorch >= 2.0
21-
- Shared filesystem & passwordless SSH between hosts
19+
Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0
20+
21+
Shared filesystem & SSH access if using multiple machines
2222

23-
## Usage
23+
## Minimal example
24+
25+
Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each):
2426

2527
```python
26-
# Simple example
27-
def distributed_function():
28-
pass
28+
def train_model(model, dataset):
29+
trained_model = train(model, dataset)
30+
31+
if int(os.environ["RANK"]) == 0:
32+
torch.save(learned_model, 'model.pt')
33+
return 'model.pt'
34+
35+
return None
2936
```
3037

3138
```python
3239
import torchrunx as trx
3340

34-
trx.launch(
35-
func=distributed_function,
36-
func_kwargs={},
37-
hostnames=["node1", "node2"], # or just: ["localhost"]
41+
model_path = trx.launch(
42+
func=train_model,
43+
func_kwargs={'model': my_model, 'training_dataset': mnist_train},
44+
hostnames=["localhost", "other_node"],
3845
workers_per_host=2
39-
)
46+
)["localhost"][0] # return from rank 0 (first worker on "localhost")
4047
```
4148

42-
### In a SLURM allocation
49+
## Why should I use this?
50+
51+
[`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) is a hammer. `torchrunx` is a chisel.
52+
53+
Whether you have 1 GPU, 8 GPUs, or 8 machines:
54+
55+
Convenience:
56+
57+
- If you don't want to set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself
58+
- If you want to run `python myscript.py` instead of `torchrun myscript.py`
59+
- If you don't want to manually SSH and run `torchrun --master-ip --master-port ...` on every machine (and if you don't want to babysit these machines for hanging failures)
60+
61+
Robustness:
62+
63+
- If you want to run a complex, _modular_ workflow in one script
64+
- no worries about memory leaks or OS failures
65+
- don't parallelize your entire script: just the functions you want
66+
67+
Features:
68+
69+
- Our launch utility is super _Pythonic_
70+
- If you want to run distributed PyTorch functions from Python Notebooks.
71+
- Automatic integration with SLURM
72+
73+
Why not?
74+
75+
- We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR.
76+
77+
## More complicated example
78+
79+
We could also launch multiple functions, with different GPUs:
4380

4481
```python
45-
trx.launch(
46-
# ...
47-
hostnames=trx.slurm_hosts(),
48-
workers_per_host=trx.slurm_workers()
49-
)
82+
def train_model(model, dataset):
83+
trained_model = train(model, dataset)
84+
85+
if int(os.environ["RANK"]) == 0:
86+
torch.save(learned_model, 'model.pt')
87+
return 'model.pt'
88+
89+
return None
90+
91+
def test_model(model_path, test_dataset):
92+
model = torch.load(model_path)
93+
accuracy = inference(model, test_dataset)
94+
return accuracy
5095
```
5196

52-
## Compared to other tools
97+
```python
98+
import torchrunx as trx
99+
100+
model_path = trx.launch(
101+
func=train_model,
102+
func_kwargs={'model': my_model, 'training_dataset': mnist_train},
103+
hostnames=["localhost", "other_node"],
104+
workers_per_host=2
105+
)["localhost"][0] # return from rank 0 (first worker on "localhost")
53106

54-
## Contributing
55107

56-
We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository. We use `ruff` for linting and formatting, `pyright` for static type checking, and `pytest` for testing. We build for `PyPI` and `conda-forge`. Our release pipeline is powered by Github Actions.
108+
109+
accuracy = trx.launch(
110+
func=test_model,
111+
func_kwargs={'model': learned_model, 'test_dataset': mnist_test},
112+
hostnames=["localhost"],
113+
workers_per_host=1
114+
)["localhost"][0]
115+
116+
print(f'Accuracy: {accuracy}')
117+
```

docs/source/contributing.rst

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
Contributing
22
============
33

4-
Development environment
5-
-----------------------
4+
.. include:: ../../CONTRIBUTING.md
5+
:parser: myst_parser.sphinx_
66

7-
Ensure you have the latest development environment installed. After cloning our repository, `install pixi <https://pixi.sh/latest/#installation>`_ and run ``pixi shell`` in the repo's root directory. Additionally, we use `ruff <https://github.com/astral-sh/ruff>`_ for linting and formatting, `pyright <https://github.com/microsoft/pyright>`_ for type checking, and ``pytest`` for testing.
7+
.. Development environment
8+
.. -----------------------
89
9-
Testing
10-
-------
10+
.. Ensure you have the latest development environment installed. After cloning our repository, `install pixi <https://pixi.sh/latest/#installation>`_ and run ``pixi shell`` in the repo's root directory. Additionally, we use `ruff <https://github.com/astral-sh/ruff>`_ for linting and formatting, `pyright <https://github.com/microsoft/pyright>`_ for type checking, and ``pytest`` for testing.
1111
12-
``tests/`` contains ``pytest``-style tests for validating that code changes do not break the core functionality of **torchrunx**. At the moment, we have a few simple CI tests powered by Github action, which are limited to single-agent CPU-only tests due to Github's infrastructure.
12+
.. Testing
13+
.. -------
1314
14-
Contributing
15-
------------
15+
.. ``tests/`` contains ``pytest``-style tests for validating that code changes do not break the core functionality of **torchrunx**. At the moment, we have a few simple CI tests powered by Github action, which are limited to single-agent CPU-only tests due to Github's infrastructure.
16+
17+
.. Contributing
18+
.. ------------
1619
17-
Make a pull request with your changes and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in **torchrunx**.
20+
.. Make a pull request with your changes and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in **torchrunx**.

src/torchrunx/agent.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class WorkerArgs:
3232
logger_port: int
3333
main_agent_hostname: str
3434
main_agent_port: int
35-
backend: Literal["mpi", "gloo", "nccl", "ucc", None]
35+
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None
3636
rank: int
3737
local_rank: int
3838
local_world_size: int
@@ -67,29 +67,30 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce
6767

6868
redirect_stdio_to_logger(logger)
6969

70-
store = dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
71-
host_name=worker_args.main_agent_hostname,
72-
port=worker_args.main_agent_port,
73-
world_size=worker_args.world_size,
74-
is_master=(worker_args.rank == 0),
75-
)
76-
77-
backend = worker_args.backend or ("nccl" if torch.cuda.is_available() else "gloo")
78-
79-
dist.init_process_group(
80-
backend=backend,
81-
world_size=worker_args.world_size,
82-
rank=worker_args.rank,
83-
store=store,
84-
timeout=datetime.timedelta(seconds=worker_args.timeout),
85-
)
86-
87-
os.environ["RANK"] = str(worker_args.rank)
88-
os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
89-
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
90-
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
91-
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
92-
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)
70+
if worker_args.backend is not None:
71+
os.environ["RANK"] = str(worker_args.rank)
72+
os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
73+
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
74+
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
75+
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
76+
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)
77+
78+
backend = worker_args.backend
79+
if backend == "auto":
80+
backend = "nccl" if torch.cuda.is_available() else "gloo"
81+
82+
dist.init_process_group(
83+
backend=backend,
84+
world_size=worker_args.world_size,
85+
rank=worker_args.rank,
86+
store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
87+
host_name=worker_args.main_agent_hostname,
88+
port=worker_args.main_agent_port,
89+
world_size=worker_args.world_size,
90+
is_master=(worker_args.rank == 0),
91+
),
92+
timeout=datetime.timedelta(seconds=worker_args.timeout),
93+
)
9394

9495
try:
9596
return worker_args.function()

0 commit comments

Comments
 (0)