Skip to content

Commit 045070b

Browse files
author
Zecheng Zhang
committed
Update
1 parent 1940194 commit 045070b

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

test/explain/algorithm/test_graphmask_explainer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
ModelReturnType,
1010
ModelTaskLevel,
1111
)
12-
from torch_geometric.nn import GCNConv, global_add_pool
12+
from torch_geometric.nn import GATConv, global_add_pool
1313

1414

15-
class GCN(torch.nn.Module):
15+
class GAT(torch.nn.Module):
1616
def __init__(self, model_config: ModelConfig):
1717
super().__init__()
1818
self.model_config = model_config
@@ -22,8 +22,8 @@ def __init__(self, model_config: ModelConfig):
2222
else:
2323
out_channels = 1
2424

25-
self.conv1 = GCNConv(3, 16)
26-
self.conv2 = GCNConv(16, out_channels)
25+
self.conv1 = GATConv(3, 16, heads=2)
26+
self.conv2 = GATConv(16 * 2, out_channels, heads=1)
2727

2828
def forward(self, x, edge_index, batch=None, edge_label_index=None):
2929
x = self.conv1(x, edge_index).relu()
@@ -110,7 +110,7 @@ def test_graph_mask_explainer_binary_classification(
110110
return_type=return_type,
111111
)
112112

113-
model = GCN(model_config)
113+
model = GAT(model_config)
114114

115115
target = None
116116
if explanation_type == 'phenomenon':
@@ -162,7 +162,7 @@ def test_graph_mask_explainer_multiclass_classification(
162162
return_type=return_type,
163163
)
164164

165-
model = GCN(model_config)
165+
model = GAT(model_config)
166166

167167
target = None
168168
if explanation_type == 'phenomenon':
@@ -207,7 +207,7 @@ def test_graph_mask_explainer_regression(
207207
task_level=task_level,
208208
)
209209

210-
model = GCN(model_config)
210+
model = GAT(model_config)
211211

212212
target = None
213213
if explanation_type == 'phenomenon':

0 commit comments

Comments
 (0)