Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193))
- Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187))
- Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6937](https://github.com/pyg-team/pytorch_geometric/pull/6937))
- Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
Expand Down
20 changes: 13 additions & 7 deletions test/nn/conv/test_point_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def test_point_net_conv():
pos2 = torch.randn(2, 3)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_coo_tensor()

local_nn = Seq(Lin(16 + 3, 32), ReLU(), Lin(32, 32))
global_nn = Seq(Lin(32, 32))
Expand All @@ -29,7 +30,8 @@ def test_point_net_conv():
'))')
out = conv(x1, pos1, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(OptTensor, Tensor, Tensor) -> Tensor'
Expand All @@ -38,14 +40,18 @@ def test_point_net_conv():

t = '(OptTensor, Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, pos1, adj.t()), out, atol=1e-6)
assert torch.allclose(jit(x1, pos1, adj1.t()), out, atol=1e-6)

adj = adj.sparse_resize((4, 2))
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_coo_tensor()
out = conv(x1, (pos1, pos2), edge_index)
assert out.size() == (2, 32)
assert conv((x1, None), (pos1, pos2), edge_index).tolist() == out.tolist()
assert torch.allclose(conv(x1, (pos1, pos2), adj.t()), out, atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), adj.t()), out,
assert torch.allclose(conv(x1, (pos1, pos2), adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, (pos1, pos2), adj2.t()), out, atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), adj1.t()), out,
atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), adj2.t()), out,
atol=1e-6)

if is_full_test():
Expand All @@ -56,5 +62,5 @@ def test_point_net_conv():

t = '(PairOptTensor, PairTensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, None), (pos1, pos2), adj.t()), out,
assert torch.allclose(jit((x1, None), (pos1, pos2), adj1.t()), out,
atol=1e-6)
8 changes: 5 additions & 3 deletions test/nn/conv/test_point_gnn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ def test_pointgnn_conv():
pos = torch.randn(6, 3)
edge_index = torch.tensor([[0, 1, 1, 1, 2, 5], [1, 2, 3, 4, 3, 4]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(6, 6))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(6, 6))
adj2 = adj1.to_torch_sparse_coo_tensor()

