Skip to content

Commit af4ad18

Browse files
committed
run furmatter to constrained
1 parent 69ae57a commit af4ad18

File tree

1 file changed

+32
-21
lines changed

1 file changed

+32
-21
lines changed

docs/src/tutorials/constrained.md

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,63 @@
22

33
In this tutorial, we will demonstrate how to deal with constrained posteriors in more detail.
44
Formally, by constrained posteriors, we mean that the target posterior has a density defined over a space that does not span the "full" Euclidean space $\mathbb{R}^d$:
5+
56
```math
67
\pi : \mathcal{X} \to \mathbb{R}_{> 0} ,
78
```
9+
810
where $\mathcal{X} \subset \mathbb{R}^d$ but not $\mathcal{X} = \mathbb{R}^d$.
911

1012
For instance, consider the basic hierarchical model for estimating the mean of the data $y_1, \ldots, y_n$:
13+
1114
```math
1215
\begin{aligned}
1316
\sigma &\sim \operatorname{LogNormal}(\alpha, \beta) \\
1417
\mu &\sim \operatorname{Normal}(0, \sigma) \\
1518
y_i &\sim \operatorname{Normal}(\mu, \sigma) .
1619
\end{aligned}
1720
```
18-
The corresponding posterior
21+
22+
The corresponding posterior
23+
1924
```math
2025
\pi(\mu, \sigma \mid y_1, \ldots, y_n)
2126
=
2227
\operatorname{LogNormal}(\sigma; \alpha, \beta)
2328
\operatorname{Normal}(\mu; 0, \sigma)
2429
\prod_{i=1}^n \operatorname{Normal}(y_i; \mu, \sigma)
2530
```
26-
has a density with respect to the space
31+
32+
has a density with respect to the space
33+
2734
```math
2835
\mathcal{X} = \mathbb{R}_{> 0} \times \mathbb{R} .
2936
```
37+
3038
There are also more complicated examples of constrained spaces.
3139
For example, a $k$-dimensional variable with a Dirichlet prior will be constrained to live on a $k$-dimensional simplex.
3240

3341
Now, most algorithms provided by `AdvancedVI`, such as:
3442

35-
- `KLMinRepGradDescent`
36-
- `KLMinRepGradProxDescent`
37-
- `KLMinNaturalGradDescent`
38-
- `FisherMinBatchMatch`
43+
- `KLMinRepGradDescent`
44+
- `KLMinRepGradProxDescent`
45+
- `KLMinNaturalGradDescent`
46+
- `FisherMinBatchMatch`
3947

4048
tend to assume the target posterior is defined over the whole Euclidean space $\mathbb{R}^d$.
4149
Therefore, to apply these algorithms, we need to do something about the constraints.
4250
We will describe some recommended ways of doing this.
4351

4452
## Transforming the Posterior
53+
4554
The most widely applicable way is to transform the posterior $\pi : \mathcal{X} \to \mathbb{R}_{>0}$ to be unconstrained.
4655
That is, consider some bijective map $b : \mathcal{X} \to \mathbb{R}^{d}$ between the $\mathcal{X}$ and the associated Euclidean space $\mathbb{R}^{d}$.
4756
Using the inverse of the map $b^{-1}$ and its Jacobian $\mathrm{J}_{b^{-1}}$, we can apply a change of variable to the posterior and obtain its unconstrained counterpart
57+
4858
```math
4959
\pi_{b^{-1}}(\eta) : \mathbb{R}^d \to \mathbb{R}_{>0} = \pi(b^{-1}(\eta)) {\lvert \mathrm{J}_{b^{-1}}(\eta) \rvert} .
5060
```
61+
5162
This idea popularized by Stan[^CGHetal2017] and Tensorflow probability[^DLTetal2017] is, in fact, how most probabilistic programming frameworks enable the use of off-the-shelf Markov chain Monte Carlo algorithms.
5263
In the context of variational inference, we will first approximate the unconstrained posterior as
5364

@@ -70,7 +81,6 @@ z \sim q_{b}^* \quad\Leftrightarrow\quad z \stackrel{\mathrm{d}}{=} b^{-1}(\eta)
7081
The idea of applying a change-of-variable to the variational approximation to match a constrained posterior was popularized by the automatic differentiation VI[^KTRGB2017].
7182

7283
[^KTRGB2017]: Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research, 18(14), 1-45.
73-
7484
Now, there are two ways how to do this in Julia.
7585
First, let's define the constrained posterior example above using the `LogDensityProblems` interface for illustration:
7686

