-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Dist muon #3264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Dist muon #3264
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughAdds distributed Muon optimizer support for FSDP2 (Fully Sharded Data Parallel v2) training. Changes include new pretraining configuration templates, conditional distributed optimizer initialization, refined validation checks for Muon-FSDP version compatibility, and comprehensive multi-GPU end-to-end tests validating various configurations and scaling scenarios. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
📖 Documentation Preview: https://6924bec7ff092b67ee40481e--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 1e74ae4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (2)
examples/qwen2/adamw-pretrain-fsdp2.yaml (1)
61-69: Use recommended FSDP2 schema in example (avoid deprecated nested fsdp_version andfsdp_prefixes)This AdamW example mirrors the MuON config by putting
fsdp_versionunderfsdp_configand usingfsdp_-prefixed keys. While validators will normalize this, they also emit deprecation warnings and can make logs confusing (e.g., generic FSDP version hints seeing the top-level version as unset/“not 2”).For a cleaner, recommended example, consider updating the FSDP block to:
That way, users copying this config get the non-deprecated layout by default, and all FSDP validators see the intended version and keys without extra normalization work.
tests/e2e/multigpu/test_dist_muon_fsdp2.py (1)
52-371: Consider factoring out repeated config+subprocess scaffolding in testsEach test in
TestDistMuonFSDP2repeats very similar steps:
- Construct a small
DictDefaultconfig with minor variations.- Write it to
config.yamlundertemp_dir.- Invoke
execute_subprocess_async(["axolotl", "train", ...])with an appropriate--num-processesand unique port.- Call
verify_training_success(temp_dir).The duplication is not wrong, but it does make future tweaks (e.g., changing dataset or base model for all tests) a bit more error-prone.
If you want to reduce repetition, you could introduce a small helper that takes only the per-test overrides, for example:
+def _run_dist_muon_fsdp2(temp_dir, overrides, num_processes: int): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.02, + "optimizer": "muon", + "weight_decay": 0.01, + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + } | overrides + ) + + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + str(num_processes), + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir)Then each test would only specify its unique bits (e.g., LoRA settings, cpu_ram_efficient_loading, reshard_after_forward) and the desired
num_processes.This is optional, but it can make these e2e tests easier to evolve over time.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/qwen2/adamw-pretrain-fsdp2.yaml(1 hunks)examples/qwen2/muon-pretrain-fsdp2.yaml(1 hunks)src/axolotl/core/builders/base.py(1 hunks)src/axolotl/utils/schemas/validation.py(1 hunks)tests/e2e/multigpu/test_dist_muon_fsdp2.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-22T22:14:35.531Z
Learnt from: gholmes829
Repo: axolotl-ai-cloud/axolotl PR: 3167
File: src/axolotl/utils/schemas/validation.py:819-834
Timestamp: 2025-09-22T22:14:35.531Z
Learning: In the axolotl codebase, validation methods maintain separation of concerns - early validators focus on core logic while `check_fsdp_config_kwargs_prefix` handles deprecated prefix normalization. This pattern should be preserved for consistency rather than mixing prefix handling into individual validators.
Applied to files:
src/axolotl/utils/schemas/validation.py
🧬 Code graph analysis (2)
src/axolotl/core/builders/base.py (1)
src/axolotl/utils/distributed.py (1)
build_parallelism_config(299-316)
tests/e2e/multigpu/test_dist_muon_fsdp2.py (2)
src/axolotl/utils/dict.py (1)
DictDefault(6-38)tests/e2e/utils.py (2)
most_recent_subdir(35-42)require_torch_2_7_0(81-90)
🪛 Ruff (0.14.5)
src/axolotl/utils/schemas/validation.py
758-760: Avoid specifying long messages outside the exception class
(TRY003)
765-767: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: preview
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.9.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.9.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.7.1)
🔇 Additional comments (1)
src/axolotl/core/builders/base.py (1)
270-288: DistMuon vs Muon selection logic looks correctThe device-mesh–based branch cleanly selects
DistMuonOptimizerFactoryfor distributed setups and falls back toMuonOptimizerFactoryfor single-GPU, reusing the existingbuild_parallelism_configpattern used for other custom optimizers. No changes needed here from a correctness standpoint.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
using the two example confis I've included in this PR
Summary by CodeRabbit
New Features
Tests