Skip to content

Commit cfe9f42

Browse files
authored
Merge pull request #39 from apoorvkh/moving-functions
moving unshared utils into other files
2 parents b5241af + 59b2b4a commit cfe9f42

File tree

6 files changed

+160
-155
lines changed

6 files changed

+160
-155
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
3939
line-length = 100
4040
src = ["src", "tests"]
4141
[tool.ruff.lint]
42-
select = ["E", "F"]
42+
select = ["E", "F", "I"]
4343

4444
[tool.pyright]
4545
include = ["src", "tests"]

src/torchrunx/__init__.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,45 @@
1+
from __future__ import annotations
2+
13
from .launcher import launch
2-
from .utils import slurm_hosts, slurm_workers
4+
5+
6+
def slurm_hosts() -> list[str]:
7+
"""Retrieves hostnames of Slurm-allocated nodes.
8+
9+
:return: Hostnames of nodes in current Slurm allocation
10+
:rtype: list[str]
11+
"""
12+
import os
13+
import subprocess
14+
15+
# TODO: sanity check SLURM variables, commands
16+
assert "SLURM_JOB_ID" in os.environ
17+
return (
18+
subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
19+
.decode()
20+
.strip()
21+
.split("\n")
22+
)
23+
24+
25+
def slurm_workers() -> int:
26+
"""
27+
| Determines number of workers per node in current Slurm allocation using
28+
| the ``SLURM_JOB_GPUS`` or ``SLURM_CPUS_ON_NODE`` environmental variables.
29+
30+
:return: The implied number of workers per node
31+
:rtype: int
32+
"""
33+
import os
34+
35+
# TODO: sanity check SLURM variables, commands
36+
assert "SLURM_JOB_ID" in os.environ
37+
if "SLURM_JOB_GPUS" in os.environ:
38+
# TODO: is it possible to allocate uneven GPUs across nodes?
39+
return len(os.environ["SLURM_JOB_GPUS"].split(","))
40+
else:
41+
# TODO: should we assume that we plan to do one worker per CPU?
42+
return int(os.environ["SLURM_CPUS_ON_NODE"])
43+
344

445
__all__ = ["launch", "slurm_hosts", "slurm_workers"]

src/torchrunx/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from argparse import ArgumentParser
22

3-
from . import agent
3+
from .agent import main
44
from .utils import LauncherAgentGroup
55

66
if __name__ == "__main__":
@@ -18,4 +18,4 @@
1818
rank=args.rank,
1919
)
2020

21-
agent.main(launcher_agent_group)
21+
main(launcher_agent_group)

src/torchrunx/agent.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import socket
5+
import sys
56
from dataclasses import dataclass
67
from typing import Callable, Literal
78

@@ -17,7 +18,6 @@
1718
AgentStatus,
1819
LauncherAgentGroup,
1920
LauncherPayload,
20-
WorkerTee,
2121
get_open_port,
2222
)
2323

@@ -42,6 +42,30 @@ def from_bytes(cls, serialized: bytes) -> Self:
4242
return cloudpickle.loads(serialized)
4343

4444

45+
class WorkerTee(object):
46+
def __init__(self, name: os.PathLike | str, mode: str):
47+
self.file = open(name, mode)
48+
self.stdout = sys.stdout
49+
sys.stdout = self
50+
51+
def __enter__(self):
52+
return self
53+
54+
def __exit__(self, exception_type, exception_value, exception_traceback):
55+
self.__del__()
56+
57+
def __del__(self):
58+
sys.stdout = self.stdout
59+
self.file.close()
60+
61+
def write(self, data):
62+
self.file.write(data)
63+
self.stdout.write(data)
64+
65+
def flush(self):
66+
self.file.flush()
67+
68+
4569
def entrypoint(serialized_worker_args: bytes):
4670
worker_args = WorkerArgs.from_bytes(serialized_worker_args)
4771

src/torchrunx/launcher.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,79 @@
22

33
import datetime
44
import fnmatch
5+
import io
6+
import ipaddress
57
import itertools
68
import os
79
import socket
10+
import subprocess
811
import sys
12+
import time
913
from collections import ChainMap
1014
from functools import partial
1115
from multiprocessing import Process
1216
from pathlib import Path
1317
from typing import Any, Callable, Literal
1418

19+
import fabric
1520
import torch.distributed as dist
1621

1722
from .utils import (
1823
AgentPayload,
1924
AgentStatus,
2025
LauncherAgentGroup,
2126
LauncherPayload,
22-
execute_command,
2327
get_open_port,
24-
monitor_log,
2528
)
2629

2730

31+
def is_localhost(hostname_or_ip: str) -> bool:
32+
# check if host is "loopback" address (i.e. designated to send to self)
33+
try:
34+
ip = ipaddress.ip_address(hostname_or_ip)
35+
except ValueError:
36+
ip = ipaddress.ip_address(socket.gethostbyname(hostname_or_ip))
37+
if ip.is_loopback:
38+
return True
39+
# else compare local interface addresses between host and localhost
40+
host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(ip), None)]
41+
localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)]
42+
return len(set(host_addrs) & set(localhost_addrs)) > 0
43+
44+
45+
def execute_command(
46+
command: str,
47+
hostname: str,
48+
ssh_config_file: str | os.PathLike | None = None,
49+
outfile: str | os.PathLike | None = None,
50+
) -> None:
51+
# TODO: permit different stderr / stdout
52+
if is_localhost(hostname):
53+
_outfile = subprocess.DEVNULL
54+
if outfile is not None:
55+
_outfile = open(outfile, "w")
56+
subprocess.Popen(command, shell=True, stdout=_outfile, stderr=_outfile)
57+
else:
58+
with fabric.Connection(
59+
host=hostname, config=fabric.Config(runtime_ssh_path=ssh_config_file)
60+
) as conn:
61+
if outfile is None:
62+
outfile = "/dev/null"
63+
conn.run(f"{command} >> {outfile} 2>&1 &", asynchronous=True)
64+
65+
66+
def monitor_log(log_file: Path):
67+
log_file.touch()
68+
f = open(log_file, "r")
69+
print(f.read())
70+
f.seek(0, io.SEEK_END)
71+
while True:
72+
new = f.read()
73+
if len(new) != 0:
74+
print(new)
75+
time.sleep(0.1)
76+
77+
2878
def launch(
2979
func: Callable,
3080
func_kwargs: dict[str, Any],
@@ -83,16 +133,18 @@ def launch(
83133

84134
# launch command
85135

86-
env_export_string = ""
87136
env_exports = []
88137
for k, v in os.environ.items():
89-
for e in env_vars:
90-
if any(fnmatch.fnmatch(k, e)):
91-
env_exports.append(f"{k}={v}")
138+
if any(fnmatch.fnmatch(k, e) for e in env_vars):
139+
env_exports.append(f"{k}={v}")
140+
141+
env_export_string = ""
92142
if len(env_exports) > 0:
93143
env_export_string = f"export {' '.join(env_exports)} && "
94144

95-
env_file_string = f"source {env_file} && " if env_file is not None else ""
145+
env_file_string = ""
146+
if env_file is not None:
147+
env_file_string = f"source {env_file} && "
96148

97149
launcher_hostname = socket.getfqdn()
98150
launcher_port = get_open_port()

0 commit comments

Comments
 (0)