Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from sklearn import __version__
from sklearn._config import config_context, get_config
from sklearn.callback._callback_context import callback_management_context
from sklearn.exceptions import InconsistentVersionWarning
from sklearn.utils._metadata_requests import _MetadataRequester, _routing_enabled
from sklearn.utils._missing import is_pandas_na, is_scalar_nan
Expand Down Expand Up @@ -1334,17 +1335,37 @@ def wrapper(estimator, *args, **kwargs):
if not global_skip_validation and not partial_fit_and_fitted:
estimator._validate_params()

with config_context(
skip_parameter_validation=(
prefer_skip_nested_validation or global_skip_validation
)
with (
config_context(
skip_parameter_validation=(
prefer_skip_nested_validation or global_skip_validation
)
),
callback_management_context(estimator, fit_method.__name__),
):
try:
return fit_method(estimator, *args, **kwargs)
finally:
if hasattr(estimator, "_callback_fit_ctx"):
estimator._callback_fit_ctx.eval_on_fit_end(estimator=estimator)
return fit_method(estimator, *args, **kwargs)

return wrapper

return decorator


def fit_callback_context(fit_method):
"""Decorator to run the fit methods within the callback context manager.

Parameters
----------
fit_method : method
The fit method to decorate.

Returns
-------
decorated_fit : method
The decorated fit method.
"""

def wrapper(estimator, *args, **kwargs):
with callback_management_context(estimator, fit_method.__name__):
fit_method(estimator, *args, **kwargs)

return wrapper
71 changes: 60 additions & 11 deletions sklearn/callback/_callback_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

from contextlib import contextmanager

from sklearn.callback import AutoPropagatedCallback

# TODO(callbacks): move these explanations into a dedicated user guide.
Expand Down Expand Up @@ -63,7 +65,8 @@
#
# @_fit_context()
# def fit(self, X, y):
# callback_ctx = self.__skl__init_callback_context__(max_subtasks=self.max_iter)
# callback_ctx = self._callback_fit_ctx
# callback_ctx.max_subtasks = self.max_iter
# callback_ctx.eval_on_fit_begin(estimator=self)
#
# for i in range(self.max_iter):
Expand Down Expand Up @@ -123,7 +126,7 @@ class CallbackContext:
"""

@classmethod
def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):
def _from_estimator(cls, estimator, task_name):
"""Private constructor to create a root context.

Parameters
Expand All @@ -133,13 +136,6 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):

task_name : str
The name of the task this context is responsible for.

task_id : int
The id of the task this context is responsible for.

max_subtasks : int or None, default=None
The maximum number of subtasks of this task. 0 means it's a leaf.
None means the maximum number of subtasks is not known in advance.
"""
new_ctx = cls.__new__(cls)

Expand All @@ -148,10 +144,10 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):
new_ctx._callbacks = getattr(estimator, "_skl_callbacks", [])
new_ctx.estimator_name = estimator.__class__.__name__
new_ctx.task_name = task_name
new_ctx.task_id = task_id
new_ctx.task_id = 0
new_ctx.parent = None
new_ctx._children_map = {}
new_ctx.max_subtasks = max_subtasks
new_ctx.max_subtasks = None
new_ctx.prev_estimator_name = None
new_ctx.prev_task_name = None

Expand Down Expand Up @@ -269,6 +265,27 @@ def subcontext(self, task_name="", task_id=0, max_subtasks=None):
max_subtasks=max_subtasks,
)

def set_task_info(self, *, task_id=None, task_name=None, max_subtasks=None):
"""Setter for the task_id, task_name and max_subtasks attributes.

Parameters
----------
task_id : int or None, default=None
Id of the context's task, ignored if None.
task_name : str or None, default=None
Name of the context's task, ignored if None.
max_subtasks : int or None, default=None
Number of maximum subtasks for this context's task, ignored if None.
"""
if task_id is not None:
self.task_id = task_id
if task_name is not None:
self.task_name = task_name
if max_subtasks is not None:
self.max_subtasks = max_subtasks

return self

def eval_on_fit_begin(self, estimator):
"""Evaluate the `on_fit_begin` method of the callbacks.

Expand Down Expand Up @@ -415,3 +432,35 @@ def get_context_path(context):
if context.parent is None
else get_context_path(context.parent) + [context]
)