@@ -83,7 +93,7 @@ end
8393
8494
function LogDensityProblems.logdensity(prob::Mean, θ)
8595
μ, σ = θ[1], θ[2]
86-
ℓp_μ = logpdf(Normal(0, σ), μ)
96+
ℓp_μ = logpdf(Normal(0, σ), μ)
8797
ℓp_σ = logpdf(LogNormal(0, 3), σ)
8898
ℓl_y = mapreduce(yi -> logpdf(Normal(μ, σ), yi), +, prob.y)
8999
return ℓp_μ + ℓp_σ + ℓl_y
@@ -120,7 +130,7 @@ For example:
120130
```@example constraints
121131
function Bijectors.bijector(::Mean)
122132
return Bijectors.Stacked(
123-
Bijectors.bijector.([Normal(0, 1), LogNormal(1, 1)]), [1:1, 2:2],
133+
Bijectors.bijector.([Normal(0, 1), LogNormal(1, 1)]), [1:1, 2:2]
124134
)
125135
end
126136
@@ -130,11 +140,10 @@ binv = Bijectors.inverse(b)
130140

131141
Refer to the documentation of `Bijectors.jl` for more details.
132142

133-
134143
## Wrap the `LogDensityProblem`
135144

136145
The most general and easy way to obtain an unconstrained posterior using a `Bijector` is to wrap our original `LogDensityProblem` to form a new `LogDensityProblem`.
137-
This approach only requires the user to implement the model-specific `Bijectors.bijector` function as above.
146+
This approach only requires the user to implement the model-specific `Bijectors.bijector` function as above.
138147
The rest can be done by simply copy-pasting the code below:
139148

140149
```@example constraints
@@ -179,13 +188,14 @@ x = randn(LogDensityProblems.dimension(prob_trans)) # sample on an unconstrained
179188
LogDensityProblems.logdensity(prob_trans, x)
180189
```
181190

182-
We can also wrap `prob_trans` with `LogDensityProblemsAD.ADGradient` to make it differentiable.
191+
We can also wrap `prob_trans` with `LogDensityProblemsAD.ADGradient` to make it differentiable.
192+
183193
```@example constraints
184194
using LogDensityProblemsAD
185195
using ADTypes, ReverseDiff
186196
187197
prob_trans_ad = LogDensityProblemsAD.ADgradient(
188-
ADTypes.AutoReverseDiff(; compile=true), prob_trans; x = randn(2)
198+
ADTypes.AutoReverseDiff(; compile=true), prob_trans; x=randn(2)
189199
)
190200
```
191201

@@ -218,16 +228,15 @@ using Plots
218228
219229
x = rand(q_opt_trans, 1000)
220230
221-
Plots.stephist(x[2,:], normed=true, xlabel="Posterior of σ", label=nothing, xlims=(0, 2))
222-
Plots.vline!([1.0], label="True Value")
231+
Plots.stephist(x[2, :]; normed=true, xlabel="Posterior of σ", label=nothing, xlims=(0, 2))
232+
Plots.vline!([1.0]; label="True Value")
223233
savefig("constrained_histogram.svg")
224234
```
225235

226236
![](constrained_histogram.svg)
227237

228238
We can see that the transformed posterior is indeed a meaningful approximation of the original posterior $\pi(\sigma \mid y_1, \ldots, y_n)$ we were interested in.
229239

230-
231240
## Bake a Bijector into the `LogDensityProblem`
232241

233242
A problem with the general approach above is that automatically differentiating through `TransformedLogDensityProblem` can be a bit inefficient (due to `Stacked`), especially with reverse-mode AD.
@@ -241,26 +250,28 @@ struct MeanTransformed{BInvS}
241250
end
242251
243252
function MeanTransformed(y::Vector{Float64})
244-
binv_σ = Bijectors.bijector(LogNormal(0, 3)) |> Bijectors.inverse
253+
binv_σ = Bijectors.inverse(Bijectors.bijector(LogNormal(0, 3)))
245254
return MeanTransformed(y, binv_σ)
246255
end
247256
248257
function LogDensityProblems.logdensity(prob::MeanTransformed, θ)
249258
(; y, binv_σ) = prob
250259
μ = θ[1]
251-
260+
252261
# Apply bijector and compute Jacobian
253-
σ, ℓabsdetjac_σ = with_logabsdet_jacobian(binv_σ, θ[2])
262+
σ, ℓabsdetjac_σ = with_logabsdet_jacobian(binv_σ, θ[2])
254263
255-
ℓp_μ = logpdf(Normal(0, σ), μ)
264+
ℓp_μ = logpdf(Normal(0, σ), μ)
256265
ℓp_σ = logpdf(LogNormal(0, 3), σ)
257266
ℓl_y = mapreduce(yi -> logpdf(Normal(μ, σ), yi), +, prob.y)
258267
return ℓp_μ + ℓp_σ + ℓl_y + ℓabsdetjac_σ
259268
end
260269
261270
LogDensityProblems.dimension(::MeanTransformed) = 2
262271
263-
LogDensityProblems.capabilities(::Type{MeanTransformed}) = LogDensityProblems.LogDensityOrder{0}()
272+
function LogDensityProblems.capabilities(::Type{MeanTransformed})
273+
LogDensityProblems.LogDensityOrder{0}()
274+
end
264275
265276
n_data = 30
266277
prob_bakedtrans = MeanTransformed(randn(n_data))

0 commit comments

Comments
 (0)