Skip to content

Commit a6d56a2

Browse files
authored
Improve FastLDF type stability when all parameters are linked or unlinked (#1141)
* Improve type stability when all parameters are linked or unlinked * fix a merge conflict * fix enzyme gc crash (locally at least) * Fixes from review
1 parent 8547e25 commit a6d56a2

File tree

5 files changed

+99
-24
lines changed

5 files changed

+99
-24
lines changed

src/chains.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,15 @@ via `unflatten` plus re-evaluation. It is faster for two reasons:
130130
"""
131131
function ParamsWithStats(
132132
param_vector::AbstractVector,
133-
ldf::DynamicPPL.LogDensityFunction,
133+
ldf::DynamicPPL.LogDensityFunction{Tlink},
134134
stats::NamedTuple=NamedTuple();
135135
include_colon_eq::Bool=true,
136136
include_log_probs::Bool=true,
137-
)
137+
) where {Tlink}
138138
strategy = InitFromParams(
139-
VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector),
139+
VectorWithRanges{Tlink}(
140+
ldf._iden_varname_ranges, ldf._varname_ranges, param_vector
141+
),
140142
nothing,
141143
)
142144
accs = if include_log_probs

src/contexts/init.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ struct RangeAndLinked
214214
end
215215

216216
"""
217-
VectorWithRanges(
217+
VectorWithRanges{Tlink}(
218218
iden_varname_ranges::NamedTuple,
219219
varname_ranges::Dict{VarName,RangeAndLinked},
220220
vect::AbstractVector{<:Real},
@@ -223,6 +223,12 @@ end
223223
A struct that wraps a vector of parameter values, plus information about how random
224224
variables map to ranges in that vector.
225225
226+
The type parameter `Tlink` can be either `true` or `false`, to mark that the variables in
227+
this `VectorWithRanges` are linked/not linked, or `nothing` if either the linking status is
228+
not known or is mixed, i.e. some are linked while others are not. Using `nothing` does not
229+
affect functionality or correctness, but causes more work to be done at runtime, with
230+
possible impacts on type stability and performance.
231+
226232
In the simplest case, this could be accomplished only with a single dictionary mapping
227233
VarNames to ranges and link status. However, for performance reasons, we separate out
228234
VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All
@@ -231,13 +237,26 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict.
231237
It would be nice to improve the NamedTuple and Dict approach. See, e.g.
232238
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
233239
"""
234-
struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}}
240+
struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}}
235241
# This NamedTuple stores the ranges for identity VarNames
236242
iden_varname_ranges::N
237243
# This Dict stores the ranges for all other VarNames
238244
varname_ranges::Dict{VarName,RangeAndLinked}
239245
# The full parameter vector which we index into to get variable values
240246
vect::T
247+
248+
function VectorWithRanges{Tlink}(
249+
iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T
250+
) where {Tlink,N,T}
251+
if !(Tlink isa Union{Bool,Nothing})
252+
throw(
253+
ArgumentError(
254+
"VectorWithRanges type parameter has to be one of `true`, `false`, or `nothing`.",
255+
),
256+
)
257+
end
258+
return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect)
259+
end
241260
end
242261

243262
function _get_range_and_linked(
@@ -252,11 +271,15 @@ function init(
252271
::Random.AbstractRNG,
253272
vn::VarName,
254273
dist::Distribution,
255-
p::InitFromParams{<:VectorWithRanges},
256-
)
274+
p::InitFromParams{<:VectorWithRanges{T}},
275+
) where {T}
257276
vr = p.params
258277
range_and_linked = _get_range_and_linked(vr, vn)
259-
transform = if range_and_linked.is_linked
278+
# T can either be `nothing` (i.e., link status is mixed, in which
279+
# case we use the stored link status), or `true` / `false`, which
280+
# indicates that all variables are linked / unlinked.
281+
linked = isnothing(T) ? range_and_linked.is_linked : T
282+
transform = if linked
260283
from_linked_vec_transform(dist)
261284
else
262285
from_vec_transform(dist)

src/logdensityfunction.jl

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
140140
`unflatten` + `evaluate!!` approach also fails with such models.
141141
"""
142142
struct LogDensityFunction{
143+
# true if all variables are linked; false if all variables are unlinked; nothing if
144+
# mixed
145+
Tlink,
143146
M<:Model,
144147
AD<:Union{ADTypes.AbstractADType,Nothing},
145148
F<:Function,
@@ -163,6 +166,21 @@ struct LogDensityFunction{
163166
# Figure out which variable corresponds to which index, and
164167
# which variables are linked.
165168
all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo)
169+
# Figure out if all variables are linked, unlinked, or mixed
170+
link_statuses = Bool[]
171+
for ral in all_iden_ranges
172+
push!(link_statuses, ral.is_linked)
173+
end
174+
for (_, ral) in all_ranges
175+
push!(link_statuses, ral.is_linked)
176+
end
177+
Tlink = if all(link_statuses)
178+
true
179+
elseif all(!s for s in link_statuses)
180+
false
181+
else
182+
nothing
183+
end
166184
x = [val for val in varinfo[:]]
167185
dim = length(x)
168186
# Do AD prep if needed
@@ -172,12 +190,13 @@ struct LogDensityFunction{
172190
# Make backend-specific tweaks to the adtype
173191
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
174192
DI.prepare_gradient(
175-
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
193+
LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
176194
adtype,
177195
x,
178196
)
179197
end
180198
return new{
199+
Tlink,
181200
typeof(model),
182201
typeof(adtype),
183202
typeof(getlogdensity),
@@ -209,36 +228,45 @@ end
209228
ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
210229
ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
211230

212-
struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
231+
struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple}
213232
model::M
214233
getlogdensity::F
215234
iden_varname_ranges::N
216235
varname_ranges::Dict{VarName,RangeAndLinked}
236+
237+
function LogDensityAt{Tlink}(
238+
model::M,
239+
getlogdensity::F,
240+
iden_varname_ranges::N,
241+
varname_ranges::Dict{VarName,RangeAndLinked},
242+
) where {Tlink,M,F,N}
243+
return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
244+
end
217245
end
218-
function (f::LogDensityAt)(params::AbstractVector{<:Real})
246+
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
219247
strategy = InitFromParams(
220-
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
248+
VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing
221249
)
222250
accs = ldf_accs(f.getlogdensity)
223251
_, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy)
224252
return f.getlogdensity(vi)
225253
end
226254

227255
function LogDensityProblems.logdensity(
228-
ldf::LogDensityFunction, params::AbstractVector{<:Real}
229-
)
230-
return LogDensityAt(
256+
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
257+
) where {Tlink}
258+
return LogDensityAt{Tlink}(
231259
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
232260
)(
233261
params
234262
)
235263
end
236264

237265
function LogDensityProblems.logdensity_and_gradient(
238-
ldf::LogDensityFunction, params::AbstractVector{<:Real}
239-
)
266+
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
267+
) where {Tlink}
240268
return DI.value_and_gradient(
241-
LogDensityAt(
269+
LogDensityAt{Tlink}(
242270
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
243271
),
244272
ldf._adprep,
@@ -247,12 +275,14 @@ function LogDensityProblems.logdensity_and_gradient(
247275
)
248276
end
249277

250-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M}
278+
function LogDensityProblems.capabilities(
279+
::Type{<:LogDensityFunction{T,M,Nothing}}
280+
) where {T,M}
251281
return LogDensityProblems.LogDensityOrder{0}()
252282
end
253283
function LogDensityProblems.capabilities(
254-
::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
255-
) where {M}
284+
::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}}
285+
) where {T,M}
256286
return LogDensityProblems.LogDensityOrder{1}()
257287
end
258288
function LogDensityProblems.dimension(ldf::LogDensityFunction)

test/integration/enzyme/main.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@ using Test: @test, @testset
55
import Enzyme: set_runtime_activity, Forward, Reverse, Const
66
using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test
77

8-
ADTYPES = Dict(
9-
"EnzymeForward" =>
8+
ADTYPES = (
9+
(
10+
"EnzymeForward",
1011
AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const),
11-
"EnzymeReverse" =>
12+
),
13+
(
14+
"EnzymeReverse",
1215
AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const),
16+
),
1317
)
1418

1519
@testset "$ad_key" for (ad_key, ad_type) in ADTYPES

test/logdensityfunction.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,22 @@ end
108108
end
109109
end
110110

111+
@testset "LogDensityFunction: Type stability" begin
112+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
113+
unlinked_vi = DynamicPPL.VarInfo(m)
114+
@testset "$islinked" for islinked in (false, true)
115+
vi = if islinked
116+
DynamicPPL.link!!(unlinked_vi, m)
117+
else
118+
unlinked_vi
119+
end
120+
ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi)
121+
x = vi[:]
122+
@inferred LogDensityProblems.logdensity(ldf, x)
123+
end
124+
end
125+
end
126+
111127
@testset "LogDensityFunction: performance" begin
112128
if Threads.nthreads() == 1
113129
# Evaluating these three models should not lead to any allocations (but only when

0 commit comments

Comments
 (0)