Skip to content

Commit dd22da2

Browse files
author
XinweiHe
committed
fix
1 parent 0c53d1c commit dd22da2

File tree

1 file changed

+24
-26
lines changed

1 file changed

+24
-26
lines changed

torch_geometric/explain/algorithm/attention_explainer.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,36 +33,14 @@ def __init__(self, reduce: str = 'max'):
3333
def forward(
3434
self,
3535
model: torch.nn.Module,
36-
x: Union[Tensor, Dict[NodeType, Tensor]],
37-
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
36+
x: Tensor,
37+
edge_index: Tensor,
3838
*,
3939
target: Tensor,
4040
index: Optional[Union[int, Tensor]] = None,
4141
**kwargs,
42-
) -> Union[Explanation, HeteroExplanation]:
43-
"""Generate explanations based on attention coefficients."""
44-
self.is_hetero = isinstance(x, dict)
45-
46-
# Collect attention coefficients
47-
alphas_dict = self._collect_attention_coefficients(
48-
model, x, edge_index, **kwargs)
49-
50-
# Process attention coefficients
51-
if self.is_hetero:
52-
return self._create_hetero_explanation(model, alphas_dict,
53-
edge_index, index, x)
54-
else:
55-
return self._create_homo_explanation(model, alphas_dict,
56-
edge_index, index, x)
57-
58-
def _collect_attention_coefficients(
59-
self,
60-
model: torch.nn.Module,
61-
x: Union[Tensor, Dict[NodeType, Tensor]],
62-
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
63-
**kwargs,
6442
) -> Explanation:
65-
pass
43+
...
6644

6745
@overload
6846
def forward(
@@ -75,7 +53,7 @@ def forward(
7553
index: Optional[Union[int, Tensor]] = None,
7654
**kwargs,
7755
) -> HeteroExplanation:
78-
pass
56+
...
7957

8058
def forward(
8159
self,
@@ -102,6 +80,26 @@ def forward(
10280
return self._create_homo_explanation(model, alphas_dict,
10381
edge_index, index, x)
10482

83+
@overload
84+
def _collect_attention_coefficients(
85+
self,
86+
model: torch.nn.Module,
87+
x: Tensor,
88+
edge_index: Tensor,
89+
**kwargs,
90+
) -> List[Tensor]:
91+
...
92+
93+
@overload
94+
def _collect_attention_coefficients(
95+
self,
96+
model: torch.nn.Module,
97+
x: Dict[NodeType, Tensor],
98+
edge_index: Dict[EdgeType, Tensor],
99+
**kwargs,
100+
) -> Dict[EdgeType, List[Tensor]]:
101+
...
102+
105103
def _collect_attention_coefficients(
106104
self,
107105
model: torch.nn.Module,

0 commit comments

Comments
 (0)