Skip to content

Conversation

@SalmanMohammadi
Copy link
Contributor

@SalmanMohammadi SalmanMohammadi commented Nov 14, 2025

image

using the two example confis I've included in this PR

Summary by CodeRabbit

  • New Features

    • Enabled Muon optimizer support for distributed training with FSDP2.
    • Added example configurations for Qwen2 model pretraining with AdamW and Muon optimizers, including FSDP2 settings and comprehensive hyperparameters.
    • Improved FSDP version-aware compatibility validation for Muon optimizer.
  • Tests

    • Added multi-GPU end-to-end tests validating distributed Muon with FSDP2 configurations across 2 and 4 GPU setups, including full fine-tuning and LoRA scenarios.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 14, 2025

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Adds distributed Muon optimizer support for FSDP2 (Fully Sharded Data Parallel v2) training. Changes include new pretraining configuration templates, conditional distributed optimizer initialization, refined validation checks for Muon-FSDP version compatibility, and comprehensive multi-GPU end-to-end tests validating various configurations and scaling scenarios.

Changes

Cohort / File(s) Summary
Pretraining Configuration Templates
examples/qwen2/adamw-pretrain-fsdp2.yaml, examples/qwen2/muon-pretrain-fsdp2.yaml
Adds two new YAML configuration files for Qwen2 pretraining with FSDP2. One uses AdamW optimizer while the other uses Muon optimizer. Both specify training hyperparameters (2048 sequence length, 1 epoch), gradient accumulation, flash attention, gradient checkpointing, and fsdp_config with version 2 settings.
Distributed Optimizer Support
src/axolotl/core/builders/base.py
Implements conditional selection between DistMuonOptimizerFactory (for distributed setups with device mesh) and MuonOptimizerFactory (for single-GPU), injecting device_mesh into optimizer kwargs when needed.
Validator Schema Refinement
src/axolotl/utils/schemas/validation.py
Refines Muon optimizer compatibility checks to disallow DeepSpeed while permitting FSDP only when fsdp_version == 2. Assumes version 1 by default and raises targeted errors for incompatible configurations.
End-to-End Tests
tests/e2e/multigpu/test_dist_muon_fsdp2.py
Introduces TestDistMuonFSDP2 test class with five test methods covering: full fine-tuning (FFT) and LoRA fine-tuning on 2 GPUs, 4-GPU scaling, and parametrized variations of cpu_ram_efficient_loading and reshard_after_forward flags. Tests verify artifact generation, checkpoint creation, and loss sanity.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Configuration files (examples/qwen2/*.yaml): Repetitive structure and straightforward; minimal review overhead.
  • Optimizer builder logic (src/axolotl/core/builders/base.py): Requires careful verification of conditional branching and device_mesh injection for distributed paths.
  • Validation refinement (src/axolotl/utils/schemas/validation.py): Version-aware logic and error messaging warrant attention to ensure fsdp_version interpretation is correct and error messages are clear.
  • Test suite (tests/e2e/multigpu/test_dist_muon_fsdp2.py): Substantial and introduces multiple test scenarios; reviewer should verify configuration generation, distributed training launch correctness, and artifact validation logic.

Possibly related PRs

  • FSDP2 fix validation and add tests #2910: Modifies FSDP-related validation logic in src/axolotl/utils/schemas/validation.py around fsdp_version and optimizer compatibility checks, overlapping with the Muon-FSDP validation refinements in this PR.

Suggested reviewers

  • winglian

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The title 'Dist muon' is extremely vague and does not clearly convey the main changes in the pull request, which include adding distributed Muon optimizer support with FSDP2, new configuration examples, validation updates, and comprehensive end-to-end tests. Consider expanding the title to be more descriptive, such as 'Add distributed Muon optimizer support with FSDP2' or 'Support DistMuon with FSDP2 configuration and testing' to better reflect the scope of changes.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@SalmanMohammadi SalmanMohammadi marked this pull request as ready for review November 19, 2025 19:26
@github-actions
Copy link
Contributor

github-actions bot commented Nov 19, 2025

📖 Documentation Preview: https://6924bec7ff092b67ee40481e--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 1e74ae4

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
examples/qwen2/adamw-pretrain-fsdp2.yaml (1)

61-69: Use recommended FSDP2 schema in example (avoid deprecated nested fsdp_version and fsdp_ prefixes)

This AdamW example mirrors the MuON config by putting fsdp_version under fsdp_config and using fsdp_-prefixed keys. While validators will normalize this, they also emit deprecation warnings and can make logs confusing (e.g., generic FSDP version hints seeing the top-level version as unset/“not 2”).

For a cleaner, recommended example, consider updating the FSDP block to:

That way, users copying this config get the non-deprecated layout by default, and all FSDP validators see the intended version and keys without extra normalization work.

tests/e2e/multigpu/test_dist_muon_fsdp2.py (1)

52-371: Consider factoring out repeated config+subprocess scaffolding in tests

Each test in TestDistMuonFSDP2 repeats very similar steps:

  • Construct a small DictDefault config with minor variations.
  • Write it to config.yaml under temp_dir.
  • Invoke execute_subprocess_async(["axolotl", "train", ...]) with an appropriate --num-processes and unique port.
  • Call verify_training_success(temp_dir).

The duplication is not wrong, but it does make future tweaks (e.g., changing dataset or base model for all tests) a bit more error-prone.

If you want to reduce repetition, you could introduce a small helper that takes only the per-test overrides, for example:

+def _run_dist_muon_fsdp2(temp_dir, overrides, num_processes: int):
+    cfg = DictDefault(
+        {
+            "base_model": "Qwen/Qwen2.5-0.5B",
+            "sequence_len": 2048,
+            "val_set_size": 0.01,
+            "datasets": [
+                {
+                    "path": "tatsu-lab/alpaca",
+                    "type": "alpaca",
+                    "split": "train[:10%]",
+                },
+            ],
+            "num_epochs": 1,
+            "max_steps": 2,
+            "micro_batch_size": 2,
+            "gradient_accumulation_steps": 1,
+            "output_dir": temp_dir,
+            "learning_rate": 0.02,
+            "optimizer": "muon",
+            "weight_decay": 0.01,
+            "lr_scheduler": "cosine",
+            "flash_attention": True,
+            "fsdp_version": 2,
+            "fsdp_config": {
+                "offload_params": False,
+                "cpu_ram_efficient_loading": False,
+                "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
+                "state_dict_type": "FULL_STATE_DICT",
+                "auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
+                "reshard_after_forward": True,
+            },
+            "use_tensorboard": True,
+            "bf16": True,
+        } | overrides
+    )
+
+    Path(temp_dir).mkdir(parents=True, exist_ok=True)
+    with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
+        fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
+
+    execute_subprocess_async(
+        [
+            "axolotl",
+            "train",
+            str(Path(temp_dir) / "config.yaml"),
+            "--num-processes",
+            str(num_processes),
+            "--main-process-port",
+            f"{get_torch_dist_unique_port()}",
+        ]
+    )
+
+    verify_training_success(temp_dir)

Then each test would only specify its unique bits (e.g., LoRA settings, cpu_ram_efficient_loading, reshard_after_forward) and the desired num_processes.

This is optional, but it can make these e2e tests easier to evolve over time.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 301e228 and c6e0c2e.

📒 Files selected for processing (5)
  • examples/qwen2/adamw-pretrain-fsdp2.yaml (1 hunks)
  • examples/qwen2/muon-pretrain-fsdp2.yaml (1 hunks)
  • src/axolotl/core/builders/base.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (1 hunks)
  • tests/e2e/multigpu/test_dist_muon_fsdp2.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-22T22:14:35.531Z
Learnt from: gholmes829
Repo: axolotl-ai-cloud/axolotl PR: 3167
File: src/axolotl/utils/schemas/validation.py:819-834
Timestamp: 2025-09-22T22:14:35.531Z
Learning: In the axolotl codebase, validation methods maintain separation of concerns - early validators focus on core logic while `check_fsdp_config_kwargs_prefix` handles deprecated prefix normalization. This pattern should be preserved for consistency rather than mixing prefix handling into individual validators.

Applied to files:

  • src/axolotl/utils/schemas/validation.py
🧬 Code graph analysis (2)
src/axolotl/core/builders/base.py (1)
src/axolotl/utils/distributed.py (1)
  • build_parallelism_config (299-316)
tests/e2e/multigpu/test_dist_muon_fsdp2.py (2)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
tests/e2e/utils.py (2)
  • most_recent_subdir (35-42)
  • require_torch_2_7_0 (81-90)
🪛 Ruff (0.14.5)
src/axolotl/utils/schemas/validation.py

758-760: Avoid specifying long messages outside the exception class

(TRY003)


765-767: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: preview
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.9.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.9.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
🔇 Additional comments (1)
src/axolotl/core/builders/base.py (1)

270-288: DistMuon vs Muon selection logic looks correct

The device-mesh–based branch cleanly selects DistMuonOptimizerFactory for distributed setups and falls back to MuonOptimizerFactory for single-GPU, reusing the existing build_parallelism_config pattern used for other custom optimizers. No changes needed here from a correctness standpoint.

@SalmanMohammadi SalmanMohammadi changed the title [WIP] Dist muon Dist muon Nov 20, 2025
@codecov
Copy link

codecov bot commented Nov 21, 2025

Codecov Report

❌ Patch coverage is 92.68293% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/core/builders/base.py 57.14% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants