@@ -320,14 +320,14 @@ function compile(
320320 ),
321321 )
322322 base_model = BUGSModel (g, nonmissing_eval_env, model_def, data, initial_params)
323-
323+
324324 # If adtype provided, wrap with gradient capabilities
325325 if adtype != = nothing
326326 # Convert symbol to ADType if needed
327327 adtype_obj = _resolve_adtype (adtype)
328328 return _wrap_with_gradient (base_model, adtype_obj)
329329 end
330-
330+
331331 return base_model
332332end
333333
@@ -344,17 +344,19 @@ Supported symbol shortcuts:
344344"""
345345function _resolve_adtype (adtype:: Symbol )
346346 if adtype === :ReverseDiff
347- return ADTypes. AutoReverseDiff (compile= true )
347+ return ADTypes. AutoReverseDiff (; compile= true )
348348 elseif adtype === :ForwardDiff
349349 return ADTypes. AutoForwardDiff ()
350350 elseif adtype === :Zygote
351351 return ADTypes. AutoZygote ()
352352 elseif adtype === :Enzyme
353353 return ADTypes. AutoEnzyme ()
354354 else
355- error (" Unknown AD backend symbol: $adtype . " *
356- " Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " *
357- " Or use an ADTypes object like AutoReverseDiff(compile=true)." )
355+ error (
356+ " Unknown AD backend symbol: $adtype . " *
357+ " Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " *
358+ " Or use an ADTypes object like AutoReverseDiff(compile=true)." ,
359+ )
358360 end
359361end
360362
@@ -366,17 +368,13 @@ function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.Abstra
366368 # Get initial parameters for preparation
367369 # Use invokelatest to handle world age issues with generated functions
368370 x = Base. invokelatest (getparams, base_model)
369-
371+
370372 # Prepare gradient using DifferentiationInterface
371373 # Use invokelatest to handle world age issues when calling logdensity during preparation
372374 prep = Base. invokelatest (
373- DI. prepare_gradient,
374- Model. _logdensity_switched,
375- adtype,
376- x,
377- DI. Constant (base_model)
375+ DI. prepare_gradient, Model. _logdensity_switched, adtype, x, DI. Constant (base_model)
378376 )
379-
377+
380378 return Model. BUGSModelWithGradient (adtype, prep, base_model)
381379end
382380# function compile(
0 commit comments