Skip to content

Misc refactoring #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ jobs:
cache: false
environments: default
activate-environment: default
- run: pytest tests/test_CI.py
- run: pytest tests/test_ci.py
4 changes: 2 additions & 2 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 19 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "torchrunx"
version = "0.1.3"
version = "0.2.0"
authors = [
{name = "Apoorv Khandelwal", email = "[email protected]"},
{name = "Peter Curtin", email = "[email protected]"},
Expand Down Expand Up @@ -41,7 +41,24 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
line-length = 100
src = ["src", "tests"]
[tool.ruff.lint]
select = ["E", "F", "B", "UP", "I"]
select = ["ALL"]
ignore = [
"D", # documentation
"ANN101", "ANN102", "ANN401", # self / cls / Any annotations
"BLE001", # blind exceptions
"TD", # todo syntax
"FIX002", # existing todos
"PLR0913", # too many arguments
"DTZ005", # datetime timezone
"S301", # bandit: pickle
"S603", "S607", # bandit: subprocess
"COM812", "ISC001", # conflict with formatter
]
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = [
"S101", # allow asserts
"T201" # allow prints
]

[tool.pyright]
include = ["src", "tests"]
Expand Down
4 changes: 4 additions & 0 deletions src/torchrunx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .launcher import Launcher, launch
from .logging_utils import add_filter_to_handler, file_handler, stream_handler

__all__ = [
"Launcher",
"launch",
"add_filter_to_handler",
"file_handler",
"stream_handler",
]
66 changes: 28 additions & 38 deletions src/torchrunx/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import socket
import sys
import tempfile
import traceback
from dataclasses import dataclass
from typing import Any, Callable, Literal

import cloudpickle
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing import start_processes
from typing_extensions import Self
import torch.distributed.elastic.multiprocessing as dist_mp

from .logging_utils import log_records_to_socket, redirect_stdio_to_logger
from .utils import (
Expand All @@ -40,16 +40,20 @@ class WorkerArgs:
hostname: str
timeout: int

def to_bytes(self) -> bytes:
return cloudpickle.dumps(self)
def serialize(self) -> SerializedWorkerArgs:
return SerializedWorkerArgs(worker_args=self)

@classmethod
def from_bytes(cls, serialized: bytes) -> Self:
return cloudpickle.loads(serialized)

class SerializedWorkerArgs:
def __init__(self, worker_args: WorkerArgs) -> None:
self.bytes = cloudpickle.dumps(worker_args)

def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:
worker_args = WorkerArgs.from_bytes(serialized_worker_args)
def deserialize(self) -> WorkerArgs:
return cloudpickle.loads(self.bytes)


def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException:
worker_args: WorkerArgs = serialized_worker_args.deserialize()

logger = logging.getLogger()

Expand All @@ -63,18 +67,14 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:

redirect_stdio_to_logger(logger)

store = dist.TCPStore( # pyright: ignore[reportPrivateImportUsage]
store = dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
host_name=worker_args.main_agent_hostname,
port=worker_args.main_agent_port,
world_size=worker_args.world_size,
is_master=(worker_args.rank == 0),
)

backend = worker_args.backend
if backend is None:
backend = "nccl" if torch.cuda.is_available() else "gloo"

logger.debug(f"using backend: {backend}")
backend = worker_args.backend or ("nccl" if torch.cuda.is_available() else "gloo")

dist.init_process_group(
backend=backend,
Expand All @@ -91,19 +91,17 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)

logger.debug(f"executing function: {worker_args.function}")

try:
return worker_args.function()
except Exception as e:
logger.error(e)
traceback.print_exc()
return WorkerException(exception=e)
finally:
sys.stdout.flush()
sys.stderr.flush()


def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int):
def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None:
agent_rank = launcher_agent_group.rank - 1

payload = AgentPayload(
Expand Down Expand Up @@ -132,16 +130,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_

redirect_stdio_to_logger(logger)

if torch.__version__ >= "2.3":
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs

log_kwargs = {"logs_specs": DefaultLogsSpecs(log_dir=tempfile.mkdtemp())}
else:
log_kwargs = {"log_dir": tempfile.mkdtemp()}

# spawn workers

ctx = start_processes(
ctx = dist_mp.start_processes(
name=f"{hostname}_",
entrypoint=entrypoint,
args={
Expand All @@ -159,31 +150,30 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
world_size=worker_world_size,
hostname=launcher_payload.hostnames[agent_rank],
timeout=launcher_payload.timeout,
).to_bytes(),
).serialize(),
)
for i in range(num_workers)
},
envs={i: {} for i in range(num_workers)},
**log_kwargs, # pyright: ignore [reportArgumentType]
**(
{"logs_specs": dist_mp.DefaultLogsSpecs(log_dir=tempfile.mkdtemp())}
if torch.__version__ >= "2.3"
else {"log_dir": tempfile.mkdtemp()}
), # pyright: ignore [reportArgumentType]
)
logger.info("starting processes")

try:
status = None
while True:
if status is None or status.state == "running":
status = AgentStatus.from_result(
result=ctx.wait(5), worker_global_ranks=worker_global_ranks
)
status = AgentStatus.from_result(ctx.wait(5))

agent_statuses = launcher_agent_group.sync_agent_statuses(status=status)

if all(s.state == "done" for s in agent_statuses):
break
elif any(s.state == "failed" for s in agent_statuses):
all_done = all(s.state == "done" for s in agent_statuses)
any_failed = any(s.state == "failed" for s in agent_statuses)
if all_done or any_failed:
break
except:
raise
finally:
ctx.close()
sys.stdout.flush()
Expand Down
17 changes: 11 additions & 6 deletions src/torchrunx/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def slurm_hosts() -> list[str]:
:rtype: list[str]
"""
# TODO: sanity check SLURM variables, commands
assert in_slurm_job()
if not in_slurm_job():
msg = "Not in a SLURM job"
raise RuntimeError(msg)
return (
subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
.decode()
Expand All @@ -35,15 +37,18 @@ def slurm_workers() -> int:
:rtype: int
"""
# TODO: sanity check SLURM variables, commands
assert in_slurm_job()
if not in_slurm_job():
msg = "Not in a SLURM job"
raise RuntimeError(msg)

if "SLURM_JOB_GPUS" in os.environ:
# TODO: is it possible to allocate uneven GPUs across nodes?
return len(os.environ["SLURM_JOB_GPUS"].split(","))
elif "SLURM_GPUS_PER_NODE" in os.environ:
if "SLURM_GPUS_PER_NODE" in os.environ:
return int(os.environ["SLURM_GPUS_PER_NODE"])
else:
# TODO: should we assume that we plan to do one worker per CPU?
return int(os.environ["SLURM_CPUS_ON_NODE"])

# TODO: should we assume that we plan to do one worker per CPU?
return int(os.environ["SLURM_CPUS_ON_NODE"])


def auto_hosts() -> list[str]:
Expand Down
Loading