Skip to content

Commit 1b78b95

Browse files
authored
Merge pull request #23 from automl/feature/pseudo-randomness
Feature/pseudo randomness
2 parents 8892042 + 646dd0a commit 1b78b95

File tree

6 files changed

+33
-7
lines changed

6 files changed

+33
-7
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# v0.0.4
2+
- Added pseudorandomization
3+
- Added index-specific approximation
4+
15
# v0.0.3
26
- Added multi-baseline ablation game. This game computes ablation paths with respect to multiple baseline configurations and aggregates values for different paths via mean, min, max or variance.
37
- Added waterfall plots to the HyperSHAP interface.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "hypershap"
3-
version = "0.0.2"
3+
version = "0.0.4"
44
description = "HyperSHAP is a post-hoc explanation method for hyperparameter optimization."
55
authors = [{ name = "Marcel Wever", email = "[email protected]" }]
66
readme = "README.md"

src/hypershap/surrogate_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def __init__(
182182
config_space: ConfigurationSpace,
183183
data: list[tuple[Configuration, float]],
184184
base_model: BaseEstimator | None = None,
185+
seed: int | None = 0,
185186
) -> None:
186187
"""Initialize the DataBasedSurrogateModel with data and an optional base model.
187188
@@ -191,13 +192,14 @@ def __init__(
191192
is a tuple of (Configuration, float).
192193
base_model: The base model to be used for fitting the surrogate model.
193194
If None, a RandomForestRegressor is used.
195+
seed: The random seed for pseudo-randomization of the surrogate model. Defaults to 0.
194196
195197
"""
196198
train_x = np.array([obs[0].get_array() for obs in data])
197199
train_y = np.array([obs[1] for obs in data])
198200

199201
if base_model is None:
200-
base_model = RandomForestRegressor()
202+
base_model = RandomForestRegressor(random_state=seed)
201203

202204
pipeline = cast("SklearnRegressorProtocol", base_model)
203205
pipeline.fit(train_x, train_y)

src/hypershap/task.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from hypershap import ConfigSpaceSearcher
1919

20+
from copy import deepcopy
21+
2022
from sklearn.ensemble import RandomForestRegressor
2123

2224
from hypershap.surrogate_model import DataBasedSurrogateModel, ModelBasedSurrogateModel, SurrogateModel
@@ -132,6 +134,7 @@ def from_function(
132134
function: Callable[[Configuration], float],
133135
n_samples: int = 1_000,
134136
base_model: BaseEstimator | None = None,
137+
seed: int | None = 0,
135138
) -> ExplanationTask:
136139
"""Create an ExplanationTask from a function that evaluates configurations.
137140
@@ -140,17 +143,21 @@ def from_function(
140143
function: A callable that takes a configuration and returns its performance.
141144
n_samples: The number of configurations to sample for training the surrogate model. Defaults to 1000.
142145
base_model: The base model to use for training the surrogate model. Defaults to RandomForestRegressor.
146+
seed: The seed for the random number generator, it is used to seed a deep copy of the config space.
143147
144148
Returns:
145149
An ExplanationTask instance.
146150
147151
"""
148-
samples: list[Configuration] = config_space.sample_configuration(n_samples)
152+
cs = deepcopy(config_space)
153+
if seed is not None:
154+
cs.seed(seed)
155+
samples: list[Configuration] = cs.sample_configuration(n_samples)
149156
values: list[float] = [function(config) for config in samples]
150157
data: list[tuple[Configuration, float]] = list(zip(samples, values, strict=False))
151-
base_model = base_model if base_model is not None else RandomForestRegressor()
158+
base_model = base_model if base_model is not None else RandomForestRegressor(random_state=seed)
152159

153-
return ExplanationTask.from_data(config_space=config_space, data=data, base_model=base_model)
160+
return ExplanationTask.from_data(config_space=cs, data=data, base_model=base_model)
154161

155162
@staticmethod
156163
def from_function_multidata(

src/hypershap/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
from abc import ABC, abstractmethod
9+
from copy import deepcopy
910
from enum import Enum
1011
from typing import TYPE_CHECKING
1112

@@ -83,6 +84,7 @@ def __init__(
8384
explanation_task: BaselineExplanationTask,
8485
mode: Aggregation = Aggregation.MAX,
8586
n_samples: int = 10_000,
87+
seed: int | None = 0,
8688
) -> None:
8789
"""Initialize the random configuration space searcher.
8890
@@ -91,11 +93,14 @@ def __init__(
9193
space and surrogate model.
9294
mode: The aggregation mode for performance values.
9395
n_samples: The number of configurations to sample.
96+
seed: The random seed for sampling configurations from the config space.
9497
9598
"""
9699
super().__init__(explanation_task, mode=mode)
97-
98-
sampled_configurations = self.explanation_task.config_space.sample_configuration(size=n_samples)
100+
cs = deepcopy(explanation_task.config_space)
101+
if seed is not None:
102+
cs.seed(seed)
103+
sampled_configurations = cs.sample_configuration(size=n_samples)
99104
self.random_sample = np.array([config.get_array() for config in sampled_configurations])
100105

101106
# cache coalition values to ensure monotonicity for min/max

tests/test_extended_settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ def test_large_ablation(large_base_et: ExplanationTask) -> None:
2626
hypershap.ablation(comparison, baseline)
2727

2828

29+
def test_large_ablation_kernelshap(large_base_et: ExplanationTask) -> None:
30+
"""Test HyperSHAP with large config space."""
31+
baseline = large_base_et.config_space.sample_configuration()
32+
comparison = large_base_et.config_space.sample_configuration()
33+
hypershap = HyperSHAP(explanation_task=large_base_et, approximation_budget=2**7)
34+
hypershap.ablation(comparison, baseline, index="k-SII")
35+
36+
2937
def test_multi_data_ablation(
3038
multi_data_baseline_config: Configuration,
3139
multi_data_config_space: ConfigurationSpace,

0 commit comments

Comments
 (0)