Skip to content

Commit 547b11d

Browse files
committed
Make subtest arguments serializable
Signed-off-by: Keith Battocchi <[email protected]>
1 parent 1972ace commit 547b11d

File tree

10 files changed

+48
-30
lines changed

10 files changed

+48
-30
lines changed

econml/tests/test_cate_interpreter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_can_use_interpreters(self):
3333
est = LinearDML(model_y=LinearRegression(), model_t=LogisticRegression(), discrete_treatment=True)
3434
est.fit(Y, T, X=X)
3535
for intrp in [SingleTreeCateInterpreter(), SingleTreePolicyInterpreter()]:
36-
with self.subTest(t_shape=t_shape, y_shape=y_shape, intrp=intrp):
36+
with self.subTest(t_shape=t_shape, y_shape=y_shape, intrp=type(intrp).__name__):
3737
with self.assertRaises(Exception):
3838
# prior to calling interpret, can't plot, render, etc.
3939
intrp.plot()
@@ -137,12 +137,12 @@ def test_random_cate_settings(self):
137137
policy_intrp_kwargs.update(sample_treatment_costs=0.1)
138138
elif self.coinflip():
139139
if discrete_t:
140-
policy_intrp_kwargs.update(sample_treatment_costs=np.random.normal(size=(10, 2)))
140+
policy_intrp_kwargs.update(sample_treatment_costs=np.random.normal(size=(10, 2)).tolist())
141141
else:
142142
if self.coinflip():
143-
policy_intrp_kwargs.update(sample_treatment_costs=np.random.normal(size=(10, 1)))
143+
policy_intrp_kwargs.update(sample_treatment_costs=np.random.normal(size=(10, 1)).tolist())
144144
else:
145-
policy_intrp_kwargs.update(sample_treatment_costs=np.random.normal(size=(10,)))
145+
policy_intrp_kwargs.update(sample_treatment_costs=np.random.normal(size=(10,)).tolist())
146146

147147
if self.coinflip():
148148
common_kwargs.update(feature_names=['A', 'B', 'C', 'D'])

econml/tests/test_discrete_outcome.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def test_constraints(self):
204204
]
205205

206206
for est in ests:
207-
with self.subTest(est=est, kind='discrete treatment'):
207+
with self.subTest(est=type(est).__name__, kind='discrete treatment'):
208208
est.discrete_treatment = False
209209
est.model_t = LogisticRegression()
210210
with pytest.raises(AttributeError):
@@ -217,7 +217,7 @@ def test_constraints(self):
217217
ests += [LinearDRLearner()]
218218
for est in ests:
219219
print(est)
220-
with self.subTest(est=est, kind='discrete outcome'):
220+
with self.subTest(est=type(est).__name__, kind='discrete outcome'):
221221
est.discrete_outcome = False
222222
if isinstance(est, LinearDRLearner):
223223
est.model_regression = LogisticRegression()

econml/tests/test_dml.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ def make_random(n, is_discrete, d):
207207

208208
for inf in infs:
209209
with self.subTest(d_w=d_w, d_x=d_x, d_y=d_y, d_t=d_t,
210-
is_discrete=is_discrete, est=est, inf=inf):
210+
is_discrete=is_discrete, est=type(est).__name__,
211+
model_final=repr(getattr(est, 'model_final', None)),
212+
inf=repr(inf)):
211213

212214
if X is None and (not fit_cate_intercept):
213215
with pytest.raises(AttributeError):
@@ -485,7 +487,8 @@ def make_random(is_discrete, d):
485487

486488
for inf in infs:
487489
with self.subTest(d_w=d_w, d_x=d_x, d_y=d_y, d_t=d_t,
488-
is_discrete=is_discrete, est=est, inf=inf):
490+
is_discrete=is_discrete,
491+
featurizer=repr(est.featurizer), inf=repr(inf)):
489492
if X is None:
490493
with pytest.raises(AttributeError):
491494
est.fit(Y, T, X=X, W=W, inference=inf)

econml/tests/test_dmliv.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def eff_shape(n, d_x, d_y):
106106

107107
for est in est_list:
108108
with self.subTest(d_w=d_w, d_x=d_x, binary_T=binary_T, binary_Z=binary_Z,
109-
featurizer=featurizer, est=est):
109+
featurizer=repr(featurizer), est=type(est).__name__,
110+
projection=getattr(est, 'projection', None)):
110111

