Skip to content

Commit 43b0f17

Browse files
jadechoghari2toinfmichel-aractingiimstevenpmwork
authored
feat(policies): Add X-VLA (#2405)
* first commit * more fixes * add franka action * update testing script * add changes * update files * logits matching * add imagenet as a norm type * logits matching atol1e-2 * more eval fixes * more changes * xvla works on libero * remove seed * more refactoring * more fixes * more changes * more changes * more fixes * migrate policy revert * major pre-commit cleanup * renaming * revert to self.transformer * refactor * new changes * clean * update libero * more changes * make it work * more changes: * remove imagenet dependency * style * more * more refactor * remove proprio * add loss * more * more * add freeze/unfreeze options * add testing * upgrade transformers version * update testing * add installation * remove .sh file * fix testing * silent linter in xvlatest * fix failing test * upgrade test, fix failing * fix testing * more fixes to testing * require cuda in tests * temp check * add xvla docs * fix styling * update libero doc * remove timm dep * add different dtype support * remove timm skip * remove white lines * Enhance X-VLA finetuning documentation with optimizer details (#2537) Added detailed instructions for implementing a custom optimizer and modifying parameter retrieval for X-VLA finetuning. Signed-off-by: Jinliang Zheng <[email protected]> * fix style * iterate on review * iterate on cpilot * revert xvla dep * free up ci * test(xvla): remove main test (#2565) * Add xvla custom optim and dtype (#2567) * add custom optim * add custom optim * add auto mode * more changes * add identity to all * add auto * release * add docs * make image smaller docs * smaller image in doc * evan smaller image doc * finalize doc --------- Signed-off-by: Jinliang Zheng <[email protected]> Signed-off-by: Steven Palma <[email protected]> Co-authored-by: Jinliang Zheng <[email protected]> Co-authored-by: Michel Aractingi <[email protected]> Co-authored-by: Steven Palma <[email protected]>
1 parent b0b7554 commit 43b0f17

22 files changed

+6620
-10
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
title: π₀.₅ (Pi05)
4040
- local: groot
4141
title: NVIDIA GR00T N1.5
42+
- local: xvla
43+
title: X-VLA
4244
title: "Policies"
4345
- sections:
4446
- local: async

docs/source/libero.mdx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ lerobot-eval \
6262

6363
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
6464

65+
### Control Mode
66+
67+
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
68+
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
69+
6570
### Policy inputs and outputs
6671

6772
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:

docs/source/xvla.mdx

Lines changed: 570 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ groot = [
133133
"ninja>=1.11.1,<2.0.0",
134134
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
135135
]
136+
xvla = ["lerobot[transformers-dep]"]
136137
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
137138

138139
# Features
@@ -161,6 +162,7 @@ all = [
161162
"lerobot[pi]",
162163
"lerobot[smolvla]",
163164
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
165+
"lerobot[xvla]",
164166
"lerobot[hilserl]",
165167
"lerobot[async]",
166168
"lerobot[dev]",

src/lerobot/envs/configs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def gym_kwargs(self) -> dict:
245245
class LiberoEnv(EnvConfig):
246246
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
247247
fps: int = 30
248-
episode_length: int = 520
248+
episode_length: int | None = None
249249
obs_type: str = "pixels_agent_pos"
250250
render_mode: str = "rgb_array"
251251
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
@@ -272,6 +272,7 @@ class LiberoEnv(EnvConfig):
272272
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
273273
}
274274
)
275+
control_mode: str = "relative" # or "absolute"
275276

276277
def __post_init__(self):
277278
if self.obs_type == "pixels":

src/lerobot/envs/factory.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import gymnasium as gym
2020
from gymnasium.envs.registration import registry as gym_registry
2121

22+
from lerobot.configs.policies import PreTrainedConfig
2223
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
2324
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
25+
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
2426
from lerobot.processor import ProcessorStep
2527
from lerobot.processor.env_processor import LiberoProcessorStep
2628
from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -39,6 +41,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
3941

4042
def make_env_pre_post_processors(
4143
env_cfg: EnvConfig,
44+
policy_cfg: PreTrainedConfig,
4245
) -> tuple[
4346
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
4447
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
@@ -61,6 +64,10 @@ def make_env_pre_post_processors(
6164
# Preprocessor and Postprocessor steps are Identity for most environments
6265
preprocessor_steps: list[ProcessorStep] = []
6366
postprocessor_steps: list[ProcessorStep] = []
67+
if isinstance(policy_cfg, XVLAConfig):
68+
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
69+
70+
return make_xvla_libero_pre_post_processors()
6471

6572
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
6673
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
@@ -136,6 +143,8 @@ def make_env(
136143
init_states=cfg.init_states,
137144
gym_kwargs=cfg.gym_kwargs,
138145
env_cls=env_cls,
146+
control_mode=cfg.control_mode,
147+
episode_length=cfg.episode_length,
139148
)
140149
elif "metaworld" in cfg.type:
141150
from lerobot.envs.metaworld import create_metaworld_envs

src/lerobot/envs/libero.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,7 @@ def get_libero_dummy_action():
8080
return [0, 0, 0, 0, 0, 0, -1]
8181

8282

83-
OBS_STATE_DIM = 8
8483
ACTION_DIM = 7
85-
AGENT_POS_LOW = -1000.0
86-
AGENT_POS_HIGH = 1000.0
8784
ACTION_LOW = -1.0
8885
ACTION_HIGH = 1.0
8986
TASK_SUITE_MAX_STEPS: dict[str, int] = {
@@ -103,6 +100,7 @@ def __init__(
103100
task_suite: Any,
104101
task_id: int,
105102
task_suite_name: str,
103+
episode_length: int | None = None,
106104
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
107105
obs_type: str = "pixels",
108106
render_mode: str = "rgb_array",
@@ -114,6 +112,7 @@ def __init__(
114112
episode_index: int = 0,
115113
camera_name_mapping: dict[str, str] | None = None,
116114
num_steps_wait: int = 10,
115+
control_mode: str = "relative",
117116
):
118117
super().__init__()
119118
self.task_id = task_id
@@ -141,14 +140,19 @@ def __init__(
141140
self.camera_name_mapping = camera_name_mapping
142141
self.num_steps_wait = num_steps_wait
143142
self.episode_index = episode_index
143+
self.episode_length = episode_length
144144
# Load once and keep
145145
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
146146
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
147147

148148
self._env = self._make_envs_task(task_suite, self.task_id)
149149
default_steps = 500
150-
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
151-
150+
self._max_episode_steps = (
151+
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
152+
if self.episode_length is None
153+
else self.episode_length
154+
)
155+
self.control_mode = control_mode
152156
images = {}
153157
for cam in self.camera_name:
154158
images[self.camera_name_mapping[cam]] = spaces.Box(
@@ -296,6 +300,15 @@ def reset(self, seed=None, **kwargs):
296300
# Increasing this value can improve determinism and reproducibility across resets.
297301
for _ in range(self.num_steps_wait):
298302
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
303+
304+
if self.control_mode == "absolute":
305+
for robot in self._env.robots:
306+
robot.controller.use_delta = False
307+
elif self.control_mode == "relative":
308+
for robot in self._env.robots:
309+
robot.controller.use_delta = True
310+
else:
311+
raise ValueError(f"Invalid control mode: {self.control_mode}")
299312
observation = self._format_raw_obs(raw_obs)
300313
info = {"is_success": False}
301314
return observation, info
@@ -341,8 +354,10 @@ def _make_env_fns(
341354
task_id: int,
342355
n_envs: int,
343356
camera_names: list[str],
357+
episode_length: int | None,
344358
init_states: bool,
345359
gym_kwargs: Mapping[str, Any],
360+
control_mode: str,
346361
) -> list[Callable[[], LiberoEnv]]:
347362
"""Build n_envs factory callables for a single (suite, task_id)."""
348363

@@ -354,7 +369,9 @@ def _make_env(episode_index: int, **kwargs) -> LiberoEnv:
354369
task_suite_name=suite_name,
355370
camera_name=camera_names,
356371
init_states=init_states,
372+
episode_length=episode_length,
357373
episode_index=episode_index,
374+
control_mode=control_mode,
358375
**local_kwargs,
359376
)
360377

@@ -374,6 +391,8 @@ def create_libero_envs(
374391
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
375392
init_states: bool = True,
376393
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
394+
control_mode: str = "relative",
395+
episode_length: int | None = None,
377396
) -> dict[str, dict[int, Any]]:
378397
"""
379398
Create vectorized LIBERO environments with a consistent return shape.
@@ -415,12 +434,14 @@ def create_libero_envs(
415434
for tid in selected:
416435
fns = _make_env_fns(
417436
suite=suite,
437+
episode_length=episode_length,
418438
suite_name=suite_name,
419439
task_id=tid,
420440
n_envs=n_envs,
421441
camera_names=camera_names,
422442
init_states=init_states,
423443
gym_kwargs=gym_kwargs,
444+
control_mode=control_mode,
424445
)
425446
out[suite_name][tid] = env_cls(fns)
426447
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")

src/lerobot/optim/optimizers.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,107 @@ def build(self, params: dict) -> torch.optim.Optimizer:
104104
return torch.optim.SGD(params, **kwargs)
105105

106106

107+
@OptimizerConfig.register_subclass("xvla-adamw")
108+
@dataclass
109+
class XVLAAdamWConfig(OptimizerConfig):
110+
"""Custom AdamW optimizer for XVLA with differential learning rates.
111+
112+
The Vision-Language Model (VLM) is trained with 1/10 of the base learning rate
113+
for stable optimization, while all other components use the full LR.
114+
115+
This LR ratio is crucial for achieving strong and stable finetuning performance.
116+
117+
Soft-prompts can optionally use a separate learning rate with warm-up support.
118+
Set `soft_prompt_lr_scale` to a value < 1.0 (e.g., 0.1) to start soft-prompts
119+
at a lower LR. Combine with a warmup scheduler for optimal results.
120+
121+
Note:
122+
Completely matching official reported performance may require an additional
123+
warm-up LR schedule for soft-prompts, which can bring minor improvements.
124+
When `soft_prompt_warmup_lr_scale` is set, soft-prompts start at
125+
`lr * soft_prompt_warmup_lr_scale` and should be warmed up via the scheduler.
126+
127+
Parameter Groups:
128+
- Group 0 (vlm): VLM parameters at lr * 0.1, weight_decay * 0.1
129+
- Group 1 (soft_prompts): Soft-prompt parameters at lr * soft_prompt_lr_scale
130+
- Group 2 (other): All other parameters at full lr
131+
"""
132+
133+
lr: float = 1e-4
134+
betas: tuple[float, float] = (0.9, 0.99)
135+
eps: float = 1e-8
136+
weight_decay: float = 0.0
137+
grad_clip_norm: float = 10.0
138+
# Soft-prompt specific settings
139+
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
140+
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
141+
142+
def build(self, params: dict) -> torch.optim.Optimizer:
143+
"""
144+
Build AdamW optimizer with differential learning rates.
145+
146+
Expects `named_parameters()` as input (dict of name -> param).
147+
Applies:
148+
- lr * 0.1 for all VLM-related parameters
149+
- lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup)
150+
- full lr for all other parameters
151+
152+
Args:
153+
params: Dictionary of parameter names to parameters (from named_parameters())
154+
155+
Returns:
156+
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
157+
"""
158+
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
159+
160+
vlm_group, soft_prompt_group, other_group = [], [], []
161+
for name, p in params.items():
162+
if not p.requires_grad:
163+
continue
164+
if "vlm" in name.lower():
165+
vlm_group.append(p)
166+
elif "soft_prompt" in name.lower():
167+
soft_prompt_group.append(p)
168+
else:
169+
other_group.append(p)
170+
171+
# Determine soft-prompt LR
172+
soft_prompt_lr = self.lr * self.soft_prompt_lr_scale
173+
if self.soft_prompt_warmup_lr_scale is not None:
174+
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
175+
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
176+
177+
param_groups = [
178+
{
179+
"params": vlm_group,
180+
"lr": self.lr * 0.1,
181+
"weight_decay": self.weight_decay * 0.1,
182+
"name": "vlm",
183+
},
184+
{
185+
"params": soft_prompt_group,
186+
"lr": soft_prompt_lr,
187+
"weight_decay": self.weight_decay,
188+
"name": "soft_prompts",
189+
},
190+
{
191+
"params": other_group,
192+
"lr": self.lr,
193+
"weight_decay": self.weight_decay,
194+
"name": "other",
195+
},
196+
]
197+
198+
# Filter out empty groups
199+
param_groups = [g for g in param_groups if len(g["params"]) > 0]
200+
201+
return torch.optim.AdamW(
202+
param_groups,
203+
betas=self.betas,
204+
eps=self.eps,
205+
)
206+
207+
107208
@OptimizerConfig.register_subclass("multi_adam")
108209
@dataclass
109210
class MultiAdamConfig(OptimizerConfig):

src/lerobot/policies/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
2222
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
2323
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
24+
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
2425

2526
__all__ = [
2627
"ACTConfig",
@@ -31,4 +32,5 @@
3132
"TDMPCConfig",
3233
"VQBeTConfig",
3334
"GrootConfig",
35+
"XVLAConfig",
3436
]

src/lerobot/policies/factory.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
4242
from lerobot.policies.utils import validate_visual_features_consistency
4343
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
44+
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
4445
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
4546
from lerobot.processor.converters import (
4647
batch_to_transition,
@@ -108,6 +109,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
108109
from lerobot.policies.groot.modeling_groot import GrootPolicy
109110

110111
return GrootPolicy
112+
elif name == "xvla":
113+
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
114+
115+
return XVLAPolicy
111116
else:
112117
try:
113118
return _get_policy_cls_from_policy_name(name=name)
@@ -154,6 +159,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
154159
return RewardClassifierConfig(**kwargs)
155160
elif policy_type == "groot":
156161
return GrootConfig(**kwargs)
162+
elif policy_type == "xvla":
163+
return XVLAConfig(**kwargs)
157164
else:
158165
try:
159166
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -337,6 +344,15 @@ def make_pre_post_processors(
337344
config=policy_cfg,
338345
dataset_stats=kwargs.get("dataset_stats"),
339346
)
347+
elif isinstance(policy_cfg, XVLAConfig):
348+
from lerobot.policies.xvla.processor_xvla import (
349+
make_xvla_pre_post_processors,
350+
)
351+
352+
processors = make_xvla_pre_post_processors(
353+
config=policy_cfg,
354+
dataset_stats=kwargs.get("dataset_stats"),
355+
)
340356

341357
else:
342358
try:
@@ -414,8 +430,7 @@ def make_policy(
414430
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
415431
features = env_to_policy_features(env_cfg)
416432

417-
if not cfg.output_features:
418-
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
433+
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
419434
if not cfg.input_features:
420435
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
421436
kwargs["config"] = cfg

0 commit comments

Comments
 (0)