Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193))
- Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187))
- Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6932](https://github.com/pyg-team/pytorch_geometric/pull/6932), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947), [#6950](https://github.com/pyg-team/pytorch_geometric/pull/6950), [#6951](https://github.com/pyg-team/pytorch_geometric/pull/6951), [#6957](https://github.com/pyg-team/pytorch_geometric/pull/6957))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6932](https://github.com/pyg-team/pytorch_geometric/pull/6932), [#6937](https://github.com/pyg-team/pytorch_geometric/pull/6937), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947), [#6950](https://github.com/pyg-team/pytorch_geometric/pull/6950), [#6951](https://github.com/pyg-team/pytorch_geometric/pull/6951), [#6957](https://github.com/pyg-team/pytorch_geometric/pull/6957))
- Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
Expand Down
20 changes: 13 additions & 7 deletions test/nn/conv/test_point_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def test_point_net_conv():
pos2 = torch.randn(2, 3)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_csc_tensor()

local_nn = Seq(Lin(16 + 3, 32), ReLU(), Lin(32, 32))
global_nn = Seq(Lin(32, 32))
Expand All @@ -29,7 +30,8 @@ def test_point_net_conv():
'))')
out = conv(x1, pos1, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(OptTensor, Tensor, Tensor) -> Tensor'
Expand All @@ -38,14 +40,18 @@ def test_point_net_conv():

t = '(OptTensor, Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, pos1, adj.t()), out, atol=1e-6)
assert torch.allclose(jit(x1, pos1, adj1.t()), out, atol=1e-6)

adj = adj.sparse_resize((4, 2))
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_csc_tensor()
out = conv(x1, (pos1, pos2), edge_index)
assert out.size() == (2, 32)
assert conv((x1, None), (pos1, pos2), edge_index).tolist() == out.tolist()
assert torch.allclose(conv(x1, (pos1, pos2), adj.t()), out, atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), adj.t()), out,
assert torch.allclose(conv(x1, (pos1, pos2), adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, (pos1, pos2), adj2.t()), out, atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), adj1.t()), out,
atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), adj2.t()), out,
atol=1e-6)

if is_full_test():
Expand All @@ -56,5 +62,5 @@ def test_point_net_conv():

t = '(PairOptTensor, PairTensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, None), (pos1, pos2), adj.t()), out,
assert torch.allclose(jit((x1, None), (pos1, pos2), adj1.t()), out,
atol=1e-6)
10 changes: 6 additions & 4 deletions test/nn/conv/test_point_gnn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from torch_geometric.testing import is_full_test


def test_pointgnn_conv():
def test_point_gnn_conv():
x = torch.randn(6, 8)
pos = torch.randn(6, 3)
edge_index = torch.tensor([[0, 1, 1, 1, 2, 5], [1, 2, 3, 4, 3, 4]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(6, 6))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(6, 6))
adj2 = adj1.to_torch_sparse_csc_tensor()

