Skip to content

Commit 9e2d5f4

Browse files
committed
update all docs
1 parent 3a68eb6 commit 9e2d5f4

File tree

6 files changed

+38
-38
lines changed

6 files changed

+38
-38
lines changed

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,13 @@ Here's a simple example where we "train" a model on two nodes (with 2 GPUs each)
5656
import torchrunx as trx
5757

5858
if __name__ == "__main__":
59-
trained_model = trx.launch(
59+
result = trx.launch(
6060
func=train,
6161
hostnames=["localhost", "other_node"],
62-
workers_per_host=2 # num. GPUs
63-
).value(rank=0) # get returned object
62+
workers_per_host=2 # number of GPUs
63+
)
6464

65+
trained_model = result.rank(0)
6566
torch.save(trained_model.state_dict(), "model.pth")
6667
```
6768

@@ -70,9 +71,9 @@ if __name__ == "__main__":
7071

7172
## Why should I use this?
7273

73-
Whether you have 1 GPU, 8 GPUs, or 8 machines.
74+
Whether you have 1 GPU, 8 GPUs, or 8 machines:
7475

75-
__Features:__
76+
__Features__
7677

7778
- Our [`launch()`](https://torchrunx.readthedocs.io/stable/api.html#torchrunx.launch) utility is super _Pythonic_
7879
- Return objects from your workers
@@ -81,13 +82,13 @@ __Features:__
8182
- Fine-grained control over logging, environment variables, exception handling, etc.
8283
- Automatic integration with SLURM
8384

84-
__Robustness:__
85+
__Robustness__
8586

8687
- If you want to run a complex, _modular_ workflow in __one__ script
8788
- don't parallelize your entire script: just the functions you want!
8889
- no worries about memory leaks or OS failures
8990

90-
__Convenience:__
91+
__Convenience__
9192

9293
- If you don't want to:
9394
- set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself

docs/source/advanced.rst

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@ We could also launch multiple functions (e.g. train on many GPUs, test on one GP
1414
func=train,
1515
hostnames=["node1", "node2"],
1616
workers_per_host=8
17-
).value(rank=0)
17+
).rank(0)
1818
1919
accuracy = trx.launch(
2020
func=test,
21-
func_kwargs={'model': model},
21+
func_args=(trained_model,),
2222
hostnames=["localhost"],
2323
workers_per_host=1
24-
).value(rank=0)
24+
).rank(0)
2525
2626
print(f'Accuracy: {accuracy}')
2727
2828
29-
:mod:`torchrunx.launch` is self-cleaning: all processes are terminated (and the used memory is completely released) after each invocation.
29+
:mod:`torchrunx.launch` is self-cleaning: all processes are terminated (and the used memory is completely released) before the subsequent invocation.
3030

3131
Launcher class
3232
--------------
@@ -85,9 +85,9 @@ Raises a ``RuntimeError`` if ``hostnames="slurm"`` or ``workers_per_host="slurm"
8585
Propagating exceptions
8686
----------------------
8787
88-
Exceptions that are raised in Workers will be raised by the launcher process.
88+
Exceptions that are raised in workers will be raised by the launcher process.
8989
90-
A :mod:`torchrunx.AgentKilledError` will be raised if any agent dies unexpectedly (e.g. if force-killed by the OS, due to segmentation faults or OOM).
90+
A :mod:`torchrunx.AgentFailedError` or :mod:`torchrunx.WorkerFailedError` will be raised if any agent or worker dies unexpectedly (e.g. if sent a signal from the OS, due to segmentation faults or OOM).
9191
9292
Environment variables
9393
---------------------
@@ -100,14 +100,14 @@ Environment variables in the launcher process that match the ``default_env_vars`
100100
Custom logging
101101
--------------
102102
103-
We forward all logs (i.e. from ``logging`` and ``stdio``) from workers and agents to the Launcher. By default, the logs from the first agent and its first worker are printed into the Launcher's ``stdout`` stream. Logs from all agents and workers are written to files in ``$TORCHRUNX_LOG_DIR`` (default: ``./torchrunx_logs``) and are named by timestamp, hostname, and local_rank.
103+
We forward all logs (i.e. from :mod:`logging` and :mod:`sys.stdin`/:mod:`sys.stdout`) from workers and agents to the launcher. By default, the logs from the first agent and its first worker are printed into the launcher's ``stdout`` stream. Logs from all agents and workers are written to files in ``$TORCHRUNX_LOG_DIR`` (default: ``./torchrunx_logs``) and are named by timestamp, hostname, and local_rank.
104104
105-
``logging.Handler`` objects can be provided via the ``log_handlers`` argument to provide further customization (mapping specific agents/workers to custom output streams).
105+
:mod:`logging.Handler` objects can be provided via the ``log_handlers`` argument to provide further customization (mapping specific agents/workers to custom output streams).
106106
107107
We provide some utilities to help:
108108
109-
.. autofunction:: torchrunx.add_filter_to_handler
110-
111109
.. autofunction:: torchrunx.file_handler
112110
113111
.. autofunction:: torchrunx.stream_handler
112+
113+
.. autofunction:: torchrunx.add_filter_to_handler

docs/source/api.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ API
66
.. autoclass:: torchrunx.LaunchResult
77
:members:
88

