1+
2+ import sympy .parsing .sympy_parser as parser
3+ import sympy
4+ from pyhf .parameters import ParamViewer
5+ import jax .numpy as jnp
6+ import jax
7+
8+ def create_modifiers (additional_parameters = None ):
9+
10+ class PureFunctionModifierBuilder :
11+ is_shared = True
12+ def __init__ (self , pdfconfig ):
13+ self .config = pdfconfig
14+ self .required_parsets = additional_parameters or {}
15+ self .builder_data = {'local' : {},'global' : {'symbols' : set ()}}
16+
17+ def collect (self , thismod , nom ):
18+ maskval = True if thismod else False
19+ mask = [maskval ] * len (nom )
20+ return {'mask' : mask }
21+
22+ def append (self , key , channel , sample , thismod , defined_samp ):
23+ self .builder_data ['local' ].setdefault (key , {}).setdefault (sample , {}).setdefault ('data' , {'mask' : []})
24+
25+ nom = (
26+ defined_samp ['data' ]
27+ if defined_samp
28+ else [0.0 ] * self .config .channel_nbins [channel ]
29+ )
30+ moddata = self .collect (thismod , nom )
31+ self .builder_data ['local' ][key ][sample ]['data' ]['mask' ] += moddata ['mask' ]
32+
33+ if thismod is not None :
34+ formula = thismod ['data' ]['formula' ]
35+ parsed = parser .parse_expr (formula )
36+ free_symbols = parsed .free_symbols
37+ for x in free_symbols :
38+ self .builder_data ['global' ].setdefault ('symbols' ,set ()).add (x )
39+ else :
40+ parsed = None
41+ self .builder_data ['local' ].setdefault (key ,{}).setdefault (sample ,{}).setdefault ('channels' ,{}).setdefault (channel ,{})['parsed' ] = parsed
42+
43+ def finalize (self ):
44+ list_of_symbols = [str (x ) for x in self .builder_data ['global' ]['symbols' ]]
45+ self .builder_data ['global' ]['symbol_names' ] = list_of_symbols
46+ for modname , modspec in self .builder_data ['local' ].items ():
47+ for sample , samplespec in modspec .items ():
48+ for channel , channelspec in samplespec ['channels' ].items ():
49+ if channelspec ['parsed' ] is not None :
50+ channelspec ['jaxfunc' ] = sympy .lambdify (list_of_symbols , channelspec ['parsed' ], 'jax' )
51+ else :
52+ channelspec ['jaxfunc' ] = lambda * args : 1.0
53+ return self .builder_data
54+
55+ class PureFunctionModifierApplicator :
56+ op_code = 'multiplication'
57+ name = 'purefunc'
58+
59+ def __init__ (
60+ self , modifiers = None , pdfconfig = None , builder_data = None , batch_size = None
61+ ):
62+ self .builder_data = builder_data
63+ self .batch_size = batch_size
64+ self .pdfconfig = pdfconfig
65+ self .inputs = [str (x ) for x in builder_data ['global' ]['symbols' ]]
66+
67+ self .keys = [f'{ mtype } /{ m } ' for m , mtype in modifiers ]
68+ self .modifiers = [m for m , _ in modifiers ]
69+
70+ parfield_shape = (
71+ (self .batch_size , pdfconfig .npars )
72+ if self .batch_size
73+ else (pdfconfig .npars ,)
74+ )
75+
76+ self .param_viewer = ParamViewer (parfield_shape , pdfconfig .par_map , self .inputs )
77+ self .create_jax_eval ()
78+
79+ def create_jax_eval (self ):
80+ def eval_func (pars ):
81+ return jnp .array ([
82+ [
83+ jnp .concatenate ([
84+ self .builder_data ['local' ][m ][s ]['channels' ][c ]['jaxfunc' ](* pars )* jnp .ones (self .pdfconfig .channel_nbins [c ])
85+ for c in self .pdfconfig .channels
86+ ])
87+ for s in self .pdfconfig .samples
88+ ]
89+ for m in self .keys
90+
91+ ])
92+ self .jaxeval = eval_func
93+
94+ def apply_nonbatched (self ,pars ):
95+ return jnp .expand_dims (self .jaxeval (pars ),2 )
96+
97+ def apply_batched (self ,pars ):
98+ return jax .vmap (self .jaxeval , in_axes = (1 ,), out_axes = 2 )(pars )
99+
100+ def apply (self , pars ):
101+ if not self .param_viewer .index_selection :
102+ return
103+ if self .batch_size is None :
104+ par_selection = self .param_viewer .get (pars )
105+ results_purefunc = self .apply_nonbatched (par_selection )
106+ else :
107+ par_selection = self .param_viewer .get (pars )
108+ results_purefunc = self .apply_batched (par_selection )
109+ return results_purefunc
110+
111+ return PureFunctionModifierBuilder , PureFunctionModifierApplicator
112+
113+
114+ from pyhf .modifiers import histfactory_set
115+
116+ def enable (new_params = None ):
117+ modifier_set = {}
118+ modifier_set .update (** histfactory_set )
119+
120+ builder , applicator = create_modifiers (new_params )
121+
122+ modifier_set .update (** {
123+ applicator .name : (builder , applicator )}
124+ )
125+ return modifier_set
126+
127+ def new_unconstrained_scalars (new_params ):
128+ param_spec = {
129+ p ['name' ]:
130+ [{
131+ 'paramset_type' : 'unconstrained' ,
132+ 'n_parameters' : 1 ,
133+ 'is_shared' : True ,
134+ 'inits' : (p ['init' ],),
135+ 'bounds' : ((p ['min' ], p ['max' ]),),
136+ 'is_scalar' : True ,
137+ 'fixed' : False ,
138+ }]
139+ for p in new_params
140+ }
141+ return param_spec
0 commit comments