Skip to content

Blackjax sampler suffers divergences with TruncatedNormal likelihood #775

@noahg2

Description

@noahg2

Describe the issue as clearly as possible:

It seems that the PyMC blackjax sampler struggles to sample from models that have a truncated likelihood. In particular, the sampler often fails to converge with many divergences. Based on this initial discussion, it seems that this is likely to be a problem with the sampler because other samplers complete successfully on the same model, and there doesn't appear to be any issues with geometry at play.

Screenshot of erroneous output:

Image

Steps/code to reproduce the bug:

import pymc as pm

N_OBSERVATIONS = 50

with pm.Model() as model:
    mu = pm.Normal("mu")
    sigma = pm.HalfNormal("sigma", sigma=0.5)
    y = pm.TruncatedNormal("y", mu=mu, sigma=sigma, lower=-10, upper=10, size=(N_OBSERVATIONS,))
    prior_trace = pm.sample_prior_predictive(random_seed=100)

data = prior_trace.prior.y.isel(chain=0, draw=0)
with pm.observe(model, {y: data}):
    idata = pm.sample(nuts_sampler="blackjax")
    idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True)

Expected result:

The sampler should complete with no divergences with posteriors similar to those of the PyMC NUTS.

Error message:

Blackjax/JAX/jaxlib/Python version information:

BlackJax 1.2.4
Python 3.11.11 | packaged by conda-forge | (main, Dec  5 2024, 14:21:42) [Clang 18.1.8 ]
Jax 0.4.31
Jaxlib 0.4.31
PyMC 5.20.0

Context for the issue:

This issue appears to render the sampler unable to fit models with certain truncated likelihoods, which are a useful construct in a number of domains.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions