| AD Backend | Integration Status |
|---|---|
| ForwardDiff | |
| ReverseDiff | |
| Zygote | |
| Mooncake | |
| Enzyme |
AdvancedVI provides implementations of variational inference (VI) algorithms, which is a family of algorithms aiming for scalable approximate Bayesian inference by leveraging optimization.
AdvancedVI is part of the Turing probabilistic programming ecosystem.
The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. Turing, only need to write a light wrapper for integration.
For example, integrating Turing with AdvancedVI.ADVI only involves converting a Turing.Model into a LogDensityProblem and extracting a corresponding Bijectors.bijector.
We will describe a simple example to demonstrate the basic usage of AdvancedVI.
AdvancedVI works with differentiable models specified through the LogDensityProblem interface.
Let's look at a basic logistic regression example with a hierarchical prior.
For a dataset
The LogDensityProblem corresponding to this model can be constructed as
using LogDensityProblems: LogDensityProblems
using Distributions
using FillArrays
struct LogReg{XType,YType}
X::XType
y::YType
end
function LogDensityProblems.logdensity(model::LogReg, θ)
(; X, y) = model
d = size(X, 2)
β, σ = θ[1:size(X, 2)], θ[end]
logprior_β = logpdf(MvNormal(Zeros(d), σ), β)
logprior_σ = logpdf(LogNormal(0, 3), σ)
logit = X*β
loglike_y = mapreduce((li, yi) -> logpdf(BernoulliLogit(li), yi), +, logit, y)
return loglike_y + logprior_β + logprior_σ
end
function LogDensityProblems.dimension(model::LogReg)
return size(model.X, 2) + 1
end
function LogDensityProblems.capabilities(::Type{<:LogReg})
return LogDensityProblems.LogDensityOrder{0}()
end;Since the support of σ is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a bijector to transform θ.
We will use Bijectors for this purpose.
The bijector corresponding to the joint support of our model can be constructed as follows:
using Bijectors: Bijectors
function Bijectors.bijector(model::LogReg)
d = size(model.X, 2)
return Bijectors.Stacked(
Bijectors.bijector.([MvNormal(Zeros(d), 1.0), LogNormal(0, 3)]),
[1:d, (d + 1):(d + 1)],
)
end;A simpler approach would be to use Turing, where a Turing.Model can be automatically be converted into a LogDensityProblem and a corresponding bijector is automatically generated.
Since most VI algorithms assume that the posterior is unconstrained, we will apply a change-of-variable to our model to make it unconstrained.
This amounts to wrapping it into a LogDensityProblem that applies the transformation and the corresponding Jacobian adjustment.
struct TransformedLogDensityProblem{Prob,BInv}
prob::Prob
binv::BInv
end
function TransformedLogDensityProblem(prob)
b = Bijectors.bijector(prob)
binv = Bijectors.inverse(b)
return TransformedLogDensityProblem{typeof(prob),typeof(binv)}(prob, binv)
end
function LogDensityProblems.logdensity(prob_trans::TransformedLogDensityProblem, θ_trans)
(; prob, binv) = prob_trans
θ, logabsdetjac = Bijectors.with_logabsdet_jacobian(binv, θ_trans)
return LogDensityProblems.logdensity(prob, θ) + logabsdetjac
end
function LogDensityProblems.dimension(prob_trans::TransformedLogDensityProblem)
(; prob, binv) = prob_trans
b = Bijectors.inverse(binv)
d = LogDensityProblems.dimension(prob)
return prod(Bijectors.output_size(b, (d,)))
end
function LogDensityProblems.capabilities(
::Type{TransformedLogDensityProblem{Prob,BInv}}
) where {Prob,BInv}
return LogDensityProblems.capabilities(Prob)
end;For the dataset, we will use the popular sonar classification dataset from the UCI repository.
This can be automatically downloaded using OpenML.
The sonar dataset corresponds to the dataset id 40.
using OpenML: OpenML
using DataFrames: DataFrames
data = Array(DataFrames.DataFrame(OpenML.load(40)))
X = Matrix{Float64}(data[:, 1:(end - 1)])
y = Vector{Bool}(data[:, end] .== "Mine");Let's apply some basic pre-processing and add an intercept column:
using Statistics
X = (X .- mean(X; dims=2)) ./ std(X; dims=2)
X = hcat(X, ones(size(X, 1)));The model can now be instantiated as follows:
prob = LogReg(X, y);
prob_trans = TransformedLogDensityProblem(prob)For the VI algorithm, we will use KLMinRepGradDescent:
using ADTypes, ReverseDiff
using AdvancedVI
alg = KLMinRepGradDescent(ADTypes.AutoReverseDiff(); operator=ClipScale())This algorithm minimizes the exclusive/reverse KL divergence via stochastic gradient descent in the (Euclidean) space of the parameters of the variational approximation with the reparametrization gradient123. This is also commonly referred as automatic differentiation VI, black-box VI, stochastic gradient VI, and so on.
Also, projection or proximal operators can be used through the keyword argument operator.
For this example, we will use Gaussian variational family, which is part of the more broad location-scale family.
These require the scale matrix to have strictly positive eigenvalues at all times.
Here, the projection operator ClipScale ensures this.
This KLMinRepGradDescent, in particular, assumes that the target LogDensityProblem has gradients.
For this, it is straightforward to use LogDensityProblemsAD:
using DifferentiationInterface: DifferentiationInterface
using LogDensityProblemsAD: LogDensityProblemsAD
prob_trans_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), prob_trans);For the variational family, we will consider a FullRankGaussian approximation:
using LinearAlgebra
d = LogDensityProblems.dimension(prob_trans_ad)
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6*I, d, d)))
q = MeanFieldGaussian(zeros(d), Diagonal(ones(d)));We can now run VI:
max_iter = 10^3
q_opt, info, _ = AdvancedVI.optimize(alg, max_iter, prob_trans_ad, q);Recall that we applied a change-of-variable to the posterior to make it unconstrained.
This, however, is not the original constrained posterior that we wanted to approximate.
Therefore, we finally need to apply a change-of-variable to q_opt to make it approximate our original problem.
b = Bijectors.bijector(prob)
binv = Bijectors.inverse(b)
q_trans = Bijectors.TransformedDistribution(q_opt, binv)For more examples and details, please refer to the documentation.
Footnotes
-
Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International Conference on Machine Learning. PMLR. ↩
-
Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning. PMLR. ↩
-
Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In International Conference on Learning Representations. ↩