conv = PointGNNConv(
mlp_h=MLP([8, 16, 3]),
Expand All @@ -25,7 +26,8 @@ def test_pointgnn_conv():

out = conv(x, pos, edge_index)
assert out.size() == (6, 8)
assert torch.allclose(conv(x, pos, adj.t()), out)
assert torch.allclose(conv(x, pos, adj1.t()), out)
assert torch.allclose(conv(x, pos, adj2.t()), out)

if is_full_test():
t = '(Tensor, Tensor, Tensor) -> Tensor'
Expand All @@ -34,4 +36,4 @@ def test_pointgnn_conv():

t = '(Tensor, Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, pos, adj.t()), out)
assert torch.allclose(jit(x, pos, adj1.t()), out)
25 changes: 16 additions & 9 deletions test/nn/conv/test_point_transformer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,53 @@ def test_point_transformer_conv():
pos2 = torch.randn(2, 3)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_csc_tensor()

conv = PointTransformerConv(in_channels=16, out_channels=32)
assert str(conv) == 'PointTransformerConv(16, 32)'

out = conv(x1, pos1, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, pos1, edge_index).tolist() == out.tolist()
assert torch.allclose(jit(x1, pos1, edge_index), out, atol=1e-6)

t = '(Tensor, Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, pos1, adj.t()), out, atol=1e-6)
assert torch.allclose(jit(x1, pos1, adj1.t()), out, atol=1e-6)

pos_nn = Sequential(Linear(3, 16), ReLU(), Linear(16, 32))
attn_nn = Sequential(Linear(32, 32), ReLU(), Linear(32, 32))
conv = PointTransformerConv(16, 32, pos_nn, attn_nn)

out = conv(x1, pos1, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)

conv = PointTransformerConv((16, 8), 32)
adj = adj.sparse_resize((4, 2))
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_csc_tensor()

out = conv((x1, x2), (pos1, pos2), edge_index)
assert out.size() == (2, 32)
assert torch.allclose(conv((x1, x2), (pos1, pos2), adj.t()), out,
assert torch.allclose(conv((x1, x2), (pos1, pos2), adj1.t()), out,
atol=1e-6)
assert torch.allclose(conv((x1, x2), (pos1, pos2), adj2.t()), out,
atol=1e-6)

if is_full_test():
t = '(PairTensor, PairTensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), (pos1, pos2), edge_index).tolist() == out.tolist()
assert torch.allclose(jit((x1, x2), (pos1, pos2), edge_index), out,
atol=1e-6)

t = '(PairTensor, PairTensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), (pos1, pos2), adj.t()), out,
assert torch.allclose(jit((x1, x2), (pos1, pos2), adj1.t()), out,
atol=1e-6)
36 changes: 23 additions & 13 deletions test/nn/conv/test_ppf_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def test_ppf_conv():
n2 = F.normalize(torch.rand(2, 3), dim=-1)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_csc_tensor()

local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32))
global_nn = Seq(Lin(32, 32))
Expand All @@ -32,34 +33,43 @@ def test_ppf_conv():
'))')
out = conv(x1, pos1, n1, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, pos1, n1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, n1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, pos1, n1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(OptTensor, Tensor, Tensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, pos1, n1, edge_index).tolist() == out.tolist()
assert torch.allclose(jit(x1, pos1, n1, edge_index), out, atol=1e-6)

t = '(OptTensor, Tensor, Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, pos1, n1, adj.t()), out, atol=1e-6)
assert torch.allclose(jit(x1, pos1, n1, adj1.t()), out, atol=1e-6)

adj = adj.sparse_resize((4, 2))
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_csc_tensor()
out = conv(x1, (pos1, pos2), (n1, n2), edge_index)
assert out.size() == (2, 32)
assert conv((x1, None), (pos1, pos2), (n1, n2),
edge_index).tolist() == out.tolist()
assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj.t()), out,
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), edge_index),
out, atol=1e-6)
assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj1.t()), out,
atol=1e-6)
assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj2.t()), out,
atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj.t()),
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj1.t()),
out, atol=1e-6)
assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj2.t()),
out, atol=1e-6)

if is_full_test():
t = '(PairOptTensor, PairTensor, PairTensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, None), (pos1, pos2), (n1, n2),
edge_index).tolist() == out.tolist()
assert torch.allclose(
jit((x1, None), (pos1, pos2), (n1, n2), edge_index),
out,
atol=1e-6,
)

t = '(PairOptTensor, PairTensor, PairTensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, None), (pos1, pos2), (n1, n2), adj.t()),
out, atol=1e-6)
assert torch.allclose(
jit((x1, None), (pos1, pos2), (n1, n2), adj1.t()), out, atol=1e-6)
20 changes: 12 additions & 8 deletions test/nn/conv/test_res_gated_graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,39 @@ def test_res_gated_graph_conv():
x2 = torch.randn(2, 32)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_csc_tensor()

conv = ResGatedGraphConv(8, 32)
assert str(conv) == 'ResGatedGraphConv(8, 32)'
out = conv(x1, edge_index)
assert out.size() == (4, 32)
assert conv(x1, adj.t()).tolist() == out.tolist()
assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, edge_index).tolist() == out.tolist()
assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)

t = '(Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, adj.t()).tolist() == out.tolist()
assert torch.allclose(jit(x1, adj1.t()), out, atol=1e-6)

adj = adj.sparse_resize((4, 2))
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_csc_tensor()
conv = ResGatedGraphConv((8, 32), 32)
assert str(conv) == 'ResGatedGraphConv((8, 32), 32)'
out = conv((x1, x2), edge_index)
assert out.size() == (2, 32)
assert conv((x1, x2), adj.t()).tolist() == out.tolist()
assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(PairTensor, Tensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), edge_index).tolist() == out.tolist()
assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6)

t = '(PairTensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), adj.t()).tolist() == out.tolist()
assert torch.allclose(jit((x1, x2), adj1.t()), out, atol=1e-6)