diff --git a/JuliaBUGS/History.md b/JuliaBUGS/History.md index cb5567431..0e43ea767 100644 --- a/JuliaBUGS/History.md +++ b/JuliaBUGS/History.md @@ -1,5 +1,14 @@ # JuliaBUGS Changelog +## 0.12 + +- **DifferentiationInterface.jl integration**: Use `adtype` parameter in `compile()` to enable gradient-based inference via [ADTypes.jl](https://github.com/SciML/ADTypes.jl). + - Example: `model = compile(model_def, data; adtype=AutoReverseDiff())` + - Supports `AutoReverseDiff`, `AutoForwardDiff`, `AutoMooncake` + +- **Breaking**: `LogDensityProblemsAD.ADgradient` is no longer supported. + - Use `compile(...; adtype=...)` or `BUGSModelWithGradient(model, adtype)` instead. + ## 0.10.1 Expose docs for changes in [v0.10.0](https://github.com/TuringLang/JuliaBUGS.jl/releases/tag/JuliaBUGS-v0.10.0) diff --git a/JuliaBUGS/Project.toml b/JuliaBUGS/Project.toml index 81449b157..dc0345ad9 100644 --- a/JuliaBUGS/Project.toml +++ b/JuliaBUGS/Project.toml @@ -9,6 +9,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -52,6 +53,7 @@ AdvancedHMC = "0.6, 0.7, 0.8" AdvancedMH = "0.8" BangBang = "0.4.1" Bijectors = "0.13, 0.14, 0.15.5" +DifferentiationInterface = "0.7" Distributions = "0.23.8, 0.24, 0.25" Documenter = "0.27, 1" GLMakie = "0.10, 0.11, 0.12, 0.13" diff --git a/JuliaBUGS/docs/Project.toml b/JuliaBUGS/docs/Project.toml index a8e5b92ac..ca5e166dc 100644 --- a/JuliaBUGS/docs/Project.toml +++ b/JuliaBUGS/docs/Project.toml @@ -1,7 +1,16 @@ [deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + +[sources] +JuliaBUGS = {path = ".."} [compat] Documenter = "1.14" diff --git a/JuliaBUGS/docs/src/example.md b/JuliaBUGS/docs/src/example.md index eaba1a01e..e8a162f23 100644 --- a/JuliaBUGS/docs/src/example.md +++ b/JuliaBUGS/docs/src/example.md @@ -2,6 +2,7 @@ ```@setup abc using JuliaBUGS +using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains, ADTypes, ReverseDiff data = ( r = [10, 23, 23, 26, 17, 5, 53, 55, 32, 46, 10, 8, 10, 8, 23, 0, 3, 22, 15, 32, 3], @@ -190,86 +191,113 @@ initialize!(model, initializations) initialize!(model, rand(26)) ``` -`LogDensityProblemsAD.jl` defined some extensions that support automatic differentiation packages. -For example, with `ReverseDiff.jl` +### Automatic Differentiation + +JuliaBUGS integrates with automatic differentiation (AD) through [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), enabling gradient-based inference methods like Hamiltonian Monte Carlo (HMC) and No-U-Turn Sampler (NUTS). + +#### Specifying an AD Backend + +To compile a model with gradient support, pass the `adtype` parameter to `compile`: ```julia -using LogDensityProblemsAD, ReverseDiff +# Compile with gradient support using ADTypes from ADTypes.jl +using ADTypes +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) +``` + +Alternatively, if you already have a compiled `BUGSModel`, you can wrap it with `BUGSModelWithGradient` without recompiling: -ad_model = ADgradient(:ReverseDiff, model; compile=Val(true)) +```julia +base_model = compile(model_def, data) +model = BUGSModelWithGradient(base_model, AutoReverseDiff(compile=true)) ``` -Here `ad_model` will also implement all the interfaces of [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/). -`LogDensityProblemsAD.jl` will automatically add the interface function [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient) to the model, which will return the log density and gradient of the model. -And `ad_model` can be used in the same way as `model` in the example below. +Available AD backends include: +- `AutoReverseDiff(compile=true)` - ReverseDiff with tape compilation (recommended for most models) +- `AutoForwardDiff()` - ForwardDiff (efficient for models with few parameters) +- `AutoMooncake()` - Mooncake (requires `UseGeneratedLogDensityFunction()` mode) + +For fine-grained control, you can configure the AD backend: + +```julia +# ReverseDiff without compilation +model = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) +``` + +The compiled model with gradient support implements the [`LogDensityProblems.jl`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/) interface, including [`logdensity_and_gradient`](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.logdensity_and_gradient), which returns both the log density and its gradient. ### Inference -For a differentiable model, we can use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) to perform inference. -For instance, +For gradient-based inference, we use [`AdvancedHMC.jl`](https://github.com/TuringLang/AdvancedHMC.jl) with models compiled with an `adtype`: -```julia -using AdvancedHMC, AbstractMCMC, LogDensityProblems, MCMCChains +```@example abc +# Compile with gradient support +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) n_samples, n_adapts = 2000, 1000 D = LogDensityProblems.dimension(model); initial_θ = rand(D) samples_and_stats = AbstractMCMC.sample( - ad_model, + model, NUTS(0.8), n_samples; chain_type = Chains, n_adapts = n_adapts, init_params = initial_θ, - discard_initial = n_adapts + discard_initial = n_adapts, + progress = false ) +describe(samples_and_stats) ``` -This will return the MCMC Chain, - -```plaintext -Chains MCMC chain (2000×40×1 Array{Real, 3}): - -Iterations = 1001:1:3000 -Number of chains = 1 -Samples per chain = 2000 -parameters = alpha0, alpha12, alpha1, alpha2, tau, b[16], b[12], b[10], b[14], b[13], b[7], b[6], b[20], b[1], b[4], b[5], b[2], b[18], b[8], b[3], b[9], b[21], b[17], b[15], b[11], b[19], sigma -internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, is_adapt - -Summary Statistics - parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec - Symbol Float64 Float64 Float64 Real Float64 Float64 Missing - - alpha0 -0.5642 0.2320 0.0084 766.9305 1022.5211 1.0021 missing - alpha12 -0.8489 0.5247 0.0170 946.0418 1044.1109 1.0002 missing - alpha1 0.0587 0.3715 0.0119 966.4367 1233.2257 1.0007 missing - alpha2 1.3852 0.3410 0.0127 712.2978 974.1566 1.0002 missing - tau 1.8880 0.7705 0.0447 348.9331 338.3655 1.0030 missing - b[16] -0.2445 0.4459 0.0132 1528.0578 843.8225 1.0003 missing - b[12] 0.2050 0.3602 0.0086 1868.6126 1202.1363 0.9996 missing - b[10] -0.3500 0.2893 0.0090 1047.3119 1245.9358 1.0008 missing - ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ - 19 rows omitted - -Quantiles - parameters 2.5% 25.0% 50.0% 75.0% 97.5% - Symbol Float64 Float64 Float64 Float64 Float64 - - alpha0 -1.0143 -0.7143 -0.5590 -0.4100 -0.1185 - alpha12 -1.9063 -1.1812 -0.8296 -0.5153 0.1521 - alpha1 -0.6550 -0.1822 0.0512 0.2885 0.8180 - alpha2 0.7214 1.1663 1.3782 1.5998 2.0986 - tau 0.5461 1.3941 1.8353 2.3115 3.6225 - b[16] -1.2359 -0.4836 -0.1909 0.0345 0.5070 - b[12] -0.4493 -0.0370 0.1910 0.4375 0.9828 - b[10] -0.9570 -0.5264 -0.3331 -0.1514 0.1613 - ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ - 19 rows omitted +This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html). + +## Evaluation Modes and Automatic Differentiation + +JuliaBUGS supports multiple evaluation modes and AD backends. The evaluation mode determines how the log density is computed, and constrains which AD backends can be used. + +### Evaluation Modes +| Mode | AD Backends | +|------|-------------| +| `UseGraph()` (default) | ReverseDiff, ForwardDiff | +| `UseGeneratedLogDensityFunction()` | Mooncake | + +- **`UseGraph()`**: Evaluates by traversing the computational graph. Supports user-defined primitives registered via `@bugs_primitive`. +- **`UseGeneratedLogDensityFunction()`**: Generates and compiles a Julia function for the log density. + +### AD Backends with `UseGraph()` Mode + +Use [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) or [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) with the default `UseGraph()` mode: + +```julia +using ADTypes + +# ReverseDiff with tape compilation (recommended for large models) +model = compile(model_def, data; adtype=AutoReverseDiff(compile=true)) + +# ForwardDiff (efficient for small models with < 20 parameters) +model = compile(model_def, data; adtype=AutoForwardDiff()) + +# ReverseDiff without compilation (supports control flow) +model = compile(model_def, data; adtype=AutoReverseDiff(compile=false)) ``` -This is consistent with the result in the [OpenBUGS seeds example](https://chjackson.github.io/openbugsdoc/Examples/Seeds.html). +!!! warning "Compiled ReverseDiff does not support control flow" + Compiled tapes record a fixed execution path. If your model contains value-dependent control flow (e.g., `if x > 0`, `while`, truncation), the tape will only capture one branch and produce **incorrect gradients** when the control flow takes a different path. Use `AutoReverseDiff(compile=false)` or `AutoForwardDiff()` for models with control flow. + +### AD Backend with `UseGeneratedLogDensityFunction()` Mode + +Use [Mooncake.jl](https://github.com/compintell/Mooncake.jl) with the generated log density function mode: + +```julia +using ADTypes + +model = compile(model_def, data) +model = set_evaluation_mode(model, UseGeneratedLogDensityFunction()) +model = BUGSModelWithGradient(model, AutoMooncake(; config=nothing)) +``` ## Parallel and Distributed Sampling with `AbstractMCMC` @@ -283,7 +311,7 @@ The model compilation code remains the same, and we can sample multiple chains i ```julia n_chains = 4 samples_and_stats = AbstractMCMC.sample( - ad_model, + model, AdvancedHMC.NUTS(0.65), AbstractMCMC.MCMCThreads(), n_samples, @@ -311,7 +339,7 @@ For example: ```julia @everywhere begin - using JuliaBUGS, LogDensityProblems, LogDensityProblemsAD, AbstractMCMC, AdvancedHMC, MCMCChains, ReverseDiff # also other packages one may need + using JuliaBUGS, LogDensityProblems, AbstractMCMC, AdvancedHMC, MCMCChains, ADTypes, ReverseDiff # Define the functions to use # Use `@bugs_primitive` to register the functions to use in the model @@ -322,7 +350,7 @@ end n_chains = nprocs() - 1 # use all the processes except the parent process samples_and_stats = AbstractMCMC.sample( - ad_model, + model, AdvancedHMC.NUTS(0.65), AbstractMCMC.MCMCDistributed(), n_samples, diff --git a/JuliaBUGS/examples/Project.toml b/JuliaBUGS/examples/Project.toml index 401edf467..8a043e94c 100644 --- a/JuliaBUGS/examples/Project.toml +++ b/JuliaBUGS/examples/Project.toml @@ -4,17 +4,18 @@ AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + +[sources] +JuliaBUGS = {path = ".."} diff --git a/JuliaBUGS/examples/bnn.jl b/JuliaBUGS/examples/bnn.jl index f7666904f..afb21e0a8 100644 --- a/JuliaBUGS/examples/bnn.jl +++ b/JuliaBUGS/examples/bnn.jl @@ -1,17 +1,16 @@ using JuliaBUGS +using Distributions: Bernoulli, MvNormal using AbstractMCMC using ADTypes using AdvancedHMC -using DifferentiationInterface using FillArrays +using ForwardDiff using Functors using LinearAlgebra using LogDensityProblems -using LogDensityProblemsAD using Lux using MCMCChains -using Mooncake using Random ## data simulation @@ -84,7 +83,7 @@ function make_prediction(parameters, xs; ps=ps, nn=nn) return Lux.apply(nn, f32(xs), f32(vector_to_parameters(parameters, ps))) end -JuliaBUGS.@bugs_primitive parameter_distribution make_prediction +JuliaBUGS.@bugs_primitive parameter_distribution make_prediction Bernoulli @eval JuliaBUGS begin ps = Main.ps @@ -96,16 +95,17 @@ end data = (nparameters=Lux.parameterlength(nn), xs=xs_hcat, ts=ts, N=length(ts), sigma=sigma) +# Use ForwardDiff with UseGraph mode (required for user-defined primitives) model = compile(model_def, data) - -ad_model = ADgradient(AutoMooncake(; config=Mooncake.Config()), model) +model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph()) +model = JuliaBUGS.BUGSModelWithGradient(model, AutoForwardDiff()) # sampling is slow, so sample 10 of them to verify that this can work samples_and_stats = AbstractMCMC.sample( - ad_model, + model, NUTS(0.65), 10; chain_type=Chains, - # n_adapts=1000, + # n_adapts=1000, # discard_initial=1000 ) diff --git a/JuliaBUGS/examples/gp.jl b/JuliaBUGS/examples/gp.jl index fd8188863..562937e61 100644 --- a/JuliaBUGS/examples/gp.jl +++ b/JuliaBUGS/examples/gp.jl @@ -7,14 +7,11 @@ using JuliaBUGS using JuliaBUGS: @model # Required packages for GP modeling and MCMC -using AbstractGPs, Distributions, LogExpFunctions -using LogDensityProblems, LogDensityProblemsAD +using AbstractGPs, Distributions, LogExpFunctions, ForwardDiff +using LogDensityProblems +using ADTypes using AbstractMCMC, AdvancedHMC, MCMCChains -# Differentiation backend -using DifferentiationInterface -using Mooncake: Mooncake - # --- Data Definition --- # Golf putting data from Gelman et al. (BDA3, Chapter 5) @@ -120,94 +117,18 @@ model = gp_golf_putting( data.jitter, # Numerical stability term ) -# Generate the log density function for optimal performance -model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGeneratedLogDensityFunction()) - -# --- MCMC Setup with Custom LogDensityProblems Wrapper --- - -# We need a wrapper around the JuliaBUGS model to interface with LogDensityProblems -# and utilize automatic differentiation (AD) via Mooncake.jl for gradient computation, -# which is required by AdvancedHMC. - -struct BUGSMooncakeModel{T,P} - model::T # The JuliaBUGS model - prep::P # Pre-allocated workspace for gradient computation using Mooncake -end - -# Define the function to compute the log density using the JuliaBUGS model's internal function -f(x) = model.log_density_computation_function(model.evaluation_env, x) - -# Prepare the differentiation backend (Mooncake) -backend = AutoMooncake(; config=nothing) -x_init = rand(LogDensityProblems.dimension(model)) # Initial point for testing/preparation -prep = prepare_gradient(f, backend, x_init) - -# Create the wrapped model instance -bugsmooncake = BUGSMooncakeModel(model, prep) - -# --- LogDensityProblems Interface Implementation for the Wrapper --- - -# Define logdensity function for the wrapper -function LogDensityProblems.logdensity(model::BUGSMooncakeModel, x::AbstractVector) - return f(x) # Calls the underlying JuliaBUGS log density function -end - -# Define logdensity_and_gradient function using the prepared DifferentiationInterface setup -function LogDensityProblems.logdensity_and_gradient( - model::BUGSMooncakeModel, x::AbstractVector -) - # Computes both the log density and its gradient using Mooncake AD - return DifferentiationInterface.value_and_gradient( - f, model.prep, AutoMooncake(; config=nothing), x - ) -end - -# Define dimension function -function LogDensityProblems.dimension(model::BUGSMooncakeModel) - return LogDensityProblems.dimension(model.model) # Delegates to the original model -end - -# Define a custom bundle_samples function to convert the AdvancedHMC.Transition to a Chains object -function AbstractMCMC.bundle_samples( - ts::Vector{<:AdvancedHMC.Transition}, - logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSMooncakeModel}, - sampler::AdvancedHMC.AbstractHMCSampler, - state, - chain_type::Type{Chains}; - discard_initial=0, - thinning=1, - kwargs..., -) - stats_names = collect(keys(merge((; lp=ts[1].z.ℓπ.value), AdvancedHMC.stat(ts[1])))) - stats_values = [ - vcat([ts[i].z.ℓπ.value..., collect(values(AdvancedHMC.stat(ts[i])))...]) for - i in eachindex(ts) - ] - - return JuliaBUGS.gen_chains( - logdensitymodel.logdensity.model, - [t.z.θ for t in ts], - stats_names, - stats_values; - discard_initial=discard_initial, - thinning=thinning, - kwargs..., - ) -end - -# Specify capabilities (indicates gradient availability) -function LogDensityProblems.capabilities(::Type{<:BUGSMooncakeModel}) - return LogDensityProblems.LogDensityOrder{1}() # Can compute up to the gradient -end +# Use graph evaluation mode with ForwardDiff AD (required for user-defined primitives) +model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph()) +grad_model = JuliaBUGS.BUGSModelWithGradient(model, AutoForwardDiff()) # --- MCMC Sampling --- # Sample from the posterior distribution using AdvancedHMC's NUTS sampler samples_and_stats = AbstractMCMC.sample( - AbstractMCMC.LogDensityModel(bugsmooncake), # Wrap the model for AbstractMCMC + grad_model, AdvancedHMC.NUTS(0.65), # No-U-Turn Sampler 1000; # Total number of samples chain_type=Chains, # Store results as MCMCChains object n_adapts=500, # Number of adaptation steps for NUTS - discard_initial=500, # Number of initial samples (warmup) to discard; + discard_initial=500, # Number of initial samples (warmup) to discard ) diff --git a/JuliaBUGS/examples/sir.jl b/JuliaBUGS/examples/sir.jl index 108d47ce1..5e915e5b3 100644 --- a/JuliaBUGS/examples/sir.jl +++ b/JuliaBUGS/examples/sir.jl @@ -6,7 +6,8 @@ using JuliaBUGS using JuliaBUGS: @model using Distributions using DifferentialEquations -using LogDensityProblems, LogDensityProblemsAD +using LogDensityProblems +using ADTypes using AbstractMCMC, AdvancedHMC, MCMCChains using Distributed # For distributed example @@ -112,8 +113,8 @@ model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGraph()) # --- MCMC Sampling: NUTS with ForwardDiff AD --- -# Create an AD-aware wrapper for the model using ForwardDiff for gradients -ad_model_forwarddiff = ADgradient(:ForwardDiff, model) +# Create gradient-enabled model using ForwardDiff +grad_model = JuliaBUGS.BUGSModelWithGradient(model, AutoForwardDiff()) # MCMC settings n_samples = 1000 @@ -121,7 +122,7 @@ n_adapts = 500 # Run the NUTS sampler samples_nuts_fwd = AbstractMCMC.sample( - ad_model_forwarddiff, + grad_model, AdvancedHMC.NUTS(0.65), # No-U-Turn Sampler with step size adaptation target n_samples; chain_type=Chains, # Store results as MCMCChains object diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl index 179ef02b5..eca960856 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl @@ -4,9 +4,8 @@ using AbstractMCMC using AdvancedHMC using ADTypes using JuliaBUGS -using JuliaBUGS: BUGSModel, getparams, initialize! +using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize! using JuliaBUGS.LogDensityProblems -using JuliaBUGS.LogDensityProblemsAD using JuliaBUGS.Random using MCMCChains: Chains @@ -40,10 +39,10 @@ end function _gibbs_internal_hmc( rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state ) - # Wrap model with AD gradient computation - logdensitymodel = AbstractMCMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(ad_backend, cond_model) - ) + # Create gradient model on-the-fly + ad_model = BUGSModelWithGradient(cond_model, ad_backend) + x = getparams(cond_model) + logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) # Take HMC/NUTS step if isnothing(state) @@ -53,7 +52,7 @@ function _gibbs_internal_hmc( logdensitymodel, sampler; n_adapts=0, # Disable adaptation within Gibbs - initial_params=getparams(cond_model), + initial_params=x, ) else # Use existing state for subsequent steps @@ -67,7 +66,7 @@ end function AbstractMCMC.bundle_samples( ts::Vector{<:AdvancedHMC.Transition}, - logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, + logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient}, sampler::AdvancedHMC.AbstractHMCSampler, state, chain_type::Type{Chains}; diff --git a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl index ca30555be..edab75d99 100644 --- a/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl @@ -4,9 +4,8 @@ using AbstractMCMC using AdvancedMH using ADTypes using JuliaBUGS -using JuliaBUGS: BUGSModel, getparams, initialize! +using JuliaBUGS: BUGSModel, BUGSModelWithGradient, getparams, initialize! using JuliaBUGS.LogDensityProblems -using JuliaBUGS.LogDensityProblemsAD using JuliaBUGS.Random using MCMCChains: Chains @@ -52,10 +51,10 @@ end function _gibbs_internal_mh( rng::Random.AbstractRNG, cond_model::BUGSModel, sampler, ad_backend, state ) - # Wrap model with AD gradient computation for gradient-based proposals - logdensitymodel = AbstractMCMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(ad_backend, cond_model) - ) + # Create gradient model on-the-fly + ad_model = BUGSModelWithGradient(cond_model, ad_backend) + x = getparams(cond_model) + logdensitymodel = AbstractMCMC.LogDensityModel(ad_model) # Take MH step with gradient information if isnothing(state) @@ -64,7 +63,7 @@ function _gibbs_internal_mh( logdensitymodel, sampler; n_adapts=0, # Disable adaptation within Gibbs - initial_params=getparams(cond_model), + initial_params=x, ) else t, s = AbstractMCMC.step(rng, logdensitymodel, sampler, state; n_adapts=0) @@ -105,7 +104,7 @@ end function AbstractMCMC.bundle_samples( ts::Vector{<:AdvancedMH.Transition}, - logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, + logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient}, sampler::AdvancedMH.MHSampler, state, chain_type::Type{Chains}; @@ -113,7 +112,6 @@ function AbstractMCMC.bundle_samples( thinning=1, kwargs..., ) - # Same extraction for gradient-based MH samplers param_samples = [t.params for t in ts] stats_names = [:lp] stats_values = [[t.lp] for t in ts] diff --git a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl index eec864093..224883c36 100644 --- a/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl +++ b/JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl @@ -2,10 +2,14 @@ module JuliaBUGSMCMCChainsExt using AbstractMCMC using JuliaBUGS -using JuliaBUGS: BUGSModel, find_generated_quantities_variables, evaluate!!, getparams +using JuliaBUGS: + BUGSModel, + BUGSModelWithGradient, + find_generated_quantities_variables, + evaluate!!, + getparams using JuliaBUGS.AbstractPPL using JuliaBUGS.Accessors -using JuliaBUGS.LogDensityProblemsAD using MCMCChains: Chains function JuliaBUGS.gen_chains( @@ -22,14 +26,14 @@ function JuliaBUGS.gen_chains( end function JuliaBUGS.gen_chains( - model::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}, + model::AbstractMCMC.LogDensityModel{<:BUGSModelWithGradient}, samples, stats_names, stats_values; kwargs..., ) - # Extract BUGSModel from ADGradient wrapper - bugs_model = model.logdensity.ℓ + # Extract BUGSModel from gradient wrapper + bugs_model = model.logdensity.base_model return JuliaBUGS.gen_chains(bugs_model, samples, stats_names, stats_values; kwargs...) end diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index 7f51e13e3..5e279e4e4 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -6,6 +6,7 @@ using Accessors using ADTypes using BangBang using Bijectors: Bijectors +using DifferentiationInterface using Distributions using Graphs, MetaGraphsNext using LinearAlgebra @@ -17,6 +18,7 @@ using Serialization: Serialization using StaticArrays import Base: ==, hash, Symbol, size +import DifferentiationInterface as DI import Distributions: truncated export @bugs @@ -234,16 +236,21 @@ function validate_bugs_expression(expr, line_num) end """ - compile(model_def, data[, initial_params]; skip_validation=false) + compile(model_def, data[, initial_params]; adtype=nothing) -Compile the model with model definition and data. Optionally, initializations can be provided. -If initializations are not provided, values will be sampled from the prior distributions. +Compile a BUGS model. Returns `BUGSModel`, or `BUGSModelWithGradient` if `adtype` is provided. -By default, validates that all functions in the model are in the BUGS allowlist (suitable for @bugs macro). -Set `skip_validation=true` to skip validation (for @model macro usage). +# Arguments +- `model_def::Expr`: Model definition from `@bugs` macro +- `data::NamedTuple`: Observed data +- `initial_params::NamedTuple`: Initial parameter values (optional, defaults to prior samples) +- `adtype`: AD backend from ADTypes.jl (e.g., `AutoReverseDiff()`, `AutoForwardDiff()`, `AutoMooncake()`) -The compiled model uses `UseGraph` evaluation mode by default. To use the optimized generated -log-density function, call `set_evaluation_mode(model, UseGeneratedLogDensityFunction())`. +# Examples +```julia +model = compile(model_def, data) +model = compile(model_def, data; adtype=AutoReverseDiff()) +``` """ function compile( model_def::Expr, @@ -251,6 +258,7 @@ function compile( initial_params::NamedTuple=NamedTuple(); skip_validation::Bool=false, eval_module::Module=@__MODULE__, + adtype::Union{Nothing,ADTypes.AbstractADType,Symbol}=nothing, ) # Validate functions by default (for @bugs macro usage) # Skip validation only for @model macro @@ -279,18 +287,15 @@ function compile( values(eval_env), ), ) - return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params, true) + base_model = BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params, true) + + # If adtype provided, wrap with gradient capabilities + if adtype !== nothing + return Base.invokelatest(Model.BUGSModelWithGradient, base_model, adtype) + end + + return base_model end -# function compile( -# model_str::String, -# data::NamedTuple, -# initial_params::NamedTuple=NamedTuple(); -# replace_period::Bool=true, -# no_enclosure::Bool=false, -# ) -# model_def = _bugs_string_input(model_str, replace_period, no_enclosure) -# return compile(model_def, data, initial_params) -# end """ register_bugs_function(func_name::Symbol) diff --git a/JuliaBUGS/src/gibbs.jl b/JuliaBUGS/src/gibbs.jl index fa71d85b9..8a1c4e870 100644 --- a/JuliaBUGS/src/gibbs.jl +++ b/JuliaBUGS/src/gibbs.jl @@ -432,7 +432,7 @@ function AbstractMCMC.step( # For gradient-based samplers, wrap with AD _, ad_backend = sub_sampler logdensitymodel = AbstractMCMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(ad_backend, cond_model) + Model.BUGSModelWithGradient(cond_model, ad_backend) ) else # For non-gradient samplers, use model directly diff --git a/JuliaBUGS/src/model/Model.jl b/JuliaBUGS/src/model/Model.jl index 0a89eef4e..f5d059041 100644 --- a/JuliaBUGS/src/model/Model.jl +++ b/JuliaBUGS/src/model/Model.jl @@ -2,8 +2,10 @@ module Model using Accessors using AbstractPPL +using ADTypes using BangBang using Bijectors +import DifferentiationInterface as DI using Distributions using Graphs using LinearAlgebra @@ -26,6 +28,9 @@ export set_evaluation_mode, set_observed_values! # Evaluation mode types export UseGraph, UseGeneratedLogDensityFunction, UseAutoMarginalization +# Gradient wrapper +export BUGSModelWithGradient + # Internal evaluation functions (exported for testing, not re-exported to users) export evaluate_with_rng!!, evaluate_with_env!!, evaluate_with_values!! export evaluate_with_marginalization_values!! diff --git a/JuliaBUGS/src/model/logdensityproblems.jl b/JuliaBUGS/src/model/logdensityproblems.jl index 82e80232b..a8eb9e45c 100644 --- a/JuliaBUGS/src/model/logdensityproblems.jl +++ b/JuliaBUGS/src/model/logdensityproblems.jl @@ -55,3 +55,107 @@ end function LogDensityProblems.capabilities(::AbstractBUGSModel) return LogDensityProblems.LogDensityOrder{0}() end + +""" + BUGSModelWithGradient{AD,P,M} + +Wrap a `BUGSModel` with AD capabilities for gradient-based inference. + +Implements `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. + +# Fields +- `adtype::AD`: AD backend (e.g., `AutoReverseDiff()`) +- `prep::P`: Prepared gradient from DifferentiationInterface +- `base_model::M`: The underlying `BUGSModel` + +See also [`compile`](@ref). +""" +struct BUGSModelWithGradient{AD<:ADTypes.AbstractADType,P,M<:BUGSModel} + adtype::AD + prep::P + base_model::M +end + +""" + BUGSModelWithGradient(model::BUGSModel, adtype::ADTypes.AbstractADType) + +Construct a gradient-enabled model wrapper from a BUGSModel and an AD backend. + +# AD Backend Compatibility + +Different AD backends have different compatibility with evaluation modes: + +- **`UseGeneratedLogDensityFunction`**: Only compatible with mutation-supporting backends + like `AutoMooncake` and `AutoEnzyme`. The generated functions mutate arrays in-place. +- **`UseGraph`**: Compatible with `AutoReverseDiff`, `AutoForwardDiff`, and other + tape-based or forward-mode backends. Also works with Mooncake and Enzyme. + +If an incompatible combination is detected, a warning is issued and the model is +automatically switched to `UseGraph` mode. + +# Example +```julia +model = compile(model_def, data) +grad_model = BUGSModelWithGradient(model, AutoReverseDiff(compile=true)) +``` +""" +function BUGSModelWithGradient(model::BUGSModel, adtype::ADTypes.AbstractADType) + # Check AD backend compatibility with evaluation mode + model = _check_ad_compatibility(model, adtype) + + x = getparams(model) + prep = DI.prepare_gradient(_logdensity_for_gradient, adtype, x, DI.Constant(model)) + return BUGSModelWithGradient(adtype, prep, model) +end + +# AD backends that support mutation (required for UseGeneratedLogDensityFunction) +_supports_mutation(::ADTypes.AutoMooncake) = true +_supports_mutation(::ADTypes.AutoEnzyme) = true +_supports_mutation(::ADTypes.AbstractADType) = false + +function _check_ad_compatibility(model::BUGSModel, adtype::ADTypes.AbstractADType) + if model.evaluation_mode isa UseGeneratedLogDensityFunction && + !_supports_mutation(adtype) + @warn "AD backend $(typeof(adtype)) does not support mutation required by " * + "UseGeneratedLogDensityFunction mode. Switching to UseGraph mode." maxlog = 1 + return set_evaluation_mode(model, UseGraph()) + end + return model +end + +# Forward base BUGSModel interface +function LogDensityProblems.logdensity(model::BUGSModelWithGradient, x::AbstractVector) + return LogDensityProblems.logdensity(model.base_model, x) +end + +function LogDensityProblems.dimension(model::BUGSModelWithGradient) + return LogDensityProblems.dimension(model.base_model) +end + +function LogDensityProblems.capabilities(::Type{<:BUGSModelWithGradient}) + return LogDensityProblems.LogDensityOrder{1}() # Gradient available +end + +""" + _logdensity_for_gradient(x, model) + +Target function for gradient computation via DifferentiationInterface. +The parameter vector `x` comes first (the argument to differentiate w.r.t.), +and the model is passed as a constant context (not differentiated). +""" +function _logdensity_for_gradient(x::AbstractVector, model::BUGSModel) + return _eval_logdensity(model, model.evaluation_mode, x) +end + +""" + LogDensityProblems.logdensity_and_gradient(model::BUGSModelWithGradient, x) + +Compute log density and its gradient using DifferentiationInterface. +""" +function LogDensityProblems.logdensity_and_gradient( + model::BUGSModelWithGradient, x::AbstractVector +) + return DI.value_and_gradient( + _logdensity_for_gradient, model.prep, model.adtype, x, DI.Constant(model.base_model) + ) +end diff --git a/JuliaBUGS/test/BUGSPrimitives/distributions.jl b/JuliaBUGS/test/BUGSPrimitives/distributions.jl index 69505e2f7..82c4f04af 100644 --- a/JuliaBUGS/test/BUGSPrimitives/distributions.jl +++ b/JuliaBUGS/test/BUGSPrimitives/distributions.jl @@ -15,9 +15,10 @@ end A[1:2, 1:2] ~ dwish(B[:, :], 2) C[1:2] ~ dmnorm(mu[:], A[:, :]) end - model = compile(model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],)) + ad_model = compile( + model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],); adtype=AutoReverseDiff() + ) - ad_model = ADgradient(:ReverseDiff, model) theta = [ 0.7931743744870574, 0.5151017206811268, diff --git a/JuliaBUGS/test/Project.toml b/JuliaBUGS/test/Project.toml index 9c02130ee..f03a6ce3b 100644 --- a/JuliaBUGS/test/Project.toml +++ b/JuliaBUGS/test/Project.toml @@ -7,6 +7,9 @@ AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -31,12 +34,15 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] ADTypes = "1.14.0" AbstractMCMC = "5" -AbstractPPL = "0.8.4, 0.9, 0.10, 0.11" -AdvancedHMC = "0.6, 0.7" +AbstractPPL = "0.8.4, 0.9, 0.10, 0.11, 0.12, 0.13" +AdvancedHMC = "0.6, 0.7, 0.8" AdvancedMH = "0.8" BangBang = "0.4.1" ChainRules = "1" +DifferentiationInterface = "0.7" Distributions = "0.23.8, 0.24, 0.25" +ForwardDiff = "1" +Mooncake = "0.4" Documenter = "0.27, 1" Graphs = "1" JuliaSyntax = "1" @@ -44,7 +50,7 @@ LinearAlgebra = "1.10" LogDensityProblems = "2" LogDensityProblemsAD = "1" LogExpFunctions = "0.3" -MCMCChains = "6" +MCMCChains = "6, 7" MacroTools = "0.5" MetaGraphsNext = "0.6, 0.7" OrderedCollections = "1" diff --git a/JuliaBUGS/test/ad_compatibility.jl b/JuliaBUGS/test/ad_compatibility.jl new file mode 100644 index 000000000..c8ffe6d9c --- /dev/null +++ b/JuliaBUGS/test/ad_compatibility.jl @@ -0,0 +1,116 @@ +@testset "AD Backend Compatibility" begin + # Use a simpler model for testing AD compatibility + # (similar to existing tests in JuliaBUGSAdvancedHMCExt.jl) + model_def = @bugs begin + mu ~ dnorm(0, 1) + for i in 1:N + y[i] ~ dnorm(mu, 1) + end + end + data = (N=5, y=[1.0, 2.0, 1.5, 2.5, 1.8]) + + @testset "UseGraph mode" begin + model = compile(model_def, data) + @test model.evaluation_mode isa JuliaBUGS.UseGraph + + x = JuliaBUGS.getparams(model) + + @testset "AutoReverseDiff" begin + grad_model = JuliaBUGS.BUGSModelWithGradient(model, AutoReverseDiff()) + @test grad_model isa JuliaBUGS.BUGSModelWithGradient + @test grad_model.base_model.evaluation_mode isa JuliaBUGS.UseGraph + + # Test gradient computation works + logp, grad = LogDensityProblems.logdensity_and_gradient(grad_model, x) + @test isfinite(logp) + @test all(isfinite, grad) + end + + @testset "AutoForwardDiff" begin + grad_model = JuliaBUGS.BUGSModelWithGradient(model, AutoForwardDiff()) + @test grad_model isa JuliaBUGS.BUGSModelWithGradient + @test grad_model.base_model.evaluation_mode isa JuliaBUGS.UseGraph + + logp, grad = LogDensityProblems.logdensity_and_gradient(grad_model, x) + @test isfinite(logp) + @test all(isfinite, grad) + end + + @testset "Gradient consistency across backends" begin + rd_model = JuliaBUGS.BUGSModelWithGradient(model, AutoReverseDiff()) + fd_model = JuliaBUGS.BUGSModelWithGradient(model, AutoForwardDiff()) + + logp_rd, grad_rd = LogDensityProblems.logdensity_and_gradient(rd_model, x) + logp_fd, grad_fd = LogDensityProblems.logdensity_and_gradient(fd_model, x) + + @test logp_rd ≈ logp_fd + @test grad_rd ≈ grad_fd rtol = 1e-6 + end + end + + @testset "UseGeneratedLogDensityFunction mode" begin + model = compile(model_def, data) + model = JuliaBUGS.set_evaluation_mode( + model, JuliaBUGS.UseGeneratedLogDensityFunction() + ) + @test model.evaluation_mode isa JuliaBUGS.UseGeneratedLogDensityFunction + + x = JuliaBUGS.getparams(model) + + @testset "AutoReverseDiff - should warn and switch to UseGraph" begin + grad_model = @test_warn "does not support mutation" JuliaBUGS.BUGSModelWithGradient( + model, AutoReverseDiff() + ) + @test grad_model.base_model.evaluation_mode isa JuliaBUGS.UseGraph + + # Should still work after switching + logp, grad = LogDensityProblems.logdensity_and_gradient(grad_model, x) + @test isfinite(logp) + @test all(isfinite, grad) + end + + @testset "AutoForwardDiff - should switch to UseGraph" begin + # Note: Warning is suppressed due to maxlog=1 (already shown in ReverseDiff test) + grad_model = JuliaBUGS.BUGSModelWithGradient(model, AutoForwardDiff()) + @test grad_model.base_model.evaluation_mode isa JuliaBUGS.UseGraph + + logp, grad = LogDensityProblems.logdensity_and_gradient(grad_model, x) + @test isfinite(logp) + @test all(isfinite, grad) + end + + @testset "AutoMooncake - should work without warning" begin + grad_model = JuliaBUGS.BUGSModelWithGradient( + model, AutoMooncake(; config=nothing) + ) + @test grad_model.base_model.evaluation_mode isa + JuliaBUGS.UseGeneratedLogDensityFunction + + logp, grad = LogDensityProblems.logdensity_and_gradient(grad_model, x) + @test isfinite(logp) + @test all(isfinite, grad) + end + end + + @testset "compile with adtype parameter" begin + @testset "AutoReverseDiff" begin + grad_model = compile(model_def, data; adtype=AutoReverseDiff()) + @test grad_model isa JuliaBUGS.BUGSModelWithGradient + + x = JuliaBUGS.getparams(grad_model.base_model) + logp, grad = LogDensityProblems.logdensity_and_gradient(grad_model, x) + @test isfinite(logp) + @test all(isfinite, grad) + end + + @testset "AutoForwardDiff" begin + grad_model = compile(model_def, data; adtype=AutoForwardDiff()) + @test grad_model isa JuliaBUGS.BUGSModelWithGradient + + x = JuliaBUGS.getparams(grad_model.base_model) + logp, grad = LogDensityProblems.logdensity_and_gradient(grad_model, x) + @test isfinite(logp) + @test all(isfinite, grad) + end + end +end diff --git a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl index 9115ac9dd..31cedc9de 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl @@ -6,10 +6,9 @@ y = x[1] + x[3] end data = (mu=[0, 0], sigma=[1 0; 0 1]) - model = compile(model_def, data) - ad_model = Base.invokelatest(ADgradient, :ReverseDiff, model; compile=Val(true)) + ad_model = compile(model_def, data; adtype=AutoReverseDiff(; compile=true)) n_samples, n_adapts = 10, 0 - D = LogDensityProblems.dimension(model) + D = LogDensityProblems.dimension(ad_model) initial_θ = rand(D) samples_and_stats = Base.invokelatest( AbstractMCMC.sample, @@ -33,13 +32,14 @@ (; model_def, data, inits, reference_results) = Base.getfield( JuliaBUGS.BUGSExamples, example ) - model = JuliaBUGS.compile(model_def, data, inits) - ad_model = Base.invokelatest(ADgradient, :ReverseDiff, model; compile=Val(true)) + ad_model = JuliaBUGS.compile( + model_def, data, inits; adtype=AutoReverseDiff(; compile=true) + ) n_samples, n_adapts = 1000, 1000 - D = LogDensityProblems.dimension(model) - initial_θ = Base.invokelatest(JuliaBUGS.getparams, model) + D = LogDensityProblems.dimension(ad_model) + initial_θ = Base.invokelatest(JuliaBUGS.getparams, ad_model.base_model) samples_and_stats = Base.invokelatest( AbstractMCMC.sample, diff --git a/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl b/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl index 0c4383e40..c22de01b0 100644 --- a/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl +++ b/JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl @@ -27,10 +27,10 @@ ) model = compile(model_def, data, (;)) - ad_model = Base.invokelatest(ADgradient, :ReverseDiff, model; compile=Val(true)) + ad_model = compile(model_def, data, (;); adtype=AutoReverseDiff(; compile=true)) n_samples, n_adapts = 2000, 1000 - D = LogDensityProblems.dimension(model) + D = LogDensityProblems.dimension(ad_model) initial_θ = rand(D) hmc_chain = Base.invokelatest( @@ -73,8 +73,7 @@ n_samples, n_adapts = 20000, 5000 - mh_chain = Base.invokelatest( - AbstractMCMC.sample, + mh_chain = AbstractMCMC.sample( model, RWMH(MvNormal(zeros(D), I)), n_samples; @@ -109,10 +108,9 @@ sigma[2] ~ InverseGamma(2, 3) sigma[3] ~ InverseGamma(2, 3) end - model = compile(model_def, (;)) - ad_model = Base.invokelatest(ADgradient, :ReverseDiff, model; compile=Val(true)) - hmc_chain = Base.invokelatest( - AbstractMCMC.sample, ad_model, NUTS(0.8), 10; progress=false, chain_type=Chains + ad_model = compile(model_def, (;); adtype=AutoReverseDiff(; compile=true)) + hmc_chain = AbstractMCMC.sample( + ad_model, NUTS(0.8), 10; progress=false, chain_type=Chains ) @test Set(hmc_chain.name_map[:parameters]) == Set([ Symbol("sigma[3]"), diff --git a/JuliaBUGS/test/parallel_sampling.jl b/JuliaBUGS/test/parallel_sampling.jl index 7871aca7f..23dfa4ebf 100644 --- a/JuliaBUGS/test/parallel_sampling.jl +++ b/JuliaBUGS/test/parallel_sampling.jl @@ -19,9 +19,8 @@ data = (N=N, x=x_data) inits = (mu=0.0, tau=1.0) - model = compile(model_def, data, inits) - # Use compile=Val(false) for thread safety with ReverseDiff - ad_model = ADgradient(:ReverseDiff, model; compile=Val(false)) + # Use compile=false for thread safety with ReverseDiff + ad_model = compile(model_def, data, inits; adtype=AutoReverseDiff(; compile=false)) # Single chain reference n_samples = 200 diff --git a/JuliaBUGS/test/runtests.jl b/JuliaBUGS/test/runtests.jl index 98a021394..7bbb1d102 100644 --- a/JuliaBUGS/test/runtests.jl +++ b/JuliaBUGS/test/runtests.jl @@ -40,6 +40,8 @@ using AdvancedHMC using AdvancedMH using MCMCChains using ReverseDiff +using ForwardDiff +using Mooncake JuliaBUGS.@bugs_primitive Beta Bernoulli Categorical Exponential Gamma InverseGamma Normal Uniform LogNormal Poisson JuliaBUGS.@bugs_primitive Diagonal Dirichlet LKJ MvNormal @@ -96,6 +98,7 @@ const TEST_GROUPS = OrderedDict{String,Function}( "inference_mh" => () -> include("independent_mh.jl"), "gibbs" => () -> include("gibbs.jl"), "parallel_sampling" => () -> include("parallel_sampling.jl"), + "ad_compatibility" => () -> include("ad_compatibility.jl"), "experimental" => () -> include("experimental/ProbabilisticGraphicalModels/runtests.jl"), )