Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from torch.nn import Linear
from torch_scatter import scatter
from torch_sparse import SparseTensor
from torch_sparse.matmul import spmm

from torch_geometric.nn import MessagePassing, aggr
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_geometric.utils import spmm


class MyConv(MessagePassing):
Expand Down Expand Up @@ -55,30 +55,44 @@ def test_my_conv():
row, col = edge_index
value = torch.randn(row.size(0))
adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
torch_adj = adj.to_torch_sparse_coo_tensor()

conv = MyConv(8, 32)
out = conv(x1, edge_index, value)
assert out.size() == (4, 32)
assert conv(x1, edge_index, value, (4, 4)).tolist() == out.tolist()
assert conv(x1, adj.t()).tolist() == out.tolist()
assert torch.allclose(conv(x1, torch_adj.t()), out)
conv.fuse = False
assert conv(x1, adj.t()).tolist() == out.tolist()
assert torch.allclose(conv(x1, torch_adj.t()), out)
conv.fuse = True

adj = adj.sparse_resize((4, 2))
torch_adj = adj.to_torch_sparse_coo_tensor()
conv = MyConv((8, 16), 32)
out1 = conv((x1, x2), edge_index, value)
out2 = conv((x1, None), edge_index, value, (4, 2))
assert out1.size() == (2, 32)
assert out2.size() == (2, 32)
assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist()
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1)
assert conv((x1, None), adj.t()).tolist() == out2.tolist()
assert torch.allclose(conv((x1, None), torch_adj.t()), out2)
conv.fuse = False
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1)
assert conv((x1, None), adj.t()).tolist() == out2.tolist()
assert torch.allclose(conv((x1, None), torch_adj.t()), out2)
conv.fuse = True

# Test backward compatibility for PyTorch SparseTensor
conv.fuse = True
torch_adj = torch_adj.requires_grad_()
conv((x1, x2), torch_adj.t()).sum().backward()
assert torch_adj.grad is not None


def test_my_conv_out_of_bounds():
x = torch.randn(3, 8)
Expand Down Expand Up @@ -202,6 +216,7 @@ def test_my_multiple_aggr_conv(multi_aggr_tuple):
out = conv(x, edge_index)
assert out.size() == (4, 16 * expand)
assert torch.allclose(conv(x, adj.t()), out)
assert torch.allclose(conv(x, adj.t().to_torch_sparse_coo_tensor()), out)


def test_my_multiple_aggr_conv_jittable():
Expand Down Expand Up @@ -272,6 +287,7 @@ def test_my_edge_conv():
assert out.size() == (4, 16)
assert torch.allclose(out, expected)
assert torch.allclose(conv(x, adj.t()), out)
assert torch.allclose(conv(x, adj.t().to_torch_sparse_coo_tensor()), out)


def test_my_edge_conv_jittable():
Expand Down Expand Up @@ -425,10 +441,12 @@ def test_my_default_arg_conv():
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
torch_adj = adj.to_torch_sparse_coo_tensor()

conv = MyDefaultArgConv()
assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0]
assert conv(x, adj.t()).view(-1).tolist() == [0, 0, 0, 0]
assert conv(x, torch_adj.t()).view(-1).tolist() == [0, 0, 0, 0]


def test_my_default_arg_conv_jittable():
Expand Down
80 changes: 62 additions & 18 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch_geometric.nn.aggr import Aggregation, MultiAggregation
from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver
from torch_geometric.typing import Adj, Size
from torch_geometric.utils import is_torch_sparse_tensor

from .utils.helpers import expand_left
from .utils.inspector import Inspector, func_body_repr, func_header_repr
Expand Down Expand Up @@ -182,7 +183,18 @@ def __init__(
def __check_input__(self, edge_index, size):
the_size: List[Optional[int]] = [None, None]

if isinstance(edge_index, Tensor):
if is_torch_sparse_tensor(edge_index):
if self.flow == 'target_to_source':
raise ValueError(
('Flow direction "target_to_source" is invalid for '
'message propagation via `torch.sparse.Tensor`. If '
'you really want to make use of a reverse message '
'passing flow, pass in the transposed sparse tensor to '
'the message passing module, e.g., `adj_t.t()`.'))
the_size[0] = edge_index.size(1)
the_size[1] = edge_index.size(0)
return the_size
elif isinstance(edge_index, Tensor):
int_dtypes = (torch.uint8, torch.int8, torch.int32, torch.int64)

if edge_index.dtype not in int_dtypes:
Expand Down Expand Up @@ -214,8 +226,8 @@ def __check_input__(self, edge_index, size):

raise ValueError(
('`MessagePassing.propagate` only supports integer tensors of '
'shape `[2, num_messages]` or `torch_sparse.SparseTensor` for '
'argument `edge_index`.'))
'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or '
'`torch.sparse.Tensor` for argument `edge_index`.'))

def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor):
the_size = size[dim]
Expand All @@ -227,7 +239,12 @@ def __set_size__(self, size: List[Optional[int]], dim: int, src: Tensor):
f'dimension {self.node_dim}, but expected size {the_size}.'))

