11"""Multi-node multi-GPU example on ogbn-papers100m.
22
3- To run:
3+ Example way to run using srun :
44srun -l -N<num_nodes> --ntasks-per-node=<ngpu_per_node> \
55 --container-name=cont --container-image=<image_url> \
66 --container-mounts=/ogb-papers100m/:/workspace/dataset
77python3 path_to_script.py
88"""
99import os
1010import time
11+ from typing import Optional
1112
1213import torch
1314import torch .distributed as dist
1415import torch .nn .functional as F
1516from ogb .nodeproppred import PygNodePropPredDataset
1617from torch .nn .parallel import DistributedDataParallel
18+ from torchmetrics import Accuracy
1719
1820from torch_geometric .loader import NeighborLoader
19- from torch_geometric .nn import GCNConv
21+ from torch_geometric .nn import GCN
2022
2123
2224def get_num_workers () -> int :
@@ -31,21 +33,7 @@ def get_num_workers() -> int:
3133 return num_workers
3234
3335
34- class GCN (torch .nn .Module ):
35- def __init__ (self , in_channels , hidden_channels , out_channels ):
36- super ().__init__ ()
37- self .conv1 = GCNConv (in_channels , hidden_channels )
38- self .conv2 = GCNConv (hidden_channels , out_channels )
39-
40- def forward (self , x , edge_index ):
41- x = F .dropout (x , p = 0.5 , training = self .training )
42- x = self .conv1 (x , edge_index ).relu ()
43- x = F .dropout (x , p = 0.5 , training = self .training )
44- x = self .conv2 (x , edge_index )
45- return x
46-
47-
48- def run (world_size , data , split_idx , model ):
36+ def run (world_size , data , split_idx , model , acc , wall_clock_start ):
4937 local_id = int (os .environ ['LOCAL_RANK' ])
5038 rank = torch .distributed .get_rank ()
5139 torch .cuda .set_device (local_id )
@@ -54,38 +42,48 @@ def run(world_size, data, split_idx, model):
5442 print (f'Using { nprocs } GPUs...' )
5543
5644 split_idx ['train' ] = split_idx ['train' ].split (
57- split_idx ['train' ].size (0 ) // world_size ,
58- dim = 0 ,
59- )[rank ].clone ()
45+ split_idx ['train' ].size (0 ) // world_size , dim = 0 )[rank ].clone ()
46+ split_idx ['valid' ] = split_idx ['valid' ].split (
47+ split_idx ['valid' ].size (0 ) // world_size , dim = 0 )[rank ].clone ()
48+ split_idx ['test' ] = split_idx ['test' ].split (
49+ split_idx ['test' ].size (0 ) // world_size , dim = 0 )[rank ].clone ()
6050
6151 model = DistributedDataParallel (model .to (device ), device_ids = [local_id ])
62- optimizer = torch .optim .Adam (model .parameters (), lr = 0.01 )
52+ optimizer = torch .optim .Adam (model .parameters (), lr = 0.001 ,
53+ weight_decay = 5e-4 )
6354
6455 kwargs = dict (
6556 data = data ,
66- batch_size = 128 ,
57+ batch_size = 1024 ,
6758 num_workers = get_num_workers (),
68- num_neighbors = [50 , 50 ],
59+ num_neighbors = [30 , 30 ],
6960 )
7061
7162 train_loader = NeighborLoader (
7263 input_nodes = split_idx ['train' ],
7364 shuffle = True ,
65+ drop_last = True ,
7466 ** kwargs ,
7567 )
76- if rank == 0 :
77- val_loader = NeighborLoader (input_nodes = split_idx ['valid' ], ** kwargs )
78- test_loader = NeighborLoader (input_nodes = split_idx ['test' ], ** kwargs )
68+ val_loader = NeighborLoader (input_nodes = split_idx ['valid' ], ** kwargs )
69+ test_loader = NeighborLoader (input_nodes = split_idx ['test' ], ** kwargs )
7970
8071 val_steps = 1000
8172 warmup_steps = 100
73+ acc = acc .to (device )
74+ dist .barrier ()
75+ torch .cuda .synchronize ()
8276 if rank == 0 :
77+ prep_time = round (time .perf_counter () - wall_clock_start , 2 )
78+ print ("Total time before training begins (prep_time)=" , prep_time ,
79+ "seconds" )
8380 print ("Beginning training..." )
8481
85- for epoch in range (1 , 4 ):
82+ for epoch in range (1 , 21 ):
8683 model .train ()
8784 for i , batch in enumerate (train_loader ):
8885 if i == warmup_steps :
86+ torch .cuda .synchronize ()
8987 start = time .time ()
9088 batch = batch .to (device )
9189 optimizer .zero_grad ()
@@ -98,53 +96,56 @@ def run(world_size, data, split_idx, model):
9896 if rank == 0 and i % 10 == 0 :
9997 print (f'Epoch: { epoch :02d} , Iteration: { i } , Loss: { loss :.4f} ' )
10098
99+ dist .barrier ()
100+ torch .cuda .synchronize ()
101101 if rank == 0 :
102- sec_per_iter = (time .time () - start ) / (i - warmup_steps )
102+ sec_per_iter = (time .time () - start ) / (i + 1 - warmup_steps )
103103 print (f"Avg Training Iteration Time: { sec_per_iter :.6f} s/iter" )
104104
105+ @torch .no_grad ()
106+ def test (loader : NeighborLoader , num_steps : Optional [int ] = None ):
105107 model .eval ()
106- total_correct = total_examples = 0
107- for i , batch in enumerate (val_loader ):
108- if i >= val_steps :
108+ for j , batch in enumerate (loader ):
109+ if num_steps is not None and j >= num_steps :
109110 break
110- if i == warmup_steps :
111- start = time .time ()
112-
113111 batch = batch .to (device )
114- with torch .no_grad ():
115- out = model (batch .x , batch .edge_index )[:batch .batch_size ]
116- pred = out .argmax (dim = - 1 )
112+ out = model (batch .x , batch .edge_index )[:batch .batch_size ]
117113 y = batch .y [:batch .batch_size ].view (- 1 ).to (torch .long )
114+ acc (out , y )
115+ acc_sum = acc .compute ()
116+ return acc_sum
118117
119- total_correct += int ((pred == y ).sum ())
120- total_examples += y .size (0 )
118+ eval_acc = test (val_loader , num_steps = val_steps )
119+ if rank == 0 :
120+ print (f"Val Accuracy: { eval_acc :.4f} %" , )
121121
122- print (f"Val Acc: { total_correct / total_examples :.4f} " )
123- sec_per_iter = (time .time () - start ) / (i - warmup_steps )
124- print (f"Avg Inference Iteration Time: { sec_per_iter :.6f} s/iter" )
122+ acc .reset ()
123+ dist .barrier ()
125124
125+ test_acc = test (test_loader )
126126 if rank == 0 :
127- model .eval ()
128- total_correct = total_examples = 0
129- for i , batch in enumerate (test_loader ):
130- batch = batch .to (device )
131- with torch .no_grad ():
132- out = model (batch .x , batch .edge_index )[:batch .batch_size ]
133- pred = out .argmax (dim = - 1 )
134- y = batch .y [:batch .batch_size ].view (- 1 ).to (torch .long )
127+ print (f"Test Accuracy: { test_acc :.4f} %" , )
135128
136- total_correct += int ((pred == y ).sum ())
137- total_examples += y .size (0 )
138- print (f"Test Acc: { total_correct / total_examples :.4f} " )
129+ dist .barrier ()
130+ acc .reset ()
131+ torch .cuda .synchronize ()
132+
133+ if rank == 0 :
134+ total_time = round (time .perf_counter () - wall_clock_start , 2 )
135+ print ("Total Program Runtime (total_time) =" , total_time , "seconds" )
136+ print ("total_time - prep_time =" , total_time - prep_time , "seconds" )
139137
140138
141139if __name__ == '__main__' :
140+ wall_clock_start = time .perf_counter ()
142141 # Setup multi-node:
143142 torch .distributed .init_process_group ("nccl" )
144143 nprocs = dist .get_world_size ()
145144 assert dist .is_initialized (), "Distributed cluster not initialized"
146145 dataset = PygNodePropPredDataset (name = 'ogbn-papers100M' )
147146 split_idx = dataset .get_idx_split ()
148- model = GCN (dataset .num_features , 64 , dataset .num_classes )
149-
150- run (nprocs , dataset [0 ], split_idx , model )
147+ model = GCN (dataset .num_features , 256 , 2 , dataset .num_classes )
148+ acc = Accuracy (task = "multiclass" , num_classes = dataset .num_classes )
149+ data = dataset [0 ]
150+ data .y = data .y .reshape (- 1 )
151+ run (nprocs , data , split_idx , model , acc , wall_clock_start )
0 commit comments