|
| 1 | +# [Dealing with Constrained Posteriors](@id constrained) |
| 2 | + |
| 3 | +In this tutorial, we will demonstrate how to deal with constrained posteriors in more detail. |
| 4 | +Formally, by constrained posteriors, we mean that the target posterior has a density defined over a space that does not span the "full" Euclidean space $\mathbb{R}^d$: |
| 5 | +```math |
| 6 | +\pi : \mathcal{X} \to \mathbb{R}_{> 0} , |
| 7 | +``` |
| 8 | +where $\mathcal{X} \subset \mathbb{R}^d$ but not $\mathcal{X} = \mathbb{R}^d$. |
| 9 | + |
| 10 | +For instance, consider the basic hierarchical model for estimating the mean of the data $y_1, \ldots, y_n$: |
| 11 | +```math |
| 12 | +\begin{aligned} |
| 13 | + \sigma &\sim \operatorname{LogNormal}(\alpha, \beta) \\ |
| 14 | + \mu &\sim \operatorname{Normal}(0, \sigma) \\ |
| 15 | + y_i &\sim \operatorname{Normal}(\mu, \sigma) . |
| 16 | +\end{aligned} |
| 17 | +``` |
| 18 | +The corresponding posterior |
| 19 | +```math |
| 20 | +\pi(\mu, \sigma \mid y_1, \ldots, y_n) |
| 21 | += |
| 22 | +\operatorname{LogNormal}(\sigma; \alpha, \beta) |
| 23 | +\operatorname{Normal}(\mu; 0, \sigma) |
| 24 | +\prod_{i=1}^n \operatorname{Normal}(y_i; \mu, \sigma) |
| 25 | +``` |
| 26 | +has a density with respect to the space |
| 27 | +```math |
| 28 | + \mathcal{X} = \mathbb{R}_{> 0} \times \mathbb{R} . |
| 29 | +``` |
| 30 | +There are also more complicated examples of constrained spaces. |
| 31 | +For example, a $k$-dimensional variable with a Dirichlet prior will be constrained to live on a $k$-dimensional simplex. |
| 32 | + |
| 33 | +Now, most algorithms provided by `AdvancedVI`, such as: |
| 34 | + |
| 35 | +- `KLMinRepGradDescent` |
| 36 | +- `KLMinRepGradProxDescent` |
| 37 | +- `KLMinNaturalGradDescent` |
| 38 | +- `FisherMinBatchMatch` |
| 39 | + |
| 40 | +tend to assume the target posterior is defined over the whole Euclidean space $\mathbb{R}^d$. |
| 41 | +Therefore, to apply these algorithms, we need to do something about the constraints. |
| 42 | +We will describe some recommended ways of doing this. |
| 43 | + |
| 44 | +## Transforming the Posterior |
| 45 | +The most widely applicable way is to transform the posterior $\pi : \mathcal{X} \to \mathbb{R}_{>0}$ to be unconstrained. |
| 46 | +That is, consider some bijective map $b : \mathcal{X} \to \mathbb{R}^{d}$ between the $\mathcal{X}$ and the associated Euclidean space $\mathbb{R}^{d}$. |
| 47 | +Using the inverse of the map $b^{-1}$ and its Jacobian $\mathrm{J}_{b^{-1}}$, we can apply a change of variable to the posterior and obtain its unconstrained counterpart |
| 48 | +```math |
| 49 | +\pi_{b^{-1}}(\eta) : \mathbb{R}^d \to \mathbb{R}_{>0} = \pi(b^{-1}(\eta)) {\lvert \mathrm{J}_{b^{-1}}(\eta) \rvert} . |
| 50 | +``` |
| 51 | +This idea popularized by Stan[^CGHetal2017] and Tensorflow probability[^DLTetal2017] is, in fact, how most probabilistic programming frameworks enable the use of off-the-shelf Markov chain Monte Carlo algorithms. |
| 52 | +In the context of variational inference, we will first approximate the unconstrained posterior as |
| 53 | + |
| 54 | +```math |
| 55 | +q^* = \arg\min_{q \in \mathcal{Q}} \;\; \mathrm{D}(q, \pi_{b^{-1}}) . |
| 56 | +``` |
| 57 | + |
| 58 | +and then transform the optimal unconstrained approximation $q^*$ to be constrained by again applying a change of variable as |
| 59 | + |
| 60 | +```math |
| 61 | +q_{b}^* : \mathcal{X} \to \mathbb{R}_{>0} = q(b(z)) {\lvert \mathrm{J}_{b}(z) \rvert} . |
| 62 | +``` |
| 63 | + |
| 64 | +Sampling from $q_{b}^*$ amounts to pushing each sample from $q$ into $b^{-1}$: |
| 65 | + |
| 66 | +```math |
| 67 | +z \sim q_{b}^* \quad\Leftrightarrow\quad z \stackrel{\mathrm{d}}{=} b^{-1}(\eta) ; \quad \eta \sim q^* . |
| 68 | +``` |
| 69 | + |
| 70 | +The idea of applying a change-of-variable to the variational approximation to match a constrained posterior was popularized by the automatic differentiation VI[^KTRGB2017]. |
| 71 | + |
| 72 | +[^KTRGB2017]: Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research, 18(14), 1-45. |
| 73 | + |
| 74 | +Now, there are two ways how to do this in Julia. |
| 75 | +First, let's define the constrained posterior example above using the `LogDensityProblems` interface for illustration: |
| 76 | + |
| 77 | +```@example constraints |
| 78 | +using LogDensityProblems |
| 79 | +
|
| 80 | +struct Mean |
| 81 | + y::Vector{Float64} |
| 82 | +end |
| 83 | +
|
| 84 | +function LogDensityProblems.logdensity(prob::Mean, θ) |
| 85 | + μ, σ = θ[1], θ[2] |
| 86 | + ℓp_μ = logpdf(Normal(0, σ), μ) |
| 87 | + ℓp_σ = logpdf(LogNormal(0, 3), σ) |
| 88 | + ℓl_y = mapreduce(yi -> logpdf(Normal(μ, σ), yi), +, prob.y) |
| 89 | + return ℓp_μ + ℓp_σ + ℓl_y |
| 90 | +end |
| 91 | +
|
| 92 | +LogDensityProblems.dimension(::Mean) = 2 |
| 93 | +
|
| 94 | +LogDensityProblems.capabilities(::Type{Mean}) = LogDensityProblems.LogDensityOrder{0}() |
| 95 | +
|
| 96 | +n_data = 30 |
| 97 | +prob = Mean(randn(n_data)) |
| 98 | +nothing |
| 99 | +``` |
| 100 | + |
| 101 | +We need to find the right transformation associated with a `LogNormal` prior. |
| 102 | +Most of the common bijective transformations can be found in [`Bijectors.jl`](https://github.com/TuringLang/Bijectors.jl) package[^FXTYG2020]. |
| 103 | +See the following: |
| 104 | + |
| 105 | +```@example constraints |
| 106 | +using Bijectors |
| 107 | +
|
| 108 | +b_σ = Bijectors.bijector(LogNormal(0, 1)) |
| 109 | +``` |
| 110 | + |
| 111 | +and the inverse transformation can be obtained as |
| 112 | + |
| 113 | +```@example constraints |
| 114 | +binv_σ = Bijectors.inverse(b_σ) |
| 115 | +``` |
| 116 | + |
| 117 | +Multiple bijectors can also be stacked to form a joint bijector using `Bijectors.Stacked`. |
| 118 | +For example: |
| 119 | + |
| 120 | +```@example constraints |
| 121 | +function Bijectors.bijector(::Mean) |
| 122 | + return Bijectors.Stacked( |
| 123 | + Bijectors.bijector.([Normal(0, 1), LogNormal(1, 1)]), [1:1, 2:2], |
| 124 | + ) |
| 125 | +end |
| 126 | +
|
| 127 | +b = Bijectors.bijector(prob) |
| 128 | +binv = Bijectors.inverse(b) |
| 129 | +``` |
| 130 | + |
| 131 | +Refer to the documentation of `Bijectors.jl` for more details. |
| 132 | + |
| 133 | + |
| 134 | +## Wrap the `LogDensityProblem` |
| 135 | + |
| 136 | +The most general and easy way to obtain an unconstrained posterior using a `Bijector` is to wrap our original `LogDensityProblem` to form a new `LogDensityProblem`. |
| 137 | +This approach only requires the user to implement the model-specific `Bijectors.bijector` function as above. |
| 138 | +The rest can be done by simply copy-pasting the code below: |
| 139 | + |
| 140 | +```@example constraints |
| 141 | +struct TransformedLogDensityProblem{Prob,BInv} |
| 142 | + prob::Prob |
| 143 | + binv::BInv |
| 144 | +end |
| 145 | +
|
| 146 | +function TransformedLogDensityProblem(prob) |
| 147 | + b = Bijectors.bijector(prob) |
| 148 | + binv = Bijectors.inverse(b) |
| 149 | + return TransformedLogDensityProblem{typeof(prob),typeof(binv)}(prob, binv) |
| 150 | +end |
| 151 | +
|
| 152 | +function LogDensityProblems.logdensity(prob_trans::TransformedLogDensityProblem, θ_trans) |
| 153 | + (; prob, binv) = prob_trans |
| 154 | + θ, logabsdetjac = Bijectors.with_logabsdet_jacobian(binv, θ_trans) |
| 155 | + return LogDensityProblems.logdensity(prob, θ) + logabsdetjac |
| 156 | +end |
| 157 | +
|
| 158 | +function LogDensityProblems.dimension(prob_trans::TransformedLogDensityProblem) |
| 159 | + (; prob, binv) = prob_trans |
| 160 | + b = Bijectors.inverse(binv) |
| 161 | + d = LogDensityProblems.dimension(prob) |
| 162 | + return prod(Bijectors.output_size(b, (d,))) |
| 163 | +end |
| 164 | +
|
| 165 | +function LogDensityProblems.capabilities( |
| 166 | + ::Type{TransformedLogDensityProblem{Prob,BInv}} |
| 167 | +) where {Prob,BInv} |
| 168 | + return LogDensityProblems.capabilities(Prob) |
| 169 | +end |
| 170 | +nothing |
| 171 | +``` |
| 172 | + |
| 173 | +Wrapping `prob` with `TransformedLogDensityProblem` yields our unconstrained posterior. |
| 174 | + |
| 175 | +```@example constraints |
| 176 | +prob_trans = TransformedLogDensityProblem(prob) |
| 177 | +
|
| 178 | +x = randn(LogDensityProblems.dimension(prob_trans)) # sample on an unconstrained support |
| 179 | +LogDensityProblems.logdensity(prob_trans, x) |
| 180 | +``` |
| 181 | + |
| 182 | +We can also wrap `prob_trans` with `LogDensityProblemsAD.ADGradient` to make it differentiable. |
| 183 | +```@example constraints |
| 184 | +using LogDensityProblemsAD |
| 185 | +using ADTypes, ReverseDiff |
| 186 | +
|
| 187 | +prob_trans_ad = LogDensityProblemsAD.ADgradient( |
| 188 | + ADTypes.AutoReverseDiff(; compile=true), prob_trans; x = randn(2) |
| 189 | +) |
| 190 | +``` |
| 191 | + |
| 192 | +Let's now run VI to verify that it works. |
| 193 | +Here, we will use `FisherMinBatchMatch`, which expects an unconstrained posterior. |
| 194 | + |
| 195 | +```@example constraints |
| 196 | +using AdvancedVI |
| 197 | +using LinearAlgebra |
| 198 | +
|
| 199 | +d = LogDensityProblems.dimension(prob_trans_ad) |
| 200 | +q = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6*I, d, d))) |
| 201 | +
|
| 202 | +q_opt, info, _ = AdvancedVI.optimize( |
| 203 | + FisherMinBatchMatch(), 100, prob_trans_ad, q; show_progress=false |
| 204 | +) |
| 205 | +nothing |
| 206 | +``` |
| 207 | + |
| 208 | +We have now obtained a variational approximation `q_opt` of the unconstrained posterior associated with `prob_trans`. |
| 209 | +It remains to transform `q_opt` back to the constrained space we were originally interested in. |
| 210 | +This can be done by wrapping it into a `Bijectors.TransformedDistribution`. |
| 211 | + |
| 212 | +```@example constraints |
| 213 | +q_opt_trans = Bijectors.TransformedDistribution(q_opt, binv) |
| 214 | +``` |
| 215 | + |
| 216 | +```@example constraints |
| 217 | +using Plots |
| 218 | +
|
| 219 | +x = rand(q_opt_trans, 1000) |
| 220 | +
|
| 221 | +Plots.stephist(x[2,:], normed=true, xlabel="Posterior of σ", label=nothing, xlims=(0, 2)) |
| 222 | +Plots.vline!([1.0], label="True Value") |
| 223 | +savefig("constrained_histogram.svg") |
| 224 | +``` |
| 225 | + |
| 226 | + |
| 227 | + |
| 228 | +We can see that the transformed posterior is indeed a meaningful approximation of the original posterior $\pi(\sigma \mid y_1, \ldots, y_n)$ we were interested in. |
| 229 | + |
| 230 | + |
| 231 | +## Bake a Bijector into the `LogDensityProblem` |
| 232 | + |
| 233 | +A problem with the general approach above is that automatically differentiating through `TransformedLogDensityProblem` can be a bit inefficient (due to `Stacked`), especially with reverse-mode AD. |
| 234 | +Therefore, another effective but less automatic approach is to bake the transformation and Jacobian adjustment into the `LogDensityProblem` itself. |
| 235 | +Here is an example for our mean estimation model: |
| 236 | + |
| 237 | +```@example constraints |
| 238 | +struct MeanTransformed{BInvS} |
| 239 | + y::Vector{Float64} |
| 240 | + binv_σ::BInvS |
| 241 | +end |
| 242 | +
|
| 243 | +function MeanTransformed(y::Vector{Float64}) |
| 244 | + binv_σ = Bijectors.bijector(LogNormal(0, 3)) |> Bijectors.inverse |
| 245 | + return MeanTransformed(y, binv_σ) |
| 246 | +end |
| 247 | +
|
| 248 | +function LogDensityProblems.logdensity(prob::MeanTransformed, θ) |
| 249 | + (; y, binv_σ) = prob |
| 250 | + μ = θ[1] |
| 251 | + |
| 252 | + # Apply bijector and compute Jacobian |
| 253 | + σ, ℓabsdetjac_σ = with_logabsdet_jacobian(binv_σ, θ[2]) |
| 254 | +
|
| 255 | + ℓp_μ = logpdf(Normal(0, σ), μ) |
| 256 | + ℓp_σ = logpdf(LogNormal(0, 3), σ) |
| 257 | + ℓl_y = mapreduce(yi -> logpdf(Normal(μ, σ), yi), +, prob.y) |
| 258 | + return ℓp_μ + ℓp_σ + ℓl_y + ℓabsdetjac_σ |
| 259 | +end |
| 260 | +
|
| 261 | +LogDensityProblems.dimension(::MeanTransformed) = 2 |
| 262 | +
|
| 263 | +LogDensityProblems.capabilities(::Type{MeanTransformed}) = LogDensityProblems.LogDensityOrder{0}() |
| 264 | +
|
| 265 | +n_data = 30 |
| 266 | +prob_bakedtrans = MeanTransformed(randn(n_data)) |
| 267 | +nothing |
| 268 | +``` |
| 269 | + |
| 270 | +Now, `prob_bakedtrans` can be used identically as `prob_trans` above. |
| 271 | +For problems with larger dimensions, however, baking the bijector into the problem as above could be significantly more efficient. |
| 272 | + |
| 273 | +[^CGHetal2017]: Carpenter, B., Gelman, A., Hoffman, M. D., Lee, D., Goodrich, B., Betancourt, M., ... & Riddell, A. (2017). Stan: A probabilistic programming language. Journal of statistical software, 76, 1-32. |
| 274 | +[^DLTetal2017]: Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604. |
| 275 | +[^FXTYG2020]: Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020, February). Bijectors. jl: Flexible transformations for probability distributions. In Symposium on Advances in Approximate Bayesian Inference (pp. 1-17). PMLR. |
0 commit comments