Skip to content

Commit 2f4794b

Browse files
committed
fix externalsampler bug
1 parent 324b76a commit 2f4794b

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

src/mcmc/external_sampler.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,17 +122,20 @@ function externalsampler(
122122
end
123123

124124
# TODO(penelopeysm): Can't we clean this up somehow?
125-
struct TuringState{S,V,L<:DynamicPPL.LogDensityFunction}
125+
struct TuringState{S,V,P<:AbstractVector,L<:DynamicPPL.LogDensityFunction}
126126
state::S
127127
# Note that this varinfo is used only for structure. Its parameters and other info do
128128
# not need to be accurate
129129
varinfo::V
130+
# These are the actual parameters that this state is at
131+
params::P
130132
ldf::L
131133
end
132134

133-
# get_varinfo should return something from which the correct parameters can be
134-
# obtained, hence we use state.varinfo rather than state.ldf.varinfo
135-
get_varinfo(state::TuringState) = state.varinfo
135+
# get_varinfo must return something from which the correct parameters can be obtained
136+
function get_varinfo(state::TuringState)
137+
return DynamicPPL.unflatten(state.varinfo, state.params)
138+
end
136139
get_varinfo(state::AbstractVarInfo) = state
137140

138141
function AbstractMCMC.step(
@@ -188,7 +191,7 @@ function AbstractMCMC.step(
188191
new_stats = AbstractMCMC.getstats(state_inner)
189192
return (
190193
DynamicPPL.ParamsWithStats(new_parameters, f, new_stats),
191-
TuringState(state_inner, varinfo, f),
194+
TuringState(state_inner, varinfo, new_parameters, f),
192195
)
193196
end
194197

@@ -211,6 +214,6 @@ function AbstractMCMC.step(
211214
new_stats = AbstractMCMC.getstats(state_inner)
212215
return (
213216
DynamicPPL.ParamsWithStats(new_parameters, f, new_stats),
214-
TuringState(state_inner, state.varinfo, f),
217+
TuringState(state_inner, state.varinfo, new_parameters, f),
215218
)
216219
end

src/mcmc/gibbs.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -456,14 +456,11 @@ function AbstractMCMC.step_warmup(
456456
end
457457

458458
"""
459-
setparams_varinfo!!(model, sampler::AbstractSampler, state, params::AbstractVarInfo)
459+
setparams_varinfo!!(model::DynamicPPL.Model, sampler::AbstractSampler, state, params::AbstractVarInfo)
460460
461461
A lot like AbstractMCMC.setparams!!, but instead of taking a vector of parameters, takes an
462462
`AbstractVarInfo` object. Also takes the `sampler` as an argument. By default, falls back to
463463
`AbstractMCMC.setparams!!(model, state, params[:])`.
464-
465-
`model` is typically a `DynamicPPL.Model`, but can also be e.g. an
466-
`AbstractMCMC.LogDensityModel`.
467464
"""
468465
function setparams_varinfo!!(
469466
model::DynamicPPL.Model, ::AbstractSampler, state, params::AbstractVarInfo
@@ -488,12 +485,18 @@ function setparams_varinfo!!(
488485
end
489486

490487
function setparams_varinfo!!(
491-
::DynamicPPL.Model, ::ExternalSampler, state::TuringState, params::AbstractVarInfo
488+
model::DynamicPPL.Model,
489+
sampler::ExternalSampler,
490+
state::TuringState,
491+
params::AbstractVarInfo,
492492
)
493+
new_ldf = DynamicPPL.LogDensityFunction(
494+
model, DynamicPPL.getlogjoint_internal, params; adtype=sampler.adtype
495+
)
493496
new_inner_state = AbstractMCMC.setparams!!(
494-
AbstractMCMC.LogDensityModel(state.ldf), state.state, params[:]
497+
AbstractMCMC.LogDensityModel(new_ldf), state.state, params[:]
495498
)
496-
return TuringState(new_inner_state, params, state.ldf)
499+
return TuringState(new_inner_state, params, params[:], new_ldf)
497500
end
498501

499502
function setparams_varinfo!!(

0 commit comments

Comments
 (0)