You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Integrate DifferentiationInterface.jl for gradient computation (#397)
Replace `LogDensityProblemsAD` with
[`DifferentiationInterface.jl`](https://github.com/JuliaDiff/DifferentiationInterface.jl)
for automatic differentiation.
Changes:
- Add `adtype` parameter to compile() for specifying AD backends
- Support symbol shortcuts: `:ReverseDiff`, `:ForwardDiff`, `:Zygote`,
`:Enzyme`
- Implement `BUGSModelWithGradient` wrapper with
`logdensity_and_gradient` method
- Backward compatible: existing code without `adtype` continues to work
Usage:
```julia
model = compile(model_def, data; adtype=:ReverseDiff)
```
```julia
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) # Same as above
```
Closes#380
---------
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Xianda Sun <[email protected]>
Copy file name to clipboardExpand all lines: JuliaBUGS/History.md
+9Lines changed: 9 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,5 +1,14 @@
1
1
# JuliaBUGS Changelog
2
2
3
+
## 0.12
4
+
5
+
-**DifferentiationInterface.jl integration**: Use `adtype` parameter in `compile()` to enable gradient-based inference via [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
`LogDensityProblemsAD.jl` defined some extensions that support automatic differentiation packages.
194
-
For example, with `ReverseDiff.jl`
194
+
### Automatic Differentiation
195
+
196
+
JuliaBUGS integrates with automatic differentiation (AD) through [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), enabling gradient-based inference methods like Hamiltonian Monte Carlo (HMC) and No-U-Turn Sampler (NUTS).
197
+
198
+
#### Specifying an AD Backend
199
+
200
+
To compile a model with gradient support, pass the `adtype` parameter to `compile`:
195
201
196
202
```julia
197
-
using LogDensityProblemsAD, ReverseDiff
203
+
# Compile with gradient support using ADTypes from ADTypes.jl
204
+
using ADTypes
205
+
model =compile(model_def, data; adtype=AutoReverseDiff(compile=true))
206
+
```
207
+
208
+
Alternatively, if you already have a compiled `BUGSModel`, you can wrap it with `BUGSModelWithGradient` without recompiling:
model =BUGSModelWithGradient(base_model, AutoReverseDiff(compile=true))
200
213
```
201
214
202
-
Here `ad_model` will also implement all the interfaces of [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/).
203
-
`LogDensityProblemsAD.jl` will automatically add the interface function [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient) to the model, which will return the log density and gradient of the model.
204
-
And `ad_model` can be used in the same way as `model` in the example below.
215
+
Available AD backends include:
216
+
-`AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (recommended for most models)
217
+
-`AutoForwardDiff()` - ForwardDiff (efficient for models with few parameters)
For fine-grained control, you can configure the AD backend:
221
+
222
+
```julia
223
+
# ReverseDiff without compilation
224
+
model =compile(model_def, data; adtype=AutoReverseDiff(compile=false))
225
+
```
226
+
227
+
The compiled model with gradient support implements the [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, including [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient), which returns both the log density and its gradient.
205
228
206
229
### Inference
207
230
208
-
For a differentiable model, we can use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) to perform inference.
209
-
For instance,
231
+
For gradient-based inference, we use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) with models compiled with an `adtype`:
210
232
211
-
```julia
212
-
using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains
233
+
```@example abc
234
+
# Compile with gradient support
235
+
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
213
236
214
237
n_samples, n_adapts = 2000, 1000
215
238
216
239
D = LogDensityProblems.dimension(model); initial_θ = rand(D)
This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html).
255
+
256
+
## Evaluation Modes and Automatic Differentiation
257
+
258
+
JuliaBUGS supports multiple evaluation modes and AD backends. The evaluation mode determines how the log density is computed, and constrains which AD backends can be used.
-**`UseGraph()`**: Evaluates by traversing the computational graph. Supports user-defined primitives registered via `@bugs_primitive`.
268
+
-**`UseGeneratedLogDensityFunction()`**: Generates and compiles a Julia function for the log density.
269
+
270
+
### AD Backends with `UseGraph()` Mode
271
+
272
+
Use [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) or [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) with the default `UseGraph()` mode:
273
+
274
+
```julia
275
+
using ADTypes
276
+
277
+
# ReverseDiff with tape compilation (recommended for large models)
278
+
model =compile(model_def, data; adtype=AutoReverseDiff(compile=true))
279
+
280
+
# ForwardDiff (efficient for small models with < 20 parameters)
281
+
model =compile(model_def, data; adtype=AutoForwardDiff())
282
+
283
+
# ReverseDiff without compilation (supports control flow)
284
+
model =compile(model_def, data; adtype=AutoReverseDiff(compile=false))
270
285
```
271
286
272
-
This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html).
287
+
!!! warning "Compiled ReverseDiff does not support control flow"
288
+
Compiled tapes record a fixed execution path. If your model contains value-dependent control flow (e.g., `if x > 0`, `while`, truncation), the tape will only capture one branch and produce **incorrect gradients** when the control flow takes a different path. Use `AutoReverseDiff(compile=false)` or `AutoForwardDiff()` for models with control flow.
289
+
290
+
### AD Backend with `UseGeneratedLogDensityFunction()` Mode
291
+
292
+
Use [Mooncake.jl](https://github.com/compintell/Mooncake.jl) with the generated log density function mode:
293
+
294
+
```julia
295
+
using ADTypes
296
+
297
+
model =compile(model_def, data)
298
+
model =set_evaluation_mode(model, UseGeneratedLogDensityFunction())
299
+
model =BUGSModelWithGradient(model, AutoMooncake(; config=nothing))
300
+
```
273
301
274
302
## Parallel and Distributed Sampling with `AbstractMCMC`
275
303
@@ -283,7 +311,7 @@ The model compilation code remains the same, and we can sample multiple chains i
283
311
```julia
284
312
n_chains =4
285
313
samples_and_stats = AbstractMCMC.sample(
286
-
ad_model,
314
+
model,
287
315
AdvancedHMC.NUTS(0.65),
288
316
AbstractMCMC.MCMCThreads(),
289
317
n_samples,
@@ -311,7 +339,7 @@ For example:
311
339
312
340
```julia
313
341
@everywherebegin
314
-
using JuliaBUGS, LogDensityProblems, LogDensityProblemsAD, AbstractMCMC, AdvancedHMC, MCMCChains, ReverseDiff # also other packages one may need
342
+
using JuliaBUGS, LogDensityProblems, AbstractMCMC, AdvancedHMC, MCMCChains, ADTypes, ReverseDiff
315
343
316
344
# Define the functions to use
317
345
# Use `@bugs_primitive` to register the functions to use in the model
@@ -322,7 +350,7 @@ end
322
350
323
351
n_chains =nprocs() -1# use all the processes except the parent process
0 commit comments