Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,7 @@ def _clone_parametrized(estimator, *, safe=True):

# attach callbacks to the new estimator
if hasattr(estimator, "_skl_callbacks"):
# TODO(callbacks): Figure out the exact behavior we want when cloning an
# estimator with callbacks.
new_object._skl_callbacks = clone(estimator._skl_callbacks, safe=False)
new_object._skl_callbacks = estimator._skl_callbacks

# quick sanity check of the parameters of the clone
for name in new_object_params:
Expand Down
2 changes: 2 additions & 0 deletions sklearn/callback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sklearn.callback._base import AutoPropagatedCallback, Callback
from sklearn.callback._callback_context import CallbackContext
from sklearn.callback._metric_monitor import MetricMonitor
from sklearn.callback._mixin import CallbackSupportMixin
from sklearn.callback._progressbar import ProgressBar

Expand All @@ -16,5 +17,6 @@
"Callback",
"CallbackContext",
"CallbackSupportMixin",
"MetricMonitor",
"ProgressBar",
]
16 changes: 10 additions & 6 deletions sklearn/callback/_callback_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import time

from sklearn.callback import AutoPropagatedCallback

# TODO(callbacks): move these explanations into a dedicated user guide.
Expand Down Expand Up @@ -145,8 +147,10 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):

# We don't store the estimator in the context to avoid circular references
# because the estimator already holds a reference to the context.
new_ctx.init_time = time.time()
new_ctx._callbacks = getattr(estimator, "_skl_callbacks", [])
new_ctx.estimator_name = estimator.__class__.__name__
new_ctx.estimator_id = id(estimator)
new_ctx.task_name = task_name
new_ctx.task_id = task_id
new_ctx.parent = None
Expand Down Expand Up @@ -191,8 +195,10 @@ def _from_parent(cls, parent_context, *, task_name, task_id, max_subtasks=None):
"""
new_ctx = cls.__new__(cls)

new_ctx.init_time = time.time()
new_ctx._callbacks = parent_context._callbacks
new_ctx.estimator_name = parent_context.estimator_name
new_ctx.estimator_id = parent_context.estimator_id
new_ctx._estimator_depth = parent_context._estimator_depth
new_ctx.task_name = task_name
new_ctx.task_id = task_id
Expand Down Expand Up @@ -383,16 +389,14 @@ def propagate_callbacks(self, sub_estimator):
)
]

if not callbacks_to_propagate:
return self

# We store the parent context in the sub-estimator to be able to merge the
# task trees of the sub-estimator and the meta-estimator.
sub_estimator._parent_callback_ctx = self

sub_estimator.set_callbacks(
getattr(sub_estimator, "_skl_callbacks", []) + callbacks_to_propagate
)
if callbacks_to_propagate:
sub_estimator.set_callbacks(
getattr(sub_estimator, "_skl_callbacks", []) + callbacks_to_propagate
)

return self

Expand Down
108 changes: 108 additions & 0 deletions sklearn/callback/_metric_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import inspect
import time
from multiprocessing import Manager

import pandas as pd

from sklearn.callback._callback_context import get_context_path


class MetricMonitor:
"""Callback that monitors a metric for each iterative steps of an estimator.

The specified metric function is called on the target values `y` and the predicted
values on the samples `y_pred = estimator.predict(X)` at each iterative step of the
estimator.

