Skip to content

Commit 048a310

Browse files
committed
add constraint tutorial
1 parent 4d8d95e commit 048a310

File tree

2 files changed

+276
-0
lines changed

2 files changed

+276
-0
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ makedocs(;
2121
"Scaling to Large Datasets" => "tutorials/subsampling.md",
2222
"Stan Models" => "tutorials/stan.md",
2323
"Normalizing Flows" => "tutorials/flows.md",
24+
"Dealing with Constrained Posteriors" => "tutorials/constrained.md"
2425
],
2526
"Algorithms" => [
2627
"`KLMinRepGradDescent`" => "klminrepgraddescent.md",

docs/src/tutorials/constrained.md

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# [Dealing with Constrained Posteriors](@id constrained)
2+
3+
In this tutorial, we will demonstrate how to deal with constrained posteriors in more detail.
4+
Formally, by constrained posteriors, we mean that the target posterior has a density defined over a space that does not span the "full" Euclidean space $\mathbb{R}^d$:
5+
```math
6+
\pi : \mathcal{X} \to \mathbb{R}_{> 0} ,
7+
```
8+
where $\mathcal{X} \subset \mathbb{R}^d$ but not $\mathcal{X} = \mathbb{R}^d$.
9+
10+
For instance, consider the basic hierarchical model for estimating the mean of the data $y_1, \ldots, y_n$:
11+
```math
12+
\begin{aligned}
13+
\sigma &\sim \operatorname{LogNormal}(\alpha, \beta) \\
14+
\mu &\sim \operatorname{Normal}(0, \sigma) \\
15+
y_i &\sim \operatorname{Normal}(\mu, \sigma) .
16+
\end{aligned}
17+
```
18+
The corresponding posterior
19+
```math
20+
\pi(\mu, \sigma \mid y_1, \ldots, y_n)
21+
=
22+
\operatorname{LogNormal}(\sigma; \alpha, \beta)
23+
\operatorname{Normal}(\mu; 0, \sigma)
24+
\prod_{i=1}^n \operatorname{Normal}(y_i; \mu, \sigma)
25+
```
26+
has a density with respect to the space
27+
```math
28+
\mathcal{X} = \mathbb{R}_{> 0} \times \mathbb{R} .
29+
```
30+
There are also more complicated examples of constrained spaces.
31+
For example, a $k$-dimensional variable with a Dirichlet prior will be constrained to live on a $k$-dimensional simplex.
32+
33+
Now, most algorithms provided by `AdvancedVI`, such as:
34+
35+
- `KLMinRepGradDescent`
36+
- `KLMinRepGradProxDescent`
37+
- `KLMinNaturalGradDescent`
38+
- `FisherMinBatchMatch`
39+
40+
tend to assume the target posterior is defined over the whole Euclidean space $\mathbb{R}^d$.
41+
Therefore, to apply these algorithms, we need to do something about the constraints.
42+
We will describe some recommended ways of doing this.
43+
44+
## Transforming the Posterior
45+
The most widely applicable way is to transform the posterior $\pi : \mathcal{X} \to \mathbb{R}_{>0}$ to be unconstrained.
46+
That is, consider some bijective map $b : \mathcal{X} \to \mathbb{R}^{d}$ between the $\mathcal{X}$ and the associated Euclidean space $\mathbb{R}^{d}$.
47+
Using the inverse of the map $b^{-1}$ and its Jacobian $\mathrm{J}_{b^{-1}}$, we can apply a change of variable to the posterior and obtain its unconstrained counterpart
48+
```math
49+
\pi_{b^{-1}}(\eta) : \mathbb{R}^d \to \mathbb{R}_{>0} = \pi(b^{-1}(\eta)) {\lvert \mathrm{J}_{b^{-1}}(\eta) \rvert} .
50+
```
51+
This idea popularized by Stan[^CGHetal2017] and Tensorflow probability[^DLTetal2017] is, in fact, how most probabilistic programming frameworks enable the use of off-the-shelf Markov chain Monte Carlo algorithms.
52+
In the context of variational inference, we will first approximate the unconstrained posterior as
53+
54+
```math
55+
q^* = \arg\min_{q \in \mathcal{Q}} \;\; \mathrm{D}(q, \pi_{b^{-1}}) .
56+
```
57+
58+
and then transform the optimal unconstrained approximation $q^*$ to be constrained by again applying a change of variable as
59+
60+
```math
61+
q_{b}^* : \mathcal{X} \to \mathbb{R}_{>0} = q(b(z)) {\lvert \mathrm{J}_{b}(z) \rvert} .
62+
```
63+
64+
Sampling from $q_{b}^*$ amounts to pushing each sample from $q$ into $b^{-1}$:
65+
66+
```math
67+
z \sim q_{b}^* \quad\Leftrightarrow\quad z \stackrel{\mathrm{d}}{=} b^{-1}(\eta) ; \quad \eta \sim q^* .
68+
```
69+
70+
The idea of applying a change-of-variable to the variational approximation to match a constrained posterior was popularized by the automatic differentiation VI[^KTRGB2017].
71+
72+
[^KTRGB2017]: Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research, 18(14), 1-45.
73+
74+
Now, there are two ways how to do this in Julia.
75+
First, let's define the constrained posterior example above using the `LogDensityProblems` interface for illustration:
76+
77+
```@example constraints
78+
using LogDensityProblems
79+
80+
struct Mean
81+
y::Vector{Float64}
82+
end
83+
84+
function LogDensityProblems.logdensity(prob::Mean, θ)
85+
μ, σ = θ[1], θ[2]
86+
ℓp_μ = logpdf(Normal(0, σ), μ)
87+
ℓp_σ = logpdf(LogNormal(0, 3), σ)
88+
ℓl_y = mapreduce(yi -> logpdf(Normal(μ, σ), yi), +, prob.y)
89+
return ℓp_μ + ℓp_σ + ℓl_y
90+
end
91+
92+
LogDensityProblems.dimension(::Mean) = 2
93+
94+
LogDensityProblems.capabilities(::Type{Mean}) = LogDensityProblems.LogDensityOrder{0}()
95+
96+
n_data = 30
97+
prob = Mean(randn(n_data))
98+
nothing
99+
```
100+
101+
We need to find the right transformation associated with a `LogNormal` prior.
102+
Most of the common bijective transformations can be found in [`Bijectors.jl`](https://github.com/TuringLang/Bijectors.jl) package[^FXTYG2020].
103+
See the following:
104+
105+
```@example constraints
106+
using Bijectors
107+
108+
b_σ = Bijectors.bijector(LogNormal(0, 1))
109+
```
110+
111+
and the inverse transformation can be obtained as
112+
113+
```@example constraints
114+
binv_σ = Bijectors.inverse(b_σ)
115+
```
116+
117+
Multiple bijectors can also be stacked to form a joint bijector using `Bijectors.Stacked`.
118+
For example:
119+
120+
```@example constraints
121+
function Bijectors.bijector(::Mean)
122+
return Bijectors.Stacked(
123+
Bijectors.bijector.([Normal(0, 1), LogNormal(1, 1)]), [1:1, 2:2],
124+
)
125+
end
126+
127+
b = Bijectors.bijector(prob)
128+
binv = Bijectors.inverse(b)
129+
```
130+
131+
Refer to the documentation of `Bijectors.jl` for more details.
132+
133+
134+
## Wrap the `LogDensityProblem`
135+
136+
The most general and easy way to obtain an unconstrained posterior using a `Bijector` is to wrap our original `LogDensityProblem` to form a new `LogDensityProblem`.
137+
This approach only requires the user to implement the model-specific `Bijectors.bijector` function as above.
138+
The rest can be done by simply copy-pasting the code below:
139+
140+
```@example constraints
141+
struct TransformedLogDensityProblem{Prob,BInv}
142+
prob::Prob
143+
binv::BInv
144+
end
145+
146+
function TransformedLogDensityProblem(prob)
147+
b = Bijectors.bijector(prob)
148+
binv = Bijectors.inverse(b)
149+
return TransformedLogDensityProblem{typeof(prob),typeof(binv)}(prob, binv)
150+
end
151+
152+
function LogDensityProblems.logdensity(prob_trans::TransformedLogDensityProblem, θ_trans)
153+
(; prob, binv) = prob_trans
154+
θ, logabsdetjac = Bijectors.with_logabsdet_jacobian(binv, θ_trans)
155+
return LogDensityProblems.logdensity(prob, θ) + logabsdetjac
156+
end
157+
158+
function LogDensityProblems.dimension(prob_trans::TransformedLogDensityProblem)
159+
(; prob, binv) = prob_trans
160+
b = Bijectors.inverse(binv)
161+
d = LogDensityProblems.dimension(prob)
162+
return prod(Bijectors.output_size(b, (d,)))
163+
end
164+
165+
function LogDensityProblems.capabilities(
166+
::Type{TransformedLogDensityProblem{Prob,BInv}}
167+
) where {Prob,BInv}
168+
return LogDensityProblems.capabilities(Prob)
169+
end
170+
nothing
171+
```
172+
173+
Wrapping `prob` with `TransformedLogDensityProblem` yields our unconstrained posterior.
174+
175+
```@example constraints
176+
prob_trans = TransformedLogDensityProblem(prob)
177+
178+
x = randn(LogDensityProblems.dimension(prob_trans)) # sample on an unconstrained support
179+
LogDensityProblems.logdensity(prob_trans, x)
180+
```
181+
182+
We can also wrap `prob_trans` with `LogDensityProblemsAD.ADGradient` to make it differentiable.
183+
```@example constraints
184+
using LogDensityProblemsAD
185+
using ADTypes, ReverseDiff
186+
187+
prob_trans_ad = LogDensityProblemsAD.ADgradient(
188+
ADTypes.AutoReverseDiff(; compile=true), prob_trans; x = randn(2)
189+
)
190+
```
191+
192+
Let's now run VI to verify that it works.
193+
Here, we will use `FisherMinBatchMatch`, which expects an unconstrained posterior.
194+
195+
```@example constraints
196+
using AdvancedVI
197+
using LinearAlgebra
198+
199+
d = LogDensityProblems.dimension(prob_trans_ad)
200+
q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6*I, d, d)))
201+
202+
q_opt, info, _ = AdvancedVI.optimize(
203+
FisherMinBatchMatch(), 100, prob_trans_ad, q; show_progress=false
204+
)
205+
nothing
206+
```
207+
208+
We have now obtained a variational approximation `q_opt` of the unconstrained posterior associated with `prob_trans`.
209+
It remains to transform `q_opt` back to the constrained space we were originally interested in.
210+
This can be done by wrapping it into a `Bijectors.TransformedDistribution`.
211+
212+
```@example constraints
213+
q_opt_trans = Bijectors.TransformedDistribution(q_opt, binv)
214+
```
215+
216+
```@example constraints
217+
using Plots
218+
219+
x = rand(q_opt_trans, 1000)
220+
221+
Plots.stephist(x[2,:], normed=true, xlabel="Posterior of σ", label=nothing, xlims=(0, 2))
222+
Plots.vline!([1.0], label="True Value")
223+
savefig("constrained_histogram.svg")
224+
```
225+
226+
![](constrained_histogram.svg)
227+
228+
We can see that the transformed posterior is indeed a meaningful approximation of the original posterior $\pi(\sigma \mid y_1, \ldots, y_n)$ we were interested in.
229+
230+
231+
## Bake a Bijector into the `LogDensityProblem`
232+
233+
A problem with the general approach above is that automatically differentiating through `TransformedLogDensityProblem` can be a bit inefficient (due to `Stacked`), especially with reverse-mode AD.
234+
Therefore, another effective but less automatic approach is to bake the transformation and Jacobian adjustment into the `LogDensityProblem` itself.
235+
Here is an example for our mean estimation model:
236+
237+
```@example constraints
238+
struct MeanTransformed{BInvS}
239+
y::Vector{Float64}
240+
binv_σ::BInvS
241+
end
242+
243+
function MeanTransformed(y::Vector{Float64})
244+
binv_σ = Bijectors.bijector(LogNormal(0, 3)) |> Bijectors.inverse
245+
return MeanTransformed(y, binv_σ)
246+
end
247+
248+
function LogDensityProblems.logdensity(prob::MeanTransformed, θ)
249+
(; y, binv_σ) = prob
250+
μ = θ[1]
251+
252+
# Apply bijector and compute Jacobian
253+
σ, ℓabsdetjac_σ = with_logabsdet_jacobian(binv_σ, θ[2])
254+
255+
ℓp_μ = logpdf(Normal(0, σ), μ)
256+
ℓp_σ = logpdf(LogNormal(0, 3), σ)
257+
ℓl_y = mapreduce(yi -> logpdf(Normal(μ, σ), yi), +, prob.y)
258+
return ℓp_μ + ℓp_σ + ℓl_y + ℓabsdetjac_σ
259+
end
260+
261+
LogDensityProblems.dimension(::MeanTransformed) = 2
262+
263+
LogDensityProblems.capabilities(::Type{MeanTransformed}) = LogDensityProblems.LogDensityOrder{0}()
264+
265+
n_data = 30
266+
prob_bakedtrans = MeanTransformed(randn(n_data))
267+
nothing
268+
```
269+
270+
Now, `prob_bakedtrans` can be used identically as `prob_trans` above.
271+
For problems with larger dimensions, however, baking the bijector into the problem as above could be significantly more efficient.
272+
273+
[^CGHetal2017]: Carpenter, B., Gelman, A., Hoffman, M. D., Lee, D., Goodrich, B., Betancourt, M., ... & Riddell, A. (2017). Stan: A probabilistic programming language. Journal of statistical software, 76, 1-32.
274+
[^DLTetal2017]: Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604.
275+
[^FXTYG2020]: Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020, February). Bijectors. jl: Flexible transformations for probability distributions. In Symposium on Advances in Approximate Bayesian Inference (pp. 1-17). PMLR.

0 commit comments

Comments
 (0)