9-
.. autoclass:: torchrunx.AgentKilledError
9+
.. autoclass:: torchrunx.AgentFailedError
10+
11+
.. autoclass:: torchrunx.WorkerFailedError

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
"myst_parser",
1818
"sphinx_toolbox.sidebar_links",
1919
"sphinx_toolbox.github",
20-
"sphinx.ext.autodoc.typehints",
2120
"sphinx.ext.napoleon",
21+
"sphinx.ext.autodoc.typehints",
2222
"sphinx.ext.linkcode",
2323
]
2424

src/torchrunx/launcher.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@
3939

4040
@dataclass
4141
class Launcher:
42-
"""Alias class for ``torchrunx.launch``.
43-
44-
Useful for sequential invocations on the same configuration or for specifying arguments via CLI.
45-
"""
42+
"""Useful for sequential invocations or for specifying arguments via CLI."""
4643

4744
hostnames: list[str] | Literal["auto", "slurm"] = "auto"
4845
workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto"
@@ -69,7 +66,7 @@ def run( # noqa: C901, PLR0912
6966
func_kwargs: dict[str, Any] | None = None,
7067
log_handlers: list[Handler] | Literal["auto"] | None = "auto",
7168
) -> LaunchResult:
72-
"""Run a function using the configuration in ``torchrunx.Launcher``."""
69+
"""Run a function using the :mod:`torchrunx.Launcher` configuration."""
7370
if not dist.is_available():
7471
msg = "The torch.distributed package is not available."
7572
raise RuntimeError(msg)
@@ -267,21 +264,21 @@ class LaunchResult:
267264
hostnames: list[str]
268265
return_values: list[list[Any]]
269266

270-
def by_hostname(self) -> dict[str, list[Any]]:
267+
def by_hostnames(self) -> dict[str, list[Any]]:
271268
"""All return values from workers, indexed by host and local rank."""
272269
return dict(zip(self.hostnames, self.return_values))
273270

274-
def by_rank(self) -> list[Any]:
271+
def by_ranks(self) -> list[Any]:
275272
"""All return values from workers, indexed by global rank."""
276273
return reduce(add, self.return_values)
277274

278-
def get(self, hostname: str, rank: int) -> Any:
279-
"""Get return value from worker (indexed by host and local rank)."""
275+
def index(self, hostname: str, rank: int) -> Any:
276+
"""Get return value from worker by host and local rank."""
280277
return self.return_values[self.hostnames.index(hostname)][rank]
281278

282-
def rank(self, idx: int) -> Any:
283-
"""Get return value from worker (indexed by global rank)."""
284-
return self.by_rank()[idx]
279+
def rank(self, i: int) -> Any:
280+
"""Get return value from worker by global rank."""
281+
return self.by_rank()[i]
285282

286283

287284
def _resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:

src/torchrunx/utils/logging.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def add_filter_to_handler(
4141
local_rank: int | None, # None indicates agent
4242
log_level: int = logging.NOTSET,
4343
) -> None:
44-
"""A filter for ``logging.Handler`` such that only specific agent/worker logs are handled.
44+
"""A filter for :mod:`logging.Handler` such that only specific agent/worker logs are handled.
4545
4646
Args:
47-
handler: ``logging.Handler`` to be modified.
47+
handler: Handler to be modified.
4848
hostname: Name of specified host.
4949
local_rank: Rank of specified worker (or ``None`` for agent).
5050
log_level: Minimum log level to capture.
@@ -63,7 +63,7 @@ def _filter(record: WorkerLogRecord) -> bool:
6363
def stream_handler(
6464
hostname: str, local_rank: int | None, log_level: int = logging.NOTSET
6565
) -> Handler:
66-
"""logging.Handler builder function for writing logs to stdout."""
66+
"""Handler builder function for writing logs from specified hostname/rank to stdout."""
6767
handler = logging.StreamHandler(stream=sys.stdout)
6868
add_filter_to_handler(handler, hostname, local_rank, log_level=log_level)
6969
handler.setFormatter(
@@ -82,7 +82,7 @@ def file_handler(
8282
file_path: str | os.PathLike,
8383
log_level: int = logging.NOTSET,
8484
) -> Handler:
85-
"""logging.Handler builder function for writing logs to a file."""
85+
"""Handler builder function for writing logs from specified hostname/rank to a file."""
8686
handler = logging.FileHandler(file_path)
8787
add_filter_to_handler(handler, hostname, local_rank, log_level=log_level)
8888
formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s")
@@ -96,7 +96,7 @@ def file_handlers(
9696
log_dir: str | os.PathLike = Path("torchrunx_logs"),
9797
log_level: int = logging.NOTSET,
9898
) -> list[Handler]:
99-
"""Builder function for writing logs for all workers/agents to a directory.
99+
"""Handler builder function for writing logs for all workers/agents to a directory.
100100
101101
Files are named with timestamp, hostname, and the local_rank (for workers).
102102
"""
@@ -123,9 +123,9 @@ def default_handlers(
123123
log_dir: str | os.PathLike = Path("torchrunx_logs"),
124124
log_level: int = logging.INFO,
125125
) -> list[Handler]:
126-
"""A default set of logging.Handlers to be used when ``launch(log_handlers="auto")``.
126+
"""Default :mod:`logging.Handler`s for ``log_handlers="auto"`` in :mod:`torchrunx.launch`.
127127
128-
Logs for host[0] and its local_rank[0] worker are written to the launcher process stdout.
128+
Logs for ``host[0]`` and its ``local_rank[0]`` worker are written to launcher process stdout.
129129
Logs for all agents/workers are written to files in ``log_dir`` (named by timestamp, hostname,
130130
local_rank).
131131
"""

0 commit comments

Comments
 (0)