Skip to content

Commit 84dc2cf

Browse files
committed
pass prob as first arg
1 parent 0a70754 commit 84dc2cf

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/datafit.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,15 @@ Turing.@model function bayesianODE(prob, alg, t, pdist, pkeys, data, datamap, no
211211

212212
pprior ~ product_distribution(pdist)
213213

214-
prob = _remake(prob, (prob.tspan[1], t[end]), pkeys, pprior)
214+
prob = _remake(prob, (prob.tspan[1], t[end]), pkeys, pprior)
215+
215216
sol = solve(prob, alg, saveat = t)
216217
if !SciMLBase.successful_retcode(sol)
217218
Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf)
218219
return nothing
219220
end
220-
for i in eachindex(data)
221-
data[i] ~ MvNormal(datamap(sol), σ^2 * I)
221+
for (i,x) in enumerate(datamap(sol))
222+
data[i] ~ MvNormal(x, σ^2 * I)
222223
end
223224
return nothing
224225
end
@@ -422,7 +423,7 @@ function bayesian_datafit(prob,
422423
pkeys,
423424
last.(data),
424425
IndexKeyMap(prob, data),
425-
noise_prior)
426+
noise_prior)
426427
chain = Turing.sample(model,
427428
Turing.NUTS(0.65),
428429
mcmcensemble,
@@ -467,7 +468,7 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
467468
mcmcensemble::AbstractMCMC.AbstractMCMCEnsemble = Turing.MCMCThreads(),
468469
nchains = 4,
469470
niter = 1000)
470-
(pdist_, pkeys) = bayes_unpack_data(p)
471+
(pdist_, pkeys) = bayes_unpack_data(probs, p)
471472
pdist, grouppriorsfunc = flatten(pdist_)
472473

473474
model = ensemblebayesianODE(probs,
@@ -491,7 +492,7 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
491492
mcmcensemble::AbstractMCMC.AbstractMCMCEnsemble = Turing.MCMCThreads(),
492493
nchains = 4,
493494
niter = 1_000)
494-
pdist_, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data(p, data)
495+
pdist_, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data(probs, p, data)
495496
pdist, grouppriorsfunc = flatten(pdist_)
496497
model = ensemblebayesianODE(probs,
497498
map(first default_algorithm, probs),

0 commit comments

Comments
 (0)