JAX prefers immutable objects but neither Python nor JAX provide an immutable
dictionary. 😢
This repository defines a light-weight immutable map
(lower-level than a dict) that JAX understands as a PyTree. 🎉 🕶️
pip install xmmutablemapusing uv
uv add xmmutablemapfrom source, using pip
pip install git+https://github.com/GalacticDynamics/xmmutablemap.gitbuilding from source
cd /path/to/parent
git clone https://github.com/GalacticDynamics/xmmutablemap.git
cd xmmutablemap
pip install -e . # editable modexmutablemap provides the class ImmutableMap, which is a full implementation
of
Python's Mapping ABC.
If you've used a dict then you already know how to use ImmutableMap! The
things ImmutableMap adds is 1) immutability (and related benefits like
hashability) and 2) compatibility with JAX.
from xmmutablemap import ImmutableMap
print(ImmutableMap(a=1, b=2, c=3))
# ImmutableMap({'a': 1, 'b': 2, 'c': 3})
print(ImmutableMap({"a": 1, "b": 2.0, "c": "3"}))
# ImmutableMap({'a': 1, 'b': 2.0, 'c': '3'})One of the key benefits of ImmutableMap is its compatibility with JAX. Since
it's immutable and hashable, it can be used in places where JAX would normally
complain about mutable objects like regular dictionaries.
Here's an example showing how ImmutableMap can be used as a default value in a
dataclass, which is particularly useful with JAX:
import functools
import jax
import jax.numpy as jnp
from dataclasses import dataclass
from xmmutablemap import ImmutableMap
@functools.partial(
jax.tree_util.register_dataclass, data_fields=["params"], meta_fields=["batch_size"]
)
@dataclass(frozen=True)
class Config:
"""Configuration with immutable default parameters."""
# This works! ImmutableMap is immutable and hashable
params: ImmutableMap[str, float] = ImmutableMap(
learning_rate=0.001, momentum=0.9, weight_decay=1e-4
)
batch_size: int = 32
# JAX can safely transform functions using this dataclass
@jax.jit
def train_step(config: Config, data: jnp.ndarray) -> jnp.ndarray:
"""Example training step that uses config parameters."""
lr = config.params["learning_rate"]
return data * lr
# This works perfectly
config = Config()
data = jnp.array([1.0, 2.0, 3.0])
result = train_step(config, data)
print(f"Result: {result}")
# Result: [0.001 0.002 0.003]- Immutability: Once created,
ImmutableMapcannot be modified, preventing accidental mutations that could break JAX's functional programming model - Hashability: JAX can safely cache and memoize functions that use
ImmutableMapinstances - PyTree Support:
ImmutableMapis registered as a JAX PyTree, so it works seamlessly with JAX transformations likejit,grad,vmap, etc. - Safe Defaults: Can be used as default values in dataclasses without the typical pitfalls of mutable defaults
We welcome contributions!