Skip to content

Commit f08b95b

Browse files
committed
Update
1 parent 89e1c72 commit f08b95b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torch_geometric/nn/conv/gcn_conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44
from torch import Tensor
55
from torch.nn import Parameter
6-
from torch_scatter import scatter_add
76
from torch_sparse import SparseTensor, fill_diag, mul
87
from torch_sparse import sum as sparsesum
98

@@ -14,6 +13,7 @@
1413
from torch_geometric.utils import (
1514
add_remaining_self_loops,
1615
is_torch_sparse_tensor,
16+
scatter,
1717
spmm,
1818
to_torch_coo_tensor,
1919
)
@@ -79,7 +79,7 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
7979

8080
row, col = edge_index[0], edge_index[1]
8181
idx = col if flow == "source_to_target" else row
82-
deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
82+
deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes)
8383
deg_inv_sqrt = deg.pow_(-0.5)
8484
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
8585
edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

0 commit comments

Comments
 (0)