Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `NodeFormer` model and example ([#10409](https://github.com/pyg-team/pytorch_geometric/pull/10409))
- Added `torch_geometric.llm` and its examples ([#10436](https://github.com/pyg-team/pytorch_geometric/pull/10436))
- Added support for negative weights in `sparse_cross_entropy` ([#10432](https://github.com/pyg-team/pytorch_geometric/pull/10432))
- Added `connected_components()` method to `Data` and `HeterData` ([#10388](https://github.com/pyg-team/pytorch_geometric/pull/10388))
Expand Down
55 changes: 47 additions & 8 deletions examples/ogbn_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@

from torch_geometric import seed_everything
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn.models import GAT, GraphSAGE, Polynormer, SGFormer
from torch_geometric.nn.models import (
GAT,
GraphSAGE,
NodeFormer,
Polynormer,
SGFormer,
)
from torch_geometric.utils import (
add_self_loops,
remove_self_loops,
Expand All @@ -37,7 +43,7 @@
"--model",
type=str.lower,
default='SGFormer',
choices=['sage', 'gat', 'sgformer', 'polynormer'],
choices=['sage', 'gat', 'sgformer', 'polynormer', 'nodeformer'],
help="Model used for training",
)

Expand All @@ -55,6 +61,8 @@
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--wd', type=float, default=0.0)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--lamda', type=float, default=0.1,
help='weight for edge reg loss of nodeformer')
parser.add_argument(
'--use_directed_graph',
action='store_true',
Expand Down Expand Up @@ -132,18 +140,23 @@ def train(epoch: int) -> tuple[Tensor, float]:

total_loss = total_correct = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
if args.model in ['sgformer', 'polynormer']:
if args.model == 'polynormer' and epoch == args.local_epochs:
print('start global attention')
model._global = True
out = model(batch.x, batch.edge_index.to(device),
batch.batch.to(device))[:batch.batch_size]
out = model(batch.x, batch.edge_index,
batch.batch)[:batch.batch_size]
elif args.model in ['nodeformer']:
out, link_loss = model(batch.x, batch.edge_index)
out = out[:batch.batch_size]
else:
out = model(batch.x,
batch.edge_index.to(device))[:batch.batch_size]
out = model(batch.x, batch.edge_index)[:batch.batch_size]
y = batch.y[:batch.batch_size].squeeze().to(torch.long)
loss = F.cross_entropy(out, y)
if args.model in ['nodeformer']:
loss -= args.lamda * sum(link_loss) / len(link_loss)
loss.backward()
optimizer.step()

Expand All @@ -168,6 +181,9 @@ def test(loader: NeighborLoader) -> float:
if args.model in ['sgformer', 'polynormer']:
out = model(batch.x, batch.edge_index,
batch.batch)[:batch.batch_size]
elif args.model in ['nodeformer']:
out, _ = model(batch.x, batch.edge_index)
out = out[:batch.batch_size]
else:
out = model(batch.x, batch.edge_index)[:batch_size]
pred = out.argmax(dim=-1)
Expand Down Expand Up @@ -214,6 +230,24 @@ def get_model(model_name: str) -> torch.nn.Module:
out_channels=dataset.num_classes,
local_layers=num_layers,
)
elif model_name == 'nodeformer':
model = NodeFormer(
in_channels=dataset.num_features,
hidden_channels=num_hidden_channels,
out_channels=dataset.num_classes,
num_layers=num_layers,
num_heads=args.num_heads,
use_bn=True,
nb_random_features=100,
use_gumbel=True,
use_residual=True,
use_act=True,
use_jk=True,
nb_gumbel_sample=5,
rb_order=1,
rb_trans='identity',
tau=1.0,
)
else:
raise ValueError(f'Unsupported model type: {model_name}')

Expand All @@ -227,8 +261,13 @@ def get_model(model_name: str) -> torch.nn.Module:
lr=args.lr,
weight_decay=args.wd,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max',
patience=5)
if args.model == 'nodeformer':
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[100, 200],
gamma=0.5)
else:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max', patience=5)

print(f'Total time before training begins took '
f'{time.perf_counter() - wall_clock_start:.4f}s')
Expand Down
21 changes: 21 additions & 0 deletions test/nn/models/test_nodeformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch

from torch_geometric.nn.models import NodeFormer


def test_nodeformer():
x = torch.randn(10, 16)
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 0, 6, 7, 8, 9, 5],
])

model = NodeFormer(
in_channels=16,
hidden_channels=128,
out_channels=40,
num_layers=3,
)
out, link_loss = model(x, edge_index)
assert out.size() == (10, 40)
assert len(link_loss) == 3
2 changes: 2 additions & 0 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .sgformer import SGFormer

from .polynormer import Polynormer
from .nodeformer import NodeFormer
# Deprecated:
from torch_geometric.explain.algorithm.captum import (to_captum_input,
captum_output_to_dicts)
Expand Down Expand Up @@ -85,5 +86,6 @@
'LPFormer',
'SGFormer',
'Polynormer',
'NodeFormer',
'ARLinkPredictor',
]
Loading
Loading