Skip to content

automatic function #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/torchrunx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
from .launcher import Launcher, launch
from .slurm import slurm_hosts, slurm_workers

__all__ = ["Launcher", "launch", "slurm_hosts", "slurm_workers"]
__all__ = ["Launcher", "launch", "slurm_hosts", "slurm_workers", "auto_hosts", "auto_workers"]
36 changes: 34 additions & 2 deletions src/torchrunx/slurm.py → src/torchrunx/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import os
import subprocess

import torch


def in_slurm_job() -> bool:
return "SLURM_JOB_ID" in os.environ


def slurm_hosts() -> list[str]:
"""Retrieves hostnames of Slurm-allocated nodes.
Expand All @@ -11,7 +17,7 @@ def slurm_hosts() -> list[str]:
:rtype: list[str]
"""
# TODO: sanity check SLURM variables, commands
assert "SLURM_JOB_ID" in os.environ
assert in_slurm_job()
return (
subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
.decode()
Expand All @@ -29,10 +35,36 @@ def slurm_workers() -> int:
:rtype: int
"""
# TODO: sanity check SLURM variables, commands
assert "SLURM_JOB_ID" in os.environ
assert in_slurm_job()
if "SLURM_JOB_GPUS" in os.environ:
# TODO: is it possible to allocate uneven GPUs across nodes?
return len(os.environ["SLURM_JOB_GPUS"].split(","))
else:
# TODO: should we assume that we plan to do one worker per CPU?
return int(os.environ["SLURM_CPUS_ON_NODE"])


def auto_hosts() -> list[str]:
"""
Automatically determine hostname list

:return: Hostnames in Slurm allocation, or ['localhost']
:rtype: list[str]
"""
if in_slurm_job():
slurm_hosts()

return ["localhost"]


def auto_workers() -> int:
"""
Automatically determine number of workers per host

:return: Workers per host
:rtype: int
"""
if in_slurm_job():
return slurm_workers()

return torch.cuda.device_count() or os.cpu_count() or 1
25 changes: 19 additions & 6 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import fabric
import torch.distributed as dist

from .environment import auto_hosts, auto_workers
from .utils import (
AgentPayload,
AgentStatus,
Expand Down Expand Up @@ -78,8 +79,9 @@ def monitor_log(log_file: Path):

@dataclass
class Launcher:
hostnames: list[str] = field(default_factory=lambda: ["localhost"])
workers_per_host: int | list[int] = 1
auto: bool = False
hostnames: list[str] | None = field(default_factory=lambda: ["localhost"])
workers_per_host: int | list[int] | None = 1
ssh_config_file: str | os.PathLike | None = None
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None
log_dir: os.PathLike | str = "./logs"
Expand Down Expand Up @@ -117,6 +119,13 @@ def run(
:return: A dictionary mapping worker ranks to their output
:rtype: dict[int, Any]
"""

if self.auto:
if self.workers_per_host is None:
self.workers_per_host = auto_workers()
if self.hostnames is None:
self.hostnames = auto_hosts()

if not dist.is_available():
raise RuntimeError("The torch.distributed package is not available.")

Expand Down Expand Up @@ -260,8 +269,9 @@ def launch(
func: Callable,
func_args: tuple[Any] = tuple(),
func_kwargs: dict[str, Any] = {},
hostnames: list[str] = ["localhost"],
workers_per_host: int | list[int] = 1,
auto: bool = False,
hostnames: list[str] | None = ["localhost"],
workers_per_host: int | list[int] | None = 1,
ssh_config_file: str | os.PathLike | None = None,
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None,
log_dir: os.PathLike | str = "./logs",
Expand All @@ -287,10 +297,12 @@ def launch(
:type func_args: tuple[Any]
:param func_kwargs: Any keyword arguments to be provided when calling ``func``
:type func_kwargs: dict[str, Any]
:param auto: Automatically determine allocation sizes, supports Slurm allocation. ``hostnames`` and ``workers_per_host`` are automatically assigned if they're set to ``None``, defaults to None
:type auto: bool, optional
:param hostnames: A list of node hostnames to start workers on, defaults to ["localhost"]
:type hostnames: list[str], optional
:type hostnames: list[str] | None, optional
:param workers_per_host: The number of workers per node. Providing an ``int`` implies all nodes should have ``workers_per_host`` workers, meanwhile providing a list causes node ``i`` to have ``worker_per_host[i]`` workers, defaults to 1
:type workers_per_host: int | list[int], optional
:type workers_per_host: int | list[int] | None, optional
:param ssh_config_file: An SSH configuration file to use when connecting to nodes, defaults to None
:type ssh_config_file: str | os.PathLike | None, optional
:param backend: A ``torch.distributed`` `backend string <https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend>`_, defaults to None
Expand All @@ -308,6 +320,7 @@ def launch(
:rtype: dict[int, Any]
""" # noqa: E501
return Launcher(
auto=auto,
hostnames=hostnames,
workers_per_host=workers_per_host,
ssh_config_file=ssh_config_file,
Expand Down
Loading