Skip to content

Commit bb52034

Browse files
committed
Merge branch 'main' of github.com:apoorvkh/torchrunx into python-3.9
2 parents 1735920 + b8b18f6 commit bb52034

File tree

10 files changed

+379
-391
lines changed

10 files changed

+379
-391
lines changed

.github/workflows/docs-preview.yml

Lines changed: 0 additions & 16 deletions
This file was deleted.

README.md

Lines changed: 54 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# torchrunx 🔥
22

33
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchrunx)](https://github.com/apoorvkh/torchrunx/blob/main/pyproject.toml)
4+
[![PyTorch Version](https://img.shields.io/badge/torch-%3E%3D2.0-orange)](https://github.com/pytorch/pytorch)
45
[![PyPI - Version](https://img.shields.io/pypi/v/torchrunx)](https://pypi.org/project/torchrunx/)
56
![Tests](https://img.shields.io/github/actions/workflow/status/apoorvkh/torchrunx/.github%2Fworkflows%2Fmain.yml)
67
[![Docs](https://readthedocs.org/projects/torchrunx/badge/?version=stable)](https://torchrunx.readthedocs.io)
@@ -16,102 +17,78 @@ By [Apoorv Khandelwal](http://apoorvkh.com) and [Peter Curtin](https://github.co
1617
pip install torchrunx
1718
```
1819

19-
Requires: Linux, Python >= 3.9, PyTorch >= 2.0
20+
**Requires:** Linux (with shared filesystem & SSH access if using multiple machines)
2021

21-
Shared filesystem & SSH access if using multiple machines
22+
## Demo
2223

23-
## Minimal example
24+
Here's a simple example where we "train" a model on two nodes (with 2 GPUs each).
2425

25-
Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each):
26+
<details>
27+
<summary>Training code</summary>
2628

27-
```python
28-
def train_model(model, dataset):
29-
trained_model = train(model, dataset)
30-
31-
if int(os.environ["RANK"]) == 0:
32-
torch.save(learned_model, 'model.pt')
33-
return 'model.pt'
34-
35-
return None
36-
```
37-
38-
```python
39-
import torchrunx as trx
40-
41-
model_path = trx.launch(
42-
func=train_model,
43-
func_kwargs={'model': my_model, 'training_dataset': mnist_train},
44-
hostnames=["localhost", "other_node"],
45-
workers_per_host=2
46-
)["localhost"][0] # return from rank 0 (first worker on "localhost")
47-
```
48-
49-
## Why should I use this?
50-
51-
[`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) is a hammer. `torchrunx` is a chisel.
52-
53-
Whether you have 1 GPU, 8 GPUs, or 8 machines:
54-
55-
Convenience:
29+
```python
30+
import os
31+
import torch
5632

57-
- If you don't want to set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself
58-
- If you want to run `python myscript.py` instead of `torchrun myscript.py`
59-
- If you don't want to manually SSH and run `torchrun --master-ip --master-port ...` on every machine (and if you don't want to babysit these machines for hanging failures)
33+
def train():
34+
rank = int(os.environ['RANK'])
35+
local_rank = int(os.environ['LOCAL_RANK'])
6036

61-
Robustness:
37+
model = torch.nn.Linear(10, 10).to(local_rank)
38+
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
39+
optimizer = torch.optim.AdamW(ddp_model.parameters())
6240

63-
- If you want to run a complex, _modular_ workflow in one script
64-
- no worries about memory leaks or OS failures
65-
- don't parallelize your entire script: just the functions you want
66-
67-
Features:
41+
optimizer.zero_grad()
42+
outputs = ddp_model(torch.randn(5, 10))
43+
labels = torch.randn(5, 10).to(local_rank)
44+
torch.nn.functional.mse_loss(outputs, labels).backward()
45+
optimizer.step()
6846

69-
- Our launch utility is super _Pythonic_
70-
- If you want to run distributed PyTorch functions from Python Notebooks.
71-
- Automatic integration with SLURM
47+
if rank == 0:
48+
return model
49+
```
7250

73-
Why not?
51+
You could also use `transformers.Trainer` (or similar) to automatically handle all the multi-GPU / DDP code above.
52+
</details>
7453

75-
- We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR.
7654

77-
## More complicated example
55+
```python
56+
import torchrunx as trx
7857

79-
We could also launch multiple functions, with different GPUs:
58+
if __name__ == "__main__":
59+
trained_model = trx.launch(
60+
func=train,
61+
hostnames=["localhost", "other_node"],
62+
workers_per_host=2 # num. GPUs
63+
).value(rank=0) # get returned object
8064

81-
```python
82-
def train_model(model, dataset):
83-
trained_model = train(model, dataset)
65+
torch.save(trained_model.state_dict(), "model.pth")
66+
```
8467

85-
if int(os.environ["RANK"]) == 0:
86-
torch.save(learned_model, 'model.pt')
87-
return 'model.pt'
68+
### [Full API](https://torchrunx.readthedocs.io/stable/api.html)
69+
### [Advanced Usage](https://torchrunx.readthedocs.io/stable/advanced.html)
8870

89-
return None
71+
## Why should I use this?
9072

91-
def test_model(model_path, test_dataset):
92-
model = torch.load(model_path)
93-
accuracy = inference(model, test_dataset)
94-
return accuracy
95-
```
73+
Whether you have 1 GPU, 8 GPUs, or 8 machines.
9674

97-
```python
98-
import torchrunx as trx
75+
__Features:__
9976

100-
model_path = trx.launch(
101-
func=train_model,
102-
func_kwargs={'model': my_model, 'training_dataset': mnist_train},
103-
hostnames=["localhost", "other_node"],
104-
workers_per_host=2
105-
)["localhost"][0] # return from rank 0 (first worker on "localhost")
77+
- Our [`launch()`](https://torchrunx.readthedocs.io/stable/api.html#torchrunx.launch) utility is super _Pythonic_
78+
- Return objects from your workers
79+
- Run `python script.py` instead of `torchrun script.py`
80+
- Launch multi-node functions, even from Python Notebooks
81+
- Fine-grained control over logging, environment variables, exception handling, etc.
82+
- Automatic integration with SLURM
10683

84+
__Robustness:__
10785

86+
- If you want to run a complex, _modular_ workflow in __one__ script
87+
- don't parallelize your entire script: just the functions you want!
88+
- no worries about memory leaks or OS failures
10889

109-
accuracy = trx.launch(
110-
func=test_model,
111-
func_kwargs={'model': learned_model, 'test_dataset': mnist_test},
112-
hostnames=["localhost"],
113-
workers_per_host=1
114-
)["localhost"][0]
90+
__Convenience:__
11591

116-
print(f'Accuracy: {accuracy}')
117-
```
92+
- If you don't want to:
93+
- set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself
94+
- manually SSH into every machine and `torchrun --master-ip --master-port ...`, babysit failed processes, etc.

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
sphinx==6.2.1
22
furo
33
myst-parser
4-
sphinx-toolbox
4+
sphinx-toolbox

docs/source/advanced.rst

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,33 @@
11
Advanced Usage
22
==============
33

4+
Multiple functions in one script
5+
--------------------------------
6+
7+
We could also launch multiple functions (e.g. train on many GPUs, test on one GPU):
8+
9+
.. code-block:: python
10+
11+
import torchrunx as trx
12+
13+
trained_model = trx.launch(
14+
func=train,
15+
hostnames=["node1", "node2"],
16+
workers_per_host=8
17+
).value(rank=0)
18+
19+
accuracy = trx.launch(
20+
func=test,
21+
func_kwargs={'model': model},
22+
hostnames=["localhost"],
23+
workers_per_host=1
24+
).value(rank=0)
25+
26+
print(f'Accuracy: {accuracy}')
27+
28+
``trx.launch()`` is self-cleaning: all processes are terminated (and the used memory is completely released) after each invocation.
29+
30+
431
Environment Detection
532
---------------------
633

@@ -61,18 +88,9 @@ For example, the `python ... --help` command will then result in:
6188
Custom Logging
6289
--------------
6390
64-
Logs are generated at the worker and agent level, and are specified to :mod:`torchrunx.launch` via the ``log_spec`` argument. By default, a :mod:`torchrunx.DefaultLogSpec` is instantiated, causing logs at the worker and agent levels to be logged to files under ``'./logs'``, and the rank 0 worker's output streams are streamed to the launcher ``stdout``. Logs are prefixed with a timestamp by default. Agent logs have the format ``{timestamp}-{agent hostname}.log`` and workers have the format ``{timestamp}-{agent hostname}[{worker local rank}].log``.
65-
66-
Custom logging classes can be subclassed from the :mod:`torchrunx.LogSpec` class. Any subclass must have a ``get_map`` method returning a dictionary mapping logger names to lists of :mod:`logging.Handler` objects, in order to be passed to :mod:`torchrunx.launch`. The logger names are of the format ``{agent hostname}`` for agents and ``{agent hostname}[{worker local rank}]`` for workers. The :mod:`torchrunx.DefaultLogSpec` maps all the loggers to :mod:`logging.Filehandler` object pointing to the files mentioned in the previous paragraph. It additionally maps the global rank 0 worker to a :mod:`logging.StreamHandler`, which writes logs the launcher's ``stdout`` stream.
67-
68-
.. autoclass:: torchrunx.LogSpec
69-
:members:
70-
71-
.. autoclass:: torchrunx.DefaultLogSpec
72-
:members:
91+
Logs are generated at the worker and agent level, and are specified to :mod:`torchrunx.launch` via the ``log_spec`` argument. By default, a is instantiated, causing logs at the worker and agent levels to be logged to files under ``'./logs'``, and the rank 0 worker's output streams are streamed to the launcher ``stdout``. Logs are prefixed with a timestamp by default. Agent logs have the format ``{timestamp}-{agent hostname}.log`` and workers have the format ``{timestamp}-{agent hostname}[{worker local rank}].log``.
7392
74-
..
75-
TODO: example log structure
93+
Custom logging classes can be subclassed from the class. Any subclass must have a ``get_map`` method returning a dictionary mapping logger names to lists of :mod:`logging.Handler` objects, in order to be passed to :mod:`torchrunx.launch`. The logger names are of the format ``{agent hostname}`` for agents and ``{agent hostname}[{worker local rank}]`` for workers. The maps all the loggers to :mod:`logging.Filehandler` object pointing to the files mentioned in the previous paragraph. It additionally maps the global rank 0 worker to a :mod:`logging.StreamHandler`, which writes logs the launcher's ``stdout`` stream.
7694
7795
Propagating Exceptions
7896
----------------------

docs/source/api.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
API
22
=============
33

4-
..
5-
TODO: examples, environmental variables available to workers (e.g. RANK, LOCAL_RANK)
4+
.. autofunction:: torchrunx.launch(func: Callable, ...)
65

7-
.. automodule:: torchrunx
8-
:members: launch, slurm_hosts, slurm_workers
6+
.. autoclass:: torchrunx.LaunchResult
7+
:members:

docs/source/conf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@
2020
'myst_parser',
2121
'sphinx_toolbox.sidebar_links',
2222
'sphinx_toolbox.github',
23+
'sphinx.ext.autodoc.typehints',
24+
#"sphinx_autodoc_typehints",
2325
]
2426

27+
autodoc_typehints = "both"
28+
#typehints_defaults = 'comma'
29+
2530
github_username = 'apoorvkh'
2631
github_repository = 'torchrunx'
2732

@@ -43,4 +48,4 @@
4348
epub_show_urls = 'footnote'
4449

4550
# code block syntax highlighting
46-
#pygments_style = 'sphinx'
51+
#pygments_style = 'sphinx'

docs/source/index.rst

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
1-
Getting Started
2-
===============
3-
41
.. include:: ../../README.md
52
:parser: myst_parser.sphinx_
63

7-
Contents
8-
--------
9-
104
.. toctree::
11-
:maxdepth: 2
5+
:hidden:
6+
:maxdepth: 1
127

138
api
149
advanced
@@ -17,4 +12,4 @@ Contents
1712

1813
.. sidebar-links::
1914
:github:
20-
:pypi: torchrunx
15+
:pypi: torchrunx

0 commit comments

Comments
 (0)