@contextmanager
def callback_management_context(estimator, fit_method_name):
"""Context manager for the CallbackContext initialization and clean-up during fit.

Parameters
----------
estimator : estimator instance
Estimator being fitted.
fit_method_name : str
The name of the fit method being called.

Yields
------
None.
"""
estimator._callback_fit_ctx = CallbackContext._from_estimator(
estimator, task_name=fit_method_name
)
try:
yield
finally:
try:
estimator._callback_fit_ctx.eval_on_fit_end(estimator)
del estimator._callback_fit_ctx
except AttributeError:
pass
try:
del estimator._parent_callback_ctx
except AttributeError:
pass
2 changes: 1 addition & 1 deletion sklearn/callback/_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __skl_init_callback_context__(self, task_name="fit", max_subtasks=None):
The callback context for the estimator.
"""
self._callback_fit_ctx = CallbackContext._from_estimator(
estimator=self, task_name=task_name, task_id=0, max_subtasks=max_subtasks
estimator=self, task_name=task_name
)

return self._callback_fit_ctx
101 changes: 91 additions & 10 deletions sklearn/callback/tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import time

from sklearn.base import BaseEstimator, _fit_context, clone
from sklearn.base import BaseEstimator, _fit_context, clone, fit_callback_context
from sklearn.callback import CallbackSupportMixin
from sklearn.utils.parallel import Parallel, delayed

Expand All @@ -14,7 +14,7 @@ class TestingCallback:
def on_fit_begin(self, estimator):
pass

def on_fit_end(self):
def on_fit_end(self, estimator, context):
pass

def on_fit_task_end(self, estimator, context, **kwargs):
Expand Down Expand Up @@ -51,12 +51,12 @@ def __init__(self, max_iter=20, computation_intensity=0.001):

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__(
self._callback_fit_ctx.set_task_info(
max_subtasks=self.max_iter
).eval_on_fit_begin(estimator=self)

for i in range(self.max_iter):
subcontext = callback_ctx.subcontext(task_id=i)
subcontext = self._callback_fit_ctx.subcontext(task_id=i)

time.sleep(self.computation_intensity) # Computation intensive task

Expand Down Expand Up @@ -84,13 +84,11 @@ def __init__(self, computation_intensity=0.001):

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__().eval_on_fit_begin(
estimator=self
)
self._callback_fit_ctx.eval_on_fit_begin(estimator=self)

i = 0
while True:
subcontext = callback_ctx.subcontext(task_id=i)
subcontext = self._callback_fit_ctx.subcontext(task_id=i)

time.sleep(self.computation_intensity) # Computation intensive task

Expand Down Expand Up @@ -128,7 +126,7 @@ def __init__(

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__(
self._callback_fit_ctx.set_task_info(
max_subtasks=self.n_outer
).eval_on_fit_begin(estimator=self)

Expand All @@ -138,7 +136,7 @@ def fit(self, X=None, y=None):
self.estimator,
X,
y,
callback_ctx=callback_ctx.subcontext(
callback_ctx=self._callback_fit_ctx.subcontext(
task_name="outer", task_id=i, max_subtasks=self.n_inner
),
)
Expand Down Expand Up @@ -167,3 +165,86 @@ def _func(meta_estimator, inner_estimator, X, y, *, callback_ctx):
estimator=meta_estimator,
data={"X_train": X, "y_train": y},
)


class SimpleMetaEstimator(CallbackSupportMixin, BaseEstimator):
"""A class that mimics the behavior of a meta-estimator that does not clone the
estimator and does not parallelize.
There is no iteration, the meta estimator simply calls the fit of the estimator once
in a subcontext.
"""

_parameter_constraints: dict = {}

def __init__(self, estimator):
self.estimator = estimator

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._prefer_skip_nested_validation = False
return tags

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
self._callback_fit_ctx.set_task_info(max_subtasks=1).eval_on_fit_begin(
estimator=self
)
subcontext = self._callback_fit_ctx.subcontext(
task_name="subtask"
).propagate_callbacks(sub_estimator=self.estimator)
self.estimator.fit(X, y)
self._callback_fit_ctx.eval_on_fit_task_end(
estimator=self,
data={
"X_train": X,
"y_train": y,
},
)

return self


class PublicFitDecoratorEstimator(CallbackSupportMixin, BaseEstimator):
"""A class that mimics a third-party estimator using the public decorator
`fit_callback_context` to manage callback contexts during fit.
"""

_parameter_constraints: dict = {}

def __init__(self, max_iter=20, computation_intensity=0.001):
self.max_iter = max_iter
self.computation_intensity = computation_intensity

@fit_callback_context
def fit(self, X=None, y=None):
self._callback_fit_ctx.set_task_info(
max_subtasks=self.max_iter
).eval_on_fit_begin(estimator=self)

for i in range(self.max_iter):
subcontext = self._callback_fit_ctx.subcontext(task_id=i)

time.sleep(self.computation_intensity) # Computation intensive task

if subcontext.eval_on_fit_task_end(
estimator=self,
data={"X_train": X, "y_train": y},
):
break

self.n_iter_ = i + 1

return self


class ParentFitEstimator(Estimator):
"""A class mimicking an estimator that uses its parent fit method."""

_parameter_constraints: dict = {}

def __init__(self, max_iter=20, computation_intensity=0.001):
super().__init__(max_iter, computation_intensity)

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
return super().fit(X, y)
Loading
Loading