|
7 | 7 |
|
8 | 8 |
|
9 | 9 | @pytest.mark.parametrize('project', [False, True]) |
10 | | -def test_sage_conv(project): |
| 10 | +@pytest.mark.parametrize('aggr', ['mean', 'sum']) |
| 11 | +def test_sage_conv(project, aggr): |
11 | 12 | x1 = torch.randn(4, 8) |
12 | 13 | x2 = torch.randn(2, 16) |
13 | 14 | edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) |
14 | 15 | row, col = edge_index |
15 | 16 | adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) |
| 17 | + adj2 = adj.to_torch_sparse_coo_tensor() |
16 | 18 |
|
17 | | - conv = SAGEConv(8, 32, project=project) |
18 | | - assert str(conv) == 'SAGEConv(8, 32, aggr=mean)' |
| 19 | + conv = SAGEConv(8, 32, project=project, aggr=aggr) |
| 20 | + assert str(conv) == f'SAGEConv(8, 32, aggr={aggr})' |
19 | 21 | out = conv(x1, edge_index) |
20 | 22 | assert out.size() == (4, 32) |
21 | | - assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist() |
22 | | - assert conv(x1, adj.t()).tolist() == out.tolist() |
| 23 | + assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out, atol=1e-6) |
| 24 | + assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6) |
| 25 | + if aggr == 'sum': |
| 26 | + assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) |
23 | 27 |
|
24 | 28 | if is_full_test(): |
25 | 29 | t = '(Tensor, Tensor, Size) -> Tensor' |
26 | 30 | jit = torch.jit.script(conv.jittable(t)) |
27 | | - assert jit(x1, edge_index).tolist() == out.tolist() |
28 | | - assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist() |
| 31 | + assert torch.allclose(jit(x1, edge_index), out, atol=1e-6) |
| 32 | + assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out, atol=1e-6) |
29 | 33 |
|
30 | 34 | t = '(Tensor, SparseTensor, Size) -> Tensor' |
31 | 35 | jit = torch.jit.script(conv.jittable(t)) |
32 | | - assert jit(x1, adj.t()).tolist() == out.tolist() |
| 36 | + assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6) |
33 | 37 |
|
34 | 38 | adj = adj.sparse_resize((4, 2)) |
35 | | - conv = SAGEConv((8, 16), 32, project=project) |
36 | | - assert str(conv) == 'SAGEConv((8, 16), 32, aggr=mean)' |
| 39 | + adj2 = adj.to_torch_sparse_coo_tensor() |
| 40 | + conv = SAGEConv((8, 16), 32, project=project, aggr=aggr) |
| 41 | + assert str(conv) == f'SAGEConv((8, 16), 32, aggr={aggr})' |
37 | 42 | out1 = conv((x1, x2), edge_index) |
38 | 43 | out2 = conv((x1, None), edge_index, (4, 2)) |
39 | 44 | assert out1.size() == (2, 32) |
40 | 45 | assert out2.size() == (2, 32) |
41 | | - assert conv((x1, x2), edge_index, (4, 2)).tolist() == out1.tolist() |
42 | | - assert conv((x1, x2), adj.t()).tolist() == out1.tolist() |
43 | | - assert conv((x1, None), adj.t()).tolist() == out2.tolist() |
| 46 | + assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1, atol=1e-6) |
| 47 | + assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6) |
| 48 | + assert torch.allclose(conv((x1, None), adj.t()), out2, atol=1e-6) |
| 49 | + if aggr == 'sum': |
| 50 | + assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6) |
| 51 | + assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6) |
44 | 52 |
|
45 | 53 | if is_full_test(): |
46 | 54 | t = '(OptPairTensor, Tensor, Size) -> Tensor' |
47 | 55 | jit = torch.jit.script(conv.jittable(t)) |
48 | | - assert jit((x1, x2), edge_index).tolist() == out1.tolist() |
49 | | - assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out1.tolist() |
50 | | - assert jit((x1, None), edge_index, |
51 | | - size=(4, 2)).tolist() == out2.tolist() |
| 56 | + assert torch.allclose(jit((x1, x2), edge_index), out1, atol=1e-6) |
| 57 | + assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1, |
| 58 | + atol=1e-6) |
| 59 | + assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2, |
| 60 | + atol=1e-6) |
52 | 61 |
|
53 | 62 | t = '(OptPairTensor, SparseTensor, Size) -> Tensor' |
54 | 63 | jit = torch.jit.script(conv.jittable(t)) |
55 | | - assert jit((x1, x2), adj.t()).tolist() == out1.tolist() |
56 | | - assert jit((x1, None), adj.t()).tolist() == out2.tolist() |
| 64 | + assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6) |
| 65 | + assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6) |
57 | 66 |
|
58 | 67 |
|
59 | 68 | def test_lstm_aggr_sage_conv(): |
|
0 commit comments