Skip to content

Commit 98abaae

Browse files
EdisonLeeeeewsad1rusty1s
authored
Pytorch Sparse Tensor support: PointConv, PointGNNConv, PointTransformerConv, PPFConv, and ResGatedGraphConv (#6937)
Co-authored-by: Jinu Sunil <[email protected]> Co-authored-by: Matthias Fey <[email protected]>
1 parent 5549b85 commit 98abaae

File tree

6 files changed

+71
-42
lines changed

6 files changed

+71
-42
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9696
- Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193))
9797
- Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187))
9898
- Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161))
99-
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6932](https://github.com/pyg-team/pytorch_geometric/pull/6932), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947), [#6950](https://github.com/pyg-team/pytorch_geometric/pull/6950), [#6951](https://github.com/pyg-team/pytorch_geometric/pull/6951), [#6957](https://github.com/pyg-team/pytorch_geometric/pull/6957))
99+
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6932](https://github.com/pyg-team/pytorch_geometric/pull/6932), [#6937](https://github.com/pyg-team/pytorch_geometric/pull/6937), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947), [#6950](https://github.com/pyg-team/pytorch_geometric/pull/6950), [#6951](https://github.com/pyg-team/pytorch_geometric/pull/6951), [#6957](https://github.com/pyg-team/pytorch_geometric/pull/6957))
100100
- Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
101101
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
102102
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))

test/nn/conv/test_point_conv.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def test_point_net_conv():
1414
pos2 = torch.randn(2, 3)
1515
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
1616
row, col = edge_index
17-
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
17+
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
18+
adj2 = adj1.to_torch_sparse_csc_tensor()
1819

1920
local_nn = Seq(Lin(16 + 3, 32), ReLU(), Lin(32, 32))
2021
global_nn = Seq(Lin(32, 32))
@@ -29,7 +30,8 @@ def test_point_net_conv():
2930
'))')
3031
out = conv(x1, pos1, edge_index)
3132
assert out.size() == (4, 32)
32-
assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)
33+
assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)
34+
assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)
3335

3436
if is_full_test():
3537
t = '(OptTensor, Tensor, Tensor) -> Tensor'
@@ -38,14 +40,18 @@ def test_point_net_conv():
3840

3941
t = '(OptTensor, Tensor, SparseTensor) -> Tensor'
4042
jit = torch.jit.script(conv.jittable(t))
41-
assert torch.allclose(jit(x1, pos1, adj.t()), out, atol=1e-6)
43+
assert torch.allclose(jit(x1, pos1, adj1.t()), out, atol=1e-6)
4244

43-
adj = adj.sparse_resize((4, 2))
45+
adj1 = adj1.sparse_resize((4, 2))
46+
adj2 = adj1.to_torch_sparse_csc_tensor()
4447
out = conv(x1, (pos1, pos2), edge_index)
4548
assert out.size() == (2, 32)
4649
assert conv((x1, None), (pos1, pos2), edge_index).tolist() == out.tolist()
47-
assert torch.allclose(conv(x1, (pos1, pos2), adj.t()), out, atol=1e-6)
48-
assert torch.allclose(conv((x1, None), (pos1, pos2), adj.t()), out,
50+
assert torch.allclose(conv(x1, (pos1, pos2), adj1.t()), out, atol=1e-6)
51+
assert torch.allclose(conv(x1, (pos1, pos2), adj2.t()), out, atol=1e-6)
52+
assert torch.allclose(conv((x1, None), (pos1, pos2), adj1.t()), out,
53+
atol=1e-6)
54+
assert torch.allclose(conv((x1, None), (pos1, pos2), adj2.t()), out,
4955
atol=1e-6)
5056

5157
if is_full_test():
@@ -56,5 +62,5 @@ def test_point_net_conv():
5662

5763
t = '(PairOptTensor, PairTensor, SparseTensor) -> Tensor'
5864
jit = torch.jit.script(conv.jittable(t))
59-
assert torch.allclose(jit((x1, None), (pos1, pos2), adj.t()), out,
65+
assert torch.allclose(jit((x1, None), (pos1, pos2), adj1.t()), out,
6066
atol=1e-6)

test/nn/conv/test_point_gnn_conv.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from torch_geometric.testing import is_full_test
66

77

8-
def test_pointgnn_conv():
8+
def test_point_gnn_conv():
99
x = torch.randn(6, 8)
1010
pos = torch.randn(6, 3)
1111
edge_index = torch.tensor([[0, 1, 1, 1, 2, 5], [1, 2, 3, 4, 3, 4]])
1212
row, col = edge_index
13-
adj = SparseTensor(row=row, col=col, sparse_sizes=(6, 6))
13+
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(6, 6))
14+
adj2 = adj1.to_torch_sparse_csc_tensor()
1415

1516
conv = PointGNNConv(
1617
mlp_h=MLP([8, 16, 3]),
@@ -25,7 +26,8 @@ def test_pointgnn_conv():
2526

2627
out = conv(x, pos, edge_index)
2728
assert out.size() == (6, 8)
28-
assert torch.allclose(conv(x, pos, adj.t()), out)
29+
assert torch.allclose(conv(x, pos, adj1.t()), out)
30+
assert torch.allclose(conv(x, pos, adj2.t()), out)
2931

3032
if is_full_test():
3133
t = '(Tensor, Tensor, Tensor) -> Tensor'
@@ -34,4 +36,4 @@ def test_pointgnn_conv():
3436

3537
t = '(Tensor, Tensor, SparseTensor) -> Tensor'
3638
jit = torch.jit.script(conv.jittable(t))
37-
assert torch.allclose(jit(x, pos, adj.t()), out)
39+
assert torch.allclose(jit(x, pos, adj1.t()), out)

test/nn/conv/test_point_transformer_conv.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,46 +13,53 @@ def test_point_transformer_conv():
1313
pos2 = torch.randn(2, 3)
1414
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
1515
row, col = edge_index
16-
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
16+
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
17+
adj2 = adj1.to_torch_sparse_csc_tensor()
1718

1819
conv = PointTransformerConv(in_channels=16, out_channels=32)
1920
assert str(conv) == 'PointTransformerConv(16, 32)'
2021

2122
out = conv(x1, pos1, edge_index)
2223
assert out.size() == (4, 32)
23-
assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)
24+
assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)
25+
assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)
2426

