Skip to content

Blackjax MCMC error #779

@danleonte

Description

@danleonte

Describe the issue as clearly as possible:

The example usage code from barker_mcmc

barker = blackjax.barker(logdensity_fn, step_size)
state = barker.init(position)
new_state, info = barker.step(rng_key, state)

raises an error

barker = blackjax.barker(log_density_fn, step_size)
TypeError: 'module' object is not callable

Possible related to #723. I got it to work for a single chain

# initial_states = jnp.repeat(jnp.array([15.,15.,0.,1.,0.])[jnp.newaxis,:],axis=0, repeats = num_chains)
initial_position = jnp.array([15., 15., 0., 1., 0.])  # shape (5,)
barker_state = blackjax.mcmc.barker.init(
    position=initial_position, logdensity_fn=log_density_fn)
barker_kernel = blackjax.barker.build_kernel()
barker_kernel(rng_key=rng_key, state=barker_state,
              logdensity_fn=log_density_fn, step_size=0.1)


@jax.jit
def one_step(state, subkey):
    new_state, info = barker_kernel(subkey, state, log_density_fn, step_size)
    return new_state, state.position


# blackjax.mcmc.barker.build_kernel(jax.random.PRNGKey(42),barker_state)
keys = jax.random.split(rng_key, num_samples)
result = jax.lax.scan(one_step, barker_state, keys)

but I can't wrap my head around the windows size adaptation, it keeps giving an error. Judging by

warmup = blackjax.window_adaptation(blackjax.nuts, logdensity)

from the quickstart page

I need to pass the barker kernel in some way, but the current one raises error. Alternatively, this window size adaptation attempt

barker_kernel = blackjax.mcmc.barker.as_top_level_api(logdensity_fn = log_density_fn, 
                                                      step_size = step_size,
                                                      inverse_mass_matrix  = jnp.eye(5))
barker_kernel.step(rng, initial_state)

blackjax.window_adaptation(algorithm = barker_kernel, logdensity_fn= log_density_fn, is_mass_matrix_diagonal= False)

results in a different error

   mcmc_kernel = algorithm.build_kernel(integrator)
AttributeError: 'SamplingAlgorithm' object has no attribute 'build_kernel'

Can you please advise on this? Thanks.

Steps/code to reproduce the bug:

barker = blackjax.barker(logdensity_fn, step_size)
state = barker.init(position)
new_state, info = barker.step(rng_key, state)

Expected result:

.

Error message:

barker = blackjax.barker(log_density_fn, step_size)
TypeError: 'module' object is not callable

Blackjax/JAX/jaxlib/Python version information:

blackjax: 1.2.5
jax: 0.5.0 
on CPU

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions