Skip to content

Commit 82d8ad2

Browse files
Add PyTorch SparseTensor support for GCNConv and gcn_norm (#6033)
For `gcn_norm`, a PyTorch SparseTensor is converted to `torch_sparse.SparseTensor` and then converted back. Not really like this but have no better solution. Co-authored-by: rusty1s <[email protected]>
1 parent 11e576b commit 82d8ad2

File tree

3 files changed

+48
-22
lines changed

3 files changed

+48
-22
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.3.0] - 2023-MM-DD
77
### Added
8+
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033))
89
- Add inputs_channels back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
910
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
1011
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))

test/nn/conv/test_gcn_conv.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,19 @@ def test_gcn_conv():
1414
value = torch.rand(row.size(0))
1515
adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
1616
adj1 = adj2.set_value(None)
17+
adj3 = adj1.to_torch_sparse_coo_tensor()
18+
adj4 = adj2.to_torch_sparse_coo_tensor()
1719

1820
conv = GCNConv(16, 32)
1921
assert conv.__repr__() == 'GCNConv(16, 32)'
2022
out1 = conv(x, edge_index)
2123
assert out1.size() == (4, 32)
2224
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
25+
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)
2326
out2 = conv(x, edge_index, value)
2427
assert out2.size() == (4, 32)
2528
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)
29+
assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)
2630

2731
if is_full_test():
2832
t = '(Tensor, Tensor, OptTensor) -> Tensor'

torch_geometric/nn/conv/gcn_conv.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
1-
from typing import Optional, Tuple
1+
from typing import Optional
22

33
import torch
44
from torch import Tensor
55
from torch.nn import Parameter
6-
from torch_scatter import scatter_add
7-
from torch_sparse import SparseTensor, fill_diag, matmul, mul
6+
from torch_sparse import SparseTensor, fill_diag, mul
87
from torch_sparse import sum as sparsesum
98

109
from torch_geometric.nn.conv import MessagePassing
1110
from torch_geometric.nn.dense.linear import Linear
1211
from torch_geometric.nn.inits import zeros
13-
from torch_geometric.typing import PairTensor # noqa
14-
from torch_geometric.typing import Adj, OptTensor
15-
from torch_geometric.utils import add_remaining_self_loops
12+
from torch_geometric.typing import Adj, OptPairTensor, OptTensor
13+
from torch_geometric.utils import (
14+
add_remaining_self_loops,
15+
is_torch_sparse_tensor,
16+
scatter,
17+
spmm,
18+
to_torch_coo_tensor,
19+
)
1620
from torch_geometric.utils.num_nodes import maybe_num_nodes
1721

1822

1923
@torch.jit._overload
2024
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
2125
add_self_loops=True, flow="source_to_target", dtype=None):
22-
# type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> PairTensor # noqa
26+
# type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> OptPairTensor # noqa
2327
pass
2428

2529

@@ -36,7 +40,7 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
3640
fill_value = 2. if improved else 1.
3741

3842
if isinstance(edge_index, SparseTensor):
39-
assert flow in ["source_to_target"]
43+
assert flow == 'source_to_target'
4044
adj_t = edge_index
4145
if not adj_t.has_value():
4246
adj_t = adj_t.fill_value(1., dtype=dtype)
@@ -47,8 +51,19 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
4751
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
4852
adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
4953
adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
54+
5055
return adj_t
5156

57+
# `edge_index` can be a `torch.LongTensor` or `torch.sparse.Tensor`:
58+
is_sparse_tensor = is_torch_sparse_tensor(edge_index)
59+
if is_sparse_tensor:
60+
assert flow == 'source_to_target'
61+
# Reverse `flow` since sparse tensors model transposed adjacencies:
62+
flow = 'target_to_source'
63+
adj_t = edge_index
64+
num_nodes = adj_t.size(0)
65+
edge_index = adj_t._indices()
66+
edge_weight = adj_t._values()
5267
else:
5368
assert flow in ["source_to_target", "target_to_source"]
5469
num_nodes = maybe_num_nodes(edge_index, num_nodes)
@@ -57,18 +72,24 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
5772
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
5873
device=edge_index.device)
5974

60-
if add_self_loops:
61-
edge_index, tmp_edge_weight = add_remaining_self_loops(
62-
edge_index, edge_weight, fill_value, num_nodes)
63-
assert tmp_edge_weight is not None
64-
edge_weight = tmp_edge_weight
65-
66-
row, col = edge_index[0], edge_index[1]
67-
idx = col if flow == "source_to_target" else row
68-
deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
69-
deg_inv_sqrt = deg.pow_(-0.5)
70-
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
71-
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
75+
if add_self_loops:
76+
edge_index, tmp_edge_weight = add_remaining_self_loops(
77+
edge_index, edge_weight, fill_value, num_nodes)
78+
assert tmp_edge_weight is not None
79+
edge_weight = tmp_edge_weight
80+
81+
row, col = edge_index[0], edge_index[1]
82+
idx = col if flow == "source_to_target" else row
83+
deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum')
84+
deg_inv_sqrt = deg.pow_(-0.5)
85+
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
86+
edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
87+
88+
if is_sparse_tensor:
89+
adj_t = to_torch_coo_tensor(edge_index, edge_weight, size=num_nodes)
90+
return adj_t, None
91+
else:
92+
return edge_index, edge_weight
7293

7394

7495
class GCNConv(MessagePassing):
@@ -128,7 +149,7 @@ class GCNConv(MessagePassing):
128149
- **output:** node features :math:`(|\mathcal{V}|, F_{out})`
129150
"""
130151

131-
_cached_edge_index: Optional[Tuple[Tensor, Tensor]]
152+
_cached_edge_index: Optional[OptPairTensor]
132153
_cached_adj_t: Optional[SparseTensor]
133154

134155
def __init__(self, in_channels: int, out_channels: int,
@@ -207,4 +228,4 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
207228
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
208229

209230
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
210-
return matmul(adj_t, x, reduce=self.aggr)
231+
return spmm(adj_t, x, reduce=self.aggr)

0 commit comments

Comments
 (0)