|
1 | 1 | module DynamicPPLMCMCChainsExt |
2 | 2 |
|
3 | | -using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC |
| 3 | +using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random |
4 | 4 | using MCMCChains: MCMCChains |
5 | 5 |
|
6 | | -_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names |
7 | | - |
8 | | -function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains) |
9 | | - return _has_varname_to_symbol(chain.info) |
10 | | -end |
11 | | - |
12 | | -function _check_varname_indexing(c::MCMCChains.Chains) |
13 | | - return DynamicPPL.supports_varname_indexing(c) || |
14 | | - error("This `Chains` object does not support indexing using `VarName`s.") |
15 | | -end |
16 | | - |
17 | | -function DynamicPPL.getindex_varname( |
| 6 | +function getindex_varname( |
18 | 7 | c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx |
19 | 8 | ) |
20 | | - _check_varname_indexing(c) |
21 | 9 | return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx] |
22 | 10 | end |
23 | | -function DynamicPPL.varnames(c::MCMCChains.Chains) |
24 | | - _check_varname_indexing(c) |
| 11 | +function get_varnames(c::MCMCChains.Chains) |
| 12 | + haskey(c.info, :varname_to_symbol) || |
| 13 | + error("This `Chains` object does not support indexing using `VarName`s.") |
25 | 14 | return keys(c.info.varname_to_symbol) |
26 | 15 | end |
27 | 16 |
|
28 | | -function chain_sample_to_varname_dict( |
29 | | - c::MCMCChains.Chains{Tval}, sample_idx, chain_idx |
30 | | -) where {Tval} |
31 | | - _check_varname_indexing(c) |
32 | | - d = Dict{DynamicPPL.VarName,Tval}() |
33 | | - for vn in DynamicPPL.varnames(c) |
34 | | - d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) |
35 | | - end |
36 | | - return d |
37 | | -end |
38 | | - |
39 | 17 | """ |
40 | 18 | AbstractMCMC.from_samples( |
41 | 19 | ::Type{MCMCChains.Chains}, |
@@ -118,8 +96,8 @@ function AbstractMCMC.to_samples( |
118 | 96 | # Get parameters |
119 | 97 | params_matrix = map(idxs) do (sample_idx, chain_idx) |
120 | 98 | d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}() |
121 | | - for vn in DynamicPPL.varnames(chain) |
122 | | - d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx) |
| 99 | + for vn in get_varnames(chain) |
| 100 | + d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx) |
123 | 101 | end |
124 | 102 | d |
125 | 103 | end |
@@ -177,6 +155,46 @@ function AbstractMCMC.bundle_samples( |
177 | 155 | return sort_chain ? sort(chain) : chain |
178 | 156 | end |
179 | 157 |
|
| 158 | +""" |
| 159 | + reevaluate_with_chain( |
| 160 | + rng::AbstractRNG, |
| 161 | + model::Model, |
| 162 | + chain::MCMCChains.Chains |
| 163 | + accs::NTuple{N,AbstractAccumulator}; |
| 164 | + fallback=nothing, |
| 165 | + ) |
| 166 | +
|
| 167 | +Re-evaluate `model` for each sample in `chain` using the accumulators provided in `at`, |
| 168 | +returning an matrix of `(retval, updated_at)` tuples. |
| 169 | +
|
| 170 | +This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the |
| 171 | +initialisation strategy when re-evaluating the model. For many usecases the fallback should |
| 172 | +not be provided (as we expect the chain to contain all necessary variables); but for |
| 173 | +`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating |
| 174 | +the posterior predictions). |
| 175 | +""" |
| 176 | +function reevaluate_with_chain( |
| 177 | + rng::Random.AbstractRNG, |
| 178 | + model::DynamicPPL.Model, |
| 179 | + chain::MCMCChains.Chains, |
| 180 | + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, |
| 181 | + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, |
| 182 | +) where {N} |
| 183 | + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) |
| 184 | + vi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs)) |
| 185 | + return map(params_with_stats) do ps |
| 186 | + DynamicPPL.init!!(rng, model, vi, DynamicPPL.InitFromParams(ps.params, fallback)) |
| 187 | + end |
| 188 | +end |
| 189 | +function reevaluate_with_chain( |
| 190 | + model::DynamicPPL.Model, |
| 191 | + chain::MCMCChains.Chains, |
| 192 | + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, |
| 193 | + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, |
| 194 | +) where {N} |
| 195 | + return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback) |
| 196 | +end |
| 197 | + |
180 | 198 | """ |
181 | 199 | predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) |
182 | 200 |
|
@@ -245,30 +263,18 @@ function DynamicPPL.predict( |
245 | 263 | include_all=false, |
246 | 264 | ) |
247 | 265 | parameter_only_chain = MCMCChains.get_sections(chain, :parameters) |
248 | | - |
249 | | - # Set up a VarInfo with the right accumulators |
250 | | - varinfo = DynamicPPL.setaccs!!( |
251 | | - DynamicPPL.VarInfo(), |
252 | | - ( |
253 | | - DynamicPPL.LogPriorAccumulator(), |
254 | | - DynamicPPL.LogLikelihoodAccumulator(), |
255 | | - DynamicPPL.ValuesAsInModelAccumulator(false), |
256 | | - ), |
| 266 | + accs = ( |
| 267 | + DynamicPPL.LogPriorAccumulator(), |
| 268 | + DynamicPPL.LogLikelihoodAccumulator(), |
| 269 | + DynamicPPL.ValuesAsInModelAccumulator(false), |
257 | 270 | ) |
258 | | - _, varinfo = DynamicPPL.init!!(model, varinfo) |
259 | | - varinfo = DynamicPPL.typed_varinfo(varinfo) |
260 | | - |
261 | | - params_and_stats = AbstractMCMC.to_samples( |
262 | | - DynamicPPL.ParamsWithStats, parameter_only_chain |
| 271 | + predictions = map( |
| 272 | + DynamicPPL.ParamsWithStats ∘ last, |
| 273 | + reevaluate_with_chain( |
| 274 | + rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior() |
| 275 | + ), |
263 | 276 | ) |
264 | | - predictions = map(params_and_stats) do ps |
265 | | - _, varinfo = DynamicPPL.init!!( |
266 | | - rng, model, varinfo, DynamicPPL.InitFromParams(ps.params) |
267 | | - ) |
268 | | - DynamicPPL.ParamsWithStats(varinfo) |
269 | | - end |
270 | 277 | chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions) |
271 | | - |
272 | 278 | parameter_names = if include_all |
273 | 279 | MCMCChains.names(chain_result, :parameters) |
274 | 280 | else |
@@ -348,18 +354,7 @@ julia> returned(model, chain) |
348 | 354 | """ |
349 | 355 | function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) |
350 | 356 | chain = MCMCChains.get_sections(chain_full, :parameters) |
351 | | - varinfo = DynamicPPL.VarInfo(model) |
352 | | - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) |
353 | | - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) |
354 | | - return map(params_with_stats) do ps |
355 | | - first( |
356 | | - DynamicPPL.init!!( |
357 | | - model, |
358 | | - varinfo, |
359 | | - DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()), |
360 | | - ), |
361 | | - ) |
362 | | - end |
| 357 | + return map(first, reevaluate_with_chain(model, chain, (), nothing)) |
363 | 358 | end |
364 | 359 |
|
365 | 360 | """ |
@@ -452,24 +447,13 @@ function DynamicPPL.pointwise_logdensities( |
452 | 447 | ::Type{Tout}=MCMCChains.Chains, |
453 | 448 | ::Val{whichlogprob}=Val(:both), |
454 | 449 | ) where {whichlogprob,Tout} |
455 | | - vi = DynamicPPL.VarInfo(model) |
456 | 450 | acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() |
457 | 451 | accname = DynamicPPL.accumulator_name(acc) |
458 | | - vi = DynamicPPL.setaccs!!(vi, (acc,)) |
459 | 452 | parameter_only_chain = MCMCChains.get_sections(chain, :parameters) |
460 | | - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) |
461 | | - pointwise_logps = map(iters) do (sample_idx, chain_idx) |
462 | | - # Extract values from the chain |
463 | | - values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) |
464 | | - # Re-evaluate the model |
465 | | - _, vi = DynamicPPL.init!!( |
466 | | - model, |
467 | | - vi, |
468 | | - DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), |
469 | | - ) |
470 | | - DynamicPPL.getacc(vi, Val(accname)).logps |
471 | | - end |
472 | | - |
| 453 | + pointwise_logps = |
| 454 | + map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi) |
| 455 | + DynamicPPL.getacc(vi, Val(accname)).logps |
| 456 | + end |
473 | 457 | # pointwise_logps is a matrix of OrderedDicts |
474 | 458 | all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() |
475 | 459 | for d in pointwise_logps |
@@ -556,15 +540,15 @@ julia> logjoint(demo_model([1., 2.]), chain) |
556 | 540 | ``` |
557 | 541 | """ |
558 | 542 | function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) |
559 | | - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model |
560 | | - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) |
561 | | - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( |
562 | | - vn_parent => DynamicPPL.values_from_chain( |
563 | | - var_info, vn_parent, chain, chain_idx, iteration_idx |
564 | | - ) for vn_parent in keys(var_info) |
565 | | - ) |
566 | | - DynamicPPL.logjoint(model, argvals_dict) |
567 | | - end |
| 543 | + return map( |
| 544 | + DynamicPPL.getlogjoint ∘ last, |
| 545 | + reevaluate_with_chain( |
| 546 | + model, |
| 547 | + chain, |
| 548 | + (DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()), |
| 549 | + nothing, |
| 550 | + ), |
| 551 | + ) |
568 | 552 | end |
569 | 553 |
|
570 | 554 | """ |
@@ -596,15 +580,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain) |
596 | 580 | ``` |
597 | 581 | """ |
598 | 582 | function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) |
599 | | - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model |
600 | | - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) |
601 | | - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( |
602 | | - vn_parent => DynamicPPL.values_from_chain( |
603 | | - var_info, vn_parent, chain, chain_idx, iteration_idx |
604 | | - ) for vn_parent in keys(var_info) |
605 | | - ) |
606 | | - DynamicPPL.loglikelihood(model, argvals_dict) |
607 | | - end |
| 583 | + return map( |
| 584 | + DynamicPPL.getloglikelihood ∘ last, |
| 585 | + reevaluate_with_chain( |
| 586 | + model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing |
| 587 | + ), |
| 588 | + ) |
608 | 589 | end |
609 | 590 |
|
610 | 591 | """ |
@@ -637,15 +618,10 @@ julia> logprior(demo_model([1., 2.]), chain) |
637 | 618 | ``` |
638 | 619 | """ |
639 | 620 | function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) |
640 | | - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model |
641 | | - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) |
642 | | - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( |
643 | | - vn_parent => DynamicPPL.values_from_chain( |
644 | | - var_info, vn_parent, chain, chain_idx, iteration_idx |
645 | | - ) for vn_parent in keys(var_info) |
646 | | - ) |
647 | | - DynamicPPL.logprior(model, argvals_dict) |
648 | | - end |
| 621 | + return map( |
| 622 | + DynamicPPL.getlogprior ∘ last, |
| 623 | + reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing), |
| 624 | + ) |
649 | 625 | end |
650 | 626 |
|
651 | 627 | end |
0 commit comments