From 8acdcf8f2b9a0bb065b2e6de5fe80dd352455c04 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Thu, 10 Nov 2022 08:38:16 +0800 Subject: [PATCH 1/7] add SparseaTensor support for MessagePassing --- torch_geometric/nn/conv/message_passing.py | 73 ++++++++++++++++------ 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index 95be8e1a9f92..4ad09d6e957c 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -26,6 +26,7 @@ from torch_geometric.nn.aggr import Aggregation, MultiAggregation from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver from torch_geometric.typing import Adj, Size +from torch_geometric.utils import is_torch_sparse_tensor from .utils.helpers import expand_left from .utils.inspector import Inspector, func_body_repr, func_header_repr @@ -182,7 +183,18 @@ def __init__( def __check_input__(self, edge_index, size): the_size: List[Optional[int]] = [None, None] - if isinstance(edge_index, Tensor): + if is_torch_sparse_tensor(edge_index): + if self.flow == 'target_to_source': + raise ValueError( + ('Flow direction "target_to_source" is invalid for ' + 'message propagation via `torch.sparse.Tensor`. If ' + 'you really want to make use of a reverse message ' + 'passing flow, pass in the transposed sparse tensor to ' + 'the message passing module, e.g., `adj_t.t()`.')) + the_size[0] = edge_index.size(1) + the_size[1] = edge_index.size(0) + return the_size + elif isinstance(edge_index, Tensor): int_dtypes = (torch.uint8, torch.int8, torch.int32, torch.int64) if edge_index.dtype not in int_dtypes: @@ -214,8 +226,8 @@ def __check_input__(self, edge_index, size): raise ValueError( ('`MessagePassing.propagate` only supports integer tensors of ' - 'shape `[2, num_messages]` or `torch_sparse.SparseTensor` for ' - 'argument `edge_index`.')) + 'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or ' + '`torch.sparse.Tensor` for argument `edge_index`.')) def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor): the_size = size[dim] @@ -227,7 +239,11 @@ def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor): f'dimension {self.node_dim}, but expected size {the_size}.')) def __lift__(self, src, edge_index, dim): - if isinstance(edge_index, Tensor): + if is_torch_sparse_tensor(edge_index): + # TODO: should we use `rowptr` when `dim=1`` as like SparseTensor? + index = edge_index._indices()[dim] + return src.index_select(self.node_dim, index) + elif isinstance(edge_index, Tensor): try: index = edge_index[dim] return src.index_select(self.node_dim, index) @@ -270,8 +286,8 @@ def __lift__(self, src, edge_index, dim): raise ValueError( ('`MessagePassing.propagate` only supports integer tensors of ' - 'shape `[2, num_messages]` or `torch_sparse.SparseTensor` for ' - 'argument `edge_index`.')) + 'shape `[2, num_messages]`, `torch_sparse.SparseTensor` ' + 'or `torch.sparse.Tensor` for argument `edge_index`.')) def __collect__(self, args, edge_index, size, kwargs): i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1) @@ -296,7 +312,22 @@ def __collect__(self, args, edge_index, size, kwargs): out[arg] = data - if isinstance(edge_index, Tensor): + if is_torch_sparse_tensor(edge_index): + # TODO: Since `._values()`` returns a detached tensor, + # should we use a coalesced matrix instead? + # This may lead to more overheads though. + out['adj_t'] = edge_index + out['edge_index'] = None + out['edge_index_i'] = edge_index._indices()[0] + out['edge_index_j'] = edge_index._indices()[1] + out['ptr'] = None # TODO: should we handle this? + if out.get('edge_weight', None) is None: + out['edge_weight'] = edge_index._values() + if out.get('edge_attr', None) is None: + out['edge_attr'] = edge_index._values() + if out.get('edge_type', None) is None: + out['edge_type'] = edge_index._values() + elif isinstance(edge_index, Tensor): out['adj_t'] = None out['edge_index'] = edge_index out['edge_index_i'] = edge_index[i] @@ -327,8 +358,9 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs): r"""The initial call to start propagating messages. Args: - edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a - :obj:`torch_sparse.SparseTensor` that defines the underlying + edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor`, a + :obj:`torch_sparse.SparseTensor` or a + :obj:`torch.sparse.Tensor that defines the underlying graph connectivity/message passing flow. :obj:`edge_index` holds the indices of a general (sparse) assignment matrix of shape :obj:`[N, M]`. @@ -338,9 +370,9 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs): nodes in :obj:`edge_index[1]` (in case :obj:`flow="source_to_target"`). If :obj:`edge_index` is of type - :obj:`torch_sparse.SparseTensor`, its sparse indices - :obj:`(row, col)` should relate to :obj:`row = edge_index[1]` - and :obj:`col = edge_index[0]`. + :obj:`torch_sparse.SparseTensor` or :obj:`torch.sparse.Tensor`, + its sparse indices :obj:`(row, col)` should relate to + :obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`. The major difference between both formats is that we need to input the *transposed* sparse adjacency matrix into :func:`propagate`. @@ -349,7 +381,8 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs): If set to :obj:`None`, the size will be automatically inferred and assumed to be quadratic. This argument is ignored in case :obj:`edge_index` is a - :obj:`torch_sparse.SparseTensor`. (default: :obj:`None`) + :obj:`torch_sparse.SparseTensor` or + a :obj:`torch.sparse.Tensor`. (default: :obj:`None`) **kwargs: Any additional data which is needed to construct and aggregate messages, and to update node embeddings. """ @@ -363,7 +396,8 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs): size = self.__check_input__(edge_index, size) # Run "fused" message and aggregation (if applicable). - if (isinstance(edge_index, SparseTensor) and self.fuse + if ((isinstance(edge_index, SparseTensor) + or is_torch_sparse_tensor(edge_index)) and self.fuse and not self.explain): coll_dict = self.__collect__(self.__fused_user_args__, edge_index, size, kwargs) @@ -451,8 +485,9 @@ def edge_updater(self, edge_index: Adj, **kwargs): graph. Args: - edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a - :obj:`torch_sparse.SparseTensor` that defines the underlying + edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor`, a + :obj:`torch_sparse.SparseTensor` or + a :obj:`torch.sparse.Tensor` that defines the underlying graph connectivity/message passing flow. See :meth:`propagate` for more information. **kwargs: Any additional data which is needed to compute or update @@ -549,13 +584,15 @@ def aggregate(self, inputs: Tensor, index: Tensor, return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size, dim=self.node_dim) - def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: + def message_and_aggregate(self, adj_t: Union[SparseTensor, + Tensor]) -> Tensor: r"""Fuses computations of :func:`message` and :func:`aggregate` into a single function. If applicable, this saves both time and memory since messages do not explicitly need to be materialized. This function will only gets called in case it is implemented and - propagation takes place based on a :obj:`torch_sparse.SparseTensor`. + propagation takes place based on a :obj:`torch_sparse.SparseTensor` + or a `torch.sparse.Tensor`. """ raise NotImplementedError From 185c8b0ba2af76a9205d6223994e40cce96cd4a9 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Thu, 10 Nov 2022 15:57:24 +0800 Subject: [PATCH 2/7] add test --- test/nn/conv/test_message_passing.py | 22 +++++++++++++++++++++- torch_geometric/nn/conv/message_passing.py | 3 ++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/test/nn/conv/test_message_passing.py b/test/nn/conv/test_message_passing.py index 3c844d0643e1..43436194ee5f 100644 --- a/test/nn/conv/test_message_passing.py +++ b/test/nn/conv/test_message_passing.py @@ -7,10 +7,10 @@ from torch.nn import Linear from torch_scatter import scatter from torch_sparse import SparseTensor -from torch_sparse.matmul import spmm from torch_geometric.nn import MessagePassing, aggr from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size +from torch_geometric.utils import spmm class MyConv(MessagePassing): @@ -61,8 +61,12 @@ def test_my_conv(): assert out.size() == (4, 32) assert conv(x1, edge_index, value, (4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() + assert conv(x1, + adj.t().to_torch_sparse_coo_tensor()).tolist() == out.tolist() conv.fuse = False assert conv(x1, adj.t()).tolist() == out.tolist() + assert conv(x1, + adj.t().to_torch_sparse_coo_tensor()).tolist() == out.tolist() conv.fuse = True adj = adj.sparse_resize((4, 2)) @@ -73,10 +77,22 @@ def test_my_conv(): assert out2.size() == (2, 32) assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist() assert conv((x1, x2), adj.t()).tolist() == out1.tolist() + assert conv( + (x1, x2), + adj.t().to_torch_sparse_coo_tensor()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() + assert conv( + (x1, None), + adj.t().to_torch_sparse_coo_tensor()).tolist() == out2.tolist() conv.fuse = False assert conv((x1, x2), adj.t()).tolist() == out1.tolist() + assert conv( + (x1, x2), + adj.t().to_torch_sparse_coo_tensor()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() + assert conv( + (x1, None), + adj.t().to_torch_sparse_coo_tensor()).tolist() == out2.tolist() conv.fuse = True @@ -202,6 +218,7 @@ def test_my_multiple_aggr_conv(multi_aggr_tuple): out = conv(x, edge_index) assert out.size() == (4, 16 * expand) assert torch.allclose(conv(x, adj.t()), out) + assert torch.allclose(conv(x, adj.t().to_torch_sparse_coo_tensor()), out) def test_my_multiple_aggr_conv_jittable(): @@ -272,6 +289,7 @@ def test_my_edge_conv(): assert out.size() == (4, 16) assert torch.allclose(out, expected) assert torch.allclose(conv(x, adj.t()), out) + assert torch.allclose(conv(x, adj.t().to_torch_sparse_coo_tensor()), out) def test_my_edge_conv_jittable(): @@ -429,6 +447,8 @@ def test_my_default_arg_conv(): conv = MyDefaultArgConv() assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0] assert conv(x, adj.t()).view(-1).tolist() == [0, 0, 0, 0] + assert conv(x, adj.t().to_torch_sparse_coo_tensor() + ).view(-1).tolist() == [0, 0, 0, 0] # yapf: disable def test_my_default_arg_conv_jittable(): diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index 4ad09d6e957c..fb496d99e179 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -241,7 +241,8 @@ def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor): def __lift__(self, src, edge_index, dim): if is_torch_sparse_tensor(edge_index): # TODO: should we use `rowptr` when `dim=1`` as like SparseTensor? - index = edge_index._indices()[dim] + assert dim == 0 or dim == 1 + index = edge_index._indices()[1 - dim] return src.index_select(self.node_dim, index) elif isinstance(edge_index, Tensor): try: From 2226e3772d7b1f8502a0ce7ed3c68dfa251f002d Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Thu, 10 Nov 2022 16:03:12 +0800 Subject: [PATCH 3/7] format --- test/nn/conv/test_message_passing.py | 28 ++++++++-------------- torch_geometric/nn/conv/message_passing.py | 6 +++-- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/test/nn/conv/test_message_passing.py b/test/nn/conv/test_message_passing.py index 43436194ee5f..b70435540c74 100644 --- a/test/nn/conv/test_message_passing.py +++ b/test/nn/conv/test_message_passing.py @@ -55,21 +55,21 @@ def test_my_conv(): row, col = edge_index value = torch.randn(row.size(0)) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) + torch_adj = adj.to_torch_sparse_coo_tensor() conv = MyConv(8, 32) out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert conv(x1, edge_index, value, (4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() - assert conv(x1, - adj.t().to_torch_sparse_coo_tensor()).tolist() == out.tolist() + assert conv(x1, torch_adj.t()).tolist() == out.tolist() conv.fuse = False assert conv(x1, adj.t()).tolist() == out.tolist() - assert conv(x1, - adj.t().to_torch_sparse_coo_tensor()).tolist() == out.tolist() + assert conv(x1, torch_adj.t()).tolist() == out.tolist() conv.fuse = True adj = adj.sparse_resize((4, 2)) + torch_adj = adj.to_torch_sparse_coo_tensor() conv = MyConv((8, 16), 32) out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) @@ -77,22 +77,14 @@ def test_my_conv(): assert out2.size() == (2, 32) assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist() assert conv((x1, x2), adj.t()).tolist() == out1.tolist() - assert conv( - (x1, x2), - adj.t().to_torch_sparse_coo_tensor()).tolist() == out1.tolist() + assert conv((x1, x2), torch_adj.t()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() - assert conv( - (x1, None), - adj.t().to_torch_sparse_coo_tensor()).tolist() == out2.tolist() + assert conv((x1, None), torch_adj.t()).tolist() == out2.tolist() conv.fuse = False assert conv((x1, x2), adj.t()).tolist() == out1.tolist() - assert conv( - (x1, x2), - adj.t().to_torch_sparse_coo_tensor()).tolist() == out1.tolist() + assert conv((x1, x2), torch_adj.t()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() - assert conv( - (x1, None), - adj.t().to_torch_sparse_coo_tensor()).tolist() == out2.tolist() + assert conv((x1, None), torch_adj.t()).tolist() == out2.tolist() conv.fuse = True @@ -443,12 +435,12 @@ def test_my_default_arg_conv(): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + torch_adj = adj.to_torch_sparse_coo_tensor() conv = MyDefaultArgConv() assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0] assert conv(x, adj.t()).view(-1).tolist() == [0, 0, 0, 0] - assert conv(x, adj.t().to_torch_sparse_coo_tensor() - ).view(-1).tolist() == [0, 0, 0, 0] # yapf: disable + assert conv(x, torch_adj.t()).view(-1).tolist() == [0, 0, 0, 0] def test_my_default_arg_conv_jittable(): diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index fb496d99e179..37016fa52c56 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -585,8 +585,10 @@ def aggregate(self, inputs: Tensor, index: Tensor, return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size, dim=self.node_dim) - def message_and_aggregate(self, adj_t: Union[SparseTensor, - Tensor]) -> Tensor: + def message_and_aggregate( + self, + adj_t: Union[SparseTensor, Tensor], + ) -> Tensor: r"""Fuses computations of :func:`message` and :func:`aggregate` into a single function. If applicable, this saves both time and memory since messages do not From 40e652e03ec5ff0c62236d654ed437154a122a00 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Thu, 10 Nov 2022 21:42:05 +0800 Subject: [PATCH 4/7] Update test --- test/nn/conv/test_message_passing.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/nn/conv/test_message_passing.py b/test/nn/conv/test_message_passing.py index b70435540c74..f4a8d9cb4e0c 100644 --- a/test/nn/conv/test_message_passing.py +++ b/test/nn/conv/test_message_passing.py @@ -62,10 +62,10 @@ def test_my_conv(): assert out.size() == (4, 32) assert conv(x1, edge_index, value, (4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() - assert conv(x1, torch_adj.t()).tolist() == out.tolist() + assert torch.allclose(conv(x1, torch_adj.t()), out) conv.fuse = False assert conv(x1, adj.t()).tolist() == out.tolist() - assert conv(x1, torch_adj.t()).tolist() == out.tolist() + assert torch.allclose(conv(x1, torch_adj.t()), out) conv.fuse = True adj = adj.sparse_resize((4, 2)) @@ -77,14 +77,14 @@ def test_my_conv(): assert out2.size() == (2, 32) assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist() assert conv((x1, x2), adj.t()).tolist() == out1.tolist() - assert conv((x1, x2), torch_adj.t()).tolist() == out1.tolist() + assert torch.allclose(conv((x1, x2), torch_adj.t()), out1) assert conv((x1, None), adj.t()).tolist() == out2.tolist() - assert conv((x1, None), torch_adj.t()).tolist() == out2.tolist() + assert torch.allclose(conv((x1, None), torch_adj.t()), out2) conv.fuse = False assert conv((x1, x2), adj.t()).tolist() == out1.tolist() - assert conv((x1, x2), torch_adj.t()).tolist() == out1.tolist() + assert torch.allclose(conv((x1, x2), torch_adj.t()), out1) assert conv((x1, None), adj.t()).tolist() == out2.tolist() - assert conv((x1, None), torch_adj.t()).tolist() == out2.tolist() + assert torch.allclose(conv((x1, None), torch_adj.t()), out2) conv.fuse = True From 5a7b51e8a2eff30109cdb29c51c682396d31e20b Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Thu, 10 Nov 2022 22:09:57 +0800 Subject: [PATCH 5/7] Update test --- test/nn/conv/test_message_passing.py | 6 ++++++ torch_geometric/nn/conv/message_passing.py | 20 ++++++++++++-------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/test/nn/conv/test_message_passing.py b/test/nn/conv/test_message_passing.py index f4a8d9cb4e0c..92531aac2b66 100644 --- a/test/nn/conv/test_message_passing.py +++ b/test/nn/conv/test_message_passing.py @@ -87,6 +87,12 @@ def test_my_conv(): assert torch.allclose(conv((x1, None), torch_adj.t()), out2) conv.fuse = True + # Test backward compatibility for PyTorch SparseTensor + conv.fuse = True + torch_adj = torch_adj.requires_grad_() + conv((x1, x2), torch_adj.t()).sum().backward() + assert torch_adj.grad is not None + def test_my_conv_out_of_bounds(): x = torch.randn(3, 8) diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index 37016fa52c56..35347e19866f 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -314,20 +314,24 @@ def __collect__(self, args, edge_index, size, kwargs): out[arg] = data if is_torch_sparse_tensor(edge_index): - # TODO: Since `._values()`` returns a detached tensor, - # should we use a coalesced matrix instead? - # This may lead to more overheads though. + if edge_index.requires_grad: + edge_index = edge_index.coalesce() + indices = edge_index.indices() + values = edge_index.values() + else: + indices = edge_index._indices() + values = edge_index._values() out['adj_t'] = edge_index out['edge_index'] = None - out['edge_index_i'] = edge_index._indices()[0] - out['edge_index_j'] = edge_index._indices()[1] + out['edge_index_i'] = indices[0] + out['edge_index_j'] = indices[1] out['ptr'] = None # TODO: should we handle this? if out.get('edge_weight', None) is None: - out['edge_weight'] = edge_index._values() + out['edge_weight'] = values if out.get('edge_attr', None) is None: - out['edge_attr'] = edge_index._values() + out['edge_attr'] = values if out.get('edge_type', None) is None: - out['edge_type'] = edge_index._values() + out['edge_type'] = values elif isinstance(edge_index, Tensor): out['adj_t'] = None out['edge_index'] = edge_index From cedce468784c4306f5e49fc6e5b3a33a9d010db4 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 15 Nov 2022 12:18:01 +0000 Subject: [PATCH 6/7] update --- test/nn/conv/test_message_passing.py | 25 ++++++++++++---------- torch_geometric/nn/conv/message_passing.py | 10 +++++---- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/test/nn/conv/test_message_passing.py b/test/nn/conv/test_message_passing.py index 92531aac2b66..343c735d08e7 100644 --- a/test/nn/conv/test_message_passing.py +++ b/test/nn/conv/test_message_passing.py @@ -60,34 +60,35 @@ def test_my_conv(): conv = MyConv(8, 32) out = conv(x1, edge_index, value) assert out.size() == (4, 32) - assert conv(x1, edge_index, value, (4, 4)).tolist() == out.tolist() - assert conv(x1, adj.t()).tolist() == out.tolist() + assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out) + assert torch.allclose(conv(x1, adj.t()), out) assert torch.allclose(conv(x1, torch_adj.t()), out) conv.fuse = False - assert conv(x1, adj.t()).tolist() == out.tolist() + assert torch.allclose(conv(x1, adj.t()), out) assert torch.allclose(conv(x1, torch_adj.t()), out) conv.fuse = True adj = adj.sparse_resize((4, 2)) torch_adj = adj.to_torch_sparse_coo_tensor() + conv = MyConv((8, 16), 32) out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) - assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist() - assert conv((x1, x2), adj.t()).tolist() == out1.tolist() + assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) + assert torch.allclose(conv((x1, x2), adj.t()), out1) assert torch.allclose(conv((x1, x2), torch_adj.t()), out1) - assert conv((x1, None), adj.t()).tolist() == out2.tolist() + assert torch.allclose(conv((x1, None), adj.t()), out2) assert torch.allclose(conv((x1, None), torch_adj.t()), out2) conv.fuse = False - assert conv((x1, x2), adj.t()).tolist() == out1.tolist() + assert torch.allclose(conv((x1, x2), adj.t()), out1) assert torch.allclose(conv((x1, x2), torch_adj.t()), out1) - assert conv((x1, None), adj.t()).tolist() == out2.tolist() + assert torch.allclose(conv((x1, None), adj.t()), out2) assert torch.allclose(conv((x1, None), torch_adj.t()), out2) conv.fuse = True - # Test backward compatibility for PyTorch SparseTensor + # Test backward compatibility for `torch.sparse` tensors: conv.fuse = True torch_adj = torch_adj.requires_grad_() conv((x1, x2), torch_adj.t()).sum().backward() @@ -211,12 +212,13 @@ def test_my_multiple_aggr_conv(multi_aggr_tuple): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + torch_adj = adj.to_torch_sparse_coo_tensor() conv = MyMultipleAggrConv(aggr_kwargs=aggr_kwargs) out = conv(x, edge_index) assert out.size() == (4, 16 * expand) assert torch.allclose(conv(x, adj.t()), out) - assert torch.allclose(conv(x, adj.t().to_torch_sparse_coo_tensor()), out) + assert torch.allclose(conv(x, torch_adj.t()), out) def test_my_multiple_aggr_conv_jittable(): @@ -279,6 +281,7 @@ def test_my_edge_conv(): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + torch_adj = adj.to_torch_sparse_coo_tensor() expected = scatter(x[row] - x[col], col, dim=0, dim_size=4, reduce='add') @@ -287,7 +290,7 @@ def test_my_edge_conv(): assert out.size() == (4, 16) assert torch.allclose(out, expected) assert torch.allclose(conv(x, adj.t()), out) - assert torch.allclose(conv(x, adj.t().to_torch_sparse_coo_tensor()), out) + assert torch.allclose(conv(x, torch_adj.t()), out) def test_my_edge_conv_jittable(): diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index 35347e19866f..944cc22c6c15 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -240,10 +240,10 @@ def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor): def __lift__(self, src, edge_index, dim): if is_torch_sparse_tensor(edge_index): - # TODO: should we use `rowptr` when `dim=1`` as like SparseTensor? assert dim == 0 or dim == 1 index = edge_index._indices()[1 - dim] return src.index_select(self.node_dim, index) + elif isinstance(edge_index, Tensor): try: index = edge_index[dim] @@ -325,19 +325,21 @@ def __collect__(self, args, edge_index, size, kwargs): out['edge_index'] = None out['edge_index_i'] = indices[0] out['edge_index_j'] = indices[1] - out['ptr'] = None # TODO: should we handle this? + out['ptr'] = None # TODO Get `rowptr` from CSR representation. if out.get('edge_weight', None) is None: out['edge_weight'] = values if out.get('edge_attr', None) is None: out['edge_attr'] = values if out.get('edge_type', None) is None: out['edge_type'] = values + elif isinstance(edge_index, Tensor): out['adj_t'] = None out['edge_index'] = edge_index out['edge_index_i'] = edge_index[i] out['edge_index_j'] = edge_index[j] out['ptr'] = None + elif isinstance(edge_index, SparseTensor): out['adj_t'] = edge_index out['edge_index'] = None @@ -365,7 +367,7 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs): Args: edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor`, a :obj:`torch_sparse.SparseTensor` or a - :obj:`torch.sparse.Tensor that defines the underlying + :obj:`torch.sparse.Tensor` that defines the underlying graph connectivity/message passing flow. :obj:`edge_index` holds the indices of a general (sparse) assignment matrix of shape :obj:`[N, M]`. @@ -599,7 +601,7 @@ def message_and_aggregate( explicitly need to be materialized. This function will only gets called in case it is implemented and propagation takes place based on a :obj:`torch_sparse.SparseTensor` - or a `torch.sparse.Tensor`. + or a :obj:`torch.sparse.Tensor`. """ raise NotImplementedError From 9c575b11b0dec657c9672cc0a04fb9818138c5c5 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 15 Nov 2022 12:19:08 +0000 Subject: [PATCH 7/7] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 60010ea5f730..8f87510aba40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add `to_fixed_size` graph transformer ([#5939](https://github.com/pyg-team/pytorch_geometric/pull/5939)) - Add support for symbolic tracing of `SchNet` model ([#5938](https://github.com/pyg-team/pytorch_geometric/pull/5938)) - Add support for customizable interaction graph in `SchNet` model ([#5919](https://github.com/pyg-team/pytorch_geometric/pull/5919)) -- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906)) +- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944)) - Added `HydroNet` water cluster dataset ([#5537](https://github.com/pyg-team/pytorch_geometric/pull/5537), [#5902](https://github.com/pyg-team/pytorch_geometric/pull/5902), [#5903](https://github.com/pyg-team/pytorch_geometric/pull/5903)) - Added explainability support for heterogeneous GNNs ([#5886](https://github.com/pyg-team/pytorch_geometric/pull/5886)) - Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888))