11import os
2- import shutil
3- import sys
2+ import tempfile
43
4+ import pytest
55import torch
66import torch .distributed as dist
77
8- sys .path .append ("../src" )
9-
10- import torchrunx # noqa: I001
8+ import torchrunx as trx
119
1210
1311def test_simple_localhost ():
@@ -30,38 +28,27 @@ def dist_func():
3028
3129 return o .detach ()
3230
33- r = torchrunx .launch (
34- func = dist_func ,
35- func_kwargs = {},
36- workers_per_host = 2 ,
37- backend = "gloo" ,
31+ r = trx .launch (
32+ func = dist_func , func_kwargs = {}, workers_per_host = 2 , backend = "gloo" , log_dir = "./test_logs"
3833 )
3934
4035 assert torch .all (r [0 ] == r [1 ])
4136
42- dist .destroy_process_group ()
43-
4437
4538def test_logging ():
4639 def dist_func ():
4740 rank = int (os .environ ["RANK" ])
4841 print (f"worker rank: { rank } " )
4942
50- try :
51- shutil .rmtree ("./test_logs" , ignore_errors = True )
52- except FileNotFoundError :
53- pass
54-
55- torchrunx .launch (
56- func = dist_func , func_kwargs = {}, workers_per_host = 2 , backend = "gloo" , log_dir = "./test_logs"
57- )
43+ tmp = tempfile .mkdtemp ()
44+ trx .launch (func = dist_func , func_kwargs = {}, workers_per_host = 2 , backend = "gloo" , log_dir = tmp )
5845
59- log_files = next (os .walk ("./test_logs" ), (None , None , []))[2 ]
46+ log_files = next (os .walk (tmp ), (None , None , []))[2 ]
6047
6148 assert len (log_files ) == 3
6249
6350 for file in log_files :
64- with open ("./test_logs/" + file , "r" ) as f :
51+ with open (f" { tmp } / { file } " , "r" ) as f :
6552 if file .endswith ("0.log" ):
6653 assert f .read () == "worker rank: 0\n "
6754 elif file .endswith ("1.log" ):
@@ -71,7 +58,18 @@ def dist_func():
7158 assert "worker rank: 0" in contents
7259 assert "worker rank: 1" in contents
7360
74- # clean up
75- shutil .rmtree ("./test_logs" , ignore_errors = True )
7661
77- dist .destroy_process_group ()
62+ def test_error ():
63+ def error_func ():
64+ raise ValueError ("abcdefg" )
65+
66+ with pytest .raises (RuntimeError ) as excinfo :
67+ trx .launch (
68+ func = error_func ,
69+ func_kwargs = {},
70+ workers_per_host = 1 ,
71+ backend = "gloo" ,
72+ log_dir = tempfile .mkdtemp (),
73+ )
74+
75+ assert "abcdefg" in str (excinfo .value )
0 commit comments