Parameters
----------
metric : function
The metric to compute.
metric_params : dict or None, default=None
Additional keyword arguments for the metric function.
on_validation : bool, default=True
Whether to compute the metric on validation data (if True) or training data
(if False).
"""

def __init__(self, metric, metric_params=None, on_validation=True):
self.on_validation = on_validation
self.metric_params = metric_params or dict()
if metric_params is not None:
valid_params = inspect.signature(metric).parameters
invalid_params = [arg for arg in metric_params if arg not in valid_params]
if invalid_params:
raise ValueError(
f"The parameters '{invalid_params}' cannot be used with the"
f" function {metric.__module__}.{metric.__name__}."
)
self.metric_func = metric
self._shared_mem_log = Manager().list()

def on_fit_begin(self, estimator):
if not hasattr(estimator, "predict"):
raise ValueError(
f"Estimator {estimator.__class__} does not have a predict method, which"
" is necessary to use a MetricMonitor callback."
)

def on_fit_task_end(
self, estimator, context, data, from_reconstruction_attributes, **kwargs
):
# TODO: add check to verify we're on the innermost level of the fit loop
# e.g. for the KMeans
X, y = (
(data["X_val"], data["y_val"])
if self.on_validation
else (data["X_train"], data["y_train"])
)
y_pred = from_reconstruction_attributes().predict(X)
metric_value = self.metric_func(y, y_pred, **self.metric_params)
log_item = {self.metric_func.__name__: metric_value}
for depth, ctx in enumerate(get_context_path(context)):
if depth == 0:
timestamp = time.strftime(
"%Y-%m-%d_%H:%M:%S", time.localtime(ctx.init_time)
)
log_item["_run"] = (
f"{ctx.estimator_name}_{ctx.estimator_id}_{timestamp}"
)
prev_task_str = (
f"{ctx.prev_estimator_name}_{ctx.prev_task_name}|"
if ctx.prev_estimator_name is not None
else ""
)
log_item[f"{depth}_{prev_task_str}{ctx.estimator_name}_{ctx.task_name}"] = (
ctx.task_id
)
self._shared_mem_log.append(log_item)

def on_fit_end(self, estimator, context):
pass

def get_logs(self):
"""Generate a pandas Dataframe with the logged values.

Returns
-------
pandas.DataFrame
Multi-index DataFrame with indices corresponding to the task tree.
"""
logs = pd.DataFrame(list(self._shared_mem_log))
log_dict = {}
if not logs.empty:
for run_id in logs["_run"].unique():
run_log = logs.loc[logs["_run"] == run_id].copy()
# Drop columns that correspond to other runs task_id and are filled with
# NaNs, and the run column, but always keep the metric column.
columns_to_keep = ~(run_log.isnull().all())
columns_to_keep["_run"] = False
columns_to_keep[self.metric_func.__name__] = True
run_log = run_log.loc[:, columns_to_keep]
log_dict[run_id] = run_log.set_index(
[col for col in run_log.columns if col != self.metric_func.__name__]
).sort_index()
return log_dict
24 changes: 24 additions & 0 deletions sklearn/callback/_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import copy

from sklearn.callback._base import Callback
from sklearn.callback._callback_context import CallbackContext

Expand Down Expand Up @@ -53,3 +55,25 @@ def __skl_init_callback_context__(self, task_name="fit", max_subtasks=None):
)

return self._callback_fit_ctx

def _from_reconstruction_attributes(self, *, reconstruction_attributes):
"""Return an as if fitted copy of this estimator

Parameters
----------
reconstruction_attributes : callable
A callable that has no arguments and returns the necessary fitted attributes
to create a working fitted estimator from this instance.

Using a callable allows lazy evaluation of the potentially costly
reconstruction attributes.

