-
Notifications
You must be signed in to change notification settings - Fork 122
Open
Description
Describe the issue as clearly as possible:
The example usage code from barker_mcmc
barker = blackjax.barker(logdensity_fn, step_size)
state = barker.init(position)
new_state, info = barker.step(rng_key, state)raises an error
barker = blackjax.barker(log_density_fn, step_size)
TypeError: 'module' object is not callablePossible related to #723. I got it to work for a single chain
# initial_states = jnp.repeat(jnp.array([15.,15.,0.,1.,0.])[jnp.newaxis,:],axis=0, repeats = num_chains)
initial_position = jnp.array([15., 15., 0., 1., 0.]) # shape (5,)
barker_state = blackjax.mcmc.barker.init(
position=initial_position, logdensity_fn=log_density_fn)
barker_kernel = blackjax.barker.build_kernel()
barker_kernel(rng_key=rng_key, state=barker_state,
logdensity_fn=log_density_fn, step_size=0.1)
@jax.jit
def one_step(state, subkey):
new_state, info = barker_kernel(subkey, state, log_density_fn, step_size)
return new_state, state.position
# blackjax.mcmc.barker.build_kernel(jax.random.PRNGKey(42),barker_state)
keys = jax.random.split(rng_key, num_samples)
result = jax.lax.scan(one_step, barker_state, keys)but I can't wrap my head around the windows size adaptation, it keeps giving an error. Judging by
warmup = blackjax.window_adaptation(blackjax.nuts, logdensity)from the quickstart page
I need to pass the barker kernel in some way, but the current one raises error. Alternatively, this window size adaptation attempt
barker_kernel = blackjax.mcmc.barker.as_top_level_api(logdensity_fn = log_density_fn,
step_size = step_size,
inverse_mass_matrix = jnp.eye(5))
barker_kernel.step(rng, initial_state)
blackjax.window_adaptation(algorithm = barker_kernel, logdensity_fn= log_density_fn, is_mass_matrix_diagonal= False)
results in a different error
mcmc_kernel = algorithm.build_kernel(integrator)
AttributeError: 'SamplingAlgorithm' object has no attribute 'build_kernel'
Can you please advise on this? Thanks.
Steps/code to reproduce the bug:
barker = blackjax.barker(logdensity_fn, step_size)
state = barker.init(position)
new_state, info = barker.step(rng_key, state)Expected result:
.Error message:
barker = blackjax.barker(log_density_fn, step_size)
TypeError: 'module' object is not callableBlackjax/JAX/jaxlib/Python version information:
blackjax: 1.2.5
jax: 0.5.0
on CPUContext for the issue:
No response
Metadata
Metadata
Assignees
Labels
No labels