Skip to content

Commit 97f29e4

Browse files
authored
MNT: Use bare joblib.Parallel (#631)
* MNT: sort imports Signed-off-by: DerWeh <[email protected]> * MNT: drop abstraction around joblib.Parallel Joblib already offers a system to add custom backends: https://joblib.readthedocs.io/en/stable/custom_parallel_backend.html Signed-off-by: DerWeh <[email protected]> --------- Signed-off-by: DerWeh <[email protected]>
1 parent aac2ff5 commit 97f29e4

File tree

4 files changed

+14
-50
lines changed

4 files changed

+14
-50
lines changed

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@
44
import heapq
55
import json
66
import logging
7-
from operator import itemgetter
87
import os
98
from copy import deepcopy
10-
from itertools import combinations, count
11-
from math import ceil, isnan
12-
from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Union, Callable
13-
from warnings import warn
149
from dataclasses import dataclass, field
10+
from itertools import combinations, count, starmap
11+
from math import ceil, isnan
1512
from multiprocessing import shared_memory
16-
import numpy as np
13+
from operator import itemgetter
14+
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
15+
from warnings import warn
1716

17+
import numpy as np
18+
from joblib import Parallel, delayed
1819
from sklearn.base import (
1920
BaseEstimator,
2021
ClassifierMixin,
@@ -28,7 +29,6 @@
2829
from ... import develop
2930
from ...api.base import ExplainerMixin
3031
from ...api.templates import FeatureValueExplanation
31-
from ...provider import JobLibProvider
3232
from ...utils._clean_simple import (
3333
clean_dimensions,
3434
clean_X_and_init_score,
@@ -44,6 +44,7 @@
4444
)
4545
from ...utils._histogram import make_all_histogram_edges
4646
from ...utils._link import inv_link, link_func
47+
from ...utils._measure_mem import total_bytes
4748
from ...utils._misc import clean_index, clean_indexes
4849
from ...utils._native import Native
4950
from ...utils._preprocessor import construct_bins
@@ -54,6 +55,7 @@
5455
)
5556
from ...utils._rank_interactions import rank_interactions
5657
from ...utils._seed import normalize_seed
58+
from ...utils._shared_dataset import SharedDataset
5759
from ...utils._unify_data import unify_data
5860
from ._bin import (
5961
ebm_eval_terms,
@@ -71,8 +73,6 @@
7173
process_terms,
7274
remove_extra_bins,
7375
)
74-
from ...utils._shared_dataset import SharedDataset
75-
from ...utils._measure_mem import total_bytes
7676

7777
_log = logging.getLogger(__name__)
7878

@@ -1053,7 +1053,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
10531053

10541054
exclude_features = {i for i, v in enumerate(monotone_constraints) if v != 0}
10551055

1056-
provider = JobLibProvider(n_jobs=self.n_jobs)
1056+
parallel = Parallel(n_jobs=self.n_jobs)
10571057

10581058
bagged_intercept = np.zeros((self.outer_bags, n_scores), np.float64)
10591059
if not is_differential_privacy:
@@ -1208,7 +1208,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
12081208
)
12091209
)
12101210

1211-
results = provider.parallel(boost, parallel_args)
1211+
results = parallel(starmap(delayed(boost), parallel_args))
12121212

12131213
# let python reclaim the dataset memory via reference counting
12141214
# parallel_args holds references to dataset, so must be deleted
@@ -1320,8 +1320,8 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
13201320
)
13211321
)
13221322

1323-
bagged_ranked_interaction = provider.parallel(
1324-
rank_interactions, parallel_args
1323+
bagged_ranked_interaction = parallel(
1324+
starmap(delayed(rank_interactions), parallel_args)
13251325
)
13261326

13271327
# this holds references to dataset, internal_bags, and scores_bags which we want python to reclaim later
@@ -1480,7 +1480,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
14801480
)
14811481
)
14821482

1483-
results = provider.parallel(boost, parallel_args)
1483+
results = parallel(starmap(delayed(boost), parallel_args))
14841484

14851485
# allow python to reclaim these big memory items via reference counting
14861486
# this holds references to dataset, scores_bags, and bags

python/interpret-core/interpret/provider/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) 2023 The InterpretML Contributors
22
# Distributed under the MIT software license
33

4-
from ._compute import ComputeProvider, JobLibProvider # noqa: F401
54
from ._visualize import ( # noqa: F401
65
AutoVisualizeProvider,
76
DashProvider,

python/interpret-core/interpret/provider/_compute.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

python/interpret-core/tests/provider/test_providers.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
AutoVisualizeProvider,
88
DashProvider,
99
InlineProvider,
10-
JobLibProvider,
1110
PreserveProvider,
1211
)
1312

@@ -29,12 +28,6 @@ def example_explanation():
2928
return explainer.explain_local(data["test"]["X"].head(), data["test"]["y"].head())
3029

3130

32-
def test_joblib_provider():
33-
provider = JobLibProvider()
34-
results = provider.parallel(task_fn, task_args_iter)
35-
assert results == [2, 4, 6]
36-
37-
3831
@pytest.mark.slow
3932
def test_auto_visualize_provider(example_explanation):
4033
# NOTE: We know this environment is going to use Dash.

0 commit comments

Comments
 (0)