99 ModelReturnType ,
1010 ModelTaskLevel ,
1111)
12- from torch_geometric .nn import GCNConv , global_add_pool
12+ from torch_geometric .nn import GATConv , global_add_pool
1313
1414
15- class GCN (torch .nn .Module ):
15+ class GAT (torch .nn .Module ):
1616 def __init__ (self , model_config : ModelConfig ):
1717 super ().__init__ ()
1818 self .model_config = model_config
@@ -22,8 +22,8 @@ def __init__(self, model_config: ModelConfig):
2222 else :
2323 out_channels = 1
2424
25- self .conv1 = GCNConv (3 , 16 )
26- self .conv2 = GCNConv (16 , out_channels )
25+ self .conv1 = GATConv (3 , 16 , heads = 2 )
26+ self .conv2 = GATConv (16 * 2 , out_channels , heads = 1 )
2727
2828 def forward (self , x , edge_index , batch = None , edge_label_index = None ):
2929 x = self .conv1 (x , edge_index ).relu ()
@@ -110,7 +110,7 @@ def test_graph_mask_explainer_binary_classification(
110110 return_type = return_type ,
111111 )
112112
113- model = GCN (model_config )
113+ model = GAT (model_config )
114114
115115 target = None
116116 if explanation_type == 'phenomenon' :
@@ -162,7 +162,7 @@ def test_graph_mask_explainer_multiclass_classification(
162162 return_type = return_type ,
163163 )
164164
165- model = GCN (model_config )
165+ model = GAT (model_config )
166166
167167 target = None
168168 if explanation_type == 'phenomenon' :
@@ -207,7 +207,7 @@ def test_graph_mask_explainer_regression(
207207 task_level = task_level ,
208208 )
209209
210- model = GCN (model_config )
210+ model = GAT (model_config )
211211
212212 target = None
213213 if explanation_type == 'phenomenon' :
0 commit comments