2527
if is_full_test():
2628
t = '(Tensor, Tensor, Tensor) -> Tensor'
2729
jit = torch.jit.script(conv.jittable(t))
28-
assert jit(x1, pos1, edge_index).tolist() == out.tolist()
30+
assert torch.allclose(jit(x1, pos1, edge_index), out, atol=1e-6)
2931

3032
t = '(Tensor, Tensor, SparseTensor) -> Tensor'
3133
jit = torch.jit.script(conv.jittable(t))
32-
assert torch.allclose(jit(x1, pos1, adj.t()), out, atol=1e-6)
34+
assert torch.allclose(jit(x1, pos1, adj1.t()), out, atol=1e-6)
3335

3436
pos_nn = Sequential(Linear(3, 16), ReLU(), Linear(16, 32))
3537
attn_nn = Sequential(Linear(32, 32), ReLU(), Linear(32, 32))
3638
conv = PointTransformerConv(16, 32, pos_nn, attn_nn)
3739

3840
out = conv(x1, pos1, edge_index)
3941
assert out.size() == (4, 32)
40-
assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)
42+
assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)
43+
assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)
4144

4245
conv = PointTransformerConv((16, 8), 32)
43-
adj = adj.sparse_resize((4, 2))
46+
adj1 = adj1.sparse_resize((4, 2))
47+
adj2 = adj1.to_torch_sparse_csc_tensor()
4448

4549
out = conv((x1, x2), (pos1, pos2), edge_index)
4650
assert out.size() == (2, 32)
47-
assert torch.allclose(conv((x1, x2), (pos1, pos2), adj.t()), out,
51+
assert torch.allclose(conv((x1, x2), (pos1, pos2), adj1.t()), out,
52+
atol=1e-6)
53+
assert torch.allclose(conv((x1, x2), (pos1, pos2), adj2.t()), out,
4854
atol=1e-6)
4955

5056
if is_full_test():
5157
t = '(PairTensor, PairTensor, Tensor) -> Tensor'
5258
jit = torch.jit.script(conv.jittable(t))
53-
assert jit((x1, x2), (pos1, pos2), edge_index).tolist() == out.tolist()
59+
assert torch.allclose(jit((x1, x2), (pos1, pos2), edge_index), out,
60+
atol=1e-6)
5461

5562
t = '(PairTensor, PairTensor, SparseTensor) -> Tensor'
5663
jit = torch.jit.script(conv.jittable(t))
57-
assert torch.allclose(jit((x1, x2), (pos1, pos2), adj.t()), out,
64+
assert torch.allclose(jit((x1, x2), (pos1, pos2), adj1.t()), out,
5865
atol=1e-6)

test/nn/conv/test_ppf_conv.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def test_ppf_conv():
1717
n2 = F.normalize(torch.rand(2, 3), dim=-1)
1818
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
1919
row, col = edge_index
20-
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
20+
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
21+
adj2 = adj1.to_torch_sparse_csc_tensor()
2122

2223
local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32))
2324
global_nn = Seq(Lin(32, 32))
@@ -32,34 +33,43 @@ def test_ppf_conv():
3233
'))')
3334
out = conv(x1, pos1, n1, edge_index)
3435
assert out.size() == (4, 32)
35-
assert torch.allclose(conv(x1, pos1, n1, adj.t()), out, atol=1e-6)
36+
assert torch.allclose(conv(x1, pos1, n1, adj1.t()), out, atol=1e-6)
37+
assert torch.allclose(conv(x1, pos1, n1, adj2.t()), out, atol=1e-6)
3638

