diff --git a/sklearn/callback/_callback_context.py b/sklearn/callback/_callback_context.py index 1c44e01f094b4..4de695a2b1dff 100644 --- a/sklearn/callback/_callback_context.py +++ b/sklearn/callback/_callback_context.py @@ -1,6 +1,9 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import inspect +import warnings +from sklearn.base import BaseEstimator from sklearn.callback import AutoPropagatedCallback # TODO(callbacks): move these explanations into a dedicated user guide. @@ -141,6 +144,18 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None): The maximum number of subtasks of this task. 0 means it's a leaf. None means the maximum number of subtasks is not known in advance. """ + if hasattr(estimator, "_skl_callbacks") and estimator._skl_callbacks: + meta_est_no_callback = called_from_no_callback_meta_estimator() + if meta_est_no_callback is not None: + warnings.warn( + f"The estimator {estimator.__class__.__name__} which supports" + f" callbacks is used within the fitting of a {meta_est_no_callback}" + " meta-estimator which does not support callbacks. The behaviour of" + f" callbacks that are attached to {estimator.__class__.__name__}" + " will be undefined.", + UserWarning, + ) + new_ctx = cls.__new__(cls) # We don't store the estimator in the context to avoid circular references @@ -358,6 +373,8 @@ def propagate_callbacks(self, sub_estimator): sub_estimator : estimator instance The estimator to which the callbacks should be propagated. """ + from sklearn.callback._mixin import CallbackSupportMixin + bad_callbacks = [ callback.__class__.__name__ for callback in getattr(sub_estimator, "_skl_callbacks", []) @@ -385,6 +402,16 @@ def propagate_callbacks(self, sub_estimator): if not callbacks_to_propagate: return self + if not isinstance(sub_estimator, CallbackSupportMixin): + warnings.warn( + f"The estimator {sub_estimator.__class__.__name__} which does not" + " supports callbacks is being used in a meta-estimator which supports" + " callbacks. The callbacks will not be propagated through this" + " estimator.", + UserWarning, + ) + return self + # We store the parent context in the sub-estimator to be able to merge the # task trees of the sub-estimator and the meta-estimator. sub_estimator._parent_callback_ctx = self @@ -415,3 +442,30 @@ def get_context_path(context): if context.parent is None else get_context_path(context.parent) + [context] ) + + +def called_from_no_callback_meta_estimator(): + """Helper function to check if in the traceback there is a call of a fit function + from a meta-estimator that does not support callbacks. + + Returns + ------- + str or None + The name of the class of the meta-estimator if there is one in the traceback + which does not support callback, None otherwise. + """ + from sklearn.callback._mixin import CallbackSupportMixin + + for frame_info in inspect.stack()[1:]: + if ( + frame_info.function not in ("fit", "fit_transform", "partial_fit") + or "self" not in frame_info.frame.f_locals + ): + continue + + if isinstance( + frame_info.frame.f_locals["self"], BaseEstimator + ) and not isinstance(frame_info.frame.f_locals["self"], CallbackSupportMixin): + return frame_info.frame.f_locals["self"].__class__.__name__ + + return None diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 21a011e02afd8..8e0e27f070c2e 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -14,7 +14,7 @@ class TestingCallback: def on_fit_begin(self, estimator): pass - def on_fit_end(self): + def on_fit_end(self, estimator, data): pass def on_fit_task_end(self, estimator, context, **kwargs): @@ -167,3 +167,57 @@ def _func(meta_estimator, inner_estimator, X, y, *, callback_ctx): estimator=meta_estimator, data={"X_train": X, "y_train": y}, ) + + +class MetaEstimatorNoCallback(BaseEstimator): + """A class that mimics the behavior of a meta-estimator which does not support + callbacks. + """ + + _parameter_constraints: dict = {} + + def __init__( + self, estimator, n_outer=4, n_inner=3, n_jobs=None, prefer="processes" + ): + self.estimator = estimator + self.n_outer = n_outer + self.n_inner = n_inner + self.n_jobs = n_jobs + self.prefer = prefer + + @_fit_context(prefer_skip_nested_validation=False) + def fit(self, X=None, y=None): + Parallel(n_jobs=self.n_jobs, prefer=self.prefer)( + delayed(_func_no_callback)(self, self.estimator, X, y) + for i in range(self.n_outer) + ) + + return self + + +def _func_no_callback(meta_estimator, inner_estimator, X, y): + for i in range(meta_estimator.n_inner): + est = clone(inner_estimator) + + est.fit(X, y) + + +class EstimatorNoCallback(BaseEstimator): + """A class that mimics the behavior of an estimator which does not support + callbacks. + """ + + _parameter_constraints: dict = {} + + def __init__(self, max_iter=20, computation_intensity=0.001): + self.max_iter = max_iter + self.computation_intensity = computation_intensity + + @_fit_context(prefer_skip_nested_validation=False) + def fit(self, X=None, y=None): + for i in range(self.max_iter): + time.sleep(self.computation_intensity) # Computation intensive task + + self.n_iter_ = i + 1 + + return self diff --git a/sklearn/callback/tests/test_callback_context.py b/sklearn/callback/tests/test_callback_context.py index e61a21d9461d5..772ed1c39fcd0 100644 --- a/sklearn/callback/tests/test_callback_context.py +++ b/sklearn/callback/tests/test_callback_context.py @@ -7,7 +7,9 @@ from sklearn.callback._callback_context import CallbackContext, get_context_path from sklearn.callback.tests._utils import ( Estimator, + EstimatorNoCallback, MetaEstimator, + MetaEstimatorNoCallback, TestingAutoPropagatedCallback, TestingCallback, ) @@ -183,3 +185,27 @@ def test_merge_with(): # The name and estimator name of the tasks it was merged with are stored assert inner_root.prev_task_name == outer_child.task_name assert inner_root.prev_estimator_name == outer_child.estimator_name + + +@pytest.mark.parametrize("n_jobs", [1, 2]) +@pytest.mark.parametrize("prefer", ["threads", "processes"]) +def test_no_callback_meta_est_warning(n_jobs, prefer): + estimator = Estimator() + estimator.set_callbacks(TestingCallback()) + meta_estimator = MetaEstimatorNoCallback(estimator, n_jobs=n_jobs, prefer=prefer) + with pytest.warns( + UserWarning, + match="meta-estimator which does not support callbacks.", + ): + meta_estimator.fit() + + +def test_no_callback_est_in_meta_est(): + estimator = EstimatorNoCallback() + meta_estimator = MetaEstimator(estimator) + meta_estimator.set_callbacks(TestingAutoPropagatedCallback()) + with pytest.warns( + UserWarning, + match="which does not supports callbacks is being used in a meta-estimator", + ): + meta_estimator.fit()