From 83b8bd7f13c89486edacbe248d01e4bf80b9a99c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 3 Dec 2025 11:28:49 +0000 Subject: [PATCH 1/3] Add a flag to control the use of views in gradients --- Project.toml | 2 +- docs/src/general.md | 8 +++ src/AdvancedVI.jl | 2 + src/algorithms/fisherminbatchmatch.jl | 7 ++- src/algorithms/gauss_expected_grad_hess.jl | 12 ++++- src/utils.jl | 10 ++++ test/general/view.jl | 57 ++++++++++++++++++++++ 7 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 test/general/view.jl diff --git a/Project.toml b/Project.toml index 437f5216d..09a91c2f5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.6" +version = "0.6.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/general.md b/docs/src/general.md index f7f14e0b6..5756fbce3 100644 --- a/docs/src/general.md +++ b/docs/src/general.md @@ -15,6 +15,14 @@ optimize Each algorithm may interact differently with the arguments of `optimize`. Therefore, please refer to the documentation of each different algorithm for a detailed description on their behavior and their requirements. +The `prob` argument to `optimize` must satisfy the LogDensityProblems.jl interface. +Some algorithms in AdvancedVI will call `logdensity_and_gradient` or `logdensity_gradient_and_hessian` methods using not a vector, but a view of an array. +If this is not supported by the `prob` argument, you should define this method to return `false`: + +```@docs +use_view_in_gradient +``` + ## [Monitoring the Objective Value](@id estimate_objective) Furthermore, each algorithm has an associated variational objective subject to *minimization*. (By convention, we assume all objectives are minimized rather than maximized.) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 1d07fd975..35ba97553 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -307,6 +307,8 @@ function optimize end export optimize include("utils.jl") +export use_view_in_gradient + include("optimize.jl") ## Parameter Space SGD Implementations diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl index b794a12af..805cb2d83 100644 --- a/src/algorithms/fisherminbatchmatch.jl +++ b/src/algorithms/fisherminbatchmatch.jl @@ -92,7 +92,12 @@ function rand_batch_match_samples_with_objective!( z = C*u .+ μ logπ_sum = zero(eltype(μ)) for b in 1:n_samples - logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, view(z, :, b)) + zb = if use_view_in_gradient(prob) + view(z, :, b) + else + z[:, b] + end + logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, zb) grad_buf[:, b] = gb logπ_sum += logπb end diff --git a/src/algorithms/gauss_expected_grad_hess.jl b/src/algorithms/gauss_expected_grad_hess.jl index af5c21b22..561cb808f 100644 --- a/src/algorithms/gauss_expected_grad_hess.jl +++ b/src/algorithms/gauss_expected_grad_hess.jl @@ -44,7 +44,11 @@ function gaussian_expectation_gradient_and_hessian!( m, C = q.location, q.scale z = C*u .+ m for b in 1:n_samples - zb, ub = view(z, :, b), view(u, :, b) + zb, ub = if use_view_in_gradient(prob) + view(z, :, b), view(u, :, b) + else + z[:, b], u[:, b] + end logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb) logπ_avg += logπ/n_samples @@ -60,7 +64,11 @@ function gaussian_expectation_gradient_and_hessian!( # Second-order: use naive sample average z = rand(rng, q, n_samples) for b in 1:n_samples - zb = view(z, :, b) + zb = if use_view_in_gradient(prob) + view(z, :, b) + else + z[:, b] + end logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian( prob, zb ) diff --git a/src/utils.jl b/src/utils.jl index 6481d68d5..dd2b1e26e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,3 +12,13 @@ function catsamples_and_acc( ∑y = last(state_curr) + last(state_new) return (x, ∑y) end + +""" + use_view_in_gradient(prob)::Bool + +When calling `logdensity_and_gradient(prob, x)`, this determines whether `x` can be passed +as a view. This is usually better for efficiency and hence the default is `true`. However, +some `prob`s may not support views (e.g. if gradient preparation has already been done with +a full vector). +""" +use_view_in_gradient(@nospecialize(prob::Any)) = true diff --git a/test/general/view.jl b/test/general/view.jl new file mode 100644 index 000000000..75ef9ae49 --- /dev/null +++ b/test/general/view.jl @@ -0,0 +1,57 @@ +@testset "use_view_in_gradient" begin + # Set up a LogDensityProblem that does not accept views + struct LogDensityNoView end + dims = 2 + LogDensityProblems.dimension(::LogDensityNoView) = dims + LogDensityProblems.capabilities(::Type{<:LogDensityNoView}) = + LogDensityProblems.LogDensityOrder{1}() + function LogDensityProblems.logdensity(::LogDensityNoView, x::AbstractArray) + return sum(x .^ 2) + end + function LogDensityProblems.logdensity(::LogDensityNoView, ::SubArray) + error("Cannot use view") + end + function LogDensityProblems.logdensity_and_gradient(::LogDensityNoView, x::AbstractArray) + ld = sum(x .^ 2) + grad = 2 .* x + return ld, grad + end + function LogDensityProblems.logdensity_and_gradient(::LogDensityNoView, ::SubArray) + error("Cannot use view") + end + + names_and_algs = [ + ("KLMinNaturalGradDescent", KLMinNaturalGradDescent(; stepsize=1e-2, n_samples=10)), + ( + "KLMinSqrtNaturalGradDescent", + KLMinSqrtNaturalGradDescent(; stepsize=1e-2, n_samples=10), + ), + ("KLMinWassFwdBwd", KLMinWassFwdBwd(; stepsize=1e-2, n_samples=10)), + ("FisherMinBatchMatch", FisherMinBatchMatch()), + ] + + # Attempt to run VI without setting `use_view_in_gradient` to false + AdvancedVI.use_view_in_gradient(::LogDensityNoView) = true + @testset "$name" for (name, algorithm) in names_and_algs + @test_throws "Cannot use view" optimize( + algorithm, + 10, + LogDensityNoView(), + FullRankGaussian(zeros(dims), LowerTriangular(Matrix{Float64}(0.6 * I, dims, dims))); + show_progress=false, + ) + end + + # Then run VI with `use_view_in_gradient` set to false + AdvancedVI.use_view_in_gradient(::LogDensityNoView) = false + @testset "$name" for (name, algorithm) in names_and_algs + @test optimize( + algorithm, + 10, + LogDensityNoView(), + FullRankGaussian(zeros(dims), LowerTriangular(Matrix{Float64}(0.6 * I, dims, dims))); + show_progress=false, + ) isa Any + end + +end From 3b63c92253a9430f850aa689b25b8e64dc1b196b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 3 Dec 2025 11:36:13 +0000 Subject: [PATCH 2/3] format --- test/general/view.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/general/view.jl b/test/general/view.jl index 75ef9ae49..6805ee323 100644 --- a/test/general/view.jl +++ b/test/general/view.jl @@ -9,15 +9,17 @@ return sum(x .^ 2) end function LogDensityProblems.logdensity(::LogDensityNoView, ::SubArray) - error("Cannot use view") + return error("Cannot use view") end - function LogDensityProblems.logdensity_and_gradient(::LogDensityNoView, x::AbstractArray) + function LogDensityProblems.logdensity_and_gradient( + ::LogDensityNoView, x::AbstractArray + ) ld = sum(x .^ 2) grad = 2 .* x return ld, grad end function LogDensityProblems.logdensity_and_gradient(::LogDensityNoView, ::SubArray) - error("Cannot use view") + return error("Cannot use view") end names_and_algs = [ @@ -37,7 +39,9 @@ algorithm, 10, LogDensityNoView(), - FullRankGaussian(zeros(dims), LowerTriangular(Matrix{Float64}(0.6 * I, dims, dims))); + FullRankGaussian( + zeros(dims), LowerTriangular(Matrix{Float64}(0.6 * I, dims, dims)) + ); show_progress=false, ) end @@ -49,9 +53,10 @@ algorithm, 10, LogDensityNoView(), - FullRankGaussian(zeros(dims), LowerTriangular(Matrix{Float64}(0.6 * I, dims, dims))); + FullRankGaussian( + zeros(dims), LowerTriangular(Matrix{Float64}(0.6 * I, dims, dims)) + ); show_progress=false, ) isa Any end - end From 4e873eeac1f5dea8dd3c656c2701e5e24616684a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 3 Dec 2025 11:39:00 +0000 Subject: [PATCH 3/3] format (again?? why is this complaining) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/general/view.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/general/view.jl b/test/general/view.jl index 6805ee323..d4e5b26a9 100644 --- a/test/general/view.jl +++ b/test/general/view.jl @@ -3,8 +3,9 @@ struct LogDensityNoView end dims = 2 LogDensityProblems.dimension(::LogDensityNoView) = dims - LogDensityProblems.capabilities(::Type{<:LogDensityNoView}) = - LogDensityProblems.LogDensityOrder{1}() + LogDensityProblems.capabilities(::Type{<:LogDensityNoView}) = LogDensityProblems.LogDensityOrder{ + 1 + }() function LogDensityProblems.logdensity(::LogDensityNoView, x::AbstractArray) return sum(x .^ 2) end