111112
# ensure we can serialize unfit estimator
112113
pickle.dumps(est)
@@ -195,7 +196,7 @@ def true_fn(X):
195196
return true_ate
196197
y, T, Z, X = dgp(n, p, true_fn)
197198
for est in ests_list:
198-
with self.subTest(est=est):
199+
with self.subTest(projection=est.projection):
199200
est.fit(y, T, Z=Z, X=None, W=X, inference="auto")
200201
ate_lb, ate_ub = est.effect_interval()
201202
np.testing.assert_array_less(ate_lb, true_ate)
@@ -250,7 +251,7 @@ def test_groups(self):
250251
]
251252

252253
for est in est_list:
253-
with self.subTest(est=est):
254+
with self.subTest(est=type(est).__name__, projection=getattr(est, 'projection', None)):
254255
est.fit(y, T, Z=Z, X=X, W=W, groups=groups)
255256
est.score(y, T, Z=Z, X=X, W=W)
256257
est.const_marginal_effect(X)

econml/tests/test_driv.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ def eff_shape(n, d_x):
172172
for est in est_list:
173173
with self.subTest(d_w=d_w, d_x=d_x, binary_T=binary_T, binary_Z=binary_Z,
174174
projection=projection, fit_cov_directly=fit_cov_directly,
175-
featurizer=featurizer,
176-
est=est):
175+
featurizer=repr(featurizer), est=type(est).__name__):
177176

178177
# TODO: serializing/deserializing for every combination -- is this necessary?
179178
# ensure we can serialize unfit estimator
@@ -266,7 +265,7 @@ def dgp(n, p, true_fn):
266265
use_ray=use_ray
267266
)]
268267
for est in ests_list:
269-
with self.subTest(est=est):
268+
with self.subTest(est=type(est).__name__):
270269
# no heterogeneity
271270
n = 1000
272271
p = 10
@@ -415,7 +414,7 @@ def ceil(a, b): # ceiling analog of //
415414
]
416415

417416
for est in est_list:
418-
with self.subTest(est=est):
417+
with self.subTest(est=type(est).__name__):
419418
est.fit(y, T, Z=Z, X=X, W=W, groups=groups)
420419
est.score(y, T, Z=Z, X=X, W=W)
421420
est.const_marginal_effect(X)

econml/tests/test_drlearner.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def make_random(is_discrete, d):
136136

137137
for inf in infs:
138138
with self.subTest(d_w=d_w, d_x=d_x, d_y=d_y, d_t=d_t,
139-
is_discrete=is_discrete, est=est, inf=inf):
139+
is_discrete=is_discrete, est=type(est).__name__, inf=repr(inf)):
140140
est.fit(Y, T, X=X, W=W, inference=inf)
141141

142142
# ensure that we can serialize fit estimator
@@ -453,9 +453,12 @@ def test_drlearner_all_attributes(self):
453453
if (not isinstance(models[2], StatsModelsLinearRegression)) and (sample_var
454454
is not None):
455455
continue
456-
with self.subTest(X=X, W=W, sample_weight=sample_weight, freq_weight=freq_weight,
457-
sample_var=sample_var,
458-
featurizer=featurizer, models=models,
456+
with self.subTest(X_present=(X is not None), W_present=(W is not None),
457+
sample_weight_present=(sample_weight is not None),
458+
freq_weight_present=(freq_weight is not None),
459+
sample_var_present=(sample_var is not None),
460+
featurizer=repr(featurizer),
461+
models=[type(m).__name__ for m in models],
459462
multitask_model_final=multitask_model_final):
460463
est = DRLearner(model_propensity=models[0],
461464
model_regression=models[1],
@@ -577,10 +580,13 @@ def _test_drlearner_with_inference_all_attributes(self, use_ray):
577580
LinearRegression(), SparseLinearDRLearner,
578581
'auto')
579582
]:
580-
with self.subTest(X=X, W=W, sample_weight=sample_weight, freq_weight=freq_weight,
581-
sample_var=sample_var,
582-
featurizer=featurizer, model_y=model_y, model_t=model_t,
583-
est_class=est_class, inference=inference):
583+
with self.subTest(X_present=(X is not None), W_present=(W is not None),
584+
sample_weight_present=(sample_weight is not None),
585+
freq_weight_present=(freq_weight is not None),
586+
sample_var_present=(sample_var is not None),
587+
featurizer=repr(featurizer),
588+
model_y=type(model_y).__name__, model_t=type(model_t).__name__,
589+
est_class=est_class.__name__, inference=repr(inference)):
584590
if (X is None) and (est_class == SparseLinearDRLearner):
585591
continue
586592
if (X is None) and (est_class == ForestDRLearner):

