1212
1313if TYPE_CHECKING :
1414 from ConfigSpace import Configuration
15+ from shapiq import ValidApproximationIndices
1516
1617 from hypershap .utils import ConfigSpaceSearcher
1718
2021import matplotlib .pyplot as plt
2122import networkx as nx
2223import numpy as np
23- from shapiq import SHAPIQ , ExactComputer , InteractionValues , KernelSHAPIQ
24+ from shapiq import ExactComputer , InteractionValues
25+ from shapiq .explainer .configuration import setup_approximator_automatically
2426
2527from hypershap .games import (
2628 AblationGame ,
@@ -66,13 +68,13 @@ class HyperSHAP:
6668 __init__(explanation_task: ExplanationTask):
6769 Initializes the HyperSHAP instance with an explanation task.
6870
69- ablation(config_of_interest: Configuration, baseline_config: Configuration, index: str = "FSII", order: int = 2) -> InteractionValues:
71+ ablation(config_of_interest: Configuration, baseline_config: Configuration, index: ValidApproximationIndices = "FSII", order: int = 2) -> InteractionValues:
7072 Computes and returns the interaction values for ablation analysis.
7173
72- tunability(baseline_config: Configuration | None, index: str = "FSII", order: int = 2) -> InteractionValues:
74+ tunability(baseline_config: Configuration | None, index: ValidApproximationIndices = "FSII", order: int = 2) -> InteractionValues:
7375 Computes and returns the interaction values for tunability analysis.
7476
75- optimizer_bias(optimizer_of_interest: ConfigSpaceSearcher, optimizer_ensemble: list[ConfigSpaceSearcher], index: str = "FSII", order: int = 2) -> InteractionValues:
77+ optimizer_bias(optimizer_of_interest: ConfigSpaceSearcher, optimizer_ensemble: list[ConfigSpaceSearcher], index: ValidApproximationIndices = "FSII", order: int = 2) -> InteractionValues:
7678 Computes and returns the interaction values for optimizer bias analysis.
7779
7880 plot_si_graph(interaction_values: InteractionValues | None = None, save_path: str | None = None):
@@ -116,19 +118,22 @@ def __init__(
116118 )
117119 self .verbose = verbose
118120
119- def __get_interaction_values (self , game : AbstractHPIGame , index : str = "FSII" , order : int = 2 ) -> InteractionValues :
121+ def __get_interaction_values (
122+ self ,
123+ game : AbstractHPIGame ,
124+ index : ValidApproximationIndices = "FSII" ,
125+ order : int = 2 ,
126+ seed : int | None = 0 ,
127+ ) -> InteractionValues :
120128 if game .n_players <= EXACT_MAX_HYPERPARAMETERS :
121129 # instantiate exact computer if number of hyperparameters is small enough
122130 ec = ExactComputer (n_players = game .get_num_hyperparameters (), game = game ) # pyright: ignore
123131
124132 # compute interaction values with the given index and order
125133 interaction_values = ec (index = index , order = order )
126134 else :
127- # instantiate kernel
128- if index == "FSII" :
129- approx = SHAPIQ (n = game .n_players , max_order = 2 , index = index )
130- else :
131- approx = KernelSHAPIQ (n = game .n_players , max_order = 2 , index = index )
135+ # instantiate approximator
136+ approx = setup_approximator_automatically (index , order , game .n_players , seed )
132137
133138 # approximate interaction values with the given index and order
134139 interaction_values = approx (budget = self .approximation_budget , game = game )
@@ -142,15 +147,15 @@ def ablation(
142147 self ,
143148 config_of_interest : Configuration ,
144149 baseline_config : Configuration ,
145- index : str = "FSII" ,
150+ index : ValidApproximationIndices = "FSII" ,
146151 order : int = 2 ,
147152 ) -> InteractionValues :
148153 """Compute and return the interaction values for ablation analysis.
149154
150155 Args:
151156 config_of_interest (Configuration): The configuration of interest.
152157 baseline_config (Configuration): The baseline configuration.
153- index (str , optional): The index to use for computing interaction values. Defaults to "FSII".
158+ index (ValidApproximationIndices , optional): The index to use for computing interaction values. Defaults to "FSII".
154159 order (int, optional): The order of the interaction values. Defaults to 2.
155160
156161 Returns:
@@ -191,7 +196,7 @@ def ablation_multibaseline(
191196 config_of_interest : Configuration ,
192197 baseline_configs : list [Configuration ],
193198 aggregation : Aggregation = Aggregation .AVG ,
194- index : str = "FSII" ,
199+ index : ValidApproximationIndices = "FSII" ,
195200 order : int = 2 ,
196201 ) -> InteractionValues :
197202 """Compute and return the interaction values for multi-baseline ablation analysis.
@@ -200,7 +205,7 @@ def ablation_multibaseline(
200205 config_of_interest (Configuration): The configuration of interest.
201206 baseline_configs (list[Configuration]): The list of baseline configurations.
202207 aggregation (Aggregation): The aggregation method to use for computing interaction values.
203- index (str , optional): The index to use for computing interaction values. Defaults to "FSII".
208+ index (ValidApproximationIndices , optional): The index to use for computing interaction values. Defaults to "FSII".
204209 order (int, optional): The order of the interaction values. Defaults to 2.
205210
206211 Returns:
@@ -240,9 +245,10 @@ def ablation_multibaseline(
240245 def tunability (
241246 self ,
242247 baseline_config : Configuration | None = None ,
243- index : str = "FSII" ,
248+ index : ValidApproximationIndices = "FSII" ,
244249 order : int = 2 ,
245250 n_samples : int = 10_000 ,
251+ seed : int | None = 0 ,
246252 ) -> InteractionValues :
247253 """Compute and return the interaction values for tunability analysis.
248254
@@ -251,6 +257,7 @@ def tunability(
251257 index (str, optional): The index to use for computing interaction values. Defaults to "FSII".
252258 order (int, optional): The order of the interaction values. Defaults to 2.
253259 n_samples (int, optional): The number of samples to use for simulating HPO. Defaults to 10_000.
260+ seed (int, optiona): The random seed for simulating HPO. Defaults to 0.
254261
255262 Returns:
256263 InteractionValues: The computed interaction values.
@@ -278,6 +285,7 @@ def tunability(
278285 explanation_task = tunability_task ,
279286 n_samples = n_samples ,
280287 mode = Aggregation .MAX ,
288+ seed = seed ,
281289 ),
282290 n_workers = self .n_workers ,
283291 verbose = self .verbose ,
@@ -295,9 +303,10 @@ def tunability(
295303 def sensitivity (
296304 self ,
297305 baseline_config : Configuration | None = None ,
298- index : str = "FSII" ,
306+ index : ValidApproximationIndices = "FSII" ,
299307 order : int = 2 ,
300308 n_samples : int = 10_000 ,
309+ seed : int | None = 0 ,
301310 ) -> InteractionValues :
302311 """Compute and return the interaction values for sensitivity analysis.
303312
@@ -306,6 +315,7 @@ def sensitivity(
306315 index (str, optional): The index to use for computing interaction values. Defaults to "FSII".
307316 order (int, optional): The order of the interaction values. Defaults to 2.
308317 n_samples (int, optional): The number of samples to use for simulating HPO. Defaults to 10_000.
318+ seed (int, optiona): The random seed for simulating HPO. Defaults to 0.
309319
310320 Returns:
311321 InteractionValues: The computed interaction values.
@@ -333,6 +343,7 @@ def sensitivity(
333343 explanation_task = sensitivity_task ,
334344 n_samples = n_samples ,
335345 mode = Aggregation .VAR ,
346+ seed = seed ,
336347 ),
337348 n_workers = self .n_workers ,
338349 verbose = self .verbose ,
@@ -350,17 +361,19 @@ def sensitivity(
350361 def mistunability (
351362 self ,
352363 baseline_config : Configuration | None = None ,
353- index : str = "FSII" ,
364+ index : ValidApproximationIndices = "FSII" ,
354365 order : int = 2 ,
355366 n_samples : int = 10_000 ,
367+ seed : int | None = 0 ,
356368 ) -> InteractionValues :
357369 """Compute and return the interaction values for mistunability analysis.
358370
359371 Args:
360372 baseline_config (Configuration | None, optional): The baseline configuration. Defaults to None.
361- index (str , optional): The index to use for computing interaction values. Defaults to "FSII".
373+ index (ValidApproximationIndices , optional): The index to use for computing interaction values. Defaults to "FSII".
362374 order (int, optional): The order of the interaction values. Defaults to 2.
363375 n_samples (int, optional): The number of samples to use for simulating HPO. Defaults to 10_000.
376+ seed (int, optiona): The random seed for simulating HPO. Defaults to 0.
364377
365378 Returns:
366379 InteractionValues: The computed interaction values.
@@ -388,6 +401,7 @@ def mistunability(
388401 explanation_task = mistunability_task ,
389402 n_samples = n_samples ,
390403 mode = Aggregation .MIN ,
404+ seed = seed ,
391405 ),
392406 n_workers = self .n_workers ,
393407 verbose = self .verbose ,
@@ -405,15 +419,15 @@ def optimizer_bias(
405419 self ,
406420 optimizer_of_interest : ConfigSpaceSearcher ,
407421 optimizer_ensemble : list [ConfigSpaceSearcher ],
408- index : str = "FSII" ,
422+ index : ValidApproximationIndices = "FSII" ,
409423 order : int = 2 ,
410424 ) -> InteractionValues :
411425 """Compute and return the interaction values for optimizer bias analysis.
412426
413427 Args:
414428 optimizer_of_interest (ConfigSpaceSearcher): The optimizer of interest.
415429 optimizer_ensemble (list[ConfigSpaceSearcher]): The ensemble of optimizers.
416- index (str , optional): The index to use for computing interaction values. Defaults to "FSII".
430+ index (ValidApproximationIndices , optional): The index to use for computing interaction values. Defaults to "FSII".
417431 order (int, optional): The order of the interaction values. Defaults to 2.
418432
419433 Returns:
0 commit comments