From 09cf8ce488c11881c145bfb8e08179ebe41e0d3c Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Fri, 12 Jul 2024 15:47:54 -0400 Subject: [PATCH 1/2] moving unshared utils into other files --- pyproject.toml | 2 +- src/torchrunx/__init__.py | 41 ++++++++- src/torchrunx/__main__.py | 4 +- src/torchrunx/agent.py | 26 +++++- src/torchrunx/launcher.py | 54 +++++++++++- src/torchrunx/utils.py | 174 +++++++------------------------------- 6 files changed, 151 insertions(+), 150 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 18bad431..146bc7f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"] line-length = 100 src = ["src", "tests"] [tool.ruff.lint] -select = ["E", "F"] +select = ["E", "F", "I"] [tool.pyright] include = ["src", "tests"] diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index f4c0d71f..9bde09c1 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,4 +1,43 @@ from .launcher import launch -from .utils import slurm_hosts, slurm_workers + + +def slurm_hosts() -> list[str]: + """Retrieves hostnames of Slurm-allocated nodes. + + :return: Hostnames of nodes in current Slurm allocation + :rtype: list[str] + """ + import os + import subprocess + + # TODO: sanity check SLURM variables, commands + assert "SLURM_JOB_ID" in os.environ + return ( + subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]) + .decode() + .strip() + .split("\n") + ) + + +def slurm_workers() -> int: + """ + | Determines number of workers per node in current Slurm allocation using + | the ``SLURM_JOB_GPUS`` or ``SLURM_CPUS_ON_NODE`` environmental variables. + + :return: The implied number of workers per node + :rtype: int + """ + import os + + # TODO: sanity check SLURM variables, commands + assert "SLURM_JOB_ID" in os.environ + if "SLURM_JOB_GPUS" in os.environ: + # TODO: is it possible to allocate uneven GPUs across nodes? + return len(os.environ["SLURM_JOB_GPUS"].split(",")) + else: + # TODO: should we assume that we plan to do one worker per CPU? + return int(os.environ["SLURM_CPUS_ON_NODE"]) + __all__ = ["launch", "slurm_hosts", "slurm_workers"] diff --git a/src/torchrunx/__main__.py b/src/torchrunx/__main__.py index 76cc92ff..f458d37d 100644 --- a/src/torchrunx/__main__.py +++ b/src/torchrunx/__main__.py @@ -1,6 +1,6 @@ from argparse import ArgumentParser -from . import agent +from .agent import main from .utils import LauncherAgentGroup if __name__ == "__main__": @@ -18,4 +18,4 @@ rank=args.rank, ) - agent.main(launcher_agent_group) + main(launcher_agent_group) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 6a206154..7145b2a0 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -2,6 +2,7 @@ import os import socket +import sys from dataclasses import dataclass from typing import Callable, Literal @@ -17,7 +18,6 @@ AgentStatus, LauncherAgentGroup, LauncherPayload, - WorkerTee, get_open_port, ) @@ -42,6 +42,30 @@ def from_bytes(cls, serialized: bytes) -> Self: return cloudpickle.loads(serialized) +class WorkerTee(object): + def __init__(self, name: os.PathLike | str, mode: str): + self.file = open(name, mode) + self.stdout = sys.stdout + sys.stdout = self + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, exception_traceback): + self.__del__() + + def __del__(self): + sys.stdout = self.stdout + self.file.close() + + def write(self, data): + self.file.write(data) + self.stdout.write(data) + + def flush(self): + self.file.flush() + + def entrypoint(serialized_worker_args: bytes): worker_args = WorkerArgs.from_bytes(serialized_worker_args) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 491fa8a0..ec8d8efe 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -2,16 +2,21 @@ import datetime import fnmatch +import io +import ipaddress import itertools import os import socket +import subprocess import sys +import time from collections import ChainMap from functools import partial from multiprocessing import Process from pathlib import Path from typing import Any, Callable, Literal +import fabric import torch.distributed as dist from .utils import ( @@ -19,12 +24,57 @@ AgentStatus, LauncherAgentGroup, LauncherPayload, - execute_command, get_open_port, - monitor_log, ) +def is_localhost(hostname_or_ip: str) -> bool: + # check if host is "loopback" address (i.e. designated to send to self) + try: + ip = ipaddress.ip_address(hostname_or_ip) + except ValueError: + ip = ipaddress.ip_address(socket.gethostbyname(hostname_or_ip)) + if ip.is_loopback: + return True + # else compare local interface addresses between host and localhost + host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(ip), None)] + localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] + return len(set(host_addrs) & set(localhost_addrs)) > 0 + + +def execute_command( + command: str, + hostname: str, + ssh_config_file: str | os.PathLike | None = None, + outfile: str | os.PathLike | None = None, +) -> None: + # TODO: permit different stderr / stdout + if is_localhost(hostname): + _outfile = subprocess.DEVNULL + if outfile is not None: + _outfile = open(outfile, "w") + subprocess.Popen(command, shell=True, stdout=_outfile, stderr=_outfile) + else: + with fabric.Connection( + host=hostname, config=fabric.Config(runtime_ssh_path=ssh_config_file) + ) as conn: + if outfile is None: + outfile = "/dev/null" + conn.run(f"{command} >> {outfile} 2>&1 &", asynchronous=True) + + +def monitor_log(log_file: Path): + log_file.touch() + f = open(log_file, "r") + print(f.read()) + f.seek(0, io.SEEK_END) + while True: + new = f.read() + if len(new) != 0: + print(new) + time.sleep(0.1) + + def launch( func: Callable, func_kwargs: dict[str, Any], diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index f1b319c0..a82a7e53 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -1,60 +1,19 @@ from __future__ import annotations import datetime -import io -import ipaddress -import os import socket -import subprocess -import sys -import time from contextlib import closing from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Literal import cloudpickle -import fabric import torch.distributed as dist from torch.distributed.elastic.multiprocessing.api import RunProcsResult from torch.distributed.elastic.multiprocessing.errors import ProcessFailure from typing_extensions import Self -def slurm_hosts() -> list[str]: - """Retrieves hostnames of Slurm-allocated nodes. - - :return: Hostnames of nodes in current Slurm allocation - :rtype: list[str] - """ - # TODO: sanity check SLURM variables, commands - assert "SLURM_JOB_ID" in os.environ - return ( - subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]) - .decode() - .strip() - .split("\n") - ) - - -def slurm_workers() -> int: - """ - | Determines number of workers per node in current Slurm allocation using - | the ``SLURM_JOB_GPUS`` or ``SLURM_CPUS_ON_NODE`` environmental variables. - - :return: The implied number of workers per node - :rtype: int - """ - # TODO: sanity check SLURM variables, commands - assert "SLURM_JOB_ID" in os.environ - if "SLURM_JOB_GPUS" in os.environ: - # TODO: is it possible to allocate uneven GPUs across nodes? - return len(os.environ["SLURM_JOB_GPUS"].split(",")) - else: - # TODO: should we assume that we plan to do one worker per CPU? - return int(os.environ["SLURM_CPUS_ON_NODE"]) - - def get_open_port() -> int: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("", 0)) @@ -62,41 +21,6 @@ def get_open_port() -> int: return port -def is_localhost(hostname_or_ip: str) -> bool: - # check if host is "loopback" address (i.e. designated to send to self) - try: - ip = ipaddress.ip_address(hostname_or_ip) - except ValueError: - ip = ipaddress.ip_address(socket.gethostbyname(hostname_or_ip)) - if ip.is_loopback: - return True - # else compare local interface addresses between host and localhost - host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(ip), None)] - localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] - return len(set(host_addrs) & set(localhost_addrs)) > 0 - - -def execute_command( - command: str, - hostname: str, - ssh_config_file: str | os.PathLike | None = None, - outfile: str | os.PathLike | None = None, -) -> None: - # TODO: permit different stderr / stdout - if is_localhost(hostname): - _outfile = subprocess.DEVNULL - if outfile is not None: - _outfile = open(outfile, "w") - subprocess.Popen(command, shell=True, stdout=_outfile, stderr=_outfile) - else: - with fabric.Connection( - host=hostname, config=fabric.Config(runtime_ssh_path=ssh_config_file) - ) as conn: - if outfile is None: - outfile = "/dev/null" - conn.run(f"{command} >> {outfile} 2>&1 &", asynchronous=True) - - @dataclass class LauncherPayload: fn: Callable @@ -114,6 +38,37 @@ class AgentPayload: process_id: int +@dataclass +class AgentStatus: + running: bool = True + failed: bool = False + return_values: dict[int, Any] = field(default_factory=dict) + failures: dict[int, ProcessFailure] = field(default_factory=dict) + stdouts: dict[int, str] = field(default_factory=dict) + stderrs: dict[int, str] = field(default_factory=dict) + + @classmethod + def from_result(cls, result: RunProcsResult | None, worker_global_ranks: list[int]) -> Self: + if result is None: + return cls() + + return cls( + running=False, + failed=result.is_failed(), + return_values={worker_global_ranks[k]: v for k, v in result.return_values.items()}, + failures={worker_global_ranks[k]: v for k, v in result.failures.items()}, + ) + + def is_running(self) -> bool: + return self.running + + def is_failed(self) -> bool: + return self.failed + + def is_done(self) -> bool: + return not self.running and not self.failed + + @dataclass class LauncherAgentGroup: launcher_hostname: str @@ -156,70 +111,3 @@ def sync_payloads( def sync_agent_statuses(self, status: AgentStatus) -> list[AgentStatus]: return self._all_gather(object=status)[1:] - - -@dataclass -class AgentStatus: - running: bool = True - failed: bool = False - return_values: dict[int, Any] = field(default_factory=dict) - failures: dict[int, ProcessFailure] = field(default_factory=dict) - stdouts: dict[int, str] = field(default_factory=dict) - stderrs: dict[int, str] = field(default_factory=dict) - - @classmethod - def from_result(cls, result: RunProcsResult | None, worker_global_ranks: list[int]) -> Self: - if result is None: - return cls() - - return cls( - running=False, - failed=result.is_failed(), - return_values={worker_global_ranks[k]: v for k, v in result.return_values.items()}, - failures={worker_global_ranks[k]: v for k, v in result.failures.items()}, - ) - - def is_running(self) -> bool: - return self.running - - def is_failed(self) -> bool: - return self.failed - - def is_done(self) -> bool: - return not self.running and not self.failed - - -class WorkerTee(object): - def __init__(self, name: os.PathLike | str, mode: str): - self.file = open(name, mode) - self.stdout = sys.stdout - sys.stdout = self - - def __enter__(self): - return self - - def __exit__(self, exception_type, exception_value, exception_traceback): - self.__del__() - - def __del__(self): - sys.stdout = self.stdout - self.file.close() - - def write(self, data): - self.file.write(data) - self.stdout.write(data) - - def flush(self): - self.file.flush() - - -def monitor_log(log_file: Path): - log_file.touch() - f = open(log_file, "r") - print(f.read()) - f.seek(0, io.SEEK_END) - while True: - new = f.read() - if len(new) != 0: - print(new) - time.sleep(0.1) From 59b2b4a2884fdd5048955ea6ce82496ebfd2de25 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Fri, 12 Jul 2024 15:53:28 -0400 Subject: [PATCH 2/2] pyright fixes --- src/torchrunx/__init__.py | 2 ++ src/torchrunx/launcher.py | 12 +++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index 9bde09c1..b811e659 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .launcher import launch diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index ec8d8efe..3a2f45da 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -133,16 +133,18 @@ def launch( # launch command - env_export_string = "" env_exports = [] for k, v in os.environ.items(): - for e in env_vars: - if any(fnmatch.fnmatch(k, e)): - env_exports.append(f"{k}={v}") + if any(fnmatch.fnmatch(k, e) for e in env_vars): + env_exports.append(f"{k}={v}") + + env_export_string = "" if len(env_exports) > 0: env_export_string = f"export {' '.join(env_exports)} && " - env_file_string = f"source {env_file} && " if env_file is not None else "" + env_file_string = "" + if env_file is not None: + env_file_string = f"source {env_file} && " launcher_hostname = socket.getfqdn() launcher_port = get_open_port()