Skip to content

Commit 2fc66c5

Browse files
authored
Add PyTorch SparseTensor support for GINConv, SAGEConv, and GraphConv (#6532)
This PR aims to add PyTorch SparseTensor support for `GINConv`, `SAGEConv`, and `GraphConv`, ~~and also update type hints in `message_and_aggregate` correspondingly~~.
1 parent 6d6162d commit 2fc66c5

File tree

7 files changed

+71
-41
lines changed

7 files changed

+71
-41
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6060
- Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193))
6161
- Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187))
6262
- Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161))
63-
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514))
63+
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532))
6464
- Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
6565
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
6666
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))

test/nn/conv/test_gin_conv.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def test_gin_conv():
1414
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
1515
row, col = edge_index
1616
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
17+
adj2 = adj.to_torch_sparse_coo_tensor()
1718

1819
nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32))
1920
conv = GINConv(nn, train_eps=True)
@@ -25,8 +26,9 @@ def test_gin_conv():
2526
'))')
2627
out = conv(x1, edge_index)
2728
assert out.size() == (4, 32)
28-
assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
29-
assert conv(x1, adj.t()).tolist() == out.tolist()
29+
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out, atol=1e-6)
30+
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
31+
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)
3032

3133
if is_full_test():
3234
t = '(Tensor, Tensor, Size) -> Tensor'
@@ -39,13 +41,16 @@ def test_gin_conv():
3941
assert jit(x1, adj.t()).tolist() == out.tolist()
4042

4143
adj = adj.sparse_resize((4, 2))
44+
adj2 = adj.to_torch_sparse_coo_tensor()
4245
out1 = conv((x1, x2), edge_index)
4346
out2 = conv((x1, None), edge_index, (4, 2))
4447
assert out1.size() == (2, 32)
4548
assert out2.size() == (2, 32)
46-
assert conv((x1, x2), edge_index, (4, 2)).tolist() == out1.tolist()
47-
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
48-
assert conv((x1, None), adj.t()).tolist() == out2.tolist()
49+
assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1, atol=1e-6)
50+
assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6)
51+
assert torch.allclose(conv((x1, None), adj.t()), out2, atol=1e-6)
52+
assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)
53+
assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)
4954

5055
if is_full_test():
5156
t = '(OptPairTensor, Tensor, Size) -> Tensor'

test/nn/conv/test_graph_conv.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,23 @@ def test_graph_conv():
1313
value = torch.randn(edge_index.size(1))
1414
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
1515
adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
16+
adj3 = adj1.to_torch_sparse_coo_tensor()
17+
adj4 = adj2.to_torch_sparse_coo_tensor()
1618

1719
conv = GraphConv(8, 32)
1820
assert conv.__repr__() == 'GraphConv(8, 32)'
1921
out11 = conv(x1, edge_index)
2022
assert out11.size() == (4, 32)
21-
assert conv(x1, edge_index, size=(4, 4)).tolist() == out11.tolist()
22-
assert conv(x1, adj1.t()).tolist() == out11.tolist()
23+
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out11, atol=1e-6)
24+
assert torch.allclose(conv(x1, adj1.t()), out11, atol=1e-6)
25+
assert torch.allclose(conv(x1, adj3.t()), out11, atol=1e-6)
2326

2427
out12 = conv(x1, edge_index, value)
2528
assert out12.size() == (4, 32)
26-
assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out12.tolist()
27-
assert conv(x1, adj2.t()).tolist() == out12.tolist()
29+
assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out12,
30+
atol=1e-6)
31+
assert torch.allclose(conv(x1, adj2.t()), out12, atol=1e-6)
32+
assert torch.allclose(conv(x1, adj4.t()), out12, atol=1e-6)
2833

2934
if is_full_test():
3035
t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
@@ -42,6 +47,8 @@ def test_graph_conv():
4247

