Skip to content

Commit 2a1755a

Browse files
committed
update READMe
1 parent 7a6e902 commit 2a1755a

File tree

1 file changed

+53
-13
lines changed

1 file changed

+53
-13
lines changed

README.md

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
[![Tests](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Tests.yml/badge.svg?branch=main)
44
[![Coverage](https://codecov.io/gh/TuringLang/AdvancedVI.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/TuringLang/AdvancedVI.jl)
55

6-
| AD Backend | Integration Status |
7-
| ------------- | ------------- |
8-
| [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) | [![ForwardDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml?query=branch%3Amain) |
9-
| [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl) | [![ReverseDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml?query=branch%3Amain) |
10-
| [Zygote](https://github.com/FluxML/Zygote.jl) | [![Zygote](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml?query=branch%3Amain) |
11-
| [Mooncake](https://github.com/chalk-lab/Mooncake.jl) | [![Mooncake](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml?query=branch%3Amain) |
12-
| [Enzyme](https://github.com/EnzymeAD/Enzyme.jl) | [![Enzyme](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml?query=branch%3Amain) |
6+
| AD Backend | Integration Status |
7+
|:---------------------------------------------------------- |:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
8+
| [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) | [![ForwardDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ForwardDiff.yml?query=branch%3Amain) |
9+
| [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl) | [![ReverseDiff](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/ReverseDiff.yml?query=branch%3Amain) |
10+
| [Zygote](https://github.com/FluxML/Zygote.jl) | [![Zygote](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Zygote.yml?query=branch%3Amain) |
11+
| [Mooncake](https://github.com/chalk-lab/Mooncake.jl) | [![Mooncake](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Mooncake.yml?query=branch%3Amain) |
12+
| [Enzyme](https://github.com/EnzymeAD/Enzyme.jl) | [![Enzyme](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml/badge.svg?branch=main)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/Enzyme.yml?query=branch%3Amain) |
1313

1414
# AdvancedVI.jl
1515

@@ -69,7 +69,7 @@ end;
6969

7070
Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`.
7171
We will use [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) for this purpose.
72-
This corresponds to the automatic differentiation variational inference (ADVI) formulation[^KTRGB2017].
72+
The bijector corresponding to the joint support of our model can be constructed as follows:
7373

7474
```julia
7575
using Bijectors: Bijectors
@@ -85,6 +85,36 @@ end;
8585

8686
A simpler approach would be to use [`Turing`](https://github.com/TuringLang/Turing.jl), where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated.
8787

88+
Since most VI algorithms assume that the posterior is unconstrained, we will apply a change-of-variable to our model to make it unconstrained.
89+
This amounts to wrapping it into a `LogDensityProblem` that applies the transformation and apply a Jacobian adjustment.
90+
91+
```julia
92+
struct TransformedLogDensityProblem{Prob,Trans}
93+
prob::Prob
94+
transform::Trans
95+
end
96+
97+
function TransformedLogDensityProblem(prob, transform)
98+
return TransformedLogDensityProblem{typeof(prob),typeof(transform)}(prob, transform)
99+
end
100+
101+
function LogDensityProblems.logdensity(prob_trans::TransformedLogDensityProblem, θ_trans)
102+
(; prob, transform) = prob_trans
103+
θ, logabsdetjac = Bijectors.with_logabsdet_jacobian(transform, θ_trans)
104+
return LogDensityProblems.logdensity(prob, θ) + logabsdetjac
105+
end
106+
107+
function LogDensityProblems.dimension(prob_trans::TransformedLogDensityProblem)
108+
return LogDensityProblems.dimension(prob_trans.prob)
109+
end
110+
111+
function LogDensityProblems.capabilities(
112+
::Type{TransformedLogDensityProblem{Prob,Trans}}
113+
) where {Prob,Trans}
114+
return LogDensityProblems.capabilities(Prob)
115+
end;
116+
```
117+
88118
For the dataset, we will use the popular [sonar classification dataset](https://archive.ics.uci.edu/dataset/151/connectionist+bench+sonar+mines+vs+rocks) from the UCI repository.
89119
This can be automatically downloaded using [`OpenML`](https://github.com/JuliaAI/OpenML.jl).
90120
The sonar dataset corresponds to the dataset id 40.
@@ -109,7 +139,10 @@ X = hcat(X, ones(size(X, 1)));
109139
The model can now be instantiated as follows:
110140

111141
```julia
112-
model = LogReg(X, y);
142+
prob = LogReg(X, y);
143+
b = Bijectors.bijector(prob)
144+
binv = Bijectors.inverse(b)
145+
prob_trans = TransformedLogDensityProblem(prob, binv)
113146
```
114147

115148
For the VI algorithm, we will use `KLMinRepGradDescent`:
@@ -136,15 +169,15 @@ For this, it is straightforward to use `LogDensityProblemsAD`:
136169
using DifferentiationInterface: DifferentiationInterface
137170
using LogDensityProblemsAD: LogDensityProblemsAD
138171

139-
model_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), model);
172+
prob_trans_ad = LogDensityProblemsAD.ADgradient(ADTypes.AutoReverseDiff(), prob_trans);
140173
```
141174

142175
For the variational family, we will consider a `FullRankGaussian` approximation:
143176

144177
```julia
145178
using LinearAlgebra
146179

147-
d = LogDensityProblems.dimension(model_ad)
180+
d = LogDensityProblems.dimension(prob_trans_ad)
148181
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.37*I, d, d)))
149182
q = MeanFieldGaussian(zeros(d), Diagonal(ones(d)));
150183
```
@@ -161,12 +194,19 @@ We can now run VI:
161194

162195
```julia
163196
max_iter = 10^3
164-
q, info, _ = AdvancedVI.optimize(alg, max_iter, model_ad, q_transformed;);
197+
q_opt, info, _ = AdvancedVI.optimize(alg, max_iter, prob_trans_ad, q);
198+
```
199+
200+
Recall that we applied a change-of-variable to the posterior to make it unconstrained.
201+
This, however, is not the original constrained posterior that we wanted to approximate.
202+
Therefore, we finally need to apply a change-of-variable to `q_opt` to make it approximate our original problem.
203+
204+
```julia
205+
q_trans = Bijectors.TransformedDistribution(q, binv)
165206
```
166207

167208
For more examples and details, please refer to the documentation.
168209

169210
[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning*. PMLR.
170211
[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In *International Conference on Machine Learning*. PMLR.
171212
[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In *International Conference on Learning Representations*.
172-
[^KTRGB2017]: Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. *Journal of machine learning research*.

0 commit comments

Comments
 (0)