Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions JuliaBUGS/History.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# JuliaBUGS Changelog

## 0.12

- **DifferentiationInterface.jl integration**: Use `adtype` parameter in `compile()` to enable gradient-based inference via [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
- Example: `model = compile(model_def, data; adtype=AutoReverseDiff())`
- Supports `AutoReverseDiff`, `AutoForwardDiff`, `AutoMooncake`

- **Breaking**: `LogDensityProblemsAD.ADgradient` is no longer supported.
- Use `compile(...; adtype=...)` or `BUGSModelWithGradient(model, adtype)` instead.

## 0.10.1

Expose docs for changes in [v0.10.0](https://github.com/TuringLang/JuliaBUGS.jl/releases/tag/JuliaBUGS-v0.10.0)
Expand Down
2 changes: 2 additions & 0 deletions JuliaBUGS/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Expand Down Expand Up @@ -52,6 +53,7 @@ AdvancedHMC = "0.6, 0.7, 0.8"
AdvancedMH = "0.8"
BangBang = "0.4.1"
Bijectors = "0.13, 0.14, 0.15.5"
DifferentiationInterface = "0.7"
Distributions = "0.23.8, 0.24, 0.25"
Documenter = "0.27, 1"
GLMakie = "0.10, 0.11, 0.12, 0.13"
Expand Down
9 changes: 9 additions & 0 deletions JuliaBUGS/docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[sources]
JuliaBUGS = {path = ".."}

[compat]
Documenter = "1.14"
142 changes: 85 additions & 57 deletions JuliaBUGS/docs/src/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

```@setup abc
using JuliaBUGS
using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains, ADTypes, ReverseDiff

data = (
r = [10, 23, 23, 26, 17, 5, 53, 55, 32, 46, 10, 8, 10, 8, 23, 0, 3, 22, 15, 32, 3],
Expand Down Expand Up @@ -190,86 +191,113 @@ initialize!(model, initializations)
initialize!(model, rand(26))
```

`LogDensityProblemsAD.jl` defined some extensions that support automatic differentiation packages.
For example, with `ReverseDiff.jl`
### Automatic Differentiation

JuliaBUGS integrates with automatic differentiation (AD) through [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), enabling gradient-based inference methods like Hamiltonian Monte Carlo (HMC) and No-U-Turn Sampler (NUTS).

#### Specifying an AD Backend

To compile a model with gradient support, pass the `adtype` parameter to `compile`:

```julia
using LogDensityProblemsAD, ReverseDiff
# Compile with gradient support using ADTypes from ADTypes.jl
using ADTypes
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
```

Alternatively, if you already have a compiled `BUGSModel`, you can wrap it with `BUGSModelWithGradient` without recompiling:

ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
```julia
base_model = compile(model_def, data)
model = BUGSModelWithGradient(base_model, AutoReverseDiff(compile=true))
```

Here `ad_model` will also implement all the interfaces of [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/).
`LogDensityProblemsAD.jl` will automatically add the interface function [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient) to the model, which will return the log density and gradient of the model.
And `ad_model` can be used in the same way as `model` in the example below.
Available AD backends include:
- `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (recommended for most models)
- `AutoForwardDiff()` - ForwardDiff (efficient for models with few parameters)
- `AutoMooncake()` - Mooncake (requires `UseGeneratedLogDensityFunction()` mode)

For fine-grained control, you can configure the AD backend:

```julia
# ReverseDiff without compilation
model = compile(model_def, data; adtype=AutoReverseDiff(compile=false))
```

The compiled model with gradient support implements the [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, including [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient), which returns both the log density and its gradient.

### Inference

For a differentiable model, we can use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) to perform inference.
For instance,
For gradient-based inference, we use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) with models compiled with an `adtype`:

```julia
using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains
```@example abc
# Compile with gradient support
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))

n_samples, n_adapts = 2000, 1000

D = LogDensityProblems.dimension(model); initial_θ = rand(D)

samples_and_stats = AbstractMCMC.sample(
ad_model,
model,
NUTS(0.8),
n_samples;
chain_type = Chains,
n_adapts = n_adapts,
init_params = initial_θ,
discard_initial = n_adapts
discard_initial = n_adapts,
progress = false
)
describe(samples_and_stats)
```

This will return the MCMC Chain,

```plaintext
Chains MCMC chain (2000×40×1 Array{Real, 3}):

Iterations = 1001:1:3000
Number of chains = 1
Samples per chain = 2000
parameters = alpha0, alpha12, alpha1, alpha2, tau, b[16], b[12], b[10], b[14], b[13], b[7], b[6], b[20], b[1], b[4], b[5], b[2], b[18], b[8], b[3], b[9], b[21], b[17], b[15], b[11], b[19], sigma
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, is_adapt

Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Real Float64 Float64 Missing

alpha0 -0.5642 0.2320 0.0084 766.9305 1022.5211 1.0021 missing
alpha12 -0.8489 0.5247 0.0170 946.0418 1044.1109 1.0002 missing
alpha1 0.0587 0.3715 0.0119 966.4367 1233.2257 1.0007 missing
alpha2 1.3852 0.3410 0.0127 712.2978 974.1566 1.0002 missing
tau 1.8880 0.7705 0.0447 348.9331 338.3655 1.0030 missing
b[16] -0.2445 0.4459 0.0132 1528.0578 843.8225 1.0003 missing
b[12] 0.2050 0.3602 0.0086 1868.6126 1202.1363 0.9996 missing
b[10] -0.3500 0.2893 0.0090 1047.3119 1245.9358 1.0008 missing
⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
19 rows omitted

Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64

alpha0 -1.0143 -0.7143 -0.5590 -0.4100 -0.1185
alpha12 -1.9063 -1.1812 -0.8296 -0.5153 0.1521
alpha1 -0.6550 -0.1822 0.0512 0.2885 0.8180
alpha2 0.7214 1.1663 1.3782 1.5998 2.0986
tau 0.5461 1.3941 1.8353 2.3115 3.6225
b[16] -1.2359 -0.4836 -0.1909 0.0345 0.5070
b[12] -0.4493 -0.0370 0.1910 0.4375 0.9828
b[10] -0.9570 -0.5264 -0.3331 -0.1514 0.1613
⋮ ⋮ ⋮ ⋮ ⋮ ⋮
19 rows omitted
This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html).

## Evaluation Modes and Automatic Differentiation

JuliaBUGS supports multiple evaluation modes and AD backends. The evaluation mode determines how the log density is computed, and constrains which AD backends can be used.

### Evaluation Modes

| Mode | AD Backends |
|------|-------------|
| `UseGraph()` (default) | ReverseDiff, ForwardDiff |
| `UseGeneratedLogDensityFunction()` | Mooncake |

- **`UseGraph()`**: Evaluates by traversing the computational graph. Supports user-defined primitives registered via `@bugs_primitive`.
- **`UseGeneratedLogDensityFunction()`**: Generates and compiles a Julia function for the log density.

### AD Backends with `UseGraph()` Mode

Use [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) or [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) with the default `UseGraph()` mode:

```julia
using ADTypes

# ReverseDiff with tape compilation (recommended for large models)
model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))

