diff --git a/sklearn/base.py b/sklearn/base.py index 9abd145ce37a6..bf902eb7756af 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -1335,11 +1335,7 @@ def wrapper(estimator, *args, **kwargs): prefer_skip_nested_validation or global_skip_validation ) ): - try: - return fit_method(estimator, *args, **kwargs) - finally: - if hasattr(estimator, "_callback_fit_ctx"): - estimator._callback_fit_ctx.eval_on_fit_end(estimator=estimator) + return fit_method(estimator, *args, **kwargs) return wrapper diff --git a/sklearn/callback/_callback_context.py b/sklearn/callback/_callback_context.py index 1c44e01f094b4..cddac9fc2f99c 100644 --- a/sklearn/callback/_callback_context.py +++ b/sklearn/callback/_callback_context.py @@ -123,7 +123,7 @@ class CallbackContext: """ @classmethod - def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None): + def _from_estimator(cls, estimator, task_name): """Private constructor to create a root context. Parameters @@ -133,13 +133,6 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None): task_name : str The name of the task this context is responsible for. - - task_id : int - The id of the task this context is responsible for. - - max_subtasks : int or None, default=None - The maximum number of subtasks of this task. 0 means it's a leaf. - None means the maximum number of subtasks is not known in advance. """ new_ctx = cls.__new__(cls) @@ -148,10 +141,10 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None): new_ctx._callbacks = getattr(estimator, "_skl_callbacks", []) new_ctx.estimator_name = estimator.__class__.__name__ new_ctx.task_name = task_name - new_ctx.task_id = task_id + new_ctx.task_id = 0 new_ctx.parent = None new_ctx._children_map = {} - new_ctx.max_subtasks = max_subtasks + new_ctx.max_subtasks = None new_ctx.prev_estimator_name = None new_ctx.prev_task_name = None diff --git a/sklearn/callback/_mixin.py b/sklearn/callback/_mixin.py index 2827d05e930d7..1bc51ee1e75c0 100644 --- a/sklearn/callback/_mixin.py +++ b/sklearn/callback/_mixin.py @@ -1,6 +1,8 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import functools + from sklearn.callback._base import Callback from sklearn.callback._callback_context import CallbackContext @@ -31,25 +33,28 @@ def set_callbacks(self, callbacks): return self - def __skl_init_callback_context__(self, task_name="fit", max_subtasks=None): - """Initialize the callback context for the estimator. - Parameters - ---------- - task_name : str, default='fit' - The name of the root task. +def fit_callback(fit_method): + """Decorator to initialize the callback context for the fit methods.""" - max_subtasks : int or None, default=None - The maximum number of subtasks that can be children of the root task. None - means the maximum number of subtasks is not known in advance. + @functools.wraps(fit_method) + def callback_wrapper(estimator, *args, **kwargs): + if not isinstance(estimator, CallbackSupportMixin): + raise ValueError( + f"Estimator {estimator.__class__.__name__} does not support callbacks," + " as it does not inherit from CallbackSupportMixin." + ) - Returns - ------- - callback_fit_ctx : CallbackContext - The callback context for the estimator. - """ - self._callback_fit_ctx = CallbackContext._from_estimator( - estimator=self, task_name=task_name, task_id=0, max_subtasks=max_subtasks + estimator.__sklearn_callback_fit_ctx__ = CallbackContext._from_estimator( + estimator, task_name=fit_method.__name__ ) - return self._callback_fit_ctx + try: + return fit_method(estimator, *args, **kwargs) + finally: + estimator.__sklearn_callback_fit_ctx__.eval_on_fit_end(estimator) + del estimator.__sklearn_callback_fit_ctx__ + if hasattr(estimator, "_parent_callback_ctx"): + del estimator._parent_callback_ctx + + return callback_wrapper diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 21a011e02afd8..7fe9a45541298 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -5,6 +5,7 @@ from sklearn.base import BaseEstimator, _fit_context, clone from sklearn.callback import CallbackSupportMixin +from sklearn.callback._mixin import fit_callback from sklearn.utils.parallel import Parallel, delayed @@ -14,7 +15,7 @@ class TestingCallback: def on_fit_begin(self, estimator): pass - def on_fit_end(self): + def on_fit_end(self, estimator, context): pass def on_fit_task_end(self, estimator, context, **kwargs): @@ -49,12 +50,12 @@ def __init__(self, max_iter=20, computation_intensity=0.001): self.max_iter = max_iter self.computation_intensity = computation_intensity - @_fit_context(prefer_skip_nested_validation=False) - def fit(self, X=None, y=None): - callback_ctx = self.__skl_init_callback_context__( - max_subtasks=self.max_iter - ).eval_on_fit_begin(estimator=self) - + @fit_callback + @_fit_context(prefer_skip_nested_validation=True) + def fit(self, X=None, y=None, X_val=None, y_val=None): + callback_ctx = self.__sklearn_callback_fit_ctx__ + callback_ctx.max_subtasks = self.max_iter + callback_ctx.eval_on_fit_begin(estimator=self) for i in range(self.max_iter): subcontext = callback_ctx.subcontext(task_id=i) @@ -82,12 +83,12 @@ class WhileEstimator(CallbackSupportMixin, BaseEstimator): def __init__(self, computation_intensity=0.001): self.computation_intensity = computation_intensity - @_fit_context(prefer_skip_nested_validation=False) - def fit(self, X=None, y=None): - callback_ctx = self.__skl_init_callback_context__().eval_on_fit_begin( + @fit_callback + @_fit_context(prefer_skip_nested_validation=True) + def fit(self, X=None, y=None, X_val=None, y_val=None): + callback_ctx = self.__sklearn_callback_fit_ctx__.eval_on_fit_begin( estimator=self ) - i = 0 while True: subcontext = callback_ctx.subcontext(task_id=i) @@ -126,11 +127,12 @@ def __init__( self.n_jobs = n_jobs self.prefer = prefer + @fit_callback @_fit_context(prefer_skip_nested_validation=False) def fit(self, X=None, y=None): - callback_ctx = self.__skl_init_callback_context__( - max_subtasks=self.n_outer - ).eval_on_fit_begin(estimator=self) + callback_ctx = self.__sklearn_callback_fit_ctx__ + callback_ctx.max_subtasks = self.n_outer + callback_ctx.eval_on_fit_begin(estimator=self) Parallel(n_jobs=self.n_jobs, prefer=self.prefer)( delayed(_func)( @@ -167,3 +169,48 @@ def _func(meta_estimator, inner_estimator, X, y, *, callback_ctx): estimator=meta_estimator, data={"X_train": X, "y_train": y}, ) + + +class EstimatorWithoutCallbackMixin(BaseEstimator): + @fit_callback + def fit(self, X=None, y=None): + pass + + +class SimpleMetaEstimator(CallbackSupportMixin, BaseEstimator): + """A class that mimics the behavior of a meta-estimator that does not clone the + estimator and does not parallelize. + + There is no iteration, the meta estimator simply calls the fit of the estimator once + in a subcontext. + """ + + _parameter_constraints: dict = {} + + def __init__(self, estimator): + self.estimator = estimator + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags._prefer_skip_nested_validation = False + return tags + + @fit_callback + @_fit_context(prefer_skip_nested_validation=False) + def fit(self, X=None, y=None): + callback_ctx = self.__sklearn_callback_fit_ctx__ + callback_ctx.max_subtasks = 1 + callback_ctx.eval_on_fit_begin(estimator=self) + subcontext = callback_ctx.subcontext(task_name="subtask").propagate_callbacks( + sub_estimator=self.estimator + ) + self.estimator.fit(X, y) + callback_ctx.eval_on_fit_task_end( + estimator=self, + data={ + "X_train": X, + "y_train": y, + }, + ) + + return self diff --git a/sklearn/callback/tests/test_callback_context.py b/sklearn/callback/tests/test_callback_context.py index e61a21d9461d5..7834fa2e775af 100644 --- a/sklearn/callback/tests/test_callback_context.py +++ b/sklearn/callback/tests/test_callback_context.py @@ -8,6 +8,7 @@ from sklearn.callback.tests._utils import ( Estimator, MetaEstimator, + SimpleMetaEstimator, TestingAutoPropagatedCallback, TestingCallback, ) @@ -22,7 +23,7 @@ def test_propagate_callbacks(): metaestimator = MetaEstimator(estimator) metaestimator.set_callbacks([not_propagated_callback, propagated_callback]) - callback_ctx = metaestimator.__skl_init_callback_context__() + callback_ctx = CallbackContext._from_estimator(metaestimator, task_name="fit") callback_ctx.propagate_callbacks(estimator) assert hasattr(estimator, "_parent_callback_ctx") @@ -35,7 +36,7 @@ def test_propagate_callback_no_callback(): estimator = Estimator() metaestimator = MetaEstimator(estimator) - callback_ctx = metaestimator.__skl_init_callback_context__() + callback_ctx = CallbackContext._from_estimator(metaestimator, task_name="fit") assert len(callback_ctx._callbacks) == 0 callback_ctx.propagate_callbacks(estimator) @@ -62,28 +63,20 @@ def test_auto_propagated_callbacks(): def _make_task_tree(n_children, n_grandchildren): """Helper function to create a tree of tasks with a context for each task.""" estimator = Estimator() - root = CallbackContext._from_estimator( - estimator, - task_name="root task", - task_id=0, - max_subtasks=n_children, - ) + root = CallbackContext._from_estimator(estimator, task_name="root task") + root.max_subtasks = n_children for i in range(n_children): - child = CallbackContext._from_estimator( - estimator, - task_name="child task", - task_id=i, - max_subtasks=n_grandchildren, - ) + child = CallbackContext._from_estimator(estimator, task_name="child task") + child.task_id = (i,) + child.max_subtasks = n_grandchildren root._add_child(child) for j in range(n_grandchildren): grandchild = CallbackContext._from_estimator( - estimator, - task_name="grandchild task", - task_id=j, + estimator, task_name="grandchild task" ) + grandchild.task_id = j child._add_child(grandchild) return root @@ -122,57 +115,48 @@ def test_task_tree(): def test_add_child(): """Sanity check for the `_add_child` method.""" estimator = Estimator() - root = CallbackContext._from_estimator( - estimator, task_name="root task", task_id=0, max_subtasks=2 - ) + root = CallbackContext._from_estimator(estimator, task_name="root task") + root.max_subtasks = 2 - root._add_child( - CallbackContext._from_estimator(estimator, task_name="child task", task_id=0) - ) + first_child = CallbackContext._from_estimator(estimator, task_name="child task") + + root._add_child(first_child) assert root.max_subtasks == 2 assert len(root._children_map) == 1 + second_child = CallbackContext._from_estimator(estimator, task_name="child task") # root already has a child with id 0 with pytest.raises( ValueError, match=r"Callback context .* already has a child with task_id=0" ): - root._add_child( - CallbackContext._from_estimator( - estimator, task_name="child task", task_id=0 - ) - ) + root._add_child(second_child) - root._add_child( - CallbackContext._from_estimator(estimator, task_name="child task", task_id=1) - ) + second_child.task_id = 1 + root._add_child(second_child) assert len(root._children_map) == 2 + third_child = CallbackContext._from_estimator(estimator, task_name="child task") + third_child.task_id = 2 # root can have at most 2 children with pytest.raises(ValueError, match=r"Cannot add child to callback context"): - root._add_child( - CallbackContext._from_estimator( - estimator, task_name="child task", task_id=2 - ) - ) + root._add_child(third_child) def test_merge_with(): """Sanity check for the `_merge_with` method.""" estimator = Estimator() meta_estimator = MetaEstimator(estimator) - outer_root = CallbackContext._from_estimator( - meta_estimator, task_name="root", task_id=0, max_subtasks=2 - ) + outer_root = CallbackContext._from_estimator(meta_estimator, task_name="root") + outer_root.max_subtasks = 2 # Add a child task within the same estimator - outer_child = CallbackContext._from_estimator( - meta_estimator, task_name="child", task_id="id", max_subtasks=1 - ) + outer_child = CallbackContext._from_estimator(meta_estimator, task_name="child") + outer_child.max_subtasks = 1 outer_root._add_child(outer_child) # The root task of the inner estimator is merged with (and effectively replaces) # a leaf of the outer estimator because they correspond to the same formal task. - inner_root = CallbackContext._from_estimator(estimator, task_name="root", task_id=0) + inner_root = CallbackContext._from_estimator(estimator, task_name="root") inner_root._merge_with(outer_child) assert inner_root.parent is outer_root @@ -183,3 +167,12 @@ def test_merge_with(): # The name and estimator name of the tasks it was merged with are stored assert inner_root.prev_task_name == outer_child.task_name assert inner_root.prev_estimator_name == outer_child.estimator_name + + +def test_no_parent_callback_after_fit(): + """Check that the `_parent_callback_ctx` attribute does not survive after fit.""" + estimator = Estimator() + meta_estimator = SimpleMetaEstimator(estimator) + meta_estimator.set_callbacks(TestingAutoPropagatedCallback()) + meta_estimator.fit() + assert not hasattr(estimator, "_parent_callback_ctx") diff --git a/sklearn/callback/tests/test_mixin.py b/sklearn/callback/tests/test_mixin.py index 006439f2cbd18..134738d70db15 100644 --- a/sklearn/callback/tests/test_mixin.py +++ b/sklearn/callback/tests/test_mixin.py @@ -5,6 +5,7 @@ from sklearn.callback.tests._utils import ( Estimator, + EstimatorWithoutCallbackMixin, NotValidCallback, TestingAutoPropagatedCallback, TestingCallback, @@ -41,10 +42,20 @@ def test_set_callbacks_error(callbacks): estimator.set_callbacks(callbacks) -def test_init_callback_context(): - """Sanity check for the `__skl_init_callback_context__` method.""" +def test_callback_removed_after_fit(): + """Test that the __sklearn_callback_fit_ctx__ attribute gets removed after fit.""" estimator = Estimator() - callback_ctx = estimator.__skl_init_callback_context__() - - assert hasattr(estimator, "_callback_fit_ctx") - assert hasattr(callback_ctx, "_callbacks") + estimator.fit() + assert not hasattr(estimator, "__sklearn_callback_fit_ctx__") + + +def test_decorator_error(): + """Test the error raised by the fit_callback decotrator if the estimator does not + inherit from CallbackSupportMixin""" + estimator = EstimatorWithoutCallbackMixin() + with pytest.raises( + ValueError, + match="does not support callbacks, as it does not inherit from" + " CallbackSupportMixin.", + ): + estimator.fit()