Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,49 @@
This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.

For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.
For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/logdensityfunction.jl` file, which contains extensive comments.

As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it.
In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`.
If you were previously relying on this behaviour, you will need to store a VarInfo separately.

#### Threadsafe evaluation

DynamicPPL models have traditionally supported running some probabilistic statements (e.g. tilde-statements, or `@addlogprob!`) in parallel.
Prior to DynamicPPL 0.39, thread safety for such models used to be enabled by default if Julia was launched with more than one thread.

In DynamicPPL 0.39, **thread-safe evaluation is now disabled by default**.
If you need it (see below for more discussion of when you _do_ need it), you **must** now manually mark it as so, using:

```julia
@model f() = ...
model = f()
model = setthreadsafe(model, true)
```

The problem with the previous on-by-default is that it can sacrifice a huge amount of performance when thread safety is not needed.
This is especially true when running Julia in a notebook, where multiple threads are often enabled by default.
Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation.

**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.**
This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros:

- tilde-statements
- calls to `@addlogprob!`
- any direct manipulation of the special `__varinfo__` variable

If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe.
**Notably, the following do not require threadsafe evaluation:**

- Using threading for any computation that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation.
- Sampling with `AbstractMCMC.MCMCThreads()`.

For more information about threadsafe evaluation, please see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/).

When threadsafe evaluation is enabled for a model, an internal flag is set on the model.
The value of this flag can be queried using `DynamicPPL.requires_threadsafe(model)`, which returns a boolean.
This function is newly exported in this version of DynamicPPL.
Comment on lines +51 to +53
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realised that Turing needs this function (basically PG/SMC should error any time they encounter a model that needs threadsafe eval -- TuringLang/Turing.jl#2658), so we need to export it. Other changes just follow on from review!


#### Parent and leaf contexts

The `DynamicPPL.NodeTrait` function has been removed.
Expand Down
8 changes: 8 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ The context of a model can be set using [`contextualize`](@ref):
contextualize
```

Some models require threadsafe evaluation (see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more information on when this is necessary).
If this is the case, one must enable threadsafe evaluation for a model:

```@docs
setthreadsafe
requires_threadsafe
```

## Evaluation

With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ export AbstractVarInfo,
Model,
getmissings,
getargnames,
setthreadsafe,
requires_threadsafe,
extract_priors,
values_as_in_model,
# evaluation
Expand Down
54 changes: 40 additions & 14 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn)
modeldef = build_model_definition(expr)

# Generate main body
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn)
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, true)

return build_output(modeldef, linenumbernode)
end
Expand Down Expand Up @@ -346,36 +346,59 @@ Generate the body of the main evaluation function from expression `expr` and arg
If `warn` is true, a warning is displayed if internal variables are used in the model
definition.
"""
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
generate_mainbody(mod, expr, warn, warn_threads) =
generate_mainbody!(mod, Symbol[], expr, warn, warn_threads)

generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
generate_mainbody!(mod, found, x, warn, warn_threads) = x
function generate_mainbody!(mod, found, sym::Symbol, warn, warn_threads)
if warn && sym in INTERNALNAMES && sym ∉ found
@warn "you are using the internal variable `$sym`"
push!(found, sym)
end

return sym
end
function generate_mainbody!(mod, found, expr::Expr, warn)
function generate_mainbody!(mod, found, expr::Expr, warn, warn_threads)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# Flag to determine whether we've issued a warning for threadsafe macros Note that this
# detection is not fully correct. We can only detect the presence of a macro that has
# the symbol `Threads.@threads`, however, we can't detect if that *is actually*
# Threads.@threads from Base.Threads.

# Do we don't want escaped expressions because we unfortunately
# escape the entire body afterwards.
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn)
Meta.isexpr(expr, :escape) &&
return generate_mainbody(mod, found, expr.args[1], warn, warn_threads)

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
if (
expr.args[1] == Symbol("@threads") ||
expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) &&
warn_threads
)
warn_threads = false
@warn (
"It looks like you are using `Threads.@threads` in your model definition." *
"\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." *
" If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." *
"\n\nThreadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements." *
"\n\nPlease see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required."
)
end
return generate_mainbody!(
mod, found, macroexpand(mod, expr; recursive=true), warn, warn_threads
)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return generate_mainbody!(
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn, warn_threads
)
end

Expand All @@ -385,8 +408,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
L, R = args_tilde
return Base.remove_linenums!(
generate_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found, L, warn, warn_threads),
generate_mainbody!(mod, found, R, warn, warn_threads),
),
)
end
Expand All @@ -397,13 +420,16 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
L, R = args_assign
return Base.remove_linenums!(
generate_assign(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found, L, warn, warn_threads),
generate_mainbody!(mod, found, R, warn, warn_threads),
),
)
end

return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
return Expr(
expr.head,
map(x -> generate_mainbody!(mod, found, x, warn, warn_threads), expr.args)...,
)
end

function generate_assign(left, right)
Expand Down Expand Up @@ -699,7 +725,7 @@ function build_output(modeldef, linenumbernode)
# to the call site
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
return $(DynamicPPL.Model)($name, $args_nt; $(kwargs_inclusion...))
return $(DynamicPPL.Model){false}($name, $args_nt; $(kwargs_inclusion...))
end

return MacroTools.@q begin
Expand Down
6 changes: 4 additions & 2 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,10 @@ function check_model_and_trace(
# Perform checks before evaluating the model.
issuccess = check_model_pre_evaluation(model)

# Force single-threaded execution.
_, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
# TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a
# check on the merged accumulator, rather than checking it in the accumulate_assume
# calls. That way we can also correctly support multi-threaded evaluation.
_, varinfo = DynamicPPL.evaluate!!(model, varinfo)
Comment on lines -427 to +430
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the last thing that I'm a bit displeased about, but it's still better than on main (I wrote up an issue here #1157), so I thinkkkkk we can leave the proper fix to another PR.


# Perform checks after evaluating the model.
debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME))
Expand Down
Loading
Loading