Skip to content

Commit db57a1d

Browse files
committed
Fixes for DPPL
1 parent 98c4c11 commit db57a1d

File tree

13 files changed

+84
-483
lines changed

13 files changed

+84
-483
lines changed

ext/TuringDynamicHMCExt.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ $(TYPEDFIELDS)
3737
"""
3838
struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S}
3939
logdensity::L
40-
vi::V
4140
"Cache of sample, log density, and gradient of log density evaluation."
4241
cache::C
4342
metric::M
@@ -70,9 +69,8 @@ function Turing.Inference.initialstep(
7069
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
7170

7271
# Create first sample and state.
73-
vi = DynamicPPL.unflatten(vi, Q.q)
74-
sample = Turing.Inference.Transition(model, vi, nothing)
75-
state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ)
72+
sample = DynamicPPL.ParamsWithStats(Q.q, ℓ)
73+
state = DynamicNUTSState(ℓ, Q, steps.H.κ, steps.ϵ)
7674

7775
return sample, state
7876
end
@@ -85,15 +83,13 @@ function AbstractMCMC.step(
8583
kwargs...,
8684
)
8785
# Compute next sample.
88-
vi = state.vi
8986
= state.logdensity
9087
steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ℓ, state.stepsize)
9188
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
9289

9390
# Create next sample and state.
94-
vi = DynamicPPL.unflatten(vi, Q.q)
95-
sample = Turing.Inference.Transition(model, vi, nothing)
96-
newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize)
91+
sample = DynamicPPL.ParamsWithStats(Q.q, ℓ)
92+
newstate = DynamicNUTSState(ℓ, Q, state.metric, state.stepsize)
9793

9894
return sample, newstate
9995
end

src/mcmc/Inference.jl

Lines changed: 0 additions & 306 deletions
Original file line numberDiff line numberDiff line change
@@ -114,312 +114,6 @@ function mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::
114114
return log(rand()) + logp_current logp_proposal + log_proposal_ratio
115115
end
116116

117-
######################
118-
# Default Transition #
119-
######################
120-
getstats(::Any) = NamedTuple()
121-
getstats(nt::NamedTuple) = nt
122-
123-
struct Transition{T,F<:AbstractFloat,N<:NamedTuple}
124-
θ::T
125-
logprior::F
126-
loglikelihood::F
127-
stat::N
128-
129-
"""
130-
Transition(model::Model, vi::AbstractVarInfo, stats; reevaluate=true)
131-
132-
Construct a new `Turing.Inference.Transition` object using the outputs of a
133-
sampler step.
134-
135-
Here, `vi` represents a VarInfo _for which the appropriate parameters have
136-
already been set_. However, the accumulators (e.g. logp) may in general
137-
have junk contents. The role of this method is to re-evaluate `model` and
138-
thus set the accumulators to the correct values.
139-
140-
`stats` is any object on which `Turing.Inference.getstats` can be called to
141-
return a NamedTuple of statistics. This could be, for example, the transition
142-
returned by an (unwrapped) external sampler. Or alternatively, it could
143-
simply be a NamedTuple itself (for which `getstats` acts as the identity).
144-
145-
By default, the model is re-evaluated in order to obtain values of:
146-
- the values of the parameters as per user parameterisation (`vals_as_in_model`)
147-
- the various components of the log joint probability (`logprior`, `loglikelihood`)
148-
that are guaranteed to be correct.
149-
150-
If you **know** for a fact that the VarInfo `vi` already contains this information,
151-
then you can set `reevaluate=false` to skip the re-evaluation step.
152-
153-
!!! warning
154-
Note that in general this is unsafe and may lead to wrong results.
155-
156-
If `reevaluate` is set to `false`, it is the caller's responsibility to ensure that
157-
the `VarInfo` passed in has `ValuesAsInModelAccumulator`, `LogPriorAccumulator`,
158-
and `LogLikelihoodAccumulator` set up with the correct values. Note that the
159-
`ValuesAsInModelAccumulator` must also have `include_colon_eq == true`, i.e. it
160-
must be set up to track `x := y` statements.
161-
"""
162-
function Transition(
163-
model::DynamicPPL.Model, vi::AbstractVarInfo, stats; reevaluate=true
164-
)
165-
# Avoid mutating vi as it may be used later e.g. when constructing
166-
# sampler states.
167-
vi = deepcopy(vi)
168-
if reevaluate
169-
vi = DynamicPPL.setaccs!!(
170-
vi,
171-
(
172-
DynamicPPL.ValuesAsInModelAccumulator(true),
173-
DynamicPPL.LogPriorAccumulator(),
174-
DynamicPPL.LogLikelihoodAccumulator(),
175-
),
176-
)
177-
_, vi = DynamicPPL.evaluate!!(model, vi)
178-
end
179-
180-
# Extract all the information we need
181-
vals_as_in_model = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
182-
logprior = DynamicPPL.getlogprior(vi)
183-
loglikelihood = DynamicPPL.getloglikelihood(vi)
184-
185-
# Get additional statistics
186-
stats = getstats(stats)
187-
return new{typeof(vals_as_in_model),typeof(logprior),typeof(stats)}(
188-
vals_as_in_model, logprior, loglikelihood, stats
189-
)
190-
end
191-
192-
function Transition(
193-
model::DynamicPPL.Model,
194-
untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata},
195-
stats;
196-
reevaluate=true,
197-
)
198-
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
199-
# much faster to convert it to a typed varinfo first, hence this method.
200-
# https://github.com/TuringLang/Turing.jl/issues/2604
201-
return Transition(
202-
model, DynamicPPL.typed_varinfo(untyped_vi), stats; reevaluate=reevaluate
203-
)
204-
end
205-
end
206-
207-
function getstats_with_lp(t::Transition)
208-
return merge(
209-
t.stat,
210-
(
211-
lp=t.logprior + t.loglikelihood,
212-
logprior=t.logprior,
213-
loglikelihood=t.loglikelihood,
214-
),
215-
)
216-
end
217-
function getstats_with_lp(vi::AbstractVarInfo)
218-
return (
219-
lp=DynamicPPL.getlogjoint(vi),
220-
logprior=DynamicPPL.getlogprior(vi),
221-
loglikelihood=DynamicPPL.getloglikelihood(vi),
222-
)
223-
end
224-
225-
##########################
226-
# Chain making utilities #
227-
##########################
228-
229-
# TODO(penelopeysm): Separate Turing.Inference.getparams (should only be
230-
# defined for AbstractVarInfo and Turing.Inference.Transition; returns varname
231-
# => value maps) from AbstractMCMC.getparams (defined for any sampler transition,
232-
# returns vector).
233-
"""
234-
Turing.Inference.getparams(model::DynamicPPL.Model, t::Any)
235-
236-
Return a vector of parameter values from the given sampler transition `t` (i.e.,
237-
the first return value of AbstractMCMC.step). By default, returns the `t.θ` field.
238-
239-
!!! note
240-
This method only needs to be implemented for external samplers. It will be
241-
removed in future releases and replaced with `AbstractMCMC.getparams`.
242-
"""
243-
getparams(::DynamicPPL.Model, t) = t.θ
244-
"""
245-
Turing.Inference.getparams(model::DynamicPPL.Model, t::AbstractVarInfo)
246-
247-
Return a key-value map of parameters from the varinfo.
248-
"""
249-
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
250-
t = Transition(model, vi, nothing)
251-
return getparams(model, t)
252-
end
253-
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
254-
names_set = OrderedSet{VarName}()
255-
# Extract the parameter names and values from each transition.
256-
dicts = map(ts) do t
257-
# In general getparams returns a dict of VarName => values. We need to also
258-
# split it up into constituent elements using
259-
# `AbstractPPL.varname_and_value_leaves` because otherwise MCMCChains.jl
260-
# won't understand it.
261-
vals = getparams(model, t)
262-
nms_and_vs = if isempty(vals)
263-
Tuple{VarName,Any}[]
264-
else
265-
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
266-
mapreduce(collect, vcat, iters)
267-
end
268-
nms = map(first, nms_and_vs)
269-
vs = map(last, nms_and_vs)
270-
for nm in nms
271-
push!(names_set, nm)
272-
end
273-
# Convert the names and values to a single dictionary.
274-
return OrderedDict(zip(nms, vs))
275-
end
276-
names = collect(names_set)
277-
vals = [get(dicts[i], key, missing) for i in eachindex(dicts), key in names]
278-
279-
return names, vals
280-
end
281-
282-
function get_transition_extras(ts::AbstractVector)
283-
# Extract stats + log probabilities from each transition or VarInfo
284-
extra_data = map(getstats_with_lp, ts)
285-
return names_values(extra_data)
286-
end
287-
288-
function names_values(extra_data::AbstractVector{<:NamedTuple{names}}) where {names}
289-
values = [getfield(data, name) for data in extra_data, name in names]
290-
return collect(names), values
291-
end
292-
293-
function names_values(xs::AbstractVector{<:NamedTuple})
294-
# Obtain all parameter names.
295-
names_set = Set{Symbol}()
296-
for x in xs
297-
for k in keys(x)
298-
push!(names_set, k)
299-
end
300-
end
301-
names_unique = collect(names_set)
302-
303-
# Extract all values as matrix.
304-
values = [haskey(x, name) ? x[name] : missing for x in xs, name in names_unique]
305-
306-
return names_unique, values
307-
end
308-
309-
getlogevidence(transitions, sampler, state) = missing
310-
311-
# Default MCMCChains.Chains constructor.
312-
function AbstractMCMC.bundle_samples(
313-
ts::Vector{<:Transition},
314-
model::DynamicPPL.Model,
315-
spl::AbstractSampler,
316-
state,
317-
chain_type::Type{MCMCChains.Chains};
318-
save_state=false,
319-
stats=missing,
320-
sort_chain=false,
321-
include_varname_to_symbol=true,
322-
discard_initial=0,
323-
thinning=1,
324-
kwargs...,
325-
)
326-
# Convert transitions to array format.
327-
# Also retrieve the variable names.
328-
varnames, vals = _params_to_array(model, ts)
329-
varnames_symbol = map(Symbol, varnames)
330-
331-
# Get the values of the extra parameters in each transition.
332-
extra_params, extra_values = get_transition_extras(ts)
333-
334-
# Extract names & construct param array.
335-
nms = [varnames_symbol; extra_params]
336-
parray = hcat(vals, extra_values)
337-
338-
# Get the average or final log evidence, if it exists.
339-
le = getlogevidence(ts, spl, state)
340-
341-
# Set up the info tuple.
342-
info = NamedTuple()
343-
344-
if include_varname_to_symbol
345-
info = merge(info, (varname_to_symbol=OrderedDict(zip(varnames, varnames_symbol)),))
346-
end
347-
348-
if save_state
349-
info = merge(info, (model=model, sampler=spl, samplerstate=state))
350-
end
351-
352-
# Merge in the timing info, if available
353-
if !ismissing(stats)
354-
info = merge(info, (start_time=stats.start, stop_time=stats.stop))
355-
end
356-
357-
# Conretize the array before giving it to MCMCChains.
358-
parray = MCMCChains.concretize(parray)
359-
360-
# Chain construction.
361-
chain = MCMCChains.Chains(
362-
parray,
363-
nms,
364-
(internals=extra_params,);
365-
evidence=le,
366-
info=info,
367-
start=discard_initial + 1,
368-
thin=thinning,
369-
)
370-
371-
return sort_chain ? sort(chain) : chain
372-
end
373-
374-
function AbstractMCMC.bundle_samples(
375-
ts::Vector{<:Transition},
376-
model::DynamicPPL.Model,
377-
spl::AbstractSampler,
378-
state,
379-
chain_type::Type{Vector{NamedTuple}};
380-
kwargs...,
381-
)
382-
return map(ts) do t
383-
# Construct a dictionary of pairs `vn => value`.
384-
params = OrderedDict(getparams(model, t))
385-
# Group the variable names by their symbol.
386-
sym_to_vns = group_varnames_by_symbol(keys(params))
387-
# Convert the values to a vector.
388-
vals = map(values(sym_to_vns)) do vns
389-
map(Base.Fix1(getindex, params), vns)
390-
end
391-
return merge(NamedTuple(zip(keys(sym_to_vns), vals)), getstats_with_lp(t))
392-
end
393-
end
394-
395-
"""
396-
group_varnames_by_symbol(vns)
397-
398-
Group the varnames by their symbol.
399-
400-
# Arguments
401-
- `vns`: Iterable of `VarName`.
402-
403-
# Returns
404-
- `OrderedDict{Symbol, Vector{VarName}}`: A dictionary mapping symbol to a vector of varnames.
405-
"""
406-
function group_varnames_by_symbol(vns)
407-
d = OrderedDict{Symbol,Vector{VarName}}()
408-
for vn in vns
409-
sym = DynamicPPL.getsym(vn)
410-
if !haskey(d, sym)
411-
d[sym] = VarName[]
412-
end
413-
push!(d[sym], vn)
414-
end
415-
return d
416-
end
417-
418-
function save(c::MCMCChains.Chains, spl::AbstractSampler, model, vi, samples)
419-
nt = NamedTuple{(:sampler, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples))
420-
return setinfo(c, merge(nt, c.info))
421-
end
422-
423117
#######################################
424118
# Concrete algorithm implementations. #
425119
#######################################

0 commit comments

Comments
 (0)