diff --git a/CHANGELOG.md b/CHANGELOG.md index 2efc9b85c0ac..cf431cbf031f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add `to_fixed_size` graph transformer ([#5939](https://github.com/pyg-team/pytorch_geometric/pull/5939)) - Add support for symbolic tracing of `SchNet` model ([#5938](https://github.com/pyg-team/pytorch_geometric/pull/5938)) - Add support for customizable interaction graph in `SchNet` model ([#5919](https://github.com/pyg-team/pytorch_geometric/pull/5919)) -- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906)) +- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944)) - Added `HydroNet` water cluster dataset ([#5537](https://github.com/pyg-team/pytorch_geometric/pull/5537), [#5902](https://github.com/pyg-team/pytorch_geometric/pull/5902), [#5903](https://github.com/pyg-team/pytorch_geometric/pull/5903)) - Added explainability support for heterogeneous GNNs ([#5886](https://github.com/pyg-team/pytorch_geometric/pull/5886)) - Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888)) diff --git a/test/nn/conv/test_message_passing.py b/test/nn/conv/test_message_passing.py index 3c844d0643e1..343c735d08e7 100644 --- a/test/nn/conv/test_message_passing.py +++ b/test/nn/conv/test_message_passing.py @@ -7,10 +7,10 @@ from torch.nn import Linear from torch_scatter import scatter from torch_sparse import SparseTensor -from torch_sparse.matmul import spmm from torch_geometric.nn import MessagePassing, aggr from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size +from torch_geometric.utils import spmm class MyConv(MessagePassing): @@ -55,29 +55,44 @@ def test_my_conv(): row, col = edge_index value = torch.randn(row.size(0)) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) + torch_adj = adj.to_torch_sparse_coo_tensor() conv = MyConv(8, 32) out = conv(x1, edge_index, value) assert out.size() == (4, 32) - assert conv(x1, edge_index, value, (4, 4)).tolist() == out.tolist() - assert conv(x1, adj.t()).tolist() == out.tolist() + assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out) + assert torch.allclose(conv(x1, adj.t()), out) + assert torch.allclose(conv(x1, torch_adj.t()), out) conv.fuse = False - assert conv(x1, adj.t()).tolist() == out.tolist() + assert torch.allclose(conv(x1, adj.t()), out) + assert torch.allclose(conv(x1, torch_adj.t()), out) conv.fuse = True adj = adj.sparse_resize((4, 2)) + torch_adj = adj.to_torch_sparse_coo_tensor() + conv = MyConv((8, 16), 32) out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) - assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist() - assert conv((x1, x2), adj.t()).tolist() == out1.tolist() - assert conv((x1, None), adj.t()).tolist() == out2.tolist() + assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) + assert torch.allclose(conv((x1, x2), adj.t()), out1) + assert torch.allclose(conv((x1, x2), torch_adj.t()), out1) + assert torch.allclose(conv((x1, None), adj.t()), out2) + assert torch.allclose(conv((x1, None), torch_adj.t()), out2) conv.fuse = False - assert conv((x1, x2), adj.t()).tolist() == out1.tolist() - assert conv((x1, None), adj.t()).tolist() == out2.tolist() + assert torch.allclose(conv((x1, x2), adj.t()), out1) + assert torch.allclose(conv((x1, x2), torch_adj.t()), out1) + assert torch.allclose(conv((x1, None), adj.t()), out2) + assert torch.allclose(conv((x1, None), torch_adj.t()), out2) + conv.fuse = True + + # Test backward compatibility for `torch.sparse` tensors: conv.fuse = True + torch_adj = torch_adj.requires_grad_() + conv((x1, x2), torch_adj.t()).sum().backward() + assert torch_adj.grad is not None def test_my_conv_out_of_bounds(): @@ -197,11 +212,13 @@ def test_my_multiple_aggr_conv(multi_aggr_tuple): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + torch_adj = adj.to_torch_sparse_coo_tensor() conv = MyMultipleAggrConv(aggr_kwargs=aggr_kwargs) out = conv(x, edge_index) assert out.size() == (4, 16 * expand) assert torch.allclose(conv(x, adj.t()), out) + assert torch.allclose(conv(x, torch_adj.t()), out) def test_my_multiple_aggr_conv_jittable(): @@ -264,6 +281,7 @@ def test_my_edge_conv(): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + torch_adj = adj.to_torch_sparse_coo_tensor() expected = scatter(x[row] - x[col], col, dim=0, dim_size=4, reduce='add') @@ -272,6 +290,7 @@ def test_my_edge_conv(): assert out.size() == (4, 16) assert torch.allclose(out, expected) assert torch.allclose(conv(x, adj.t()), out) + assert torch.allclose(conv(x, torch_adj.t()), out) def test_my_edge_conv_jittable(): @@ -425,10 +444,12 @@ def test_my_default_arg_conv(): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + torch_adj = adj.to_torch_sparse_coo_tensor() conv = MyDefaultArgConv() assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0] assert conv(x, adj.t()).view(-1).tolist() == [0, 0, 0, 0] + assert conv(x, torch_adj.t()).view(-1).tolist() == [0, 0, 0, 0] def test_my_default_arg_conv_jittable(): diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index 95be8e1a9f92..944cc22c6c15 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -26,6 +26,7 @@ from torch_geometric.nn.aggr import Aggregation, MultiAggregation from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver from torch_geometric.typing import Adj, Size +from torch_geometric.utils import is_torch_sparse_tensor from .utils.helpers import expand_left from .utils.inspector import Inspector, func_body_repr, func_header_repr @@ -182,7 +183,18 @@ def __init__( def __check_input__(self, edge_index, size): the_size: List[Optional[int]] = [None, None] - if isinstance(edge_index, Tensor): + if is_torch_sparse_tensor(edge_index): + if self.flow == 'target_to_source': + raise ValueError( + ('Flow direction "target_to_source" is invalid for ' + 'message propagation via `torch.sparse.Tensor`. If ' + 'you really want to make use of a reverse message ' + 'passing flow, pass in the transposed sparse tensor to ' + 'the message passing module, e.g., `adj_t.t()`.')) + the_size[0] = edge_index.size(1) + the_size[1] = edge_index.size(0) + return the_size + elif isinstance(edge_index, Tensor): int_dtypes = (torch.uint8, torch.int8, torch.int32, torch.int64) if edge_index.dtype not in int_dtypes: @@ -214,8 +226,8 @@ def __check_input__(self, edge_index, size): raise ValueError( ('`MessagePassing.propagate` only supports integer tensors of ' - 'shape `[2, num_messages]` or `torch_sparse.SparseTensor` for ' - 'argument `edge_index`.')) + 'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or ' + '`torch.sparse.Tensor` for argument `edge_index`.')) def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor): the_size = size[dim] @@ -227,7 +239,12 @@ def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor): f'dimension {self.node_dim}, but expected size {the_size}.')) def __lift__(self, src, edge_index, dim): - if isinstance(edge_index, Tensor): + if is_torch_sparse_tensor(edge_index): + assert dim == 0 or dim == 1 + index = edge_index._indices()[1 - dim] + return src.index_select(self.node_dim, index) + + elif isinstance(edge_index, Tensor): try: index = edge_index[dim] return src.index_select(self.node_dim, index) @@ -270,8 +287,8 @@ def __lift__(self, src, edge_index, dim): raise ValueError( ('`MessagePassing.propagate` only supports integer tensors of ' - 'shape `[2, num_messages]` or `torch_sparse.SparseTensor` for ' - 'argument `edge_index`.')) + 'shape `[2, num_messages]`, `torch_sparse.SparseTensor` ' + 'or `torch.sparse.Tensor` for argument `edge_index`.')) def __collect__(self, args, edge_index, size, kwargs): i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1) @@ -296,12 +313,33 @@ def __collect__(self, args, edge_index, size, kwargs): out[arg] = data - if isinstance(edge_index, Tensor): + if is_torch_sparse_tensor(edge_index): + if edge_index.requires_grad: + edge_index = edge_index.coalesce() + indices = edge_index.indices() + values = edge_index.values() + else: + indices = edge_index._indices() + values = edge_index._values() + out['adj_t'] = edge_index + out['edge_index'] = None + out['edge_index_i'] = indices[0] + out['edge_index_j'] = indices[1] + out['ptr'] = None # TODO Get `rowptr` from CSR representation. + if out.get('edge_weight', None) is None: + out['edge_weight'] = values + if out.get('edge_attr', None) is None: + out['edge_attr'] = values + if out.get('edge_type', None) is None: + out['edge_type'] = values + + elif isinstance(edge_index, Tensor): out['adj_t'] = None out['edge_index'] = edge_index out['edge_index_i'] = edge_index[i] out['edge_index_j'] = edge_index[j] out['ptr'] = None + elif isinstance(edge_index, SparseTensor): out['adj_t'] = edge_index out['edge_index'] = None @@ -327,8 +365,9 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs): r"""The initial call to start propagating messages. Args: - edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a - :obj:`torch_sparse.SparseTensor` that defines the underlying + edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor`, a + :obj:`torch_sparse.SparseTensor` or a + :obj:`torch.sparse.Tensor` that defines the underlying graph connectivity/message passing flow. :obj:`edge_index` holds the indices of a general (sparse) assignment matrix of shape :obj:`[N, M]`. @@ -338,9 +377,9 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs): nodes in :obj:`edge_index[1]` (in case :obj:`flow="source_to_target"`). If :obj:`edge_index` is of type - :obj:`torch_sparse.SparseTensor`, its sparse indices - :obj:`(row, col)` should relate to :obj:`row = edge_index[1]` - and :obj:`col = edge_index[0]`. + :obj:`torch_sparse.SparseTensor` or :obj:`torch.sparse.Tensor`, + its sparse indices :obj:`(row, col)` should relate to + :obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`. The major difference between both formats is that we need to input the *transposed* sparse adjacency matrix into :func:`propagate`. @@ -349,7 +388,8 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs): If set to :obj:`None`, the size will be automatically inferred and assumed to be quadratic. This argument is ignored in case :obj:`edge_index` is a - :obj:`torch_sparse.SparseTensor`. (default: :obj:`None`) + :obj:`torch_sparse.SparseTensor` or + a :obj:`torch.sparse.Tensor`. (default: :obj:`None`) **kwargs: Any additional data which is needed to construct and aggregate messages, and to update node embeddings. """ @@ -363,7 +403,8 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs): size = self.__check_input__(edge_index, size) # Run "fused" message and aggregation (if applicable). - if (isinstance(edge_index, SparseTensor) and self.fuse + if ((isinstance(edge_index, SparseTensor) + or is_torch_sparse_tensor(edge_index)) and self.fuse and not self.explain): coll_dict = self.__collect__(self.__fused_user_args__, edge_index, size, kwargs) @@ -451,8 +492,9 @@ def edge_updater(self, edge_index: Adj, **kwargs): graph. Args: - edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a - :obj:`torch_sparse.SparseTensor` that defines the underlying + edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor`, a + :obj:`torch_sparse.SparseTensor` or + a :obj:`torch.sparse.Tensor` that defines the underlying graph connectivity/message passing flow. See :meth:`propagate` for more information. **kwargs: Any additional data which is needed to compute or update @@ -549,13 +591,17 @@ def aggregate(self, inputs: Tensor, index: Tensor, return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size, dim=self.node_dim) - def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: + def message_and_aggregate( + self, + adj_t: Union[SparseTensor, Tensor], + ) -> Tensor: r"""Fuses computations of :func:`message` and :func:`aggregate` into a single function. If applicable, this saves both time and memory since messages do not explicitly need to be materialized. This function will only gets called in case it is implemented and - propagation takes place based on a :obj:`torch_sparse.SparseTensor`. + propagation takes place based on a :obj:`torch_sparse.SparseTensor` + or a :obj:`torch.sparse.Tensor`. """ raise NotImplementedError