# ForwardDiff (efficient for small models with < 20 parameters)
model = compile(model_def, data; adtype=AutoForwardDiff())

# ReverseDiff without compilation (supports control flow)
model = compile(model_def, data; adtype=AutoReverseDiff(compile=false))
```

This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html).
!!! warning "Compiled ReverseDiff does not support control flow"
Compiled tapes record a fixed execution path. If your model contains value-dependent control flow (e.g., `if x > 0`, `while`, truncation), the tape will only capture one branch and produce **incorrect gradients** when the control flow takes a different path. Use `AutoReverseDiff(compile=false)` or `AutoForwardDiff()` for models with control flow.

### AD Backend with `UseGeneratedLogDensityFunction()` Mode

Use [Mooncake.jl](https://github.com/compintell/Mooncake.jl) with the generated log density function mode:

```julia
using ADTypes

model = compile(model_def, data)
model = set_evaluation_mode(model, UseGeneratedLogDensityFunction())
model = BUGSModelWithGradient(model, AutoMooncake(; config=nothing))
```

## Parallel and Distributed Sampling with `AbstractMCMC`

Expand All @@ -283,7 +311,7 @@ The model compilation code remains the same, and we can sample multiple chains i
```julia
n_chains = 4
samples_and_stats = AbstractMCMC.sample(
ad_model,
model,
AdvancedHMC.NUTS(0.65),
AbstractMCMC.MCMCThreads(),
n_samples,
Expand Down Expand Up @@ -311,7 +339,7 @@ For example:

```julia
@everywhere begin
using JuliaBUGS, LogDensityProblems, LogDensityProblemsAD, AbstractMCMC, AdvancedHMC, MCMCChains, ReverseDiff # also other packages one may need
using JuliaBUGS, LogDensityProblems, AbstractMCMC, AdvancedHMC, MCMCChains, ADTypes, ReverseDiff

# Define the functions to use
# Use `@bugs_primitive` to register the functions to use in the model
Expand All @@ -322,7 +350,7 @@ end

n_chains = nprocs() - 1 # use all the processes except the parent process
samples_and_stats = AbstractMCMC.sample(
ad_model,
model,
AdvancedHMC.NUTS(0.65),
AbstractMCMC.MCMCDistributed(),
n_samples,
Expand Down
5 changes: 3 additions & 2 deletions JuliaBUGS/examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@ AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[sources]
JuliaBUGS = {path = ".."}
16 changes: 8 additions & 8 deletions JuliaBUGS/examples/bnn.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
using JuliaBUGS
using Distributions: Bernoulli, MvNormal

using AbstractMCMC
using ADTypes
using AdvancedHMC
using DifferentiationInterface
using FillArrays
using ForwardDiff
using Functors
using LinearAlgebra
using LogDensityProblems
using LogDensityProblemsAD
using Lux
using MCMCChains
using Mooncake
using Random

## data simulation
Expand Down Expand Up @@ -84,7 +83,7 @@ function make_prediction(parameters, xs; ps=ps, nn=nn)
return Lux.apply(nn, f32(xs), f32(vector_to_parameters(parameters, ps)))
end

JuliaBUGS.@bugs_primitive parameter_distribution make_prediction
JuliaBUGS.@bugs_primitive parameter_distribution make_prediction Bernoulli

@eval JuliaBUGS begin
ps = Main.ps
Expand All @@ -96,16 +95,17 @@ end

data = (nparameters=Lux.parameterlength(nn), xs=xs_hcat, ts=ts, N=length(ts), sigma=sigma)

# Use ForwardDiff with UseGraph mode (required for user-defined primitives)
model = compile(model_def, data)

ad_model = ADgradient(AutoMooncake(; config=Mooncake.Config()), model)
model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph())
model = JuliaBUGS.BUGSModelWithGradient(model, AutoForwardDiff())

# sampling is slow, so sample 10 of them to verify that this can work
samples_and_stats = AbstractMCMC.sample(
ad_model,
model,
NUTS(0.65),
10;
chain_type=Chains,
# n_adapts=1000,
# n_adapts=1000,
# discard_initial=1000
)
Loading
Loading