|
4 | 4 | import heapq |
5 | 5 | import json |
6 | 6 | import logging |
7 | | -from operator import itemgetter |
8 | 7 | import os |
9 | 8 | from copy import deepcopy |
10 | | -from itertools import combinations, count |
11 | | -from math import ceil, isnan |
12 | | -from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Union, Callable |
13 | | -from warnings import warn |
14 | 9 | from dataclasses import dataclass, field |
| 10 | +from itertools import combinations, count, starmap |
| 11 | +from math import ceil, isnan |
15 | 12 | from multiprocessing import shared_memory |
16 | | -import numpy as np |
| 13 | +from operator import itemgetter |
| 14 | +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union |
| 15 | +from warnings import warn |
17 | 16 |
|
| 17 | +import numpy as np |
| 18 | +from joblib import Parallel, delayed |
18 | 19 | from sklearn.base import ( |
19 | 20 | BaseEstimator, |
20 | 21 | ClassifierMixin, |
|
28 | 29 | from ... import develop |
29 | 30 | from ...api.base import ExplainerMixin |
30 | 31 | from ...api.templates import FeatureValueExplanation |
31 | | -from ...provider import JobLibProvider |
32 | 32 | from ...utils._clean_simple import ( |
33 | 33 | clean_dimensions, |
34 | 34 | clean_X_and_init_score, |
|
44 | 44 | ) |
45 | 45 | from ...utils._histogram import make_all_histogram_edges |
46 | 46 | from ...utils._link import inv_link, link_func |
| 47 | +from ...utils._measure_mem import total_bytes |
47 | 48 | from ...utils._misc import clean_index, clean_indexes |
48 | 49 | from ...utils._native import Native |
49 | 50 | from ...utils._preprocessor import construct_bins |
|
54 | 55 | ) |
55 | 56 | from ...utils._rank_interactions import rank_interactions |
56 | 57 | from ...utils._seed import normalize_seed |
| 58 | +from ...utils._shared_dataset import SharedDataset |
57 | 59 | from ...utils._unify_data import unify_data |
58 | 60 | from ._bin import ( |
59 | 61 | ebm_eval_terms, |
|
71 | 73 | process_terms, |
72 | 74 | remove_extra_bins, |
73 | 75 | ) |
74 | | -from ...utils._shared_dataset import SharedDataset |
75 | | -from ...utils._measure_mem import total_bytes |
76 | 76 |
|
77 | 77 | _log = logging.getLogger(__name__) |
78 | 78 |
|
@@ -1053,7 +1053,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): |
1053 | 1053 |
|
1054 | 1054 | exclude_features = {i for i, v in enumerate(monotone_constraints) if v != 0} |
1055 | 1055 |
|
1056 | | - provider = JobLibProvider(n_jobs=self.n_jobs) |
| 1056 | + parallel = Parallel(n_jobs=self.n_jobs) |
1057 | 1057 |
|
1058 | 1058 | bagged_intercept = np.zeros((self.outer_bags, n_scores), np.float64) |
1059 | 1059 | if not is_differential_privacy: |
@@ -1208,7 +1208,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): |
1208 | 1208 | ) |
1209 | 1209 | ) |
1210 | 1210 |
|
1211 | | - results = provider.parallel(boost, parallel_args) |
| 1211 | + results = parallel(starmap(delayed(boost), parallel_args)) |
1212 | 1212 |
|
1213 | 1213 | # let python reclaim the dataset memory via reference counting |
1214 | 1214 | # parallel_args holds references to dataset, so must be deleted |
@@ -1320,8 +1320,8 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): |
1320 | 1320 | ) |
1321 | 1321 | ) |
1322 | 1322 |
|
1323 | | - bagged_ranked_interaction = provider.parallel( |
1324 | | - rank_interactions, parallel_args |
| 1323 | + bagged_ranked_interaction = parallel( |
| 1324 | + starmap(delayed(rank_interactions), parallel_args) |
1325 | 1325 | ) |
1326 | 1326 |
|
1327 | 1327 | # this holds references to dataset, internal_bags, and scores_bags which we want python to reclaim later |
@@ -1480,7 +1480,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None): |
1480 | 1480 | ) |
1481 | 1481 | ) |
1482 | 1482 |
|
1483 | | - results = provider.parallel(boost, parallel_args) |
| 1483 | + results = parallel(starmap(delayed(boost), parallel_args)) |
1484 | 1484 |
|
1485 | 1485 | # allow python to reclaim these big memory items via reference counting |
1486 | 1486 | # this holds references to dataset, scores_bags, and bags |
|
0 commit comments