Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added an example for hierarichial sampling ([#7244](https://github.com/pyg-team/pytorch_geometric/pull/7244))
- Added Kùzu remote backend examples ([#7298](https://github.com/pyg-team/pytorch_geometric/pull/7298))
- Fixed tracing of `add_self_loops` for a dynamic number of nodes ([#7330](https://github.com/pyg-team/pytorch_geometric/pull/7330))
- Added an optional `add_pad_mask` argument to the `Pad` transform ([#7339](https://github.com/pyg-team/pytorch_geometric/pull/7339))
Expand Down
59 changes: 59 additions & 0 deletions examples/hierarchical_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os.path as osp

import torch
import torch.nn.functional as F
from tqdm import tqdm

from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn.models.basic_gnn import GraphSAGE

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')
dataset = Reddit(path)

# Already send node features/labels to GPU for faster access during sampling:
data = dataset[0].to(device, 'x', 'y')

kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}
loader = NeighborLoader(data, input_nodes=data.train_mask,
num_neighbors=[20, 10, 5], shuffle=True, **kwargs)

model = GraphSAGE(
dataset.num_features,
hidden_channels=64,
out_channels=dataset.num_classes,
num_layers=3,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


def train(trim=False):
for batch in tqdm(loader):
optimizer.zero_grad()
batch = batch.to(device)

if not trim:
out = model(batch.x, batch.edge_index)
else:
out = model(
batch.x,
batch.edge_index,
num_sampled_nodes_per_hop=batch.num_sampled_nodes,
num_sampled_edges_per_hop=batch.num_sampled_edges,
)

out = out[:batch.batch_size]
y = batch.y[:batch.batch_size]

loss = F.cross_entropy(out, y)
loss.backward()
optimizer.step()


print('One epoch training without Hierarchical Graph Sampling:')
train(trim=False)

print('One epoch training with Hierarchical Graph Sampling:')
train(trim=True)