1-
21import sympy .parsing .sympy_parser as parser
32import sympy
43from pyhf .parameters import ParamViewer
54import jax .numpy as jnp
65import jax
76
7+
88def create_modifiers ():
99
1010 class PureFunctionModifierBuilder :
1111 is_shared = True
12+
1213 def __init__ (self , pdfconfig ):
1314 self .config = pdfconfig
1415 self .required_parsets = {}
15- self .builder_data = {'local' : {},'global' : {'symbols' : set ()}}
16+ self .builder_data = {'local' : {}, 'global' : {'symbols' : set ()}}
1617
1718 def collect (self , thismod , nom ):
1819 maskval = True if thismod else False
@@ -21,23 +22,25 @@ def collect(self, thismod, nom):
2122
2223 def require_synbols_as_scalars (self , symbols ):
2324 param_spec = {
24- p :
25- [{
26- 'paramset_type' : 'unconstrained' ,
27- 'n_parameters' : 1 ,
28- 'is_shared' : True ,
29- 'inits' : (1.0 ,),
30- 'bounds' : ((0 ,10 ),),
31- 'is_scalar' : True ,
32- 'fixed' : False ,
33- }]
25+ p : [
26+ {
27+ 'paramset_type' : 'unconstrained' ,
28+ 'n_parameters' : 1 ,
29+ 'is_shared' : True ,
30+ 'inits' : (1.0 ,),
31+ 'bounds' : ((0 , 10 ),),
32+ 'is_scalar' : True ,
33+ 'fixed' : False ,
34+ }
35+ ]
3436 for p in symbols
3537 }
3638 return param_spec
3739
38-
3940 def append (self , key , channel , sample , thismod , defined_samp ):
40- self .builder_data ['local' ].setdefault (key , {}).setdefault (sample , {}).setdefault ('data' , {'mask' : []})
41+ self .builder_data ['local' ].setdefault (key , {}).setdefault (
42+ sample , {}
43+ ).setdefault ('data' , {'mask' : []})
4144
4245 nom = (
4346 defined_samp ['data' ]
@@ -52,10 +55,12 @@ def append(self, key, channel, sample, thismod, defined_samp):
5255 parsed = parser .parse_expr (formula )
5356 free_symbols = parsed .free_symbols
5457 for x in free_symbols :
55- self .builder_data ['global' ].setdefault ('symbols' ,set ()).add (x )
58+ self .builder_data ['global' ].setdefault ('symbols' , set ()).add (x )
5659 else :
5760 parsed = None
58- self .builder_data ['local' ].setdefault (key ,{}).setdefault (sample ,{}).setdefault ('channels' ,{}).setdefault (channel ,{})['parsed' ] = parsed
61+ self .builder_data ['local' ].setdefault (key , {}).setdefault (
62+ sample , {}
63+ ).setdefault ('channels' , {}).setdefault (channel , {})['parsed' ] = parsed
5964
6065 def finalize (self ):
6166 list_of_symbols = [str (x ) for x in self .builder_data ['global' ]['symbols' ]]
@@ -67,7 +72,9 @@ def finalize(self):
6772 for sample , samplespec in modspec .items ():
6873 for channel , channelspec in samplespec ['channels' ].items ():
6974 if channelspec ['parsed' ] is not None :
70- channelspec ['jaxfunc' ] = sympy .lambdify (list_of_symbols , channelspec ['parsed' ], 'jax' )
75+ channelspec ['jaxfunc' ] = sympy .lambdify (
76+ list_of_symbols , channelspec ['parsed' ], 'jax'
77+ )
7178 else :
7279 channelspec ['jaxfunc' ] = lambda * args : 1.0
7380 return self .builder_data
@@ -93,28 +100,37 @@ def __init__(
93100 else (pdfconfig .npars ,)
94101 )
95102
96- self .param_viewer = ParamViewer (parfield_shape , pdfconfig .par_map , self .inputs )
103+ self .param_viewer = ParamViewer (
104+ parfield_shape , pdfconfig .par_map , self .inputs
105+ )
97106 self .create_jax_eval ()
98107
99108 def create_jax_eval (self ):
100109 def eval_func (pars ):
101- return jnp .array ([
110+ return jnp .array (
102111 [
103- jnp .concatenate ([
104- self .builder_data ['local' ][m ][s ]['channels' ][c ]['jaxfunc' ](* pars )* jnp .ones (self .pdfconfig .channel_nbins [c ])
105- for c in self .pdfconfig .channels
106- ])
107- for s in self .pdfconfig .samples
112+ [
113+ jnp .concatenate (
114+ [
115+ self .builder_data ['local' ][m ][s ]['channels' ][c ][
116+ 'jaxfunc'
117+ ](* pars )
118+ * jnp .ones (self .pdfconfig .channel_nbins [c ])
119+ for c in self .pdfconfig .channels
120+ ]
121+ )
122+ for s in self .pdfconfig .samples
123+ ]
124+ for m in self .keys
108125 ]
109- for m in self . keys
126+ )
110127
111- ])
112128 self .jaxeval = eval_func
113-
114- def apply_nonbatched (self ,pars ):
115- return jnp .expand_dims (self .jaxeval (pars ),2 )
116129
117- def apply_batched (self ,pars ):
130+ def apply_nonbatched (self , pars ):
131+ return jnp .expand_dims (self .jaxeval (pars ), 2 )
132+
133+ def apply_batched (self , pars ):
118134 return jax .vmap (self .jaxeval , in_axes = (1 ,), out_axes = 2 )(pars )
119135
120136 def apply (self , pars ):
@@ -127,19 +143,18 @@ def apply(self, pars):
127143 par_selection = self .param_viewer .get (pars )
128144 results_purefunc = self .apply_batched (par_selection )
129145 return results_purefunc
130-
146+
131147 return PureFunctionModifierBuilder , PureFunctionModifierApplicator
132148
133149
134150from pyhf .modifiers import histfactory_set
135151
152+
136153def enable ():
137154 modifier_set = {}
138155 modifier_set .update (** histfactory_set )
139156
140157 builder , applicator = create_modifiers ()
141158
142- modifier_set .update (** {
143- applicator .name : (builder , applicator )}
144- )
145- return modifier_set
159+ modifier_set .update (** {applicator .name : (builder , applicator )})
160+ return modifier_set
0 commit comments