Skip to content

Commit 8547e25

Browse files
authored
Implement predict, returned, logjoint, ... with OnlyAccsVarInfo (#1130)
* Use OnlyAccsVarInfo for many re-evaluation functions * drop `fast_` prefix * Add a changelog
1 parent 766f663 commit 8547e25

File tree

5 files changed

+102
-144
lines changed

5 files changed

+102
-144
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,14 @@ You should not need to use these directly, please use `AbstractPPL.condition` an
3232

3333
Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead.
3434

35+
The unexported functions `supports_varname_indexing(chain)`, `getindex_varname(chain)`, and `varnames(chain)` have been removed.
36+
3537
The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space.
3638
This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function).
3739

40+
The family of functions `returned(model, chain)`, along with the same signatures of `pointwise_logdensities`, `logjoint`, `loglikelihood`, and `logprior`, have been changed such that if the chain does not contain all variables in the model, an error is thrown.
41+
Previously the behaviour would have been to sample missing variables.
42+
3843
## 0.38.9
3944

4045
Remove warning when using Enzyme as the AD backend.

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 80 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,19 @@
11
module DynamicPPLMCMCChainsExt
22

3-
using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC
3+
using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random
44
using MCMCChains: MCMCChains
55

6-
_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
7-
8-
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
9-
return _has_varname_to_symbol(chain.info)
10-
end
11-
12-
function _check_varname_indexing(c::MCMCChains.Chains)
13-
return DynamicPPL.supports_varname_indexing(c) ||
14-
error("This `Chains` object does not support indexing using `VarName`s.")
15-
end
16-
17-
function DynamicPPL.getindex_varname(
6+
function getindex_varname(
187
c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx
198
)
20-
_check_varname_indexing(c)
219
return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx]
2210
end
23-
function DynamicPPL.varnames(c::MCMCChains.Chains)
24-
_check_varname_indexing(c)
11+
function get_varnames(c::MCMCChains.Chains)
12+
haskey(c.info, :varname_to_symbol) ||
13+
error("This `Chains` object does not support indexing using `VarName`s.")
2514
return keys(c.info.varname_to_symbol)
2615
end
2716

