Skip to content

Commit 870179f

Browse files
puririshi98pre-commit-ci[bot]akihironitta
authored
Improvements for Papers100m single gpu and single node multi gpu examples (Cugraph, GATConv, better default hyperparams, eval on all ranks) (#8173)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <[email protected]>
1 parent 08eb6b9 commit 870179f

File tree

7 files changed

+721
-131
lines changed

7 files changed

+721
-131
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10+
- Added support for cuGraph data loading and `GAT` in single node Papers100m examples ([#8173](https://github.com/pyg-team/pytorch_geometric/pull/8173))
1011
- Added the `VariancePreservingAggregation` (VPA) ([#9075](https://github.com/pyg-team/pytorch_geometric/pull/9075))
1112
- Added option to pass custom` from_smiles` functionality to `PCQM4Mv2` and `MoleculeNet` ([#9073](https://github.com/pyg-team/pytorch_geometric/pull/9073))
1213
- Added `group_cat` functionality ([#9029](https://github.com/pyg-team/pytorch_geometric/pull/9029))

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see
1212
- [`ogbn_products_sage.py`](./ogbn_products_sage.py) and [`ogbn_products_gat.py`](./ogbn_products_gat.py) show how to train [`GraphSAGE`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphSAGE.html) and [`GAT`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GAT.html) models on the `ogbn-products` dataset.
1313
- [`ogbn_proteins_deepgcn.py`](./ogbn_proteins_deepgcn.py) is an example to showcase how to train deep GNNs on the `ogbn-proteins` dataset.
1414
- [`ogbn_papers_100m.py`](./ogbn_papers_100m.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges.
15+
- [`ogbn_papers_100m_cugraph.py`](./ogbn_papers_100m_cugraph.py) shows how to accelerate the `ogbn-papers100m` workflow using [CuGraph](https://github.com/rapidsai/cugraph).
1516

1617
For examples on using `torch.compile`, see the examples under [`examples/compile`](./compile).
1718

examples/multi_gpu/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
| [`distributed_sampling.py`](./distributed_sampling.py) | single-node | Example for training GNNs on a homogeneous graph with neighbor sampling. |
99
| [`distributed_sampling_multinode.py`](./distributed_sampling_multinode.py) | multi-node | Example for training GNNs on a homogeneous graph with neighbor sampling on multiple nodes. |
1010
| [`distributed_sampling_multinode.sbatch`](./distributed_sampling_multinode.sbatch) | multi-node | Example for submitting a training job to a Slurm cluster using [`distributed_sampling_multi_node.py`](./distributed_sampling_multinode.py). |
11-
| [`papers100m_gcn.py`](./papers100m_gcn.py) | single-node | Example for training GNNs on a homogeneous graph. |
11+
| [`papers100m_gcn.py`](./papers100m_gcn.py) | single-node | Example for training GNNs on the `ogbn-papers100M` homogeneous graph w/ ~1.6B edges. |
12+
| [`papers100m_gcn_cugraph.py`](./papers100m_gcn_cugraph.py%60) | single-node | Example for accelerating GNN training on `ogbn-papers100M` using [CuGraph](...). |
1213
| [`papers100m_gcn_multinode.py`](./papers100m_gcn_multinode.py) | multi-node | Example for training GNNs on a homogeneous graph on multiple nodes. |
1314
| [`mag240m_graphsage.py`](./mag240m_graphsage.py) | single-node | Example for training GNNs on a large heterogeneous graph. |
1415
| [`taobao.py`](./taobao.py) | single-node | Example for training link prediction GNNs on a heterogeneous graph. |
Lines changed: 147 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import argparse
12
import os
3+
import tempfile
24
import time
35

46
import torch
@@ -7,136 +9,189 @@
79
import torch.nn.functional as F
810
from ogb.nodeproppred import PygNodePropPredDataset
911
from torch.nn.parallel import DistributedDataParallel
12+
from torchmetrics import Accuracy
1013

14+
import torch_geometric
1115
from torch_geometric.loader import NeighborLoader
12-
from torch_geometric.nn import GCNConv
1316

1417

15-
def get_num_workers(world_size: int) -> int:
16-
num_workers = None
18+
def get_num_workers(world_size):
19+
num_work = None
1720
if hasattr(os, "sched_getaffinity"):
1821
try:
19-
num_workers = len(os.sched_getaffinity(0)) // (2 * world_size)
22+
num_work = len(os.sched_getaffinity(0)) / (2 * world_size)
2023
except Exception:
2124
pass
22-
if num_workers is None:
23-
num_workers = os.cpu_count() // (2 * world_size)
24-
return num_workers
25+
if num_work is None:
26+
num_work = os.cpu_count() / (2 * world_size)
27+
return int(num_work)
2528

2629

27-
class GCN(torch.nn.Module):
28-
def __init__(self, in_channels, hidden_channels, out_channels):
29-
super().__init__()
30-
self.conv1 = GCNConv(in_channels, hidden_channels)
31-
self.conv2 = GCNConv(hidden_channels, out_channels)
30+
def run_train(rank, data, world_size, model, epochs, batch_size, fan_out,
31+
split_idx, num_classes, wall_clock_start, tempdir=None,
32+
num_layers=3):
3233

33-
def forward(self, x, edge_index=None):
34-
x = F.dropout(x, p=0.5, training=self.training)
35-
x = self.conv1(x, edge_index).relu()
36-
x = F.dropout(x, p=0.5, training=self.training)
37-
x = self.conv2(x, edge_index)
38-
return x
39-
40-
41-
def run(rank, world_size, data, split_idx, model):
34+
# init pytorch worker
4235
os.environ['MASTER_ADDR'] = 'localhost'
4336
os.environ['MASTER_PORT'] = '12355'
4437
dist.init_process_group('nccl', rank=rank, world_size=world_size)
4538

46-
split_idx['train'] = split_idx['train'].split(
47-
split_idx['train'].size(0) // world_size,
48-
dim=0,
49-
)[rank].clone()
50-
51-
model = DistributedDataParallel(model.to(rank), device_ids=[rank])
52-
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
39+
if world_size > 1:
40+
split_idx['train'] = split_idx['train'].split(
41+
split_idx['train'].size(0) // world_size, dim=0)[rank].clone()
42+
split_idx['valid'] = split_idx['valid'].split(
43+
split_idx['valid'].size(0) // world_size, dim=0)[rank].clone()
44+
split_idx['test'] = split_idx['test'].split(
45+
split_idx['test'].size(0) // world_size, dim=0)[rank].clone()
46+
model = model.to(rank)
47+
model = DistributedDataParallel(model, device_ids=[rank])
48+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01,
49+
weight_decay=0.0005)
5350

5451
kwargs = dict(
55-
data=data,
56-
batch_size=128,
57-
num_workers=get_num_workers(world_size),
58-
num_neighbors=[50, 50],
52+
num_neighbors=[fan_out] * num_layers,
53+
batch_size=batch_size,
5954
)
60-
train_loader = NeighborLoader(
61-
input_nodes=split_idx['train'],
62-
shuffle=True,
63-
**kwargs,
64-
)
65-
if rank == 0:
66-
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
67-
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)
68-
69-
val_steps = 1000
70-
warmup_steps = 100
55+
num_work = get_num_workers(world_size)
56+
train_loader = NeighborLoader(data, input_nodes=split_idx['train'],
57+
num_workers=num_work, shuffle=True,
58+
drop_last=True, **kwargs)
59+
val_loader = NeighborLoader(data, input_nodes=split_idx['valid'],
60+
num_workers=num_work, **kwargs)
61+
test_loader = NeighborLoader(data, input_nodes=split_idx['test'],
62+
num_workers=num_work, **kwargs)
63+
64+
eval_steps = 1000
65+
warmup_steps = 20
66+
acc = Accuracy(task="multiclass", num_classes=num_classes).to(rank)
67+
dist.barrier()
68+
torch.cuda.synchronize()
7169
if rank == 0:
70+
prep_time = round(time.perf_counter() - wall_clock_start, 2)
71+
print("Total time before training begins (prep_time) =", prep_time,
72+
"seconds")
7273
print("Beginning training...")
73-
74-
for epoch in range(1, 4):
75-
model.train()
74+
for epoch in range(epochs):
7675
for i, batch in enumerate(train_loader):
7776
if i == warmup_steps:
77+
torch.cuda.synchronize()
7878
start = time.time()
7979
batch = batch.to(rank)
80+
batch_size = batch.num_sampled_nodes[0]
81+
batch.y = batch.y.to(torch.long)
8082
optimizer.zero_grad()
81-
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
82-
out = model(batch.x, batch.edge_index)[:batch.batch_size]
83-
loss = F.cross_entropy(out, y)
83+
out = model(batch.x, batch.edge_index)
84+
loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size])
8485
loss.backward()
8586
optimizer.step()
86-
8787
if rank == 0 and i % 10 == 0:
88-
print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}')
89-
88+
print("Epoch: " + str(epoch) + ", Iteration: " + str(i) +
89+
", Loss: " + str(loss))
90+
nb = i + 1.0
91+
dist.barrier()
92+
torch.cuda.synchronize()
9093
if rank == 0:
91-
sec_per_iter = (time.time() - start) / (i - warmup_steps)
92-
print(f"Avg Training Iteration Time: {sec_per_iter:.6f} s/iter")
93-
94-
model.eval()
95-
total_correct = total_examples = 0
94+
print("Average Training Iteration Time:",
95+
(time.time() - start) / (nb - warmup_steps), "s/iter")
96+
with torch.no_grad():
9697
for i, batch in enumerate(val_loader):
97-
if i >= val_steps:
98+
if i >= eval_steps:
9899
break
99-
if i == warmup_steps:
100-
start = time.time()
101100

102101
batch = batch.to(rank)
103-
with torch.no_grad():
104-
out = model(batch.x, batch.edge_index)[:batch.batch_size]
105-
pred = out.argmax(dim=-1)
106-
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
107-
108-
total_correct += int((pred == y).sum())
109-
total_examples += y.size(0)
110-
111-
print(f"Val Acc: {total_correct / total_examples:.4f}")
112-
sec_per_iter = (time.time() - start) / (i - warmup_steps)
113-
print(f"Avg Inference Iteration Time: {sec_per_iter:.6f} s/iter")
114-
115-
if rank == 0:
116-
model.eval()
117-
total_correct = total_examples = 0
102+
batch_size = batch.num_sampled_nodes[0]
103+
104+
batch.y = batch.y.to(torch.long)
105+
out = model(batch.x, batch.edge_index)
106+
acc_i = acc( # noqa
107+
out[:batch_size].softmax(dim=-1), batch.y[:batch_size])
108+
acc_sum = acc.compute()
109+
if rank == 0:
110+
print(f"Validation Accuracy: {acc_sum * 100.0:.4f}%", )
111+
dist.barrier()
112+
acc.reset()
113+
114+
with torch.no_grad():
118115
for i, batch in enumerate(test_loader):
119116
batch = batch.to(rank)
120-
with torch.no_grad():
121-
out = model(batch.x, batch.edge_index)[:batch.batch_size]
122-
pred = out.argmax(dim=-1)
123-
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
117+
batch_size = batch.num_sampled_nodes[0]
124118

125-
total_correct += int((pred == y).sum())
126-
total_examples += y.size(0)
127-
print(f"Test Acc: {total_correct / total_examples:.4f}")
119+
batch.y = batch.y.to(torch.long)
120+
out = model(batch.x, batch.edge_index)
121+
acc_i = acc( # noqa
122+
out[:batch_size].softmax(dim=-1), batch.y[:batch_size])
123+
acc_sum = acc.compute()
124+
if rank == 0:
125+
print(f"Test Accuracy: {acc_sum * 100.0:.4f}%", )
126+
dist.barrier()
127+
acc.reset()
128+
if rank == 0:
129+
total_time = round(time.perf_counter() - wall_clock_start, 2)
130+
print("Total Program Runtime (total_time) =", total_time, "seconds")
131+
print("total_time - prep_time =", total_time - prep_time, "seconds")
128132

129133

130134
if __name__ == '__main__':
131-
dataset = PygNodePropPredDataset(name='ogbn-papers100M')
132-
split_idx = dataset.get_idx_split()
133-
model = GCN(dataset.num_features, 64, dataset.num_classes)
134135

135-
world_size = torch.cuda.device_count()
136-
print('Let\'s use', world_size, 'GPUs!')
137-
mp.spawn(
138-
run,
139-
args=(world_size, dataset[0], split_idx, model),
140-
nprocs=world_size,
141-
join=True,
136+
parser = argparse.ArgumentParser()
137+
parser.add_argument('--hidden_channels', type=int, default=256)
138+
parser.add_argument('--num_layers', type=int, default=2)
139+
parser.add_argument('--lr', type=float, default=0.001)
140+
parser.add_argument('--epochs', type=int, default=20)
141+
parser.add_argument('--batch_size', type=int, default=1024)
142+
parser.add_argument('--fan_out', type=int, default=30)
143+
parser.add_argument(
144+
"--use_gat_conv",
145+
action='store_true',
146+
help="Whether or not to use GATConv. (Defaults to using GCNConv)",
142147
)
148+
parser.add_argument(
149+
"--n_gat_conv_heads",
150+
type=int,
151+
default=4,
152+
help="If using GATConv, number of attention heads to use",
153+
)
154+
parser.add_argument(
155+
"--n_devices", type=int, default=-1,
156+
help="1-8 to use that many GPUs. Defaults to all available GPUs")
157+
158+
args = parser.parse_args()
159+
wall_clock_start = time.perf_counter()
160+
161+
dataset = PygNodePropPredDataset(name='ogbn-papers100M',
162+
root='/datasets/ogb_datasets')
163+
split_idx = dataset.get_idx_split()
164+
data = dataset[0]
165+
data.y = data.y.reshape(-1)
166+
if args.use_gat_conv:
167+
model = torch_geometric.nn.models.GAT(dataset.num_features,
168+
args.hidden_channels,
169+
args.num_layers,
170+
dataset.num_classes,
171+
heads=args.n_gat_conv_heads)
172+
else:
173+
model = torch_geometric.nn.models.GCN(
174+
dataset.num_features,
175+
args.hidden_channels,
176+
args.num_layers,
177+
dataset.num_classes,
178+
)
179+
180+
print("Data =", data)
181+
if args.n_devices == -1:
182+
world_size = torch.cuda.device_count()
183+
else:
184+
world_size = args.n_devices
185+
print('Let\'s use', world_size, 'GPUs!')
186+
with tempfile.TemporaryDirectory() as tempdir:
187+
if world_size > 1:
188+
mp.spawn(
189+
run_train,
190+
args=(data, world_size, model, args.epochs, args.batch_size,
191+
args.fan_out, split_idx, dataset.num_classes,
192+
wall_clock_start, tempdir, args.num_layers),
193+
nprocs=world_size, join=True)
194+
else:
195+
run_train(0, data, world_size, model, args.epochs, args.batch_size,
196+
args.fan_out, split_idx, dataset.num_classes,
197+
wall_clock_start, tempdir, args.num_layers)

0 commit comments

Comments
 (0)