Skip to content

examples to tests #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 17, 2024
7 changes: 3 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ jobs:
pixi-version: v0.27.1
frozen: true
cache: true
environments: default
activate-environment: default
environments: extra
activate-environment: extra
- run: pyright
if: success() || failure()

Expand Down Expand Up @@ -86,5 +86,4 @@ jobs:
cache: false
environments: default
activate-environment: default

- run: pytest tests
- run: pytest tests/test_CI.py
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ logs/
test_logs/
_build/
out/
output/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
6,839 changes: 2,365 additions & 4,474 deletions pixi.lock

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ pytest = "*"
build = "*"
twine = "*"

[feature.extra.pypi-dependencies]
transformers = "*"
submitit = "*"
setuptools = "*"
accelerate = "*"

[environments]
default = { features = ["package", "dev"], solve-group = "default" }
dev = { features = ["dev"], solve-group = "default" }
extra = { features = ["package", "dev", "extra"], solve-group = "default"}
1 change: 1 addition & 0 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def run(
raise
finally:
print_process.kill()
dist.destroy_process_group()

return_values: dict[int, Any] = dict(ChainMap(*[s.return_values for s in agent_statuses]))
return return_values
Expand Down
48 changes: 23 additions & 25 deletions tests/test_CI.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import shutil
import sys
import tempfile

import pytest
import torch
import torch.distributed as dist

sys.path.append("../src")

import torchrunx # noqa: I001
import torchrunx as trx


def test_simple_localhost():
Expand All @@ -30,38 +28,27 @@ def dist_func():

return o.detach()

r = torchrunx.launch(
func=dist_func,
func_kwargs={},
workers_per_host=2,
backend="gloo",
r = trx.launch(
func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir="./test_logs"
)

assert torch.all(r[0] == r[1])

dist.destroy_process_group()


def test_logging():
def dist_func():
rank = int(os.environ["RANK"])
print(f"worker rank: {rank}")

try:
shutil.rmtree("./test_logs", ignore_errors=True)
except FileNotFoundError:
pass

torchrunx.launch(
func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir="./test_logs"
)
tmp = tempfile.mkdtemp()
trx.launch(func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir=tmp)

log_files = next(os.walk("./test_logs"), (None, None, []))[2]
log_files = next(os.walk(tmp), (None, None, []))[2]

assert len(log_files) == 3

for file in log_files:
with open("./test_logs/" + file, "r") as f:
with open(f"{tmp}/{file}", "r") as f:
if file.endswith("0.log"):
assert f.read() == "worker rank: 0\n"
elif file.endswith("1.log"):
Expand All @@ -71,7 +58,18 @@ def dist_func():
assert "worker rank: 0" in contents
assert "worker rank: 1" in contents

# clean up
shutil.rmtree("./test_logs", ignore_errors=True)

dist.destroy_process_group()
def test_error():
def error_func():
raise ValueError("abcdefg")

with pytest.raises(RuntimeError) as excinfo:
trx.launch(
func=error_func,
func_kwargs={},
workers_per_host=1,
backend="gloo",
log_dir=tempfile.mkdtemp(),
)

assert "abcdefg" in str(excinfo.value)
17 changes: 8 additions & 9 deletions examples/slurm_poc.py → tests/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,21 @@
import torch
import torch.distributed as dist

import torchrunx

# this is not a pytest test, but a functional test designed to be run on a slurm allocation
import torchrunx as trx


def test_launch():
result = torchrunx.launch(
result = trx.launch(
func=simple_matmul,
hostnames=torchrunx.slurm_hosts(),
workers_per_host=torchrunx.slurm_workers(),
hostnames=trx.slurm_hosts(),
workers_per_host=trx.slurm_workers(),
)

t = True
for i in range(len(result)):
assert torch.all(result[i] == result[0]), "Not all tensors equal"
print(result[0])
print("PASS")
t = t and torch.all(result[i] == result[0])

assert t, "Not all tensors equal"


def simple_matmul():
Expand Down
41 changes: 30 additions & 11 deletions examples/submitit_train.py → tests/test_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,54 @@ def __getitem__(self, index):
"labels": self.labels[index],
}


def main():
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
train_dataset = DummyDataset()

## Training

training_arguments = TrainingArguments(
output_dir = "output",
do_train = True,
per_device_train_batch_size = 16,
max_steps = 20,
output_dir="output",
do_train=True,
per_device_train_batch_size=16,
max_steps=20,
)

trainer = Trainer(
model=model, # type: ignore
model=model, # type: ignore
args=training_arguments,
train_dataset=train_dataset
train_dataset=train_dataset,
)

trainer.train()


def launch():
trx.launch(
func=main,
func_kwargs={},
hostnames=trx.slurm_hosts(),
workers_per_host=trx.slurm_workers()
func=main, func_kwargs={}, hostnames=trx.slurm_hosts(), workers_per_host=trx.slurm_workers()
)


def test_submitit():
executor = submitit.SlurmExecutor(folder="logs")

executor.update_parameters(
time=60,
nodes=1,
ntasks_per_node=1,
mem="32G",
cpus_per_task=4,
gpus_per_node=2,
constraint="geforce3090",
partition="3090-gcondo",
stderr_to_stdout=True,
use_srun=False,
)

executor.submit(launch).result()


if __name__ == "__main__":
executor = submitit.SlurmExecutor(folder="logs")

Expand All @@ -68,4 +87,4 @@ def launch():
use_srun=False,
)

executor.submit(launch)
executor.submit(launch)
28 changes: 8 additions & 20 deletions examples/distributed_train.py → tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import socket
import subprocess

import torchrunx
import torchrunx as trx


def worker():
Expand Down Expand Up @@ -34,24 +32,14 @@ def forward(self, x):
loss.sum().backward()


def resolve_node_ips(nodelist):
# Expand the nodelist into individual hostnames
hostnames = (
subprocess.check_output(["scontrol", "show", "hostnames", nodelist])
.decode()
.strip()
.split("\n")
def test_distributed_train():
trx.launch(
worker,
hostnames=trx.slurm_hosts(),
workers_per_host=trx.slurm_workers(),
backend="nccl",
)
# Resolve each hostname to an IP address
ips = [socket.gethostbyname(hostname) for hostname in hostnames]
return ips


if __name__ == "__main__":
torchrunx.launch(
worker,
{},
hostnames=torchrunx.slurm_hosts(),
workers_per_host=torchrunx.slurm_workers(),
backend="nccl",
)
test_distributed_train()