@@ -20,20 +20,20 @@ def test_sage_conv(project, aggr):
2020 assert str (conv ) == f'SAGEConv(8, 32, aggr={ aggr } )'
2121 out = conv (x1 , edge_index )
2222 assert out .size () == (4 , 32 )
23- assert conv (x1 , edge_index , size = (4 , 4 )). tolist () == out . tolist ( )
24- assert conv (x1 , adj .t ()). tolist () == out . tolist ( )
23+ assert torch . allclose ( conv (x1 , edge_index , size = (4 , 4 )), out , atol = 1e-6 )
24+ assert torch . allclose ( conv (x1 , adj .t ()), out , atol = 1e-6 )
2525 if aggr == 'sum' :
26- assert conv (x1 , adj2 .t ()). tolist () == out . tolist ( )
26+ assert torch . allclose ( conv (x1 , adj2 .t ()), out , atol = 1e-6 )
2727
2828 if is_full_test ():
2929 t = '(Tensor, Tensor, Size) -> Tensor'
3030 jit = torch .jit .script (conv .jittable (t ))
31- assert jit (x1 , edge_index ). tolist () == out . tolist ( )
32- assert jit (x1 , edge_index , size = (4 , 4 )). tolist () == out . tolist ( )
31+ assert torch . allclose ( jit (x1 , edge_index ), out , atol = 1e-6 )
32+ assert torch . allclose ( jit (x1 , edge_index , size = (4 , 4 )), out , atol = 1e-6 )
3333
3434 t = '(Tensor, SparseTensor, Size) -> Tensor'
3535 jit = torch .jit .script (conv .jittable (t ))
36- assert jit (x1 , adj .t ()). tolist () == out . tolist ( )
36+ assert torch . allclose ( jit (x1 , adj .t ()), out , atol = 1e-6 )
3737
3838 adj = adj .sparse_resize ((4 , 2 ))
3939 adj2 = adj .to_torch_sparse_coo_tensor ()
@@ -43,25 +43,26 @@ def test_sage_conv(project, aggr):
4343 out2 = conv ((x1 , None ), edge_index , (4 , 2 ))
4444 assert out1 .size () == (2 , 32 )
4545 assert out2 .size () == (2 , 32 )
46- assert conv ((x1 , x2 ), edge_index , (4 , 2 )). tolist () == out1 . tolist ( )
47- assert conv ((x1 , x2 ), adj .t ()). tolist () == out1 . tolist ( )
48- assert conv ((x1 , None ), adj .t ()). tolist () == out2 . tolist ( )
46+ assert torch . allclose ( conv ((x1 , x2 ), edge_index , (4 , 2 )), out1 , atol = 1e-6 )
47+ assert torch . allclose ( conv ((x1 , x2 ), adj .t ()), out1 , atol = 1e-6 )
48+ assert torch . allclose ( conv ((x1 , None ), adj .t ()), out2 , atol = 1e-6 )
4949 if aggr == 'sum' :
50- assert conv ((x1 , x2 ), adj2 .t ()). tolist () == out1 . tolist ( )
51- assert conv ((x1 , None ), adj2 .t ()). tolist () == out2 . tolist ( )
50+ assert torch . allclose ( conv ((x1 , x2 ), adj2 .t ()), out1 , atol = 1e-6 )
51+ assert torch . allclose ( conv ((x1 , None ), adj2 .t ()), out2 , atol = 1e-6 )
5252
5353 if is_full_test ():
5454 t = '(OptPairTensor, Tensor, Size) -> Tensor'
5555 jit = torch .jit .script (conv .jittable (t ))
56- assert jit ((x1 , x2 ), edge_index ).tolist () == out1 .tolist ()
57- assert jit ((x1 , x2 ), edge_index , size = (4 , 2 )).tolist () == out1 .tolist ()
58- assert jit ((x1 , None ), edge_index ,
59- size = (4 , 2 )).tolist () == out2 .tolist ()
56+ assert torch .allclose (jit ((x1 , x2 ), edge_index ), out1 , atol = 1e-6 )
57+ assert torch .allclose (jit ((x1 , x2 ), edge_index , size = (4 , 2 )), out1 ,
58+ atol = 1e-6 )
59+ assert torch .allclose (jit ((x1 , None ), edge_index , size = (4 , 2 )), out2 ,
60+ atol = 1e-6 )
6061
6162 t = '(OptPairTensor, SparseTensor, Size) -> Tensor'
6263 jit = torch .jit .script (conv .jittable (t ))
63- assert jit ((x1 , x2 ), adj .t ()). tolist () == out1 . tolist ( )
64- assert jit ((x1 , None ), adj .t ()). tolist () == out2 . tolist ( )
64+ assert torch . allclose ( jit ((x1 , x2 ), adj .t ()), out1 , atol = 1e-6 )
65+ assert torch . allclose ( jit ((x1 , None ), adj .t ()), out2 , atol = 1e-6 )
6566
6667
6768def test_lstm_aggr_sage_conv ():
0 commit comments