Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

### Added

- Added `NodeFormer` model and example ([#10409](https://github.com/pyg-team/pytorch_geometric/pull/10409))
- Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918))
- Added `torch_geometric.llm` and its examples ([#10436](https://github.com/pyg-team/pytorch_geometric/pull/10436))
- Added support for negative weights in `sparse_cross_entropy` ([#10432](https://github.com/pyg-team/pytorch_geometric/pull/10432))
Expand Down
55 changes: 47 additions & 8 deletions examples/ogbn_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@

from torch_geometric import seed_everything
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn.models import GAT, GraphSAGE, Polynormer, SGFormer
from torch_geometric.nn.models import (
GAT,
GraphSAGE,
NodeFormer,
Polynormer,
SGFormer,
)
from torch_geometric.utils import (
add_self_loops,
remove_self_loops,
Expand All @@ -37,7 +43,7 @@
"--model",
type=str.lower,
default='SGFormer',
choices=['sage', 'gat', 'sgformer', 'polynormer'],
choices=['sage', 'gat', 'sgformer', 'polynormer', 'nodeformer'],
help="Model used for training",
)

Expand All @@ -55,6 +61,8 @@
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--wd', type=float, default=0.0)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--lamda', type=float, default=0.1,
help='weight for edge reg loss of nodeformer')
parser.add_argument(
'--use_directed_graph',
action='store_true',
Expand Down Expand Up @@ -132,18 +140,23 @@ def train(epoch: int) -> tuple[Tensor, float]:

total_loss = total_correct = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
if args.model in ['sgformer', 'polynormer']:
if args.model == 'polynormer' and epoch == args.local_epochs:
print('start global attention')
model._global = True
out = model(batch.x, batch.edge_index.to(device),
batch.batch.to(device))[:batch.batch_size]
out = model(batch.x, batch.edge_index,
batch.batch)[:batch.batch_size]
elif args.model in ['nodeformer']:
out, link_loss = model(batch.x, batch.edge_index)
out = out[:batch.batch_size]
else:
out = model(batch.x,
batch.edge_index.to(device))[:batch.batch_size]
out = model(batch.x, batch.edge_index)[:batch.batch_size]
y = batch.y[:batch.batch_size].squeeze().to(torch.long)
loss = F.cross_entropy(out, y)
if args.model in ['nodeformer']:
loss -= args.lamda * sum(link_loss) / len(link_loss)
loss.backward()
optimizer.step()

Expand All @@ -168,6 +181,9 @@ def test(loader: NeighborLoader) -> float:
if args.model in ['sgformer', 'polynormer']:
out = model(batch.x, batch.edge_index,
batch.batch)[:batch.batch_size]
elif args.model in ['nodeformer']:
out, _ = model(batch.x, batch.edge_index)
out = out[:batch.batch_size]
else:
out = model(batch.x, batch.edge_index)[:batch_size]
pred = out.argmax(dim=-1)
Expand Down Expand Up @@ -214,6 +230,24 @@ def get_model(model_name: str) -> torch.nn.Module:
out_channels=dataset.num_classes,
local_layers=num_layers,
)
elif model_name == 'nodeformer':
model = NodeFormer(
in_channels=dataset.num_features,
hidden_channels=num_hidden_channels,
out_channels=dataset.num_classes,
num_layers=num_layers,
num_heads=args.num_heads,
use_bn=True,
nb_random_features=100,
use_gumbel=True,
use_residual=True,
use_act=True,
use_jk=True,
nb_gumbel_sample=5,
rb_order=1,
rb_trans='identity',
tau=1.0,
)
else:
raise ValueError(f'Unsupported model type: {model_name}')

Expand All @@ -227,8 +261,13 @@ def get_model(model_name: str) -> torch.nn.Module:
lr=args.lr,
weight_decay=args.wd,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max',
patience=5)
if args.model == 'nodeformer':
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[100, 200],
gamma=0.5)
else:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max', patience=5)

print(f'Total time before training begins took '
f'{time.perf_counter() - wall_clock_start:.4f}s')
Expand Down
50 changes: 50 additions & 0 deletions test/nn/models/test_nodeformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import torch

from torch_geometric.nn.models import NodeFormer


@pytest.mark.parametrize('use_bn', [True, False])
@pytest.mark.parametrize('use_gumbel', [True, False])
@pytest.mark.parametrize('use_residual', [True, False])
@pytest.mark.parametrize('use_act', [True, False])
@pytest.mark.parametrize('use_jk', [True, False])
@pytest.mark.parametrize('use_edge_loss', [True, False])
@pytest.mark.parametrize('rb_trans', ['sigmoid', 'identity'])
@pytest.mark.parametrize('rb_order', [0, 1])
def test_nodeformer(
use_bn,
use_gumbel,
use_residual,
use_act,
use_jk,
use_edge_loss,
rb_trans,
rb_order,
):
x = torch.randn(10, 16)
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 0, 6, 7, 8, 9, 5],
])

model = NodeFormer(
in_channels=16,
hidden_channels=128,
out_channels=40,
num_layers=3,
use_bn=use_bn,
use_gumbel=use_gumbel,
use_residual=use_residual,
use_act=use_act,
use_jk=use_jk,
use_edge_loss=use_edge_loss,
rb_trans=rb_trans,
rb_order=rb_order,
)
if use_edge_loss:
out, link_loss = model(x, edge_index)
assert len(link_loss) == 3
else:
out = model(x, edge_index)
assert out.size() == (10, 40)
2 changes: 2 additions & 0 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .sgformer import SGFormer

from .polynormer import Polynormer
from .nodeformer import NodeFormer
# Deprecated:
from torch_geometric.explain.algorithm.captum import (to_captum_input,
captum_output_to_dicts)
Expand Down Expand Up @@ -85,5 +86,6 @@
'LPFormer',
'SGFormer',
'Polynormer',
'NodeFormer',
'ARLinkPredictor',
]
Loading
Loading