@@ -456,14 +456,11 @@ function AbstractMCMC.step_warmup(
456456end
457457
458458"""
459- setparams_varinfo!!(model, sampler::AbstractSampler, state, params::AbstractVarInfo)
459+ setparams_varinfo!!(model::DynamicPPL.Model , sampler::AbstractSampler, state, params::AbstractVarInfo)
460460
461461A lot like AbstractMCMC.setparams!!, but instead of taking a vector of parameters, takes an
462462`AbstractVarInfo` object. Also takes the `sampler` as an argument. By default, falls back to
463463`AbstractMCMC.setparams!!(model, state, params[:])`.
464-
465- `model` is typically a `DynamicPPL.Model`, but can also be e.g. an
466- `AbstractMCMC.LogDensityModel`.
467464"""
468465function setparams_varinfo!! (
469466 model:: DynamicPPL.Model , :: AbstractSampler , state, params:: AbstractVarInfo
@@ -488,12 +485,18 @@ function setparams_varinfo!!(
488485end
489486
490487function setparams_varinfo!! (
491- :: DynamicPPL.Model , :: ExternalSampler , state:: TuringState , params:: AbstractVarInfo
488+ model:: DynamicPPL.Model ,
489+ sampler:: ExternalSampler ,
490+ state:: TuringState ,
491+ params:: AbstractVarInfo ,
492492)
493+ new_ldf = DynamicPPL. LogDensityFunction (
494+ model, DynamicPPL. getlogjoint_internal, params; adtype= sampler. adtype
495+ )
493496 new_inner_state = AbstractMCMC. setparams!! (
494- AbstractMCMC. LogDensityModel (state . ldf ), state. state, params[:]
497+ AbstractMCMC. LogDensityModel (new_ldf ), state. state, params[:]
495498 )
496- return TuringState (new_inner_state, params, state . ldf )
499+ return TuringState (new_inner_state, params, params[:], new_ldf )
497500end
498501
499502function setparams_varinfo!! (
0 commit comments