Skip to content

moving unshared utils into other files #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
line-length = 100
src = ["src", "tests"]
[tool.ruff.lint]
select = ["E", "F"]
select = ["E", "F", "I"]

[tool.pyright]
include = ["src", "tests"]
Expand Down
43 changes: 42 additions & 1 deletion src/torchrunx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,45 @@
from __future__ import annotations

from .launcher import launch
from .utils import slurm_hosts, slurm_workers


def slurm_hosts() -> list[str]:
"""Retrieves hostnames of Slurm-allocated nodes.

:return: Hostnames of nodes in current Slurm allocation
:rtype: list[str]
"""
import os
import subprocess

# TODO: sanity check SLURM variables, commands
assert "SLURM_JOB_ID" in os.environ
return (
subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
.decode()
.strip()
.split("\n")
)


def slurm_workers() -> int:
"""
| Determines number of workers per node in current Slurm allocation using
| the ``SLURM_JOB_GPUS`` or ``SLURM_CPUS_ON_NODE`` environmental variables.

:return: The implied number of workers per node
:rtype: int
"""
import os

# TODO: sanity check SLURM variables, commands
assert "SLURM_JOB_ID" in os.environ
if "SLURM_JOB_GPUS" in os.environ:
# TODO: is it possible to allocate uneven GPUs across nodes?
return len(os.environ["SLURM_JOB_GPUS"].split(","))
else:
# TODO: should we assume that we plan to do one worker per CPU?
return int(os.environ["SLURM_CPUS_ON_NODE"])


__all__ = ["launch", "slurm_hosts", "slurm_workers"]
4 changes: 2 additions & 2 deletions src/torchrunx/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from argparse import ArgumentParser

from . import agent
from .agent import main
from .utils import LauncherAgentGroup

if __name__ == "__main__":
Expand All @@ -18,4 +18,4 @@
rank=args.rank,
)

agent.main(launcher_agent_group)
main(launcher_agent_group)
26 changes: 25 additions & 1 deletion src/torchrunx/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import socket
import sys
from dataclasses import dataclass
from typing import Callable, Literal

Expand All @@ -17,7 +18,6 @@
AgentStatus,
LauncherAgentGroup,
LauncherPayload,
WorkerTee,
get_open_port,
)

Expand All @@ -42,6 +42,30 @@ def from_bytes(cls, serialized: bytes) -> Self:
return cloudpickle.loads(serialized)


class WorkerTee(object):
def __init__(self, name: os.PathLike | str, mode: str):
self.file = open(name, mode)
self.stdout = sys.stdout
sys.stdout = self

def __enter__(self):
return self

def __exit__(self, exception_type, exception_value, exception_traceback):
self.__del__()

def __del__(self):
sys.stdout = self.stdout
self.file.close()

def write(self, data):
self.file.write(data)
self.stdout.write(data)

def flush(self):
self.file.flush()


def entrypoint(serialized_worker_args: bytes):
worker_args = WorkerArgs.from_bytes(serialized_worker_args)

Expand Down
66 changes: 59 additions & 7 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,79 @@

import datetime
import fnmatch
import io
import ipaddress
import itertools
import os
import socket
import subprocess
import sys
import time
from collections import ChainMap
from functools import partial
from multiprocessing import Process
from pathlib import Path
from typing import Any, Callable, Literal

import fabric
import torch.distributed as dist

from .utils import (
AgentPayload,
AgentStatus,
LauncherAgentGroup,
LauncherPayload,
execute_command,
get_open_port,
monitor_log,
)


def is_localhost(hostname_or_ip: str) -> bool:
# check if host is "loopback" address (i.e. designated to send to self)
try:
ip = ipaddress.ip_address(hostname_or_ip)
except ValueError:
ip = ipaddress.ip_address(socket.gethostbyname(hostname_or_ip))
if ip.is_loopback:
return True
# else compare local interface addresses between host and localhost
host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(ip), None)]
localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)]
return len(set(host_addrs) & set(localhost_addrs)) > 0


def execute_command(
command: str,
hostname: str,
ssh_config_file: str | os.PathLike | None = None,
outfile: str | os.PathLike | None = None,
) -> None:
# TODO: permit different stderr / stdout
if is_localhost(hostname):
_outfile = subprocess.DEVNULL
if outfile is not None:
_outfile = open(outfile, "w")
subprocess.Popen(command, shell=True, stdout=_outfile, stderr=_outfile)
else:
with fabric.Connection(
host=hostname, config=fabric.Config(runtime_ssh_path=ssh_config_file)
) as conn:
if outfile is None:
outfile = "/dev/null"
conn.run(f"{command} >> {outfile} 2>&1 &", asynchronous=True)


def monitor_log(log_file: Path):
log_file.touch()
f = open(log_file, "r")
print(f.read())
f.seek(0, io.SEEK_END)
while True:
new = f.read()
if len(new) != 0:
print(new)
time.sleep(0.1)


def launch(
func: Callable,
func_kwargs: dict[str, Any],
Expand Down Expand Up @@ -83,16 +133,18 @@ def launch(

# launch command

env_export_string = ""
env_exports = []
for k, v in os.environ.items():
for e in env_vars:
if any(fnmatch.fnmatch(k, e)):
env_exports.append(f"{k}={v}")
if any(fnmatch.fnmatch(k, e) for e in env_vars):
env_exports.append(f"{k}={v}")

env_export_string = ""
if len(env_exports) > 0:
env_export_string = f"export {' '.join(env_exports)} && "

env_file_string = f"source {env_file} && " if env_file is not None else ""
env_file_string = ""
if env_file is not None:
env_file_string = f"source {env_file} && "

launcher_hostname = socket.getfqdn()
launcher_port = get_open_port()
Expand Down
Loading