Skip to content

Commit a1251ab

Browse files
puririshi98pre-commit-ci[bot]rusty1s
authored
Improvements to multinode papers100m default hyperparams and adding eval on all ranks (#8823)
> using main branch of PyG (16 GraceHopper nodes): > Val Acc: 0.4546 > Test Acc: 0.3770 > > using PR branch (2 GraceHopper nodes due to availability): > Validation Accuracy: 51.1759% > Test Accuracy: 44.5692% PR ready --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s <[email protected]>
1 parent cfdb4ce commit a1251ab

File tree

2 files changed

+58
-56
lines changed

2 files changed

+58
-56
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818

1919
### Changed
2020

21+
- Improvements to multi-node `ogbn-papers100m` default hyperparameters and adding evaluation on all ranks ([#8823](https://github.com/pyg-team/pytorch_geometric/pull/8823))
2122
- Changed distributed sampler and loader tests to correctly report failures in subprocesses to `pytest` ([#8978](https://github.com/pyg-team/pytorch_geometric/pull/8978))
2223
- Remove filtering of node/edge types in `trim_to_layer` functionality ([#9021](https://github.com/pyg-team/pytorch_geometric/pull/9021))
2324
- Default to `scatter` operations in `MessagePassing` in case `torch.use_deterministic_algorithms` is not set ([#9009](https://github.com/pyg-team/pytorch_geometric/pull/9009))
Lines changed: 57 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
"""Multi-node multi-GPU example on ogbn-papers100m.
22
3-
To run:
3+
Example way to run using srun:
44
srun -l -N<num_nodes> --ntasks-per-node=<ngpu_per_node> \
55
--container-name=cont --container-image=<image_url> \
66
--container-mounts=/ogb-papers100m/:/workspace/dataset
77
python3 path_to_script.py
88
"""
99
import os
1010
import time
11+
from typing import Optional
1112

1213
import torch
1314
import torch.distributed as dist
1415
import torch.nn.functional as F
1516
from ogb.nodeproppred import PygNodePropPredDataset
1617
from torch.nn.parallel import DistributedDataParallel
18+
from torchmetrics import Accuracy
1719

1820
from torch_geometric.loader import NeighborLoader
19-
from torch_geometric.nn import GCNConv
21+
from torch_geometric.nn import GCN
2022

2123

2224
def get_num_workers() -> int:
@@ -31,21 +33,7 @@ def get_num_workers() -> int:
3133
return num_workers
3234

3335

34-
class GCN(torch.nn.Module):
35-
def __init__(self, in_channels, hidden_channels, out_channels):
36-
super().__init__()
37-
self.conv1 = GCNConv(in_channels, hidden_channels)
38-
self.conv2 = GCNConv(hidden_channels, out_channels)
39-
40-
def forward(self, x, edge_index):
41-
x = F.dropout(x, p=0.5, training=self.training)
42-
x = self.conv1(x, edge_index).relu()
43-
x = F.dropout(x, p=0.5, training=self.training)
44-
x = self.conv2(x, edge_index)
45-
return x
46-
47-
48-
def run(world_size, data, split_idx, model):
36+
def run(world_size, data, split_idx, model, acc, wall_clock_start):
4937
local_id = int(os.environ['LOCAL_RANK'])
5038
rank = torch.distributed.get_rank()
5139
torch.cuda.set_device(local_id)
@@ -54,38 +42,48 @@ def run(world_size, data, split_idx, model):
5442
print(f'Using {nprocs} GPUs...')
5543

5644
split_idx['train'] = split_idx['train'].split(
57-
split_idx['train'].size(0) // world_size,
58-
dim=0,
59-
)[rank].clone()
45+
split_idx['train'].size(0) // world_size, dim=0)[rank].clone()
46+
split_idx['valid'] = split_idx['valid'].split(
47+
split_idx['valid'].size(0) // world_size, dim=0)[rank].clone()
48+
split_idx['test'] = split_idx['test'].split(
49+
split_idx['test'].size(0) // world_size, dim=0)[rank].clone()
6050

6151
model = DistributedDataParallel(model.to(device), device_ids=[local_id])
62-
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
52+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,
53+
weight_decay=5e-4)
6354

6455
kwargs = dict(
6556
data=data,
66-
batch_size=128,
57+
batch_size=1024,
6758
num_workers=get_num_workers(),
68-
num_neighbors=[50, 50],
59+
num_neighbors=[30, 30],
6960
)
7061

7162
train_loader = NeighborLoader(
7263
input_nodes=split_idx['train'],
7364
shuffle=True,
65+
drop_last=True,
7466
**kwargs,
7567
)
76-
if rank == 0:
77-
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
78-
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)
68+
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
69+
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)
7970

8071
val_steps = 1000
8172
warmup_steps = 100
73+
acc = acc.to(device)
74+
dist.barrier()
75+
torch.cuda.synchronize()
8276
if rank == 0:
77+
prep_time = round(time.perf_counter() - wall_clock_start, 2)
78+
print("Total time before training begins (prep_time)=", prep_time,
79+
"seconds")
8380
print("Beginning training...")
8481

85-
for epoch in range(1, 4):
82+
for epoch in range(1, 21):
8683
model.train()
8784
for i, batch in enumerate(train_loader):
8885
if i == warmup_steps:
86+
torch.cuda.synchronize()
8987
start = time.time()
9088
batch = batch.to(device)
9189
optimizer.zero_grad()
@@ -98,53 +96,56 @@ def run(world_size, data, split_idx, model):
9896
if rank == 0 and i % 10 == 0:
9997
print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}')
10098

99+
dist.barrier()
100+
torch.cuda.synchronize()
101101
if rank == 0:
102-
sec_per_iter = (time.time() - start) / (i - warmup_steps)
102+
sec_per_iter = (time.time() - start) / (i + 1 - warmup_steps)
103103
print(f"Avg Training Iteration Time: {sec_per_iter:.6f} s/iter")
104104

105+
@torch.no_grad()
106+
def test(loader: NeighborLoader, num_steps: Optional[int] = None):
105107
model.eval()
106-
total_correct = total_examples = 0
107-
for i, batch in enumerate(val_loader):
108-
if i >= val_steps:
108+
for j, batch in enumerate(loader):
109+
if num_steps is not None and j >= num_steps:
109110
break
110-
if i == warmup_steps:
111-
start = time.time()
112-
113111
batch = batch.to(device)
114-
with torch.no_grad():
115-
out = model(batch.x, batch.edge_index)[:batch.batch_size]
116-
pred = out.argmax(dim=-1)
112+
out = model(batch.x, batch.edge_index)[:batch.batch_size]
117113
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
114+
acc(out, y)
115+
acc_sum = acc.compute()
116+
return acc_sum
118117

119-
total_correct += int((pred == y).sum())
120-
total_examples += y.size(0)
118+
eval_acc = test(val_loader, num_steps=val_steps)
119+
if rank == 0:
120+
print(f"Val Accuracy: {eval_acc:.4f}%", )
121121

122-
print(f"Val Acc: {total_correct / total_examples:.4f}")
123-
sec_per_iter = (time.time() - start) / (i - warmup_steps)
124-
print(f"Avg Inference Iteration Time: {sec_per_iter:.6f} s/iter")
122+
acc.reset()
123+
dist.barrier()
125124

125+
test_acc = test(test_loader)
126126
if rank == 0:
127-
model.eval()
128-
total_correct = total_examples = 0
129-
for i, batch in enumerate(test_loader):
130-
batch = batch.to(device)
131-
with torch.no_grad():
132-
out = model(batch.x, batch.edge_index)[:batch.batch_size]
133-
pred = out.argmax(dim=-1)
134-
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
127+
print(f"Test Accuracy: {test_acc:.4f}%", )
135128

136-
total_correct += int((pred == y).sum())
137-
total_examples += y.size(0)
138-
print(f"Test Acc: {total_correct / total_examples:.4f}")
129+
dist.barrier()
130+
acc.reset()
131+
torch.cuda.synchronize()
132+
133+
if rank == 0:
134+
total_time = round(time.perf_counter() - wall_clock_start, 2)
135+
print("Total Program Runtime (total_time) =", total_time, "seconds")
136+
print("total_time - prep_time =", total_time - prep_time, "seconds")
139137

140138

141139
if __name__ == '__main__':
140+
wall_clock_start = time.perf_counter()
142141
# Setup multi-node:
143142
torch.distributed.init_process_group("nccl")
144143
nprocs = dist.get_world_size()
145144
assert dist.is_initialized(), "Distributed cluster not initialized"
146145
dataset = PygNodePropPredDataset(name='ogbn-papers100M')
147146
split_idx = dataset.get_idx_split()
148-
model = GCN(dataset.num_features, 64, dataset.num_classes)
149-
150-
run(nprocs, dataset[0], split_idx, model)
147+
model = GCN(dataset.num_features, 256, 2, dataset.num_classes)
148+
acc = Accuracy(task="multiclass", num_classes=dataset.num_classes)
149+
data = dataset[0]
150+
data.y = data.y.reshape(-1)
151+
run(nprocs, data, split_idx, model, acc, wall_clock_start)

0 commit comments

Comments
 (0)