Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions test/nn/conv/test_gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,27 @@ def test_gcn_conv():
value = torch.rand(row.size(0))
adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
adj1 = adj2.set_value(None)
adj3 = adj1.to_torch_sparse_coo_tensor()
adj4 = adj2.to_torch_sparse_coo_tensor()

conv = GCNConv(16, 32)
assert conv.__repr__() == 'GCNConv(16, 32)'
out1 = conv(x, edge_index)
assert out1.size() == (4, 32)
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)
out2 = conv(x, edge_index, value)
assert out2.size() == (4, 32)
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)
assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x, edge_index).tolist() == out1.tolist()
assert jit(x, edge_index, value).tolist() == out2.tolist()
assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)
assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)

t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
Expand Down
40 changes: 26 additions & 14 deletions torch_geometric/nn/conv/gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, fill_diag, matmul, mul
from torch_sparse import SparseTensor, fill_diag, mul
from torch_sparse import sum as sparsesum

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils import (
add_remaining_self_loops,
is_sparse,
is_torch_sparse_tensor,
spmm,
)
from torch_geometric.utils.num_nodes import maybe_num_nodes


Expand All @@ -34,9 +39,13 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,

fill_value = 2. if improved else 1.

if isinstance(edge_index, SparseTensor):
if is_sparse(edge_index):
assert flow in ["source_to_target"]
adj_t = edge_index
is_torch_sparse = is_torch_sparse_tensor(adj_t)
if is_torch_sparse:
adj_t = SparseTensor.from_torch_sparse_coo_tensor(adj_t)

if not adj_t.has_value():
adj_t = adj_t.fill_value(1., dtype=dtype)
if add_self_loops:
Expand All @@ -46,6 +55,9 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))

if is_torch_sparse:
adj_t = adj_t.to_torch_sparse_coo_tensor()
return adj_t

else:
Expand Down Expand Up @@ -169,27 +181,27 @@ def forward(self, x: Tensor, edge_index: Adj,
""""""

if self.normalize:
if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if is_sparse(edge_index):
cache = self._cached_adj_t
if cache is None:
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops, self.flow, x.dtype)
if self.cached:
self._cached_edge_index = (edge_index, edge_weight)
self._cached_adj_t = edge_index
else:
edge_index, edge_weight = cache[0], cache[1]
edge_index = cache

elif isinstance(edge_index, SparseTensor):
cache = self._cached_adj_t
elif isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if cache is None:
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops, self.flow, x.dtype)
if self.cached:
self._cached_adj_t = edge_index
self._cached_edge_index = (edge_index, edge_weight)
else:
edge_index = cache
edge_index, edge_weight = cache[0], cache[1]

x = self.lin(x)

Expand All @@ -206,4 +218,4 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)
return spmm(adj_t, x, reduce=self.aggr)