Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedVI"
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.6"
version = "0.6.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
8 changes: 8 additions & 0 deletions docs/src/general.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ optimize
Each algorithm may interact differently with the arguments of `optimize`.
Therefore, please refer to the documentation of each different algorithm for a detailed description on their behavior and their requirements.

The `prob` argument to `optimize` must satisfy the LogDensityProblems.jl interface.
Some algorithms in AdvancedVI will call `logdensity_and_gradient` or `logdensity_gradient_and_hessian` methods using not a vector, but a view of an array.
If this is not supported by the `prob` argument, you should define this method to return `false`:

```@docs
use_view_in_gradient
```

## [Monitoring the Objective Value](@id estimate_objective)

Furthermore, each algorithm has an associated variational objective subject to *minimization*. (By convention, we assume all objectives are minimized rather than maximized.)
Expand Down
2 changes: 2 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ function optimize end
export optimize

include("utils.jl")
export use_view_in_gradient

include("optimize.jl")

## Parameter Space SGD Implementations
Expand Down
7 changes: 6 additions & 1 deletion src/algorithms/fisherminbatchmatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ function rand_batch_match_samples_with_objective!(
z = C*u .+ μ
logπ_sum = zero(eltype(μ))
for b in 1:n_samples
logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, view(z, :, b))
zb = if use_view_in_gradient(prob)
view(z, :, b)
else
z[:, b]
end
logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, zb)
grad_buf[:, b] = gb
logπ_sum += logπb
end
Expand Down
12 changes: 10 additions & 2 deletions src/algorithms/gauss_expected_grad_hess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ function gaussian_expectation_gradient_and_hessian!(
m, C = q.location, q.scale
z = C*u .+ m
for b in 1:n_samples
zb, ub = view(z, :, b), view(u, :, b)
zb, ub = if use_view_in_gradient(prob)
view(z, :, b), view(u, :, b)
else
z[:, b], u[:, b]
end
logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb)
logπ_avg += logπ/n_samples

Expand All @@ -60,7 +64,11 @@ function gaussian_expectation_gradient_and_hessian!(
# Second-order: use naive sample average
z = rand(rng, q, n_samples)
for b in 1:n_samples
zb = view(z, :, b)
zb = if use_view_in_gradient(prob)
view(z, :, b)
else
z[:, b]
end
logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian(
prob, zb
)
Expand Down
10 changes: 10 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@ function catsamples_and_acc(
∑y = last(state_curr) + last(state_new)
return (x, ∑y)
end

"""
use_view_in_gradient(prob)::Bool

When calling `logdensity_and_gradient(prob, x)`, this determines whether `x` can be passed
as a view. This is usually better for efficiency and hence the default is `true`. However,
some `prob`s may not support views (e.g. if gradient preparation has already been done with
a full vector).
"""
use_view_in_gradient(@nospecialize(prob::Any)) = true
63 changes: 63 additions & 0 deletions test/general/view.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
@testset "use_view_in_gradient" begin
# Set up a LogDensityProblem that does not accept views
struct LogDensityNoView end
dims = 2
LogDensityProblems.dimension(::LogDensityNoView) = dims
LogDensityProblems.capabilities(::Type{<:LogDensityNoView}) = LogDensityProblems.LogDensityOrder{
1
}()
function LogDensityProblems.logdensity(::LogDensityNoView, x::AbstractArray)
return sum(x .^ 2)
end
function LogDensityProblems.logdensity(::LogDensityNoView, ::SubArray)
return error("Cannot use view")
end
function LogDensityProblems.logdensity_and_gradient(
::LogDensityNoView, x::AbstractArray
)
ld = sum(x .^ 2)
grad = 2 .* x
return ld, grad
end
function LogDensityProblems.logdensity_and_gradient(::LogDensityNoView, ::SubArray)
return error("Cannot use view")
end

names_and_algs = [
("KLMinNaturalGradDescent", KLMinNaturalGradDescent(; stepsize=1e-2, n_samples=10)),
(
"KLMinSqrtNaturalGradDescent",
KLMinSqrtNaturalGradDescent(; stepsize=1e-2, n_samples=10),
),
("KLMinWassFwdBwd", KLMinWassFwdBwd(; stepsize=1e-2, n_samples=10)),
("FisherMinBatchMatch", FisherMinBatchMatch()),
]

# Attempt to run VI without setting `use_view_in_gradient` to false
AdvancedVI.use_view_in_gradient(::LogDensityNoView) = true
@testset "$name" for (name, algorithm) in names_and_algs
@test_throws "Cannot use view" optimize(
algorithm,
10,
LogDensityNoView(),
FullRankGaussian(
zeros(dims), LowerTriangular(Matrix{Float64}(0.6 * I, dims, dims))
);
show_progress=false,
)
end

# Then run VI with `use_view_in_gradient` set to false
AdvancedVI.use_view_in_gradient(::LogDensityNoView) = false
@testset "$name" for (name, algorithm) in names_and_algs
@test optimize(
algorithm,
10,
LogDensityNoView(),
FullRankGaussian(
zeros(dims), LowerTriangular(Matrix{Float64}(0.6 * I, dims, dims))
);
show_progress=false,
) isa Any
end
end
Loading