1414
1515
1616def explain_message (self , out : Tensor , x_i : Tensor , x_j : Tensor ) -> Tensor :
17+ orig_size = out .size ()
18+ if out .ndim == 3 :
19+ out = out .view (out .size (0 ), - 1 )
20+ if x_i .ndim == 3 :
21+ x_i = x_i .view (x_i .size (0 ), - 1 )
22+ if x_j .ndim == 3 :
23+ x_j = x_j .view (x_j .size (0 ), - 1 )
1724 basis_messages = F .layer_norm (out , (out .size (- 1 ), )).relu ()
1825
1926 if getattr (self , 'message_scale' , None ) is not None :
@@ -33,6 +40,8 @@ def explain_message(self, out: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor:
3340 self .latest_source_embeddings = x_j
3441 self .latest_target_embeddings = x_i
3542
43+ if len (orig_size ) == 3 :
44+ basis_messages = basis_messages .view (orig_size [0 ], orig_size [1 ], - 1 )
3645 return basis_messages
3746
3847
@@ -194,8 +203,13 @@ def _set_masks(self, x: Tensor):
194203 self .node_feat_mask = torch .nn .Parameter (
195204 torch .randn (1 , num_feat , device = device ) * std )
196205
197- def _set_trainable (self , i_dims : List [int ], j_dims : List [int ],
198- h_dims : List [int ], device : torch .device ):
206+ def _set_trainable (
207+ self ,
208+ i_dims : List [int ],
209+ j_dims : List [int ],
210+ h_dims : List [int ],
211+ device : torch .device ,
212+ ):
199213 baselines , self .gates , full_biases = [], torch .nn .ModuleList (), []
200214 zipped = zip (i_dims , j_dims , h_dims )
201215
@@ -361,7 +375,13 @@ def _train_explainer(
361375 for module in model .modules ():
362376 if isinstance (module , MessagePassing ):
363377 input_dims .append (module .in_channels )
364- output_dims .append (module .out_channels )
378+ if hasattr (module , 'heads' ):
379+ heads = module .heads
380+ else :
381+ heads = 1
382+ # If multihead attention is used, the output channels are
383+ # multiplied by the number of heads
384+ output_dims .append (module .out_channels * heads )
365385
366386 self ._set_masks (x )
367387 self ._set_trainable (input_dims , output_dims , output_dims , x .device )
@@ -405,6 +425,8 @@ def _train_explainer(
405425 output = self .full_biases [i ]
406426 for j in range (len (gate_input )):
407427 input = gate_input [j ][i ]
428+ if input .ndim == 3 :
429+ input = input .view (input .size (0 ), - 1 )
408430 try :
409431 partial = self .gates [i * 4 ][j ](input )
410432 except Exception :
0 commit comments