diff --git a/sklearn/base.py b/sklearn/base.py index 9abd145ce37a6..f85678028f2f9 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -136,9 +136,7 @@ def _clone_parametrized(estimator, *, safe=True): # attach callbacks to the new estimator if hasattr(estimator, "_skl_callbacks"): - # TODO(callbacks): Figure out the exact behavior we want when cloning an - # estimator with callbacks. - new_object._skl_callbacks = clone(estimator._skl_callbacks, safe=False) + new_object._skl_callbacks = estimator._skl_callbacks # quick sanity check of the parameters of the clone for name in new_object_params: diff --git a/sklearn/callback/__init__.py b/sklearn/callback/__init__.py index 22adc8cf4e4cc..4cf1c54d340a2 100644 --- a/sklearn/callback/__init__.py +++ b/sklearn/callback/__init__.py @@ -8,6 +8,7 @@ from sklearn.callback._base import AutoPropagatedCallback, Callback from sklearn.callback._callback_context import CallbackContext +from sklearn.callback._metric_monitor import MetricMonitor from sklearn.callback._mixin import CallbackSupportMixin from sklearn.callback._progressbar import ProgressBar @@ -16,5 +17,6 @@ "Callback", "CallbackContext", "CallbackSupportMixin", + "MetricMonitor", "ProgressBar", ] diff --git a/sklearn/callback/_callback_context.py b/sklearn/callback/_callback_context.py index c9ffc590f965d..cbcbacdd83140 100644 --- a/sklearn/callback/_callback_context.py +++ b/sklearn/callback/_callback_context.py @@ -1,6 +1,8 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import time + from sklearn.callback import AutoPropagatedCallback # TODO(callbacks): move these explanations into a dedicated user guide. @@ -145,8 +147,10 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None): # We don't store the estimator in the context to avoid circular references # because the estimator already holds a reference to the context. + new_ctx.init_time = time.time() new_ctx._callbacks = getattr(estimator, "_skl_callbacks", []) new_ctx.estimator_name = estimator.__class__.__name__ + new_ctx.estimator_id = id(estimator) new_ctx.task_name = task_name new_ctx.task_id = task_id new_ctx.parent = None @@ -191,8 +195,10 @@ def _from_parent(cls, parent_context, *, task_name, task_id, max_subtasks=None): """ new_ctx = cls.__new__(cls) + new_ctx.init_time = time.time() new_ctx._callbacks = parent_context._callbacks new_ctx.estimator_name = parent_context.estimator_name + new_ctx.estimator_id = parent_context.estimator_id new_ctx._estimator_depth = parent_context._estimator_depth new_ctx.task_name = task_name new_ctx.task_id = task_id @@ -383,16 +389,14 @@ def propagate_callbacks(self, sub_estimator): ) ] - if not callbacks_to_propagate: - return self - # We store the parent context in the sub-estimator to be able to merge the # task trees of the sub-estimator and the meta-estimator. sub_estimator._parent_callback_ctx = self - sub_estimator.set_callbacks( - getattr(sub_estimator, "_skl_callbacks", []) + callbacks_to_propagate - ) + if callbacks_to_propagate: + sub_estimator.set_callbacks( + getattr(sub_estimator, "_skl_callbacks", []) + callbacks_to_propagate + ) return self diff --git a/sklearn/callback/_metric_monitor.py b/sklearn/callback/_metric_monitor.py new file mode 100644 index 0000000000000..7936cfb462759 --- /dev/null +++ b/sklearn/callback/_metric_monitor.py @@ -0,0 +1,108 @@ +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause + +import inspect +import time +from multiprocessing import Manager + +import pandas as pd + +from sklearn.callback._callback_context import get_context_path + + +class MetricMonitor: + """Callback that monitors a metric for each iterative steps of an estimator. + + The specified metric function is called on the target values `y` and the predicted + values on the samples `y_pred = estimator.predict(X)` at each iterative step of the + estimator. + + Parameters + ---------- + metric : function + The metric to compute. + metric_params : dict or None, default=None + Additional keyword arguments for the metric function. + on_validation : bool, default=True + Whether to compute the metric on validation data (if True) or training data + (if False). + """ + + def __init__(self, metric, metric_params=None, on_validation=True): + self.on_validation = on_validation + self.metric_params = metric_params or dict() + if metric_params is not None: + valid_params = inspect.signature(metric).parameters + invalid_params = [arg for arg in metric_params if arg not in valid_params] + if invalid_params: + raise ValueError( + f"The parameters '{invalid_params}' cannot be used with the" + f" function {metric.__module__}.{metric.__name__}." + ) + self.metric_func = metric + self._shared_mem_log = Manager().list() + + def on_fit_begin(self, estimator): + if not hasattr(estimator, "predict"): + raise ValueError( + f"Estimator {estimator.__class__} does not have a predict method, which" + " is necessary to use a MetricMonitor callback." + ) + + def on_fit_task_end( + self, estimator, context, data, from_reconstruction_attributes, **kwargs + ): + # TODO: add check to verify we're on the innermost level of the fit loop + # e.g. for the KMeans + X, y = ( + (data["X_val"], data["y_val"]) + if self.on_validation + else (data["X_train"], data["y_train"]) + ) + y_pred = from_reconstruction_attributes().predict(X) + metric_value = self.metric_func(y, y_pred, **self.metric_params) + log_item = {self.metric_func.__name__: metric_value} + for depth, ctx in enumerate(get_context_path(context)): + if depth == 0: + timestamp = time.strftime( + "%Y-%m-%d_%H:%M:%S", time.localtime(ctx.init_time) + ) + log_item["_run"] = ( + f"{ctx.estimator_name}_{ctx.estimator_id}_{timestamp}" + ) + prev_task_str = ( + f"{ctx.prev_estimator_name}_{ctx.prev_task_name}|" + if ctx.prev_estimator_name is not None + else "" + ) + log_item[f"{depth}_{prev_task_str}{ctx.estimator_name}_{ctx.task_name}"] = ( + ctx.task_id + ) + self._shared_mem_log.append(log_item) + + def on_fit_end(self, estimator, context): + pass + + def get_logs(self): + """Generate a pandas Dataframe with the logged values. + + Returns + ------- + pandas.DataFrame + Multi-index DataFrame with indices corresponding to the task tree. + """ + logs = pd.DataFrame(list(self._shared_mem_log)) + log_dict = {} + if not logs.empty: + for run_id in logs["_run"].unique(): + run_log = logs.loc[logs["_run"] == run_id].copy() + # Drop columns that correspond to other runs task_id and are filled with + # NaNs, and the run column, but always keep the metric column. + columns_to_keep = ~(run_log.isnull().all()) + columns_to_keep["_run"] = False + columns_to_keep[self.metric_func.__name__] = True + run_log = run_log.loc[:, columns_to_keep] + log_dict[run_id] = run_log.set_index( + [col for col in run_log.columns if col != self.metric_func.__name__] + ).sort_index() + return log_dict diff --git a/sklearn/callback/_mixin.py b/sklearn/callback/_mixin.py index 2827d05e930d7..d940937b4561b 100644 --- a/sklearn/callback/_mixin.py +++ b/sklearn/callback/_mixin.py @@ -1,6 +1,8 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import copy + from sklearn.callback._base import Callback from sklearn.callback._callback_context import CallbackContext @@ -53,3 +55,25 @@ def __skl_init_callback_context__(self, task_name="fit", max_subtasks=None): ) return self._callback_fit_ctx + + def _from_reconstruction_attributes(self, *, reconstruction_attributes): + """Return an as if fitted copy of this estimator + + Parameters + ---------- + reconstruction_attributes : callable + A callable that has no arguments and returns the necessary fitted attributes + to create a working fitted estimator from this instance. + + Using a callable allows lazy evaluation of the potentially costly + reconstruction attributes. + + Returns + ------- + fitted_estimator : estimator instance + The fitted copy of this estimator. + """ + new_estimator = copy.copy(self) # XXX deepcopy ? + for key, val in reconstruction_attributes().items(): + setattr(new_estimator, key, val) + return new_estimator diff --git a/sklearn/callback/tests/_utils.py b/sklearn/callback/tests/_utils.py index 21a011e02afd8..40571ecc7cdc5 100644 --- a/sklearn/callback/tests/_utils.py +++ b/sklearn/callback/tests/_utils.py @@ -2,6 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause import time +from functools import partial + +import numpy as np from sklearn.base import BaseEstimator, _fit_context, clone from sklearn.callback import CallbackSupportMixin @@ -45,28 +48,51 @@ class Estimator(CallbackSupportMixin, BaseEstimator): _parameter_constraints: dict = {} - def __init__(self, max_iter=20, computation_intensity=0.001): + def __init__(self, intercept=0, max_iter=20, computation_intensity=0.001): self.max_iter = max_iter self.computation_intensity = computation_intensity + self.intercept = intercept @_fit_context(prefer_skip_nested_validation=False) - def fit(self, X=None, y=None): + def fit(self, X_train=None, y_train=None, X_val=None, y_val=None): + data = { + "X_train": X_train, + "y_train": y_train, + "X_val": X_val, + "y_val": y_val, + } callback_ctx = self.__skl_init_callback_context__( max_subtasks=self.max_iter ).eval_on_fit_begin(estimator=self) for i in range(self.max_iter): - subcontext = callback_ctx.subcontext(task_id=i) + subcontext = callback_ctx.subcontext(task_id=i, task_name="fit_iter") time.sleep(self.computation_intensity) # Computation intensive task if subcontext.eval_on_fit_task_end( estimator=self, - data={"X_train": X, "y_train": y}, + data=data, + from_reconstruction_attributes=partial( + self._from_reconstruction_attributes, + reconstruction_attributes=lambda: {"coef_": i + 1}, + ), ): break - self.n_iter_ = i + 1 + self.coef_ = i + 1 + + return self + + def predict(self, X): + return np.mean(X, axis=1) * self.coef_ + self.intercept + + +class EstimatorWithoutPredict(CallbackSupportMixin, BaseEstimator): + _parameter_constraints: dict = {} + + def fit(self): + self.__skl_init_callback_context__().eval_on_fit_begin(estimator=self) return self @@ -79,34 +105,51 @@ class WhileEstimator(CallbackSupportMixin, BaseEstimator): _parameter_constraints: dict = {} - def __init__(self, computation_intensity=0.001): + def __init__(self, intercept=0, max_iter=20, computation_intensity=0.001): + self.intercept = intercept self.computation_intensity = computation_intensity + self.max_iter = max_iter @_fit_context(prefer_skip_nested_validation=False) - def fit(self, X=None, y=None): + def fit(self, X_train=None, y_train=None, X_val=None, y_val=None): + data = { + "X_train": X_train, + "y_train": y_train, + "X_val": X_val, + "y_val": y_val, + } callback_ctx = self.__skl_init_callback_context__().eval_on_fit_begin( estimator=self ) i = 0 while True: - subcontext = callback_ctx.subcontext(task_id=i) + subcontext = callback_ctx.subcontext(task_id=i, task_name="fit_iter") time.sleep(self.computation_intensity) # Computation intensive task if subcontext.eval_on_fit_task_end( estimator=self, - data={"X_train": X, "y_train": y}, + data=data, + from_reconstruction_attributes=partial( + self._from_reconstruction_attributes, + reconstruction_attributes=lambda: {"coef_": i + 1}, + ), ): break - if i == 20: + if i == self.max_iter - 1: break i += 1 + self.coef_ = i + 1 + return self + def predict(self, X): + return np.mean(X, axis=1) * self.coef_ + self.intercept + class MetaEstimator(CallbackSupportMixin, BaseEstimator): """A class that mimics the behavior of a meta-estimator. @@ -127,7 +170,13 @@ def __init__( self.prefer = prefer @_fit_context(prefer_skip_nested_validation=False) - def fit(self, X=None, y=None): + def fit(self, X_train=None, y_train=None, X_val=None, y_val=None): + data = { + "X_train": X_train, + "y_train": y_train, + "X_val": X_val, + "y_val": y_val, + } callback_ctx = self.__skl_init_callback_context__( max_subtasks=self.n_outer ).eval_on_fit_begin(estimator=self) @@ -136,8 +185,7 @@ def fit(self, X=None, y=None): delayed(_func)( self, self.estimator, - X, - y, + data, callback_ctx=callback_ctx.subcontext( task_name="outer", task_id=i, max_subtasks=self.n_inner ), @@ -148,22 +196,22 @@ def fit(self, X=None, y=None): return self -def _func(meta_estimator, inner_estimator, X, y, *, callback_ctx): +def _func(meta_estimator, inner_estimator, data, *, callback_ctx): for i in range(meta_estimator.n_inner): est = clone(inner_estimator) + iter_id = callback_ctx.task_id * meta_estimator.n_inner + i + est.intercept = iter_id inner_ctx = callback_ctx.subcontext( task_name="inner", task_id=i ).propagate_callbacks(sub_estimator=est) - est.fit(X, y) + est.fit(**data) inner_ctx.eval_on_fit_task_end( estimator=meta_estimator, - data={"X_train": X, "y_train": y}, ) callback_ctx.eval_on_fit_task_end( estimator=meta_estimator, - data={"X_train": X, "y_train": y}, ) diff --git a/sklearn/callback/tests/test_callback_context.py b/sklearn/callback/tests/test_callback_context.py index e61a21d9461d5..635c7c6720a96 100644 --- a/sklearn/callback/tests/test_callback_context.py +++ b/sklearn/callback/tests/test_callback_context.py @@ -56,7 +56,7 @@ def test_auto_propagated_callbacks(): r"sub-estimator .*of a meta-estimator .*can't have auto-propagated callbacks" ) with pytest.raises(TypeError, match=match): - meta_estimator.fit(X=None, y=None) + meta_estimator.fit() def _make_task_tree(n_children, n_grandchildren): diff --git a/sklearn/callback/tests/test_metric_monitor.py b/sklearn/callback/tests/test_metric_monitor.py new file mode 100644 index 0000000000000..003f3675541fb --- /dev/null +++ b/sklearn/callback/tests/test_metric_monitor.py @@ -0,0 +1,186 @@ +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause + +from itertools import product + +import numpy as np +import pandas as pd +import pytest + +from sklearn.callback import MetricMonitor +from sklearn.callback.tests._utils import ( + Estimator, + EstimatorWithoutPredict, + MetaEstimator, + WhileEstimator, +) +from sklearn.metrics import mean_pinball_loss, mean_squared_error + + +@pytest.mark.parametrize("EstimatorClass", [Estimator, WhileEstimator]) +@pytest.mark.parametrize( + "metric, metric_params", + [(mean_squared_error, None), (mean_pinball_loss, {"alpha": 0.6})], +) +def test_metric_monitor(EstimatorClass, metric, metric_params): + max_iter = 3 + n_dim = 5 + n_samples = 3 + intercept = 1 + estimator = EstimatorClass(intercept=intercept, max_iter=max_iter) + callback_train = MetricMonitor( + metric, metric_params=metric_params, on_validation=False + ) + callback_val = MetricMonitor( + metric, metric_params=metric_params, on_validation=True + ) + estimator.set_callbacks([callback_train, callback_val]) + rng = np.random.RandomState(0) + X_train, y_train = rng.uniform(size=(n_dim, n_samples)), rng.uniform(size=n_dim) + X_val, y_val = rng.uniform(size=(n_dim, n_samples)), rng.uniform(size=n_dim) + + estimator.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val) + + metric_params = metric_params or dict() + log_train = callback_train.get_logs() + assert len(log_train) == 1 + run_id_train, log_train = next(iter(log_train.items())) + assert f"{estimator.__class__.__name__}_{id(estimator)}_" in run_id_train + + expected_log_train = pd.DataFrame( + [ + { + f"0_{estimator.__class__.__name__}_fit": 0, + f"1_{estimator.__class__.__name__}_fit_iter": i, + metric.__name__: metric( + y_train, X_train.mean(axis=1) * (i + 1) + intercept, **metric_params + ), + } + for i in range(max_iter) + ] + ) + expected_log_train = expected_log_train.set_index( + [col for col in expected_log_train.columns if col != metric.__name__] + ) + assert log_train.equals(expected_log_train) + assert np.array_equal(log_train.index.names, expected_log_train.index.names) + + log_val = callback_val.get_logs() + assert len(log_val) == 1 + run_id_val, log_val = next(iter(log_val.items())) + assert f"{estimator.__class__.__name__}_{id(estimator)}_" in run_id_val + + expected_log_val = pd.DataFrame( + [ + { + f"0_{estimator.__class__.__name__}_fit": 0, + f"1_{estimator.__class__.__name__}_fit_iter": i, + metric.__name__: metric( + y_val, X_val.mean(axis=1) * (i + 1) + intercept, **metric_params + ), + } + for i in range(max_iter) + ] + ) + expected_log_val = expected_log_val.set_index( + [col for col in expected_log_val.columns if col != metric.__name__] + ) + assert log_val.equals(expected_log_val) + assert np.array_equal(log_val.index.names, expected_log_val.index.names) + + +def test_no_predict_error(): + estimator = EstimatorWithoutPredict() + callback = MetricMonitor(mean_pinball_loss, metric_params={"alpha": 0.6}) + estimator.set_callbacks(callback) + + with pytest.raises(ValueError, match="does not have a predict method"): + estimator.fit() + + +def test_wrong_kwarg_error(): + with pytest.raises(ValueError, match="cannot be used with the function"): + MetricMonitor(mean_pinball_loss, metric_params={"wrong_name": 0.6}) + + +@pytest.mark.parametrize("prefer", ["processes", "threads"]) +@pytest.mark.parametrize( + "metric, metric_params", + [(mean_squared_error, None), (mean_pinball_loss, {"alpha": 0.6})], +) +def test_within_meta_estimator(prefer, metric, metric_params): + n_outer = 3 + n_inner = 2 + max_iter = 4 + n_dim = 5 + n_samples = 3 + rng = np.random.RandomState(0) + X_train, y_train = rng.uniform(size=(n_dim, n_samples)), rng.uniform(size=n_dim) + X_val, y_val = rng.uniform(size=(n_dim, n_samples)), rng.uniform(size=n_dim) + callback_train = MetricMonitor( + metric, metric_params=metric_params, on_validation=False + ) + callback_val = MetricMonitor( + metric, metric_params=metric_params, on_validation=True + ) + est = Estimator(max_iter=max_iter) + est.set_callbacks([callback_train, callback_val]) + meta_est = MetaEstimator( + est, n_outer=n_outer, n_inner=n_inner, n_jobs=2, prefer=prefer + ) + + meta_est.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val) + + metric_params = metric_params or dict() + expected_log_train = [] + expected_log_val = [] + for i_outer, i_inner in product(range(n_outer), range(n_inner)): + est = Estimator(intercept=i_outer * n_inner + i_inner) + for i_estimator_fit_iter in range(max_iter): + setattr(est, "coef_", i_estimator_fit_iter + 1) + expected_log_train.append( + { + metric.__name__: metric( + y_train, est.predict(X_train), **metric_params + ), + f"0_{meta_est.__class__.__name__}_fit": 0, + f"1_{meta_est.__class__.__name__}_outer": i_outer, + f"2_{meta_est.__class__.__name__}_inner|" + f"{est.__class__.__name__}_fit": i_inner, + f"3_{est.__class__.__name__}_fit_iter": i_estimator_fit_iter, + } + ) + expected_log_val.append( + { + metric.__name__: metric(y_val, est.predict(X_val), **metric_params), + f"0_{meta_est.__class__.__name__}_fit": 0, + f"1_{meta_est.__class__.__name__}_outer": i_outer, + f"2_{meta_est.__class__.__name__}_inner|" + f"{est.__class__.__name__}_fit": i_inner, + f"3_{est.__class__.__name__}_fit_iter": i_estimator_fit_iter, + } + ) + expected_log_train = pd.DataFrame(expected_log_train) + expected_log_train = expected_log_train.set_index( + [col for col in expected_log_train.columns if col != metric.__name__] + ) + expected_log_val = pd.DataFrame(expected_log_val) + expected_log_val = expected_log_val.set_index( + [col for col in expected_log_val.columns if col != metric.__name__] + ) + + log_train = callback_train.get_logs() + assert len(log_train) == 1 + run_id_train, log_train = next(iter(log_train.items())) + log_val = callback_val.get_logs() + assert len(log_val) == 1 + run_id_val, log_val = next(iter(log_val.items())) + + assert f"{meta_est.__class__.__name__}_{id(meta_est)}_" in run_id_train + assert f"{meta_est.__class__.__name__}_{id(meta_est)}_" in run_id_val + assert len(log_train) == len(expected_log_train) + assert len(log_val) == len(expected_log_val) + assert np.array_equal(log_train.index.names, expected_log_train.index.names) + assert np.array_equal(log_val.index.names, expected_log_val.index.names) + assert log_train.equals(expected_log_train) + assert log_val.equals(expected_log_val)