1212
1313if TYPE_CHECKING :
1414 from ConfigSpace import Configuration
15+ from shapiq import ValidApproximationIndices
1516
1617 from hypershap .utils import ConfigSpaceSearcher
1718
2021import matplotlib .pyplot as plt
2122import networkx as nx
2223import numpy as np
23- from shapiq import SHAPIQ , ExactComputer , InteractionValues , KernelSHAPIQ
24+ from shapiq import ExactComputer , InteractionValues
25+ from shapiq .explainer .configuration import setup_approximator_automatically
2426
2527from hypershap .games import (
2628 AblationGame ,
@@ -66,13 +68,13 @@ class HyperSHAP:
6668 __init__(explanation_task: ExplanationTask):
6769 Initializes the HyperSHAP instance with an explanation task.
6870
69- ablation(config_of_interest: Configuration, baseline_config: Configuration, index: str = "FSII", order: int = 2) -> InteractionValues:
71+ ablation(config_of_interest: Configuration, baseline_config: Configuration, index: ValidApproximationIndices = "FSII", order: int = 2) -> InteractionValues:
7072 Computes and returns the interaction values for ablation analysis.
7173
72- tunability(baseline_config: Configuration | None, index: str = "FSII", order: int = 2) -> InteractionValues:
74+ tunability(baseline_config: Configuration | None, index: ValidApproximationIndices = "FSII", order: int = 2) -> InteractionValues:
7375 Computes and returns the interaction values for tunability analysis.
7476
75- optimizer_bias(optimizer_of_interest: ConfigSpaceSearcher, optimizer_ensemble: list[ConfigSpaceSearcher], index: str = "FSII", order: int = 2) -> InteractionValues:
77+ optimizer_bias(optimizer_of_interest: ConfigSpaceSearcher, optimizer_ensemble: list[ConfigSpaceSearcher], index: ValidApproximationIndices = "FSII", order: int = 2) -> InteractionValues:
7678 Computes and returns the interaction values for optimizer bias analysis.
7779
7880 plot_si_graph(interaction_values: InteractionValues | None = None, save_path: str | None = None):
@@ -116,19 +118,22 @@ def __init__(
116118 )
117119 self .verbose = verbose
118120
119- def __get_interaction_values (self , game : AbstractHPIGame , index : str = "FSII" , order : int = 2 ) -> InteractionValues :
121+ def __get_interaction_values (
122+ self ,
123+ game : AbstractHPIGame ,
124+ index : ValidApproximationIndices = "FSII" ,
125+ order : int = 2 ,
126+ seed : int | None = 0 ,
127+ ) -> InteractionValues :
120128 if game .n_players <= EXACT_MAX_HYPERPARAMETERS :
121129 # instantiate exact computer if number of hyperparameters is small enough
122130 ec = ExactComputer (n_players = game .get_num_hyperparameters (), game = game ) # pyright: ignore
123131
124132 # compute interaction values with the given index and order
125133 interaction_values = ec (index = index , order = order )
126134 else :
127- # instantiate kernel
128- if index == "FSII" :
129- approx = SHAPIQ (n = game .n_players , max_order = 2 , index = index )
130- else :
131- approx = KernelSHAPIQ (n = game .n_players , max_order = 2 , index = index )
135+ # instantiate approximator
136+ approx = setup_approximator_automatically (index , order , game .n_players , seed )
132137
133138 # approximate interaction values with the given index and order
134139 interaction_values = approx (budget = self .approximation_budget , game = game )
@@ -142,15 +147,15 @@ def ablation(
142147 self ,
143148 config_of_interest : Configuration ,
144149 baseline_config : Configuration ,
145- index : str = "FSII" ,
150+ index : ValidApproximationIndices = "FSII" ,
146151 order : int = 2 ,
147152 ) -> InteractionValues :
148153 """Compute and return the interaction values for ablation analysis.
149154
150155 Args:
151156 config_of_interest (Configuration): The configuration of interest.
152157 baseline_config (Configuration): The baseline configuration.
153- index (str , optional): The index to use for computing interaction values. Defaults to "FSII".
158+ index (ValidApproximationIndices , optional): The index to use for computing interaction values. Defaults to "FSII".
154159 order (int, optional): The order of the interaction values. Defaults to 2.
155160
156161 Returns:
@@ -191,7 +196,7 @@ def ablation_multibaseline(
191196 config_of_interest : Configuration ,
192197 baseline_configs : list [Configuration ],
193198 aggregation : Aggregation = Aggregation .AVG ,
194- index : str = "FSII" ,
199+ index : ValidApproximationIndices = "FSII" ,
195200 order : int = 2 ,
196201 ) -> InteractionValues :
197202 """Compute and return the interaction values for multi-baseline ablation analysis.
@@ -200,7 +205,7 @@ def ablation_multibaseline(
200205 config_of_interest (Configuration): The configuration of interest.
201206 baseline_configs (list[Configuration]): The list of baseline configurations.
202207 aggregation (Aggregation): The aggregation method to use for computing interaction values.
203- index (str , optional): The index to use for computing interaction values. Defaults to "FSII".
208+ index (ValidApproximationIndices , optional): The index to use for computing interaction values. Defaults to "FSII".
204209 order (int, optional): The order of the interaction values. Defaults to 2.
205210
206211 Returns:
@@ -240,7 +245,7 @@ def ablation_multibaseline(
240245 def tunability (
241246 self ,
242247 baseline_config : Configuration | None = None ,
243- index : str = "FSII" ,
248+ index : ValidApproximationIndices = "FSII" ,
244249 order : int = 2 ,
245250 n_samples : int = 10_000 ,
246251 seed : int | None = 0 ,
@@ -298,7 +303,7 @@ def tunability(
298303 def sensitivity (
299304 self ,
300305 baseline_config : Configuration | None = None ,
301- index : str = "FSII" ,
306+ index : ValidApproximationIndices = "FSII" ,
302307 order : int = 2 ,
303308 n_samples : int = 10_000 ,
304309 seed : int | None = 0 ,
@@ -356,7 +361,7 @@ def sensitivity(
356361 def mistunability (
357362 self ,
358363 baseline_config : Configuration | None = None ,
359- index : str = "FSII" ,
364+ index : ValidApproximationIndices = "FSII" ,
360365 order : int = 2 ,
361366 n_samples : int = 10_000 ,
362367 seed : int | None = 0 ,
@@ -365,7 +370,7 @@ def mistunability(
365370
366371 Args:
367372 baseline_config (Configuration | None, optional): The baseline configuration. Defaults to None.
368- index (str , optional): The index to use for computing interaction values. Defaults to "FSII".
373+ index (ValidApproximationIndices , optional): The index to use for computing interaction values. Defaults to "FSII".
369374 order (int, optional): The order of the interaction values. Defaults to 2.
370375 n_samples (int, optional): The number of samples to use for simulating HPO. Defaults to 10_000.
371376 seed (int, optiona): The random seed for simulating HPO. Defaults to 0.
@@ -414,15 +419,15 @@ def optimizer_bias(
414419 self ,
415420 optimizer_of_interest : ConfigSpaceSearcher ,
416421 optimizer_ensemble : list [ConfigSpaceSearcher ],
417- index : str = "FSII" ,
422+ index : ValidApproximationIndices = "FSII" ,
418423 order : int = 2 ,
419424 ) -> InteractionValues :
420425 """Compute and return the interaction values for optimizer bias analysis.
421426
422427 Args:
423428 optimizer_of_interest (ConfigSpaceSearcher): The optimizer of interest.
424429 optimizer_ensemble (list[ConfigSpaceSearcher]): The ensemble of optimizers.
425- index (str , optional): The index to use for computing interaction values. Defaults to "FSII".
430+ index (ValidApproximationIndices , optional): The index to use for computing interaction values. Defaults to "FSII".
426431 order (int, optional): The order of the interaction values. Defaults to 2.
427432
428433 Returns:
0 commit comments