Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.3.0] - 2023-MM-DD
### Added
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033))
- Add inputs_channels back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
Expand Down
4 changes: 4 additions & 0 deletions test/nn/conv/test_gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@ def test_gcn_conv():
value = torch.rand(row.size(0))
adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
adj1 = adj2.set_value(None)
adj3 = adj1.to_torch_sparse_coo_tensor()
adj4 = adj2.to_torch_sparse_coo_tensor()

conv = GCNConv(16, 32)
assert conv.__repr__() == 'GCNConv(16, 32)'
out1 = conv(x, edge_index)
assert out1.size() == (4, 32)
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)
out2 = conv(x, edge_index, value)
assert out2.size() == (4, 32)
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)
assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, OptTensor) -> Tensor'
Expand Down
65 changes: 43 additions & 22 deletions torch_geometric/nn/conv/gcn_conv.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
from typing import Optional, Tuple
from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, fill_diag, matmul, mul
from torch_sparse import SparseTensor, fill_diag, mul
from torch_sparse import sum as sparsesum

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import PairTensor # noqa
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.typing import Adj, OptPairTensor, OptTensor
from torch_geometric.utils import (
add_remaining_self_loops,
is_torch_sparse_tensor,
scatter,
spmm,
to_torch_coo_tensor,
)
from torch_geometric.utils.num_nodes import maybe_num_nodes


@torch.jit._overload
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, flow="source_to_target", dtype=None):
# type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> PairTensor # noqa
# type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> OptPairTensor # noqa
pass


Expand All @@ -36,7 +40,7 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
fill_value = 2. if improved else 1.

if isinstance(edge_index, SparseTensor):
assert flow in ["source_to_target"]
assert flow == 'source_to_target'
adj_t = edge_index
if not adj_t.has_value():
adj_t = adj_t.fill_value(1., dtype=dtype)
Expand All @@ -47,8 +51,19 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))

return adj_t

# `edge_index` can be a `torch.LongTensor` or `torch.sparse.Tensor`:
is_sparse_tensor = is_torch_sparse_tensor(edge_index)
if is_sparse_tensor:
assert flow == 'source_to_target'
# Reverse `flow` since sparse tensors model transposed adjacencies:
flow = 'target_to_source'
adj_t = edge_index
num_nodes = adj_t.size(0)
edge_index = adj_t._indices()
edge_weight = adj_t._values()
else:
assert flow in ["source_to_target", "target_to_source"]
num_nodes = maybe_num_nodes(edge_index, num_nodes)
Expand All @@ -57,18 +72,24 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
device=edge_index.device)

if add_self_loops:
edge_index, tmp_edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
edge_weight = tmp_edge_weight

row, col = edge_index[0], edge_index[1]
idx = col if flow == "source_to_target" else row
deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
if add_self_loops:
edge_index, tmp_edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
edge_weight = tmp_edge_weight

row, col = edge_index[0], edge_index[1]
idx = col if flow == "source_to_target" else row
deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum')
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

if is_sparse_tensor:
adj_t = to_torch_coo_tensor(edge_index, edge_weight, size=num_nodes)
return adj_t, None
else:
return edge_index, edge_weight


class GCNConv(MessagePassing):
Expand Down Expand Up @@ -128,7 +149,7 @@ class GCNConv(MessagePassing):
- **output:** node features :math:`(|\mathcal{V}|, F_{out})`
"""

_cached_edge_index: Optional[Tuple[Tensor, Tensor]]
_cached_edge_index: Optional[OptPairTensor]
_cached_adj_t: Optional[SparseTensor]

def __init__(self, in_channels: int, out_channels: int,
Expand Down Expand Up @@ -207,4 +228,4 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)
return spmm(adj_t, x, reduce=self.aggr)