Skip to content

Commit a7714d0

Browse files
committed
cleanup and format
1 parent 02695ca commit a7714d0

File tree

3 files changed

+12
-24
lines changed

3 files changed

+12
-24
lines changed

tests/test_CI.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import os
2-
import sys
32
import tempfile
43

54
import pytest
65
import torch
76
import torch.distributed as dist
87

9-
sys.path.append("../src")
10-
11-
import torchrunx # noqa: I001
8+
import torchrunx as trx
129

1310

1411
def test_simple_localhost():
@@ -31,7 +28,7 @@ def dist_func():
3128

3229
return o.detach()
3330

34-
r = torchrunx.launch(
31+
r = trx.launch(
3532
func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir="./test_logs"
3633
)
3734

@@ -44,9 +41,7 @@ def dist_func():
4441
print(f"worker rank: {rank}")
4542

4643
tmp = tempfile.mkdtemp()
47-
torchrunx.launch(
48-
func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir=tmp
49-
)
44+
trx.launch(func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir=tmp)
5045

5146
log_files = next(os.walk(tmp), (None, None, []))[2]
5247

@@ -69,7 +64,7 @@ def error_func():
6964
raise ValueError("abcdefg")
7065

7166
with pytest.raises(RuntimeError) as excinfo:
72-
torchrunx.launch(
67+
trx.launch(
7368
func=error_func,
7469
func_kwargs={},
7570
workers_per_host=1,

tests/test_func.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import torch
44
import torch.distributed as dist
55

6-
import torchrunx
6+
import torchrunx as trx
77

88

99
def test_launch():
10-
result = torchrunx.launch(
10+
result = trx.launch(
1111
func=simple_matmul,
12-
hostnames=torchrunx.slurm_hosts(),
13-
workers_per_host=torchrunx.slurm_workers(),
12+
hostnames=trx.slurm_hosts(),
13+
workers_per_host=trx.slurm_workers(),
1414
)
1515

1616
t = True
@@ -19,8 +19,6 @@ def test_launch():
1919

2020
assert t, "Not all tensors equal"
2121

22-
dist.destroy_process_group()
23-
2422

2523
def simple_matmul():
2624
rank = int(os.environ["RANK"])

tests/test_train.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
import os
2-
import sys
32

4-
sys.path.append("../src")
5-
6-
import torch.distributed as dist
7-
8-
import torchrunx
3+
import torchrunx as trx
94

105

116
def worker():
@@ -38,10 +33,10 @@ def forward(self, x):
3833

3934

4035
def test_distributed_train():
41-
torchrunx.launch(
36+
trx.launch(
4237
worker,
43-
hostnames=torchrunx.slurm_hosts(),
44-
workers_per_host=torchrunx.slurm_workers(),
38+
hostnames=trx.slurm_hosts(),
39+
workers_per_host=trx.slurm_workers(),
4540
backend="nccl",
4641
)
4742

0 commit comments

Comments
 (0)