Skip to content

Commit 216fa11

Browse files
committed
test
1 parent cb7002f commit 216fa11

File tree

3 files changed

+41
-37
lines changed

3 files changed

+41
-37
lines changed

test/nn/conv/test_gin_conv.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def test_gin_conv():
2626
'))')
2727
out = conv(x1, edge_index)
2828
assert out.size() == (4, 32)
29-
assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
30-
assert conv(x1, adj.t()).tolist() == out.tolist()
31-
assert conv(x1, adj2.t()).tolist() == out.tolist()
29+
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out, atol=1e-6)
30+
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
31+
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)
3232

3333
if is_full_test():
3434
t = '(Tensor, Tensor, Size) -> Tensor'
@@ -46,11 +46,11 @@ def test_gin_conv():
4646
out2 = conv((x1, None), edge_index, (4, 2))
4747
assert out1.size() == (2, 32)
4848
assert out2.size() == (2, 32)
49-
assert conv((x1, x2), edge_index, (4, 2)).tolist() == out1.tolist()
50-
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
51-
assert conv((x1, None), adj.t()).tolist() == out2.tolist()
52-
assert conv((x1, x2), adj2.t()).tolist() == out1.tolist()
53-
assert conv((x1, None), adj2.t()).tolist() == out2.tolist()
49+
assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1, atol=1e-6)
50+
assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6)
51+
assert torch.allclose(conv((x1, None), adj.t()), out2, atol=1e-6)
52+
assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)
53+
assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)
5454

5555
if is_full_test():
5656
t = '(OptPairTensor, Tensor, Size) -> Tensor'

test/nn/conv/test_graph_conv.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@ def test_graph_conv():
2020
assert conv.__repr__() == 'GraphConv(8, 32)'
2121
out11 = conv(x1, edge_index)
2222
assert out11.size() == (4, 32)
23-
assert conv(x1, edge_index, size=(4, 4)).tolist() == out11.tolist()
24-
assert conv(x1, adj1.t()).tolist() == out11.tolist()
25-
assert conv(x1, adj3.t()).tolist() == out11.tolist()
23+
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out11, atol=1e-6)
24+
assert torch.allclose(conv(x1, adj1.t()), out11, atol=1e-6)
25+
assert torch.allclose(conv(x1, adj3.t()), out11, atol=1e-6)
2626

2727
out12 = conv(x1, edge_index, value)
2828
assert out12.size() == (4, 32)
29-
assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out12.tolist()
30-
assert conv(x1, adj2.t()).tolist() == out12.tolist()
31-
assert conv(x1, adj4.t()).tolist() == out12.tolist()
29+
assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out12,
30+
atol=1e-6)
31+
assert torch.allclose(conv(x1, adj2.t()), out12, atol=1e-6)
32+
assert torch.allclose(conv(x1, adj4.t()), out12, atol=1e-6)
3233

3334
if is_full_test():
3435
t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
@@ -58,12 +59,14 @@ def test_graph_conv():
5859
assert out22.size() == (2, 32)
5960
assert out23.size() == (2, 32)
6061
assert out24.size() == (2, 32)
61-
assert conv((x1, x2), edge_index, size=(4, 2)).tolist() == out21.tolist()
62-
assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out22.tolist()
63-
assert conv((x1, x2), adj1.t()).tolist() == out21.tolist()
64-
assert conv((x1, x2), adj2.t()).tolist() == out22.tolist()
65-
assert conv((x1, x2), adj3.t()).tolist() == out21.tolist()
66-
assert conv((x1, x2), adj4.t()).tolist() == out22.tolist()
62+
assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out21,
63+
atol=1e-6)
64+
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out22,
65+
atol=1e-6)
66+
assert torch.allclose(conv((x1, x2), adj1.t()), out21, atol=1e-6)
67+
assert torch.allclose(conv((x1, x2), adj2.t()), out22, atol=1e-6)
68+
assert torch.allclose(conv((x1, x2), adj3.t()), out21, atol=1e-6)
69+
assert torch.allclose(conv((x1, x2), adj4.t()), out22, atol=1e-6)
6770

6871
if is_full_test():
6972
t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'

test/nn/conv/test_sage_conv.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,20 @@ def test_sage_conv(project, aggr):
2020
assert str(conv) == f'SAGEConv(8, 32, aggr={aggr})'
2121
out = conv(x1, edge_index)
2222
assert out.size() == (4, 32)
23-
assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
24-
assert conv(x1, adj.t()).tolist() == out.tolist()
23+
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out, atol=1e-6)
24+
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
2525
if aggr == 'sum':
26-
assert conv(x1, adj2.t()).tolist() == out.tolist()
26+
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)
2727

2828
if is_full_test():
2929
t = '(Tensor, Tensor, Size) -> Tensor'
3030
jit = torch.jit.script(conv.jittable(t))
31-
assert jit(x1, edge_index).tolist() == out.tolist()
32-
assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
31+
assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)
32+
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out, atol=1e-6)
3333

3434
t = '(Tensor, SparseTensor, Size) -> Tensor'
3535
jit = torch.jit.script(conv.jittable(t))
36-
assert jit(x1, adj.t()).tolist() == out.tolist()
36+
assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)
3737

3838
adj = adj.sparse_resize((4, 2))
3939
adj2 = adj.to_torch_sparse_coo_tensor()
@@ -43,25 +43,26 @@ def test_sage_conv(project, aggr):
4343
out2 = conv((x1, None), edge_index, (4, 2))
4444
assert out1.size() == (2, 32)
4545
assert out2.size() == (2, 32)
46-
assert conv((x1, x2), edge_index, (4, 2)).tolist() == out1.tolist()
47-
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
48-
assert conv((x1, None), adj.t()).tolist() == out2.tolist()
46+
assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1, atol=1e-6)
47+
assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6)
48+
assert torch.allclose(conv((x1, None), adj.t()), out2, atol=1e-6)
4949
if aggr == 'sum':
50-
assert conv((x1, x2), adj2.t()).tolist() == out1.tolist()
51-
assert conv((x1, None), adj2.t()).tolist() == out2.tolist()
50+
assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)
51+
assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)
5252

5353
if is_full_test():
5454
t = '(OptPairTensor, Tensor, Size) -> Tensor'
5555
jit = torch.jit.script(conv.jittable(t))
56-
assert jit((x1, x2), edge_index).tolist() == out1.tolist()
57-
assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out1.tolist()
58-
assert jit((x1, None), edge_index,
59-
size=(4, 2)).tolist() == out2.tolist()
56+
assert torch.allclose(jit((x1, x2), edge_index), out1, atol=1e-6)
57+
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1,
58+
atol=1e-6)
59+
assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2,
60+
atol=1e-6)
6061

6162
t = '(OptPairTensor, SparseTensor, Size) -> Tensor'
6263
jit = torch.jit.script(conv.jittable(t))
63-
assert jit((x1, x2), adj.t()).tolist() == out1.tolist()
64-
assert jit((x1, None), adj.t()).tolist() == out2.tolist()
64+
assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)
65+
assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6)
6566

6667

6768
def test_lstm_aggr_sage_conv():

0 commit comments

Comments
 (0)