99import pyhf
1010
1111
12+ def pytest_addoption (parser ):
13+ parser .addoption (
14+ "--disable-backend" ,
15+ action = "append" ,
16+ type = str ,
17+ default = [],
18+ choices = ["tensorflow" , "pytorch" , "jax" , "minuit" ],
19+ help = "list of backends to disable in tests" ,
20+ )
21+
22+
1223# Factory as fixture pattern
1324@pytest .fixture
1425def get_json_from_tarfile ():
@@ -59,14 +70,14 @@ def reset_backend():
5970@pytest .fixture (
6071 scope = 'function' ,
6172 params = [
62- (pyhf . tensor . numpy_backend (), None ),
63- (pyhf . tensor . pytorch_backend (), None ),
64- (pyhf . tensor . pytorch_backend (precision = ' 64b' ), None ),
65- (pyhf . tensor . tensorflow_backend (), None ),
66- (pyhf . tensor . jax_backend (), None ),
73+ (( " numpy_backend" , dict ()), ( "scipy_optimizer" , dict ()) ),
74+ (( " pytorch_backend" , dict ()), ( "scipy_optimizer" , dict ()) ),
75+ (( " pytorch_backend" , dict (precision = " 64b" )), ( "scipy_optimizer" , dict ()) ),
76+ (( " tensorflow_backend" , dict ()), ( "scipy_optimizer" , dict ()) ),
77+ (( " jax_backend" , dict ()), ( "scipy_optimizer" , dict ()) ),
6778 (
68- pyhf . tensor . numpy_backend (poisson_from_normal = True ),
69- pyhf . optimize . minuit_optimizer ( ),
79+ ( " numpy_backend" , dict (poisson_from_normal = True ) ),
80+ ( " minuit_optimizer" , dict () ),
7081 ),
7182 ],
7283 ids = ['numpy' , 'pytorch' , 'pytorch64' , 'tensorflow' , 'jax' , 'numpy_minuit' ],
@@ -87,13 +98,26 @@ def backend(request):
8798 only_backends = [
8899 pid for pid in param_ids if request .node .get_closest_marker (f'only_{ pid } ' )
89100 ]
101+ disable_backend = any (
102+ backend in param_id for backend in request .config .option .disable_backend
103+ )
90104
91105 if skip_backend and (param_id in only_backends ):
92106 raise ValueError (
93107 f"Must specify skip_{ param_id } or only_{ param_id } but not both!"
94108 )
95109
96- if skip_backend :
110+ if disable_backend :
111+ pytest .skip (
112+ f"skipping { func_name } as the backend is disabled via "
113+ + " " .join (
114+ [
115+ f"--disable-backend { choice } "
116+ for choice in request .config .option .disable_backend
117+ ]
118+ )
119+ )
120+ elif skip_backend :
97121 pytest .skip (f"skipping { func_name } as specified" )
98122 elif only_backends and param_id not in only_backends :
99123 pytest .skip (
@@ -109,10 +133,14 @@ def backend(request):
109133 pytest .mark .xfail (reason = f"expect { func_name } to fail as specified" )
110134 )
111135
136+ tensor_config , optimizer_config = request .param
137+
138+ tensor = getattr (pyhf .tensor , tensor_config [0 ])(** tensor_config [1 ])
139+ optimizer = getattr (pyhf .optimize , optimizer_config [0 ])(** optimizer_config [1 ])
112140 # actual execution here, after all checks is done
113- pyhf .set_backend (* request . param )
141+ pyhf .set_backend (tensor , optimizer )
114142
115- yield request . param
143+ yield ( tensor , optimizer )
116144
117145
118146@pytest .fixture (
0 commit comments