|
6 | 6 | [](https://torchrunx.readthedocs.io)
|
7 | 7 | [](https://github.com/apoorvkh/torchrunx/blob/main/LICENSE)
|
8 | 8 |
|
9 |
| -Automatically launch functions and initialize distributed PyTorch environments on multiple machines |
| 9 | +By [Apoorv Khandelwal](http://apoorvkh.com) and [Peter Curtin](https://github.com/pmcurtin) |
| 10 | + |
| 11 | +**Automatically distribute PyTorch functions onto multiple machines or GPUs** |
10 | 12 |
|
11 | 13 | ## Installation
|
12 | 14 |
|
13 | 15 | ```bash
|
14 | 16 | pip install torchrunx
|
15 | 17 | ```
|
16 | 18 |
|
17 |
| -Requirements: |
18 |
| -- Operating System: Linux |
19 |
| -- Python >= 3.8.1 |
20 |
| -- PyTorch >= 2.0 |
21 |
| -- Shared filesystem & passwordless SSH between hosts |
| 19 | +Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0 |
| 20 | + |
| 21 | +Shared filesystem & SSH access if using multiple machines |
22 | 22 |
|
23 |
| -## Usage |
| 23 | +## Minimal example |
| 24 | + |
| 25 | +Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each): |
24 | 26 |
|
25 | 27 | ```python
|
26 |
| -# Simple example |
27 |
| -def distributed_function(): |
28 |
| - pass |
| 28 | +def train_model(model, dataset): |
| 29 | + trained_model = train(model, dataset) |
| 30 | + |
| 31 | + if int(os.environ["RANK"]) == 0: |
| 32 | + torch.save(learned_model, 'model.pt') |
| 33 | + return 'model.pt' |
| 34 | + |
| 35 | + return None |
29 | 36 | ```
|
30 | 37 |
|
31 | 38 | ```python
|
32 | 39 | import torchrunx as trx
|
33 | 40 |
|
34 |
| -trx.launch( |
35 |
| - func=distributed_function, |
36 |
| - func_kwargs={}, |
37 |
| - hostnames=["node1", "node2"], # or just: ["localhost"] |
| 41 | +model_path = trx.launch( |
| 42 | + func=train_model, |
| 43 | + func_kwargs={'model': my_model, 'training_dataset': mnist_train}, |
| 44 | + hostnames=["localhost", "other_node"], |
38 | 45 | workers_per_host=2
|
39 |
| -) |
| 46 | +)["localhost"][0] # return from rank 0 (first worker on "localhost") |
40 | 47 | ```
|
41 | 48 |
|
42 |
| -### In a SLURM allocation |
| 49 | +## Why should I use this? |
| 50 | + |
| 51 | +[`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) is a hammer. `torchrunx` is a chisel. |
| 52 | + |
| 53 | +Whether you have 1 GPU, 8 GPUs, or 8 machines: |
| 54 | + |
| 55 | +Convenience: |
| 56 | + |
| 57 | +- If you don't want to set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself |
| 58 | +- If you want to run `python myscript.py` instead of `torchrun myscript.py` |
| 59 | +- If you don't want to manually SSH and run `torchrun --master-ip --master-port ...` on every machine (and if you don't want to babysit these machines for hanging failures) |
| 60 | + |
| 61 | +Robustness: |
| 62 | + |
| 63 | +- If you want to run a complex, _modular_ workflow in one script |
| 64 | + - no worries about memory leaks or OS failures |
| 65 | + - don't parallelize your entire script: just the functions you want |
| 66 | + |
| 67 | +Features: |
| 68 | + |
| 69 | +- Our launch utility is super _Pythonic_ |
| 70 | +- If you want to run distributed PyTorch functions from Python Notebooks. |
| 71 | +- Automatic integration with SLURM |
| 72 | + |
| 73 | +Why not? |
| 74 | + |
| 75 | +- We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR. |
| 76 | + |
| 77 | +## More complicated example |
| 78 | + |
| 79 | +We could also launch multiple functions, with different GPUs: |
43 | 80 |
|
44 | 81 | ```python
|
45 |
| -trx.launch( |
46 |
| - # ... |
47 |
| - hostnames=trx.slurm_hosts(), |
48 |
| - workers_per_host=trx.slurm_workers() |
49 |
| -) |
| 82 | +def train_model(model, dataset): |
| 83 | + trained_model = train(model, dataset) |
| 84 | + |
| 85 | + if int(os.environ["RANK"]) == 0: |
| 86 | + torch.save(learned_model, 'model.pt') |
| 87 | + return 'model.pt' |
| 88 | + |
| 89 | + return None |
| 90 | + |
| 91 | +def test_model(model_path, test_dataset): |
| 92 | + model = torch.load(model_path) |
| 93 | + accuracy = inference(model, test_dataset) |
| 94 | + return accuracy |
50 | 95 | ```
|
51 | 96 |
|
52 |
| -## Compared to other tools |
| 97 | +```python |
| 98 | +import torchrunx as trx |
| 99 | + |
| 100 | +model_path = trx.launch( |
| 101 | + func=train_model, |
| 102 | + func_kwargs={'model': my_model, 'training_dataset': mnist_train}, |
| 103 | + hostnames=["localhost", "other_node"], |
| 104 | + workers_per_host=2 |
| 105 | +)["localhost"][0] # return from rank 0 (first worker on "localhost") |
53 | 106 |
|
54 |
| -## Contributing |
55 | 107 |
|
56 |
| -We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository. We use `ruff` for linting and formatting, `pyright` for static type checking, and `pytest` for testing. We build for `PyPI` and `conda-forge`. Our release pipeline is powered by Github Actions. |
| 108 | + |
| 109 | +accuracy = trx.launch( |
| 110 | + func=test_model, |
| 111 | + func_kwargs={'model': learned_model, 'test_dataset': mnist_test}, |
| 112 | + hostnames=["localhost"], |
| 113 | + workers_per_host=1 |
| 114 | +)["localhost"][0] |
| 115 | + |
| 116 | +print(f'Accuracy: {accuracy}') |
| 117 | +``` |
0 commit comments