Skip to content

Commit b40f54b

Browse files
shravanngoswamiiyebaigithub-actions[bot]sunxd3
authored
Integrate DifferentiationInterface.jl for gradient computation (#397)
Replace `LogDensityProblemsAD` with [`DifferentiationInterface.jl`](https://github.com/JuliaDiff/DifferentiationInterface.jl) for automatic differentiation. Changes: - Add `adtype` parameter to compile() for specifying AD backends - Support symbol shortcuts: `:ReverseDiff`, `:ForwardDiff`, `:Zygote`, `:Enzyme` - Implement `BUGSModelWithGradient` wrapper with `logdensity_and_gradient` method - Backward compatible: existing code without `adtype` continues to work Usage: ```julia model = compile(model_def, data; adtype=:ReverseDiff) ``` ```julia model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) # Same as above ``` Closes #380 --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Xianda Sun <[email protected]>
1 parent 8546750 commit b40f54b

22 files changed

+431
-222
lines changed

JuliaBUGS/History.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# JuliaBUGS Changelog
22

3+
## 0.12
4+
5+
- **DifferentiationInterface.jl integration**: Use `adtype` parameter in `compile()` to enable gradient-based inference via [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
6+
- Example: `model = compile(model_def, data; adtype=AutoReverseDiff())`
7+
- Supports `AutoReverseDiff`, `AutoForwardDiff`, `AutoMooncake`
8+
9+
- **Breaking**: `LogDensityProblemsAD.ADgradient` is no longer supported.
10+
- Use `compile(...; adtype=...)` or `BUGSModelWithGradient(model, adtype)` instead.
11+
312
## 0.10.1
413

514
Expose docs for changes in [v0.10.0](https://github.com/TuringLang/JuliaBUGS.jl/releases/tag/JuliaBUGS-v0.10.0)

JuliaBUGS/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
99
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
1010
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1111
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
12+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1213
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1314
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1415
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -52,6 +53,7 @@ AdvancedHMC = "0.6, 0.7, 0.8"
5253
AdvancedMH = "0.8"
5354
BangBang = "0.4.1"
5455
Bijectors = "0.13, 0.14, 0.15.5"
56+
DifferentiationInterface = "0.7"
5557
Distributions = "0.23.8, 0.24, 0.25"
5658
Documenter = "0.27, 1"
5759
GLMakie = "0.10, 0.11, 0.12, 0.13"

JuliaBUGS/docs/Project.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
[deps]
2+
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
3+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
4+
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
25
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
36
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
7+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
8+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
49
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
10+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
11+
12+
[sources]
13+
JuliaBUGS = {path = ".."}
514

615
[compat]
716
Documenter = "1.14"

JuliaBUGS/docs/src/example.md

Lines changed: 85 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
```@setup abc
44
using JuliaBUGS
5+
using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains, ADTypes, ReverseDiff
56
67
data = (
78
r = [10, 23, 23, 26, 17, 5, 53, 55, 32, 46, 10, 8, 10, 8, 23, 0, 3, 22, 15, 32, 3],
@@ -190,86 +191,113 @@ initialize!(model, initializations)
190191
initialize!(model, rand(26))
191192
```
192193

193-
`LogDensityProblemsAD.jl` defined some extensions that support automatic differentiation packages.
194-
For example, with `ReverseDiff.jl`
194+
### Automatic Differentiation
195+
196+
JuliaBUGS integrates with automatic differentiation (AD) through [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), enabling gradient-based inference methods like Hamiltonian Monte Carlo (HMC) and No-U-Turn Sampler (NUTS).
197+
198+
#### Specifying an AD Backend
199+
200+
To compile a model with gradient support, pass the `adtype` parameter to `compile`:
195201

196202
```julia
197-
using LogDensityProblemsAD, ReverseDiff
203+
# Compile with gradient support using ADTypes from ADTypes.jl
204+
using ADTypes
205+
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
206+
```
207+
208+
Alternatively, if you already have a compiled `BUGSModel`, you can wrap it with `BUGSModelWithGradient` without recompiling:
198209

199-
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
210+
```julia
211+
base_model = compile(model_def, data)
212+
model = BUGSModelWithGradient(base_model, AutoReverseDiff(compile=true))
200213
```
201214

202-
Here `ad_model` will also implement all the interfaces of [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/).
203-
`LogDensityProblemsAD.jl` will automatically add the interface function [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient) to the model, which will return the log density and gradient of the model.
204-
And `ad_model` can be used in the same way as `model` in the example below.
215+
Available AD backends include:
216+
- `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (recommended for most models)
217+
- `AutoForwardDiff()` - ForwardDiff (efficient for models with few parameters)
218+
- `AutoMooncake()` - Mooncake (requires `UseGeneratedLogDensityFunction()` mode)
219+
220+
For fine-grained control, you can configure the AD backend:
221+
222+
```julia
223+
# ReverseDiff without compilation
224+
model = compile(model_def, data; adtype=AutoReverseDiff(compile=false))
225+
```
226+
227+
The compiled model with gradient support implements the [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, including [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient), which returns both the log density and its gradient.
205228

206229
### Inference
207230

208-
For a differentiable model, we can use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) to perform inference.
209-
For instance,
231+
For gradient-based inference, we use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) with models compiled with an `adtype`:
210232

211-
```julia
212-
using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains
233+
```@example abc
234+
# Compile with gradient support
235+
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
213236
214237
n_samples, n_adapts = 2000, 1000
215238
216239
D = LogDensityProblems.dimension(model); initial_θ = rand(D)
217240
218241
samples_and_stats = AbstractMCMC.sample(
219-
ad_model,
242+
model,
220243
NUTS(0.8),
221244
n_samples;
222245
chain_type = Chains,
223246
n_adapts = n_adapts,
224247
init_params = initial_θ,
225-
discard_initial = n_adapts
248+
discard_initial = n_adapts,
249+
progress = false
226250
)
251+
describe(samples_and_stats)
227252
```
228253

229-
This will return the MCMC Chain,
230-
231-
```plaintext
232-
Chains MCMC chain (2000×40×1 Array{Real, 3}):
233-
234-
Iterations = 1001:1:3000
235-
Number of chains = 1
236-
Samples per chain = 2000
237-
parameters = alpha0, alpha12, alpha1, alpha2, tau, b[16], b[12], b[10], b[14], b[13], b[7], b[6], b[20], b[1], b[4], b[5], b[2], b[18], b[8], b[3], b[9], b[21], b[17], b[15], b[11], b[19], sigma
238-
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, is_adapt
239-
240-
Summary Statistics
241-
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
242-
Symbol Float64 Float64 Float64 Real Float64 Float64 Missing
243-
244-
alpha0 -0.5642 0.2320 0.0084 766.9305 1022.5211 1.0021 missing
245-
alpha12 -0.8489 0.5247 0.0170 946.0418 1044.1109 1.0002 missing
246-
alpha1 0.0587 0.3715 0.0119 966.4367 1233.2257 1.0007 missing
247-
alpha2 1.3852 0.3410 0.0127 712.2978 974.1566 1.0002 missing
248-
tau 1.8880 0.7705 0.0447 348.9331 338.3655 1.0030 missing
249-
b[16] -0.2445 0.4459 0.0132 1528.0578 843.8225 1.0003 missing
250-
b[12] 0.2050 0.3602 0.0086 1868.6126 1202.1363 0.9996 missing
251-
b[10] -0.3500 0.2893 0.0090 1047.3119 1245.9358 1.0008 missing
252-
⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
253-
19 rows omitted
254-
255-
Quantiles
256-
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
257-
Symbol Float64 Float64 Float64 Float64 Float64
258-
259-
alpha0 -1.0143 -0.7143 -0.5590 -0.4100 -0.1185
260-
alpha12 -1.9063 -1.1812 -0.8296 -0.5153 0.1521
261-
alpha1 -0.6550 -0.1822 0.0512 0.2885 0.8180
262-
alpha2 0.7214 1.1663 1.3782 1.5998 2.0986
263-
tau 0.5461 1.3941 1.8353 2.3115 3.6225
264-
b[16] -1.2359 -0.4836 -0.1909 0.0345 0.5070
265-
b[12] -0.4493 -0.0370 0.1910 0.4375 0.9828
266-
b[10] -0.9570 -0.5264 -0.3331 -0.1514 0.1613
267-
⋮ ⋮ ⋮ ⋮ ⋮ ⋮
268-
19 rows omitted
254+
This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html).
255+
256+
## Evaluation Modes and Automatic Differentiation
257+
258+
JuliaBUGS supports multiple evaluation modes and AD backends. The evaluation mode determines how the log density is computed, and constrains which AD backends can be used.
259+
260+
### Evaluation Modes
269261

262+
| Mode | AD Backends |
263+
|------|-------------|
264+
| `UseGraph()` (default) | ReverseDiff, ForwardDiff |
265+
| `UseGeneratedLogDensityFunction()` | Mooncake |
266+
267+
- **`UseGraph()`**: Evaluates by traversing the computational graph. Supports user-defined primitives registered via `@bugs_primitive`.
268+
- **`UseGeneratedLogDensityFunction()`**: Generates and compiles a Julia function for the log density.
269+
270+
### AD Backends with `UseGraph()` Mode
271+
272+
Use [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) or [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) with the default `UseGraph()` mode:
273+
274+
```julia
275+
using ADTypes
276+
277+
# ReverseDiff with tape compilation (recommended for large models)
278+
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
279+
280+
# ForwardDiff (efficient for small models with < 20 parameters)
281+
model = compile(model_def, data; adtype=AutoForwardDiff())
282+
283+
# ReverseDiff without compilation (supports control flow)
284+
model = compile(model_def, data; adtype=AutoReverseDiff(compile=false))
270285
```
271286

272-
This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html).
287+
!!! warning "Compiled ReverseDiff does not support control flow"
288+
Compiled tapes record a fixed execution path. If your model contains value-dependent control flow (e.g., `if x > 0`, `while`, truncation), the tape will only capture one branch and produce **incorrect gradients** when the control flow takes a different path. Use `AutoReverseDiff(compile=false)` or `AutoForwardDiff()` for models with control flow.
289+
290+
### AD Backend with `UseGeneratedLogDensityFunction()` Mode
291+
292+
Use [Mooncake.jl](https://github.com/compintell/Mooncake.jl) with the generated log density function mode:
293+
294+
```julia
295+
using ADTypes
296+
297+
model = compile(model_def, data)
298+
model = set_evaluation_mode(model, UseGeneratedLogDensityFunction())
299+
model = BUGSModelWithGradient(model, AutoMooncake(; config=nothing))
300+
```
273301

274302
## Parallel and Distributed Sampling with `AbstractMCMC`
275303

@@ -283,7 +311,7 @@ The model compilation code remains the same, and we can sample multiple chains i
283311
```julia
284312
n_chains = 4
285313
samples_and_stats = AbstractMCMC.sample(
286-
ad_model,
314+
model,
287315
AdvancedHMC.NUTS(0.65),
288316
AbstractMCMC.MCMCThreads(),
289317
n_samples,
@@ -311,7 +339,7 @@ For example:
311339

312340
```julia
313341
@everywhere begin
314-
using JuliaBUGS, LogDensityProblems, LogDensityProblemsAD, AbstractMCMC, AdvancedHMC, MCMCChains, ReverseDiff # also other packages one may need
342+
using JuliaBUGS, LogDensityProblems, AbstractMCMC, AdvancedHMC, MCMCChains, ADTypes, ReverseDiff
315343

316344
# Define the functions to use
317345
# Use `@bugs_primitive` to register the functions to use in the model
@@ -322,7 +350,7 @@ end
322350

323351
n_chains = nprocs() - 1 # use all the processes except the parent process
324352
samples_and_stats = AbstractMCMC.sample(
325-
ad_model,
353+
model,
326354
AdvancedHMC.NUTS(0.65),
327355
AbstractMCMC.MCMCDistributed(),
328356
n_samples,

JuliaBUGS/examples/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@ AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
44
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
55
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
66
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
7-
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
87
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
98
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
109
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1110
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1211
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
1312
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
14-
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
1513
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1614
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1715
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1816
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1917
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2018
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
19+
20+
[sources]
21+
JuliaBUGS = {path = ".."}

JuliaBUGS/examples/bnn.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
using JuliaBUGS
2+
using Distributions: Bernoulli, MvNormal
23

34
using AbstractMCMC
45
using ADTypes
56
using AdvancedHMC
6-
using DifferentiationInterface
77
using FillArrays
8+
using ForwardDiff
89
using Functors
910
using LinearAlgebra
1011
using LogDensityProblems
11-
using LogDensityProblemsAD
1212
using Lux
1313
using MCMCChains
14-
using Mooncake
1514
using Random
1615

1716
## data simulation
@@ -84,7 +83,7 @@ function make_prediction(parameters, xs; ps=ps, nn=nn)
8483
return Lux.apply(nn, f32(xs), f32(vector_to_parameters(parameters, ps)))
8584
end
8685

87-
JuliaBUGS.@bugs_primitive parameter_distribution make_prediction
86+
JuliaBUGS.@bugs_primitive parameter_distribution make_prediction Bernoulli
8887

8988
@eval JuliaBUGS begin
9089
ps = Main.ps
@@ -96,16 +95,17 @@ end
9695

9796
data = (nparameters=Lux.parameterlength(nn), xs=xs_hcat, ts=ts, N=length(ts), sigma=sigma)
9897

98+
# Use ForwardDiff with UseGraph mode (required for user-defined primitives)
9999
model = compile(model_def, data)
100-
101-
ad_model = ADgradient(AutoMooncake(; config=Mooncake.Config()), model)
100+
model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph())
101+
model = JuliaBUGS.BUGSModelWithGradient(model, AutoForwardDiff())
102102

103103
# sampling is slow, so sample 10 of them to verify that this can work
104104
samples_and_stats = AbstractMCMC.sample(
105-
ad_model,
105+
model,
106106
NUTS(0.65),
107107
10;
108108
chain_type=Chains,
109-
# n_adapts=1000,
109+
# n_adapts=1000,
110110
# discard_initial=1000
111111
)

0 commit comments

Comments
 (0)