Skip to content

Commit ee1ee52

Browse files
authored
Merge pull request #34 from automl/improvement/monotonicity_conditions_fallback
Improvement/monotonicity conditions fallback
2 parents c767532 + 05b67c3 commit ee1ee52

File tree

5 files changed

+135
-3
lines changed

5 files changed

+135
-3
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# v0.0.6
2+
3+
## Improvements
4+
- Added fallback for configuration spaces with conditions resulting in all configurations being filtered out.
5+
- Added caching and a function in RandomConfigSpaceSearcher to ensure monotonicity of the value function.
6+
17
# v0.0.5
28

39
## Features

src/hypershap/utils.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,67 @@ def search(self, coalition: np.ndarray) -> float:
161161
)
162162

163163
# predict performance values with the help of the surrogate model for the filtered configurations
164-
vals: np.ndarray = np.array(self.explanation_task.get_single_surrogate_model().evaluate(filtered_samples))
164+
if len(filtered_samples) > 0:
165+
vals: np.ndarray = np.array(
166+
self.explanation_task.get_single_surrogate_model().evaluate(filtered_samples),
167+
)
168+
else:
169+
logger.warning(
170+
"WARNING: After filtering for conditions, no configurations were left, thus, using baseline value.",
171+
)
172+
vals = np.array([self.search(np.array([False] * len(coalition)))])
165173
else:
166174
vals: np.ndarray = np.array(self.explanation_task.get_single_surrogate_model().evaluate(temp_random_sample))
167175

168-
return evaluate_aggregation(self.mode, vals)
176+
# determine the final, aggregated value of the coalition
177+
res = evaluate_aggregation(self.mode, vals)
178+
179+
# in case we are maximizing or minimizing, ensure that the value function is monotone
180+
if self.mode in (Aggregation.MAX, Aggregation.MIN):
181+
res = self._ensure_monotonicity(coalition, res)
182+
183+
# cache the coalition's value
184+
self.coalition_cache[str(coalition.tolist())] = res
185+
186+
return res
187+
188+
def _ensure_monotonicity(self, coalition: np.ndarray, value: float) -> float:
189+
"""Ensure that the value function is monotonically increasing/decreasing depending on whether we want to maximize or minimize respectively.
190+
191+
Args:
192+
coalition: The current coalition.
193+
value: The value of the coalition as determined by searching.
194+
195+
Returns: The monotonicity-ensured value of the coalition.
196+
197+
"""
198+
monotone_value = value
199+
checked_one = False
200+
201+
for i in range(len(coalition)):
202+
if coalition[i]: # check whether the entry is True and set it to False to check for a cached result
203+
temp_coalition = coalition.copy()
204+
temp_coalition[i] = False
205+
if str(temp_coalition.tolist()) in self.coalition_cache:
206+
checked_one = True
207+
monotone_value = evaluate_aggregation(
208+
self.mode,
209+
np.array([
210+
monotone_value,
211+
self.coalition_cache[str(temp_coalition.tolist())],
212+
]),
213+
)
214+
215+
if not checked_one and coalition.any():
216+
logger.warning(
217+
"Could not ensure monotonicity as none of the coalitions with one player less has been cached so far.",
218+
)
219+
220+
if value < monotone_value: # pragma: no cover
221+
logger.debug(
222+
"Ensured monotonicity with a sub-coalition's value. Increased the value of the current coalition from %s to %s.",
223+
value,
224+
monotone_value,
225+
)
226+
227+
return monotone_value

tests/fixtures/simple_setup.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
from __future__ import annotations
44

55
import pytest
6-
from ConfigSpace import Configuration, ConfigurationSpace, LessThanCondition, UniformFloatHyperparameter
6+
from ConfigSpace import (
7+
Configuration,
8+
ConfigurationSpace,
9+
GreaterThanCondition,
10+
LessThanCondition,
11+
UniformFloatHyperparameter,
12+
)
713

814
from hypershap import ExplanationTask
915

@@ -88,10 +94,34 @@ def simple_cond_config_space() -> ConfigurationSpace:
8894
return config_space
8995

