diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b65b20ecee4..f0370c92ae90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.3.0] - 2023-MM-DD ### Added +- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033)) - Add inputs_channels back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154)) - Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124)) - Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117)) diff --git a/test/nn/conv/test_gcn_conv.py b/test/nn/conv/test_gcn_conv.py index 621ef603d6c6..226acd98474e 100644 --- a/test/nn/conv/test_gcn_conv.py +++ b/test/nn/conv/test_gcn_conv.py @@ -14,15 +14,19 @@ def test_gcn_conv(): value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) + adj3 = adj1.to_torch_sparse_coo_tensor() + adj4 = adj2.to_torch_sparse_coo_tensor() conv = GCNConv(16, 32) assert conv.__repr__() == 'GCNConv(16, 32)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) + assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) + assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' diff --git a/torch_geometric/nn/conv/gcn_conv.py b/torch_geometric/nn/conv/gcn_conv.py index ae41e999f722..d87399d3e398 100644 --- a/torch_geometric/nn/conv/gcn_conv.py +++ b/torch_geometric/nn/conv/gcn_conv.py @@ -1,25 +1,29 @@ -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor from torch.nn import Parameter -from torch_scatter import scatter_add -from torch_sparse import SparseTensor, fill_diag, matmul, mul +from torch_sparse import SparseTensor, fill_diag, mul from torch_sparse import sum as sparsesum from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import zeros -from torch_geometric.typing import PairTensor # noqa -from torch_geometric.typing import Adj, OptTensor -from torch_geometric.utils import add_remaining_self_loops +from torch_geometric.typing import Adj, OptPairTensor, OptTensor +from torch_geometric.utils import ( + add_remaining_self_loops, + is_torch_sparse_tensor, + scatter, + spmm, + to_torch_coo_tensor, +) from torch_geometric.utils.num_nodes import maybe_num_nodes @torch.jit._overload def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, add_self_loops=True, flow="source_to_target", dtype=None): - # type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> PairTensor # noqa + # type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> OptPairTensor # noqa pass @@ -36,7 +40,7 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, fill_value = 2. if improved else 1. if isinstance(edge_index, SparseTensor): - assert flow in ["source_to_target"] + assert flow == 'source_to_target' adj_t = edge_index if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) @@ -47,8 +51,19 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1)) adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) + return adj_t + # `edge_index` can be a `torch.LongTensor` or `torch.sparse.Tensor`: + is_sparse_tensor = is_torch_sparse_tensor(edge_index) + if is_sparse_tensor: + assert flow == 'source_to_target' + # Reverse `flow` since sparse tensors model transposed adjacencies: + flow = 'target_to_source' + adj_t = edge_index + num_nodes = adj_t.size(0) + edge_index = adj_t._indices() + edge_weight = adj_t._values() else: assert flow in ["source_to_target", "target_to_source"] num_nodes = maybe_num_nodes(edge_index, num_nodes) @@ -57,18 +72,24 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) - if add_self_loops: - edge_index, tmp_edge_weight = add_remaining_self_loops( - edge_index, edge_weight, fill_value, num_nodes) - assert tmp_edge_weight is not None - edge_weight = tmp_edge_weight - - row, col = edge_index[0], edge_index[1] - idx = col if flow == "source_to_target" else row - deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes) - deg_inv_sqrt = deg.pow_(-0.5) - deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) - return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + if add_self_loops: + edge_index, tmp_edge_weight = add_remaining_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + assert tmp_edge_weight is not None + edge_weight = tmp_edge_weight + + row, col = edge_index[0], edge_index[1] + idx = col if flow == "source_to_target" else row + deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum') + deg_inv_sqrt = deg.pow_(-0.5) + deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) + edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + if is_sparse_tensor: + adj_t = to_torch_coo_tensor(edge_index, edge_weight, size=num_nodes) + return adj_t, None + else: + return edge_index, edge_weight class GCNConv(MessagePassing): @@ -128,7 +149,7 @@ class GCNConv(MessagePassing): - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ - _cached_edge_index: Optional[Tuple[Tensor, Tensor]] + _cached_edge_index: Optional[OptPairTensor] _cached_adj_t: Optional[SparseTensor] def __init__(self, in_channels: int, out_channels: int, @@ -207,4 +228,4 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: - return matmul(adj_t, x, reduce=self.aggr) + return spmm(adj_t, x, reduce=self.aggr)