@@ -114,312 +114,6 @@ function mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::
114114 return log (rand ()) + logp_current ≤ logp_proposal + log_proposal_ratio
115115end
116116
117- # #####################
118- # Default Transition #
119- # #####################
120- getstats (:: Any ) = NamedTuple ()
121- getstats (nt:: NamedTuple ) = nt
122-
123- struct Transition{T,F<: AbstractFloat ,N<: NamedTuple }
124- θ:: T
125- logprior:: F
126- loglikelihood:: F
127- stat:: N
128-
129- """
130- Transition(model::Model, vi::AbstractVarInfo, stats; reevaluate=true)
131-
132- Construct a new `Turing.Inference.Transition` object using the outputs of a
133- sampler step.
134-
135- Here, `vi` represents a VarInfo _for which the appropriate parameters have
136- already been set_. However, the accumulators (e.g. logp) may in general
137- have junk contents. The role of this method is to re-evaluate `model` and
138- thus set the accumulators to the correct values.
139-
140- `stats` is any object on which `Turing.Inference.getstats` can be called to
141- return a NamedTuple of statistics. This could be, for example, the transition
142- returned by an (unwrapped) external sampler. Or alternatively, it could
143- simply be a NamedTuple itself (for which `getstats` acts as the identity).
144-
145- By default, the model is re-evaluated in order to obtain values of:
146- - the values of the parameters as per user parameterisation (`vals_as_in_model`)
147- - the various components of the log joint probability (`logprior`, `loglikelihood`)
148- that are guaranteed to be correct.
149-
150- If you **know** for a fact that the VarInfo `vi` already contains this information,
151- then you can set `reevaluate=false` to skip the re-evaluation step.
152-
153- !!! warning
154- Note that in general this is unsafe and may lead to wrong results.
155-
156- If `reevaluate` is set to `false`, it is the caller's responsibility to ensure that
157- the `VarInfo` passed in has `ValuesAsInModelAccumulator`, `LogPriorAccumulator`,
158- and `LogLikelihoodAccumulator` set up with the correct values. Note that the
159- `ValuesAsInModelAccumulator` must also have `include_colon_eq == true`, i.e. it
160- must be set up to track `x := y` statements.
161- """
162- function Transition (
163- model:: DynamicPPL.Model , vi:: AbstractVarInfo , stats; reevaluate= true
164- )
165- # Avoid mutating vi as it may be used later e.g. when constructing
166- # sampler states.
167- vi = deepcopy (vi)
168- if reevaluate
169- vi = DynamicPPL. setaccs!! (
170- vi,
171- (
172- DynamicPPL. ValuesAsInModelAccumulator (true ),
173- DynamicPPL. LogPriorAccumulator (),
174- DynamicPPL. LogLikelihoodAccumulator (),
175- ),
176- )
177- _, vi = DynamicPPL. evaluate!! (model, vi)
178- end
179-
180- # Extract all the information we need
181- vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
182- logprior = DynamicPPL. getlogprior (vi)
183- loglikelihood = DynamicPPL. getloglikelihood (vi)
184-
185- # Get additional statistics
186- stats = getstats (stats)
187- return new {typeof(vals_as_in_model),typeof(logprior),typeof(stats)} (
188- vals_as_in_model, logprior, loglikelihood, stats
189- )
190- end
191-
192- function Transition (
193- model:: DynamicPPL.Model ,
194- untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
195- stats;
196- reevaluate= true ,
197- )
198- # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
199- # much faster to convert it to a typed varinfo first, hence this method.
200- # https://github.com/TuringLang/Turing.jl/issues/2604
201- return Transition (
202- model, DynamicPPL. typed_varinfo (untyped_vi), stats; reevaluate= reevaluate
203- )
204- end
205- end
206-
207- function getstats_with_lp (t:: Transition )
208- return merge (
209- t. stat,
210- (
211- lp= t. logprior + t. loglikelihood,
212- logprior= t. logprior,
213- loglikelihood= t. loglikelihood,
214- ),
215- )
216- end
217- function getstats_with_lp (vi:: AbstractVarInfo )
218- return (
219- lp= DynamicPPL. getlogjoint (vi),
220- logprior= DynamicPPL. getlogprior (vi),
221- loglikelihood= DynamicPPL. getloglikelihood (vi),
222- )
223- end
224-
225- # #########################
226- # Chain making utilities #
227- # #########################
228-
229- # TODO (penelopeysm): Separate Turing.Inference.getparams (should only be
230- # defined for AbstractVarInfo and Turing.Inference.Transition; returns varname
231- # => value maps) from AbstractMCMC.getparams (defined for any sampler transition,
232- # returns vector).
233- """
234- Turing.Inference.getparams(model::DynamicPPL.Model, t::Any)
235-
236- Return a vector of parameter values from the given sampler transition `t` (i.e.,
237- the first return value of AbstractMCMC.step). By default, returns the `t.θ` field.
238-
239- !!! note
240- This method only needs to be implemented for external samplers. It will be
241- removed in future releases and replaced with `AbstractMCMC.getparams`.
242- """
243- getparams (:: DynamicPPL.Model , t) = t. θ
244- """
245- Turing.Inference.getparams(model::DynamicPPL.Model, t::AbstractVarInfo)
246-
247- Return a key-value map of parameters from the varinfo.
248- """
249- function getparams (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
250- t = Transition (model, vi, nothing )
251- return getparams (model, t)
252- end
253- function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
254- names_set = OrderedSet {VarName} ()
255- # Extract the parameter names and values from each transition.
256- dicts = map (ts) do t
257- # In general getparams returns a dict of VarName => values. We need to also
258- # split it up into constituent elements using
259- # `AbstractPPL.varname_and_value_leaves` because otherwise MCMCChains.jl
260- # won't understand it.
261- vals = getparams (model, t)
262- nms_and_vs = if isempty (vals)
263- Tuple{VarName,Any}[]
264- else
265- iters = map (AbstractPPL. varname_and_value_leaves, keys (vals), values (vals))
266- mapreduce (collect, vcat, iters)
267- end
268- nms = map (first, nms_and_vs)
269- vs = map (last, nms_and_vs)
270- for nm in nms
271- push! (names_set, nm)
272- end
273- # Convert the names and values to a single dictionary.
274- return OrderedDict (zip (nms, vs))
275- end
276- names = collect (names_set)
277- vals = [get (dicts[i], key, missing ) for i in eachindex (dicts), key in names]
278-
279- return names, vals
280- end
281-
282- function get_transition_extras (ts:: AbstractVector )
283- # Extract stats + log probabilities from each transition or VarInfo
284- extra_data = map (getstats_with_lp, ts)
285- return names_values (extra_data)
286- end
287-
288- function names_values (extra_data:: AbstractVector{<:NamedTuple{names}} ) where {names}
289- values = [getfield (data, name) for data in extra_data, name in names]
290- return collect (names), values
291- end
292-
293- function names_values (xs:: AbstractVector{<:NamedTuple} )
294- # Obtain all parameter names.
295- names_set = Set {Symbol} ()
296- for x in xs
297- for k in keys (x)
298- push! (names_set, k)
299- end
300- end
301- names_unique = collect (names_set)
302-
303- # Extract all values as matrix.
304- values = [haskey (x, name) ? x[name] : missing for x in xs, name in names_unique]
305-
306- return names_unique, values
307- end
308-
309- getlogevidence (transitions, sampler, state) = missing
310-
311- # Default MCMCChains.Chains constructor.
312- function AbstractMCMC. bundle_samples (
313- ts:: Vector{<:Transition} ,
314- model:: DynamicPPL.Model ,
315- spl:: AbstractSampler ,
316- state,
317- chain_type:: Type{MCMCChains.Chains} ;
318- save_state= false ,
319- stats= missing ,
320- sort_chain= false ,
321- include_varname_to_symbol= true ,
322- discard_initial= 0 ,
323- thinning= 1 ,
324- kwargs... ,
325- )
326- # Convert transitions to array format.
327- # Also retrieve the variable names.
328- varnames, vals = _params_to_array (model, ts)
329- varnames_symbol = map (Symbol, varnames)
330-
331- # Get the values of the extra parameters in each transition.
332- extra_params, extra_values = get_transition_extras (ts)
333-
334- # Extract names & construct param array.
335- nms = [varnames_symbol; extra_params]
336- parray = hcat (vals, extra_values)
337-
338- # Get the average or final log evidence, if it exists.
339- le = getlogevidence (ts, spl, state)
340-
341- # Set up the info tuple.
342- info = NamedTuple ()
343-
344- if include_varname_to_symbol
345- info = merge (info, (varname_to_symbol= OrderedDict (zip (varnames, varnames_symbol)),))
346- end
347-
348- if save_state
349- info = merge (info, (model= model, sampler= spl, samplerstate= state))
350- end
351-
352- # Merge in the timing info, if available
353- if ! ismissing (stats)
354- info = merge (info, (start_time= stats. start, stop_time= stats. stop))
355- end
356-
357- # Conretize the array before giving it to MCMCChains.
358- parray = MCMCChains. concretize (parray)
359-
360- # Chain construction.
361- chain = MCMCChains. Chains (
362- parray,
363- nms,
364- (internals= extra_params,);
365- evidence= le,
366- info= info,
367- start= discard_initial + 1 ,
368- thin= thinning,
369- )
370-
371- return sort_chain ? sort (chain) : chain
372- end
373-
374- function AbstractMCMC. bundle_samples (
375- ts:: Vector{<:Transition} ,
376- model:: DynamicPPL.Model ,
377- spl:: AbstractSampler ,
378- state,
379- chain_type:: Type{Vector{NamedTuple}} ;
380- kwargs... ,
381- )
382- return map (ts) do t
383- # Construct a dictionary of pairs `vn => value`.
384- params = OrderedDict (getparams (model, t))
385- # Group the variable names by their symbol.
386- sym_to_vns = group_varnames_by_symbol (keys (params))
387- # Convert the values to a vector.
388- vals = map (values (sym_to_vns)) do vns
389- map (Base. Fix1 (getindex, params), vns)
390- end
391- return merge (NamedTuple (zip (keys (sym_to_vns), vals)), getstats_with_lp (t))
392- end
393- end
394-
395- """
396- group_varnames_by_symbol(vns)
397-
398- Group the varnames by their symbol.
399-
400- # Arguments
401- - `vns`: Iterable of `VarName`.
402-
403- # Returns
404- - `OrderedDict{Symbol, Vector{VarName}}`: A dictionary mapping symbol to a vector of varnames.
405- """
406- function group_varnames_by_symbol (vns)
407- d = OrderedDict {Symbol,Vector{VarName}} ()
408- for vn in vns
409- sym = DynamicPPL. getsym (vn)
410- if ! haskey (d, sym)
411- d[sym] = VarName[]
412- end
413- push! (d[sym], vn)
414- end
415- return d
416- end
417-
418- function save (c:: MCMCChains.Chains , spl:: AbstractSampler , model, vi, samples)
419- nt = NamedTuple {(:sampler, :model, :vi, :samples)} ((spl, model, deepcopy (vi), samples))
420- return setinfo (c, merge (nt, c. info))
421- end
422-
423117# ######################################
424118# Concrete algorithm implementations. #
425119# ######################################
0 commit comments