3739
if is_full_test():
3840
t = '(OptTensor, Tensor, Tensor, Tensor) -> Tensor'
3941
jit = torch.jit.script(conv.jittable(t))
40-
assert jit(x1, pos1, n1, edge_index).tolist() == out.tolist()
42+
assert torch.allclose(jit(x1, pos1, n1, edge_index), out, atol=1e-6)
4143

4244
t = '(OptTensor, Tensor, Tensor, SparseTensor) -> Tensor'
4345
jit = torch.jit.script(conv.jittable(t))
44-
assert torch.allclose(jit(x1, pos1, n1, adj.t()), out, atol=1e-6)
46+
assert torch.allclose(jit(x1, pos1, n1, adj1.t()), out, atol=1e-6)
4547

46-
adj = adj.sparse_resize((4, 2))
48+
adj1 = adj1.sparse_resize((4, 2))
49+
adj2 = adj1.to_torch_sparse_csc_tensor()
4750
out = conv(x1, (pos1, pos2), (n1, n2), edge_index)
4851
assert out.size() == (2, 32)
49-
assert conv((x1, None), (pos1, pos2), (n1, n2),
50-
edge_index).tolist() == out.tolist()
51-
assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj.t()), out,
52+
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), edge_index),
53+
out, atol=1e-6)
54+
assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj1.t()), out,
55+
atol=1e-6)
56+
assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj2.t()), out,
5257
atol=1e-6)
53-
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj.t()),
58+
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj1.t()),
59+
out, atol=1e-6)
60+
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj2.t()),
5461
out, atol=1e-6)
5562

5663
if is_full_test():
5764
t = '(PairOptTensor, PairTensor, PairTensor, Tensor) -> Tensor'
5865
jit = torch.jit.script(conv.jittable(t))
59-
assert jit((x1, None), (pos1, pos2), (n1, n2),
60-
edge_index).tolist() == out.tolist()
66+
assert torch.allclose(
67+
jit((x1, None), (pos1, pos2), (n1, n2), edge_index),
68+
out,
69+
atol=1e-6,
70+
)
6171

6272
t = '(PairOptTensor, PairTensor, PairTensor, SparseTensor) -> Tensor'
6373
jit = torch.jit.script(conv.jittable(t))
64-
assert torch.allclose(jit((x1, None), (pos1, pos2), (n1, n2), adj.t()),
65-
out, atol=1e-6)
74+
assert torch.allclose(
75+
jit((x1, None), (pos1, pos2), (n1, n2), adj1.t()), out, atol=1e-6)

test/nn/conv/test_res_gated_graph_conv.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,39 @@ def test_res_gated_graph_conv():
1010
x2 = torch.randn(2, 32)
1111
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
1212
row, col = edge_index
13-
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
13+
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
14+
adj2 = adj1.to_torch_sparse_csc_tensor()
1415

1516
conv = ResGatedGraphConv(8, 32)
1617
assert str(conv) == 'ResGatedGraphConv(8, 32)'
1718
out = conv(x1, edge_index)
1819
assert out.size() == (4, 32)
19-
assert conv(x1, adj.t()).tolist() == out.tolist()
20+
assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)
21+
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)
2022

2123
if is_full_test():
2224
t = '(Tensor, Tensor) -> Tensor'
2325
jit = torch.jit.script(conv.jittable(t))
24-
assert jit(x1, edge_index).tolist() == out.tolist()
26+
assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)
2527

2628
t = '(Tensor, SparseTensor) -> Tensor'
2729
jit = torch.jit.script(conv.jittable(t))
28-
assert jit(x1, adj.t()).tolist() == out.tolist()
30+
assert torch.allclose(jit(x1, adj1.t()), out, atol=1e-6)
2931

30-
adj = adj.sparse_resize((4, 2))
32+
adj1 = adj1.sparse_resize((4, 2))
33+
adj2 = adj1.to_torch_sparse_csc_tensor()
3134
conv = ResGatedGraphConv((8, 32), 32)
3235
assert str(conv) == 'ResGatedGraphConv((8, 32), 32)'
3336
out = conv((x1, x2), edge_index)
3437
assert out.size() == (2, 32)
35-
assert conv((x1, x2), adj.t()).tolist() == out.tolist()
38+
assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6)
39+
assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6)
3640

3741
if is_full_test():
3842
t = '(PairTensor, Tensor) -> Tensor'
3943
jit = torch.jit.script(conv.jittable(t))
40-
assert jit((x1, x2), edge_index).tolist() == out.tolist()
44+
assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6)
4145

4246
t = '(PairTensor, SparseTensor) -> Tensor'
4347
jit = torch.jit.script(conv.jittable(t))
44-
assert jit((x1, x2), adj.t()).tolist() == out.tolist()
48+
assert torch.allclose(jit((x1, x2), adj1.t()), out, atol=1e-6)

0 commit comments

Comments
 (0)