11import argparse
2+ import os
23import os .path as osp
34import time
45
56import cupy
67import psutil
78import rmm
89import torch
10+ import torch .distributed as dist
911from rmm .allocators .cupy import rmm_cupy_allocator
1012from rmm .allocators .torch import rmm_torch_allocator
1113
3133cudf .set_option ("spill" , True )
3234
3335
36+ # ---------------- Distributed helpers ----------------
37+ def safe_get_rank ():
38+ return dist .get_rank () if dist .is_initialized () else 0
39+
40+
41+ def safe_get_world_size ():
42+ return dist .get_world_size () if dist .is_initialized () else 1
43+
44+
45+ def init_distributed ():
46+ """Initialize distributed training if environment variables are set.
47+ Fallback to single-GPU mode otherwise.
48+ """
49+ # Already initialized ? nothing to do
50+ if dist .is_available () and dist .is_initialized ():
51+ return
52+
53+ # Default env vars for single-GPU / single-process fallback
54+ default_env = {
55+ "RANK" : "0" ,
56+ "LOCAL_RANK" : "0" ,
57+ "WORLD_SIZE" : "1" ,
58+ "LOCAL_WORLD_SIZE" : "1" ,
59+ "MASTER_ADDR" : "127.0.0.1" ,
60+ "MASTER_PORT" : "29500"
61+ }
62+
63+ # Update environment only if keys are missing
64+ for k , v in default_env .items ():
65+ os .environ .setdefault (k , v )
66+
67+ # Set CUDA device
68+ if torch .cuda .is_available ():
69+ local_rank = int (os .environ ["LOCAL_RANK" ])
70+ torch .cuda .set_device (local_rank )
71+
72+ # Initialize distributed only if world_size > 1
73+ world_size = int (os .environ ["WORLD_SIZE" ])
74+ if world_size > 1 :
75+ dist .init_process_group (backend = "nccl" , init_method = "env://" )
76+ rank = os .environ ['RANK' ]
77+ print (f"Initialized distributed: rank { rank } , world_size { world_size } " )
78+ else :
79+ print ("Running in single-GPU / single-process mode" )
80+
81+ if not dist .is_initialized ():
82+ dist .init_process_group (backend = "nccl" , init_method = "env://" , rank = 0 ,
83+ world_size = 1 )
84+
85+
86+ # ------------------------------------------------------
87+
88+
3489def arg_parse ():
3590 parser = argparse .ArgumentParser (
3691 formatter_class = argparse .ArgumentDefaultsHelpFormatter , )
@@ -98,15 +153,16 @@ def arg_parse():
98153
99154
100155def create_loader (
156+ input_nodes ,
157+ stage_name ,
101158 data ,
102159 num_neighbors ,
103- input_nodes ,
104160 replace ,
105161 batch_size ,
106- stage_name ,
107162 shuffle = False ,
108163):
109- print (f'Creating { stage_name } loader...' )
164+ if safe_get_rank () == 0 :
165+ print (f'Creating { stage_name } loader...' )
110166
111167 return NeighborLoader (
112168 data ,
@@ -118,7 +174,7 @@ def create_loader(
118174 )
119175
120176
121- def train (model , train_loader ):
177+ def train (model , train_loader , optimizer ):
122178 model .train ()
123179
124180 total_loss = total_correct = total_examples = 0
@@ -156,17 +212,26 @@ def test(model, loader):
156212
157213
158214if __name__ == '__main__' :
215+ # init DDP if needed
216+ init_distributed ()
217+
159218 args = arg_parse ()
160219 torch_geometric .seed_everything (123 )
220+
161221 if "papers" in str (args .dataset ) and (psutil .virtual_memory ().total /
162222 (1024 ** 3 )) < 390 :
163- print ("Warning: may not have enough RAM to use this many GPUs." )
164- print ("Consider upgrading RAM if an error occurs." )
165- print ("Estimated RAM Needed: ~390GB." )
223+ if safe_get_rank () == 0 :
224+ print ("Warning: may not have enough RAM to use this many GPUs." )
225+ print ("Consider upgrading RAM if an error occurs." )
226+ print ("Estimated RAM Needed: ~390GB." )
227+
166228 wall_clock_start = time .perf_counter ()
167229
168230 root = osp .join (args .dataset_dir , args .dataset_subdir )
169- print ('The root is: ' , root )
231+
232+ if safe_get_rank () == 0 :
233+ print ('The root is: ' , root )
234+
170235 dataset = PygNodePropPredDataset (name = args .dataset , root = root )
171236 split_idx = dataset .get_idx_split ()
172237
@@ -188,33 +253,30 @@ def test(model, loader):
188253 size = (data .num_nodes , data .num_nodes ),
189254 )] = data .edge_index
190255
191- feature_store = cugraph_pyg .data .TensorDictFeatureStore ()
256+ feature_store = cugraph_pyg .data .FeatureStore ()
192257 feature_store ['node' , 'x' , None ] = data .x
193258 feature_store ['node' , 'y' , None ] = data .y
194259
195260 data = (feature_store , graph_store )
196261
197- print (f"Training { args .dataset } with { args .model } model." )
262+ if safe_get_rank () == 0 :
263+ print (f"Training { args .dataset } with { args .model } model." )
264+
198265 if args .model == "GAT" :
199266 model = torch_geometric .nn .models .GAT (dataset .num_features ,
200267 args .hidden_channels ,
201268 args .num_layers ,
202269 dataset .num_classes ,
203270 heads = args .num_heads ).cuda ()
204271 elif args .model == "GCN" :
205- model = torch_geometric .nn .models .GCN (
206- dataset .num_features ,
207- args .hidden_channels ,
208- args .num_layers ,
209- dataset .num_classes ,
210- ).cuda ()
272+ model = torch_geometric .nn .models .GCN (dataset .num_features ,
273+ args .hidden_channels ,
274+ args .num_layers ,
275+ dataset .num_classes ).cuda ()
211276 elif args .model == "SAGE" :
212277 model = torch_geometric .nn .models .GraphSAGE (
213- dataset .num_features ,
214- args .hidden_channels ,
215- args .num_layers ,
216- dataset .num_classes ,
217- ).cuda ()
278+ dataset .num_features , args .hidden_channels , args .num_layers ,
279+ dataset .num_classes ).cuda ()
218280 elif args .model == 'SGFormer' :
219281 # TODO add support for this with disjoint sampling
220282 model = torch_geometric .nn .models .SGFormer (
@@ -227,7 +289,7 @@ def test(model, loader):
227289 gnn_dropout = args .dropout ,
228290 ).cuda ()
229291 else :
230- raise ValueError ('Unsupported model type: {args.model}' )
292+ raise ValueError (f 'Unsupported model type: { args .model } ' )
231293
232294 optimizer = torch .optim .Adam (model .parameters (), lr = args .lr ,
233295 weight_decay = args .wd )
@@ -239,69 +301,54 @@ def test(model, loader):
239301 batch_size = args .batch_size ,
240302 )
241303
242- train_loader = create_loader (
243- input_nodes = split_idx ['train' ],
244- stage_name = 'train' ,
245- shuffle = True ,
246- ** loader_kwargs ,
247- )
304+ train_loader = create_loader (split_idx ['train' ], 'train' , ** loader_kwargs ,
305+ shuffle = True )
306+ val_loader = create_loader (split_idx ['valid' ], 'val' , ** loader_kwargs )
307+ test_loader = create_loader (split_idx ['test' ], 'test' , ** loader_kwargs )
248308
249- val_loader = create_loader (
250- input_nodes = split_idx ['valid' ],
251- stage_name = 'val' ,
252- ** loader_kwargs ,
253- )
309+ if dist .is_initialized ():
310+ dist .barrier () # sync before training
254311
255- test_loader = create_loader (
256- input_nodes = split_idx ['test' ],
257- stage_name = 'test' ,
258- ** loader_kwargs ,
259- )
260- prep_time = round (time .perf_counter () - wall_clock_start , 2 )
261- print ("Total time before training begins (prep_time) =" , prep_time ,
262- "seconds" )
263- print ("Beginning training..." )
264- val_accs = []
265- times = []
266- train_times = []
267- inference_times = []
312+ if safe_get_rank () == 0 :
313+ prep_time = round (time .perf_counter () - wall_clock_start , 2 )
314+ print ("Total time before training begins (prep_time) =" , prep_time ,
315+ "seconds" )
316+ print ("Beginning training..." )
317+
318+ val_accs , times , train_times , inference_times = [], [], [], []
268319 best_val = 0.
269320 start = time .perf_counter ()
270- epochs = args .epochs
271- for epoch in range (1 , epochs + 1 ):
321+ for epoch in range (1 , args .epochs + 1 ):
272322 train_start = time .perf_counter ()
273- loss , train_acc = train (model , train_loader )
323+ loss , train_acc = train (model , train_loader , optimizer )
274324 train_end = time .perf_counter ()
275325 train_times .append (train_end - train_start )
276326 inference_start = time .perf_counter ()
277327 train_acc = test (model , train_loader )
278328 val_acc = test (model , val_loader )
279-
280329 inference_times .append (time .perf_counter () - inference_start )
281330 val_accs .append (val_acc )
282- print (f'Epoch { epoch :02d} , Loss: { loss :.4f} , Approx. Train:'
283- f' { train_acc :.4f} Time: { train_end - train_start :.4f} s' )
284- print (f'Train: { train_acc :.4f} , Val: { val_acc :.4f} , ' )
331+
332+ if safe_get_rank () == 0 :
333+ print (f'Epoch { epoch :02d} , Loss: { loss :.4f} , '
334+ f'Train: { train_acc :.4f} , Val: { val_acc :.4f} , '
335+ f'Time: { train_end - train_start :.4f} s' )
285336
286337 times .append (time .perf_counter () - train_start )
287- if val_acc > best_val :
288- best_val = val_acc
289-
290- print (f"Total time used: is { time .perf_counter ()- start :.4f} " )
291- val_acc = torch .tensor (val_accs )
292- print ('============================' )
293- print ("Average Epoch Time on training: {:.4f}" .format (
294- torch .tensor (train_times ).mean ()))
295- print ("Average Epoch Time on inference: {:.4f}" .format (
296- torch .tensor (inference_times ).mean ()))
297- print (f"Average Epoch Time: { torch .tensor (times ).mean ():.4f} " )
298- print (f"Median time per epoch: { torch .tensor (times ).median ():.4f} s" )
299- print (f'Final Validation: { val_acc .mean ():.4f} ± { val_acc .std ():.4f} ' )
300- print (f"Best validation accuracy: { best_val :.4f} " )
301-
302- print ("Testing..." )
303- final_test_acc = test (model , test_loader )
304- print (f'Test Accuracy: { final_test_acc :.4f} ' )
305-
306- total_time = round (time .perf_counter () - wall_clock_start , 2 )
307- print ("Total Program Runtime (total_time) =" , total_time , "seconds" )
338+ best_val = max (best_val , val_acc )
339+
340+ if safe_get_rank () == 0 :
341+ print (f"Total time used: { time .perf_counter ()- start :.4f} " )
342+ print ("Final Validation: {:.4f} ± {:.4f}" .format (
343+ torch .tensor (val_accs ).mean (),
344+ torch .tensor (val_accs ).std ()))
345+ print (f"Best validation accuracy: { best_val :.4f} " )
346+ print ("Testing..." )
347+ final_test_acc = test (model , test_loader )
348+ print (f'Test Accuracy: { final_test_acc :.4f} ' )
349+ total_time = round (time .perf_counter () - wall_clock_start , 2 )
350+ print ("Total Program Runtime (total_time) =" , total_time , "seconds" )
351+
352+ if dist .is_initialized ():
353+ dist .barrier ()
354+ dist .destroy_process_group ()
0 commit comments