4348
adj1 = adj1.sparse_resize((4, 2))
4449
adj2 = adj2.sparse_resize((4, 2))
50+
adj3 = adj1.to_torch_sparse_coo_tensor()
51+
adj4 = adj2.to_torch_sparse_coo_tensor()
4552
conv = GraphConv((8, 16), 32)
4653
assert conv.__repr__() == 'GraphConv((8, 16), 32)'
4754
out21 = conv((x1, x2), edge_index)
@@ -52,10 +59,14 @@ def test_graph_conv():
5259
assert out22.size() == (2, 32)
5360
assert out23.size() == (2, 32)
5461
assert out24.size() == (2, 32)
55-
assert conv((x1, x2), edge_index, size=(4, 2)).tolist() == out21.tolist()
56-
assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out22.tolist()
57-
assert conv((x1, x2), adj1.t()).tolist() == out21.tolist()
58-
assert conv((x1, x2), adj2.t()).tolist() == out22.tolist()
62+
assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out21,
63+
atol=1e-6)
64+
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out22,
65+
atol=1e-6)
66+
assert torch.allclose(conv((x1, x2), adj1.t()), out21, atol=1e-6)
67+
assert torch.allclose(conv((x1, x2), adj2.t()), out22, atol=1e-6)
68+
assert torch.allclose(conv((x1, x2), adj3.t()), out21, atol=1e-6)
69+
assert torch.allclose(conv((x1, x2), adj4.t()), out22, atol=1e-6)
5970

6071
if is_full_test():
6172
t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'

test/nn/conv/test_sage_conv.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,53 +7,62 @@
77

88

99
@pytest.mark.parametrize('project', [False, True])
10-
def test_sage_conv(project):
10+
@pytest.mark.parametrize('aggr', ['mean', 'sum'])
11+
def test_sage_conv(project, aggr):
1112
x1 = torch.randn(4, 8)
1213
x2 = torch.randn(2, 16)
1314
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
1415
row, col = edge_index
1516
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
17+
adj2 = adj.to_torch_sparse_coo_tensor()
1618

17-
conv = SAGEConv(8, 32, project=project)
18-
assert str(conv) == 'SAGEConv(8, 32, aggr=mean)'
19+
conv = SAGEConv(8, 32, project=project, aggr=aggr)
20+
assert str(conv) == f'SAGEConv(8, 32, aggr={aggr})'
1921
out = conv(x1, edge_index)
2022
assert out.size() == (4, 32)
21-
assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
22-
assert conv(x1, adj.t()).tolist() == out.tolist()
23+
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out, atol=1e-6)
24+
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
25+
if aggr == 'sum':
26+
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)
2327

2428
if is_full_test():
2529
t = '(Tensor, Tensor, Size) -> Tensor'
2630
jit = torch.jit.script(conv.jittable(t))
27-
assert jit(x1, edge_index).tolist() == out.tolist()
28-
assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
31+
assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)
32+
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out, atol=1e-6)
2933

3034
t = '(Tensor, SparseTensor, Size) -> Tensor'
3135
jit = torch.jit.script(conv.jittable(t))
32-
assert jit(x1, adj.t()).tolist() == out.tolist()
36+
assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)
3337

3438
adj = adj.sparse_resize((4, 2))
35-
conv = SAGEConv((8, 16), 32, project=project)
36-
assert str(conv) == 'SAGEConv((8, 16), 32, aggr=mean)'
39+
adj2 = adj.to_torch_sparse_coo_tensor()
40+
conv = SAGEConv((8, 16), 32, project=project, aggr=aggr)
41+
assert str(conv) == f'SAGEConv((8, 16), 32, aggr={aggr})'
3742
out1 = conv((x1, x2), edge_index)
3843
out2 = conv((x1, None), edge_index, (4, 2))
3944
assert out1.size() == (2, 32)
4045
assert out2.size() == (2, 32)
41-
assert conv((x1, x2), edge_index, (4, 2)).tolist() == out1.tolist()
42-
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
43-
assert conv((x1, None), adj.t()).tolist() == out2.tolist()
46+
assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1, atol=1e-6)
47+
assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6)
48+
assert torch.allclose(conv((x1, None), adj.t()), out2, atol=1e-6)
49+
if aggr == 'sum':
50+
assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)
51+
assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)
4452

