Skip to content

Commit 16d9dfd

Browse files
committed
Fix optimisation interface
1 parent 2f4794b commit 16d9dfd

File tree

2 files changed

+52
-57
lines changed

2 files changed

+52
-57
lines changed

src/optimisation/Optimisation.jl

Lines changed: 50 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ required by Optimization.jl.
8181
"""
8282
ModeResult{
8383
V<:NamedArrays.NamedArray,
84-
M<:NamedArrays.NamedArray,
85-
O<:Optim.MultivariateOptimizationResults,
86-
S<:NamedArrays.NamedArray,
84+
O<:Any,
85+
M<:OptimLogDensity,
8786
P<:AbstractDict{<:VarName,<:Any}
87+
E<:ModeEstimator,
8888
}
8989
9090
A wrapper struct to store various results from a MAP or MLE estimation.
@@ -98,6 +98,7 @@ struct ModeResult{
9898
O<:Any,
9999
M<:OptimLogDensity,
100100
P<:AbstractDict{<:AbstractPPL.VarName,<:Any},
101+
E<:ModeEstimator,
101102
} <: StatsBase.StatisticalModel
102103
"A vector with the resulting point estimates."
103104
values::V
@@ -109,6 +110,10 @@ struct ModeResult{
109110
f::M
110111
"Dictionary of parameter values"
111112
params::P
113+
"Whether the optimization was done in a transformed space."
114+
linked::Bool
115+
"The type of mode estimation (MAP or MLE)."
116+
estimator::E
112117
end
113118

114119
function Base.show(io::IO, ::MIME"text/plain", m::ModeResult)
@@ -218,42 +223,16 @@ end
218223
function StatsBase.informationmatrix(
219224
m::ModeResult; hessian_function=ForwardDiff.hessian, kwargs...
220225
)
221-
# Calculate Hessian and information matrix.
222-
223-
# Convert the values to their unconstrained states to make sure the
224-
# Hessian is computed with respect to the untransformed parameters.
225-
old_ldf = m.f.ldf
226-
linked = DynamicPPL.is_transformed(old_ldf.varinfo)
227-
if linked
228-
new_vi = DynamicPPL.invlink!!(old_ldf.varinfo, old_ldf.model)
229-
new_f = OptimLogDensity(
230-
DynamicPPL.LogDensityFunction(
231-
old_ldf.model, old_ldf.getlogdensity, new_vi; adtype=old_ldf.adtype
232-
),
233-
)
234-
m = Accessors.@set m.f = new_f
235-
end
226+
# This needs to be calculated in unlinked space
227+
model = m.f.ldf.model
228+
vi = DynamicPPL.VarInfo(model)
229+
getlogdensity = _choose_getlogdensity(m.estimator)
230+
new_optimld = OptimLogDensity(DynamicPPL.LogDensityFunction(model, getlogdensity, vi))
236231

237232
# Calculate the Hessian, which is the information matrix because the negative of the log
238233
# likelihood was optimized
239234
varnames = StatsBase.coefnames(m)
240-
info = hessian_function(m.f, m.values.array[:, 1])
241-
242-
# Link it back if we invlinked it.
243-
if linked
244-
invlinked_ldf = m.f.ldf
245-
new_vi = DynamicPPL.link!!(invlinked_ldf.varinfo, invlinked_ldf.model)
246-
new_f = OptimLogDensity(
247-
DynamicPPL.LogDensityFunction(
248-
invlinked_ldf.model,
249-
old_ldf.getlogdensity,
250-
new_vi;
251-
adtype=invlinked_ldf.adtype,
252-
),
253-
)
254-
m = Accessors.@set m.f = new_f
255-
end
256-
235+
info = hessian_function(new_optimld, m.values.array[:, 1])
257236
return NamedArrays.NamedArray(info, (varnames, varnames))
258237
end
259238

@@ -272,11 +251,7 @@ Return the values of all the variables with the symbol(s) `var_symbol` in the mo
272251
argument should be either a `Symbol` or a vector of `Symbol`s.
273252
"""
274253
function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
275-
log_density = m.f.ldf
276-
# Get all the variable names in the model. This is the same as the list of keys in
277-
# m.values, but they are more convenient to filter when they are VarNames rather than
278-
# Symbols.
279-
vals_dict = Turing.Inference.getparams(log_density.model, log_density.varinfo)
254+
vals_dict = m.params
280255
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
281256
vns_and_vals = mapreduce(collect, vcat, iters)
282257
varnames = collect(map(first, vns_and_vals))
@@ -296,18 +271,27 @@ end
296271
Base.get(m::ModeResult, var_symbol::Symbol) = get(m, [var_symbol])
297272