econml/tests/test_dynamic_dml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def make_random(n, is_discrete, d):
9393

9494
for inf in all_infs:
9595
with self.subTest(d_w=d_w, d_x=d_x, d_y=d_y, d_t=d_t,
96-
is_discrete=is_discrete, est=est, inf=inf):
96+
is_discrete=is_discrete, est=type(est).__name__, inf=repr(inf)):
9797

9898
if X is None and (not fit_cate_intercept):
9999
with pytest.raises(AttributeError):

econml/tests/test_missing_values.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,10 @@ def test_missing2(self):
235235
]
236236

237237
for est in x_w_missing_models:
238-
with self.subTest(est=est, kind='missing X and W'):
238+
with self.subTest(est=type(est).__name__,
239+
projection=getattr(est, 'projection', None),
240+
prel_cate_approach=getattr(est, 'prel_cate_approach', None),
241+
kind='missing X and W'):
239242

240243
if 'Z' in inspect.getfullargspec(est.fit).kwonlyargs:
241244
include_Z = True
@@ -265,7 +268,10 @@ def test_missing2(self):
265268
self.assertRaises(ValueError, est.dowhy.fit, **data_dict)
266269

267270
for est in w_missing_models:
268-
with self.subTest(est=est, kind='missing W'):
271+
with self.subTest(est=type(est).__name__,
272+
projection=getattr(est, 'projection', None),
273+
prel_cate_approach=getattr(est, 'prel_cate_approach', None),
274+
kind='missing W'):
269275

270276
if 'Z' in inspect.getfullargspec(est.fit).kwonlyargs:
271277
include_Z = True
@@ -291,7 +297,10 @@ def test_missing2(self):
291297
self.assertRaises(ValueError, est.dowhy.fit, **data_dict)
292298

293299
for est in metalearners:
294-
with self.subTest(est=est, kind='metalearner'):
300+
with self.subTest(est=type(est).__name__,
301+
projection=getattr(est, 'projection', None),
302+
prel_cate_approach=getattr(est, 'prel_cate_approach', None),
303+
kind='metalearner'):
295304

296305
data_dict = create_data_dict(y, T, X, X_missing, W, W_missing, Z,
297306
X_has_missing=True, W_has_missing=False, include_Z=False)

econml/tests/test_model_selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_sklearn_model_selection(self):
130130
for mdl in mdls:
131131
# these models only work on multi-output data
132132
use_array = isinstance(mdl, (MultiTaskElasticNetCV, MultiTaskLassoCV))
133-
with self.subTest(model=mdl):
133+
with self.subTest(model=type(mdl).__name__):
134134
est = LinearDML(model_t=mdl,
135135
discrete_treatment=is_discrete,
136136
model_y=LinearRegression())

econml/tests/test_shap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_continuous_t(self):
3737
), model_t=LinearRegression(), model_final=RandomForestRegressor(), featurizer=featurizer),
3838
]
3939
for est in est_list:
40-
with self.subTest(est=est, featurizer=featurizer, d_y=d_y, d_t=d_t):
40+
with self.subTest(est=type(est).__name__, featurizer=repr(featurizer), d_y=d_y, d_t=d_t):
4141
est.fit(Y, T, X=X, W=W)
4242
shap_values = est.shap_values(X[:10], feature_names=["a", "b", "c"],
4343
background_samples=None)
@@ -95,7 +95,7 @@ def test_discrete_t(self):
9595
DRLearner(multitask_model_final=False, featurizer=featurizer),
9696
ForestDRLearner()]
9797
for est in est_list:
98-
with self.subTest(est=est, featurizer=featurizer, d_y=d_y, d_t=d_t):
98+
with self.subTest(est=type(est).__name__, featurizer=repr(featurizer), d_y=d_y, d_t=d_t):
9999
if isinstance(est, (TLearner, SLearner, XLearner, DomainAdaptationLearner)):
100100
est.fit(Y, T, X=X)
101101
else:

0 commit comments

Comments
 (0)