@@ -881,30 +881,54 @@ end
881881 [init_strategy::AbstractInitStrategy=InitFromPrior()]
882882 )
883883
884- Evaluate the `model` and replace the values of the model's random variables
885- in the given `varinfo` with new values, using a specified initialisation strategy.
886- If the values in `varinfo` are not set, they will be added
887- using a specified initialisation strategy.
884+ Evaluate the `model` and replace the values of the model's random variables in the given
885+ `varinfo` with new values, using a specified initialisation strategy. If the values in
886+ `varinfo` are not set, they will be added using a specified initialisation strategy.
888887
889888If `init_strategy` is not provided, defaults to `InitFromPrior()`.
890889
891890Returns a tuple of the model's return value, plus the updated `varinfo` object.
892891"""
893- function init!! (
892+ @inline function init!! (
893+ # Note that this `@inline` is mandatory for performance, especially for
894+ # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for
895+ # trivial models) and much slower runtime.
894896 rng:: Random.AbstractRNG ,
895897 model:: Model ,
896- varinfo :: AbstractVarInfo ,
897- init_strategy :: AbstractInitStrategy = InitFromPrior () ,
898+ vi :: AbstractVarInfo ,
899+ strategy :: AbstractInitStrategy ,
898900)
899- new_model = setleafcontext (model, InitContext (rng, init_strategy))
900- return evaluate!! (new_model, varinfo)
901+ ctx = InitContext (rng, strategy)
902+ model = DynamicPPL. setleafcontext (model, ctx)
903+ # TODO (penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
904+ # it _should_ do, but this is wrong regardless.
905+ # https://github.com/TuringLang/DynamicPPL.jl/issues/1086
906+ return if Threads. nthreads () > 1
907+ # TODO (penelopeysm): The logic for setting eltype of accs is very similar to that
908+ # used in `unflatten`. The reason why we need it here is because the VarInfo `vi`
909+ # won't have been filled with parameters prior to `init!!` being called.
910+ #
911+ # Note that this eltype promotion is only needed for threadsafe evaluation. In an
912+ # ideal world, this code should be handled inside `evaluate_threadsafe!!` or a
913+ # similar method. In other words, it should not be here, and it should not be inside
914+ # `unflatten` either. The problem is performance. Shifting this code around can have
915+ # massive, inexplicable, impacts on performance. This should be investigated
916+ # properly.
917+ param_eltype = DynamicPPL. get_param_eltype (strategy)
918+ accs = map (vi. accs) do acc
919+ DynamicPPL. convert_eltype (float_type_with_fallback (param_eltype), acc)
920+ end
921+ vi = DynamicPPL. setaccs!! (vi, accs)
922+ tsvi = ThreadSafeVarInfo (resetaccs!! (vi))
923+ retval, tsvi_new = DynamicPPL. _evaluate!! (model, tsvi)
924+ return retval, setaccs!! (vi, DynamicPPL. getaccs (tsvi_new))
925+ else
926+ return DynamicPPL. _evaluate!! (model, resetaccs!! (vi))
927+ end
901928end
902- function init!! (
903- model:: Model ,
904- varinfo:: AbstractVarInfo ,
905- init_strategy:: AbstractInitStrategy = InitFromPrior (),
906- )
907- return init!! (Random. default_rng (), model, varinfo, init_strategy)
929+ @inline function init!! (model:: Model , vi:: AbstractVarInfo , strategy:: AbstractInitStrategy )
930+ # This `@inline` is also mandatory for performance
931+ return init!! (Random. default_rng (), model, vi, strategy)
908932end
909933
910934"""
0 commit comments