Skip to content

Commit 4022811

Browse files
authored
Merge pull request #45 from apoorvkh/testing
examples to tests
2 parents fd40065 + a7714d0 commit 4022811

File tree

9 files changed

+2446
-4543
lines changed

9 files changed

+2446
-4543
lines changed

.github/workflows/main.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ jobs:
3636
pixi-version: v0.27.1
3737
frozen: true
3838
cache: true
39-
environments: default
40-
activate-environment: default
39+
environments: extra
40+
activate-environment: extra
4141
- run: pyright
4242
if: success() || failure()
4343

@@ -86,5 +86,4 @@ jobs:
8686
cache: false
8787
environments: default
8888
activate-environment: default
89-
90-
- run: pytest tests
89+
- run: pytest tests/test_CI.py

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ logs/
33
test_logs/
44
_build/
55
out/
6+
output/
67

78
# Byte-compiled / optimized / DLL files
89
__pycache__/

pixi.lock

Lines changed: 2365 additions & 4474 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ pytest = "*"
1919
build = "*"
2020
twine = "*"
2121

22+
[feature.extra.pypi-dependencies]
23+
transformers = "*"
24+
submitit = "*"
25+
setuptools = "*"
26+
accelerate = "*"
27+
2228
[environments]
2329
default = { features = ["package", "dev"], solve-group = "default" }
2430
dev = { features = ["dev"], solve-group = "default" }
31+
extra = { features = ["package", "dev", "extra"], solve-group = "default"}

