@@ -33,36 +33,14 @@ def __init__(self, reduce: str = 'max'):
3333 def forward (
3434 self ,
3535 model : torch .nn .Module ,
36- x : Union [ Tensor , Dict [ NodeType , Tensor ]] ,
37- edge_index : Union [ Tensor , Dict [ EdgeType , Tensor ]] ,
36+ x : Tensor ,
37+ edge_index : Tensor ,
3838 * ,
3939 target : Tensor ,
4040 index : Optional [Union [int , Tensor ]] = None ,
4141 ** kwargs ,
42- ) -> Union [Explanation , HeteroExplanation ]:
43- """Generate explanations based on attention coefficients."""
44- self .is_hetero = isinstance (x , dict )
45-
46- # Collect attention coefficients
47- alphas_dict = self ._collect_attention_coefficients (
48- model , x , edge_index , ** kwargs )
49-
50- # Process attention coefficients
51- if self .is_hetero :
52- return self ._create_hetero_explanation (model , alphas_dict ,
53- edge_index , index , x )
54- else :
55- return self ._create_homo_explanation (model , alphas_dict ,
56- edge_index , index , x )
57-
58- def _collect_attention_coefficients (
59- self ,
60- model : torch .nn .Module ,
61- x : Union [Tensor , Dict [NodeType , Tensor ]],
62- edge_index : Union [Tensor , Dict [EdgeType , Tensor ]],
63- ** kwargs ,
6442 ) -> Explanation :
65- pass
43+ ...
6644
6745 @overload
6846 def forward (
@@ -75,7 +53,7 @@ def forward(
7553 index : Optional [Union [int , Tensor ]] = None ,
7654 ** kwargs ,
7755 ) -> HeteroExplanation :
78- pass
56+ ...
7957
8058 def forward (
8159 self ,
@@ -102,6 +80,26 @@ def forward(
10280 return self ._create_homo_explanation (model , alphas_dict ,
10381 edge_index , index , x )
10482
83+ @overload
84+ def _collect_attention_coefficients (
85+ self ,
86+ model : torch .nn .Module ,
87+ x : Tensor ,
88+ edge_index : Tensor ,
89+ ** kwargs ,
90+ ) -> List [Tensor ]:
91+ ...
92+
93+ @overload
94+ def _collect_attention_coefficients (
95+ self ,
96+ model : torch .nn .Module ,
97+ x : Dict [NodeType , Tensor ],
98+ edge_index : Dict [EdgeType , Tensor ],
99+ ** kwargs ,
100+ ) -> Dict [EdgeType , List [Tensor ]]:
101+ ...
102+
105103 def _collect_attention_coefficients (
106104 self ,
107105 model : torch .nn .Module ,
0 commit comments