Skip to content

Commit 47566a8

Browse files
committed
Merge branch 'bump_advancedvi_0.5' into breaking
2 parents ed9b4f2 + 31256e3 commit 47566a8

File tree

9 files changed

+259
-179
lines changed

9 files changed

+259
-179
lines changed

HISTORY.md

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,96 @@ When sampling using MCMCChains, the chain object will no longer have its `chain.
7272
Instead, you can calculate this yourself from the log-likelihoods stored in the chain.
7373
For SMC samplers, the log-evidence of the entire trajectory is stored in `chain[:logevidence]` (which is the same for every particle in the 'chain').
7474

75+
## AdvancedVI 0.6
76+
77+
Turing.jl v0.42 updates `AdvancedVI.jl` compatibility to 0.6 (we skipped the breaking 0.5 update as it does not introduce new features).
78+
`[email protected]` introduces major structural changes including breaking changes to the interface and multiple new features.
79+
The summary of the changes below are the things that affect the end-users of Turing.
80+
For a more comprehensive list of changes, please refer to the [changelogs](https://github.com/TuringLang/AdvancedVI.jl/blob/main/HISTORY.md) in `AdvancedVI`.
81+
82+
### Breaking changes
83+
84+
A new level of interface for defining different variational algorithms has been introduced in `AdvancedVI` v0.5. As a result, the function `Turing.vi` now receives a keyword argument `algorithm`. The object `algorithm <: AdvancedVI.AbstractVariationalAlgorithm` should now contain all the algorithm-specific configurations. Therefore, keyword arguments of `vi` that were algorithm-specific such as `objective`, `operator`, `averager` and so on, have been moved as fields of the relevant `<: AdvancedVI.AbstractVariationalAlgorithm` structs.
85+
86+
In addition, the outputs also changed. Previously, `vi` returned both the last-iterate of the algorithm `q` and the iterate average `q_avg`. Now, for the algorithms running parameter averaging, only `q_avg` is returned. As a result, the number of returned values reduced from 4 to 3.
87+
88+
For example,
89+
90+
```julia
91+
q, q_avg, info, state = vi(
92+
model, q, n_iters; objective=RepGradELBO(10), operator=AdvancedVI.ClipScale()
93+
)
94+
```
95+
96+
is now
97+
98+
```julia
99+
q_avg, info, state = vi(
100+
model,
101+
q,
102+
n_iters;
103+
algorithm=KLMinRepGradDescent(adtype; n_samples=10, operator=AdvancedVI.ClipScale()),
104+
)
105+
```
106+
107+
Similarly,
108+
109+
```julia
110+
vi(
111+
model,
112+
q,
113+
n_iters;
114+
objective=RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()),
115+
operator=AdvancedVI.ProximalLocationScaleEntropy(),
116+
)
117+
```
118+
119+
is now
120+
121+
```julia
122+
vi(model, q, n_iters; algorithm=KLMinRepGradProxDescent(adtype; n_samples=10))
123+
```
124+
125+
Lastly, to obtain the last-iterate `q` of `KLMinRepGradDescent`, which is not returned in the new interface, simply select the averaging strategy to be `AdvancedVI.NoAveraging()`. That is,
126+
127+
```julia
128+
q, info, state = vi(
129+
model,
130+
q,
131+
n_iters;
132+
algorithm=KLMinRepGradDescent(
133+
adtype;
134+
n_samples=10,
135+
operator=AdvancedVI.ClipScale(),
136+
averager=AdvancedVI.NoAveraging(),
137+
),
138+
)
139+
```
140+
141+
Additionally,
142+
143+
- The default hyperparameters of `DoG`and `DoWG` have been altered.
144+
- The deprecated `[email protected]`-era interface is now removed.
145+
- `estimate_objective` now always returns the value to be minimized by the optimization algorithm. For example, for ELBO maximization algorithms, `estimate_objective` will return the *negative ELBO*. This is breaking change from the previous behavior where the ELBO was returned.
146+
- The initial value for the `q_meanfield_gaussian`, `q_fullrank_gaussian`, and `q_locationscale` have changed. Specificially, the default initial value for the scale matrix has been changed from `I` to `0.6*I`.
147+
- When using algorithms that expect to operate in unconstrained spaces, the user is now explicitly expected to provide a `Bijectors.TransformedDistribution` wrapping an unconstrained distribution. (Refer to the docstring of `vi`.)
148+
149+
### New Features
150+
151+
`[email protected]` adds numerous new features including the following new VI algorithms:
152+
153+
- `KLMinWassFwdBwd`: Also known as "Wasserstein variational inference," this algorithm minimizes the KL divergence under the Wasserstein-2 metric.
154+
- `KLMinNaturalGradDescent`: This algorithm, also known as "online variational Newton," is the canonical "black-box" natural gradient variational inference algorithm, which minimizes the KL divergence via mirror descent under the KL divergence as the Bregman divergence.
155+
- `KLMinSqrtNaturalGradDescent`: This is a recent variant of `KLMinNaturalGradDescent` that operates in the Cholesky-factor parameterization of Gaussians instead of precision matrices.
156+
- `FisherMinBatchMatch`: This algorithm called "batch-and-match," minimizes the variation of the 2nd order Fisher divergence via a proximal point-type algorithm.
157+
158+
Any of the new algorithms above can readily be used by simply swappin the `algorithm` keyword argument of `vi`.
159+
For example, to use batch-and-match:
160+
161+
```julia
162+
vi(model, q, n_iters; algorithm=FisherMinBatchMatch())
163+
```
164+
75165
## External sampler interface
76166

77167
The interface for defining an external sampler has been reworked.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Accessors = "0.1"
5353
AdvancedHMC = "0.8.3"
5454
AdvancedMH = "0.8.9"
5555
AdvancedPS = "0.7"
56-
AdvancedVI = "0.4"
56+
AdvancedVI = "0.6"
5757
BangBang = "0.4.2"
5858
Bijectors = "0.14, 0.15"
5959
Compat = "4.15.0"

docs/src/api.md

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,19 @@ Turing.jl provides several strategies to initialise parameters for models.
110110

111111
See the [docs of AdvancedVI.jl](https://turinglang.org/AdvancedVI.jl/stable/) for detailed usage and the [variational inference tutorial](https://turinglang.org/docs/tutorials/09-variational-inference/) for a basic walkthrough.
112112

113-
| Exported symbol | Documentation | Description |
114-
|:---------------------- |:------------------------------------------------- |:---------------------------------------------------------------------------------------- |
115-
| `vi` | [`Turing.vi`](@ref) | Perform variational inference |
116-
| `q_locationscale` | [`Turing.Variational.q_locationscale`](@ref) | Find a numerically non-degenerate initialization for a location-scale variational family |
117-
| `q_meanfield_gaussian` | [`Turing.Variational.q_meanfield_gaussian`](@ref) | Find a numerically non-degenerate initialization for a mean-field Gaussian family |
118-
| `q_fullrank_gaussian` | [`Turing.Variational.q_fullrank_gaussian`](@ref) | Find a numerically non-degenerate initialization for a full-rank Gaussian family |
113+
| Exported symbol | Documentation | Description |
114+
|:----------------------------- |:-------------------------------------------------------- |:------------------------------------------------------------------------------------------------------------------------------------------------- |
115+
| `vi` | [`Turing.vi`](@ref) | Perform variational inference |
116+
| `q_locationscale` | [`Turing.Variational.q_locationscale`](@ref) | Find a numerically non-degenerate initialization for a location-scale variational family |
117+
| `q_meanfield_gaussian` | [`Turing.Variational.q_meanfield_gaussian`](@ref) | Find a numerically non-degenerate initialization for a mean-field Gaussian family |
118+
| `q_fullrank_gaussian` | [`Turing.Variational.q_fullrank_gaussian`](@ref) | Find a numerically non-degenerate initialization for a full-rank Gaussian family |
119+
| `KLMinRepGradDescent` | [`Turing.Variational.KLMinRepGradDescent`](@ref) | KL divergence minimization via stochastic gradient descent with the reparameterization gradient |
120+
| `KLMinRepGradProxDescent` | [`Turing.Variational.KLMinRepGradProxDescent`](@ref) | KL divergence minimization via stochastic proximal gradient descent with the reparameterization gradient over location-scale variational families |
121+
| `KLMinScoreGradDescent` | [`Turing.Variational.KLMinScoreGradDescent`](@ref) | KL divergence minimization via stochastic gradient descent with the score gradient |
122+
| `KLMinWassFwdBwd` | [`Turing.Variational.KLMinWassFwdBwd`](@ref) | KL divergence minimization via Wasserstein proximal gradient descent |
123+
| `KLMinNaturalGradDescent` | [`Turing.Variational.KLMinNaturalGradDescent`](@ref) | KL divergence minimization via natural gradient descent |
124+
| `KLMinSqrtNaturalGradDescent` | [`Turing.Variational.KLMinSqrtNaturalGradDescent`](@ref) | KL divergence minimization via natural gradient descent in the square-root parameterization |
125+
| `FisherMinBatchMatch` | [`Turing.Variational.FisherMinBatchMatch`](@ref) | Covariance-weighted Fisher divergence minimization via the batch-and-match algorithm |
119126

120127
### Automatic differentiation types
121128

src/Turing.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ include("stdlib/distributions.jl")
4747
include("stdlib/RandomMeasures.jl")
4848
include("mcmc/Inference.jl") # inference algorithms
4949
using .Inference
50-
include("variational/VariationalInference.jl")
50+
include("variational/Variational.jl")
5151
using .Variational
5252

5353
include("optimisation/Optimisation.jl")
@@ -119,10 +119,16 @@ export
119119
externalsampler,
120120
# Variational inference - AdvancedVI
121121
vi,
122-
ADVI,
123122
q_locationscale,
124123
q_meanfield_gaussian,
125124
q_fullrank_gaussian,
125+
KLMinRepGradProxDescent,
126+
KLMinRepGradDescent,
127+
KLMinScoreGradDescent,
128+
KLMinNaturalGradDescent,
129+
KLMinSqrtNaturalGradDescent,
130+
KLMinWassFwdBwd,
131+
FisherMinBatchMatch,
126132
# ADTypes
127133
AutoForwardDiff,
128134
AutoReverseDiff,

0 commit comments

Comments
 (0)