From 39e611f1d7a990942895d6be412c9733d1481d1e Mon Sep 17 00:00:00 2001 From: xnuohz Date: Mon, 11 Aug 2025 21:44:29 +0800 Subject: [PATCH 1/7] update --- examples/ogbn_train.py | 96 ++++- test/nn/models/test_nodeformer.py | 23 ++ torch_geometric/nn/models/__init__.py | 2 + torch_geometric/nn/models/nodeformer.py | 508 ++++++++++++++++++++++++ 4 files changed, 616 insertions(+), 13 deletions(-) create mode 100644 test/nn/models/test_nodeformer.py create mode 100644 torch_geometric/nn/models/nodeformer.py diff --git a/examples/ogbn_train.py b/examples/ogbn_train.py index 141780eef78c..c2d4cdf5e64e 100644 --- a/examples/ogbn_train.py +++ b/examples/ogbn_train.py @@ -11,7 +11,13 @@ from torch_geometric import seed_everything from torch_geometric.loader import NeighborLoader -from torch_geometric.nn.models import GAT, GraphSAGE, Polynormer, SGFormer +from torch_geometric.nn.models import ( + GAT, + GraphSAGE, + NodeFormer, + Polynormer, + SGFormer, +) from torch_geometric.utils import ( add_self_loops, remove_self_loops, @@ -37,24 +43,26 @@ "--model", type=str.lower, default='SGFormer', - choices=['sage', 'gat', 'sgformer', 'polynormer'], + choices=['sage', 'gat', 'sgformer', 'polynormer', 'nodeformer'], help="Model used for training", ) -parser.add_argument('-e', '--epochs', type=int, default=50) +parser.add_argument('-e', '--epochs', type=int, default=100) parser.add_argument('-le', '--local_epochs', type=int, default=50, help='warmup epochs for polynormer') parser.add_argument('--num_layers', type=int, default=3) -parser.add_argument('--num_heads', type=int, default=1, +parser.add_argument('--num_heads', type=int, default=4, help='number of heads for GAT or Graph Transformer model.') -parser.add_argument('-b', '--batch_size', type=int, default=1024) +parser.add_argument('-b', '--batch_size', type=int, default=2048) parser.add_argument('--num_workers', type=int, default=12) parser.add_argument('--fan_out', type=int, default=10, help='number of neighbors in each layer') -parser.add_argument('--hidden_channels', type=int, default=256) -parser.add_argument('--lr', type=float, default=0.003) +parser.add_argument('--hidden_channels', type=int, default=128) +parser.add_argument('--lr', type=float, default=1e-2) parser.add_argument('--wd', type=float, default=0.0) parser.add_argument('--dropout', type=float, default=0.5) +parser.add_argument('--lamda', type=float, default=0.1, + help='[nodeforder]weight for edge reg loss') parser.add_argument( '--use_directed_graph', action='store_true', @@ -132,18 +140,24 @@ def train(epoch: int) -> tuple[Tensor, float]: total_loss = total_correct = 0 for batch in train_loader: + batch = batch.to(device) optimizer.zero_grad() if args.model in ['sgformer', 'polynormer']: if args.model == 'polynormer' and epoch == args.local_epochs: print('start global attention') model._global = True - out = model(batch.x, batch.edge_index.to(device), - batch.batch.to(device))[:batch.batch_size] + out = model(batch.x, batch.edge_index, + batch.batch)[:batch.batch_size] + elif args.model in ['nodeformer']: + # import pdb; pdb.set_trace() + out, link_loss = model(batch.x, batch.edge_index) + out = out[:batch.batch_size] else: - out = model(batch.x, - batch.edge_index.to(device))[:batch.batch_size] + out = model(batch.x, batch.edge_index)[:batch.batch_size] y = batch.y[:batch.batch_size].squeeze().to(torch.long) loss = F.cross_entropy(out, y) + if args.model in ['nodeformer']: + loss -= args.lamda * sum(link_loss) / len(link_loss) loss.backward() optimizer.step() @@ -168,6 +182,9 @@ def test(loader: NeighborLoader) -> float: if args.model in ['sgformer', 'polynormer']: out = model(batch.x, batch.edge_index, batch.batch)[:batch.batch_size] + elif args.model in ['nodeformer']: + out, _ = model(batch.x, batch.edge_index) + out = out[:batch.batch_size] else: out = model(batch.x, batch.edge_index)[:batch_size] pred = out.argmax(dim=-1) @@ -181,6 +198,16 @@ def test(loader: NeighborLoader) -> float: def get_model(model_name: str) -> torch.nn.Module: if model_name == 'gat': + """ + Average Epoch Time on training: 0.5798s + Average Epoch Time on inference: 0.1722s + Average Epoch Time: 0.7520s + Median Epoch Time: 0.7528s + Best Validation Accuracy: 68.54% + Testing... + Test Accuracy: 67.45% + Total Program Runtime: 76.2115s + """ model = GAT( in_channels=dataset.num_features, hidden_channels=num_hidden_channels, @@ -190,6 +217,16 @@ def get_model(model_name: str) -> torch.nn.Module: heads=args.num_heads, ) elif model_name == 'sage': + """ + Average Epoch Time on training: 0.3990s + Average Epoch Time on inference: 0.1387s + Average Epoch Time: 0.5378s + Median Epoch Time: 0.5314s + Best Validation Accuracy: 69.69% + Testing... + Test Accuracy: 68.26% + Total Program Runtime: 54.5042s + """ model = GraphSAGE( in_channels=dataset.num_features, hidden_channels=num_hidden_channels, @@ -214,6 +251,34 @@ def get_model(model_name: str) -> torch.nn.Module: out_channels=dataset.num_classes, local_layers=num_layers, ) + elif model_name == 'nodeformer': + """ + Average Epoch Time on training: 2.4006s + Average Epoch Time on inference: 0.2627s + Average Epoch Time: 2.6633s + Median Epoch Time: 2.6391s + Best Validation Accuracy: 69.96% + Testing... + Test Accuracy: 68.18% + Total Program Runtime: 267.4139s + """ + model = NodeFormer( + in_channels=dataset.num_features, + hidden_channels=num_hidden_channels, + out_channels=dataset.num_classes, + num_layers=num_layers, + num_heads=args.num_heads, + use_bn=True, + nb_random_features=100, + use_gumbel=True, + use_residual=True, + use_act=True, + use_jk=True, + nb_gumbel_sample=5, + rb_order=1, + rb_trans='identity', + tau=1.0, + ) else: raise ValueError(f'Unsupported model type: {model_name}') @@ -227,8 +292,13 @@ def get_model(model_name: str) -> torch.nn.Module: lr=args.lr, weight_decay=args.wd, ) -scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', - patience=5) +if args.model == 'nodeformer': + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, + milestones=[100, 200], + gamma=0.5) +else: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode='max', patience=5) print(f'Total time before training begins took ' f'{time.perf_counter() - wall_clock_start:.4f}s') diff --git a/test/nn/models/test_nodeformer.py b/test/nn/models/test_nodeformer.py new file mode 100644 index 000000000000..3f4ce1e5bb05 --- /dev/null +++ b/test/nn/models/test_nodeformer.py @@ -0,0 +1,23 @@ +import torch + +from torch_geometric.nn.models import NodeFormer +from torch_geometric.testing import withPackage + + +@withPackage('torch_sparse') +def test_sgformer(): + x = torch.randn(10, 16) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 0, 6, 7, 8, 9, 5], + ]) + + model = NodeFormer( + in_channels=16, + hidden_channels=128, + out_channels=40, + num_layers=3, + ) + out, link_loss = model(x, edge_index) + assert out.size() == (10, 40) + assert len(link_loss) == 3 diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 269fed1da780..b3b6a5065c71 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -36,6 +36,7 @@ from .glem import GLEM from .sgformer import SGFormer from .polynormer import Polynormer +from .nodeformer import NodeFormer # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) @@ -92,5 +93,6 @@ 'GLEM', 'SGFormer', 'Polynormer', + 'NodeFormer', 'ARLinkPredictor', ] diff --git a/torch_geometric/nn/models/nodeformer.py b/torch_geometric/nn/models/nodeformer.py new file mode 100644 index 000000000000..7574acf2014c --- /dev/null +++ b/torch_geometric/nn/models/nodeformer.py @@ -0,0 +1,508 @@ +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_sparse import SparseTensor, matmul + +from torch_geometric.utils import degree + +BIG_CONSTANT = 1e8 + + +def create_projection_matrix(m, d, seed=0, scaling=0, struct_mode=False): + nb_full_blocks = int(m / d) + block_list = [] + current_seed = seed + for _ in range(nb_full_blocks): + torch.manual_seed(current_seed) + if struct_mode: + q = create_products_of_givens_rotations(d, current_seed) + else: + unstructured_block = torch.randn((d, d)) + q, _ = torch.linalg.qr(unstructured_block) + q = torch.t(q) + block_list.append(q) + current_seed += 1 + remaining_rows = m - nb_full_blocks * d + if remaining_rows > 0: + torch.manual_seed(current_seed) + if struct_mode: + q = create_products_of_givens_rotations(d, current_seed) + else: + unstructured_block = torch.randn((d, d)) + q, _ = torch.linalg.qr(unstructured_block) + q = torch.t(q) + block_list.append(q[0:remaining_rows]) + final_matrix = torch.vstack(block_list) + + current_seed += 1 + torch.manual_seed(current_seed) + if scaling == 0: + multiplier = torch.norm(torch.randn((m, d)), dim=1) + elif scaling == 1: + multiplier = torch.sqrt(torch.tensor(float(d))) * torch.ones(m) + else: + raise ValueError("Scaling must be one of {0, 1}. Was %s" % scaling) + + return torch.matmul(torch.diag(multiplier), final_matrix) + + +def create_products_of_givens_rotations(dim, seed): + nb_givens_rotations = dim * int(math.ceil(math.log(float(dim)))) + q = np.eye(dim, dim) + np.random.seed(seed) + for _ in range(nb_givens_rotations): + random_angle = math.pi * np.random.uniform() + random_indices = np.random.choice(dim, 2) + index_i = min(random_indices[0], random_indices[1]) + index_j = max(random_indices[0], random_indices[1]) + slice_i = q[index_i] + slice_j = q[index_j] + new_slice_i = math.cos(random_angle) * slice_i + math.cos( + random_angle) * slice_j + new_slice_j = -math.sin(random_angle) * slice_i + math.cos( + random_angle) * slice_j + q[index_i] = new_slice_i + q[index_j] = new_slice_j + return torch.tensor(q, dtype=torch.float32) + + +def relu_kernel_transformation(data, is_query, projection_matrix=None, + numerical_stabilizer=0.001): + del is_query + if projection_matrix is None: + return F.relu(data) + numerical_stabilizer + else: + ratio = 1.0 / torch.sqrt( + torch.tensor(projection_matrix.shape[0], torch.float32)) + data_dash = ratio * torch.einsum("bnhd,md->bnhm", data, + projection_matrix) + return F.relu(data_dash) + numerical_stabilizer + + +def softmax_kernel_transformation(data, is_query, projection_matrix=None, + numerical_stabilizer=0.000001): + data_normalizer = 1.0 / torch.sqrt( + torch.sqrt(torch.tensor(data.shape[-1], dtype=torch.float32))) + data = data_normalizer * data + ratio = 1.0 / torch.sqrt( + torch.tensor(projection_matrix.shape[0], dtype=torch.float32)) + data_dash = torch.einsum("bnhd,md->bnhm", data, projection_matrix) + diag_data = torch.square(data) + diag_data = torch.sum(diag_data, dim=len(data.shape) - 1) + diag_data = diag_data / 2.0 + diag_data = torch.unsqueeze(diag_data, dim=len(data.shape) - 1) + last_dims_t = len(data_dash.shape) - 1 + attention_dims_t = len(data_dash.shape) - 3 + if is_query: + data_dash = ratio * ( + torch.exp(data_dash - diag_data - + torch.max(data_dash, dim=last_dims_t, keepdim=True)[0]) + + numerical_stabilizer) + else: + data_dash = ratio * (torch.exp(data_dash - diag_data - torch.max( + torch.max(data_dash, dim=last_dims_t, keepdim=True)[0], + dim=attention_dims_t, keepdim=True)[0]) + numerical_stabilizer) + return data_dash + + +def numerator(qs, ks, vs): + kvs = torch.einsum("nbhm,nbhd->bhmd", ks, + vs) # kvs refers to U_k in the paper + return torch.einsum("nbhm,bhmd->nbhd", qs, kvs) + + +def denominator(qs, ks): + all_ones = torch.ones([ks.shape[0]]).to(qs.device) + ks_sum = torch.einsum("nbhm,n->bhm", ks, + all_ones) # ks_sum refers to O_k in the paper + return torch.einsum("nbhm,bhm->nbh", qs, ks_sum) + + +def numerator_gumbel(qs, ks, vs): + kvs = torch.einsum("nbhkm,nbhd->bhkmd", ks, + vs) # kvs refers to U_k in the paper + return torch.einsum("nbhm,bhkmd->nbhkd", qs, kvs) + + +def denominator_gumbel(qs, ks): + all_ones = torch.ones([ks.shape[0]]).to(qs.device) + ks_sum = torch.einsum("nbhkm,n->bhkm", ks, + all_ones) # ks_sum refers to O_k in the paper + return torch.einsum("nbhm,bhkm->nbhk", qs, ks_sum) + + +def kernelized_softmax( + query, + key, + value, + kernel_transformation, + projection_matrix=None, + edge_index=None, + tau=0.25, + return_weight=True, +): + r"""Fast computation of all-pair attentive aggregation with linear + complexity. + input: query/key/value [B, N, H, D] + return: updated node emb, attention weight (for computing edge loss) + B = graph number (always equal to 1 in Node Classification), + N = node number, H = head number, + M = random feature dimension, D = hidden size. + """ + query = query / math.sqrt(tau) + key = key / math.sqrt(tau) + query_prime = kernel_transformation(query, True, + projection_matrix) # [B, N, H, M] + key_prime = kernel_transformation(key, False, + projection_matrix) # [B, N, H, M] + query_prime = query_prime.permute(1, 0, 2, 3) # [N, B, H, M] + key_prime = key_prime.permute(1, 0, 2, 3) # [N, B, H, M] + value = value.permute(1, 0, 2, 3) # [N, B, H, D] + + # compute updated node emb, this step requires O(N) + z_num = numerator(query_prime, key_prime, value) + z_den = denominator(query_prime, key_prime) + + z_num = z_num.permute(1, 0, 2, 3) # [B, N, H, D] + z_den = z_den.permute(1, 0, 2) + z_den = torch.unsqueeze(z_den, len(z_den.shape)) + z_output = z_num / z_den # [B, N, H, D] + + # query edge prob for computing edge-level reg loss, + # this step requires O(E) + if return_weight: + start, end = edge_index + query_end, key_start = query_prime[end], key_prime[ + start] # [E, B, H, M] + edge_attn_num = torch.einsum("ebhm,ebhm->ebh", query_end, + key_start) # [E, B, H] + edge_attn_num = edge_attn_num.permute(1, 0, 2) # [B, E, H] + attn_normalizer = denominator(query_prime, key_prime) # [N, B, H] + edge_attn_dem = attn_normalizer[end] # [E, B, H] + edge_attn_dem = edge_attn_dem.permute(1, 0, 2) # [B, E, H] + A_weight = edge_attn_num / edge_attn_dem # [B, E, H] + + return z_output, A_weight + + else: + return z_output + + +def kernelized_gumbel_softmax( + query, + key, + value, + kernel_transformation, + projection_matrix=None, + edge_index=None, + K=10, + tau=0.25, + return_weight=True, +): + r"""Fast computation of all-pair attentive aggregation with + linear complexity. + input: query/key/value [B, N, H, D] + return: updated node emb, attention weight (for computing edge loss) + B = graph number (always equal to 1 in Node Classification), + N = node number, H = head number, + M = random feature dimension, D = hidden size, + K = number of Gumbel sampling. + """ + query = query / math.sqrt(tau) + key = key / math.sqrt(tau) + query_prime = kernel_transformation(query, True, + projection_matrix) # [B, N, H, M] + key_prime = kernel_transformation(key, False, + projection_matrix) # [B, N, H, M] + query_prime = query_prime.permute(1, 0, 2, 3) # [N, B, H, M] + key_prime = key_prime.permute(1, 0, 2, 3) # [N, B, H, M] + value = value.permute(1, 0, 2, 3) # [N, B, H, D] + + # compute updated node emb, this step requires O(N) + gumbels = (-torch.empty(key_prime.shape[:-1] + (K, ), memory_format=torch. + legacy_contiguous_format).exponential_().log()).to( + query.device) / tau # [N, B, H, K] + key_t_gumbel = key_prime.unsqueeze(3) * gumbels.exp().unsqueeze( + 4) # [N, B, H, K, M] + z_num = numerator_gumbel(query_prime, key_t_gumbel, + value) # [N, B, H, K, D] + z_den = denominator_gumbel(query_prime, key_t_gumbel) # [N, B, H, K] + + z_num = z_num.permute(1, 0, 2, 3, 4) # [B, N, H, K, D] + z_den = z_den.permute(1, 0, 2, 3) # [B, N, H, K] + z_den = torch.unsqueeze(z_den, len(z_den.shape)) + z_output = torch.mean(z_num / z_den, dim=3) # [B, N, H, D] + + # query edge prob for computing edge-level reg loss, + # this step requires O(E) + if return_weight: + start, end = edge_index + query_end, key_start = query_prime[end], key_prime[ + start] # [E, B, H, M] + edge_attn_num = torch.einsum("ebhm,ebhm->ebh", query_end, + key_start) # [E, B, H] + edge_attn_num = edge_attn_num.permute(1, 0, 2) # [B, E, H] + attn_normalizer = denominator(query_prime, key_prime) # [N, B, H] + edge_attn_dem = attn_normalizer[end] # [E, B, H] + edge_attn_dem = edge_attn_dem.permute(1, 0, 2) # [B, E, H] + A_weight = edge_attn_num / edge_attn_dem # [B, E, H] + + return z_output, A_weight + + else: + return z_output + + +def add_conv_relational_bias(x, edge_index, b, trans='sigmoid'): + r"""Compute updated result by the relational bias of input adjacency + the implementation is similar to the Graph Convolution Network with a + (shared) scalar weight for each edge. + """ + row, col = edge_index + d_in = degree(col, x.shape[1]).float() + d_norm_in = (1. / d_in[col]).sqrt() + d_out = degree(row, x.shape[1]).float() + d_norm_out = (1. / d_out[row]).sqrt() + conv_output = [] + for i in range(x.shape[2]): + if trans == 'sigmoid': + b_i = F.sigmoid(b[i]) + elif trans == 'identity': + b_i = b[i] + else: + raise NotImplementedError + value = torch.ones_like(row) * b_i * d_norm_in * d_norm_out + adj_i = SparseTensor(row=col, col=row, value=value, + sparse_sizes=(x.shape[1], x.shape[1])) + conv_output.append(matmul(adj_i, x[:, :, i])) # [B, N, D] + conv_output = torch.stack(conv_output, dim=2) # [B, N, H, D] + return conv_output + + +def adj_mul(adj_i, adj, N): + adj_i_sp = torch.sparse_coo_tensor( + adj_i, + torch.ones(adj_i.shape[1], dtype=torch.float).to(adj.device), (N, N)) + adj_sp = torch.sparse_coo_tensor( + adj, + torch.ones(adj.shape[1], dtype=torch.float).to(adj.device), (N, N)) + adj_j = torch.sparse.mm(adj_i_sp, adj_sp) + adj_j = adj_j.coalesce().indices() + return adj_j + + +class NodeFormerConv(nn.Module): + r"""One layer of NodeFormer that attentive aggregates all nodes + over a latent graph. + Return: node embeddings for next layer, edge loss at this layer. + """ + def __init__(self, in_channels, out_channels, num_heads, + kernel_transformation=softmax_kernel_transformation, + projection_matrix_type='a', nb_random_features=10, + use_gumbel=True, nb_gumbel_sample=10, rb_order=0, + rb_trans='sigmoid', use_edge_loss=True): + super().__init__() + self.Wk = nn.Linear(in_channels, out_channels * num_heads) + self.Wq = nn.Linear(in_channels, out_channels * num_heads) + self.Wv = nn.Linear(in_channels, out_channels * num_heads) + self.Wo = nn.Linear(out_channels * num_heads, out_channels) + if rb_order >= 1: + self.b = torch.nn.Parameter(torch.FloatTensor(rb_order, num_heads), + requires_grad=True) + + self.out_channels = out_channels + self.num_heads = num_heads + self.kernel_transformation = kernel_transformation + self.projection_matrix_type = projection_matrix_type + self.nb_random_features = nb_random_features + self.use_gumbel = use_gumbel + self.nb_gumbel_sample = nb_gumbel_sample + self.rb_order = rb_order + self.rb_trans = rb_trans + self.use_edge_loss = use_edge_loss + + def reset_parameters(self): + self.Wk.reset_parameters() + self.Wq.reset_parameters() + self.Wv.reset_parameters() + self.Wo.reset_parameters() + if self.rb_order >= 1: + if self.rb_trans == 'sigmoid': + torch.nn.init.constant_(self.b, 0.1) + elif self.rb_trans == 'identity': + torch.nn.init.constant_(self.b, 1.0) + + def forward(self, z, adjs, tau): + N = z.size(1) + query = self.Wq(z).reshape(-1, N, self.num_heads, self.out_channels) + key = self.Wk(z).reshape(-1, N, self.num_heads, self.out_channels) + value = self.Wv(z).reshape(-1, N, self.num_heads, self.out_channels) + + if self.projection_matrix_type is None: + projection_matrix = None + else: + dim = query.shape[-1] + seed = torch.ceil(torch.abs(torch.sum(query) * BIG_CONSTANT)).to( + torch.int32) + projection_matrix = create_projection_matrix( + self.nb_random_features, dim, seed=seed).to(query.device) + + # compute all-pair message passing update and attn weight + # on input edges, requires O(N) or O(N + E) + # only using Gumbel noise for training + if self.use_gumbel and self.training: + z_next, weight = kernelized_gumbel_softmax( + query, key, value, self.kernel_transformation, + projection_matrix, adjs[0], self.nb_gumbel_sample, tau, + self.use_edge_loss) + else: + z_next, weight = kernelized_softmax(query, key, value, + self.kernel_transformation, + projection_matrix, adjs[0], + tau, self.use_edge_loss) + + # compute update by relational bias of input adjacency, requires O(E) + for i in range(self.rb_order): + z_next += add_conv_relational_bias(value, adjs[i], self.b[i], + self.rb_trans) + + # aggregate results of multiple heads + z_next = self.Wo(z_next.flatten(-2, -1)) + + # compute edge regularization loss on input adjacency + if self.use_edge_loss: + row, col = adjs[0] + d_in = degree(col, query.shape[1]).float() + d_norm = 1. / d_in[col] + d_norm_ = d_norm.reshape(1, -1, 1).repeat(1, 1, weight.shape[-1]) + link_loss = torch.mean(weight.log() * d_norm_) + + return z_next, link_loss + + else: + return z_next + + +class NodeFormer(nn.Module): + r"""The NodeFormer model from the + `"NodeFormer: A Scalable Graph Structure Learning + Transformer for Node Classification" + `_ paper. + Predicted node labels, a list of edge losses at every layer. + + Args: + in_channels (int): Input channels. + hidden_channels (int): Hidden channels. + out_channels (int): Output channels. + """ + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + num_layers=3, + num_heads=4, + dropout=0.0, + kernel_transformation=softmax_kernel_transformation, + nb_random_features=30, + use_bn=True, + use_gumbel=True, + use_residual=True, + use_act=False, + use_jk=False, + nb_gumbel_sample=10, + rb_order=0, + rb_trans='sigmoid', + use_edge_loss=True, + tau=0.25, + ): + super().__init__() + + self.convs = nn.ModuleList() + self.fcs = nn.ModuleList() + self.fcs.append(nn.Linear(in_channels, hidden_channels)) + self.bns = nn.ModuleList() + self.bns.append(nn.LayerNorm(hidden_channels)) + for _ in range(num_layers): + self.convs.append( + NodeFormerConv(hidden_channels, hidden_channels, + num_heads=num_heads, + kernel_transformation=kernel_transformation, + nb_random_features=nb_random_features, + use_gumbel=use_gumbel, + nb_gumbel_sample=nb_gumbel_sample, + rb_order=rb_order, rb_trans=rb_trans, + use_edge_loss=use_edge_loss)) + self.bns.append(nn.LayerNorm(hidden_channels)) + + if use_jk: + self.fcs.append( + nn.Linear(hidden_channels * num_layers + hidden_channels, + out_channels)) + else: + self.fcs.append(nn.Linear(hidden_channels, out_channels)) + + self.dropout = dropout + self.activation = F.elu + self.use_bn = use_bn + self.use_residual = use_residual + self.use_act = use_act + self.use_jk = use_jk + self.use_edge_loss = use_edge_loss + self.tau = tau + self.rb_order = rb_order + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + for fc in self.fcs: + fc.reset_parameters() + + # def forward(self, x, adjs, tau=1.0): + def forward(self, x, adjs): + # todo: edge_index of high order adjacency + # for i in range(self.rb_order - 1): + # adj = adj_mul(adj, adj, n) + # adjs.append(adj) + adjs = [adjs] + x = x.unsqueeze(0) # [B, N, H, D], B=1 denotes number of graph + layer_ = [] + link_loss_ = [] + z = self.fcs[0](x) + if self.use_bn: + z = self.bns[0](z) + z = self.activation(z) + z = F.dropout(z, p=self.dropout, training=self.training) + layer_.append(z) + + for i, conv in enumerate(self.convs): + if self.use_edge_loss: + z, link_loss = conv(z, adjs, self.tau) + link_loss_.append(link_loss) + else: + z = conv(z, adjs, self.tau) + if self.use_residual: + z += layer_[i] + if self.use_bn: + z = self.bns[i + 1](z) + if self.use_act: + z = self.activation(z) + z = F.dropout(z, p=self.dropout, training=self.training) + layer_.append(z) + + if self.use_jk: # use jk connection for each layer + z = torch.cat(layer_, dim=-1) + + x_out = self.fcs[-1](z).squeeze(0) + x_out = F.log_softmax(x_out, dim=-1) + + # import pdb; pdb.set_trace() + if self.use_edge_loss: + return x_out, link_loss_ + else: + return x_out From 7009a8b3f3288111456ef51e2722010fbf9a436f Mon Sep 17 00:00:00 2001 From: xnuohz Date: Mon, 11 Aug 2025 21:47:18 +0800 Subject: [PATCH 2/7] update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f25f2b7e80fe..07b3b27a8d8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `NodeFormer` model and example ([#10409](https://github.com/pyg-team/pytorch_geometric/pull/10409)) - Added `Polynormer` model and example ([#9908](https://github.com/pyg-team/pytorch_geometric/pull/9908)) - Added `ProteinMPNN` model and example ([#10289](https://github.com/pyg-team/pytorch_geometric/pull/10289)) - Added the `Teeth3DS` dataset, an extended benchmark for intraoral 3D scan analysis ([#9833](https://github.com/pyg-team/pytorch_geometric/pull/9833)) From 082d6e6b26fd3e8b4ef211634596e21be84f8216 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Mon, 8 Sep 2025 22:28:13 +0800 Subject: [PATCH 3/7] update --- examples/ogbn_train.py | 43 ++++--------------------- test/nn/models/test_nodeformer.py | 4 +-- torch_geometric/nn/models/nodeformer.py | 37 +++++++++++---------- 3 files changed, 25 insertions(+), 59 deletions(-) diff --git a/examples/ogbn_train.py b/examples/ogbn_train.py index c2d4cdf5e64e..5ef1184210af 100644 --- a/examples/ogbn_train.py +++ b/examples/ogbn_train.py @@ -47,22 +47,22 @@ help="Model used for training", ) -parser.add_argument('-e', '--epochs', type=int, default=100) +parser.add_argument('-e', '--epochs', type=int, default=50) parser.add_argument('-le', '--local_epochs', type=int, default=50, help='warmup epochs for polynormer') parser.add_argument('--num_layers', type=int, default=3) -parser.add_argument('--num_heads', type=int, default=4, +parser.add_argument('--num_heads', type=int, default=1, help='number of heads for GAT or Graph Transformer model.') -parser.add_argument('-b', '--batch_size', type=int, default=2048) +parser.add_argument('-b', '--batch_size', type=int, default=1024) parser.add_argument('--num_workers', type=int, default=12) parser.add_argument('--fan_out', type=int, default=10, help='number of neighbors in each layer') -parser.add_argument('--hidden_channels', type=int, default=128) -parser.add_argument('--lr', type=float, default=1e-2) +parser.add_argument('--hidden_channels', type=int, default=256) +parser.add_argument('--lr', type=float, default=0.003) parser.add_argument('--wd', type=float, default=0.0) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--lamda', type=float, default=0.1, - help='[nodeforder]weight for edge reg loss') + help='weight for edge reg loss of nodeformer') parser.add_argument( '--use_directed_graph', action='store_true', @@ -149,7 +149,6 @@ def train(epoch: int) -> tuple[Tensor, float]: out = model(batch.x, batch.edge_index, batch.batch)[:batch.batch_size] elif args.model in ['nodeformer']: - # import pdb; pdb.set_trace() out, link_loss = model(batch.x, batch.edge_index) out = out[:batch.batch_size] else: @@ -198,16 +197,6 @@ def test(loader: NeighborLoader) -> float: def get_model(model_name: str) -> torch.nn.Module: if model_name == 'gat': - """ - Average Epoch Time on training: 0.5798s - Average Epoch Time on inference: 0.1722s - Average Epoch Time: 0.7520s - Median Epoch Time: 0.7528s - Best Validation Accuracy: 68.54% - Testing... - Test Accuracy: 67.45% - Total Program Runtime: 76.2115s - """ model = GAT( in_channels=dataset.num_features, hidden_channels=num_hidden_channels, @@ -217,16 +206,6 @@ def get_model(model_name: str) -> torch.nn.Module: heads=args.num_heads, ) elif model_name == 'sage': - """ - Average Epoch Time on training: 0.3990s - Average Epoch Time on inference: 0.1387s - Average Epoch Time: 0.5378s - Median Epoch Time: 0.5314s - Best Validation Accuracy: 69.69% - Testing... - Test Accuracy: 68.26% - Total Program Runtime: 54.5042s - """ model = GraphSAGE( in_channels=dataset.num_features, hidden_channels=num_hidden_channels, @@ -252,16 +231,6 @@ def get_model(model_name: str) -> torch.nn.Module: local_layers=num_layers, ) elif model_name == 'nodeformer': - """ - Average Epoch Time on training: 2.4006s - Average Epoch Time on inference: 0.2627s - Average Epoch Time: 2.6633s - Median Epoch Time: 2.6391s - Best Validation Accuracy: 69.96% - Testing... - Test Accuracy: 68.18% - Total Program Runtime: 267.4139s - """ model = NodeFormer( in_channels=dataset.num_features, hidden_channels=num_hidden_channels, diff --git a/test/nn/models/test_nodeformer.py b/test/nn/models/test_nodeformer.py index 3f4ce1e5bb05..85d37dbe54ab 100644 --- a/test/nn/models/test_nodeformer.py +++ b/test/nn/models/test_nodeformer.py @@ -1,11 +1,9 @@ import torch from torch_geometric.nn.models import NodeFormer -from torch_geometric.testing import withPackage -@withPackage('torch_sparse') -def test_sgformer(): +def test_nodeformer(): x = torch.randn(10, 16) edge_index = torch.tensor([ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], diff --git a/torch_geometric/nn/models/nodeformer.py b/torch_geometric/nn/models/nodeformer.py index 7574acf2014c..a0f3e5f9117f 100644 --- a/torch_geometric/nn/models/nodeformer.py +++ b/torch_geometric/nn/models/nodeformer.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch_sparse import SparseTensor, matmul from torch_geometric.utils import degree @@ -257,27 +256,33 @@ def kernelized_gumbel_softmax( def add_conv_relational_bias(x, edge_index, b, trans='sigmoid'): - r"""Compute updated result by the relational bias of input adjacency - the implementation is similar to the Graph Convolution Network with a - (shared) scalar weight for each edge. - """ row, col = edge_index - d_in = degree(col, x.shape[1]).float() + B, N, H, D = x.shape + + d_in = degree(col, N).float() d_norm_in = (1. / d_in[col]).sqrt() - d_out = degree(row, x.shape[1]).float() + d_out = degree(row, N).float() d_norm_out = (1. / d_out[row]).sqrt() + conv_output = [] - for i in range(x.shape[2]): + for i in range(H): if trans == 'sigmoid': - b_i = F.sigmoid(b[i]) + b_i = torch.sigmoid(b[i]) elif trans == 'identity': b_i = b[i] else: raise NotImplementedError - value = torch.ones_like(row) * b_i * d_norm_in * d_norm_out - adj_i = SparseTensor(row=col, col=row, value=value, - sparse_sizes=(x.shape[1], x.shape[1])) - conv_output.append(matmul(adj_i, x[:, :, i])) # [B, N, D] + + value = b_i * d_norm_in * d_norm_out # [E] + + out = torch.zeros(B, N, D, device=x.device, dtype=x.dtype) + out.index_add_( + 1, col, + x[:, row, i, :] * value.view(1, -1, 1) # :fire: 关键修正:扩展到 [B, E, D] + ) + + conv_output.append(out) + conv_output = torch.stack(conv_output, dim=2) # [B, N, H, D] return conv_output @@ -463,12 +468,7 @@ def reset_parameters(self): for fc in self.fcs: fc.reset_parameters() - # def forward(self, x, adjs, tau=1.0): def forward(self, x, adjs): - # todo: edge_index of high order adjacency - # for i in range(self.rb_order - 1): - # adj = adj_mul(adj, adj, n) - # adjs.append(adj) adjs = [adjs] x = x.unsqueeze(0) # [B, N, H, D], B=1 denotes number of graph layer_ = [] @@ -501,7 +501,6 @@ def forward(self, x, adjs): x_out = self.fcs[-1](z).squeeze(0) x_out = F.log_softmax(x_out, dim=-1) - # import pdb; pdb.set_trace() if self.use_edge_loss: return x_out, link_loss_ else: From d2844f95cc6a41789878a58a8c9a16160d93a62f Mon Sep 17 00:00:00 2001 From: xnuohz Date: Mon, 8 Sep 2025 22:28:41 +0800 Subject: [PATCH 4/7] update --- torch_geometric/nn/models/nodeformer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/models/nodeformer.py b/torch_geometric/nn/models/nodeformer.py index a0f3e5f9117f..e836cf0017f4 100644 --- a/torch_geometric/nn/models/nodeformer.py +++ b/torch_geometric/nn/models/nodeformer.py @@ -277,8 +277,9 @@ def add_conv_relational_bias(x, edge_index, b, trans='sigmoid'): out = torch.zeros(B, N, D, device=x.device, dtype=x.dtype) out.index_add_( - 1, col, - x[:, row, i, :] * value.view(1, -1, 1) # :fire: 关键修正:扩展到 [B, E, D] + 1, + col, + x[:, row, i, :] * value.view(1, -1, 1) # :fire: 关键修正:扩展到 [B, E, D] ) conv_output.append(out) From 0224761c6214a55a75a175173ed40a047a0a572b Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 14 Sep 2025 17:25:23 +0800 Subject: [PATCH 5/7] improve test cov --- test/nn/models/test_nodeformer.py | 35 +++++- torch_geometric/nn/models/nodeformer.py | 156 +++++++++++++++++------- 2 files changed, 142 insertions(+), 49 deletions(-) diff --git a/test/nn/models/test_nodeformer.py b/test/nn/models/test_nodeformer.py index 85d37dbe54ab..0547107d8cde 100644 --- a/test/nn/models/test_nodeformer.py +++ b/test/nn/models/test_nodeformer.py @@ -1,9 +1,27 @@ +import pytest import torch from torch_geometric.nn.models import NodeFormer -def test_nodeformer(): +@pytest.mark.parametrize('use_bn', [True, False]) +@pytest.mark.parametrize('use_gumbel', [True, False]) +@pytest.mark.parametrize('use_residual', [True, False]) +@pytest.mark.parametrize('use_act', [True, False]) +@pytest.mark.parametrize('use_jk', [True, False]) +@pytest.mark.parametrize('use_edge_loss', [True, False]) +@pytest.mark.parametrize('rb_trans', ['sigmoid', 'identity']) +@pytest.mark.parametrize('rb_order', [0, 1]) +def test_nodeformer( + use_bn, + use_gumbel, + use_residual, + use_act, + use_jk, + use_edge_loss, + rb_trans, + rb_order, +): x = torch.randn(10, 16) edge_index = torch.tensor([ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], @@ -15,7 +33,18 @@ def test_nodeformer(): hidden_channels=128, out_channels=40, num_layers=3, + use_bn=use_bn, + use_gumbel=use_gumbel, + use_residual=use_residual, + use_act=use_act, + use_jk=use_jk, + use_edge_loss=use_edge_loss, + rb_trans=rb_trans, + rb_order=rb_order, ) - out, link_loss = model(x, edge_index) + if use_edge_loss: + out, link_loss = model(x, edge_index) + assert len(link_loss) == 3 + else: + out = model(x, edge_index) assert out.size() == (10, 40) - assert len(link_loss) == 3 diff --git a/torch_geometric/nn/models/nodeformer.py b/torch_geometric/nn/models/nodeformer.py index e836cf0017f4..4bb6ff5c917e 100644 --- a/torch_geometric/nn/models/nodeformer.py +++ b/torch_geometric/nn/models/nodeformer.py @@ -108,28 +108,28 @@ def softmax_kernel_transformation(data, is_query, projection_matrix=None, def numerator(qs, ks, vs): - kvs = torch.einsum("nbhm,nbhd->bhmd", ks, - vs) # kvs refers to U_k in the paper + # kvs refers to U_k in the paper + kvs = torch.einsum("nbhm,nbhd->bhmd", ks, vs) return torch.einsum("nbhm,bhmd->nbhd", qs, kvs) def denominator(qs, ks): all_ones = torch.ones([ks.shape[0]]).to(qs.device) - ks_sum = torch.einsum("nbhm,n->bhm", ks, - all_ones) # ks_sum refers to O_k in the paper + # ks_sum refers to O_k in the paper + ks_sum = torch.einsum("nbhm,n->bhm", ks, all_ones) return torch.einsum("nbhm,bhm->nbh", qs, ks_sum) def numerator_gumbel(qs, ks, vs): - kvs = torch.einsum("nbhkm,nbhd->bhkmd", ks, - vs) # kvs refers to U_k in the paper + # kvs refers to U_k in the paper + kvs = torch.einsum("nbhkm,nbhd->bhkmd", ks, vs) return torch.einsum("nbhm,bhkmd->nbhkd", qs, kvs) def denominator_gumbel(qs, ks): all_ones = torch.ones([ks.shape[0]]).to(qs.device) - ks_sum = torch.einsum("nbhkm,n->bhkm", ks, - all_ones) # ks_sum refers to O_k in the paper + # ks_sum refers to O_k in the paper + ks_sum = torch.einsum("nbhkm,n->bhkm", ks, all_ones) return torch.einsum("nbhm,bhkm->nbhk", qs, ks_sum) @@ -276,11 +276,7 @@ def add_conv_relational_bias(x, edge_index, b, trans='sigmoid'): value = b_i * d_norm_in * d_norm_out # [E] out = torch.zeros(B, N, D, device=x.device, dtype=x.dtype) - out.index_add_( - 1, - col, - x[:, row, i, :] * value.view(1, -1, 1) # :fire: 关键修正:扩展到 [B, E, D] - ) + out.index_add_(1, col, x[:, row, i, :] * value.view(1, -1, 1)) conv_output.append(out) @@ -301,15 +297,49 @@ def adj_mul(adj_i, adj, N): class NodeFormerConv(nn.Module): - r"""One layer of NodeFormer that attentive aggregates all nodes - over a latent graph. - Return: node embeddings for next layer, edge loss at this layer. + r"""One layer of NodeFormer model from the + `"NodeFormer: A Scalable Graph Structure Learning + Transformer for Node Classification" + `_ paper. + that attentive aggregates all nodes over a latent graph. + Predicted node labels, a list of edge losses at every layer. + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + num_heads (int): Number of parallel heads. + kernel_transformation (func, optional): The kernel + transformation function. + (default: :func:`softmax_kernel_transformation`) + projection_matrix_type (str, optional): The type of projection matrix + to use ('a' or 'b') (default: 'a'). + nb_random_features (int, optional): The number of random features. + (default: 10). + use_gumbel (bool, optional): Whether to use Gumbel sampling + (default: True). + nb_gumbel_sample (int, optional): The number of Gumbel samples. + (default: 10). + rb_order (int, optional): The order of relational bias. + (default: 0). + rb_trans (str, optional): The type of transformation. + relational bias ('sigmoid' or 'identity') (default: 'sigmoid'). + use_edge_loss (bool, optional): Whether to use edge loss + (default: True). """ - def __init__(self, in_channels, out_channels, num_heads, - kernel_transformation=softmax_kernel_transformation, - projection_matrix_type='a', nb_random_features=10, - use_gumbel=True, nb_gumbel_sample=10, rb_order=0, - rb_trans='sigmoid', use_edge_loss=True): + def __init__( + self, + in_channels: int, + out_channels: int, + num_heads: int, + kernel_transformation=softmax_kernel_transformation, + projection_matrix_type: str = 'a', + nb_random_features: int = 10, + use_gumbel: bool = True, + nb_gumbel_sample: int = 10, + rb_order: int = 0, + rb_trans: str = 'sigmoid', + use_edge_loss: bool = True, + ): super().__init__() self.Wk = nn.Linear(in_channels, out_channels * num_heads) self.Wq = nn.Linear(in_channels, out_channels * num_heads) @@ -360,16 +390,21 @@ def forward(self, z, adjs, tau): # on input edges, requires O(N) or O(N + E) # only using Gumbel noise for training if self.use_gumbel and self.training: - z_next, weight = kernelized_gumbel_softmax( - query, key, value, self.kernel_transformation, - projection_matrix, adjs[0], self.nb_gumbel_sample, tau, - self.use_edge_loss) + result = kernelized_gumbel_softmax(query, key, value, + self.kernel_transformation, + projection_matrix, adjs[0], + self.nb_gumbel_sample, tau, + self.use_edge_loss) else: - z_next, weight = kernelized_softmax(query, key, value, - self.kernel_transformation, - projection_matrix, adjs[0], - tau, self.use_edge_loss) + result = kernelized_softmax(query, key, value, + self.kernel_transformation, + projection_matrix, adjs[0], tau, + self.use_edge_loss) + if self.use_edge_loss: + z_next, weight = result + else: + z_next = result # compute update by relational bias of input adjacency, requires O(E) for i in range(self.rb_order): z_next += add_conv_relational_bias(value, adjs[i], self.b[i], @@ -403,27 +438,56 @@ class NodeFormer(nn.Module): in_channels (int): Input channels. hidden_channels (int): Hidden channels. out_channels (int): Output channels. + num_layers (int): Number of layers. + (default: `3`) + num_heads (int): Number of heads. + (default: `4`) + dropout (float): Dropout rate. + (default: `0.0`) + kernel_transformation=softmax_kernel_transformation, + nb_random_features (int): Number of random features. + (default: `30`) + use_bn (bool): Whether to use batch normalization. + (default: `True`) + use_gumbel (bool): Whether to use Gumbel softmax. + (default: `True`) + use_residual (bool): Whether to use residual connection. + (default: `True`) + use_act (bool): Whether to use activation function. + (default: `False`) + use_jk (bool): Whether to use JK aggregation. + (default: `False`) + nb_gumbel_sample (int): Number of Gumbel samples. + (default: `10`) + rb_order (int): Order of relational bias. + (default: `0`) + rb_trans (str): Type of relational bias transformation. + (default: `'sigmoid'`) + use_edge_loss (bool): Whether to use edge loss. + (default: `True`) + tau (float): Temperature parameter for Gumbel softmax. + (default: `0.25`) """ def __init__( self, - in_channels, - hidden_channels, - out_channels, - num_layers=3, - num_heads=4, - dropout=0.0, + in_channels: int, + hidden_channels: int, + out_channels: int, + num_layers: int = 3, + num_heads: int = 4, + dropout: float = 0.0, kernel_transformation=softmax_kernel_transformation, - nb_random_features=30, - use_bn=True, - use_gumbel=True, - use_residual=True, - use_act=False, - use_jk=False, - nb_gumbel_sample=10, - rb_order=0, - rb_trans='sigmoid', - use_edge_loss=True, - tau=0.25, + nb_random_features: int = 30, + use_bn: bool = True, + use_gumbel: bool = True, + use_residual: bool = True, + use_act: bool = False, + use_jk: bool = False, + nb_gumbel_sample: int = 10, + rb_order: int = 0, + rb_trans: str = 'sigmoid', + use_edge_loss: bool = True, + tau: float = 0.25, ): super().__init__() From 3b51dc776936dcc44194f5af562599df85ff32f7 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 14 Sep 2025 17:40:37 +0800 Subject: [PATCH 6/7] fix docs --- torch_geometric/nn/models/nodeformer.py | 39 +++++++++++++++---------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/torch_geometric/nn/models/nodeformer.py b/torch_geometric/nn/models/nodeformer.py index 4bb6ff5c917e..89191bf05cb3 100644 --- a/torch_geometric/nn/models/nodeformer.py +++ b/torch_geometric/nn/models/nodeformer.py @@ -1,4 +1,5 @@ import math +from typing import Callable import numpy as np import torch @@ -308,7 +309,7 @@ class NodeFormerConv(nn.Module): in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. num_heads (int): Number of parallel heads. - kernel_transformation (func, optional): The kernel + kernel_transformation (Callable, optional): The kernel transformation function. (default: :func:`softmax_kernel_transformation`) projection_matrix_type (str, optional): The type of projection matrix @@ -320,18 +321,18 @@ class NodeFormerConv(nn.Module): nb_gumbel_sample (int, optional): The number of Gumbel samples. (default: 10). rb_order (int, optional): The order of relational bias. - (default: 0). + (default: 0) rb_trans (str, optional): The type of transformation. relational bias ('sigmoid' or 'identity') (default: 'sigmoid'). - use_edge_loss (bool, optional): Whether to use edge loss - (default: True). + use_edge_loss (bool, optional): Whether to use edge loss. + (default: True) """ def __init__( self, in_channels: int, out_channels: int, num_heads: int, - kernel_transformation=softmax_kernel_transformation, + kernel_transformation: Callable = softmax_kernel_transformation, projection_matrix_type: str = 'a', nb_random_features: int = 10, use_gumbel: bool = True, @@ -444,7 +445,9 @@ class NodeFormer(nn.Module): (default: `4`) dropout (float): Dropout rate. (default: `0.0`) - kernel_transformation=softmax_kernel_transformation, + kernel_transformation (Callable, optional): The kernel + transformation function. + (default: :func:`softmax_kernel_transformation`) nb_random_features (int): Number of random features. (default: `30`) use_bn (bool): Whether to use batch normalization. @@ -462,7 +465,7 @@ class NodeFormer(nn.Module): rb_order (int): Order of relational bias. (default: `0`) rb_trans (str): Type of relational bias transformation. - (default: `'sigmoid'`) + (default: `sigmoid`) use_edge_loss (bool): Whether to use edge loss. (default: `True`) tau (float): Temperature parameter for Gumbel softmax. @@ -476,7 +479,7 @@ def __init__( num_layers: int = 3, num_heads: int = 4, dropout: float = 0.0, - kernel_transformation=softmax_kernel_transformation, + kernel_transformation: Callable = softmax_kernel_transformation, nb_random_features: int = 30, use_bn: bool = True, use_gumbel: bool = True, @@ -498,14 +501,18 @@ def __init__( self.bns.append(nn.LayerNorm(hidden_channels)) for _ in range(num_layers): self.convs.append( - NodeFormerConv(hidden_channels, hidden_channels, - num_heads=num_heads, - kernel_transformation=kernel_transformation, - nb_random_features=nb_random_features, - use_gumbel=use_gumbel, - nb_gumbel_sample=nb_gumbel_sample, - rb_order=rb_order, rb_trans=rb_trans, - use_edge_loss=use_edge_loss)) + NodeFormerConv( + hidden_channels, + hidden_channels, + num_heads=num_heads, + kernel_transformation=kernel_transformation, + nb_random_features=nb_random_features, + use_gumbel=use_gumbel, + nb_gumbel_sample=nb_gumbel_sample, + rb_order=rb_order, + rb_trans=rb_trans, + use_edge_loss=use_edge_loss, + )) self.bns.append(nn.LayerNorm(hidden_channels)) if use_jk: From ba13929c0128c8857058977a95d257e7d75265a3 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 14 Sep 2025 17:46:30 +0800 Subject: [PATCH 7/7] fix lint --- torch_geometric/metrics/link_pred.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index 5e2a95edf8a7..614d81898e5f 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -53,7 +53,7 @@ def pred_rel_mat(self) -> Tensor: # Flatten both prediction and ground-truth indices, and determine # overlaps afterwards via `torch.searchsorted`. - max_index = max( # type: ignore + max_index = max( self.pred_index_mat.max() if self.pred_index_mat.numel() > 0 else 0, self.edge_label_index[1].max() @@ -820,8 +820,9 @@ def compute(self) -> Tensor: right = pred[col.cpu()].to(device) # Use offset to work around applying `isin` along a specific dim: - i = max(left.max(), right.max()) + 1 # type: ignore - i = torch.arange(0, i * row.size(0), i, device=device).view(-1, 1) + i = max(left.max(), right.max()) + 1 + i = torch.arange(0, i * row.size(0), i, + device=device).view(-1, 1) # type: ignore isin = torch.isin(left + i, right + i) # Compute personalization via average inverse cosine similarity: