Skip to content

Commit 119a5ab

Browse files
committed
update
1 parent 480d64d commit 119a5ab

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

torch_geometric/nn/conv/gmm_conv.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch_geometric.nn.conv import MessagePassing
88
from torch_geometric.nn.dense.linear import Linear
99
from torch_geometric.nn.inits import glorot, zeros
10-
from torch_geometric.typing import Adj, OptPairTensor, Size
10+
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
1111

1212

1313
class GMMConv(MessagePassing):
@@ -128,12 +128,16 @@ def reset_parameters(self):
128128
zeros(self.bias)
129129

130130
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
131-
edge_attr: Tensor, size: Size = None):
132-
131+
edge_attr: OptTensor = None, size: Size = None):
133132
if isinstance(x, Tensor):
134133
x = (x, x)
135134

136-
# propagate_type: (x: OptPairTensor, edge_attr: Tensor)
135+
if isinstance(
136+
edge_index, Tensor
137+
) and edge_index.layout == torch.strided and edge_attr is None:
138+
raise ValueError('Strided edge indices require edge attributes')
139+
140+
# propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
137141
if not self.separate_gaussians:
138142
out: OptPairTensor = (torch.matmul(x[0], self.g), x[1])
139143
out = self.propagate(edge_index, x=out, edge_attr=edge_attr,

0 commit comments

Comments
 (0)