-
Notifications
You must be signed in to change notification settings - Fork 231
Update for DynamicPPL 0.39 #2715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: breaking
Are you sure you want to change the base?
Conversation
|
Turing.jl documentation for PR #2715 is available at: |
24142ee to
507d814
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## breaking #2715 +/- ##
============================================
- Coverage 86.33% 85.44% -0.90%
============================================
Files 21 21
Lines 1383 1257 -126
============================================
- Hits 1194 1074 -120
+ Misses 189 183 -6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| ###################### | ||
| # Default Transition # | ||
| ###################### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All this stuff has basically been upstreamed to DynamicPPL and/or AbstractMCMC.
| accs = DynamicPPL.AccumulatorTuple(( | ||
| DynamicPPL.ValuesAsInModelAccumulator(true), | ||
| DynamicPPL.LogPriorAccumulator(), | ||
| DynamicPPL.LogLikelihoodAccumulator(), | ||
| )) | ||
| vi = DynamicPPL.OnlyAccsVarInfo(accs) | ||
| _, vi = DynamicPPL.init!!(rng, model, vi, DynamicPPL.InitFromPrior()) | ||
| return DynamicPPL.ParamsWithStats(vi), nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually quite neat that the whole Prior sampler is just defined with DynamicPPL stuff now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As has been the case in past PRs of this sort, this file provides a gentle introduction of the kinds of changes being made.
Generally, the current status is that: MCMC states often bundle a varinfo, not for the purpose of actually being an accurate varinfo, but more as a 'home' to unflatten a vector of parameters into. (See #2642) The logp is usually not updated, because the only thing that's needed is for the next iteration to do vi[:].
This PR generally attempts to remove these varinfos from states, and only ever store the parameter vector + the LDF. Often the only reason why we carried around varinfo was so that we could re-evaluate with ValuesAsInModelAcc. However, because ParamsWithStats now has a method that takes the vector + LDF and returns the values-as-in-model, we can use that without needing a varinfo now.
... That's the ideal, at least ...
The reality is that most samplers still need to carry around a varinfo, specifically so that samplers can be used inside Gibbs. (DynamicHMCExt doesn't need to, because it's not 'Gibbs-enabled'.) This suggests that a potential, and immediate, way of decoupling varinfo from the individual samplers would be to have Gibbs handle this extra varinfo overhead (i.e. make gibbs store a (varinfo, state) tuple, rather than just state).
That's probably one for the (near-ish) future. For now, at least the scope of the varinfo argument has been reduced by quite a bit: it's no longer used in the actual AbstractMCMC.step implementations of most samplers.
| n_walkers = _get_n_walkers(spl) | ||
| chains = map(1:n_walkers) do i | ||
| this_walker_samples = [s[i] for s in samples] | ||
| AbstractMCMC.bundle_samples( | ||
| this_walker_samples, model, spl, state, chain_type; kwargs... | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably more inefficient than the old code, but I am not particularly fussed since chain construction isn't a bottleneck, and it's also way cleaner.
| """ | ||
| set_namedtuple!(vi::VarInfo, nt::NamedTuple) | ||
| Places the values of a `NamedTuple` into the relevant places of a `VarInfo`. | ||
| """ | ||
| function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTuple) | ||
| for (n, vals) in pairs(nt) | ||
| vns = vi.metadata[n].vns | ||
| if vals isa AbstractVector | ||
| vals = unvectorize(vals) | ||
| end | ||
| if length(vns) == 1 | ||
| # Only one variable, assign the values to it | ||
| DynamicPPL.setindex!(vi, vals, vns[1]) | ||
| else | ||
| # Spread the values across the variables | ||
| length(vns) == length(vals) || error("Unequal number of variables and values") | ||
| for (vn, val) in zip(vns, vals) | ||
| DynamicPPL.setindex!(vi, val, vn) | ||
| end | ||
| end | ||
| end | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is exactly the same as the one that's 'deleted' above. GitHub diffs being weird, sorry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unfortunately I had to do some surgery to the optimisation interface. I would have preferred to leave it for another time but the optim interface frequently assumed that LDF carried a varinfo with it.
|
I believe there is still a Mooncake failure on 1.12 with ADTypeCheckContext, but otherwise everything on CI should pass, unless I messed something up terribly. My suspicion is that it's a Mooncake issue, not Turing; however I'll only look into this later. Edit: Confirmed locally, the test passes on 1.11 and fails on 1.12. chalk-lab/Mooncake.jl#871 |
05604ad to
34a47ba
Compare
|
Still needs a changelog (also more bullet points for the changelog are welcome), but the code can be reviewed :) |
mhauru
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy with the code, just needs the HISTORY.md entry. A few small questions.
Do I understand correctly that the old Transition and the new ParamsWithStats will (typically?) cause the same number of evaluations, though the latter may be a bit faster due to use of OnlyAccsVarInfo? So a small positive performance change would be expected from that?
| """ | ||
| throw(ArgumentError(msg)) | ||
| end | ||
| function get_gibbs_global_varinfo(context::GibbsContext) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any particular reason to this change other than code style?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NodeTrait is gone, so I had to rewrite it and I think this is just what I ended up with. Separating the GibbsContext method is optional though - would you rather keep that inside the AbstractParentContext method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I was thinking of the refactor from if-else to method dispatch. There's something elegant about doing it with dispatch, but I sometimes find it more readable when there's just a single method with an if-else logic (that gets compiled away). Curious if you have a reason to prefer one. Regardless, happy to leave this as-is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the intent is for the compiler to optimise it away, then I think that method dispatch is a more direct way of expressing that. Is it a documented guarantee that the compiler will optimise if x isa T branches away?
In terms of style, I think I gravitate towards method dispatch because it's more declarative than imperative. Same reason why return if foo; x; else y; end over if foo; return x; else return y; end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a documented guarantee that the compiler will optimise if x isa T branches away
Good point, I think not. In practice I strongly expect it to happen in simple cases like these, but there's no guarantee.
I think "more declarative" is a more precise way of expressing what I meant by "more elegant". The downside is that you may have to read the code in a weird order, and the different declarations could be scattered all over your code base.
Anyway, good chat, but no code changes needed.
test/mcmc/is.jl
Outdated
|
|
||
| @test all(isone, chains[:x]) | ||
| @test chains.logevidence ≈ -2 * log(2) | ||
| logevidence = log(mean(exp.(chains[:loglikelihood]))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this the same as the above logsumexp(chain[:loglikelihood]) - log(N), but maybe less numerically stable or fast?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's probably both less numerically stable and slower. But it's imo clearer, avoids pulling in an extra import, and doesn't cause an issue in the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a slight preference for the specialised function just because it's generally good practice to use it and the tests could lead by example. This is at disagreement level 1.5.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha to me it's a microoptimisation rather than good practice 😅
I'll change it but leave in a comment saying that "this is equivalent to .... but more numerically stable"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems it makes very little difference for speed, but just guards against over- and underflow:
julia> function f(x)
display(logsumexp(x))
display(@b logsumexp(x))
display(log(sum(exp.(x))))
display(@b log(sum(exp.(x))))
return nothing
end
f (generic function with 1 method)
julia> f(randn(10_000))
9.67801864521181
37.958 μs
9.678018645211806
35.291 μs (3 allocs: 96.062 KiB)
julia> f(randn(10_000)*1000)
3682.3316144326304
43.666 μs
Inf
77.375 μs (3 allocs: 96.062 KiB)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but the chain isn't generating those values. Not a big deal, changed now.
Yup, that's right. The performance difference is probably not very big - in my opinion the nice thing is that the behaviour is encapsulated in DPPL. |
|
Changelog added. |
|
... Darn, forgot to mention the logevidence thing. |
mhauru
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One optional addition to HISTORY.md, happy to approve.
I'm very excited for fast LDF to hit the streets and see people go screaming.
|
|
||
| - your model has other kinds of parallelism but does not include tilde-statements inside; | ||
| - or you are using `MCMCThreads()` or `MCMCDistributed()` to sample multiple chains in parallel, but your model itself does not use parallelism. | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| If your model does include parallelised tilde-statements or `@addlogprob!` calls, and you evaluate it/sample from it without setting `setthreadsafe(model, true)`, then you may get statistically incorrect results without any warnings or errors. | |
|
|
||
| When sampling using MCMCChains, the chain object will no longer have its `chain.logevidence` field set. | ||
| Instead, you can calculate this yourself from the log-likelihoods stored in the chain. | ||
| For SMC samplers, the log-evidence of the entire trajectory is stored in `chain[:logevidence]` (which is the same for every particle in the 'chain'). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check, is this only for SMC or also for PG?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having thought about this for 3 more seconds, this probably makes no sense for PG. Ignore me.
The main change in DPPL 0.39 is OnlyAccsVarInfo and faster evaluation.
This PR uses fast evaluation in MCMC sampling where it can. MCMC sampling mostly works as can be seen from the tests.
Of note
chain.logevidencefieldThe reason is because we now use the
bundle_samplesmethod in DynamicPPL, which has no way of reliably determining the log-evidence from the transition. If we wanted to fix this, we would have to add agetlogevidencefunction in AbstractMCMC.I personally don't consider this a problem. The reason why log-evidence used to be stored was because chains did not provide enough information for people to calculate this themselves (specifically, chains only stored logp, not loglikelihood). Now that
chn[:likelihood]contains the likelihood, it should be ok for people to calculate this themselves.