src/torchrunx/launcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def run(
262262
raise
263263
finally:
264264
print_process.kill()
265+
dist.destroy_process_group()
265266

266267
return_values: dict[int, Any] = dict(ChainMap(*[s.return_values for s in agent_statuses]))
267268
return return_values

tests/test_CI.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import os
2-
import shutil
3-
import sys
2+
import tempfile
43

4+
import pytest
55
import torch
66
import torch.distributed as dist
77

8-
sys.path.append("../src")
9-
10-
import torchrunx # noqa: I001
8+
import torchrunx as trx
119

1210

1311
def test_simple_localhost():
@@ -30,38 +28,27 @@ def dist_func():
3028

3129
return o.detach()
3230

33-
r = torchrunx.launch(
34-
func=dist_func,
35-
func_kwargs={},
36-
workers_per_host=2,
37-
backend="gloo",
31+
r = trx.launch(
32+
func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir="./test_logs"
3833
)
3934

4035
assert torch.all(r[0] == r[1])
4136

42-
dist.destroy_process_group()
43-
4437

4538
def test_logging():
4639
def dist_func():
4740
rank = int(os.environ["RANK"])
4841
print(f"worker rank: {rank}")
4942

50-
try:
51-
shutil.rmtree("./test_logs", ignore_errors=True)
52-
except FileNotFoundError:
53-
pass
54-
55-
torchrunx.launch(
56-
func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir="./test_logs"
57-
)
43+
tmp = tempfile.mkdtemp()
44+
trx.launch(func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir=tmp)
5845

59-
log_files = next(os.walk("./test_logs"), (None, None, []))[2]
46+
log_files = next(os.walk(tmp), (None, None, []))[2]
6047

6148
assert len(log_files) == 3
6249

6350
for file in log_files:
64-
with open("./test_logs/" + file, "r") as f:
51+
with open(f"{tmp}/{file}", "r") as f:
6552
if file.endswith("0.log"):
6653
assert f.read() == "worker rank: 0\n"
6754
elif file.endswith("1.log"):
@@ -71,7 +58,18 @@ def dist_func():
7158
assert "worker rank: 0" in contents
7259
assert "worker rank: 1" in contents
7360

74-
# clean up
75-
shutil.rmtree("./test_logs", ignore_errors=True)
7661

77-
dist.destroy_process_group()
62+
def test_error():
63+
def error_func():
64+
raise ValueError("abcdefg")
65+
66+
with pytest.raises(RuntimeError) as excinfo:
67+
trx.launch(
68+
func=error_func,
69+
func_kwargs={},
70+
workers_per_host=1,
71+
backend="gloo",
72+
log_dir=tempfile.mkdtemp(),
73+
)
74+
75+
assert "abcdefg" in str(excinfo.value)

examples/slurm_poc.py renamed to tests/test_func.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,21 @@
33
import torch
44
import torch.distributed as dist
55

6-
import torchrunx
7-
8-
# this is not a pytest test, but a functional test designed to be run on a slurm allocation
6+
import torchrunx as trx
97

108

119
def test_launch():
12-
result = torchrunx.launch(
10+
result = trx.launch(
1311
func=simple_matmul,
14-
hostnames=torchrunx.slurm_hosts(),
15-
workers_per_host=torchrunx.slurm_workers(),
12+
hostnames=trx.slurm_hosts(),
13+
workers_per_host=trx.slurm_workers(),
1614
)
1715

16+
t = True
1817
for i in range(len(result)):
19-
assert torch.all(result[i] == result[0]), "Not all tensors equal"
20-
print(result[0])
21-
print("PASS")
18+
t = t and torch.all(result[i] == result[0])
19+
20+
assert t, "Not all tensors equal"
2221

2322

2423
def simple_matmul():

examples/submitit_train.py renamed to tests/test_submitit.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,35 +23,54 @@ def __getitem__(self, index):
2323
"labels": self.labels[index],
2424
}
2525

26+
2627
def main():
2728
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
2829
train_dataset = DummyDataset()
2930

3031
## Training
3132

3233
training_arguments = TrainingArguments(
33-
output_dir = "output",
34-
do_train = True,
35-
per_device_train_batch_size = 16,
36-
max_steps = 20,
34+
output_dir="output",
35+
do_train=True,
36+
per_device_train_batch_size=16,
37+
max_steps=20,
3738
)
3839

3940
trainer = Trainer(
40-
model=model, # type: ignore
41+
model=model, # type: ignore
4142
args=training_arguments,
42-
train_dataset=train_dataset
43+
train_dataset=train_dataset,
4344
)
4445

4546
trainer.train()
4647

48+
4749
def launch():
4850
trx.launch(
49-
func=main,
50-
func_kwargs={},
51-
hostnames=trx.slurm_hosts(),
52-
workers_per_host=trx.slurm_workers()
51+
func=main, func_kwargs={}, hostnames=trx.slurm_hosts(), workers_per_host=trx.slurm_workers()
52+
)
53+
54+
55+
def test_submitit():
56+
executor = submitit.SlurmExecutor(folder="logs")
57+
58+
executor.update_parameters(
59+
time=60,
60+
nodes=1,
61+
ntasks_per_node=1,
62+
mem="32G",
63+
cpus_per_task=4,
64+
gpus_per_node=2,
65+
constraint="geforce3090",
66+
partition="3090-gcondo",
67+
stderr_to_stdout=True,
68+
use_srun=False,
5369
)
5470

71+
executor.submit(launch).result()
72+
73+
5574
if __name__ == "__main__":
5675
executor = submitit.SlurmExecutor(folder="logs")
5776

@@ -68,4 +87,4 @@ def launch():
6887
use_srun=False,
6988
)
7089

71-
executor.submit(launch)
90+
executor.submit(launch)

examples/distributed_train.py renamed to tests/test_train.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import os
2-
import socket
3-
import subprocess
42

5-
import torchrunx
3+
import torchrunx as trx
64

75

86
def worker():
@@ -34,24 +32,14 @@ def forward(self, x):
3432
loss.sum().backward()
3533

3634

37-
def resolve_node_ips(nodelist):
38-
# Expand the nodelist into individual hostnames
39-
hostnames = (
40-
subprocess.check_output(["scontrol", "show", "hostnames", nodelist])
41-
.decode()
42-
.strip()
43-
.split("\n")
35+
def test_distributed_train():
36+
trx.launch(
37+
worker,
38+
hostnames=trx.slurm_hosts(),
39+
workers_per_host=trx.slurm_workers(),
40+
backend="nccl",
4441
)
45-
# Resolve each hostname to an IP address
46-
ips = [socket.gethostbyname(hostname) for hostname in hostnames]
47-
return ips
4842

4943

5044
if __name__ == "__main__":
51-
torchrunx.launch(
52-
worker,
53-
{},
54-
hostnames=torchrunx.slurm_hosts(),
55-
workers_per_host=torchrunx.slurm_workers(),
56-
backend="nccl",
57-
)
45+
test_distributed_train()

0 commit comments

Comments
 (0)