Skip to content

Commit 82f31cf

Browse files
andreazanettipre-commit-ci[bot]rusty1s
authored
Adds an example for Hierarchical Sampling (#7244)
It compares one epoch of training with and without Hierarchical Sampling. With pyg-lib>0.1.0 we return the sampled number of nodes/edges in [neighbor_sampler.py](https://github.com/pyg-team/pytorch_geometric/blob/e3e63d66e52aa9ca4553274f0572f1f066d99c41/torch_geometric/sampler/neighbor_sampler.py#L241) Leveraging this, the [training_benchmark.py](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/training/training_benchmark.py) refers to `BasicGNN` base class, in which [the forward pass does the trimming if required](https://github.com/pyg-team/pytorch_geometric/blob/e3e63d66e52aa9ca4553274f0572f1f066d99c41/torch_geometric/nn/models/basic_gnn.py#L201) (using the `--trim` flag with `training_benchmark.py`). Therefore, this is an example that mimics what is being done in the `training_benchmark.py,` to make evident for the user what this trimming/Hierarchical Sampling is about, how to test it, and have an idea of the advantage. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s <[email protected]>
1 parent 2395d70 commit 82f31cf

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10+
- Added an example for hierarichial sampling ([#7244](https://github.com/pyg-team/pytorch_geometric/pull/7244))
1011
- Added Kùzu remote backend examples ([#7298](https://github.com/pyg-team/pytorch_geometric/pull/7298))
1112
- Fixed tracing of `add_self_loops` for a dynamic number of nodes ([#7330](https://github.com/pyg-team/pytorch_geometric/pull/7330))
1213
- Added an optional `add_pad_mask` argument to the `Pad` transform ([#7339](https://github.com/pyg-team/pytorch_geometric/pull/7339))

examples/hierarchical_sampling.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os.path as osp
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from tqdm import tqdm
6+
7+
from torch_geometric.datasets import Reddit
8+
from torch_geometric.loader import NeighborLoader
9+
from torch_geometric.nn.models.basic_gnn import GraphSAGE
10+
11+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12+
13+
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')
14+
dataset = Reddit(path)
15+
16+
# Already send node features/labels to GPU for faster access during sampling:
17+
data = dataset[0].to(device, 'x', 'y')
18+
19+
kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}
20+
loader = NeighborLoader(data, input_nodes=data.train_mask,
21+
num_neighbors=[20, 10, 5], shuffle=True, **kwargs)
22+
23+
model = GraphSAGE(
24+
dataset.num_features,
25+
hidden_channels=64,
26+
out_channels=dataset.num_classes,
27+
num_layers=3,
28+
).to(device)
29+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
30+
31+
32+
def train(trim=False):
33+
for batch in tqdm(loader):
34+
optimizer.zero_grad()
35+
batch = batch.to(device)
36+
37+
if not trim:
38+
out = model(batch.x, batch.edge_index)
39+
else:
40+
out = model(
41+
batch.x,
42+
batch.edge_index,
43+
num_sampled_nodes_per_hop=batch.num_sampled_nodes,
44+
num_sampled_edges_per_hop=batch.num_sampled_edges,
45+
)
46+
47+
out = out[:batch.batch_size]
48+
y = batch.y[:batch.batch_size]
49+
50+
loss = F.cross_entropy(out, y)
51+
loss.backward()
52+
optimizer.step()
53+
54+
55+
print('One epoch training without Hierarchical Graph Sampling:')
56+
train(trim=False)
57+
58+
print('One epoch training with Hierarchical Graph Sampling:')
59+
train(trim=True)

0 commit comments

Comments
 (0)