@@ -206,24 +206,24 @@ function bayes_unpack_data(prob, p::AbstractVector{<:Pair})
206206 (pdist, IndexKeyMap (prob, pkeys))
207207end
208208
209- Turing. @model function bayesianODE (prob, t, pdist, pkeys, data, noise_prior)
209+ Turing. @model function bayesianODE (prob, alg, t, pdist, pkeys, data, datamap , noise_prior)
210210 σ ~ noise_prior
211211
212212 pprior ~ product_distribution (pdist)
213213
214214 prob = _remake (prob, (prob. tspan[1 ], t[end ]), pkeys, pprior)
215- sol = solve (prob, saveat = t)
215+ sol = solve (prob, alg, saveat = t)
216216 if ! SciMLBase. successful_retcode (sol)
217217 Turing. DynamicPPL. acclogp!! (__varinfo__, - Inf )
218218 return nothing
219219 end
220220 for i in eachindex (data)
221- data[i]. second ~ MvNormal (sol[data[i] . first] , σ^ 2 * I)
221+ data[i] ~ MvNormal (datamap ( sol) , σ^ 2 * I)
222222 end
223223 return nothing
224224end
225225
226- Turing. @model function bayesianODE (prob,
226+ Turing. @model function bayesianODE (prob, alg,
227227 pdist,
228228 pkeys,
229229 ts,
@@ -236,7 +236,7 @@ Turing.@model function bayesianODE(prob,
236236 pprior ~ product_distribution (pdist)
237237
238238 prob = _remake (prob, (prob. tspan[1 ], lastt), pkeys, pprior)
239- sol = solve (prob)
239+ sol = solve (prob, alg )
240240 if ! SciMLBase. successful_retcode (sol)
241241 Turing. DynamicPPL. acclogp!! (__varinfo__, - Inf )
242242 return nothing
@@ -264,18 +264,19 @@ end
264264Base. length (ws:: WeightedSol ) = length (first (ws. sols))
265265Base. size (ws:: WeightedSol ) = (length (first (ws. sols)),)
266266function Base. getindex (ws:: WeightedSol{T} , i:: Int ) where {T}
267- s = zero (T)
268- w = zero (T)
269- for j in eachindex (ws. weights)
267+ s:: T = zero (T)
268+ w:: T = zero (T)
269+ @inbounds for j in eachindex (ws. weights)
270270 w += ws. weights[j]
271271 s += ws. weights[j] * ws. sols[j][i]
272272 end
273273 return s + (one (T) - w) * ws. sols[end ][i]
274274end
275- function WeightedSol (sols, select, weights)
276- T = eltype (weights)
277- s = map (Base. Fix2 (getindex, select), sols)
278- WeightedSol {T} (s, weights)
275+ function WeightedSol (sols, select, i:: Int , weights)
276+ s = map (sols, select) do sol, sel
277+ @view (sol[sel. indices[i], :])
278+ end
279+ WeightedSol {eltype(weights)} (s, weights)
279280end
280281function bayes_unpack_data (probs, p:: Tuple{Vararg{<:AbstractVector{<:Pair}}} , data)
281282 pdist, pkeys = bayes_unpack_data (probs, p)
@@ -305,43 +306,46 @@ function flatten(x::Tuple)
305306 reduce (vcat, x), Grouper (map (length, x))
306307end
307308
308- function getsols (probs, probspkeys, ppriors, t:: AbstractArray )
309- map (probs, probspkeys, ppriors) do prob, pkeys, pprior
309+ function getsols (probs, algs, probspkeys, ppriors, t:: AbstractArray )
310+ map (probs, algs, probspkeys, ppriors) do prob, alg , pkeys, pprior
310311 newprob = _remake (prob, (prob. tspan[1 ], t[end ]), pkeys, pprior)
311- solve (newprob, saveat = t)
312+ solve (newprob, alg, saveat = t)
312313 end
313314end
314- function getsols (probs, probspkeys, ppriors, lastt:: Number )
315- map (probs, probspkeys, ppriors) do prob, pkeys, pprior
315+ function getsols (probs, algs, probspkeys, ppriors, lastt:: Number )
316+ map (probs, algs, probspkeys, ppriors) do prob, alg , pkeys, pprior
316317 newprob = _remake (prob, (prob. tspan[1 ], lastt), pkeys, pprior)
317- solve (newprob)
318+ solve (newprob, alg )
318319 end
319320end
320321
321322Turing. @model function ensemblebayesianODE (probs:: Union{Tuple, AbstractVector} ,
323+ algs,
322324 t,
323325 pdist,
324326 grouppriorsfunc,
325327 probspkeys,
326328 data,
329+ datamaps,
327330 noise_prior)
328331 σ ~ noise_prior
329332 ppriors ~ product_distribution (pdist)
330333
331334 Nprobs = length (probs)
332335 Nprobs⁻¹ = inv (Nprobs)
333336 weights ~ MvNormal (Distributions. Fill (Nprobs⁻¹, Nprobs - 1 ), Nprobs⁻¹)
334- sols = getsols (probs, probspkeys, grouppriorsfunc (ppriors), t)
337+ sols = getsols (probs, algs, probspkeys, grouppriorsfunc (ppriors), t)
335338 if ! all (SciMLBase. successful_retcode, sols)
336339 Turing. DynamicPPL. acclogp!! (__varinfo__, - Inf )
337340 return nothing
338341 end
339342 for i in eachindex (data)
340- data[i]. second ~ MvNormal (WeightedSol (sols, data[i] . first , weights), σ^ 2 * I)
343+ data[i] ~ MvNormal (WeightedSol (sols, datamaps, i , weights), σ^ 2 * I)
341344 end
342345 return nothing
343346end
344347Turing. @model function ensemblebayesianODE (probs:: Union{Tuple, AbstractVector} ,
348+ algs,
345349 pdist,
346350 grouppriorsfunc,
347351 probspkeys,
@@ -353,7 +357,7 @@ Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector},
353357 σ ~ noise_prior
354358 ppriors ~ product_distribution (pdist)
355359
356- sols = getsols (probs, probspkeys, grouppriorsfunc (ppriors), lastt)
360+ sols = getsols (probs, algs, probspkeys, grouppriorsfunc (ppriors), lastt)
357361
358362 Nprobs = length (probs)
359363 Nprobs⁻¹ = inv (Nprobs)
@@ -411,7 +415,14 @@ function bayesian_datafit(prob,
411415 nchains = 4 ,
412416 niter = 1000 )
413417 (pdist, pkeys) = bayes_unpack_data (prob, p)
414- model = bayesianODE (prob, t, pdist, pkeys, data, noise_prior)
418+ model = bayesianODE (prob,
419+ first (default_algorithm (prob)),
420+ t,
421+ pdist,
422+ pkeys,
423+ last .(data),
424+ IndexKeyMap (prob, data),
425+ noise_prior)
415426 chain = Turing. sample (model,
416427 Turing. NUTS (0.65 ),
417428 mcmcensemble,
@@ -430,7 +441,15 @@ function bayesian_datafit(prob,
430441 nchains = 4 ,
431442 niter = 1_000 )
432443 pdist, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data (prob, p, data)
433- model = bayesianODE (prob, pdist, pkeys, ts, lastt, timeseries, datakeys, noise_prior)
444+ model = bayesianODE (prob,
445+ first (default_algorithm (prob)),
446+ pdist,
447+ pkeys,
448+ ts,
449+ lastt,
450+ timeseries,
451+ datakeys,
452+ noise_prior)
434453 chain = Turing. sample (model,
435454 Turing. NUTS (0.65 ),
436455 mcmcensemble,
@@ -451,7 +470,10 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
451470 (pdist_, pkeys) = bayes_unpack_data (p)
452471 pdist, grouppriorsfunc = flatten (pdist_)
453472
454- model = ensemblebayesianODE (probs, t, pdist, grouppriorsfunc, pkeys, data, noise_prior)
473+ model = ensemblebayesianODE (probs,
474+ map (first ∘ default_algorithm, probs),
475+ t, pdist, grouppriorsfunc, pkeys, last .(data),
476+ map (Base. Fix2 (IndexKeyMap, data), probs), noise_prior)
455477 chain = Turing. sample (model,
456478 Turing. NUTS (0.65 ),
457479 mcmcensemble,
@@ -472,6 +494,7 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
472494 pdist_, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data (p, data)
473495 pdist, grouppriorsfunc = flatten (pdist_)
474496 model = ensemblebayesianODE (probs,
497+ map (first ∘ default_algorithm, probs),
475498 pdist,
476499 grouppriorsfunc,
477500 pkeys,
0 commit comments