9096

97+
@pytest.fixture(scope="session")
98+
def simple_act_config_space() -> ConfigurationSpace:
99+
"""Return a simple config space with activation structure for testing."""
100+
config_space = ConfigurationSpace()
101+
config_space.seed(42)
102+
103+
a = UniformFloatHyperparameter("a", 0, 1, 0)
104+
b = UniformFloatHyperparameter("b", 0, 1, 0)
105+
config_space.add(a)
106+
config_space.add(b)
107+
108+
config_space.add(GreaterThanCondition(b, a, 0.3))
109+
return config_space
110+
111+
91112
@pytest.fixture(scope="session")
92113
def simple_cond_base_et(
93114
simple_cond_config_space: ConfigurationSpace,
94115
simple_blackbox_function: SimpleBlackboxFunction,
95116
) -> ExplanationTask:
96117
"""Return a base explanation task for the simple setup with conditions."""
97118
return ExplanationTask.from_function(simple_cond_config_space, simple_blackbox_function.evaluate)
119+
120+
121+
@pytest.fixture(scope="session")
122+
def simple_act_base_et(
123+
simple_act_config_space: ConfigurationSpace,
124+
simple_blackbox_function: SimpleBlackboxFunction,
125+
) -> ExplanationTask:
126+
"""Return a base explanation task for the simple setup with conditions."""
127+
return ExplanationTask.from_function(simple_act_config_space, simple_blackbox_function.evaluate)

tests/test_extended_settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,10 @@ def test_tunability_with_conditions(simple_cond_base_et: ExplanationTask) -> Non
7979
hypershap = HyperSHAP(simple_cond_base_et)
8080
iv = hypershap.tunability(simple_cond_base_et.config_space.get_default_configuration())
8181
assert iv is not None, "Interaction values should not be none."
82+
83+
84+
def test_tunability_with_activation_structures(simple_act_base_et: ExplanationTask) -> None:
85+
"""Test the tunability task with a configuration space that has conditions."""
86+
hypershap = HyperSHAP(simple_act_base_et)
87+
iv = hypershap.tunability(simple_act_base_et.config_space.get_default_configuration())
88+
assert iv is not None, "Interaction values should not be none."

tests/test_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from hypershap import ExplanationTask
1313
from tests.fixtures.simple_setup import SimpleBlackboxFunction
1414

15+
from ConfigSpace import Configuration
16+
1517
from hypershap.task import BaselineExplanationTask
1618
from hypershap.utils import Aggregation, RandomConfigSpaceSearcher, evaluate_aggregation
1719

@@ -144,3 +146,31 @@ def test_evaluate_aggregation() -> None:
144146
assert evaluate_aggregation(Aggregation.MAX, vals) == AGG_LIST[2]
145147
assert evaluate_aggregation(Aggregation.AVG, vals) == np.array(AGG_LIST).mean()
146148
assert abs(evaluate_aggregation(Aggregation.VAR, vals) - np.array(AGG_LIST).var()) < EPSILON
149+
150+
151+
def test_fallback_unfulfilled_conditions(simple_act_base_et: ExplanationTask) -> None:
152+
"""Test the fallback strategy when no configurations are left in random sample after filtering for conditions."""
153+
bet = BaselineExplanationTask(
154+
simple_act_base_et.config_space,
155+
simple_act_base_et.surrogate_model,
156+
simple_act_base_et.config_space.get_default_configuration(),
157+
)
158+
rcss = RandomConfigSpaceSearcher(bet)
159+
rcss.random_sample = np.array([
160+
Configuration(
161+
configuration_space=simple_act_base_et.config_space,
162+
values={
163+
"a": 0.4,
164+
"b": 0.1,
165+
},
166+
).get_array(),
167+
Configuration(
168+
configuration_space=simple_act_base_et.config_space,
169+
values={
170+
"a": 0.5,
171+
"b": 0.1,
172+
},
173+
).get_array(),
174+
])
175+
value = rcss.search(np.array([False, True]))
176+
assert value is not None

0 commit comments

Comments
 (0)