28-
function chain_sample_to_varname_dict(
29-
c::MCMCChains.Chains{Tval}, sample_idx, chain_idx
30-
) where {Tval}
31-
_check_varname_indexing(c)
32-
d = Dict{DynamicPPL.VarName,Tval}()
33-
for vn in DynamicPPL.varnames(c)
34-
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
35-
end
36-
return d
37-
end
38-
3917
"""
4018
AbstractMCMC.from_samples(
4119
::Type{MCMCChains.Chains},
@@ -118,8 +96,8 @@ function AbstractMCMC.to_samples(
11896
# Get parameters
11997
params_matrix = map(idxs) do (sample_idx, chain_idx)
12098
d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}()
121-
for vn in DynamicPPL.varnames(chain)
122-
d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx)
99+
for vn in get_varnames(chain)
100+
d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx)
123101
end
124102
d
125103
end
@@ -177,6 +155,46 @@ function AbstractMCMC.bundle_samples(
177155
return sort_chain ? sort(chain) : chain
178156
end
179157

158+
"""
159+
reevaluate_with_chain(
160+
rng::AbstractRNG,
161+
model::Model,
162+
chain::MCMCChains.Chains
163+
accs::NTuple{N,AbstractAccumulator};
164+
fallback=nothing,
165+
)
166+
167+
Re-evaluate `model` for each sample in `chain` using the accumulators provided in `at`,
168+
returning an matrix of `(retval, updated_at)` tuples.
169+
170+
This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the
171+
initialisation strategy when re-evaluating the model. For many usecases the fallback should
172+
not be provided (as we expect the chain to contain all necessary variables); but for
173+
`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating
174+
the posterior predictions).
175+
"""
176+
function reevaluate_with_chain(
177+
rng::Random.AbstractRNG,
178+
model::DynamicPPL.Model,
179+
chain::MCMCChains.Chains,
180+
accs::NTuple{N,DynamicPPL.AbstractAccumulator},
181+
fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing,
182+
) where {N}
183+
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
184+
vi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs))
185+
return map(params_with_stats) do ps
186+
DynamicPPL.init!!(rng, model, vi, DynamicPPL.InitFromParams(ps.params, fallback))
187+
end
188+
end
189+
function reevaluate_with_chain(
190+
model::DynamicPPL.Model,
191+
chain::MCMCChains.Chains,
192+
accs::NTuple{N,DynamicPPL.AbstractAccumulator},
193+
fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing,
194+
) where {N}
195+
return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback)
196+
end
197+
180198
"""
181199
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
182200
@@ -245,30 +263,18 @@ function DynamicPPL.predict(
245263
include_all=false,
246264
)
247265
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
248-
249-
# Set up a VarInfo with the right accumulators
250-
varinfo = DynamicPPL.setaccs!!(
251-
DynamicPPL.VarInfo(),
252-
(
253-
DynamicPPL.LogPriorAccumulator(),
254-
DynamicPPL.LogLikelihoodAccumulator(),
255-
DynamicPPL.ValuesAsInModelAccumulator(false),
256-
),
266+
accs = (
267+
DynamicPPL.LogPriorAccumulator(),
268+
DynamicPPL.LogLikelihoodAccumulator(),
269+
DynamicPPL.ValuesAsInModelAccumulator(false),
257270
)
258-
_, varinfo = DynamicPPL.init!!(model, varinfo)
259-
varinfo = DynamicPPL.typed_varinfo(varinfo)
260-
261-
params_and_stats = AbstractMCMC.to_samples(
262-
DynamicPPL.ParamsWithStats, parameter_only_chain
271+
predictions = map(
272+
DynamicPPL.ParamsWithStats last,
273+
reevaluate_with_chain(
274+
rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior()
275+
),
263276
)
264-
predictions = map(params_and_stats) do ps
265-
_, varinfo = DynamicPPL.init!!(
266-
rng, model, varinfo, DynamicPPL.InitFromParams(ps.params)
267-
)
268-
DynamicPPL.ParamsWithStats(varinfo)
269-
end
270277
chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions)
271-
272278
parameter_names = if include_all
273279
MCMCChains.names(chain_result, :parameters)
274280
else
@@ -348,18 +354,7 @@ julia> returned(model, chain)
348354
"""
349355
function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains)
350356
chain = MCMCChains.get_sections(chain_full, :parameters)
351-
varinfo = DynamicPPL.VarInfo(model)
352-
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
353-
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
354-
return map(params_with_stats) do ps
355-
first(
356-
DynamicPPL.init!!(
357-
model,
358-
varinfo,
359-
DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()),
360-
),
361-
)
362-
end
357+
return map(first, reevaluate_with_chain(model, chain, (), nothing))
363358
end
364359

365360
"""
@@ -452,24 +447,13 @@ function DynamicPPL.pointwise_logdensities(
452447
::Type{Tout}=MCMCChains.Chains,
453448
::Val{whichlogprob}=Val(:both),
454449
) where {whichlogprob,Tout}
455-
vi = DynamicPPL.VarInfo(model)
456450
acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}()
457451
accname = DynamicPPL.accumulator_name(acc)
458-
vi = DynamicPPL.setaccs!!(vi, (acc,))
459452
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
460-
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
461-
pointwise_logps = map(iters) do (sample_idx, chain_idx)
462-
# Extract values from the chain
463-
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
464-
# Re-evaluate the model
465-
_, vi = DynamicPPL.init!!(
466-
model,
467-
vi,
468-
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
469-
)
470-
DynamicPPL.getacc(vi, Val(accname)).logps
471-
end
472-
453+
pointwise_logps =
454+
map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi)
455+
DynamicPPL.getacc(vi, Val(accname)).logps
456+
end
473457
# pointwise_logps is a matrix of OrderedDicts
474458
all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
475459
for d in pointwise_logps
@@ -556,15 +540,15 @@ julia> logjoint(demo_model([1., 2.]), chain)
556540
```
557541
"""
558542
function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains)
559-
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
560-
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
561-
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
562-
vn_parent => DynamicPPL.values_from_chain(
563-
var_info, vn_parent, chain, chain_idx, iteration_idx
564-
) for vn_parent in keys(var_info)
565-
)
566-
DynamicPPL.logjoint(model, argvals_dict)
567-
end
543+
return map(
544+
DynamicPPL.getlogjoint last,
545+
reevaluate_with_chain(
546+
model,
547+
chain,
548+
(DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()),
549+
nothing,
550+
),
551+
)
568552
end
569553

570554
"""
@@ -596,15 +580,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain)
596580
```
597581
"""
598582
function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)
599-
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
600-
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
601-
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
602-
vn_parent => DynamicPPL.values_from_chain(
603-
var_info, vn_parent, chain, chain_idx, iteration_idx
604-
) for vn_parent in keys(var_info)
605-
)
606-
DynamicPPL.loglikelihood(model, argvals_dict)
607-
end
583+
return map(
584+
DynamicPPL.getloglikelihood last,
585+
reevaluate_with_chain(
586+
model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing
587+
),
588+
)
608589
end
609590

610591
"""
@@ -637,15 +618,10 @@ julia> logprior(demo_model([1., 2.]), chain)
637618
```
638619
"""
639620
function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)
640-
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
641-
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
642-
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
643-
vn_parent => DynamicPPL.values_from_chain(
644-
var_info, vn_parent, chain, chain_idx, iteration_idx
645-
) for vn_parent in keys(var_info)
646-
)
647-
DynamicPPL.logprior(model, argvals_dict)
648-
end
621+
return map(
622+
DynamicPPL.getlogprior last,
623+
reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing),
624+
)
649625
end
650626

651627
end

src/chains.jl

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,3 @@
1-
"""
2-
supports_varname_indexing(chain::AbstractChains)
3-
4-
Return `true` if `chain` supports indexing using `VarName` in place of the
5-
variable name index.
6-
"""
7-
supports_varname_indexing(::AbstractChains) = false
8-
9-
"""
10-
getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx)
11-
12-
Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`.
13-
14-
Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
15-
"""
16-
function getindex_varname end
17-
18-
"""
19-
varnames(chains::AbstractChains)
20-
21-
Return an iterator over the varnames present in `chains`.
22-
23-
Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
24-
"""
25-
function varnames end
26-
271
"""
282
ParamsWithStats
293

src/logdensityfunction.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,21 +193,21 @@ end
193193
# LogDensityProblems.jl interface #
194194
###################################
195195
"""
196-
fast_ldf_accs(getlogdensity::Function)
196+
ldf_accs(getlogdensity::Function)
197197
198198
Determine which accumulators are needed for fast evaluation with the given
199199
`getlogdensity` function.
200200
"""
201-
fast_ldf_accs(::Function) = default_accumulators()
202-
fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators()
203-
function fast_ldf_accs(::typeof(getlogjoint))
201+
ldf_accs(::Function) = default_accumulators()
202+
ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators()
203+
function ldf_accs(::typeof(getlogjoint))
204204
return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator()))
205205
end
206-
function fast_ldf_accs(::typeof(getlogprior_internal))
206+
function ldf_accs(::typeof(getlogprior_internal))
207207
return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator()))
208208
end
209-
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
210-
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
209+
ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
210+
ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
211211

212212
struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
213213
model::M
@@ -219,7 +219,7 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real})
219219
strategy = InitFromParams(
220220
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
221221
)
222-
accs = fast_ldf_accs(f.getlogdensity)
222+
accs = ldf_accs(f.getlogdensity)
223223
_, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy)
224224
return f.getlogdensity(vi)
225225
end

src/model.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,12 +1181,15 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0))
11811181
```
11821182
"""
11831183
function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}})
1184-
vi = DynamicPPL.setaccs!!(VarInfo(), ())
11851184
# Note: we can't use `fix(model, parameters)` because
11861185
# https://github.com/TuringLang/DynamicPPL.jl/issues/1097
1187-
# Use `nothing` as the fallback to ensure that any missing parameters cause an error
1188-
ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing))
1189-
new_model = setleafcontext(model, ctx)
1190-
# We can't use new_model() because that overwrites it with an InitContext of its own.
1191-
return first(evaluate!!(new_model, vi))
1186+
return first(
1187+
init!!(
1188+
model,
1189+
DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple()),
1190+
# Use `nothing` as the fallback to ensure that any missing parameters cause an
1191+
# error
1192+
InitFromParams(parameters, nothing),
1193+
),
1194+
)
11921195
end

0 commit comments

Comments
 (0)