Skip to content

Commit 1940194

Browse files
author
Zecheng Zhang
committed
Update
1 parent 002c7fd commit 1940194

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

torch_geometric/explain/algorithm/graphmask_explainer.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414

1515

1616
def explain_message(self, out: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor:
17+
orig_size = out.size()
18+
if out.ndim == 3:
19+
out = out.view(out.size(0), -1)
20+
if x_i.ndim == 3:
21+
x_i = x_i.view(x_i.size(0), -1)
22+
if x_j.ndim == 3:
23+
x_j = x_j.view(x_j.size(0), -1)
1724
basis_messages = F.layer_norm(out, (out.size(-1), )).relu()
1825

1926
if getattr(self, 'message_scale', None) is not None:
@@ -33,6 +40,8 @@ def explain_message(self, out: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor:
3340
self.latest_source_embeddings = x_j
3441
self.latest_target_embeddings = x_i
3542

43+
if len(orig_size) == 3:
44+
basis_messages = basis_messages.view(orig_size[0], orig_size[1], -1)
3645
return basis_messages
3746

3847

@@ -194,8 +203,13 @@ def _set_masks(self, x: Tensor):
194203
self.node_feat_mask = torch.nn.Parameter(
195204
torch.randn(1, num_feat, device=device) * std)
196205

197-
def _set_trainable(self, i_dims: List[int], j_dims: List[int],
198-
h_dims: List[int], device: torch.device):
206+
def _set_trainable(
207+
self,
208+
i_dims: List[int],
209+
j_dims: List[int],
210+
h_dims: List[int],
211+
device: torch.device,
212+
):
199213
baselines, self.gates, full_biases = [], torch.nn.ModuleList(), []
200214
zipped = zip(i_dims, j_dims, h_dims)
201215

@@ -361,7 +375,13 @@ def _train_explainer(
361375
for module in model.modules():
362376
if isinstance(module, MessagePassing):
363377
input_dims.append(module.in_channels)
364-
output_dims.append(module.out_channels)
378+
if hasattr(module, 'heads'):
379+
heads = module.heads
380+
else:
381+
heads = 1
382+
# If multihead attention is used, the output channels are
383+
# multiplied by the number of heads
384+
output_dims.append(module.out_channels * heads)
365385

366386
self._set_masks(x)
367387
self._set_trainable(input_dims, output_dims, output_dims, x.device)
@@ -405,6 +425,8 @@ def _train_explainer(
405425
output = self.full_biases[i]
406426
for j in range(len(gate_input)):
407427
input = gate_input[j][i]
428+
if input.ndim == 3:
429+
input = input.view(input.size(0), -1)
408430
try:
409431
partial = self.gates[i * 4][j](input)
410432
except Exception:

0 commit comments

Comments
 (0)