conv = PointGNNConv(
mlp_h=MLP([8, 16, 3]),
Expand All @@ -25,7 +26,8 @@ def test_pointgnn_conv():

out = conv(x, pos, edge_index)
assert out.size() == (6, 8)
assert torch.allclose(conv(x, pos, adj.t()), out)
assert torch.allclose(conv(x, pos, adj1.t()), out)
assert torch.allclose(conv(x, pos, adj2.t()), out)

if is_full_test():
t = '(Tensor, Tensor, Tensor) -> Tensor'
Expand All @@ -34,4 +36,4 @@ def test_pointgnn_conv():

t = '(Tensor, Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, pos, adj.t()), out)
assert torch.allclose(jit(x, pos, adj1.t()), out)
25 changes: 16 additions & 9 deletions test/nn/conv/test_point_transformer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,53 @@ def test_point_transformer_conv():
pos2 = torch.randn(2, 3)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_coo_tensor()

conv = PointTransformerConv(in_channels=16, out_channels=32)
assert str(conv) == 'PointTransformerConv(16, 32)'

out = conv(x1, pos1, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, pos1, edge_index).tolist() == out.tolist()
assert torch.allclose(jit(x1, pos1, edge_index), out, atol=1e-6)

t = '(Tensor, Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, pos1, adj.t()), out, atol=1e-6)
assert torch.allclose(jit(x1, pos1, adj1.t()), out, atol=1e-6)

pos_nn = Sequential(Linear(3, 16), ReLU(), Linear(16, 32))
attn_nn = Sequential(Linear(32, 32), ReLU(), Linear(32, 32))
conv = PointTransformerConv(16, 32, pos_nn, attn_nn)

out = conv(x1, pos1, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)

conv = PointTransformerConv((16, 8), 32)
adj = adj.sparse_resize((4, 2))
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_coo_tensor()

out = conv((x1, x2), (pos1, pos2), edge_index)
assert out.size() == (2, 32)
assert torch.allclose(conv((x1, x2), (pos1, pos2), adj.t()), out,
assert torch.allclose(conv((x1, x2), (pos1, pos2), adj1.t()), out,
atol=1e-6)
assert torch.allclose(conv((x1, x2), (pos1, pos2), adj2.t()), out,
atol=1e-6)

if is_full_test():
t = '(PairTensor, PairTensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), (pos1, pos2), edge_index).tolist() == out.tolist()
assert torch.allclose(jit((x1, x2), (pos1, pos2), edge_index), out,
atol=1e-6)

t = '(PairTensor, PairTensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), (pos1, pos2), adj.t()), out,
assert torch.allclose(jit((x1, x2), (pos1, pos2), adj1.t()), out,
atol=1e-6)
34 changes: 21 additions & 13 deletions test/nn/conv/test_ppf_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def test_ppf_conv():
n2 = F.normalize(torch.rand(2, 3), dim=-1)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_coo_tensor()

local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32))
global_nn = Seq(Lin(32, 32))
Expand All @@ -32,34 +33,41 @@ def test_ppf_conv():
'))')
out = conv(x1, pos1, n1, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, pos1, n1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, n1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, n1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(OptTensor, Tensor, Tensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, pos1, n1, edge_index).tolist() == out.tolist()
assert torch.allclose(jit(x1, pos1, n1, edge_index), out, atol=1e-6)

t = '(OptTensor, Tensor, Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, pos1, n1, adj.t()), out, atol=1e-6)
assert torch.allclose(jit(x1, pos1, n1, adj1.t()), out, atol=1e-6)

adj = adj.sparse_resize((4, 2))
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_coo_tensor()
out = conv(x1, (pos1, pos2), (n1, n2), edge_index)
assert out.size() == (2, 32)
assert conv((x1, None), (pos1, pos2), (n1, n2),
edge_index).tolist() == out.tolist()
assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj.t()), out,
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), edge_index),
out, atol=1e-6)
assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj1.t()), out,
atol=1e-6)
assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj2.t()), out,
atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj.t()),
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj1.t()),
out, atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj2.t()),
out, atol=1e-6)

if is_full_test():
t = '(PairOptTensor, PairTensor, PairTensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, None), (pos1, pos2), (n1, n2),
edge_index).tolist() == out.tolist()
assert torch.allclose(
jit((x1, None), (pos1, pos2), (n1, n2), edge_index), out,
atol=1e-6)

t = '(PairOptTensor, PairTensor, PairTensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, None), (pos1, pos2), (n1, n2), adj.t()),
out, atol=1e-6)
assert torch.allclose(
jit((x1, None), (pos1, pos2), (n1, n2), adj1.t()), out, atol=1e-6)
20 changes: 12 additions & 8 deletions test/nn/conv/test_res_gated_graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,39 @@ def test_res_gated_graph_conv():
x2 = torch.randn(2, 32)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_coo_tensor()

conv = ResGatedGraphConv(8, 32)
assert str(conv) == 'ResGatedGraphConv(8, 32)'
out = conv(x1, edge_index)
assert out.size() == (4, 32)
assert conv(x1, adj.t()).tolist() == out.tolist()
assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, edge_index).tolist() == out.tolist()
assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)

t = '(Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, adj.t()).tolist() == out.tolist()
assert torch.allclose(jit(x1, adj1.t()), out, atol=1e-6)

adj = adj.sparse_resize((4, 2))
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_coo_tensor()
conv = ResGatedGraphConv((8, 32), 32)
assert str(conv) == 'ResGatedGraphConv((8, 32), 32)'
out = conv((x1, x2), edge_index)
assert out.size() == (2, 32)
assert conv((x1, x2), adj.t()).tolist() == out.tolist()
assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(PairTensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), edge_index).tolist() == out.tolist()
assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6)

t = '(PairTensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), adj.t()).tolist() == out.tolist()
assert torch.allclose(jit((x1, x2), adj1.t()), out, atol=1e-6)