Skip to content

Commit 4f065cd

Browse files
authored
Merge pull request #52 from apoorvkh/automatic
automatic function
2 parents 0bd7964 + c1f522a commit 4f065cd

File tree

3 files changed

+57
-10
lines changed

3 files changed

+57
-10
lines changed

src/torchrunx/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1+
from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
12
from .launcher import Launcher, launch
2-
from .slurm import slurm_hosts, slurm_workers
33

4-
__all__ = ["Launcher", "launch", "slurm_hosts", "slurm_workers"]
4+
__all__ = ["Launcher", "launch", "slurm_hosts", "slurm_workers", "auto_hosts", "auto_workers"]

src/torchrunx/slurm.py renamed to src/torchrunx/environment.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
import os
44
import subprocess
55

6+
import torch
7+
8+
9+
def in_slurm_job() -> bool:
10+
return "SLURM_JOB_ID" in os.environ
11+
612

713
def slurm_hosts() -> list[str]:
814
"""Retrieves hostnames of Slurm-allocated nodes.
@@ -11,7 +17,7 @@ def slurm_hosts() -> list[str]:
1117
:rtype: list[str]
1218
"""
1319
# TODO: sanity check SLURM variables, commands
14-
assert "SLURM_JOB_ID" in os.environ
20+
assert in_slurm_job()
1521
return (
1622
subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
1723
.decode()
@@ -29,10 +35,36 @@ def slurm_workers() -> int:
2935
:rtype: int
3036
"""
3137
# TODO: sanity check SLURM variables, commands
32-
assert "SLURM_JOB_ID" in os.environ
38+
assert in_slurm_job()
3339
if "SLURM_JOB_GPUS" in os.environ:
3440
# TODO: is it possible to allocate uneven GPUs across nodes?
3541
return len(os.environ["SLURM_JOB_GPUS"].split(","))
3642
else:
3743
# TODO: should we assume that we plan to do one worker per CPU?
3844
return int(os.environ["SLURM_CPUS_ON_NODE"])
45+
46+
47+
def auto_hosts() -> list[str]:
48+
"""
49+
Automatically determine hostname list
50+
51+
:return: Hostnames in Slurm allocation, or ['localhost']
52+
:rtype: list[str]
53+
"""
54+
if in_slurm_job():
55+
slurm_hosts()
56+
57+
return ["localhost"]
58+
59+
60+
def auto_workers() -> int:
61+
"""
62+
Automatically determine number of workers per host
63+
64+
:return: Workers per host
65+
:rtype: int
66+
"""
67+
if in_slurm_job():
68+
return slurm_workers()
69+
70+
return torch.cuda.device_count() or os.cpu_count() or 1

src/torchrunx/launcher.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import fabric
2121
import torch.distributed as dist
2222

23+
from .environment import auto_hosts, auto_workers
2324
from .utils import (
2425
AgentPayload,
2526
AgentStatus,
@@ -78,8 +79,9 @@ def monitor_log(log_file: Path):
7879

7980
@dataclass
8081
class Launcher:
81-
hostnames: list[str] = field(default_factory=lambda: ["localhost"])
82-
workers_per_host: int | list[int] = 1
82+
auto: bool = False
83+
hostnames: list[str] | None = field(default_factory=lambda: ["localhost"])
84+
workers_per_host: int | list[int] | None = 1
8385
ssh_config_file: str | os.PathLike | None = None
8486
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None
8587
log_dir: os.PathLike | str = "./logs"
@@ -117,6 +119,15 @@ def run(
117119
:return: A dictionary mapping worker ranks to their output
118120
:rtype: dict[int, Any]
119121
"""
122+
123+
if self.auto:
124+
if self.hostnames is None:
125+
self.hostnames = auto_hosts()
126+
if self.workers_per_host is None:
127+
self.workers_per_host = auto_workers()
128+
129+
assert self.hostnames is not None and self.workers_per_host is not None
130+
120131
if not dist.is_available():
121132
raise RuntimeError("The torch.distributed package is not available.")
122133

@@ -260,8 +271,9 @@ def launch(
260271
func: Callable,
261272
func_args: tuple[Any] = tuple(),
262273
func_kwargs: dict[str, Any] = {},
263-
hostnames: list[str] = ["localhost"],
264-
workers_per_host: int | list[int] = 1,
274+
auto: bool = False,
275+
hostnames: list[str] | None = ["localhost"],
276+
workers_per_host: int | list[int] | None = 1,
265277
ssh_config_file: str | os.PathLike | None = None,
266278
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None,
267279
log_dir: os.PathLike | str = "./logs",
@@ -287,10 +299,12 @@ def launch(
287299
:type func_args: tuple[Any]
288300
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
289301
:type func_kwargs: dict[str, Any]
302+
:param auto: Automatically determine allocation sizes, supports Slurm allocation. ``hostnames`` and ``workers_per_host`` are automatically assigned if they're set to ``None``, defaults to None
303+
:type auto: bool, optional
290304
:param hostnames: A list of node hostnames to start workers on, defaults to ["localhost"]
291-
:type hostnames: list[str], optional
305+
:type hostnames: list[str] | None, optional
292306
:param workers_per_host: The number of workers per node. Providing an ``int`` implies all nodes should have ``workers_per_host`` workers, meanwhile providing a list causes node ``i`` to have ``worker_per_host[i]`` workers, defaults to 1
293-
:type workers_per_host: int | list[int], optional
307+
:type workers_per_host: int | list[int] | None, optional
294308
:param ssh_config_file: An SSH configuration file to use when connecting to nodes, defaults to None
295309
:type ssh_config_file: str | os.PathLike | None, optional
296310
:param backend: A ``torch.distributed`` `backend string <https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend>`_, defaults to None
@@ -308,6 +322,7 @@ def launch(
308322
:rtype: dict[int, Any]
309323
""" # noqa: E501
310324
return Launcher(
325+
auto=auto,
311326
hostnames=hostnames,
312327
workers_per_host=workers_per_host,
313328
ssh_config_file=ssh_config_file,

0 commit comments

Comments
 (0)