4553
if is_full_test():
4654
t = '(OptPairTensor, Tensor, Size) -> Tensor'
4755
jit = torch.jit.script(conv.jittable(t))
48-
assert jit((x1, x2), edge_index).tolist() == out1.tolist()
49-
assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out1.tolist()
50-
assert jit((x1, None), edge_index,
51-
size=(4, 2)).tolist() == out2.tolist()
56+
assert torch.allclose(jit((x1, x2), edge_index), out1, atol=1e-6)
57+
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1,
58+
atol=1e-6)
59+
assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2,
60+
atol=1e-6)
5261

5362
t = '(OptPairTensor, SparseTensor, Size) -> Tensor'
5463
jit = torch.jit.script(conv.jittable(t))
55-
assert jit((x1, x2), adj.t()).tolist() == out1.tolist()
56-
assert jit((x1, None), adj.t()).tolist() == out2.tolist()
64+
assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)
65+
assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6)
5766

5867

5968
def test_lstm_aggr_sage_conv():

torch_geometric/nn/conv/gin_conv.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
import torch
44
from torch import Tensor
5-
from torch_sparse import SparseTensor, matmul
5+
from torch_sparse import SparseTensor
66

77
from torch_geometric.nn.conv import MessagePassing
88
from torch_geometric.nn.dense.linear import Linear
99
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
10+
from torch_geometric.utils import spmm
1011

1112
from ..inits import reset
1213

@@ -84,8 +85,9 @@ def message(self, x_j: Tensor) -> Tensor:
8485

8586
def message_and_aggregate(self, adj_t: SparseTensor,
8687
x: OptPairTensor) -> Tensor:
87-
adj_t = adj_t.set_value(None, layout=None)
88-
return matmul(adj_t, x[0], reduce=self.aggr)
88+
if isinstance(adj_t, SparseTensor):
89+
adj_t = adj_t.set_value(None, layout=None)
90+
return spmm(adj_t, x[0], reduce=self.aggr)
8991

9092
def __repr__(self) -> str:
9193
return f'{self.__class__.__name__}(nn={self.nn})'

torch_geometric/nn/conv/graph_conv.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import Tuple, Union
22

33
from torch import Tensor
4-
from torch_sparse import SparseTensor, matmul
4+
from torch_sparse import SparseTensor
55

66
from torch_geometric.nn.conv import MessagePassing
77
from torch_geometric.nn.dense.linear import Linear
88
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
9+
from torch_geometric.utils import spmm
910

1011

1112
class GraphConv(MessagePassing):
@@ -91,4 +92,4 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
9192

9293
def message_and_aggregate(self, adj_t: SparseTensor,
9394
x: OptPairTensor) -> Tensor:
94-
return matmul(adj_t, x[0], reduce=self.aggr)
95+
return spmm(adj_t, x[0], reduce=self.aggr)

torch_geometric/nn/conv/sage_conv.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import torch.nn.functional as F
44
from torch import Tensor
55
from torch.nn import LSTM
6-
from torch_sparse import SparseTensor, matmul
6+
from torch_sparse import SparseTensor
77

88
from torch_geometric.nn.aggr import Aggregation, MultiAggregation
99
from torch_geometric.nn.conv import MessagePassing
1010
from torch_geometric.nn.dense.linear import Linear
1111
from torch_geometric.typing import Adj, OptPairTensor, Size
12+
from torch_geometric.utils import spmm
1213

1314

1415
class SAGEConv(MessagePassing):
@@ -145,8 +146,9 @@ def message(self, x_j: Tensor) -> Tensor:
145146

146147
def message_and_aggregate(self, adj_t: SparseTensor,
147148
x: OptPairTensor) -> Tensor:
148-
adj_t = adj_t.set_value(None, layout=None)
149-
return matmul(adj_t, x[0], reduce=self.aggr)
149+
if isinstance(adj_t, SparseTensor):
150+
adj_t = adj_t.set_value(None, layout=None)
151+
return spmm(adj_t, x[0], reduce=self.aggr)
150152

151153
def __repr__(self) -> str:
152154
return (f'{self.__class__.__name__}({self.in_channels}, '

0 commit comments

Comments
 (0)