def __lift__(self, src, edge_index, dim):
if isinstance(edge_index, Tensor):
if is_torch_sparse_tensor(edge_index):
# TODO: should we use `rowptr` when `dim=1`` as like SparseTensor?
assert dim == 0 or dim == 1
index = edge_index._indices()[1 - dim]
return src.index_select(self.node_dim, index)
elif isinstance(edge_index, Tensor):
try:
index = edge_index[dim]
return src.index_select(self.node_dim, index)
Expand Down Expand Up @@ -270,8 +287,8 @@ def __lift__(self, src, edge_index, dim):

raise ValueError(
('`MessagePassing.propagate` only supports integer tensors of '
'shape `[2, num_messages]` or `torch_sparse.SparseTensor` for '
'argument `edge_index`.'))
'shape `[2, num_messages]`, `torch_sparse.SparseTensor` '
'or `torch.sparse.Tensor` for argument `edge_index`.'))

def __collect__(self, args, edge_index, size, kwargs):
i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
Expand All @@ -296,7 +313,26 @@ def __collect__(self, args, edge_index, size, kwargs):

out[arg] = data

if isinstance(edge_index, Tensor):
if is_torch_sparse_tensor(edge_index):
if edge_index.requires_grad:
edge_index = edge_index.coalesce()
indices = edge_index.indices()
values = edge_index.values()
else:
indices = edge_index._indices()
values = edge_index._values()
out['adj_t'] = edge_index
out['edge_index'] = None
out['edge_index_i'] = indices[0]
out['edge_index_j'] = indices[1]
out['ptr'] = None # TODO: should we handle this?
if out.get('edge_weight', None) is None:
out['edge_weight'] = values
if out.get('edge_attr', None) is None:
out['edge_attr'] = values
if out.get('edge_type', None) is None:
out['edge_type'] = values
elif isinstance(edge_index, Tensor):
out['adj_t'] = None
out['edge_index'] = edge_index
out['edge_index_i'] = edge_index[i]
Expand Down Expand Up @@ -327,8 +363,9 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
r"""The initial call to start propagating messages.

Args:
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
:obj:`torch_sparse.SparseTensor` that defines the underlying
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor`, a
:obj:`torch_sparse.SparseTensor` or a
:obj:`torch.sparse.Tensor that defines the underlying
graph connectivity/message passing flow.
:obj:`edge_index` holds the indices of a general (sparse)
assignment matrix of shape :obj:`[N, M]`.
Expand All @@ -338,9 +375,9 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
nodes in :obj:`edge_index[1]`
(in case :obj:`flow="source_to_target"`).
If :obj:`edge_index` is of type
:obj:`torch_sparse.SparseTensor`, its sparse indices
:obj:`(row, col)` should relate to :obj:`row = edge_index[1]`
and :obj:`col = edge_index[0]`.
:obj:`torch_sparse.SparseTensor` or :obj:`torch.sparse.Tensor`,
its sparse indices :obj:`(row, col)` should relate to
:obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`.
The major difference between both formats is that we need to
input the *transposed* sparse adjacency matrix into
:func:`propagate`.
Expand All @@ -349,7 +386,8 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
If set to :obj:`None`, the size will be automatically inferred
and assumed to be quadratic.
This argument is ignored in case :obj:`edge_index` is a
:obj:`torch_sparse.SparseTensor`. (default: :obj:`None`)
:obj:`torch_sparse.SparseTensor` or
a :obj:`torch.sparse.Tensor`. (default: :obj:`None`)
**kwargs: Any additional data which is needed to construct and
aggregate messages, and to update node embeddings.
"""
Expand All @@ -363,7 +401,8 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
size = self.__check_input__(edge_index, size)

# Run "fused" message and aggregation (if applicable).
if (isinstance(edge_index, SparseTensor) and self.fuse
if ((isinstance(edge_index, SparseTensor)
or is_torch_sparse_tensor(edge_index)) and self.fuse
and not self.explain):
coll_dict = self.__collect__(self.__fused_user_args__, edge_index,
size, kwargs)
Expand Down Expand Up @@ -451,8 +490,9 @@ def edge_updater(self, edge_index: Adj, **kwargs):
graph.

Args:
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
:obj:`torch_sparse.SparseTensor` that defines the underlying
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor`, a
:obj:`torch_sparse.SparseTensor` or
a :obj:`torch.sparse.Tensor` that defines the underlying
graph connectivity/message passing flow.
See :meth:`propagate` for more information.
**kwargs: Any additional data which is needed to compute or update
Expand Down Expand Up @@ -549,13 +589,17 @@ def aggregate(self, inputs: Tensor, index: Tensor,
return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
dim=self.node_dim)

def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
def message_and_aggregate(
self,
adj_t: Union[SparseTensor, Tensor],
) -> Tensor:
r"""Fuses computations of :func:`message` and :func:`aggregate` into a
single function.
If applicable, this saves both time and memory since messages do not
explicitly need to be materialized.
This function will only gets called in case it is implemented and
propagation takes place based on a :obj:`torch_sparse.SparseTensor`.
propagation takes place based on a :obj:`torch_sparse.SparseTensor`
or a `torch.sparse.Tensor`.
"""
raise NotImplementedError

Expand Down