298273
"""
299-
ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
274+
ModeResult(
275+
log_density::OptimLogDensity,
276+
solution::SciMLBase.OptimizationSolution,
277+
linked::Bool,
278+
estimator::ModeEstimator,
279+
)
300280
301281
Create a `ModeResult` for a given `log_density` objective and a `solution` given by `solve`.
282+
The `linked` argument indicates whether the optimization was done in a transformed space.
302283
303284
`Optimization.solve` returns its own result type. This function converts that into the
304285
richer format of `ModeResult`. It also takes care of transforming them back to the original
305286
parameter space in case the optimization was done in a transformed space.
306287
"""
307-
function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
308-
varinfo_new = DynamicPPL.unflatten(log_density.ldf.varinfo, solution.u)
309-
# `getparams` performs invlinking if needed
310-
vals = Turing.Inference.getparams(log_density.ldf.model, varinfo_new)
288+
function ModeResult(
289+
log_density::OptimLogDensity,
290+
solution::SciMLBase.OptimizationSolution,
291+
linked::Bool,
292+
estimator::ModeEstimator,
293+
)
294+
vals = DynamicPPL.ParamsWithStats(solution.u, log_density.ldf).params
311295
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
312296
vns_vals_iter = mapreduce(collect, vcat, iters)
313297
syms = map(Symbol first, vns_vals_iter)
@@ -318,6 +302,8 @@ function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.Optimizati
318302
-solution.objective,
319303
log_density,
320304
vals,
305+
linked,
306+
estimator,
321307
)
322308
end
323309

@@ -394,18 +380,19 @@ function default_solver(constraints::ModeEstimationConstraints)
394380
end
395381

396382
"""
397-
OptimizationProblem(log_density::OptimLogDensity, adtype, constraints)
383+
OptimizationProblem(log_density::OptimLogDensity, initial_params::AbstractVector, adtype, constraints)
398384
399385
Create an `OptimizationProblem` for the objective function defined by `log_density`.
400386
401387
Note that the adtype parameter here overrides any adtype parameter the
402388
OptimLogDensity was constructed with.
403389
"""
404-
function Optimization.OptimizationProblem(log_density::OptimLogDensity, adtype, constraints)
390+
function Optimization.OptimizationProblem(
391+
log_density::OptimLogDensity, initial_params::AbstractVector, adtype, constraints
392+
)
405393
# Note that OptimLogDensity is a callable that evaluates the model with given
406394
# parameters. Hence we can use it in the objective function as below.
407395
f = Optimization.OptimizationFunction(log_density, adtype; cons=constraints.cons)
408-
initial_params = log_density.ldf.varinfo[:]
409396
prob = if !has_constraints(constraints)
410397
Optimization.OptimizationProblem(f, initial_params)
411398
else
@@ -421,6 +408,12 @@ function Optimization.OptimizationProblem(log_density::OptimLogDensity, adtype,
421408
return prob
422409
end
423410

411+
# Note that we use `getlogjoint` rather than `getlogjoint_internal`: this is intentional,
412+
# because even though the VarInfo may be linked, the optimisation target should not take the
413+
# Jacobian term into account.
414+
_choose_getlogdensity(::MAP) = DynamicPPL.getlogjoint
415+
_choose_getlogdensity(::MLE) = DynamicPPL.getloglikelihood
416+
424417
"""
425418
estimate_mode(
426419
model::DynamicPPL.Model,
@@ -478,13 +471,6 @@ function estimate_mode(
478471
solver = default_solver(constraints)
479472
end
480473

481-
# Create an OptimLogDensity object that can be used to evaluate the objective function,
482-
# i.e. the negative log density.
483-
# Note that we use `getlogjoint` rather than `getlogjoint_internal`: this
484-
# is intentional, because even though the VarInfo may be linked, the
485-
# optimisation target should not take the Jacobian term into account.
486-
getlogdensity = estimator isa MAP ? DynamicPPL.getlogjoint : DynamicPPL.getloglikelihood
487-
488474
# Set its VarInfo to the initial parameters.
489475
# TODO(penelopeysm): Unclear if this is really needed? Any time that logp is calculated
490476
# (using `LogDensityProblems.logdensity(ldf, x)`) the parameters in the
@@ -502,17 +488,24 @@ function estimate_mode(
502488
if optimise_in_unconstrained_space
503489
vi = DynamicPPL.link(vi, model)
504490
end
491+
# Re-extract initial parameters (which may now be linked).
492+
initial_params = vi[:]
505493

506494
# Note that we don't need adtype here, because it's specified inside the
507495
# OptimizationProblem
496+
getlogdensity = _choose_getlogdensity(estimator)
508497
ldf = DynamicPPL.LogDensityFunction(model, getlogdensity, vi)
498+
# Create an OptimLogDensity object that can be used to evaluate the objective function,
499+
# i.e. the negative log density.
509500
log_density = OptimLogDensity(ldf)
510501

511-
prob = Optimization.OptimizationProblem(log_density, adtype, constraints)
502+
prob = Optimization.OptimizationProblem(
503+
log_density, initial_params, adtype, constraints
504+
)
512505
solution = Optimization.solve(prob, solver; kwargs...)
513506
# TODO(mhauru) We return a ModeResult for compatibility with the older Optim.jl
514507
# interface. Might we want to break that and develop a better return type?
515-
return ModeResult(log_density, solution)
508+
return ModeResult(log_density, solution, optimise_in_unconstrained_space, estimator)
516509
end
517510

518511
"""

test/optimisation/Optimisation.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,8 @@ using Turing
670670
0.0,
671671
optim_ld,
672672
Dict{AbstractPPL.VarName,Float64}(@varname(x) => 0.0, @varname(y) => 0.0),
673+
false,
674+
MLE(),
673675
)
674676
ct = coeftable(m)
675677
@assert isnan(ct.cols[2][1])

0 commit comments

Comments
 (0)