diff --git a/sklearn/base.py b/sklearn/base.py index 3c7eaa26dca08..bf99d56b1ab32 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -15,6 +15,7 @@ from sklearn import __version__ from sklearn._config import config_context, get_config +from sklearn.callback._callback_context import callback_management_context from sklearn.exceptions import InconsistentVersionWarning from sklearn.utils._metadata_requests import _MetadataRequester, _routing_enabled from sklearn.utils._missing import is_pandas_na, is_scalar_nan @@ -1334,17 +1335,37 @@ def wrapper(estimator, *args, **kwargs): if not global_skip_validation and not partial_fit_and_fitted: estimator._validate_params() - with config_context( - skip_parameter_validation=( - prefer_skip_nested_validation or global_skip_validation - ) + with ( + config_context( + skip_parameter_validation=( + prefer_skip_nested_validation or global_skip_validation + ) + ), + callback_management_context(estimator, fit_method.__name__), ): - try: - return fit_method(estimator, *args, **kwargs) - finally: - if hasattr(estimator, "_callback_fit_ctx"): - estimator._callback_fit_ctx.eval_on_fit_end(estimator=estimator) + return fit_method(estimator, *args, **kwargs) return wrapper return decorator + + +def fit_callback_context(fit_method): + """Decorator to run the fit methods within the callback context manager. + + Parameters + ---------- + fit_method : method + The fit method to decorate. + + Returns + ------- + decorated_fit : method + The decorated fit method. + """ + + def wrapper(estimator, *args, **kwargs): + with callback_management_context(estimator, fit_method.__name__): + fit_method(estimator, *args, **kwargs) + + return wrapper diff --git a/sklearn/callback/_callback_context.py b/sklearn/callback/_callback_context.py index 1c44e01f094b4..654b0a2570a38 100644 --- a/sklearn/callback/_callback_context.py +++ b/sklearn/callback/_callback_context.py @@ -1,6 +1,8 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +from contextlib import contextmanager + from sklearn.callback import AutoPropagatedCallback # TODO(callbacks): move these explanations into a dedicated user guide. @@ -63,7 +65,8 @@ # # @_fit_context() # def fit(self, X, y): -# callback_ctx = self.__skl__init_callback_context__(max_subtasks=self.max_iter) +# callback_ctx = self._callback_fit_ctx +# callback_ctx.max_subtasks = self.max_iter # callback_ctx.eval_on_fit_begin(estimator=self) # # for i in range(self.max_iter): @@ -123,7 +126,7 @@ class CallbackContext: """ @classmethod - def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None): + def _from_estimator(cls, estimator, task_name): """Private constructor to create a root context. Parameters @@ -133,13 +136,6 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None): task_name : str The name of the task this context is responsible for. - - task_id : int - The id of the task this context is responsible for. - - max_subtasks : int or None, default=None - The maximum number of subtasks of this task. 0 means it's a leaf. - None means the maximum number of subtasks is not known in advance. """ new_ctx = cls.__new__(cls) @@ -148,10 +144,10 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None): new_ctx._callbacks = getattr(estimator, "_skl_callbacks", []) new_ctx.estimator_name = estimator.__class__.__name__ new_ctx.task_name = task_name - new_ctx.task_id = task_id + new_ctx.task_id = 0 new_ctx.parent = None new_ctx._children_map = {} - new_ctx.max_subtasks = max_subtasks + new_ctx.max_subtasks = None new_ctx.prev_estimator_name = None new_ctx.prev_task_name = None @@ -269,6 +265,27 @@ def subcontext(self, task_name="", task_id=0, max_subtasks=None): max_subtasks=max_subtasks, ) + def set_task_info(self, *, task_id=None, task_name=None, max_subtasks=None): + """Setter for the task_id, task_name and max_subtasks attributes. + + Parameters + ---------- + task_id : int or None, default=None + Id of the context's task, ignored if None. + task_name : str or None, default=None + Name of the context's task, ignored if None. + max_subtasks : int or None, default=None + Number of maximum subtasks for this context's task, ignored if None. + """ + if task_id is not None: + self.task_id = task_id + if task_name is not None: + self.task_name = task_name + if max_subtasks is not None: + self.max_subtasks = max_subtasks + + return self + def eval_on_fit_begin(self, estimator): """Evaluate the `on_fit_begin` method of the callbacks. @@ -415,3 +432,35 @@ def get_context_path(context): if context.parent is None else get_context_path(context.parent) + [context] ) + + +@contextmanager +def callback_management_context(estimator, fit_method_name): + """Context manager for the CallbackContext initialization and clean-up during fit. + + Parameters + ---------- + estimator : estimator instance + Estimator being fitted. + fit_method_name : str + The name of the fit method being called. + + Yields + ------ + None. + """ + estimator._callback_fit_ctx = CallbackContext._from_estimator( + estimator, task_name=fit_method_name + ) + try: + yield + finally: + try: + estimator._callback_fit_ctx.eval_on_fit_end(estimator) + del estimator._callback_fit_ctx + except AttributeError: + pass + try: + del estimator._parent_callback_ctx + except AttributeError: + pass diff --git a/sklearn/callback/_mixin.py b/sklearn/callback/_mixin.py index 2827d05e930d7..6289faec8ed50 100644 --- a/sklearn/callback/_mixin.py +++ b/sklearn/callback/_mixin.py @@ -49,7 +49,7 @@ def __skl_init_callback_context__(self, task_name="fit", max_subtasks=None): The callback context for the estimator. """ self._callback_fit_ctx = CallbackContext._from_estimator( - estimator=self, task_name=task_name, task_id=0, max_subtasks=max_subtasks + estimator=self, task_name=task_name ) return self._callback_fit_ctx diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 21a011e02afd8..6cbece4bd0efd 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -3,7 +3,7 @@ import time -from sklearn.base import BaseEstimator, _fit_context, clone +from sklearn.base import BaseEstimator, _fit_context, clone, fit_callback_context from sklearn.callback import CallbackSupportMixin from sklearn.utils.parallel import Parallel, delayed @@ -14,7 +14,7 @@ class TestingCallback: def on_fit_begin(self, estimator): pass - def on_fit_end(self): + def on_fit_end(self, estimator, context): pass def on_fit_task_end(self, estimator, context, **kwargs): @@ -51,12 +51,12 @@ def __init__(self, max_iter=20, computation_intensity=0.001): @_fit_context(prefer_skip_nested_validation=False) def fit(self, X=None, y=None): - callback_ctx = self.__skl_init_callback_context__( + self._callback_fit_ctx.set_task_info( max_subtasks=self.max_iter ).eval_on_fit_begin(estimator=self) for i in range(self.max_iter): - subcontext = callback_ctx.subcontext(task_id=i) + subcontext = self._callback_fit_ctx.subcontext(task_id=i) time.sleep(self.computation_intensity) # Computation intensive task @@ -84,13 +84,11 @@ def __init__(self, computation_intensity=0.001): @_fit_context(prefer_skip_nested_validation=False) def fit(self, X=None, y=None): - callback_ctx = self.__skl_init_callback_context__().eval_on_fit_begin( - estimator=self - ) + self._callback_fit_ctx.eval_on_fit_begin(estimator=self) i = 0 while True: - subcontext = callback_ctx.subcontext(task_id=i) + subcontext = self._callback_fit_ctx.subcontext(task_id=i) time.sleep(self.computation_intensity) # Computation intensive task @@ -128,7 +126,7 @@ def __init__( @_fit_context(prefer_skip_nested_validation=False) def fit(self, X=None, y=None): - callback_ctx = self.__skl_init_callback_context__( + self._callback_fit_ctx.set_task_info( max_subtasks=self.n_outer ).eval_on_fit_begin(estimator=self) @@ -138,7 +136,7 @@ def fit(self, X=None, y=None): self.estimator, X, y, - callback_ctx=callback_ctx.subcontext( + callback_ctx=self._callback_fit_ctx.subcontext( task_name="outer", task_id=i, max_subtasks=self.n_inner ), ) @@ -167,3 +165,86 @@ def _func(meta_estimator, inner_estimator, X, y, *, callback_ctx): estimator=meta_estimator, data={"X_train": X, "y_train": y}, ) + + +class SimpleMetaEstimator(CallbackSupportMixin, BaseEstimator): + """A class that mimics the behavior of a meta-estimator that does not clone the + estimator and does not parallelize. + There is no iteration, the meta estimator simply calls the fit of the estimator once + in a subcontext. + """ + + _parameter_constraints: dict = {} + + def __init__(self, estimator): + self.estimator = estimator + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags._prefer_skip_nested_validation = False + return tags + + @_fit_context(prefer_skip_nested_validation=False) + def fit(self, X=None, y=None): + self._callback_fit_ctx.set_task_info(max_subtasks=1).eval_on_fit_begin( + estimator=self + ) + subcontext = self._callback_fit_ctx.subcontext( + task_name="subtask" + ).propagate_callbacks(sub_estimator=self.estimator) + self.estimator.fit(X, y) + self._callback_fit_ctx.eval_on_fit_task_end( + estimator=self, + data={ + "X_train": X, + "y_train": y, + }, + ) + + return self + + +class PublicFitDecoratorEstimator(CallbackSupportMixin, BaseEstimator): + """A class that mimics a third-party estimator using the public decorator + `fit_callback_context` to manage callback contexts during fit. + """ + + _parameter_constraints: dict = {} + + def __init__(self, max_iter=20, computation_intensity=0.001): + self.max_iter = max_iter + self.computation_intensity = computation_intensity + + @fit_callback_context + def fit(self, X=None, y=None): + self._callback_fit_ctx.set_task_info( + max_subtasks=self.max_iter + ).eval_on_fit_begin(estimator=self) + + for i in range(self.max_iter): + subcontext = self._callback_fit_ctx.subcontext(task_id=i) + + time.sleep(self.computation_intensity) # Computation intensive task + + if subcontext.eval_on_fit_task_end( + estimator=self, + data={"X_train": X, "y_train": y}, + ): + break + + self.n_iter_ = i + 1 + + return self + + +class ParentFitEstimator(Estimator): + """A class mimicking an estimator that uses its parent fit method.""" + + _parameter_constraints: dict = {} + + def __init__(self, max_iter=20, computation_intensity=0.001): + super().__init__(max_iter, computation_intensity) + + @_fit_context(prefer_skip_nested_validation=False) + def fit(self, X=None, y=None): + return super().fit(X, y) diff --git a/sklearn/callback/tests/test_callback_context.py b/sklearn/callback/tests/test_callback_context.py index e61a21d9461d5..021763ef150c9 100644 --- a/sklearn/callback/tests/test_callback_context.py +++ b/sklearn/callback/tests/test_callback_context.py @@ -8,6 +8,7 @@ from sklearn.callback.tests._utils import ( Estimator, MetaEstimator, + SimpleMetaEstimator, TestingAutoPropagatedCallback, TestingCallback, ) @@ -22,7 +23,7 @@ def test_propagate_callbacks(): metaestimator = MetaEstimator(estimator) metaestimator.set_callbacks([not_propagated_callback, propagated_callback]) - callback_ctx = metaestimator.__skl_init_callback_context__() + callback_ctx = CallbackContext._from_estimator(metaestimator, task_name="fit") callback_ctx.propagate_callbacks(estimator) assert hasattr(estimator, "_parent_callback_ctx") @@ -35,7 +36,7 @@ def test_propagate_callback_no_callback(): estimator = Estimator() metaestimator = MetaEstimator(estimator) - callback_ctx = metaestimator.__skl_init_callback_context__() + callback_ctx = CallbackContext._from_estimator(metaestimator, task_name="fit") assert len(callback_ctx._callbacks) == 0 callback_ctx.propagate_callbacks(estimator) @@ -62,28 +63,19 @@ def test_auto_propagated_callbacks(): def _make_task_tree(n_children, n_grandchildren): """Helper function to create a tree of tasks with a context for each task.""" estimator = Estimator() - root = CallbackContext._from_estimator( - estimator, - task_name="root task", - task_id=0, - max_subtasks=n_children, - ) + root = CallbackContext._from_estimator(estimator, task_name="root task") + root.set_task_info(max_subtasks=n_children) for i in range(n_children): - child = CallbackContext._from_estimator( - estimator, - task_name="child task", - task_id=i, - max_subtasks=n_grandchildren, - ) + child = CallbackContext._from_estimator(estimator, task_name="child task") + child.set_task_info(task_id=i, max_subtasks=n_grandchildren) root._add_child(child) for j in range(n_grandchildren): grandchild = CallbackContext._from_estimator( - estimator, - task_name="grandchild task", - task_id=j, + estimator, task_name="grandchild task" ) + grandchild.set_task_info(task_id=j) child._add_child(grandchild) return root @@ -122,57 +114,48 @@ def test_task_tree(): def test_add_child(): """Sanity check for the `_add_child` method.""" estimator = Estimator() - root = CallbackContext._from_estimator( - estimator, task_name="root task", task_id=0, max_subtasks=2 - ) + root = CallbackContext._from_estimator(estimator, task_name="root task") + root.set_task_info(max_subtasks=2) - root._add_child( - CallbackContext._from_estimator(estimator, task_name="child task", task_id=0) - ) + first_child = CallbackContext._from_estimator(estimator, task_name="child task") + + root._add_child(first_child) assert root.max_subtasks == 2 assert len(root._children_map) == 1 + second_child = CallbackContext._from_estimator(estimator, task_name="child task") # root already has a child with id 0 with pytest.raises( ValueError, match=r"Callback context .* already has a child with task_id=0" ): - root._add_child( - CallbackContext._from_estimator( - estimator, task_name="child task", task_id=0 - ) - ) + root._add_child(second_child) - root._add_child( - CallbackContext._from_estimator(estimator, task_name="child task", task_id=1) - ) + second_child.set_task_info(task_id=1) + root._add_child(second_child) assert len(root._children_map) == 2 + third_child = CallbackContext._from_estimator(estimator, task_name="child task") + third_child.set_task_info(task_id=2) # root can have at most 2 children with pytest.raises(ValueError, match=r"Cannot add child to callback context"): - root._add_child( - CallbackContext._from_estimator( - estimator, task_name="child task", task_id=2 - ) - ) + root._add_child(third_child) def test_merge_with(): """Sanity check for the `_merge_with` method.""" estimator = Estimator() meta_estimator = MetaEstimator(estimator) - outer_root = CallbackContext._from_estimator( - meta_estimator, task_name="root", task_id=0, max_subtasks=2 - ) + outer_root = CallbackContext._from_estimator(meta_estimator, task_name="root") + outer_root.set_task_info(max_subtasks=2) # Add a child task within the same estimator - outer_child = CallbackContext._from_estimator( - meta_estimator, task_name="child", task_id="id", max_subtasks=1 - ) + outer_child = CallbackContext._from_estimator(meta_estimator, task_name="child") + outer_child.set_task_info(max_subtasks=1) outer_root._add_child(outer_child) # The root task of the inner estimator is merged with (and effectively replaces) # a leaf of the outer estimator because they correspond to the same formal task. - inner_root = CallbackContext._from_estimator(estimator, task_name="root", task_id=0) + inner_root = CallbackContext._from_estimator(estimator, task_name="root") inner_root._merge_with(outer_child) assert inner_root.parent is outer_root @@ -183,3 +166,12 @@ def test_merge_with(): # The name and estimator name of the tasks it was merged with are stored assert inner_root.prev_task_name == outer_child.task_name assert inner_root.prev_estimator_name == outer_child.estimator_name + + +def test_no_parent_callback_after_fit(): + """Check that the `_parent_callback_ctx` attribute does not survive after fit.""" + estimator = Estimator() + meta_estimator = SimpleMetaEstimator(estimator) + meta_estimator.set_callbacks(TestingAutoPropagatedCallback()) + meta_estimator.fit() + assert not hasattr(estimator, "_parent_callback_ctx") diff --git a/sklearn/callback/tests/test_mixin.py b/sklearn/callback/tests/test_mixin.py index 006439f2cbd18..64587c3a8aabd 100644 --- a/sklearn/callback/tests/test_mixin.py +++ b/sklearn/callback/tests/test_mixin.py @@ -6,6 +6,8 @@ from sklearn.callback.tests._utils import ( Estimator, NotValidCallback, + ParentFitEstimator, + PublicFitDecoratorEstimator, TestingAutoPropagatedCallback, TestingCallback, ) @@ -48,3 +50,23 @@ def test_init_callback_context(): assert hasattr(estimator, "_callback_fit_ctx") assert hasattr(callback_ctx, "_callbacks") + + +def test_callback_removed_after_fit(): + """Test that the _callback_fit_ctx attribute gets removed after fit.""" + estimator = Estimator() + estimator.fit() + assert not hasattr(estimator, "_callback_fit_ctx") + + +def test_public_fit_decorator(): + """Sanity check of the public fit decorator to manage callback contexts during + fit.""" + estimator = PublicFitDecoratorEstimator() + estimator.fit() + + +def test_inheritated_fit(): + """Test with an estimator that uses its parent fit function.""" + estimator = ParentFitEstimator() + estimator.fit()