-
Notifications
You must be signed in to change notification settings - Fork 122
Open
Description
Describe the issue as clearly as possible:
It seems that the PyMC blackjax sampler struggles to sample from models that have a truncated likelihood. In particular, the sampler often fails to converge with many divergences. Based on this initial discussion, it seems that this is likely to be a problem with the sampler because other samplers complete successfully on the same model, and there doesn't appear to be any issues with geometry at play.
Screenshot of erroneous output:
Steps/code to reproduce the bug:
import pymc as pm
N_OBSERVATIONS = 50
with pm.Model() as model:
mu = pm.Normal("mu")
sigma = pm.HalfNormal("sigma", sigma=0.5)
y = pm.TruncatedNormal("y", mu=mu, sigma=sigma, lower=-10, upper=10, size=(N_OBSERVATIONS,))
prior_trace = pm.sample_prior_predictive(random_seed=100)
data = prior_trace.prior.y.isel(chain=0, draw=0)
with pm.observe(model, {y: data}):
idata = pm.sample(nuts_sampler="blackjax")
idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True)Expected result:
The sampler should complete with no divergences with posteriors similar to those of the PyMC NUTS.Error message:
Blackjax/JAX/jaxlib/Python version information:
BlackJax 1.2.4
Python 3.11.11 | packaged by conda-forge | (main, Dec 5 2024, 14:21:42) [Clang 18.1.8 ]
Jax 0.4.31
Jaxlib 0.4.31
PyMC 5.20.0Context for the issue:
This issue appears to render the sampler unable to fit models with certain truncated likelihoods, which are a useful construct in a number of domains.
Metadata
Metadata
Assignees
Labels
No labels