Skip to content

Commit ecbe778

Browse files
authored
feat: Allow for disabling backend when running tests via pytest (#2340)
* Add pytest configuration `--disable-backend` that will skip tests that use that backend choice.
1 parent d4f6b30 commit ecbe778

File tree

1 file changed

+38
-10
lines changed

1 file changed

+38
-10
lines changed

tests/conftest.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@
99
import pyhf
1010

1111

12+
def pytest_addoption(parser):
13+
parser.addoption(
14+
"--disable-backend",
15+
action="append",
16+
type=str,
17+
default=[],
18+
choices=["tensorflow", "pytorch", "jax", "minuit"],
19+
help="list of backends to disable in tests",
20+
)
21+
22+
1223
# Factory as fixture pattern
1324
@pytest.fixture
1425
def get_json_from_tarfile():
@@ -59,14 +70,14 @@ def reset_backend():
5970
@pytest.fixture(
6071
scope='function',
6172
params=[
62-
(pyhf.tensor.numpy_backend(), None),
63-
(pyhf.tensor.pytorch_backend(), None),
64-
(pyhf.tensor.pytorch_backend(precision='64b'), None),
65-
(pyhf.tensor.tensorflow_backend(), None),
66-
(pyhf.tensor.jax_backend(), None),
73+
(("numpy_backend", dict()), ("scipy_optimizer", dict())),
74+
(("pytorch_backend", dict()), ("scipy_optimizer", dict())),
75+
(("pytorch_backend", dict(precision="64b")), ("scipy_optimizer", dict())),
76+
(("tensorflow_backend", dict()), ("scipy_optimizer", dict())),
77+
(("jax_backend", dict()), ("scipy_optimizer", dict())),
6778
(
68-
pyhf.tensor.numpy_backend(poisson_from_normal=True),
69-
pyhf.optimize.minuit_optimizer(),
79+
("numpy_backend", dict(poisson_from_normal=True)),
80+
("minuit_optimizer", dict()),
7081
),
7182
],
7283
ids=['numpy', 'pytorch', 'pytorch64', 'tensorflow', 'jax', 'numpy_minuit'],
@@ -87,13 +98,26 @@ def backend(request):
8798
only_backends = [
8899
pid for pid in param_ids if request.node.get_closest_marker(f'only_{pid}')
89100
]
101+
disable_backend = any(
102+
backend in param_id for backend in request.config.option.disable_backend
103+
)
90104

91105
if skip_backend and (param_id in only_backends):
92106
raise ValueError(
93107
f"Must specify skip_{param_id} or only_{param_id} but not both!"
94108
)
95109

96-
if skip_backend:
110+
if disable_backend:
111+
pytest.skip(
112+
f"skipping {func_name} as the backend is disabled via "
113+
+ " ".join(
114+
[
115+
f"--disable-backend {choice}"
116+
for choice in request.config.option.disable_backend
117+
]
118+
)
119+
)
120+
elif skip_backend:
97121
pytest.skip(f"skipping {func_name} as specified")
98122
elif only_backends and param_id not in only_backends:
99123
pytest.skip(
@@ -109,10 +133,14 @@ def backend(request):
109133
pytest.mark.xfail(reason=f"expect {func_name} to fail as specified")
110134
)
111135

136+
tensor_config, optimizer_config = request.param
137+
138+
tensor = getattr(pyhf.tensor, tensor_config[0])(**tensor_config[1])
139+
optimizer = getattr(pyhf.optimize, optimizer_config[0])(**optimizer_config[1])
112140
# actual execution here, after all checks is done
113-
pyhf.set_backend(*request.param)
141+
pyhf.set_backend(tensor, optimizer)
114142

115-
yield request.param
143+
yield (tensor, optimizer)
116144

117145

118146
@pytest.fixture(

0 commit comments

Comments
 (0)