|
7 | 7 | from torch_geometric.nn.conv import MessagePassing |
8 | 8 | from torch_geometric.nn.dense.linear import Linear |
9 | 9 | from torch_geometric.nn.inits import glorot, zeros |
10 | | -from torch_geometric.typing import Adj, OptPairTensor, Size |
| 10 | +from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size |
11 | 11 |
|
12 | 12 |
|
13 | 13 | class GMMConv(MessagePassing): |
@@ -128,12 +128,16 @@ def reset_parameters(self): |
128 | 128 | zeros(self.bias) |
129 | 129 |
|
130 | 130 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, |
131 | | - edge_attr: Tensor, size: Size = None): |
132 | | - |
| 131 | + edge_attr: OptTensor = None, size: Size = None): |
133 | 132 | if isinstance(x, Tensor): |
134 | 133 | x = (x, x) |
135 | 134 |
|
136 | | - # propagate_type: (x: OptPairTensor, edge_attr: Tensor) |
| 135 | + if isinstance( |
| 136 | + edge_index, Tensor |
| 137 | + ) and edge_index.layout == torch.strided and edge_attr is None: |
| 138 | + raise ValueError('Strided edge indices require edge attributes') |
| 139 | + |
| 140 | + # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) |
137 | 141 | if not self.separate_gaussians: |
138 | 142 | out: OptPairTensor = (torch.matmul(x[0], self.g), x[1]) |
139 | 143 | out = self.propagate(edge_index, x=out, edge_attr=edge_attr, |
|
0 commit comments