Returns
-------
fitted_estimator : estimator instance
The fitted copy of this estimator.
"""
new_estimator = copy.copy(self) # XXX deepcopy ?
for key, val in reconstruction_attributes().items():
setattr(new_estimator, key, val)
return new_estimator
82 changes: 65 additions & 17 deletions sklearn/callback/tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# SPDX-License-Identifier: BSD-3-Clause

import time
from functools import partial

import numpy as np

from sklearn.base import BaseEstimator, _fit_context, clone
from sklearn.callback import CallbackSupportMixin
Expand Down Expand Up @@ -45,28 +48,51 @@ class Estimator(CallbackSupportMixin, BaseEstimator):

_parameter_constraints: dict = {}

def __init__(self, max_iter=20, computation_intensity=0.001):
def __init__(self, intercept=0, max_iter=20, computation_intensity=0.001):
self.max_iter = max_iter
self.computation_intensity = computation_intensity
self.intercept = intercept

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
def fit(self, X_train=None, y_train=None, X_val=None, y_val=None):
data = {
"X_train": X_train,
"y_train": y_train,
"X_val": X_val,
"y_val": y_val,
}
callback_ctx = self.__skl_init_callback_context__(
max_subtasks=self.max_iter
).eval_on_fit_begin(estimator=self)

for i in range(self.max_iter):
subcontext = callback_ctx.subcontext(task_id=i)
subcontext = callback_ctx.subcontext(task_id=i, task_name="fit_iter")

time.sleep(self.computation_intensity) # Computation intensive task

if subcontext.eval_on_fit_task_end(
estimator=self,
data={"X_train": X, "y_train": y},
data=data,
from_reconstruction_attributes=partial(
self._from_reconstruction_attributes,
reconstruction_attributes=lambda: {"coef_": i + 1},
),
):
break

self.n_iter_ = i + 1
self.coef_ = i + 1

return self

def predict(self, X):
return np.mean(X, axis=1) * self.coef_ + self.intercept


class EstimatorWithoutPredict(CallbackSupportMixin, BaseEstimator):
_parameter_constraints: dict = {}

def fit(self):
self.__skl_init_callback_context__().eval_on_fit_begin(estimator=self)

return self

Expand All @@ -79,34 +105,51 @@ class WhileEstimator(CallbackSupportMixin, BaseEstimator):

_parameter_constraints: dict = {}

def __init__(self, computation_intensity=0.001):
def __init__(self, intercept=0, max_iter=20, computation_intensity=0.001):
self.intercept = intercept
self.computation_intensity = computation_intensity
self.max_iter = max_iter

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
def fit(self, X_train=None, y_train=None, X_val=None, y_val=None):
data = {
"X_train": X_train,
"y_train": y_train,
"X_val": X_val,
"y_val": y_val,
}
callback_ctx = self.__skl_init_callback_context__().eval_on_fit_begin(
estimator=self
)

i = 0
while True:
subcontext = callback_ctx.subcontext(task_id=i)
subcontext = callback_ctx.subcontext(task_id=i, task_name="fit_iter")

time.sleep(self.computation_intensity) # Computation intensive task

if subcontext.eval_on_fit_task_end(
estimator=self,
data={"X_train": X, "y_train": y},
data=data,
from_reconstruction_attributes=partial(
self._from_reconstruction_attributes,
reconstruction_attributes=lambda: {"coef_": i + 1},
),
):
break

if i == 20:
if i == self.max_iter - 1:
break

i += 1

self.coef_ = i + 1

return self

def predict(self, X):
return np.mean(X, axis=1) * self.coef_ + self.intercept


class MetaEstimator(CallbackSupportMixin, BaseEstimator):
"""A class that mimics the behavior of a meta-estimator.
Expand All @@ -127,7 +170,13 @@ def __init__(
self.prefer = prefer

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
def fit(self, X_train=None, y_train=None, X_val=None, y_val=None):
data = {
"X_train": X_train,
"y_train": y_train,
"X_val": X_val,
"y_val": y_val,
}
callback_ctx = self.__skl_init_callback_context__(
max_subtasks=self.n_outer
).eval_on_fit_begin(estimator=self)
Expand All @@ -136,8 +185,7 @@ def fit(self, X=None, y=None):
delayed(_func)(
self,
self.estimator,
X,
y,
data,
callback_ctx=callback_ctx.subcontext(
task_name="outer", task_id=i, max_subtasks=self.n_inner
),
Expand All @@ -148,22 +196,22 @@ def fit(self, X=None, y=None):
return self


def _func(meta_estimator, inner_estimator, X, y, *, callback_ctx):
def _func(meta_estimator, inner_estimator, data, *, callback_ctx):
for i in range(meta_estimator.n_inner):
est = clone(inner_estimator)
iter_id = callback_ctx.task_id * meta_estimator.n_inner + i
est.intercept = iter_id

inner_ctx = callback_ctx.subcontext(
task_name="inner", task_id=i
).propagate_callbacks(sub_estimator=est)

est.fit(X, y)
est.fit(**data)

inner_ctx.eval_on_fit_task_end(
estimator=meta_estimator,
data={"X_train": X, "y_train": y},
)

callback_ctx.eval_on_fit_task_end(
estimator=meta_estimator,
data={"X_train": X, "y_train": y},
)
Loading
Loading