From e56bcd8998c3258aa448d4e92e2024ec21eead8c Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Tue, 22 Nov 2022 13:04:31 +0800 Subject: [PATCH 01/10] Add PyTorch SparseTensor support --- torch_geometric/nn/conv/gcn_conv.py | 41 +++++++++++++++++++---------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/torch_geometric/nn/conv/gcn_conv.py b/torch_geometric/nn/conv/gcn_conv.py index 72d0eecc28c6..77a4ad77da6b 100644 --- a/torch_geometric/nn/conv/gcn_conv.py +++ b/torch_geometric/nn/conv/gcn_conv.py @@ -4,14 +4,20 @@ from torch import Tensor from torch.nn import Parameter from torch_scatter import scatter_add -from torch_sparse import SparseTensor, fill_diag, matmul, mul +from torch_sparse import SparseTensor, fill_diag, mul from torch_sparse import sum as sparsesum from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import zeros from torch_geometric.typing import Adj, OptTensor, PairTensor -from torch_geometric.utils import add_remaining_self_loops +from torch_geometric.utils import ( + add_remaining_self_loops, + is_sparse, + is_torch_sparse_tensor, + spmm, + to_torch_coo_tensor, +) from torch_geometric.utils.num_nodes import maybe_num_nodes @@ -34,9 +40,16 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, fill_value = 2. if improved else 1. - if isinstance(edge_index, SparseTensor): + if is_sparse(edge_index): assert flow in ["source_to_target"] adj_t = edge_index + if is_torch_sparse_tensor(adj_t): + edge_index, edge_weight = gcn_norm(adj_t._indices(), + adj_t._values(), num_nodes, + improved, add_self_loops, flow, + dtype) + return to_torch_coo_tensor(edge_index, edge_weight) + if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) if add_self_loops: @@ -169,27 +182,27 @@ def forward(self, x: Tensor, edge_index: Adj, """""" if self.normalize: - if isinstance(edge_index, Tensor): - cache = self._cached_edge_index + if is_sparse(edge_index): + cache = self._cached_adj_t if cache is None: - edge_index, edge_weight = gcn_norm( # yapf: disable + edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops, self.flow, x.dtype) if self.cached: - self._cached_edge_index = (edge_index, edge_weight) + self._cached_adj_t = edge_index else: - edge_index, edge_weight = cache[0], cache[1] + edge_index = cache - elif isinstance(edge_index, SparseTensor): - cache = self._cached_adj_t + elif isinstance(edge_index, Tensor): + cache = self._cached_edge_index if cache is None: - edge_index = gcn_norm( # yapf: disable + edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops, self.flow, x.dtype) if self.cached: - self._cached_adj_t = edge_index + self._cached_edge_index = (edge_index, edge_weight) else: - edge_index = cache + edge_index, edge_weight = cache[0], cache[1] x = self.lin(x) @@ -206,4 +219,4 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: - return matmul(adj_t, x, reduce=self.aggr) + return spmm(adj_t, x, reduce=self.aggr) From 31dbaec3387b0dcff94e53aab9aa7c58d0486659 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Tue, 22 Nov 2022 15:33:59 +0800 Subject: [PATCH 02/10] test --- test/nn/conv/test_gcn_conv.py | 4 ++++ torch_geometric/nn/conv/gcn_conv.py | 13 ++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/nn/conv/test_gcn_conv.py b/test/nn/conv/test_gcn_conv.py index 621ef603d6c6..226acd98474e 100644 --- a/test/nn/conv/test_gcn_conv.py +++ b/test/nn/conv/test_gcn_conv.py @@ -14,15 +14,19 @@ def test_gcn_conv(): value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) + adj3 = adj1.to_torch_sparse_coo_tensor() + adj4 = adj2.to_torch_sparse_coo_tensor() conv = GCNConv(16, 32) assert conv.__repr__() == 'GCNConv(16, 32)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) + assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) + assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' diff --git a/torch_geometric/nn/conv/gcn_conv.py b/torch_geometric/nn/conv/gcn_conv.py index 77a4ad77da6b..0d4dcf097513 100644 --- a/torch_geometric/nn/conv/gcn_conv.py +++ b/torch_geometric/nn/conv/gcn_conv.py @@ -16,7 +16,6 @@ is_sparse, is_torch_sparse_tensor, spmm, - to_torch_coo_tensor, ) from torch_geometric.utils.num_nodes import maybe_num_nodes @@ -43,12 +42,9 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, if is_sparse(edge_index): assert flow in ["source_to_target"] adj_t = edge_index - if is_torch_sparse_tensor(adj_t): - edge_index, edge_weight = gcn_norm(adj_t._indices(), - adj_t._values(), num_nodes, - improved, add_self_loops, flow, - dtype) - return to_torch_coo_tensor(edge_index, edge_weight) + is_torch_sparse = is_torch_sparse_tensor(adj_t) + if is_torch_sparse: + adj_t = SparseTensor.from_torch_sparse_coo_tensor(adj_t) if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) @@ -59,6 +55,9 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1)) adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) + + if is_torch_sparse: + adj_t = adj_t.to_torch_sparse_coo_tensor() return adj_t else: From 13ac255bebd5f56a73d02608708a94492127b711 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Tue, 22 Nov 2022 15:40:00 +0800 Subject: [PATCH 03/10] test --- test/nn/conv/test_gcn_conv.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/nn/conv/test_gcn_conv.py b/test/nn/conv/test_gcn_conv.py index 226acd98474e..231901a71682 100644 --- a/test/nn/conv/test_gcn_conv.py +++ b/test/nn/conv/test_gcn_conv.py @@ -33,6 +33,8 @@ def test_gcn_conv(): jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out1.tolist() assert jit(x, edge_index, value).tolist() == out2.tolist() + assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6) + assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6) t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) From bdd479737e94a8ce5161f2e2390e9744808deba3 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Thu, 1 Dec 2022 22:46:07 +0800 Subject: [PATCH 04/10] Update --- torch_geometric/nn/conv/gcn_conv.py | 38 +++++++++++++++++------------ 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/torch_geometric/nn/conv/gcn_conv.py b/torch_geometric/nn/conv/gcn_conv.py index 0d4dcf097513..1aa537f49aa1 100644 --- a/torch_geometric/nn/conv/gcn_conv.py +++ b/torch_geometric/nn/conv/gcn_conv.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -10,12 +10,12 @@ from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import zeros -from torch_geometric.typing import Adj, OptTensor, PairTensor +from torch_geometric.typing import Adj, OptPairTensor, OptTensor from torch_geometric.utils import ( add_remaining_self_loops, - is_sparse, is_torch_sparse_tensor, spmm, + to_torch_coo_tensor, ) from torch_geometric.utils.num_nodes import maybe_num_nodes @@ -23,7 +23,7 @@ @torch.jit._overload def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, add_self_loops=True, flow="source_to_target", dtype=None): - # type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> PairTensor # noqa + # type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> OptPairTensor # noqa pass @@ -39,13 +39,9 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, fill_value = 2. if improved else 1. - if is_sparse(edge_index): + if isinstance(edge_index, SparseTensor): assert flow in ["source_to_target"] adj_t = edge_index - is_torch_sparse = is_torch_sparse_tensor(adj_t) - if is_torch_sparse: - adj_t = SparseTensor.from_torch_sparse_coo_tensor(adj_t) - if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) if add_self_loops: @@ -56,11 +52,16 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1)) adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) - if is_torch_sparse: - adj_t = adj_t.to_torch_sparse_coo_tensor() return adj_t - else: + is_torch_sparse = is_torch_sparse_tensor(edge_index) + if is_torch_sparse: + assert flow == "source_to_target" + flow = "target_to_source" + num_nodes = num_nodes if edge_index.size(0) is None else num_nodes + edge_index, edge_weight = edge_index._indices( + ), edge_index._values() + assert flow in ["source_to_target", "target_to_source"] num_nodes = maybe_num_nodes(edge_index, num_nodes) @@ -79,7 +80,14 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) - return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + if is_torch_sparse: + adj_t = to_torch_coo_tensor(edge_index, edge_weight, + size=num_nodes) + return adj_t, None + else: + return edge_index, edge_weight class GCNConv(MessagePassing): @@ -139,7 +147,7 @@ class GCNConv(MessagePassing): - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ - _cached_edge_index: Optional[Tuple[Tensor, Tensor]] + _cached_edge_index: Optional[OptPairTensor] _cached_adj_t: Optional[SparseTensor] def __init__(self, in_channels: int, out_channels: int, @@ -181,7 +189,7 @@ def forward(self, x: Tensor, edge_index: Adj, """""" if self.normalize: - if is_sparse(edge_index): + if isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable From f5e9a65b4290f9e145c817001af42602ff2a3510 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Thu, 1 Dec 2022 22:47:08 +0800 Subject: [PATCH 05/10] Update --- torch_geometric/nn/conv/gcn_conv.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torch_geometric/nn/conv/gcn_conv.py b/torch_geometric/nn/conv/gcn_conv.py index 1aa537f49aa1..f5d20955100a 100644 --- a/torch_geometric/nn/conv/gcn_conv.py +++ b/torch_geometric/nn/conv/gcn_conv.py @@ -189,27 +189,27 @@ def forward(self, x: Tensor, edge_index: Adj, """""" if self.normalize: - if isinstance(edge_index, SparseTensor): - cache = self._cached_adj_t + if isinstance(edge_index, Tensor): + cache = self._cached_edge_index if cache is None: - edge_index = gcn_norm( # yapf: disable + edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops, self.flow, x.dtype) if self.cached: - self._cached_adj_t = edge_index + self._cached_edge_index = (edge_index, edge_weight) else: - edge_index = cache + edge_index, edge_weight = cache[0], cache[1] - elif isinstance(edge_index, Tensor): - cache = self._cached_edge_index + elif isinstance(edge_index, SparseTensor): + cache = self._cached_adj_t if cache is None: - edge_index, edge_weight = gcn_norm( # yapf: disable + edge_index = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(self.node_dim), self.improved, self.add_self_loops, self.flow, x.dtype) if self.cached: - self._cached_edge_index = (edge_index, edge_weight) + self._cached_adj_t = edge_index else: - edge_index, edge_weight = cache[0], cache[1] + edge_index = cache x = self.lin(x) From 70441da13bcea12a165aa61432ffd9c18631efdd Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Thu, 1 Dec 2022 23:00:31 +0800 Subject: [PATCH 06/10] Update --- torch_geometric/nn/conv/gcn_conv.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/torch_geometric/nn/conv/gcn_conv.py b/torch_geometric/nn/conv/gcn_conv.py index f5d20955100a..baaaa816040d 100644 --- a/torch_geometric/nn/conv/gcn_conv.py +++ b/torch_geometric/nn/conv/gcn_conv.py @@ -56,14 +56,16 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, else: is_torch_sparse = is_torch_sparse_tensor(edge_index) if is_torch_sparse: + adj_t = edge_index + num_nodes = adj_t.size(0) + edge_index = adj_t._indices() + edge_weight = adj_t._values() assert flow == "source_to_target" + # `adj_t` is transposed flow = "target_to_source" - num_nodes = num_nodes if edge_index.size(0) is None else num_nodes - edge_index, edge_weight = edge_index._indices( - ), edge_index._values() - - assert flow in ["source_to_target", "target_to_source"] - num_nodes = maybe_num_nodes(edge_index, num_nodes) + else: + assert flow in ["source_to_target", "target_to_source"] + num_nodes = maybe_num_nodes(edge_index, num_nodes) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, From e35cda468239a37ae8299965fce30a2b1e8fb06b Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Thu, 1 Dec 2022 23:03:43 +0800 Subject: [PATCH 07/10] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a0236a0a10f..e59c0180de48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.3.0] - 2023-MM-DD ### Added +- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033)) ### Changed - Optimized `utils.softmax` implementation ([#6113](https://github.com/pyg-team/pytorch_geometric/pull/6113)) ### Removed From 89e1c723ad80d8195965655ce1ea1c8bcbd21489 Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 5 Dec 2022 20:56:31 +0800 Subject: [PATCH 08/10] Update --- torch_geometric/nn/conv/gcn_conv.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_geometric/nn/conv/gcn_conv.py b/torch_geometric/nn/conv/gcn_conv.py index baaaa816040d..b7b3bedfe08c 100644 --- a/torch_geometric/nn/conv/gcn_conv.py +++ b/torch_geometric/nn/conv/gcn_conv.py @@ -54,15 +54,15 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, return adj_t else: - is_torch_sparse = is_torch_sparse_tensor(edge_index) - if is_torch_sparse: + is_sparse_tensor = is_torch_sparse_tensor(edge_index) + if is_sparse_tensor: + assert flow == "source_to_target" + # `adj_t` is transposed + flow = "target_to_source" adj_t = edge_index num_nodes = adj_t.size(0) edge_index = adj_t._indices() edge_weight = adj_t._values() - assert flow == "source_to_target" - # `adj_t` is transposed - flow = "target_to_source" else: assert flow in ["source_to_target", "target_to_source"] num_nodes = maybe_num_nodes(edge_index, num_nodes) @@ -84,7 +84,7 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] - if is_torch_sparse: + if is_sparse_tensor: adj_t = to_torch_coo_tensor(edge_index, edge_weight, size=num_nodes) return adj_t, None From f08b95b3d27151280adfbf6998af86b7901b1dad Mon Sep 17 00:00:00 2001 From: EdisonLeeeee Date: Mon, 5 Dec 2022 20:58:31 +0800 Subject: [PATCH 09/10] Update --- torch_geometric/nn/conv/gcn_conv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/conv/gcn_conv.py b/torch_geometric/nn/conv/gcn_conv.py index b7b3bedfe08c..3699dd54b732 100644 --- a/torch_geometric/nn/conv/gcn_conv.py +++ b/torch_geometric/nn/conv/gcn_conv.py @@ -3,7 +3,6 @@ import torch from torch import Tensor from torch.nn import Parameter -from torch_scatter import scatter_add from torch_sparse import SparseTensor, fill_diag, mul from torch_sparse import sum as sparsesum @@ -14,6 +13,7 @@ from torch_geometric.utils import ( add_remaining_self_loops, is_torch_sparse_tensor, + scatter, spmm, to_torch_coo_tensor, ) @@ -79,7 +79,7 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, row, col = edge_index[0], edge_index[1] idx = col if flow == "source_to_target" else row - deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes) + deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] From 2517f5cc8bce18bb4d8d7f9cfdcb84f6ffe2cbe0 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 6 Dec 2022 12:37:01 +0000 Subject: [PATCH 10/10] update --- test/nn/conv/test_gcn_conv.py | 2 - torch_geometric/nn/conv/gcn_conv.py | 64 ++++++++++++++--------------- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/test/nn/conv/test_gcn_conv.py b/test/nn/conv/test_gcn_conv.py index 231901a71682..226acd98474e 100644 --- a/test/nn/conv/test_gcn_conv.py +++ b/test/nn/conv/test_gcn_conv.py @@ -33,8 +33,6 @@ def test_gcn_conv(): jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out1.tolist() assert jit(x, edge_index, value).tolist() == out2.tolist() - assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6) - assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6) t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) diff --git a/torch_geometric/nn/conv/gcn_conv.py b/torch_geometric/nn/conv/gcn_conv.py index 3699dd54b732..d87399d3e398 100644 --- a/torch_geometric/nn/conv/gcn_conv.py +++ b/torch_geometric/nn/conv/gcn_conv.py @@ -40,7 +40,7 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, fill_value = 2. if improved else 1. if isinstance(edge_index, SparseTensor): - assert flow in ["source_to_target"] + assert flow == 'source_to_target' adj_t = edge_index if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) @@ -53,43 +53,43 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) return adj_t + + # `edge_index` can be a `torch.LongTensor` or `torch.sparse.Tensor`: + is_sparse_tensor = is_torch_sparse_tensor(edge_index) + if is_sparse_tensor: + assert flow == 'source_to_target' + # Reverse `flow` since sparse tensors model transposed adjacencies: + flow = 'target_to_source' + adj_t = edge_index + num_nodes = adj_t.size(0) + edge_index = adj_t._indices() + edge_weight = adj_t._values() else: - is_sparse_tensor = is_torch_sparse_tensor(edge_index) - if is_sparse_tensor: - assert flow == "source_to_target" - # `adj_t` is transposed - flow = "target_to_source" - adj_t = edge_index - num_nodes = adj_t.size(0) - edge_index = adj_t._indices() - edge_weight = adj_t._values() - else: - assert flow in ["source_to_target", "target_to_source"] - num_nodes = maybe_num_nodes(edge_index, num_nodes) + assert flow in ["source_to_target", "target_to_source"] + num_nodes = maybe_num_nodes(edge_index, num_nodes) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) - if add_self_loops: - edge_index, tmp_edge_weight = add_remaining_self_loops( - edge_index, edge_weight, fill_value, num_nodes) - assert tmp_edge_weight is not None - edge_weight = tmp_edge_weight - - row, col = edge_index[0], edge_index[1] - idx = col if flow == "source_to_target" else row - deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes) - deg_inv_sqrt = deg.pow_(-0.5) - deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) - edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] - - if is_sparse_tensor: - adj_t = to_torch_coo_tensor(edge_index, edge_weight, - size=num_nodes) - return adj_t, None - else: - return edge_index, edge_weight + if add_self_loops: + edge_index, tmp_edge_weight = add_remaining_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + assert tmp_edge_weight is not None + edge_weight = tmp_edge_weight + + row, col = edge_index[0], edge_index[1] + idx = col if flow == "source_to_target" else row + deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum') + deg_inv_sqrt = deg.pow_(-0.5) + deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) + edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + if is_sparse_tensor: + adj_t = to_torch_coo_tensor(edge_index, edge_weight, size=num_nodes) + return adj_t, None + else: + return edge_index, edge_weight class GCNConv(MessagePassing):