Skip to content

Commit 83b8bd7

Browse files
committed
Add a flag to control the use of views in gradients
1 parent 9e6948e commit 83b8bd7

File tree

7 files changed

+94
-4
lines changed

7 files changed

+94
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedVI"
22
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
3-
version = "0.6"
3+
version = "0.6.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/general.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ optimize
1515
Each algorithm may interact differently with the arguments of `optimize`.
1616
Therefore, please refer to the documentation of each different algorithm for a detailed description on their behavior and their requirements.
1717

18+
The `prob` argument to `optimize` must satisfy the LogDensityProblems.jl interface.
19+
Some algorithms in AdvancedVI will call `logdensity_and_gradient` or `logdensity_gradient_and_hessian` methods using not a vector, but a view of an array.
20+
If this is not supported by the `prob` argument, you should define this method to return `false`:
21+
22+
```@docs
23+
use_view_in_gradient
24+
```
25+
1826
## [Monitoring the Objective Value](@id estimate_objective)
1927

2028
Furthermore, each algorithm has an associated variational objective subject to *minimization*. (By convention, we assume all objectives are minimized rather than maximized.)

src/AdvancedVI.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ function optimize end
307307
export optimize
308308

309309
include("utils.jl")
310+
export use_view_in_gradient
311+
310312
include("optimize.jl")
311313

312314
## Parameter Space SGD Implementations

src/algorithms/fisherminbatchmatch.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,12 @@ function rand_batch_match_samples_with_objective!(
9292
z = C*u .+ μ
9393
logπ_sum = zero(eltype(μ))
9494
for b in 1:n_samples
95-
logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, view(z, :, b))
95+
zb = if use_view_in_gradient(prob)
96+
view(z, :, b)
97+
else
98+
z[:, b]
99+
end
100+
logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, zb)
96101
grad_buf[:, b] = gb
97102
logπ_sum += logπb
98103
end

src/algorithms/gauss_expected_grad_hess.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ function gaussian_expectation_gradient_and_hessian!(
4444
m, C = q.location, q.scale
4545
z = C*u .+ m
4646
for b in 1:n_samples
47-
zb, ub = view(z, :, b), view(u, :, b)
47+
zb, ub = if use_view_in_gradient(prob)
48+
view(z, :, b), view(u, :, b)
49+
else
50+
z[:, b], u[:, b]
51+
end
4852
logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb)
4953
logπ_avg += logπ/n_samples
5054

@@ -60,7 +64,11 @@ function gaussian_expectation_gradient_and_hessian!(
6064
# Second-order: use naive sample average
6165
z = rand(rng, q, n_samples)
6266
for b in 1:n_samples
63-
zb = view(z, :, b)
67+
zb = if use_view_in_gradient(prob)
68+
view(z, :, b)
69+
else
70+
z[:, b]
71+
end
6472
logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian(
6573
prob, zb
6674
)

src/utils.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,13 @@ function catsamples_and_acc(
1212
∑y = last(state_curr) + last(state_new)
1313
return (x, ∑y)
1414
end
15+
16+
"""
17+
use_view_in_gradient(prob)::Bool
18+
19+
When calling `logdensity_and_gradient(prob, x)`, this determines whether `x` can be passed
20+
as a view. This is usually better for efficiency and hence the default is `true`. However,
21+
some `prob`s may not support views (e.g. if gradient preparation has already been done with
22+
a full vector).
23+
"""
24+
use_view_in_gradient(@nospecialize(prob::Any)) = true

test/general/view.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
@testset "use_view_in_gradient" begin
2+
# Set up a LogDensityProblem that does not accept views
3+
struct LogDensityNoView end
4+
dims = 2
5+
LogDensityProblems.dimension(::LogDensityNoView) = dims
6+
LogDensityProblems.capabilities(::Type{<:LogDensityNoView}) =
7+
LogDensityProblems.LogDensityOrder{1}()
8+
function LogDensityProblems.logdensity(::LogDensityNoView, x::AbstractArray)
9+
return sum(x .^ 2)
10+
end
11+
function LogDensityProblems.logdensity(::LogDensityNoView, ::SubArray)
12+
error("Cannot use view")
13+
end
14+
function LogDensityProblems.logdensity_and_gradient(::LogDensityNoView, x::AbstractArray)
15+
ld = sum(x .^ 2)
16+
grad = 2 .* x
17+
return ld, grad
18+
end
19+
function LogDensityProblems.logdensity_and_gradient(::LogDensityNoView, ::SubArray)
20+
error("Cannot use view")
21+
end
22+
23+
names_and_algs = [
24+
("KLMinNaturalGradDescent", KLMinNaturalGradDescent(; stepsize=1e-2, n_samples=10)),
25+
(
26+
"KLMinSqrtNaturalGradDescent",
27+
KLMinSqrtNaturalGradDescent(; stepsize=1e-2, n_samples=10),
28+
),
29+
("KLMinWassFwdBwd", KLMinWassFwdBwd(; stepsize=1e-2, n_samples=10)),
30+
("FisherMinBatchMatch", FisherMinBatchMatch()),
31+
]
32+
33+
# Attempt to run VI without setting `use_view_in_gradient` to false
34+
AdvancedVI.use_view_in_gradient(::LogDensityNoView) = true
35+
@testset "$name" for (name, algorithm) in names_and_algs
36+
@test_throws "Cannot use view" optimize(
37+
algorithm,
38+
10,
39+
LogDensityNoView(),
40+
FullRankGaussian(zeros(dims), LowerTriangular(Matrix{Float64}(0.6 * I, dims, dims)));
41+
show_progress=false,
42+
)
43+
end
44+
45+
# Then run VI with `use_view_in_gradient` set to false
46+
AdvancedVI.use_view_in_gradient(::LogDensityNoView) = false
47+
@testset "$name" for (name, algorithm) in names_and_algs
48+
@test optimize(
49+
algorithm,
50+
10,
51+
LogDensityNoView(),
52+
FullRankGaussian(zeros(dims), LowerTriangular(Matrix{Float64}(0.6 * I, dims, dims)));
53+
show_progress=false,
54+
) isa Any
55+
end
56+
57+
end

0 commit comments

Comments
 (0)