Skip to content

Commit 1035b75

Browse files
committed
final updates to docs
1 parent ff0b90e commit 1035b75

File tree

12 files changed

+186
-171
lines changed

12 files changed

+186
-171
lines changed

CONTRIBUTING.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ We use `ruff check` for linting, `ruff format` for formatting, `pyright` for sta
66

77
## Pull Requests
88

9-
Make a pull request with your changes on Github and we'll try to look at it soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in __torchrunx__.
9+
Make a pull request with your changes on Github and we'll try to look at it soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in **torchrunx**.
1010

1111
## Testing
1212

@@ -16,4 +16,4 @@ At the moment, we run `pytest tests/test_ci.py` (i.e. simple single-node CPU-onl
1616

1717
## Documentation
1818

19-
Our documentation is hosted on Github Pages and is updated with every package release. We build our documentation with `sphinx` using the command: `uv run --group docs python -m sphinx --builder html --doctree-dir docs/_build/.doctrees --conf-dir docs --show-traceback docs/source docs/_build/html`. The documentation will then be generated at `docs/_build/html`.
19+
Our documentation is hosted on Github Pages and is updated with every package release. We build our documentation with [Sphinx](https://www.sphinx-doc.org): `source scripts/build_docs.sh`. The documentation will then be generated at `docs/_build/html` (and can be rendered with `python -m http.server --directory docs/_build/html`).

README.md

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,16 @@ It enables complex workflows within a single script and has useful features even
2121
pip install torchrunx
2222
```
2323

24-
Requires:
25-
- Linux
26-
- If using multiple machines: SSH & shared filesystem
24+
Requires: Linux. If using multiple machines: SSH & shared filesystem.
2725

2826
---
2927

30-
**Dummy example: parallelizing training with `torchrunx`**
28+
<h4>Example: simple training loop</h4>
29+
30+
Suppose we have some distributed training function (which needs to run on every GPU):
3131

3232
```python
33-
def distributed_training(model: nn.Module, num_steps: int) -> nn.Module:
34-
# Environment variables: RANK, LOCAL_RANK, ...
35-
# ddp_model = DistributedDataParallel(model, device_ids=[local_rank])
36-
...
37-
retun trained_model
33+
def distributed_training(model: nn.Module, num_steps: int) -> nn.Module: ...
3834
```
3935

4036
<details>
@@ -70,14 +66,14 @@ def distributed_training(model: nn.Module, num_steps: int = 10) -> nn.Module | N
7066

7167
</details>
7268

69+
We can distribute and run this function (e.g. on 2 machines x 2 GPUs) using **`torchrunx`**!
70+
7371
```python
7472
import torchrunx
7573

76-
# Launch training on 2 machines x 2 GPUs
77-
7874
launcher = torchrunx.Launcher(
79-
hostnames = ["localhost", "second_machine"],
80-
workers_per_host = 2
75+
hostnames = ["localhost", "second_machine"], # or IP addresses
76+
workers_per_host = 2 # e.g. number of GPUs per host
8177
)
8278

8379
results = launcher.run(
@@ -87,16 +83,17 @@ results = launcher.run(
8783
)
8884
```
8985

86+
Once completed, you can retrieve the results and process them as you wish.
87+
9088
```python
91-
# get the results
9289
trained_model: nn.Module = results.rank(0)
93-
# or: results.index(hostname="localhost", local_rank=0)
90+
# or: results.index(hostname="localhost", local_rank=0)
9491

95-
# and continue your script — e.g. save model to checkpoint
92+
# and continue your script
9693
torch.save(trained_model.state_dict(), "output/model.pth")
9794
```
9895

99-
**See examples where we fine-tune LLMs using:**
96+
**See more examples where we fine-tune LLMs using:**
10097
- [Transformers](https://torchrun.xyz/examples/transformers.html)
10198
- [DeepSpeed](https://torchrun.xyz/examples/deepspeed.html)
10299
- [PyTorch Lightning](https://torchrun.xyz/examples/lightning.html)

docs/source/examples/deepspeed.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Here's an example script that uses `torchrunx` with [DeepSpeed](https://www.deep
1414

1515
## Training GPT-2 on WikiText
1616

17-
Deepspeed requires additional (non-Python) dependencies. Use the following commands to set up a project. Source: [Apoorv's Blog — Managing Project Dependencies](https://blog.apoorvkh.com/posts/project-dependencies.html)
17+
Deepspeed requires additional (non-Python) dependencies. Use the following commands to set up a project. [source: [Apoorv's Blog — Managing Project Dependencies](https://blog.apoorvkh.com/posts/project-dependencies.html)]
1818

1919
Pre-requisite: [pixi](https://pixi.sh)
2020

docs/source/how_it_works.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@ Suppose you want to run a script (`train.py`) on `N` machines (or "nodes") with
44

55
You'll need to start a new process for each GPU. Each process will execute your script in parallel and select its GPU based on the process rank. Your script will also form a [distributed group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) so the processes may communicate with each other (e.g. passing tensors).
66

7+
## `torchrun`
8+
79
Normally, you'd do this by running the `torchrun --node-rank {i} ... train.py ...` command on every machine. In short, you'll end up with a topology like:
810

911
![torchrun diagram](./artifacts/torchrun.png)
1012

1113
As a side effect of this structure, every process will run until (1) script completion or (2) another process stops communicating (e.g. if killed by the system for abnormal reasons). The status of other processes is not actively communicated: so if some process is indeed killed, it would take 10 minutes (by default) for the remaining processes to time-out. Also, since this approach parallelizes the entire script, we can't catch and handle these system-level issues as exceptions.
1214

15+
## `torchrunx` 🔥
16+
1317
`torchrunx` offers a functional interface, with a launcher–worker topology, instead.
1418

1519
![torchrunx diagram](./artifacts/torchrunx.png)

docs/source/usage/general.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ You can catch these errors and handle them as you wish!
3434
```python
3535
for config in configs: # e.g. hyper-parameter sweep
3636
try:
37-
Launcher().run(train, config)
37+
torchrunx.Launcher().run(train, config)
3838
except torch.cuda.OutOfMemoryError:
3939
print(f"{config} results in OOM... continuing...")
4040
```
@@ -44,12 +44,12 @@ If you are expecting intermittent failures, you can catch errors and invoke retr
4444
```python
4545
for retry in range(3):
4646
try:
47-
Launcher().run(train, resume_from_checkpoint=True)
47+
torchrunx.Launcher().run(train, resume_from_checkpoint=True)
4848
except torchrunx.WorkerFailedError as e:
4949
print(f"Error occurred: {e}")
5050
print(f"Retrying ({retry}) ...")
51-
else:
52-
break
51+
else: # if run() is successful
52+
break
5353
```
5454

5555
## Environment variables

docs/source/usage/logging.md

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,62 @@
11
# Custom Logging
22

3-
We forward all worker and agent logs (i.e. from {mod}`logging`, {obj}`sys.stdout`, and {obj}`sys.stderr`) to the launcher for processing.
3+
We forward all agent and worker logs (i.e. from {mod}`logging`, {obj}`sys.stdout`, and {obj}`sys.stderr`) to the launcher process.
44

5-
By default, the logs from the rank 0 agent and worker are printed into the launcher's `stdout` stream. Logs from all agents and workers are written to a directory (by the current timestamp) in `$TORCHRUNX_LOG_DIR` (default: `./torchrunx_logs`).
5+
## Defaults
66

7-
You can fully customize how logs are processed using {func}`torchrunx.Launcher.set_logging_handlers`. You should provide it a function that constructs and returns a list of {obj}`logging.Handler` objects. Each {obj}`logging.Handler` controls where logs should be written.
7+
By default, the logs from the rank 0 agent and rank 0 worker are handled by loggers on the launcher process (and so they should be printed to `stdout`/`stderr`). You may control these logs like:
88

9-
We provide some handler utilities that direct a specified worker or agent's logs to a file or stream.
10-
11-
```{eval-rst}
12-
.. autofunction:: torchrunx.utils.file_handler
9+
```python
10+
logging.basicConfig(level=logging.INFO)
11+
logging.getLogger("torchrunx").setLevel(logging.DEBUG)
12+
logging.getLogger("torchrunx.node1").setLevel(logging.INFO)
13+
logging.getLogger("torchrunx.node1.1").setLevel(logging.INFO) # worker 1 (local rank) on node 1
1314
```
1415

15-
```{eval-rst}
16-
.. autofunction:: torchrunx.utils.stream_handler
17-
```
16+
Also, logs from all agents and workers are written to a directory (by the current timestamp) in `$TORCHRUNX_LOG_DIR` (default: `./torchrunx_logs`). These can be controlled using `$TORCHRUNX_LOG_LEVEL` (default: `INFO`).
1817

19-
For example, we could construct and pass a handler factory that streams the rank 0 agent and worker logs to the launcher's `stdout`.
18+
## Customization
19+
20+
You can fully customize how logs are processed using {func}`torchrunx.Launcher.set_logging_handlers`. You should provide it a factory function that constructs and returns a list of {obj}`logging.Handler` objects. Each {obj}`logging.Handler` controls where logs should be written. You can also add a filter to restrict the handler to the logs of a specific agent or worker.
21+
22+
Here's an example:
2023

2124
```python
22-
def rank_0_handlers() -> list[logging.Handler]:
25+
from torchrunx.utils.log_handling import RedirectHandler, get_handler_filter
26+
27+
def custom_handlers() -> list[logging.Handler]:
28+
29+
# Handler: redirect logs from (host 0, agent) to logger on launcher process
30+
redirect_handler = RedirectHandler()
31+
redirect_handler.addFilter(get_handler_filter(
32+
hostname=hostnames[0], local_rank=None, log_level=logging.DEBUG
33+
))
34+
35+
# Handler: output logs from (host 0, worker 0) to "output.txt"
36+
file_handler = logging.FileHandler("output.txt")
37+
file_handler.addFilter(get_handler_filter(
38+
hostname=hostnames[0], local_rank=0, log_level=logging.DEBUG
39+
))
40+
2341
return [
24-
stream_handler(hostname=hostnames[0], local_rank=None), # agent 0
25-
stream_handler(hostname=hostnames[0], local_rank=0), # worker 0
42+
redirect_handler,
43+
file_handler,
2644
]
2745
```
2846

2947
```python
30-
torchrunx.Launcher(...).set_logging_handlers(rank_0_handlers).run(...)
48+
torchrunx.Launcher(...).set_logging_handlers(custom_handlers).run(...)
3149
```
3250

33-
You can also [provide your own ``logging.Handler``](https://docs.python.org/3.9/library/logging.handlers.html#module-logging.handlers) and apply {func}`torchrunx.utils.add_filter_to_handler` to constrain which worker or agent's logs it should process.
51+
Finally, you can control library-specific logging (within the worker processes) by modifying the distributed function:
52+
53+
```python
54+
def distributed_function():
55+
logging.getLogger("transformers").setLevel(logging.DEBUG)
56+
57+
logger = logging.getLogger("my_app")
58+
logger.info("Hello world!")
59+
...
3460

35-
```{eval-rst}
36-
.. autofunction:: torchrunx.utils.add_filter_to_handler
61+
torchrunx.Launcher(...).run(distributed_function)
3762
```

docs/source/usage/slurm.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ def distributed_training():
1414

1515
if __name__ == "__main__":
1616
torchrunx.Launcher(
17-
# optionally specify:
18-
# hostnames = "slurm",
19-
# workers_per_host = "gpu"
17+
hostnames = "slurm",
18+
workers_per_host = "gpu"
2019
).run(distributed_training)
2120
```
2221

@@ -46,9 +45,8 @@ def distributed_training():
4645

4746
def launch_training():
4847
torchrunx.Launcher(
49-
# optionally specify:
50-
# hostnames = "slurm",
51-
# workers_per_host = "gpu"
48+
hostnames = "slurm",
49+
workers_per_host = "gpu"
5250
).run(distributed_training)
5351

5452
if __name__ == "__main__":

src/torchrunx/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
LauncherAgentGroup,
2020
get_open_port,
2121
)
22-
from .utils.logs import log_records_to_socket, redirect_stdio_to_logger
22+
from .utils.log_streaming import log_records_to_socket, redirect_stdio_to_logger
2323
from .worker import WorkerArgs, worker_entrypoint
2424

2525

src/torchrunx/launcher.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
resolve_environment,
3030
)
3131
from .utils.errors import ExceptionFromWorker, WorkerFailedError
32-
from .utils.logs import LoggingServerArgs, default_handlers, start_logging_server
32+
from .utils.log_handling import default_handlers
33+
from .utils.log_streaming import LoggingServerArgs, start_logging_server
3334

3435
DEFAULT_ENV_VARS_FOR_COPY = (
3536
"PATH",
@@ -80,10 +81,9 @@ def set_logging_handlers(
8081
) -> Self:
8182
"""Provide a ``handler_factory`` function to customize processing of agent/worker logs.
8283
83-
See `Custom Logging <https://torchrun.xyz/features/logging.html>`_.
84-
8584
Parameters:
8685
handler_factory: Function that constructs and returns :obj:`logging.Handler` objects.
86+
See `Custom Logging <https://torchrun.xyz/usage/logging.html>`_ for more details.
8787
"""
8888
self.handler_factory = handler_factory
8989
return self

src/torchrunx/utils/log_handling.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Utilities for intercepting logs in worker processes and handling these in the Launcher."""
2+
3+
from __future__ import annotations
4+
5+
__all__ = [
6+
"RedirectHandler",
7+
"default_handlers",
8+
"file_handlers",
9+
"get_handler_filter",
10+
]
11+
12+
import datetime
13+
import logging
14+
import os
15+
from logging import LogRecord
16+
from pathlib import Path
17+
from typing import Callable
18+
19+
20+
def get_handler_filter(
21+
hostname: str,
22+
local_rank: int | None, # None indicates agent
23+
log_level: int = logging.NOTSET,
24+
) -> Callable[[LogRecord], bool]:
25+
"""Get an agent- or worker- specific filter to apply to :obj:`logging.Handler`."""
26+
return lambda record: (
27+
record.hostname == hostname # pyright: ignore [reportAttributeAccessIssue]
28+
and record.local_rank == local_rank # pyright: ignore [reportAttributeAccessIssue]
29+
and record.levelno >= log_level
30+
)
31+
32+
33+
class RedirectHandler(logging.Handler):
34+
"""For handling logs from hostname/rank with a corresponding logger in the launcher process."""
35+
36+
def emit(self, record: LogRecord) -> None:
37+
"""Handle log record using corresponding logger."""
38+
logger = logging.getLogger(record.name)
39+
if logger.isEnabledFor(record.levelno):
40+
logger.handle(record)
41+
42+
43+
def file_handlers(
44+
hostnames: list[str],
45+
workers_per_host: list[int],
46+
log_dir: str | os.PathLike = Path("torchrunx_logs"),
47+
log_level: int = logging.NOTSET,
48+
) -> list[logging.Handler]:
49+
"""Handler builder function for writing logs for all workers/agents to a directory.
50+
51+
Files are named with hostname and the local_rank (for workers).
52+
"""
53+
handlers = []
54+
55+
timestamp = datetime.datetime.now().isoformat(timespec="seconds")
56+
log_dir = Path(log_dir) / timestamp
57+
log_dir.mkdir(parents=True, exist_ok=True)
58+
59+
formatter = logging.Formatter(
60+
"%(asctime)s:%(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
61+
)
62+
63+
for hostname, num_workers in zip(hostnames, workers_per_host):
64+
for local_rank in [None, *range(num_workers)]:
65+
local_rank_str = f"[{local_rank}]" if local_rank is not None else ""
66+
file_path = log_dir / f"{hostname}{local_rank_str}.log"
67+
68+
h = logging.FileHandler(file_path)
69+
h.addFilter(get_handler_filter(hostname, local_rank, log_level=log_level))
70+
h.setFormatter(formatter)
71+
72+
handlers.append(h)
73+
74+
return handlers
75+
76+
77+
def default_handlers(hostnames: list[str], workers_per_host: list[int]) -> list[logging.Handler]:
78+
"""Constructs default :obj:`logging.Handler` objects.
79+
80+
Logs for the rank 0 agent and rank 0 worker are redirected to loggers in the launcher process.
81+
Logs for all hosts/workers are written to files in ``$TORCHRUNX_LOG_DIR`` (named by timestamp,
82+
hostname, local_rank).
83+
"""
84+
log_dir = Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs"))
85+
86+
file_log_level = os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")
87+
if file_log_level.isdigit():
88+
file_log_level = int(file_log_level)
89+
elif file_log_level in logging._nameToLevel: # noqa: SLF001
90+
file_log_level = logging._nameToLevel[file_log_level] # noqa: SLF001
91+
else:
92+
msg = (
93+
f"Invalid value for $TORCHRUNX_LOG_LEVEL: {file_log_level}. "
94+
f"Should be a positive integer or any of: {', '.join(logging._nameToLevel.keys())}." # noqa: SLF001
95+
)
96+
raise ValueError(msg)
97+
98+
redirect_agent_0_handler = RedirectHandler()
99+
redirect_agent_0_handler.addFilter(get_handler_filter(hostnames[0], None))
100+
101+
redirect_worker_0_handler = RedirectHandler()
102+
redirect_worker_0_handler.addFilter(get_handler_filter(hostnames[0], 0))
103+
104+
return [
105+
redirect_agent_0_handler,
106+
redirect_worker_0_handler,
107+
*file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=file_log_level),
108+
]

0 commit comments

Comments
 (0)