Skip to content

Commit 4cbc660

Browse files
authored
Merge pull request #40 from apoorvkh/argument-class
Change `launch` to `Launcher` dataclass
2 parents 952bb54 + cc9c1d0 commit 4cbc660

File tree

5 files changed

+239
-207
lines changed

5 files changed

+239
-207
lines changed

.github/workflows/docs-preview.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
name: Readthedocs preview
22
on:
33
pull_request_target:
4-
types:
5-
- opened
64
paths:
75
- "docs/**"
86

docs/source/api.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
torchrunx API
1+
API
22
=============
33

44
.. automodule:: torchrunx
5-
:members:
5+
:members:

src/torchrunx/__init__.py

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,4 @@
1-
from __future__ import annotations
1+
from .launcher import Launcher, launch
2+
from .slurm import slurm_hosts, slurm_workers
23

3-
from .launcher import launch
4-
5-
6-
def slurm_hosts() -> list[str]:
7-
"""Retrieves hostnames of Slurm-allocated nodes.
8-
9-
:return: Hostnames of nodes in current Slurm allocation
10-
:rtype: list[str]
11-
"""
12-
import os
13-
import subprocess
14-
15-
# TODO: sanity check SLURM variables, commands
16-
assert "SLURM_JOB_ID" in os.environ
17-
return (
18-
subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
19-
.decode()
20-
.strip()
21-
.split("\n")
22-
)
23-
24-
25-
def slurm_workers() -> int:
26-
"""
27-
| Determines number of workers per node in current Slurm allocation using
28-
| the ``SLURM_JOB_GPUS`` or ``SLURM_CPUS_ON_NODE`` environmental variables.
29-
30-
:return: The implied number of workers per node
31-
:rtype: int
32-
"""
33-
import os
34-
35-
# TODO: sanity check SLURM variables, commands
36-
assert "SLURM_JOB_ID" in os.environ
37-
if "SLURM_JOB_GPUS" in os.environ:
38-
# TODO: is it possible to allocate uneven GPUs across nodes?
39-
return len(os.environ["SLURM_JOB_GPUS"].split(","))
40-
else:
41-
# TODO: should we assume that we plan to do one worker per CPU?
42-
return int(os.environ["SLURM_CPUS_ON_NODE"])
43-
44-
45-
__all__ = ["launch", "slurm_hosts", "slurm_workers"]
4+
__all__ = ["Launcher", "launch", "slurm_hosts", "slurm_workers"]

0 commit comments

Comments
 (0)