From cbccb387871a5e1f522c1e222c51ab88b03c0392 Mon Sep 17 00:00:00 2001 From: Andrew Aikawa Date: Fri, 11 Jul 2025 22:57:44 -0700 Subject: [PATCH 001/128] [benchmark] add h200 bench (#1361) DO NOT MERGE: WIP This is a baseline for multi-node pretraining on H200s, since currently there don't see seem to be any numbers out for H200. --- ...llama3-8b_h200_202506_trainy-whitefiber.md | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 benchmarks/llama3-8b_h200_202506_trainy-whitefiber.md diff --git a/benchmarks/llama3-8b_h200_202506_trainy-whitefiber.md b/benchmarks/llama3-8b_h200_202506_trainy-whitefiber.md new file mode 100644 index 0000000000..9ba1490f3e --- /dev/null +++ b/benchmarks/llama3-8b_h200_202506_trainy-whitefiber.md @@ -0,0 +1,54 @@ +This was performed by Trainy team on WhiteFiber in June 2025, to get a baseline of performance +of the Trainy platform on H200s platform over multiple hosts. + +### Models + +Llama 3.1 8B + +### Hardware + +Each host has + +- 8 NVIDIA H200 GPUs connected via NVLink. +- Hosts are inter-connected with a backend RDMA fabric with 400Gb/s (Mellanox CX-7) per GPU. + +### Configuration + +Runs were invoked with the following, where `NUM_NODES` was `4` and `8` +``` + torchrun \ + --nnodes $NUM_NODES \ + --nproc_per_node 8 \ + --rdzv_id 101 \ + --rdzv_backend c10d \ + --rdzv_endpoint "$MASTER_ADDR:29500" \ + torchtitan/train.py \ + --job.config-file torchtitan/models/llama3/train_configs/llama3_8b.toml \ + --metrics.enable_wandb \ + --training.local_batch_size=2 \ + --training.compile \ + --model.converters="float8" \ + --float8.enable_fsdp_float8_all_gather \ + --float8.precompute_float8_dynamic_scale_for_fsdp \ + --float8.force_recompute_fp8_weight_in_bwd \ + --profiling.profile_freq 1000000 + --training.steps 2000 +``` + +### Results + +Detailed performance results and training configurations can be found in the tables below along and can visualized in [this WandB report](https://api.wandb.ai/links/asaiacai/w4c46stp). `TPS` and `Memory(GiB)` are arbitrarily sampled at the 100th iteration: + +| NUM_NODES | TPS/GPU | Memory(GiB) | +| ----- | ----: | ----: | +| 4 | 10938 | 47.96 | +| 8 | 10753 | 46.97 | + + +### Versions and Dates + +| repo | commit | date | +| --- | --- | --- | +| torch | [2.8.0a0+5228986c39](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-05.html) | 2025/05/29 | +| torchao | [0afa4c1](https://github.com/pytorch/ao/commit/0afa4c1bd28c82921e360ddbd1b27c9d6da5b947) | 2025/06/13 | +| torchtitan | [e7c0cae](https://github.com/pytorch/torchtitan/commit/e7c0cae934df78d6e9c2835f42ff1f757dc3fddc) | 2025/06/13 | From 05e47c38d99fdb1dd39aeba76f080e529a425c5c Mon Sep 17 00:00:00 2001 From: Kfir Goldberg Date: Sun, 13 Jul 2025 21:59:36 +0300 Subject: [PATCH 002/128] fixing dtype in flux eval (#1388) passing dtype argument to preprocess_data in generate_image --- torchtitan/experiments/flux/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/experiments/flux/sampling.py b/torchtitan/experiments/flux/sampling.py index e9cb3a5438..1dd733fc55 100644 --- a/torchtitan/experiments/flux/sampling.py +++ b/torchtitan/experiments/flux/sampling.py @@ -101,7 +101,7 @@ def generate_image( batch = preprocess_data( device=device, - dtype=torch.bfloat16, + dtype=dtype, autoencoder=None, clip_encoder=clip_encoder, t5_encoder=t5_encoder, From 2764a77af3b8664ffceb9c8c772a2e11d8283220 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 14 Jul 2025 09:31:06 -0700 Subject: [PATCH 003/128] [float8] Fix module filter function (#1391) In a prior PR we added the `_init_filter_fn()` to configure a module filter function at Float8 component init time, but didn't actually use it. This went unnoticed because the existing module filter (`partial(module_filter_fn, filter_fqns=self.filter_fqns)` behaves the same way except for the case where the user uses `auto_filter_small_kn`. In this PR we fix that by using the `self.filter_fn`. ## Test plan - Test auto_filter_small_kn and verify the wk/wv are filtered for Llama3 8b: `NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --training.compile --model.converters="float8" --float8.recipe_name="rowwise" --parallelism.tensor_parallel_degree=2 --float8.filter_fqns="auto_filter_small_kn" --model.print-after-conversion` - Test without auto_filter_small_kn and verify all linears are converted: `NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --training.compile --model.converters="float8" --float8.recipe_name="rowwise" --parallelism.tensor_parallel_degree=2 --float8.filter_fqns="auto_filter" --model.print-after-conversion --- torchtitan/components/quantization/float8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index a5b2f967cf..91d42164a6 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -165,7 +165,7 @@ def convert(self, model: nn.Module): convert_to_float8_training( model, config=self.config, - module_filter_fn=partial(module_filter_fn, filter_fqns=self.filter_fqns), + module_filter_fn=self.filter_fn, ) logger.info( "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" From 7f5c3b66ac19f6826ee2515d721d8ee6b3f38696 Mon Sep 17 00:00:00 2001 From: Sarthak Pati Date: Mon, 14 Jul 2025 14:18:45 -0400 Subject: [PATCH 004/128] added badges for pip and conda, and explicit installation instructions (#1390) Fixes #1389 --- README.md | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8c81fa5000..d75dca769e 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,12 @@ [![integration tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu.yaml?query=branch%3Amain) [![arXiv](https://img.shields.io/badge/arXiv-2410.06511-b31b1b.svg)](https://arxiv.org/abs/2410.06511) -[![ICLR](https://img.shields.io/badge/ICLR-2025-blue.svg)](https://iclr.cc/virtual/2025/poster/29620) +[![ICLR](https://img.shields.io/badge/ICLR-2025-violet.svg)](https://iclr.cc/virtual/2025/poster/29620) [![forum](https://img.shields.io/badge/pytorch-forum-DE3412.svg)](https://discuss.pytorch.org/c/distributed/torchtitan/44) [![license](https://img.shields.io/badge/license-BSD_3--Clause-lightgrey.svg)](./LICENSE) +[![pip](https://img.shields.io/pypi/v/torchtitan?color=blue)](https://pypi.org/project/torchtitan/) +[![conda](https://img.shields.io/conda/vn/conda-forge/torchtitan?color=green)](https://anaconda.org/conda-forge/torchtitan) + @@ -86,6 +89,22 @@ You may want to see how the model is defined or how parallelism techniques are a ## Installation +### Nightly + +Coming soon. + +### Stable + +Via pip: +```sh +pip install torchtitan +``` +Or via conda: +```sh +conda install conda-forge::torchtitan +``` +### Sources + ```bash git clone https://github.com/pytorch/torchtitan cd torchtitan From 890897080db01cd1208deb34d2d25ba1462403c4 Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Mon, 14 Jul 2025 20:23:19 +0200 Subject: [PATCH 005/128] fix wrong b200 flops number (#1393) This pr fix what seems to be a wrong estimation of the peak flops for the B200. With the current code the peak flops of B200 is 4.5x bigger that H100 which seems off. It seems that the number reported https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703 are with 2:4 sparsity ? --- torchtitan/tools/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index cfc8c6f930..aaa0da8f89 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -97,7 +97,7 @@ def get_peak_flops(device_name: str) -> int: return 989e12 elif "B200" in device_name: # data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703 - return 4.5e15 + return 2.25e15 elif "MI300X" in device_name or "MI325X" in device_name: # MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html # MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html From 6204cdff9ca6bfd1fadd6c09621ea41507207dea Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 14 Jul 2025 13:43:40 -0700 Subject: [PATCH 006/128] refactor ParallelDims and CheckpointManager (#1384) This PR does the following: 1. move `world_mesh` into `ParallelDims`, as they have a close relationship 2. move `enable_loss_parallel` out of `ParallelDims` constructor 3. add a convenient property `seq_len_divisor` to `ParallelDims` 4. set `dataloader` and `ft_manager` as optional in `CheckpointManager` 5. some minor improvements on typing and code organization --- scripts/estimate/estimation.py | 34 +++++---- scripts/generate/test_generate.py | 11 ++- tests/unit_tests/test_model_converter.py | 1 - tests/unit_tests/test_train_spec.py | 33 ++++++--- torchtitan/components/checkpoint.py | 10 +-- torchtitan/components/optimizer.py | 2 - torchtitan/components/tokenizer.py | 4 +- torchtitan/components/validate.py | 16 ++--- torchtitan/distributed/parallel_dims.py | 42 +++++++---- torchtitan/distributed/utils.py | 17 +++-- .../experiments/deepseek_v3/__init__.py | 4 +- .../experiments/deepseek_v3/train_ds_real.py | 4 +- torchtitan/experiments/flux/__init__.py | 4 +- .../experiments/flux/infra/parallelize.py | 6 +- torchtitan/experiments/flux/train.py | 7 +- torchtitan/experiments/llama4/__init__.py | 4 +- .../experiments/llama4/infra/parallelize.py | 13 +++- torchtitan/experiments/llama4/optimizer.py | 10 ++- torchtitan/experiments/multimodal/__init__.py | 4 +- .../experiments/simple_fsdp/__init__.py | 4 +- .../experiments/simple_fsdp/parallelize.py | 21 ++++-- .../simple_fsdp/tests/test_numerics.py | 4 +- torchtitan/models/deepseek_v3/__init__.py | 4 +- .../models/deepseek_v3/infra/parallelize.py | 14 +++- torchtitan/models/llama3/__init__.py | 4 +- torchtitan/models/llama3/infra/parallelize.py | 16 ++--- torchtitan/models/llama3/infra/pipeline.py | 12 ++-- torchtitan/protocols/train_spec.py | 16 +++-- torchtitan/tools/utils.py | 6 +- torchtitan/train.py | 72 ++++++++++--------- 30 files changed, 220 insertions(+), 179 deletions(-) diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 33d4dc17f9..aade405b4d 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -39,6 +39,12 @@ def estimate_memory(job_config: JobConfig): job_config.training.compile = False job_config.parallelism.enable_compiled_autograd = False + # init fake pg + store = FakeStore() + torch.distributed.init_process_group( + "fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store + ) + parallelism_config = job_config.parallelism parallel_dims = ParallelDims( dp_shard=parallelism_config.data_parallel_shard_degree, @@ -48,8 +54,9 @@ def estimate_memory(job_config: JobConfig): pp=parallelism_config.pipeline_parallel_degree, ep=parallelism_config.expert_parallel_degree, world_size=world_size, - enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) + # ParallelDims.build_mesh has to happen outside of the FakeTensorMode + _ = parallel_dims.world_mesh # only FSDP and HSDP are supported if ( @@ -68,28 +75,21 @@ def estimate_memory(job_config: JobConfig): device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) - # init fake pg - store = FakeStore() - torch.distributed.init_process_group( - "fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store - ) - train_spec = get_train_spec(job_config.model.name) - # build meshes - world_mesh = parallel_dims.build_mesh(device_type="cuda") - # build tokenizer tokenizer = train_spec.build_tokenizer_fn(job_config) + loss_parallel_enabled = ( + parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel + ) train_context = dist_utils.get_train_context( - parallel_dims.loss_parallel_enabled, + loss_parallel_enabled, job_config.parallelism.enable_compiled_autograd, ) # build model (using meta init) - model_cls = train_spec.cls - model_args = train_spec.config[job_config.model.flavor] + model_args = train_spec.model_args[job_config.model.flavor] model_args.update_from_config(job_config, tokenizer) with ( @@ -101,14 +101,14 @@ def estimate_memory(job_config: JobConfig): f"Building {train_spec.name} {job_config.model.flavor} with {model_args}" ) with torch.device("meta"): - model = model_cls(model_args) + model = train_spec.model_cls(model_args) # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) model_converters.convert(model) # apply PT-D DP/TP parallelisms and activation checkpointing - train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) + train_spec.parallelize_fn(model, parallel_dims, job_config) model.to_empty(device="cuda") if not active_fake_mode(): @@ -117,9 +117,7 @@ def estimate_memory(job_config: JobConfig): # build optimizer after applying parallelisms to the model ft_manager = init_ft_manager(job_config) - optimizers = build_optimizers( - [model], job_config, parallel_dims, world_mesh, ft_manager - ) + optimizers = build_optimizers([model], job_config, parallel_dims, ft_manager) lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index ef31c18500..07966c2763 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -106,14 +106,13 @@ def test_generate( # Tokenizer setup tokenizer = train_spec.build_tokenizer_fn(config) - model_cls = train_spec.cls - model_args = train_spec.config[config.model.flavor] + model_args = train_spec.model_args[config.model.flavor] model_args.update_from_config(config, tokenizer) init_device = "meta" if world_size > 1 else device with torch.device(init_device): logger.info(f"Init model on init_device: {init_device}") - model = model_cls(model_args) + model = train_spec.model_cls(model_args) world_mesh = None # Init distributed env @@ -127,14 +126,12 @@ def test_generate( pp=1, ep=1, world_size=world_size, - enable_loss_parallel=False, ) - # Build world mesh for parallelism - world_mesh = parallel_dims.build_mesh(device_type=device_type) + world_mesh = parallel_dims.world_mesh # apply_tp (with Sequence Parallel) on unevenly sharded # sequences would require https://github.com/pytorch/torchtitan/pull/686 - apply_tp_minus_sp(model, world_mesh["tp"]) + apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"]) dist_utils.set_determinism(world_mesh, device, seed, deterministic) diff --git a/tests/unit_tests/test_model_converter.py b/tests/unit_tests/test_model_converter.py index 704e81a91d..6b9d9515f4 100644 --- a/tests/unit_tests/test_model_converter.py +++ b/tests/unit_tests/test_model_converter.py @@ -23,7 +23,6 @@ def build_parallel_dims(job_config, world_size): pp=parallelism_config.pipeline_parallel_degree, ep=parallelism_config.expert_parallel_degree, world_size=world_size, - enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) return parallel_dims diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index c364af3859..5b01454771 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -9,12 +9,14 @@ import pytest import torch import torch.nn as nn +from torchtitan.components.ft import FTManager from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers, OptimizersContainer from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.config_manager import JobConfig from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.models.llama3 import parallelize_llama, pipeline_llama from torchtitan.protocols.train_spec import ( apply_to_train_specs, @@ -39,7 +41,10 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: def fake_build_optimizers( - model_parts: list[nn.Module], job_config: JobConfig + model_parts: list[nn.Module], + job_config: JobConfig, + parallel_dims: ParallelDims, + ft_manager: FTManager, ) -> OptimizersContainer: optimizer_kwargs = { "lr": 0.1, @@ -57,11 +62,11 @@ def fake_build_optimizers( class TestTrainSpec: def test_register_train_spec(self): - fake_config = {"fake": None} + fake_config = {"fake": BaseModelArgs()} spec = TrainSpec( name="fake", - cls=FakeModel, - config=fake_config, + model_cls=FakeModel, + model_args=fake_config, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, @@ -78,11 +83,11 @@ def test_register_train_spec(self): new_spec = get_train_spec("fake2") def test_optim_hook(self): - fake_config = {"fake": None} + fake_config = {"fake": BaseModelArgs()} spec = TrainSpec( name="fake2", - cls=FakeModel, - config=fake_config, + model_cls=FakeModel, + model_args=fake_config, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=fake_build_optimizers, @@ -111,21 +116,27 @@ def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec: original_build_optimizers_fn = spec.build_optimizers_fn def my_build_optimizer_fn( - model_parts: list[nn.Module], job_config: JobConfig + model_parts: list[nn.Module], + job_config: JobConfig, + parallel_dims: ParallelDims, + ft_manager: FTManager, ) -> OptimizersContainer: - optimizers = original_build_optimizers_fn(model_parts, job_config) + optimizers = original_build_optimizers_fn( + model_parts, job_config, parallel_dims, ft_manager + ) optimizers.register_step_post_hook( partial(my_hook, model_parts=model_parts) ) return optimizers spec.build_optimizers_fn = my_build_optimizer_fn + return spec apply_to_train_specs(register_optimizer_hook_to_spec) - model = new_spec.cls(BaseModelArgs()) + model = new_spec.model_cls(BaseModelArgs()) model_parts = [model] - optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig()) + optimizers = new_spec.build_optimizers_fn(model_parts, None, None, None) assert optimizers.optimizers[0].__class__.__name__ == "Adam" batch = torch.randn(8, 8) model(batch).sum().backward() diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index ff055cbe75..1bc07f2f27 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -26,8 +26,8 @@ ) from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType from torch.distributed.checkpoint.stateful import Stateful -from torch.utils.data import DataLoader +from torchtitan.components.dataloader import BaseDataLoader from torchtitan.components.ft import FTManager from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer @@ -180,17 +180,19 @@ class CheckpointManager: def __init__( self, - dataloader: DataLoader, + dataloader: BaseDataLoader | None, model_parts: list[nn.Module], optimizers: OptimizersContainer, lr_schedulers: LRSchedulersContainer, states: dict[str, Any], job_config: JobConfig, - ft_manager: FTManager, + ft_manager: FTManager | None = None, ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint - self.ft_manager = ft_manager.manager if ft_manager.enabled else None + self.ft_manager = ( + ft_manager.manager if ft_manager and ft_manager.enabled else None + ) if self.ft_manager: optimizers.init_cache_state_dict() diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index cd3604f297..d2ff514cff 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -15,7 +15,6 @@ StateDictOptions, ) from torch.distributed.checkpoint.stateful import Stateful -from torch.distributed.device_mesh import DeviceMesh from torch.optim import Optimizer from torchtitan.components.ft import FTManager, has_torchft @@ -244,7 +243,6 @@ def build_optimizers( model_parts: list[nn.Module], job_config: JobConfig, parallel_dims: ParallelDims, - world_mesh: DeviceMesh, ft_manager: FTManager, ) -> OptimizersContainer: """Create a OptimizersContainer for the given model parts and job config. diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index 45ecf34f96..6ca11d6711 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -7,17 +7,15 @@ import json -import logging import os from abc import ABC, abstractmethod from typing import Any, Optional, Union from tokenizers import AddedToken, Tokenizer from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger from typing_extensions import override -logger = logging.getLogger(__name__) - class BaseTokenizer(ABC): # base tokenizer interface, for typing purpose mainly diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 904c65ca5c..7f678514b9 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -50,14 +50,12 @@ def __init__( dp_rank: int, tokenizer: BaseTokenizer, parallel_dims: ParallelDims, - world_mesh: torch.distributed.DeviceMesh, loss_fn: LossFunction, validation_context: Generator[None, None, None], maybe_enable_amp: Generator[None, None, None], ): self.job_config = job_config self.parallel_dims = parallel_dims - self.world_mesh = world_mesh self.loss_fn = loss_fn self.validation_dataloader = build_hf_validation_dataloader( job_config=job_config, @@ -78,6 +76,8 @@ def validate( model = model_parts[0] model.eval() + parallel_dims = self.parallel_dims + accumulated_losses = [] device_type = utils.device_type num_steps = 0 @@ -96,13 +96,13 @@ def validate( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=self.world_mesh["cp"], + cp_mesh=parallel_dims.world_mesh["cp"], cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, ) - if self.parallel_dims.cp_enabled + if parallel_dims.cp_enabled else None ) @@ -119,8 +119,10 @@ def validate( # Compute average loss loss = torch.sum(torch.stack(accumulated_losses)) loss /= num_steps - if self.parallel_dims.dp_cp_enabled: - global_avg_loss = dist_utils.dist_mean(loss, self.world_mesh["dp_cp"]) + if parallel_dims.dp_cp_enabled: + global_avg_loss = dist_utils.dist_mean( + loss, parallel_dims.world_mesh["dp_cp"] + ) else: global_avg_loss = loss @@ -144,7 +146,6 @@ def build_validator( dp_rank: int, tokenizer: BaseTokenizer, parallel_dims: ParallelDims, - world_mesh: torch.distributed.DeviceMesh, loss_fn: LossFunction, validation_context: Generator[None, None, None], maybe_enable_amp: Generator[None, None, None], @@ -156,7 +157,6 @@ def build_validator( dp_rank=dp_rank, tokenizer=tokenizer, parallel_dims=parallel_dims, - world_mesh=world_mesh, loss_fn=loss_fn, validation_context=validation_context, maybe_enable_amp=maybe_enable_amp, diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 08986b2207..01e14cc0b0 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -10,6 +10,7 @@ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torchtitan.tools.logging import logger +from torchtitan.tools.utils import device_type __all__ = ["ParallelDims"] @@ -24,7 +25,8 @@ class ParallelDims: pp: int ep: int world_size: int - enable_loss_parallel: bool + + _world_mesh: DeviceMesh = None def __post_init__(self): self._validate() @@ -55,16 +57,16 @@ def _validate(self): # EP would borrow all cp and some dp_shard degree assert ep % cp == 0 and (dp_shard * cp) % ep == 0 - def build_mesh(self, device_type: str) -> DeviceMesh: + def build_mesh(self) -> DeviceMesh: # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel # is not very clean, due to the limited support from DeviceMesh # for creating two staggered meshes. Will improve. if self.ep > 1: - return self._build_mesh_with_ep(device_type) + return self._build_mesh_with_ep() else: - return self._build_mesh_without_ep(device_type) + return self._build_mesh_without_ep() - def _build_mesh_with_ep(self, device_type: str) -> DeviceMesh: + def _build_mesh_with_ep(self) -> DeviceMesh: # With ep, dp_shard and ep are derived submeshes: # dp_shard = dp_shard_mod_ep * dp_shard_in_ep # ep = dp_shard_in_ep * cp @@ -128,7 +130,7 @@ def _build_mesh_with_ep(self, device_type: str) -> DeviceMesh: return mesh - def _build_mesh_without_ep(self, device_type: str) -> DeviceMesh: + def _build_mesh_without_ep(self) -> DeviceMesh: dims = [] names = [] for d, name in zip( @@ -173,6 +175,14 @@ def _build_mesh_without_ep(self, device_type: str) -> DeviceMesh: return mesh + @property + def world_mesh(self) -> str: + # doing late init so ParallelDims can still be used as a lightweight + # dataclass without having to initialize the world mesh + if self._world_mesh is None: + self._world_mesh = self.build_mesh() + return self._world_mesh + @property def dp_enabled(self): return self.dp_replicate > 1 or self.dp_shard > 1 @@ -206,18 +216,24 @@ def pp_enabled(self): return self.pp > 1 @property - def loss_parallel_enabled(self): - return self.tp > 1 and self.enable_loss_parallel + def ep_enabled(self): + return self.ep > 1 @cached_property def non_data_parallel_size(self): return self.cp * self.tp * self.pp - @property - def ep_enabled(self): - return self.ep > 1 + @cached_property + def seq_len_divisor(self): + # Sequence Parallel requires that seq_len be divisible by TP degree. + # https://github.com/pytorch/torchtitan/pull/640#discussion_r1849481001 - @property + # Context Parallel requires that seq_len be divisible by 2 * CP degree, + # when load balancing is enabled (by default). + # https://github.com/pytorch/pytorch/blob/4f62dcc/torch/distributed/tensor/experimental/_attention.py#L1246 + return self.tp * (self.cp * 2) + + @cached_property def dense_params_mesh_ndim(self): - # Note: EP params mesh ndim is 1 more due to the 'ep' mesh + # Note: In dp2ep EP, EP params mesh ndim is 1 more due to the 'ep' mesh return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 3f824d5fec..58c5df0ca5 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -307,7 +307,7 @@ def clip_grad_norm_( error_if_nonfinite: bool = False, foreach: bool | None = None, pp_mesh: DeviceMesh | None = None, - parallel_dims: ParallelDims | None = None, + ep_dense_params_mesh_ndim: int | None = None, ) -> torch.Tensor: """ Clip the gradient norm of an iterable of parameters. @@ -329,14 +329,15 @@ def clip_grad_norm_( If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently fall back to the slow implementation for other device types. Default: ``None`` - pp_mesh: pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages. - parallel_dims: ParallelDims object which contains Expert Parallel related info. + pp_mesh: Pipeline Parallel device mesh. If not None, will reduce gradient norm across PP stages. + ep_dense_params_mesh_ndim: Mesh ndim of the dense params when EP is used. If EP is not used, + set it to ``None``. Returns: Total norm of the parameter gradients (viewed as a single vector). """ - if parallel_dims and parallel_dims.ep_enabled: + if ep_dense_params_mesh_ndim is not None: return _clip_grad_norm_with_ep( parameters, max_norm, @@ -344,7 +345,7 @@ def clip_grad_norm_( error_if_nonfinite, foreach, pp_mesh, - parallel_dims, + ep_dense_params_mesh_ndim, ) if isinstance(parameters, torch.Tensor): @@ -388,10 +389,8 @@ def _clip_grad_norm_with_ep( error_if_nonfinite: bool, foreach: bool | None, pp_mesh: DeviceMesh | None, - parallel_dims: ParallelDims, + dense_params_mesh_ndim: int, ) -> torch.Tensor: - assert parallel_dims.ep_enabled - ep_params = [] non_ep_params = [] ep_grads = [] @@ -401,7 +400,7 @@ def _clip_grad_norm_with_ep( if p.grad is None: continue assert isinstance(p, DTensor) and isinstance(p.grad, DTensor) - if p.device_mesh.ndim == parallel_dims.dense_params_mesh_ndim: + if p.device_mesh.ndim == dense_params_mesh_ndim: non_ep_params.append(p) non_ep_grads.append(p.grad) else: diff --git a/torchtitan/experiments/deepseek_v3/__init__.py b/torchtitan/experiments/deepseek_v3/__init__.py index eb515bcfce..f93d0d80e5 100644 --- a/torchtitan/experiments/deepseek_v3/__init__.py +++ b/torchtitan/experiments/deepseek_v3/__init__.py @@ -42,8 +42,8 @@ register_train_spec( TrainSpec( name="deepseek3", - cls=DeepseekForCausalLM, - config=deepseek_configs, + model_cls=DeepseekForCausalLM, + model_args=deepseek_configs, parallelize_fn=parallelize_deepseek, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, diff --git a/torchtitan/experiments/deepseek_v3/train_ds_real.py b/torchtitan/experiments/deepseek_v3/train_ds_real.py index 398360a6ee..be4a92da53 100644 --- a/torchtitan/experiments/deepseek_v3/train_ds_real.py +++ b/torchtitan/experiments/deepseek_v3/train_ds_real.py @@ -155,8 +155,8 @@ def run_full_model( pp=pp_size, cp=1, tp=1, + ep=1, world_size=world_mesh.size(), - enable_loss_parallel=False, ) metrics_processor = build_metrics_processor( @@ -180,7 +180,7 @@ def run_full_model( loss_fn = cross_entropy_loss # torch.nn.functional.cross_entropy ft_manager = ft.init_ft_manager(config) - optimizer = build_optimizers([model], config, ft_manager) + optimizer = build_optimizers([model], config, proxy_parallel_dims, ft_manager) lr_scheduler = build_lr_schedulers(optimizer, config) diff --git a/torchtitan/experiments/flux/__init__.py b/torchtitan/experiments/flux/__init__.py index 5fe8ba3eeb..12613a7930 100644 --- a/torchtitan/experiments/flux/__init__.py +++ b/torchtitan/experiments/flux/__init__.py @@ -108,8 +108,8 @@ register_train_spec( TrainSpec( name="flux", - cls=FluxModel, - config=flux_configs, + model_cls=FluxModel, + model_args=flux_configs, parallelize_fn=parallelize_flux, pipelining_fn=None, build_optimizers_fn=build_optimizers, diff --git a/torchtitan/experiments/flux/infra/parallelize.py b/torchtitan/experiments/flux/infra/parallelize.py index 460c7f5886..69fef68c50 100644 --- a/torchtitan/experiments/flux/infra/parallelize.py +++ b/torchtitan/experiments/flux/infra/parallelize.py @@ -21,7 +21,6 @@ def parallelize_flux( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): @@ -36,7 +35,7 @@ def parallelize_flux( apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], cpu_offload=job_config.training.enable_cpu_offload, @@ -117,7 +116,6 @@ def apply_ac(model: nn.Module, ac_config): def parallelize_encoders( t5_model: nn.Module, clip_model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): @@ -132,7 +130,7 @@ def parallelize_encoders( reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) fsdp_config = { - "mesh": world_mesh[tuple(dp_mesh_dim_names)], + "mesh": parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], "mp_policy": mp_policy, } if job_config.training.enable_cpu_offload: diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index 269abe1c5d..c328d12b71 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -36,7 +36,7 @@ def __init__(self, job_config: JobConfig): # (mainly for debugging, expect perf loss). # For Flux model, we need distinct seed across FSDP ranks to ensure we randomly dropout prompts info in dataloader dist_utils.set_determinism( - self.world_mesh, + self.parallel_dims.world_mesh, self.device, job_config.training.seed, job_config.training.deterministic, @@ -54,11 +54,11 @@ def __init__(self, job_config: JobConfig): ) # load components - model_config = self.train_spec.config[job_config.model.flavor] + model_args = self.train_spec.model_args[job_config.model.flavor] self.autoencoder = load_ae( job_config.encoder.autoencoder_path, - model_config.autoencoder_params, + model_args.autoencoder_params, device=self.device, dtype=self._dtype, random_init=job_config.training.test_mode, @@ -77,7 +77,6 @@ def __init__(self, job_config: JobConfig): self.t5_encoder, self.clip_encoder = parallelize_encoders( t5_model=self.t5_encoder, clip_model=self.clip_encoder, - world_mesh=self.world_mesh, parallel_dims=self.parallel_dims, job_config=job_config, ) diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index 9f7affc099..798555ae43 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -94,8 +94,8 @@ register_train_spec( TrainSpec( name="llama4", - cls=Transformer, - config=llama4_configs, + model_cls=Transformer, + model_args=llama4_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_llama4_optimizers, diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index d681cd6a16..1b62011286 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -38,7 +38,6 @@ def parallelize_llama( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): @@ -49,6 +48,16 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ if parallel_dims.tp_enabled: if ( @@ -71,7 +80,7 @@ def parallelize_llama( apply_non_moe_tp( model, world_mesh["tp"], - loss_parallel=parallel_dims.loss_parallel_enabled, + loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) diff --git a/torchtitan/experiments/llama4/optimizer.py b/torchtitan/experiments/llama4/optimizer.py index d4829de88a..11870f5fef 100644 --- a/torchtitan/experiments/llama4/optimizer.py +++ b/torchtitan/experiments/llama4/optimizer.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn -from torch.distributed.device_mesh import DeviceMesh from torchtitan.components.ft import FTManager from torchtitan.components.optimizer import build_optimizers, OptimizersContainer @@ -17,10 +16,11 @@ # for MoE auxiliary-loss-free load balancing def _update_expert_bias( model_parts: list[nn.Module], - world_mesh: dict[str, DeviceMesh], parallel_dims: ParallelDims, ): - dp_cp_mesh = world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + dp_cp_mesh = ( + parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + ) # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. for model_part in model_parts: @@ -48,20 +48,18 @@ def build_llama4_optimizers( model_parts: list[nn.Module], job_config: JobConfig, parallel_dims: ParallelDims, - world_mesh: DeviceMesh, ft_manager: FTManager, ) -> OptimizersContainer: optimizers = build_optimizers( model_parts=model_parts, job_config=job_config, parallel_dims=parallel_dims, - world_mesh=world_mesh, ft_manager=ft_manager, ) optimizers.register_step_pre_hook( lambda *args, **kwargs: _update_expert_bias( - model_parts, world_mesh=world_mesh, parallel_dims=parallel_dims + model_parts, parallel_dims=parallel_dims ) ) diff --git a/torchtitan/experiments/multimodal/__init__.py b/torchtitan/experiments/multimodal/__init__.py index f3ba2a2d4c..bbb37d5c59 100644 --- a/torchtitan/experiments/multimodal/__init__.py +++ b/torchtitan/experiments/multimodal/__init__.py @@ -24,8 +24,8 @@ register_train_spec( TrainSpec( name="llama4_multimodal", - cls=MultimodalDecoder, - config=llama4_mm_configs, + model_cls=MultimodalDecoder, + model_args=llama4_mm_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, diff --git a/torchtitan/experiments/simple_fsdp/__init__.py b/torchtitan/experiments/simple_fsdp/__init__.py index 80a2b3c3a3..2b578dd4b3 100644 --- a/torchtitan/experiments/simple_fsdp/__init__.py +++ b/torchtitan/experiments/simple_fsdp/__init__.py @@ -20,8 +20,8 @@ register_train_spec( TrainSpec( name="llama3_simple_fsdp", - cls=SimpleFSDPTransformer, - config=llama3_configs, + model_cls=SimpleFSDPTransformer, + model_args=llama3_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, diff --git a/torchtitan/experiments/simple_fsdp/parallelize.py b/torchtitan/experiments/simple_fsdp/parallelize.py index c386fd3d32..7a94adea39 100644 --- a/torchtitan/experiments/simple_fsdp/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/parallelize.py @@ -7,8 +7,6 @@ import torch import torch.nn as nn -from torch.distributed import DeviceMesh - from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_tp @@ -19,7 +17,6 @@ def parallelize_llama( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): @@ -30,6 +27,16 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel @@ -48,11 +55,11 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise - tp_mesh = world_mesh["tp"] + tp_mesh = parallel_dims.world_mesh["tp"] apply_tp( model, - world_mesh["tp"], - loss_parallel=parallel_dims.loss_parallel_enabled, + tp_mesh, + loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) @@ -84,7 +91,7 @@ def parallelize_llama( model = data_parallel( model, - world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], mode=dp_mode, ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, diff --git a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py index 3c15ce573b..428182655e 100644 --- a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py +++ b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py @@ -38,10 +38,10 @@ def init_test(self): cp=1, tp=1, pp=1, + ep=1, world_size=self.world_size, - enable_loss_parallel=True, ) - self.device_mesh = self.parallel_dims.build_mesh(device_type="cuda") + self.device_mesh = self.parallel_dims.world_mesh def get_input(self): inputs = torch.randn(8, 8).cuda() diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index e86917bbc0..141b740ce6 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -113,8 +113,8 @@ register_train_spec( TrainSpec( name="deepseek_v3", - cls=DeepSeekV3Model, - config=deepseekv3_configs, + model_cls=DeepSeekV3Model, + model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, pipelining_fn=None, build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 44e0bc6bbd..1ba45f86d4 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -26,10 +26,20 @@ # Adapted from llama4/infra/parallelize.py def parallelize_deepseekv3( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + if parallel_dims.tp_enabled: if job_config.parallelism.enable_async_tensor_parallel: # TODO(jianiw): This branch needs to be tested and enabled @@ -54,7 +64,7 @@ def parallelize_deepseekv3( apply_non_moe_tp( model, world_mesh["tp"], - loss_parallel=parallel_dims.loss_parallel_enabled, + loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, enable_async_tp=False, ) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 2e9a11d47b..26895274c2 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -73,8 +73,8 @@ register_train_spec( TrainSpec( name="llama3", - cls=Transformer, - config=llama3_configs, + model_cls=Transformer, + model_args=llama3_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, build_optimizers_fn=build_optimizers, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index df395adcb2..d67e283721 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -34,7 +34,6 @@ def parallelize_llama( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): @@ -45,16 +44,15 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - # TODO: TP currently cannot handle uneven seq_len because we set `use_local_output=True` - # (to use plain Tensors), which was because of the bug in computation of complex - # numbers with DTensors when setting `use_local_output=False`. - # See https://github.com/pytorch/pytorch/issues/130646 and - # https://github.com/pytorch/torchtitan/issues/1306 for details. + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. assert ( - job_config.training.seq_len % (parallel_dims.tp * parallel_dims.cp) == 0 + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree - ({parallel_dims.tp}) and CP degree ({parallel_dims.cp}). + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ if parallel_dims.tp_enabled: @@ -78,7 +76,7 @@ def parallelize_llama( apply_tp( model, world_mesh["tp"], - loss_parallel=parallel_dims.loss_parallel_enabled, + loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) diff --git a/torchtitan/models/llama3/infra/pipeline.py b/torchtitan/models/llama3/infra/pipeline.py index 7ad73a229c..dfb424b5b5 100644 --- a/torchtitan/models/llama3/infra/pipeline.py +++ b/torchtitan/models/llama3/infra/pipeline.py @@ -8,6 +8,7 @@ import copy +import torch import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.pipelining import PipelineStage @@ -25,7 +26,7 @@ generate_split_points, stage_ids_this_rank, ) -from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction +from torchtitan.protocols.train_spec import ParallelizeFunction from torchtitan.tools.logging import logger from ..model.args import TransformerModelArgs @@ -33,15 +34,14 @@ def pipeline_llama( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, - device: DeviceType, + device: torch.device, model_config: TransformerModelArgs, parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: - pp_mesh = world_mesh["pp"] + pp_mesh = parallel_dims.world_mesh["pp"] stages, model_parts = pipeline_llama_manual_split( model, pp_mesh, parallel_dims, job_config, device, model_config @@ -52,7 +52,7 @@ def pipeline_llama( # optimizer, and checkpointing for i, m in enumerate(model_parts): # apply SPMD-style PT-D techniques - m = parallelize_fn(m, world_mesh, parallel_dims, job_config) + m = parallelize_fn(m, parallel_dims, job_config) model_parts[i] = m # NOTE: this is to update the model in the stage # in case the model is modified e.g. by torch.compile @@ -77,7 +77,7 @@ def pipeline_llama_manual_split( pp_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, - device: DeviceType, + device: torch.device, model_config: TransformerModelArgs, ) -> tuple[list[PipelineStage], list[nn.Module]]: """ diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index e7caa89f05..3ee8707714 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -7,13 +7,12 @@ # Copyright (c) Meta Platforms, Inc. All Rights Reserved. from abc import abstractmethod -from collections.abc import Callable, Mapping +from collections.abc import Callable from dataclasses import dataclass from typing import Protocol, TypeAlias import torch import torch.nn as nn -from torch.distributed.device_mesh import DeviceMesh from torch.distributed.pipelining.schedules import _PipelineSchedule from torchtitan.components.dataloader import BaseDataLoader @@ -27,8 +26,6 @@ from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims -DeviceType = int | str | torch.device - @dataclass class BaseModelArgs: @@ -65,6 +62,11 @@ def __init__(self, model_args: BaseModelArgs) -> None: @abstractmethod def init_weights(self, buffer_device: torch.device | None = None) -> None: + """Initialize model weights. + + Args: + buffer_device: Optional device to place buffers on during initialization. + """ pass @@ -76,7 +78,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: TokenizerBuilder: TypeAlias = Callable[..., BaseTokenizer] MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] OptimizersBuilder: TypeAlias = Callable[ - [list[nn.Module], JobConfig, ParallelDims, DeviceMesh, FTManager], + [list[nn.Module], JobConfig, ParallelDims, FTManager], OptimizersContainer, ] LRSchedulersBuilder: TypeAlias = Callable[ @@ -89,8 +91,8 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: @dataclass class TrainSpec: name: str - cls: type[nn.Module] - config: Mapping[str, BaseModelArgs] + model_cls: type[ModelProtocol] + model_args: dict[str, BaseModelArgs] parallelize_fn: ParallelizeFunction pipelining_fn: PipeliningFunction | None build_optimizers_fn: OptimizersBuilder diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index aaa0da8f89..4f10a088af 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -23,10 +23,8 @@ def has_cuda_capability(major: int, minor: int) -> bool: ) -def get_device_info(): - device_type = _get_available_device_type() - if device_type is None: - device_type = "cuda" # default device_type: cuda +def get_device_info() -> tuple[str, torch.device]: + device_type = _get_available_device_type() or "cuda" device_module = _get_device_module(device_type) # default device_module:torch.cuda return device_type, device_module diff --git a/torchtitan/train.py b/torchtitan/train.py index 3dc8a61b28..ea7a8e2efb 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -34,32 +34,33 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): + # core configs job_config: JobConfig - gc_handler: utils.GarbageCollection - parallel_dims: ParallelDims train_spec: train_spec_module.TrainSpec - world_mesh: torch.distributed.DeviceMesh - gradient_accumulation_steps: int + # swappable training components in TrainSpec dataloader: train_spec_module.BaseDataLoader - metrics_processor: train_spec_module.MetricsProcessor - checkpointer: CheckpointManager - train_context: Generator[None, None, None] - model_parts: list[torch.nn.Module] loss_fn: train_spec_module.LossFunction optimizers: train_spec_module.OptimizersContainer lr_schedulers: train_spec_module.LRSchedulersContainer + validator: train_spec_module.BaseValidator + metrics_processor: train_spec_module.MetricsProcessor - validator: train_spec_module.BaseValidator | None + # non-swappable training components + checkpointer: CheckpointManager + ft_manager: ft.FTManager + # runtime utilities + device: torch.device + gc_handler: utils.GarbageCollection + train_context: Generator[None, None, None] + gradient_accumulation_steps: int pp_has_first_stage: bool pp_has_last_stage: bool - device: torch.device - - # states + # additional training states step: int # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @@ -82,7 +83,8 @@ def __init__(self, job_config: JobConfig): # Device has to be set before creating TorchFT manager. device_module.set_device(self.device) - # init distributed + # init distributed and build meshes + dist_utils.init_distributed(job_config) world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism self.parallel_dims = parallel_dims = ParallelDims( @@ -93,12 +95,9 @@ def __init__(self, job_config: JobConfig): pp=parallelism_config.pipeline_parallel_degree, ep=parallelism_config.expert_parallel_degree, world_size=world_size, - enable_loss_parallel=not parallelism_config.disable_loss_parallel, ) - dist_utils.init_distributed(job_config) - # build meshes - self.world_mesh = world_mesh = parallel_dims.build_mesh(device_type=device_type) + world_mesh = parallel_dims.world_mesh if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() @@ -141,8 +140,7 @@ def __init__(self, job_config: JobConfig): ) # build model (using meta init) - model_cls = self.train_spec.cls - model_args = self.train_spec.config[job_config.model.flavor] + model_args = self.train_spec.model_args[job_config.model.flavor] # set the model args from training job configs model_args.update_from_config(job_config, tokenizer) @@ -150,7 +148,7 @@ def __init__(self, job_config: JobConfig): f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) with torch.device("meta"): - model = model_cls(model_args) + model = self.train_spec.model_cls(model_args) # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) @@ -231,7 +229,6 @@ def __init__(self, job_config: JobConfig): self.pp_has_last_stage, ) = self.train_spec.pipelining_fn( model, - world_mesh, parallel_dims, job_config, self.device, @@ -253,9 +250,7 @@ def __init__(self, job_config: JobConfig): ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - model = self.train_spec.parallelize_fn( - model, world_mesh, parallel_dims, job_config - ) + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.to_empty(device=init_device) with torch.no_grad(): @@ -283,7 +278,7 @@ def __init__(self, job_config: JobConfig): # build optimizer after applying parallelisms to the model self.optimizers = self.train_spec.build_optimizers_fn( - self.model_parts, job_config, parallel_dims, world_mesh, self.ft_manager + self.model_parts, job_config, parallel_dims, self.ft_manager ) self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( self.optimizers, job_config @@ -312,8 +307,11 @@ def __init__(self, job_config: JobConfig): ft_manager=self.ft_manager, ) + loss_parallel_enabled = ( + parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel + ) self.train_context = dist_utils.get_train_context( - parallel_dims.loss_parallel_enabled, + loss_parallel_enabled, parallelism_config.enable_compiled_autograd, ) self.maybe_enable_amp = dist_utils.maybe_enable_amp( @@ -335,7 +333,6 @@ def __init__(self, job_config: JobConfig): dp_rank=dp_rank, tokenizer=tokenizer, parallel_dims=parallel_dims, - world_mesh=world_mesh, loss_fn=self.train_spec.build_loss_fn(job_config), validation_context=self.train_context, maybe_enable_amp=self.maybe_enable_amp, @@ -391,7 +388,7 @@ def forward_backward_step( inputs = input_dict["input"] optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=self.world_mesh["cp"], + cp_mesh=parallel_dims.world_mesh["cp"], cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, @@ -457,8 +454,14 @@ def train_step( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, - pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, - parallel_dims=parallel_dims, + pp_mesh=( + parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + ), + ep_dense_params_mesh_ndim=( + parallel_dims.dense_params_mesh_ndim + if parallel_dims.ep_enabled + else None + ), ) self.checkpointer.maybe_wait_for_staging() self.optimizers.step() @@ -480,8 +483,8 @@ def train_step( ) ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None global_avg_loss, global_max_loss = ( - dist_utils.dist_mean(loss, self.world_mesh["dp_cp"], ft_pg), - dist_utils.dist_max(loss, self.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), ) else: global_avg_loss = global_max_loss = loss.detach().item() @@ -546,14 +549,13 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.world_mesh, + world_mesh=self.parallel_dims.world_mesh, ) if torch.distributed.get_rank() == 0: logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) - self.metrics_processor.close() logger.info("Training completed") def state_dict(self) -> dict[str, Any]: @@ -565,6 +567,8 @@ def load_state_dict(self, state_dict: dict[str, Any]): def close(self) -> None: if self.checkpointer: self.checkpointer.close() + if self.metrics_processor: + self.metrics_processor.close() if __name__ == "__main__": From db52d57cad7a0227c1b67014e04c6c214631a6fc Mon Sep 17 00:00:00 2001 From: ankitageorge Date: Mon, 14 Jul 2025 17:17:39 -0400 Subject: [PATCH 007/128] Add support for saving HF format tensors with DCP (#1351) If checkpoint.last_save_in_safetensors_format is set, then save the checkpoint with DCP HF components that will save the checkpoint in .safetensors files instead of regular DCP format on final save. On load, we can decide which type of load to do based on checkpoint type. Successful save: ``` (titan) [ankitageorge@devvm6863.rva0 /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh + NGPU=8 + export LOG_RANK=0,1,2 + LOG_RANK=0,1,2 + CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml + overrides= + '[' 0 -ne 0 ']' + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0,1,2 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] ***************************************** W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] ***************************************** [rank0]:[titan] 2025-07-10 19:20:49,848 - root - INFO - Starting job: Llama 3 8B training [rank1]:[titan] 2025-07-10 19:20:49,985 - root - INFO - Starting job: Llama 3 8B training [rank2]:[titan] 2025-07-10 19:20:51,188 - root - INFO - Starting job: Llama 3 8B training [rank0]:[titan] 2025-07-10 19:20:52,644 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-10 19:20:52,646 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-07-10 19:20:52,650 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:NCCL version 2.27.5+cuda12.9 [rank1]:[titan] 2025-07-10 19:20:52,976 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank1]:[titan] 2025-07-10 19:20:52,979 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank1]:[titan] 2025-07-10 19:20:52,984 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank2]:[titan] 2025-07-10 19:20:53,902 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank2]:[titan] 2025-07-10 19:20:53,905 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank2]:[titan] 2025-07-10 19:20:53,910 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - Preparing c4 dataset from allenai/c4 [rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - Preparing c4 dataset from allenai/c4 [rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - Preparing c4 dataset from allenai/c4 [rank2]:[titan] 2025-07-10 19:21:02,550 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank2]:[titan] 2025-07-10 19:21:02,944 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank2]:[titan] 2025-07-10 19:21:02,968 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank2]:[titan] 2025-07-10 19:21:02,969 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank2]:[titan] 2025-07-10 19:21:02,970 - root - INFO - Applied selective activation checkpointing to the model [rank1]:[titan] 2025-07-10 19:21:03,101 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank0]:[titan] 2025-07-10 19:21:03,142 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank2]:[titan] 2025-07-10 19:21:03,123 - root - INFO - Applied FSDP to the model [rank1]:[titan] 2025-07-10 19:21:03,491 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank1]:[titan] 2025-07-10 19:21:03,515 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank1]:[titan] 2025-07-10 19:21:03,516 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank1]:[titan] 2025-07-10 19:21:03,517 - root - INFO - Applied selective activation checkpointing to the model [rank0]:[titan] 2025-07-10 19:21:03,550 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-1921 [rank0]:[titan] 2025-07-10 19:21:03,551 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank0]:[titan] 2025-07-10 19:21:03,574 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 19:21:03,575 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:[titan] 2025-07-10 19:21:03,576 - root - INFO - Applied selective activation checkpointing to the model [rank1]:[titan] 2025-07-10 19:21:03,675 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-10 19:21:03,732 - root - INFO - Applied FSDP to the model [rank2]:[titan] 2025-07-10 19:21:03,813 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank2]:[titan] 2025-07-10 19:21:03,813 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank2]:[titan] 2025-07-10 19:21:03,814 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank2]:[titan] 2025-07-10 19:21:03,817 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2. [rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Mixed precision training is handled by fully_shard [rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200). [rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Training starts at step 1. [rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank1]:[titan] 2025-07-10 19:21:04,369 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank1]:[titan] 2025-07-10 19:21:04,373 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2. [rank0]:[titan] 2025-07-10 19:21:04,335 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank0]:[titan] 2025-07-10 19:21:04,340 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2. [rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Mixed precision training is handled by fully_shard [rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200). [rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Training starts at step 1. [rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200). [rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 1,046 tflops: 60.58 mfu: 19.42% [rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1 [rank0]:[titan] 2025-07-10 19:21:11,408 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 971 tflops: 56.23 mfu: 18.02% [rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - Calling checkpoint save after step 1 [rank2]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank1]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 1,038 tflops: 60.13 mfu: 19.27% [rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1 [rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank2]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Calling checkpoint save after step 2 [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2. [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291 [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096]) [rank0]:[titan] 2025-07-10 19:21:14,015 - root - INFO - Calling checkpoint save after step 2 [rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2. [rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291 [rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096]) [rank1]:[titan] 2025-07-10 19:21:14,023 - root - INFO - Calling checkpoint save after step 2 [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2. [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Num keys before parsing 291, after 291 [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096]) [rank0]:Done writing metadata. Took %.2f secs. 0.026559114456176758 [rank0]:Done writing data. Took %.2f secs. 66.62590146064758 [rank0]:Done consolidating. Took %.2f secs. 66.62735033035278 [rank0]:time taken for all reduce: 141.72666668891907 [rank1]:time taken for all reduce: 141.73284125328064 [rank2]:time taken for all reduce: 141.72900009155273 [rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds. [rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Training completed [rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Destroying the purge thread. [rank0]:[titan] 2025-07-10 19:23:36,827 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds. [rank0]:[titan] 2025-07-10 19:23:36,828 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds. [rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Training completed [rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Destroying the purge thread. [rank2]:[titan] 2025-07-10 19:23:37,243 - root - INFO - Process group destroyed. [rank0]:[titan] 2025-07-10 19:23:38,828 - root - INFO - Training completed [rank0]:[titan] 2025-07-10 19:23:38,829 - root - INFO - Destroying the purge thread. [rank1]:[titan] 2025-07-10 19:23:39,503 - root - INFO - Process group destroyed. [rank0]:[titan] 2025-07-10 19:23:39,705 - root - INFO - Process group destroyed. ``` Successful load: ``` (titan) [ankitageorge@devvm6863.rva0 /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh + NGPU=8 + export LOG_RANK=0 + LOG_RANK=0 + CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml + overrides= + '[' 0 -ne 0 ']' + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] ***************************************** W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] ***************************************** [rank0]:[titan] 2025-07-10 20:56:24,765 - root - INFO - Starting job: Llama 3 8B training [rank0]:NCCL version 2.27.5+cuda12.9 [rank0]:[titan] 2025-07-10 20:56:27,746 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-10 20:56:27,748 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-07-10 20:56:27,753 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:[titan] 2025-07-10 20:56:36,070 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank0]:[titan] 2025-07-10 20:56:36,430 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-2056 [rank0]:[titan] 2025-07-10 20:56:36,431 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank0]:[titan] 2025-07-10 20:56:36,452 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 20:56:36,454 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:[titan] 2025-07-10 20:56:36,455 - root - INFO - Applied selective activation checkpointing to the model [rank0]:[titan] 2025-07-10 20:56:36,598 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-10 20:56:37,138 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 1000 (warmup 200). [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Loading the checkpoint from ./outputs/checkpoint/step-3. [rank0]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/checkpoint/hf_storage.py:259: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:1579.) [rank0]: tensor = torch.frombuffer( [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds. [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Finished loading the checkpoint in 27.21 seconds. [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - step: 1 loss: 12.0247 grad_norm: 42.7524 memory: 42.12GiB(53.23%) tps: 236 tflops: 13.67 mfu: 4.38% [rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 ``` --------- Co-authored-by: ankitageorge Co-authored-by: ankitageorge Co-authored-by: tianyu-l <150487191+tianyu-l@users.noreply.github.com> --- .ci/docker/requirements.txt | 1 + docs/checkpoint.md | 5 + tests/integration_tests.py | 16 +++ tests/unit_tests/test_checkpoint.py | 4 +- torchtitan/components/checkpoint.py | 187 +++++++++++++++++++++++++--- torchtitan/config_manager.py | 11 ++ 6 files changed, 202 insertions(+), 22 deletions(-) diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index c33bfe4d84..9bf30b502c 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -7,3 +7,4 @@ wandb fsspec tyro tokenizers >= 0.15.0 +safetensors diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 5275db1a25..0ffcafb02c 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -85,3 +85,8 @@ e.g. ```bash NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 ``` + + +## How to load / save a checkpoint in HF safetensors format +For save, users need to set `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_weights_only` to save the last checkpoint in HF format (intermediate ones are always in DCP format). +For load, users need to either put the checkpoint in the `step-0` folder if using `--checkpoint.folder`, or specify `--checkpoint.initial_load_path` to load from a different folder. They also need to set `--checkpoint.initial_load_model_weights_only` to load the checkpoint in HF format. diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 6218b5a5f6..f3000eef7e 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -118,6 +118,22 @@ def build_test_list(): "Checkpoint Integration Test - Save Load Full Checkpoint", "full_checkpoint", ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.folder hf_checkpoint", + "--checkpoint.last_save_in_safetensors_format", + "--checkpoint.last_save_model_weights_only", + ], + [ + "--checkpoint.enable_checkpoint", + "--checkpoint.initial_load_path artifacts-to-be-uploaded/full_checkpoint_hf_safetensors/hf_checkpoint/step-10/", + ], + ], + "Checkpoint Integration Test - save load full checkpoint in HF safetensors format", + "full_checkpoint_hf_safetensors", + ), OverrideDefinitions( [ [ diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index 3317a51fed..2f8127bfd6 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -144,7 +144,7 @@ def tearDown(self): shutil.rmtree(self.base_temp_dir) time.sleep(0.1) - def fake_save(self, state_dict: dict, checkpoint_id: str): + def fake_save(self, state_dict: dict, checkpoint_id: str, storage_writer=None): os.makedirs(checkpoint_id, exist_ok=True) sd_to_save = {} for key, val in state_dict.items(): @@ -584,7 +584,7 @@ def __init__(self): @mock.patch("torchtitan.components.checkpoint.dcp.load") @mock.patch("torchtitan.components.checkpoint.dcp.save") def test_verify_prefix(self, mock_save, mock_load, mock_rank): - def fake_save(state_dict: dict, checkpoint_id: str): + def fake_save(state_dict: dict, checkpoint_id: str, storage_writer=None): self.assertIn("bias", state_dict) self.assertIn("weight", state_dict) # No model prefix diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 1bc07f2f27..f71417de80 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -12,12 +12,17 @@ import shutil import threading import time +from concurrent.futures import Future from typing import Any import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn +from torch.distributed.checkpoint import ( + HuggingFaceStorageReader, + HuggingFaceStorageWriter, +) from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, @@ -49,6 +54,11 @@ class AsyncMode(str, enum.Enum): ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" +class CheckpointType(str, enum.Enum): + DCP = "DCP" + SAFETENSORS = "safetensors" + + # For now, we will manually pop the freqs_cis buffer, as we made this permanent # temporarily and we don't want to include it in the exported state_dict. # Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 @@ -92,12 +102,6 @@ class SaveDone: pass -@torch.no_grad() -def save_with_gc(state, checkpoint_id): - dcp.save(state, checkpoint_id=checkpoint_id) - GarbageCollection.collect("GC collection invoked by checkpointer.") - - def purge_thread(purge_queue: queue.Queue): """Thread to purge the old checkpoints. @@ -190,6 +194,9 @@ def __init__( ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint + self.last_save_in_safetensors_format = ( + ckpt_config.last_save_in_safetensors_format + ) self.ft_manager = ( ft_manager.manager if ft_manager and ft_manager.enabled else None ) @@ -314,6 +321,98 @@ def close(self): if self.stager is not None: self.stager.close() + @torch.no_grad() + def dcp_save( + self, + state_dict: dict[str, Any], + checkpoint_id: str, + async_mode: AsyncMode, + enable_garbage_collection: bool = False, + save_in_safetensors_format: bool = False, + ) -> Future | None: + """Save the checkpoint with dcp. + Args: + state_dict (dict): The state dict to save. + checkpoint_id (str): The checkpoint id to save. + async_mode (AsyncMode): Whether the checkpoint is async. + enable_garbage_collection (bool): Whether to enable garbage collection after save. + save_in_safetensors_format (bool): Whether to save in safetensors format. + + Returns: + Future: The future object if the checkpoint is async, otherwise None. + """ + + ret: Future | None = None + + storage_writer: HuggingFaceStorageWriter | None = None + checkpoint_save_id: str | None = None + if save_in_safetensors_format: + fqn_to_index_mapping = {} + num_fqns_per_file = 30 + # the use of 30 is just a heuristic for now. + # Once these fqns map to HF ones, we can use the fqn mapping + # from the model.safetensors.index.json file + for i, key in enumerate(state_dict.keys()): + group_num = (i // num_fqns_per_file) + 1 + fqn_to_index_mapping[key] = group_num + + storage_writer = HuggingFaceStorageWriter( + path=checkpoint_id, + save_distributed=True, + fqn_to_index_mapping=fqn_to_index_mapping, + enable_consolidation=True, + thread_count_consolidation=5, + ) + else: + checkpoint_save_id = checkpoint_id + + if async_mode == AsyncMode.ASYNC: + ret = dcp.async_save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=checkpoint_save_id, + process_group=self.pg, + ) + elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: + ret = dcp.async_save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=checkpoint_save_id, + process_group=self.pg, + async_checkpointer_type=AsyncCheckpointerType.PROCESS, + async_stager=self.stager, + ) + else: + ret = dcp.save( + state_dict, + storage_writer=storage_writer, + checkpoint_id=checkpoint_save_id, + ) + + if enable_garbage_collection: + GarbageCollection.collect("GC collection invoked by checkpointer.") + + return ret + + def dcp_load( + self, + state_dict: dict[str, Any], + checkpoint_id: str, + checkpoint_type: CheckpointType, + ) -> None: + """Load the checkpoint with dcp. + Args: + state_dict (dict): The state dict to load. + checkpoint_id (str): The checkpoint id to load. + hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. + """ + + if checkpoint_type == CheckpointType.SAFETENSORS: + storage_reader = HuggingFaceStorageReader(path=checkpoint_id) + dcp.load(state_dict, storage_reader=storage_reader) + else: + dcp.load(state_dict, checkpoint_id=checkpoint_id) + @torch.no_grad() def save(self, curr_step: int, last_step: bool = False) -> None: """Save the checkpoint for the current step. @@ -354,23 +453,26 @@ def save(self, curr_step: int, last_step: bool = False) -> None: GarbageCollection.collect("GC collection invoked by checkpointer.") if self.stager is None: self.stager = DefaultStager(StagingOptions(True, True, True, True)) - result = dcp.async_save( + result = self.dcp_save( states, checkpoint_id=checkpoint_id, - process_group=self.pg, - async_checkpointer_type=AsyncCheckpointerType.PROCESS, - async_stager=self.stager, + async_mode=self.async_mode, ) self.save_future = result.upload_completion self.staging_future = result.staging_completion elif self.async_mode == AsyncMode.ASYNC: GarbageCollection.collect("GC collection invoked by checkpointer.") - self.save_future = dcp.async_save( - states, checkpoint_id=checkpoint_id, process_group=self.pg + self.save_future = self.dcp_save( + states, checkpoint_id=checkpoint_id, async_mode=self.async_mode ) GarbageCollection.collect("GC collection invoked by checkpointer.") else: - save_with_gc(states, checkpoint_id=checkpoint_id) + self.dcp_save( + states, + checkpoint_id=checkpoint_id, + async_mode=AsyncMode.DISABLED, + enable_garbage_collection=True, + ) self._purge_stale_checkpoints() logger.info( @@ -432,10 +534,19 @@ def load(self, step: int = -1) -> bool: f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found." ) + checkpoint_type = self._find_checkpoint_type(checkpoint_id) + if checkpoint_type == CheckpointType.SAFETENSORS: + assert ( + model_only + ), "Only model weights can be loaded when loading from safetensors checkpoint." logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) - dcp.load(states, checkpoint_id=checkpoint_id) + self.dcp_load( + states, + checkpoint_id=checkpoint_id, + checkpoint_type=checkpoint_type, + ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -470,13 +581,33 @@ def _find_load_step(self, folder: str = "") -> int: for filename in os.listdir(folder): match = re.search(pattern, filename) - metadata_probe = os.path.join(folder, filename, ".metadata") - if match and os.path.isfile(metadata_probe): + dcp_metadata_probe = os.path.join(folder, filename, ".metadata") + safetensors_metadata_probe = os.path.join( + folder, filename, "model.safetensors.index.json" + ) + if match and os.path.isfile(dcp_metadata_probe): + step_counts.append(int(match.group(1))) + elif match and os.path.isfile(safetensors_metadata_probe): step_counts.append(int(match.group(1))) if not step_counts: return -1 return max(step_counts) + def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType: + """Find the checkpoint type for the given id. + + Args: + checkpoint_id (str): The folder to find the checkpoint type for. + + Returns: + CheckpointType: The checkpoint type for the given folder. + """ + + for filename in os.listdir(checkpoint_id): + if filename == "model.safetensors.index.json": + return CheckpointType.SAFETENSORS + return CheckpointType.DCP + def _ft_folder(self) -> str: return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") @@ -488,8 +619,8 @@ def _ft_save(self, step: int) -> None: begin = time.monotonic() self._async_wait() checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - self.save_future = dcp.async_save( - self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg + self.save_future = self.dcp_save( + self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC ) logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.") @@ -501,7 +632,12 @@ def _ft_load(self) -> None: begin = time.monotonic() logger.info(f"Loading the FT checkpoint at step {step}.") checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) - dcp.load(self.ft_states, checkpoint_id=checkpoint_id) + self.dcp_load( + self.ft_states, + checkpoint_id=checkpoint_id, + # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. + checkpoint_type=CheckpointType.DCP, + ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -570,7 +706,18 @@ def _save_last_step(self, curr_step: int) -> None: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") states = self._flattened_model_states_sd() - save_with_gc(states, checkpoint_id=self._create_checkpoint_id(curr_step)) + if self.last_save_in_safetensors_format: + assert ( + self.last_save_model_weights_only + ), "Only model weights can be saved when saving in safetensors format." + + self.dcp_save( + states, + checkpoint_id=self._create_checkpoint_id(curr_step), + async_mode=AsyncMode.DISABLED, + enable_garbage_collection=True, + save_in_safetensors_format=self.last_save_in_safetensors_format, + ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: if not self.enable_checkpoint: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 5f1a1e8b7f..07c92b6f92 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -475,6 +475,17 @@ class Checkpoint: for many steps or checkpointing too frequently. The default value is False. """ + last_save_in_safetensors_format: bool = False + """ + Enable the use of safetensors format for checkpointing. This will save the final checkpoints + in safetensors format instead of the default DCP format. There will be a performance + cost in using this as we need to consolidate the sharded tensors to full tensors as + a separate step. last_save_model_weights_only must be true because safetensors doesn't + support saving non tensors. On load, this argument isn't needed as we will detect + whether the loaded checkpoint is in safetensors format or not. + The default value is False. + """ + @dataclass class ActivationCheckpoint: From 27e3ad85a96be757f45d08c0428f49417d6cb8ba Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Tue, 15 Jul 2025 11:17:36 -0400 Subject: [PATCH 008/128] Add Github workflow to build and publish wheel to PyTorch Index nightly (#1392) --- .github/scripts/update_version.sh | 11 ++++++ .github/workflows/build_whl_and_publish.yaml | 40 ++++++++++++++++++++ README.md | 22 +++++++---- 3 files changed, 66 insertions(+), 7 deletions(-) create mode 100644 .github/scripts/update_version.sh create mode 100644 .github/workflows/build_whl_and_publish.yaml diff --git a/.github/scripts/update_version.sh b/.github/scripts/update_version.sh new file mode 100644 index 0000000000..0cb6009225 --- /dev/null +++ b/.github/scripts/update_version.sh @@ -0,0 +1,11 @@ +version_file="assets/version.txt" +init_file="torchtitan/__init__.py" +if [[ -n "$BUILD_VERSION" ]]; then + # Update the version in version.txt + echo "$BUILD_VERSION" > "$version_file" + # Create a variable named __version__ at the end of __init__.py + echo "__version__ = \"$BUILD_VERSION\"" >> "$init_file" +else + echo "Error: BUILD_VERSION environment variable is not set or empty." + exit 1 +fi diff --git a/.github/workflows/build_whl_and_publish.yaml b/.github/workflows/build_whl_and_publish.yaml new file mode 100644 index 0000000000..0b9b0ebc88 --- /dev/null +++ b/.github/workflows/build_whl_and_publish.yaml @@ -0,0 +1,40 @@ +name: Build nightly wheels and publish to PyTorch Index + +on: + push: + branches: + - nightly + workflow_dispatch: + +permissions: + id-token: write + contents: read + +jobs: + generate-matrix: + if: github.repository_owner == 'pytorch' + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + with: + package-type: wheel + os: linux + test-infra-repository: pytorch/test-infra + test-infra-ref: main + with-cuda: enable + with-rocm: enable + python-versions: '["3.10", "3.11", "3.12"]' + build: + needs: generate-matrix + name: ${{ matrix.repository }} + uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main + strategy: + fail-fast: false + with: + repository: pytorch/torchtitan + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + package-name: torchtitan + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: .github/scripts/update_version.sh + trigger-event: ${{ github.event_name }} + build-platform: 'python-build-package' diff --git a/README.md b/README.md index d75dca769e..fa9b3ba35c 100644 --- a/README.md +++ b/README.md @@ -89,9 +89,7 @@ You may want to see how the model is defined or how parallelism techniques are a ## Installation -### Nightly - -Coming soon. +> [Install PyTorch](https://pytorch.org/get-started/locally/) before proceeding. ### Stable @@ -103,14 +101,24 @@ Or via conda: ```sh conda install conda-forge::torchtitan ``` -### Sources - + +### Nightly + +> This method requires the nightly build of PyTorch. + +```sh +pip install --pre torchtitan --index-url https://download.pytorch.org/whl/nightly/cu126 +``` +You can replace `cu126` with another version of cuda (e.g. `cu128`) or an AMD GPU (e.g. `rocm6.3`). + +### From source + +> This method requires the nightly build of PyTorch or PyTorch built from source. + ```bash git clone https://github.com/pytorch/torchtitan cd torchtitan pip install -r requirements.txt -pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall -[For AMD GPU] pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3 --force-reinstall ``` ### Downloading a tokenizer From 53f6642237fb93321f8bf9357d6cbaab54b22da1 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Tue, 15 Jul 2025 13:49:37 -0700 Subject: [PATCH 009/128] Validator integration with current metrics processor for logging (#1395) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Integrated the validator together with metrics processor for better metrics logging. Key changes: - Metrics processor is passed to validator within training loop - Validator can reuse metrics processor's built-in functionalities such as memory profiling, throughput tracking, and tensorboard/wandb logging This is how the new logging looks from terminal: Screenshot 2025-07-14 at 3 22 56 PM --- torchtitan/components/metrics.py | 35 ++++++++++++++++++++++++++++++- torchtitan/components/validate.py | 14 ++++++++----- torchtitan/tools/utils.py | 1 + torchtitan/train.py | 3 ++- 4 files changed, 46 insertions(+), 7 deletions(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 3b290addc7..0ccf9fd760 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -403,7 +403,7 @@ def log( f"{color.red}step: {step:2} " f"{color.green}loss: {global_avg_loss:7.4f} " f"{color.orange}grad_norm: {grad_norm:7.4f} " - f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" + f"{color.turquoise}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" f"({device_mem_stats.max_reserved_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " f"{color.cyan}tflops: {tflops:,.2f} " @@ -415,6 +415,39 @@ def log( self.time_last_log = time.perf_counter() self.device_memory_monitor.reset_peak_stats() + def log_validation(self, loss: float, step: int): + time_delta = time.perf_counter() - self.time_last_log + + device_mem_stats = self.device_memory_monitor.get_peak_stats() + + # tokens per second per device, abbreviated as tps + tps = self.ntokens_since_last_log / ( + time_delta * self.parallel_dims.non_data_parallel_size + ) + + metrics = { + "validation_metrics/loss": loss, + "validation_metrics/throughput(tps)": tps, + "validation_metrics/memory/max_active(GiB)": device_mem_stats.max_active_gib, + "validation_metrics/memory/max_active(%)": device_mem_stats.max_active_pct, + "validation_metrics/memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, + "validation_metrics/memory/max_reserved(%)": device_mem_stats.max_reserved_pct, + } + self.logger.log(metrics, step) + + color = self.color + logger.info( + f"{color.yellow}validate step: {step:2} " + f"{color.green}loss: {loss:7.4f} " + f"{color.turquoise}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%) " + f"{color.blue}tps: {round(tps):,}{color.reset}" + ) + + self.ntokens_since_last_log = 0 + self.time_last_log = time.perf_counter() + self.device_memory_monitor.reset_peak_stats() + def close(self): self.logger.close() diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 7f678514b9..1c4e3dbb58 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -11,12 +11,12 @@ from torch.distributed.fsdp import FSDPModule from torchtitan.components.dataloader import BaseDataLoader from torchtitan.components.loss import LossFunction +from torchtitan.components.metrics import MetricsProcessor from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config_manager import JobConfig from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.tools import utils -from torchtitan.tools.logging import logger class BaseValidator: @@ -53,6 +53,7 @@ def __init__( loss_fn: LossFunction, validation_context: Generator[None, None, None], maybe_enable_amp: Generator[None, None, None], + metrics_processor: MetricsProcessor, ): self.job_config = job_config self.parallel_dims = parallel_dims @@ -65,11 +66,13 @@ def __init__( ) self.validation_context = validation_context self.maybe_enable_amp = maybe_enable_amp + self.metrics_processor = metrics_processor @torch.no_grad() def validate( self, model_parts: list[nn.Module], + step: int, ) -> dict[str, float]: # Set model to eval mode # TODO: currently does not support pipeline parallelism @@ -89,6 +92,7 @@ def validate( ): break + self.metrics_processor.ntokens_since_last_log += labels.numel() for k, v in input_dict.items(): input_dict[k] = v.to(device_type) inputs = input_dict["input"] @@ -124,11 +128,9 @@ def validate( loss, parallel_dims.world_mesh["dp_cp"] ) else: - global_avg_loss = loss + global_avg_loss = loss.item() - logger.info( - f"Validation completed. Average loss: {global_avg_loss:.4f} over {num_steps} batches" - ) + self.metrics_processor.log_validation(loss=global_avg_loss, step=step) # Reshard after run forward pass # This is to ensure the model weights are sharded the same way for checkpoint saving. @@ -149,6 +151,7 @@ def build_validator( loss_fn: LossFunction, validation_context: Generator[None, None, None], maybe_enable_amp: Generator[None, None, None], + metrics_processor: MetricsProcessor | None = None, ) -> BaseValidator: """Build a simple validator focused on correctness.""" return Validator( @@ -160,4 +163,5 @@ def build_validator( loss_fn=loss_fn, validation_context=validation_context, maybe_enable_amp=maybe_enable_amp, + metrics_processor=metrics_processor, ) diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index 4f10a088af..0d29f9db6a 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -134,6 +134,7 @@ class Color: white = "\033[37m" reset = "\033[39m" orange = "\033[38;2;180;60;0m" + turquoise = "\033[38;2;54;234;195m" @dataclass(frozen=True) diff --git a/torchtitan/train.py b/torchtitan/train.py index ea7a8e2efb..3eae78981e 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -336,6 +336,7 @@ def __init__(self, job_config: JobConfig): loss_fn=self.train_spec.build_loss_fn(job_config), validation_context=self.train_context, maybe_enable_amp=self.maybe_enable_amp, + metrics_processor=self.metrics_processor, ) logger.info( @@ -530,7 +531,7 @@ def train(self): self.job_config.validation.enabled and self.validator.should_validate(self.step) ): - self.validator.validate(self.model_parts) + self.validator.validate(self.model_parts, self.step) self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) From 23b87369f30dbef7e5cb50ebb5ba46ff61c5bede Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Tue, 15 Jul 2025 13:52:19 -0700 Subject: [PATCH 010/128] refactor FTManager (#1397) This PR refactors `FTManager` to: 1. simplify construction logic 2. expose simpler interfact to `train.py` 3. make it optional when building optimizer and some other minor improvements. --- scripts/estimate/estimation.py | 4 +- torchtitan/components/ft.py | 136 ++++++++------------- torchtitan/components/optimizer.py | 6 +- torchtitan/distributed/utils.py | 2 +- torchtitan/experiments/llama4/optimizer.py | 2 +- torchtitan/protocols/train_spec.py | 2 +- torchtitan/train.py | 30 ++--- 7 files changed, 66 insertions(+), 116 deletions(-) diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index aade405b4d..cec91fdcdd 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -15,7 +15,6 @@ from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker from torch.testing._internal.distributed.fake_pg import FakeStore -from torchtitan.components.ft import init_ft_manager from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers from torchtitan.config_manager import ConfigManager, JobConfig @@ -116,8 +115,7 @@ def estimate_memory(job_config: JobConfig): model.train() # build optimizer after applying parallelisms to the model - ft_manager = init_ft_manager(job_config) - optimizers = build_optimizers([model], job_config, parallel_dims, ft_manager) + optimizers = build_optimizers([model], job_config, parallel_dims) lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft.py index 946fc46380..60e4d2f80d 100644 --- a/torchtitan/components/ft.py +++ b/torchtitan/components/ft.py @@ -4,19 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy import importlib from contextlib import nullcontext from typing import ContextManager, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist -import torch.distributed._functional_collectives as funcol from torch.distributed._composable.fsdp.fully_shard import FSDPModule -from torch.distributed.device_mesh import DeviceMesh from torch.distributed.distributed_c10d import ReduceOp -from torch.distributed.tensor import DTensor -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import FaultTolerance as FTConfig if importlib.util.find_spec("torchft") is not None: import torchft as ft @@ -32,14 +28,32 @@ class FTManager: def __init__( self, - manager: Optional["ft.Manager"], - group_size: int = 1, - replica_id: int = 0, + ft_config: FTConfig, ) -> None: - self._manager = manager - self.group_size = group_size - self.replica_id = replica_id - if has_torchft and manager is not None: + if not ft_config.enable: + self._manager = None + return + + if not has_torchft: + raise ImportError("torchft is not installed. Please install it.") + + pg = ft.ProcessGroupNCCL() + + # If the training method is specific, then the quorum should be synchronous + self.use_async_quorum = ft_config.semi_sync_method is None + + self._manager = ft.Manager( + pg=pg, + min_replica_size=ft_config.min_replica_size, + load_state_dict=None, + state_dict=None, + use_async_quorum=self.use_async_quorum, + replica_id=f"torchtitan_ft_{ft_config.replica_id}", + ) + self.group_size = ft_config.group_size + self.replica_id = ft_config.replica_id + + if self.use_async_quorum: self.replicate_pg = ft.process_group.ManagedProcessGroup(self._manager) self.replicate_pg.register("dp_replicate") @@ -53,85 +67,37 @@ def manager(self) -> "ft.Manager": return self._manager def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]: - return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank - - def set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None: - def all_reduce_hook(output): - dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG) - - def apply_set_all_reduce_hook(m): - if isinstance(m, FSDPModule): - m.set_all_reduce_hook(all_reduce_hook) - - for part in model_parts: - part.apply(apply_set_all_reduce_hook) - - -def init_ft_manager(job: JobConfig) -> FTManager: - """Initialize the FT manager if TorchFT is enabled. - - Args: - job (JobConfig): The job configuration. - - Returns: - FTManager: A wrapper around TorchFT.Manager - """ - if not job.fault_tolerance.enable: - return FTManager(None) - - if not has_torchft: - raise ImportError("torchft is not installed. Please install it.") - - if job.fault_tolerance.min_replica_size < 1: - raise ValueError("At least one FT replica is required.") - - pg = ft.ProcessGroupNCCL() + if self.enabled: + return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank + else: + return dp_degree, dp_rank - # If the training method is specific, then the quorum should be synchronous - use_async_quorum = job.fault_tolerance.semi_sync_method is None + def maybe_set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None: + if self.enabled and self.use_async_quorum: - return FTManager( - ft.Manager( - pg=pg, - min_replica_size=job.fault_tolerance.min_replica_size, - load_state_dict=None, - state_dict=None, - use_async_quorum=use_async_quorum, - replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}", - ), - group_size=job.fault_tolerance.group_size, - replica_id=job.fault_tolerance.replica_id, - ) - - -def ft_dist_reduce( - x: torch.Tensor, reduceOp: str, mesh: DeviceMesh -) -> tuple[torch.Tensor, str, DeviceMesh]: - if has_torchft and isinstance(mesh, ft.device_mesh._FlattenDeviceMesh): - x = funcol.all_reduce( - x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg - ) - return x, reduceOp, mesh.managed_mesh.mesh - return x, reduceOp, mesh + def all_reduce_hook(output): + dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG) + def apply_set_all_reduce_hook(m): + if isinstance(m, FSDPModule): + m.set_all_reduce_hook(all_reduce_hook) -def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor: - if has_torchft: - mesh = total_norm._spec.mesh - if isinstance(mesh, ft.device_mesh.ManagedDeviceMesh): - # The gradients along the replicated dim has already been reduced. - # So we don't need another reducution beforing removing the - # replicate dimension - local_tensor = total_norm.to_local() - placements = list(copy.copy(total_norm._spec.placements)) - placements.pop(mesh.replicate_dim) - return DTensor.from_local(local_tensor, mesh.mesh, placements) + for model_part in model_parts: + model_part.apply(apply_set_all_reduce_hook) - return total_norm + @property + def loss_sync_pg( + self, + ) -> Optional["ft.process_group.ManagedProcessGroup"]: + if self.enabled and self.use_async_quorum: + return self.replicate_pg + else: + # skip loss sync when using semi-sync training + return None def maybe_semi_sync_training( - config: JobConfig, + ft_config: FTConfig, ft_manager: FTManager, model_parts: list[torch.nn.Module], optimizer: torch.optim.Optimizer, @@ -139,10 +105,8 @@ def maybe_semi_sync_training( """ If TorchFT is enabled and the config is set, use semi_sync_method """ - ft_config = config.fault_tolerance semi_sync_method = ft_config.semi_sync_method - torchft_enabled = ft_config.enable - if torchft_enabled and semi_sync_method is not None: + if ft_config.enable and semi_sync_method is not None: from torchft import local_sgd assert ( diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index d2ff514cff..ee87888d74 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -243,7 +243,7 @@ def build_optimizers( model_parts: list[nn.Module], job_config: JobConfig, parallel_dims: ParallelDims, - ft_manager: FTManager, + ft_manager: FTManager | None = None, ) -> OptimizersContainer: """Create a OptimizersContainer for the given model parts and job config. @@ -273,7 +273,7 @@ def build_optimizers( raise NotImplementedError( "Optimizers in backward is not supported with Pipeline Parallel." ) - if ft_manager.enabled: + if ft_manager and ft_manager.enabled: raise NotImplementedError( "TorchFT is not supported with optimizers in backward." ) @@ -313,7 +313,7 @@ def build_optimizers( model_parts, optimizer_cls, optimizer_kwargs ) - if ft_manager.enabled: + if ft_manager and ft_manager.enabled: return FTOptimizersContainer( model_parts, optimizer_cls, diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 58c5df0ca5..e25794a240 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -29,7 +29,7 @@ def _dist_reduce( x: torch.Tensor, reduceOp: str, mesh: DeviceMesh, - extra_pg: dist.ProcessGroup | None = None, + extra_pg: dist.ProcessGroup | None, ) -> float: """Perform distributed reduction on a tensor. diff --git a/torchtitan/experiments/llama4/optimizer.py b/torchtitan/experiments/llama4/optimizer.py index 11870f5fef..3b20f6b1d9 100644 --- a/torchtitan/experiments/llama4/optimizer.py +++ b/torchtitan/experiments/llama4/optimizer.py @@ -48,7 +48,7 @@ def build_llama4_optimizers( model_parts: list[nn.Module], job_config: JobConfig, parallel_dims: ParallelDims, - ft_manager: FTManager, + ft_manager: FTManager | None = None, ) -> OptimizersContainer: optimizers = build_optimizers( model_parts=model_parts, diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 3ee8707714..0e376b2f65 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -78,7 +78,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: TokenizerBuilder: TypeAlias = Callable[..., BaseTokenizer] MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] OptimizersBuilder: TypeAlias = Callable[ - [list[nn.Module], JobConfig, ParallelDims, FTManager], + [list[nn.Module], JobConfig, ParallelDims, FTManager | None], OptimizersContainer, ] LRSchedulersBuilder: TypeAlias = Callable[ diff --git a/torchtitan/train.py b/torchtitan/train.py index 3eae78981e..3fc9a2560a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -13,10 +13,10 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record -import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderStopIteration +from torchtitan.components.ft import FTManager, maybe_semi_sync_training from torchtitan.components.loss import rescale_accumulated_loss from torchtitan.components.metrics import ( build_metrics_processor, @@ -50,7 +50,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # non-swappable training components checkpointer: CheckpointManager - ft_manager: ft.FTManager + ft_manager: FTManager # runtime utilities device: torch.device @@ -104,11 +104,8 @@ def __init__(self, job_config: JobConfig): else: dp_degree, dp_rank = 1, 0 - self.ft_manager = ft.init_ft_manager(job_config) - # If TorchFT is enabled, the dp_rank and dp_degree, which are used for - # dataloader must be changed. - if self.ft_manager.enabled: - dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) + self.ft_manager = FTManager(job_config.fault_tolerance) + dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) # take control of garbage collection to avoid stragglers self.gc_handler = utils.GarbageCollection( @@ -259,11 +256,7 @@ def __init__(self, job_config: JobConfig): self.model_parts = [model] - if ( - self.ft_manager.enabled - and job_config.fault_tolerance.semi_sync_method is None - ): - self.ft_manager.set_all_reduce_hook(self.model_parts) + self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) # initialize device memory monitor and get peak flops for MFU calculation device_memory_monitor = self.metrics_processor.device_memory_monitor @@ -475,14 +468,9 @@ def train_step( if not self.metrics_processor.should_log(self.step): return - if parallel_dims.dp_cp_enabled or self.ft_manager.enabled: + if parallel_dims.dp_cp_enabled: loss = loss.detach() - # Skip ft manager communication when using semi sync training - use_ft_pg = ( - self.ft_manager.enabled - and self.job_config.fault_tolerance.semi_sync_method is None - ) - ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None + ft_pg = self.ft_manager.loss_sync_pg global_avg_loss, global_max_loss = ( dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), @@ -509,8 +497,8 @@ def train(self): maybe_enable_memory_snapshot( job_config, global_step=self.step ) as memory_profiler, - ft.maybe_semi_sync_training( - job_config, + maybe_semi_sync_training( + job_config.fault_tolerance, ft_manager=self.ft_manager, model_parts=self.model_parts, optimizer=self.optimizers, From c1c55ea6f0399999f5cae9d4035f49b3338996b1 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 15 Jul 2025 17:16:29 -0400 Subject: [PATCH 011/128] Lint (#1400) Call `pre-commit run --all-files` Lint job did not run on https://github.com/pytorch/torchtitan/pull/1361? --- benchmarks/llama3-8b_h200_202506_trainy-whitefiber.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/llama3-8b_h200_202506_trainy-whitefiber.md b/benchmarks/llama3-8b_h200_202506_trainy-whitefiber.md index 9ba1490f3e..f887450597 100644 --- a/benchmarks/llama3-8b_h200_202506_trainy-whitefiber.md +++ b/benchmarks/llama3-8b_h200_202506_trainy-whitefiber.md @@ -1,4 +1,4 @@ -This was performed by Trainy team on WhiteFiber in June 2025, to get a baseline of performance +This was performed by Trainy team on WhiteFiber in June 2025, to get a baseline of performance of the Trainy platform on H200s platform over multiple hosts. ### Models From 2906d8be7167ff02fd8e1836030c7c765ef74e42 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 15 Jul 2025 17:32:42 -0400 Subject: [PATCH 012/128] [DSV3] Add PP support for DSV3 (#1345) Changes 1. New helper method to explicitly specify the modules to include 2. Update model.py to handle `None` attributes Can run the 16B DSV3 on 8 GPUs with PP: ``` NGPU=8 LOG_RANK=0,7 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.pipeline_parallel_degree 8 --parallelism.pipeline_parallel_schedule Interleaved1F1B ``` DP + TP + PP example run: https://meta.wandb.io/howardhuang_meta/torchtitan/reports/DeepSeekV3-16B-8-GPU-DP-TP-PP---VmlldzozMTAz?accessToken=ltsuxu4atlsmtk5u1g1zt04xcb6q1cs4mm9ianq8mlqlpq4ppm3lfpu1p53ei4pg TODO: - upstream `pipeline_module_split` to `torch.distributed.pipelining`? --- torchtitan/models/deepseek_v3/README.md | 2 +- torchtitan/models/deepseek_v3/__init__.py | 3 +- .../models/deepseek_v3/infra/pipeline.py | 310 ++++++++++++++++++ torchtitan/models/deepseek_v3/model/model.py | 24 +- .../train_configs/debug_model.toml | 4 +- .../train_configs/deepseek_v3_16b.toml | 2 + .../train_configs/deepseek_v3_671b.toml | 2 + 7 files changed, 338 insertions(+), 9 deletions(-) create mode 100644 torchtitan/models/deepseek_v3/infra/pipeline.py diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index 107bd0481a..367e4e9413 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -38,6 +38,7 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml - Activation checkpointing - Tensor Parallel (TP) - Expert Parallel (EP) +- Pipeline Parallel (PP) ## To be added @@ -46,7 +47,6 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml - Attention Layer: need to pass softmax_scale to sdpa() to support scaling - Parallelism - Context Parallel support for DeepSeek-V3 - - PP support for DeepSeek-V3 - torch.compile - Quantization - Testing diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 141b740ce6..de2d26b8a3 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -15,6 +15,7 @@ from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from .infra.parallelize import parallelize_deepseekv3 +from .infra.pipeline import pipeline_deepseekv3 from .model.args import DeepSeekV3ModelArgs from .model.model import DeepSeekV3Model @@ -116,7 +117,7 @@ model_cls=DeepSeekV3Model, model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, - pipelining_fn=None, + pipelining_fn=pipeline_deepseekv3, build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, diff --git a/torchtitan/models/deepseek_v3/infra/pipeline.py b/torchtitan/models/deepseek_v3/infra/pipeline.py new file mode 100644 index 0000000000..7caf3ad81f --- /dev/null +++ b/torchtitan/models/deepseek_v3/infra/pipeline.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D pipeline parallelism to the Llama model. + +import copy + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + get_schedule_class, + PipelineScheduleSingle, + ScheduleZBVZeroBubble, +) + +from torchtitan.components.loss import LossFunction +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.pipeline import build_pipeline_schedule, stage_ids_this_rank +from torchtitan.protocols.train_spec import ParallelizeFunction +from torchtitan.tools.logging import logger + +from ..model.args import DeepSeekV3ModelArgs + + +def generate_module_names_per_stage( + num_stages: int, + num_layers: int, + input_weight: int = 1, + output_weight: int = 1, +) -> list[list[str]]: + """ + Programmatically generates module names per stage for pipeline parallelism with weighting. + + Args: + num_stages: Number of pipeline stages + num_layers: Total number of transformer layers in the model + input_weight: Weight for input modules (tok_embeddings) in layer calculation + output_weight: Weight for output modules (norm + output) in layer calculation + + Returns: + List of lists containing module names for each stage + + Example: + generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2) + treats embeddings as 2 layers and norm+output as 2 layers for distribution + """ + if num_stages < 1: + raise ValueError("Number of stages must be at least 1") + + if num_stages == 1: + # Single stage gets everything + layer_names = [f"layers.{i}" for i in range(num_layers)] + return [["tok_embeddings"] + layer_names + ["norm", "output"]] + + # Calculate effective layers including weights + num_effective_layers = num_layers + input_weight + output_weight + + if num_stages > num_effective_layers: + raise ValueError( + f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" + ) + + # Calculate layers per stage (distribute evenly) + layers_per_stage = num_effective_layers // num_stages + extra_layers = num_effective_layers % num_stages + + # Ensure each stage gets at least the weight of input/output modules + if layers_per_stage < max(input_weight, output_weight): + raise ValueError( + f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})" + ) + + module_names_per_stage = [] + current_layer = 0 + + for stage_idx in range(num_stages): + stage_modules = [] + + # Calculate effective layers for this stage + effective_layers_for_stage = layers_per_stage + if stage_idx < extra_layers: + effective_layers_for_stage += 1 + + # First stage: handle input modules with weighting + if stage_idx == 0: + stage_modules.append("tok_embeddings") + # Account for input weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - input_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Last stage: handle output modules with weighting + elif stage_idx == num_stages - 1: + # Account for output weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - output_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Add output modules + stage_modules.extend(["norm", "output"]) + + # Middle stages: only transformer layers + else: + for _ in range(effective_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + module_names_per_stage.append(stage_modules) + + return module_names_per_stage + + +def pipeline_deepseekv3( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: torch.device, + model_config: DeepSeekV3ModelArgs, + parallelize_fn: ParallelizeFunction, + loss_fn: LossFunction, +) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + pp_mesh = parallel_dims.world_mesh["pp"] + + # Determine the number of virtual stages based on schedule type + schedule_class = get_schedule_class( + job_config.parallelism.pipeline_parallel_schedule + ) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + + # For multi-stage schedules, default is 2 virtual stages per rank + # For single-stage schedules, default is 1 virtual stage per rank + stages_per_rank = 1 if is_single_stage_schedule else 2 + num_virtual_stages = parallel_dims.pp * stages_per_rank + + # Generate module names per stage programmatically with weighting + num_layers = model_config.n_layers + + # You can adjust these weights based on the computational cost of embeddings and output layers + # Higher weights mean these modules are treated as "heavier" in the distribution + input_weight = 1 # Weight for tok_embeddings + output_weight = 1 # Weight for norm + output layers + + module_names_per_stage = generate_module_names_per_stage( + num_virtual_stages, num_layers, input_weight, output_weight + ) + for i, stage_ms in enumerate(module_names_per_stage): + logger.info(f"Stage {i}: {stage_ms}") + + stages, model_parts = pipeline_module_split( + model, + pp_mesh, + job_config.parallelism.pipeline_parallel_schedule, + device, + module_names_per_stage, + ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for i, m in enumerate(model_parts): + # apply SPMD-style PT-D techniques + m = parallelize_fn(m, parallel_dims, job_config) + model_parts[i] = m + # NOTE: this is to update the model in the stage + # in case the model is modified e.g. by torch.compile + stages[i].submod = m + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, model_parts, has_first_stage, has_last_stage + + +def pipeline_module_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + pp_schedule: str, + device: torch.device, + module_names_per_stage: list[list[str]], +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API creates pipeline stages based on specified module names for each stage. + + Args: + whole_model: The complete model to be split + pp_mesh: Pipeline parallel device mesh + pp_schedule: Name of pipeline parallelism schedule + device: Device type + module_names_per_stage: List of lists, where each inner list contains the module names + that should be included in that stage. Module names should be + dot-separated paths. Examples: + - "tok_embeddings" for token embeddings + - "layers.0", "layers.1" for specific transformer layers + - "norm" for the final normalization layer + - "output" for the output projection layer + + Returns: + Tuple of (stages, models) where stages are PipelineStage objects and models are the + corresponding model chunks + + Example usage: + module_names_per_stage = [ + ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer + ["layers.1", "layers.2"], # Stage 1: middle layers + ["norm", "output"] # Stage 2: final norm + output + ] + """ + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + + def _build_stage_from_modules( + stage_idx: int, module_names: list[str], num_stages: int + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}") + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with None + setattr(model, module_name, None) + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(module_names_per_stage) + stages = [] + models = [] + + schedule_class = get_schedule_class(pp_schedule) + style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" + + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): + module_names = module_names_per_stage[stage_idx] + stage, model_chunk = _build_stage_from_modules( + stage_idx, + module_names, + num_stages, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx} " + f"with modules {module_names}" + ) + stages.append(stage) + models.append(model_chunk) + + return stages, models diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 61034e4c7d..9d7c336e63 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -9,7 +9,7 @@ import torch from torch import nn -from torchtitan.models.attention import build_attention +from torchtitan.models.attention import build_attention, init_attention_mask from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs @@ -357,20 +357,32 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: b=cutoff_factor * final_out_std, ) - def forward(self, tokens: torch.Tensor): + def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None): """ Forward pass for the Transformer model. Args: - tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). + tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. + If pipeline parallelism is enabled, this will be the input token indices + for the ranks on the first pipeline stage. This will be the activation of the + previous pipeline stage if the current rank is not on the first stage. + input_batch (torch.Tensor): The input batch read from the dataloader. + This will always be the input batch regardless of the pipeline stage. + This field is required for non-first PP stages to perform document + masking attention (to analyze the boundary of the document). Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ - h = self.tok_embeddings(tokens) + if self.model_args.use_flex_attn: + init_attention_mask( + input_batch if input_batch is not None else tokens, eos_id=self.eos_id + ) + + h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens for layer in self.layers.values(): h = layer(h, self.freqs_cis) - h = self.norm(h) - output = self.output(h) + h = self.norm(h) if self.norm is not None else h + output = self.output(h) if self.output is not None else h return output diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 905aa0067e..5f66ff4c37 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -50,9 +50,11 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false expert_parallel_degree = 1 +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" [checkpoint] enable_checkpoint = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 84c6b5f6bf..a9316b5487 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -51,6 +51,8 @@ fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 enable_async_tensor_parallel = false expert_parallel_degree = 1 +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "Interleaved1F1B" [checkpoint] enable_checkpoint = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 26cb64fb70..b3722c08bd 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -51,6 +51,8 @@ fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 8 enable_async_tensor_parallel = false expert_parallel_degree = 1 +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "Interleaved1F1B" [checkpoint] enable_checkpoint = false From 972ac9fb46a05e847cfab91a44fdc3bc5af0aa38 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 16 Jul 2025 16:20:32 +0800 Subject: [PATCH 013/128] Add the missing field to NoColor (#1406) As title --- torchtitan/tools/utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index 0d29f9db6a..f31c6a735e 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -149,6 +149,12 @@ class NoColor: white = "" reset = "" orange = "" + turquoise = "" + + +assert set(NoColor.__dataclass_fields__.keys()) == set( + Color.__dataclass_fields__.keys() +), "NoColor must have the same fields as Color." def check_if_feature_in_pytorch( From 9a8cb98f1a85a1c842ae3217184fde2f0a5b3fc1 Mon Sep 17 00:00:00 2001 From: Jeffrey Wan Date: Wed, 16 Jul 2025 14:53:57 -0400 Subject: [PATCH 014/128] Add option for selective op AC to filter mm shapes based on fqn (#1380) Also see discussion in https://github.com/pytorch/torchtitan/pull/1372 This PR: - Adds new config for SAC with the default such that per-op SAC automatically skips all mms with args[1].shape matching that of the Linear at fqn "moe.router.gate" - Adds general flop/act-mem/correctness tests for AC as well as the new config --- .../unit_tests/test_activation_checkpoint.py | 256 ++++++++++++++++++ torchtitan/config_manager.py | 14 + torchtitan/models/llama3/infra/parallelize.py | 41 ++- 3 files changed, 307 insertions(+), 4 deletions(-) create mode 100644 tests/unit_tests/test_activation_checkpoint.py diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py new file mode 100644 index 0000000000..fbc585f527 --- /dev/null +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torch.nn as nn +from torch.utils.flop_counter import FlopCounterMode + +from torchtitan.config_manager import ActivationCheckpoint as ACConfig +from torchtitan.models.llama3.infra.parallelize import apply_ac + + +class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleDict({"0": TransformerBlock()}) + + def forward(self, x): + return self.layers["0"](x) + + +class TransformerBlock(nn.Module): + def __init__(self): + super().__init__() + self.moe = nn.Module() + self.moe.router = nn.Module() + self.moe.router.gate = nn.Linear(512, 512, bias=False) + self.attention = nn.Module() + self.attention.wq = nn.Linear(512, 512, bias=False) + self.output = nn.Linear(512, 1024, bias=False) + + def forward(self, x): + gate_out = self.moe.router.gate(x) + wq_out = self.attention.wq(gate_out) + final_out = self.output(wq_out) + return final_out.sum() + + +class TestApplyAC(unittest.TestCase): + def test_flops(self): + def get_bw_flops(model_fn): + x = torch.randn(512, 512, requires_grad=True) + with torch.utils.checkpoint.set_checkpoint_early_stop(False): + out = model_fn(x) + out.backward() + + x = torch.randn(512, 512, requires_grad=True) + with torch.utils.checkpoint.set_checkpoint_early_stop(False): + out = model_fn(x) + with FlopCounterMode(display=False) as mode: + out.backward() + return mode.get_total_flops() / (512**3 * 2) + + # 1. No AC + model_no_ac = TestModule() + flops_no_ac = get_bw_flops(model_no_ac) + + # 2. SAC + # Per-op SAC's policy is to save every other mm + model_selective_ac = TestModule() + ac_config_no_force = ACConfig( + mode="selective", + selective_ac_option="op", + per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list + ) + apply_ac(model_selective_ac, ac_config_no_force) + flops_selective_ac = get_bw_flops(model_selective_ac) + + # 3. Per-op SAC with force recompute "moe.router.gate" + # This leads to two mms being recomputed since they share the same shape! + model_with_force_first = TestModule() + ac_config_with_force_first = ACConfig( + mode="selective", + selective_ac_option="op", + per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], + ) + apply_ac(model_with_force_first, ac_config_with_force_first) + flops_with_force_first = get_bw_flops(model_with_force_first) + + # 4. Per-op SAC with force recompute "output" + model_with_force_last = TestModule() + ac_config_with_force_last = ACConfig( + mode="selective", + selective_ac_option="op", + per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], + ) + apply_ac(model_with_force_last, ac_config_with_force_last) + flops_with_force_last = get_bw_flops(model_with_force_last) + + # 5. Full AC + model_with_full_ac = TestModule() + ac_config_full_ac = ACConfig( + mode="full", + ) + apply_ac(model_with_full_ac, ac_config_full_ac) + flops_full_ac = get_bw_flops(model_with_full_ac) + + self.assertEqual(flops_no_ac, 8.0) + self.assertEqual(flops_selective_ac, 9.0) + self.assertEqual(flops_with_force_first, 10.0) + self.assertEqual(flops_with_force_last, 11.0) + self.assertEqual(flops_full_ac, 12.0) + + def test_mem(self): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is unavailable") + + def get_act_mem(model_fn): + x = torch.randn(512, 512, requires_grad=True, device="cuda") + out = model_fn(x) + out.backward() + start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + + out = model_fn(x) + cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + act_mem = (cur_mem - start_mem) / (1024 * 1024) # → MB + out.backward() + return act_mem + + # 1. No AC + model_no_ac = TestModule().cuda() + mem_no_ac = get_act_mem(model_no_ac) + + # 2. SAC + # Per-op SAC's policy is to save every other mm + model_selective_ac = TestModule().cuda() + ac_config_no_force = ACConfig( + mode="selective", + selective_ac_option="op", + per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list + ) + apply_ac(model_selective_ac, ac_config_no_force) + mem_selective_ac = get_act_mem(model_selective_ac) + + # 3. Per-op SAC with force recompute "moe.router.gate" + # This leads to two mms being recomputed since they share the same shape! + model_with_force_first = TestModule().cuda() + ac_config_with_force_first = ACConfig( + mode="selective", + selective_ac_option="op", + per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], + ) + apply_ac(model_with_force_first, ac_config_with_force_first) + mem_with_force_first = get_act_mem(model_with_force_first) + + # 4. Per-op SAC with force recompute "output" + model_with_force_last = TestModule().cuda() + ac_config_with_force_last = ACConfig( + mode="selective", + selective_ac_option="op", + per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], + ) + apply_ac(model_with_force_last, ac_config_with_force_last) + mem_with_force_last = get_act_mem(model_with_force_last) + + # 5. Full AC + model_with_full_ac = TestModule().cuda() + ac_config_full_ac = ACConfig( + mode="full", + ) + apply_ac(model_with_full_ac, ac_config_full_ac) + mem_full_ac = get_act_mem(model_with_full_ac) + + self.assertEqual(mem_no_ac, 2.0) + self.assertEqual(mem_selective_ac, 3.0) + self.assertEqual(mem_with_force_first, 2.0) + self.assertEqual(mem_with_force_last, 1.0) + self.assertEqual(mem_full_ac, 0.0) + # Note: SAC > no-AC here because it unnecessarily saves "output" + # even that is not needed for recomputaion and output is double + # the size of the other two mms. + + def test_correctness(self): + model_no_ac = TestModule() + + model_selective_ac = TestModule() + model_selective_ac.load_state_dict(model_no_ac.state_dict()) + apply_ac( + model_selective_ac, + ACConfig( + mode="selective", + selective_ac_option="op", + per_op_sac_force_recompute_mm_shapes_by_fqns=[], + ), + ) + model_force_first = TestModule() + model_force_first.load_state_dict(model_no_ac.state_dict()) + apply_ac( + model_force_first, + ACConfig( + mode="selective", + selective_ac_option="op", + per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], + ), + ) + + model_force_last = TestModule() + model_force_last.load_state_dict(model_no_ac.state_dict()) + apply_ac( + model_force_last, + ACConfig( + mode="selective", + selective_ac_option="op", + per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], + ), + ) + + def run_fwd_bwd(model, batch): + model.zero_grad(set_to_none=True) + xin = batch.clone().detach().requires_grad_(True) + out = model(xin) # scalar + out.backward() + + grad_in = xin.grad.detach().clone() + grad_params = [ + p.grad.detach().clone() if isinstance(p.grad, torch.Tensor) else None + for p in model.parameters() + ] + return out.detach(), grad_in, grad_params + + batch = torch.randn(64, 512) + + out_ref, gin_ref, gparams_ref = run_fwd_bwd(model_no_ac, batch) + out_sel, gin_sel, gparams_sel = run_fwd_bwd(model_selective_ac, batch) + out_f1, gin_f1, gparams_f1 = run_fwd_bwd(model_force_first, batch) + out_fl, gin_fl, gparams_fl = run_fwd_bwd(model_force_last, batch) + + for other_out in (out_sel, out_f1, out_fl): + torch.testing.assert_close(out_ref, other_out) + + for other_gin in (gin_sel, gin_f1, gin_fl): + torch.testing.assert_close(gin_ref, other_gin) + + for g_ref, g_sel, g_f1, g_fl in zip( + gparams_ref, gparams_sel, gparams_f1, gparams_fl + ): + # Skip wrapper / missing grads + if not ( + torch.is_tensor(g_ref) + and torch.is_tensor(g_sel) + and torch.is_tensor(g_f1) + and torch.is_tensor(g_fl) + ): + continue + + torch.testing.assert_close(g_ref, g_sel) + torch.testing.assert_close(g_ref, g_f1) + torch.testing.assert_close(g_ref, g_fl) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 07c92b6f92..c6e1ccea26 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -498,6 +498,20 @@ class ActivationCheckpoint: 'int' (e.g., 2) for every nth layer, or 'op' for op level ac. """ + per_op_sac_force_recompute_mm_shapes_by_fqns: list[str] = field( + default_factory=lambda: ["moe.router.gate"] + ) + """ + When per-op selective ac is used, this list of fully qualified names is used + to determine which mm shapes to force recompute, rather than being considered + by rest of the sac policy, e.g save every other mm. Only nn.Linear modules are + supported today. + + Note: this config applies to mms not limited to those matching the specified + fqns, e.g. if "moe.router.gate", corresponding to Linear(in, out), is specified, + ANY mm with shape matching (*, in) x (in, out) will be force recomputed. + """ + @dataclass class Float8: diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index d67e283721..bbbbe71e4b 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -8,6 +8,7 @@ # training techniques (e.g. activation checkpointing and compile) to the Llama model. from collections import defaultdict +from typing import Optional import torch import torch.nn as nn @@ -27,7 +28,11 @@ SequenceParallel, ) -from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config_manager import ( + ActivationCheckpoint as ACConfig, + JobConfig, + TORCH_DTYPE_MAP, +) from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger @@ -235,7 +240,9 @@ def apply_tp( } -def _apply_ac_to_transformer_block(module: nn.Module, ac_config): +def _apply_ac_to_transformer_block( + module: nn.Module, ac_config: ACConfig, *, base_fqn: Optional[str] = None +): valid_ac_modes = ("full", "selective") if ac_config.mode not in valid_ac_modes: raise ValueError( @@ -259,11 +266,35 @@ def _apply_ac_to_transformer_block(module: nn.Module, ac_config): create_selective_checkpoint_contexts, ) + mm_recompute_shapes = set() + if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: + for module_fqn, submod in module.named_modules(): + fqn = module_fqn + if base_fqn is not None: + fqn = f"{base_fqn}.{module_fqn}" + if not any( + filter_fqn in fqn + for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns + ): + continue + if not isinstance(submod, nn.Linear): + raise ValueError( + "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " + f"a nn.Linear, but got: {submod}" + ) + out_f, in_f = submod.weight.shape + mm_recompute_shapes.add((in_f, out_f)) + logger.debug( + f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" + ) + def _get_custom_policy(meta): def _custom_policy(ctx, func, *args, **kwargs): mode = "recompute" if ctx.is_recompute else "forward" mm_count_key = f"{mode}_mm_count" if func == torch.ops.aten.mm.default: + if args[1].shape in mm_recompute_shapes: + return CheckpointPolicy.PREFER_RECOMPUTE meta[mm_count_key] += 1 # Saves output of all compute ops, except every second mm to_save = func in _save_list and not ( @@ -297,10 +328,12 @@ def selective_checkpointing_context_fn(): return module -def apply_ac(model: nn.Module, ac_config): +def apply_ac(model: nn.Module, ac_config: ACConfig): """Apply activation checkpointing to the model.""" for layer_id, transformer_block in model.layers.named_children(): - transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config) + transformer_block = _apply_ac_to_transformer_block( + transformer_block, ac_config, base_fqn=f"layers.{layer_id}" + ) model.layers.register_module(layer_id, transformer_block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") From f062d48ba77910f11b3b0470135da1fd1abb486e Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Wed, 16 Jul 2025 12:34:11 -0700 Subject: [PATCH 015/128] [llama4] Change expert_bias and tokens_per_expert to non-persistent buffer (#1403) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As titled. Tested on llama4 debugging model (dp=8, ep=2): Screenshot 2025-07-15 at 8 05 12 PM --- torchtitan/experiments/llama4/model/moe.py | 2 -- torchtitan/models/deepseek_v3/model/moe.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index d7f0ce3fd9..71ac1360c3 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -249,12 +249,10 @@ def __init__(self, model_args: TransformerModelArgs): self.register_buffer( "expert_bias", torch.zeros(num_experts, dtype=torch.float32), - persistent=True, ) self.register_buffer( "tokens_per_expert", torch.zeros(num_experts, dtype=torch.float32), - persistent=True, ) else: self.expert_bias = None diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 2554d61310..840cb1a57d 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -290,12 +290,10 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): self.register_buffer( "expert_bias", torch.zeros(num_experts, dtype=torch.float32), - persistent=True, ) self.register_buffer( "tokens_per_expert", torch.zeros(num_experts, dtype=torch.float32), - persistent=True, ) else: self.expert_bias = None From d69a7378cc70e93b1061c9d68ff1c8d54e1594d9 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Wed, 16 Jul 2025 14:18:17 -0700 Subject: [PATCH 016/128] create multipe outer optimizers for diloco (#1407) Summary: enable creating a separate outer optimizer for each of the parameter fragments for streaming diloco --- torchtitan/components/ft.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft.py index 60e4d2f80d..cedc374799 100644 --- a/torchtitan/components/ft.py +++ b/torchtitan/components/ft.py @@ -116,15 +116,19 @@ def maybe_semi_sync_training( # Create the outer optimizer based on the inner optimizer parameters. params = [group["params"] for group in optimizer.param_groups] params = [param for sublist in params for param in sublist] - outer_optimizer = torch.optim.SGD( - params, lr=0.7, momentum=0.9, nesterov=True - ) + outer_optimizers = [] + for model in model_parts: + params = [p for p in model.parameters() if p.requires_grad] + outer_optimizer = torch.optim.SGD( + params, lr=0.7, momentum=0.9, nesterov=True + ) + outer_optimizers.append(outer_optimizer) return local_sgd.DiLoCo( manager=ft_manager._manager, model_fragments=model_parts, inner_optimizer=optimizer, - outer_optimizer=outer_optimizer, + outer_optimizer=outer_optimizers, sync_every=ft_config.sync_steps, should_quantize=ft_config.should_quantize, fragment_sync_delay=ft_config.fragment_sync_delay, From 4e5265ed67aef04e3b56b0cee8665d476a3c63fd Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Fri, 18 Jul 2025 11:16:14 -0700 Subject: [PATCH 017/128] [DSV3] Change sdpa interface to pass softmax_scale (#1394) ## Changes in this diff: 1. Pass softmax_scale to sdpa() forward. 2. Change some default parameters for debug_model.toml --- torchtitan/models/attention.py | 8 ++++++-- torchtitan/models/deepseek_v3/model/model.py | 5 +---- .../models/deepseek_v3/train_configs/debug_model.toml | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f7f13462ef..e4b3b8c683 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -207,11 +207,15 @@ def _init_backend(cls) -> None: cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float | None = None, ) -> torch.Tensor: assert self.backends, "SDPA Backends should not be empty." with sdpa_kernel(self.backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, is_causal=True) + return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale) def build_attention( diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 9d7c336e63..34332751fd 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -230,10 +230,7 @@ def forward( k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) - # TODO: Need to pass softmax_scale to sdpa() interface. - # For mask, DeepseekV3 uses causal mask, so we can use the default mask in sdpa - # https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L17 - output = self.sdpa(q, k, v) + output = self.sdpa(q, k, v, scale=self.softmax_scale) # Reshape and project output output = output.transpose(1, 2) # (bsz, seqlen, n_heads, v_head_dim) diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 5f66ff4c37..c4642dc5e3 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -42,7 +42,7 @@ lr_min = 0.0 local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 1 +steps = 10 compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) From c004dc4310f5baa6aff73f771de998eba300feb5 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Fri, 18 Jul 2025 14:53:00 -0700 Subject: [PATCH 018/128] separate outputs for ft replicas (#1410) Summary: - log the tensorboard for ft replicas to a separate folder - log profiles for ft replicas to a separate folder --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1410). * #1411 * __->__ #1410 --- torchtitan/components/metrics.py | 6 ++++++ torchtitan/tools/profiling.py | 12 +++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 0ccf9fd760..3fee856504 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -261,6 +261,12 @@ def _build_metric_logger( dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M") ) + if job_config.fault_tolerance.enable: + base_log_dir = os.path.join( + base_log_dir, + f"replica_{job_config.fault_tolerance.replica_id}", + ) + if metrics_config.save_for_all_ranks: base_log_dir = os.path.join( base_log_dir, f"rank_{torch.distributed.get_rank()}" diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 050b992cc8..27842bc7d4 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -34,6 +34,10 @@ def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): rank = torch.distributed.get_rank() + replica_id = None + if config.fault_tolerance.enable: + replica_id = config.fault_tolerance.replica_id + def trace_handler(prof): curr_trace_dir_name = "iteration_" + str(prof.step_num) curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) @@ -42,7 +46,13 @@ def trace_handler(prof): logger.info(f"Dumping profiler traces at step {prof.step_num}") begin = time.monotonic() - prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") + + output_file = curr_trace_dir + if replica_id is not None: + output_file = os.path.join(output_file, f"replica{replica_id}") + output_file = os.path.join(output_file, f"rank{rank}_trace.json") + + prof.export_chrome_trace(output_file) logger.info( f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" ) From c924c4458c2c3077a713d4844d4a00377ae38164 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Fri, 18 Jul 2025 14:53:14 -0700 Subject: [PATCH 019/128] allow specifying ft pg (#1411) Summary: - allow using gloo process group - add a parameter to the ft config - only nccl and gloo will be supported for now --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1411). * __->__ #1411 * #1410 --- torchtitan/components/ft.py | 11 ++++++++++- torchtitan/config_manager.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft.py index cedc374799..ee94b238bd 100644 --- a/torchtitan/components/ft.py +++ b/torchtitan/components/ft.py @@ -6,6 +6,7 @@ import importlib from contextlib import nullcontext +from datetime import timedelta from typing import ContextManager, Optional, TYPE_CHECKING, Union import torch @@ -37,7 +38,15 @@ def __init__( if not has_torchft: raise ImportError("torchft is not installed. Please install it.") - pg = ft.ProcessGroupNCCL() + process_group_timeout = timedelta( + milliseconds=ft_config.process_group_timeout_ms + ) + if ft_config.process_group == "gloo": + pg = ft.ProcessGroupGloo(timeout=process_group_timeout) + elif ft_config.process_group == "nccl": + pg = ft.ProcessGroupNCCL(timeout=process_group_timeout) + else: + raise ValueError(f"Unsuported process group: {ft_config.process_group}") # If the training method is specific, then the quorum should be synchronous self.use_async_quorum = ft_config.semi_sync_method is None diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index c6e1ccea26..c209d8fb3c 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -606,6 +606,17 @@ class FaultTolerance: Note that this is still an experimental feature. """ + process_group: str = "gloo" + """ + The process group to use for fault tolerance. Currently, only "gloo" and "nccl" are supported. + """ + + process_group_timeout_ms: int = 10000 + """ + The process group will abort if operations don't succeed within this duration. + Note: This currently only works with gloo process group. + """ + replica_id: int = 0 """The TorchFT replica ID of this run.""" From 183f6fce1a586ce66027deee2c4fdb823616cd75 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Sat, 19 Jul 2025 18:53:54 -0700 Subject: [PATCH 020/128] Remove flex+sac restriction (#1408) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stacked PRs: * __->__#1408 --- --- --- Remove flex+compile restriction Recently enabled this w/ https://github.com/pytorch/pytorch/pull/150080 Before the change I got: ```Shell [rank0]:[rank0]: File "/home/drisspg/meta/torchtitan/torchtitan/models/llama3/model/args.py", line 48, in update_from_config [rank0]:[rank0]: raise ValueError( [rank0]:[rank0]: ValueError: FlexAttention is not compatible with selective AC yet. See https://github.com/pytorch/pytorch/issues/147879 ``` I tried running this locally with; ``` CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" NGPU=1 ./run_train.sh --model.flavor debugmodel_flex_attn --training.compile ``` Got: Screenshot 2025-07-16 at 8 25 53 PM If there are other more robust testing we can do I am down I am trying to unblock some internal users and ensure I can close this issue: https://github.com/pytorch/pytorch/issues/147879 --- tests/integration_tests.py | 13 +++++++++++++ torchtitan/experiments/llama4/model/args.py | 6 ------ torchtitan/models/llama3/model/args.py | 6 ------ 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index f3000eef7e..f50c592857 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -344,6 +344,19 @@ def build_test_list(): "fsdp+flex_attn", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--parallelism.data_parallel_shard_degree=4", + "--activation_checkpoint.mode=selective", + "--activation_checkpoint.selective_ac_option=op", + "--model.flavor=debugmodel_flex_attn", + ] + ], + "FSDP + FLEX + per op SAC", + "fsdp+flex_attn+per_op_sac", + ngpu=4, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index a7f99e732b..bccedd7be8 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -71,12 +71,6 @@ def update_from_config( ) self.use_grouped_mm = False - if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn: - raise ValueError( - "FlexAttention is not compatible with selective AC yet. " - "See https://github.com/pytorch/pytorch/issues/147879" - ) - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: raise ValueError( "FlexAttention is not compatible with CP yet. " diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 38f7e3321d..7f7b4e5a96 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -44,12 +44,6 @@ def update_from_config( self.max_seq_len = job_config.training.seq_len self.eos_id = tokenizer.eos_id - if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn: - raise ValueError( - "FlexAttention is not compatible with selective AC yet. " - "See https://github.com/pytorch/pytorch/issues/147879" - ) - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: raise ValueError( "FlexAttention is not compatible with CP yet. " From beb29a1c74203410e759d5a5cf87b091fba25024 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Sun, 20 Jul 2025 16:21:48 -0700 Subject: [PATCH 021/128] add infra support for HF checkpoint conversion (#1404) This PRs adds a new field `StateDictAdapter` in `TrainSpec` - Currently it only contains a pair of static methods `to_hf` and `from_hf`. Later we could add other pairs like `to_meta` / `from_meta`, `to_vllm` / `from_vllm`, etc. - It is passed to `CheckpointManager` to convert between torchtitan model and HF model during checkpoint save / load. - It could also be potentially used by downstream inference engines which only supports HF models. In order to save / load in HF format, a model is required to have a corresponding `StateDictAdapter` subclass implementation. For Llama3, I created a placeholder `Llama3StateDictAdapter` to be implemented. cc @wesleytruong This PR also renames checkpoint config options `initial_load_model_weights_only` `last_save_model_weights_only` to simply `initial_load_model_only` `last_save_model_only`, respectively. It seems to me that the original names were made corresponding to `torch.load(..., weights_only=True)`. As long as we document & test clearly when this correspondence holds, I prefer the name in torchtitan to be simple and less ambiguous. --- docs/checkpoint.md | 16 +-- tests/integration_tests.py | 22 ++-- tests/unit_tests/test_checkpoint.py | 21 ++-- torchtitan/components/checkpoint.py | 109 ++++++++++-------- torchtitan/config_manager.py | 30 ++--- .../flux/tests/integration_tests.py | 6 +- .../flux/train_configs/debug_model.toml | 2 +- .../flux/train_configs/flux_dev_model.toml | 2 +- .../train_configs/flux_schnell_model.toml | 2 +- .../llama4/train_configs/debug_model.toml | 2 +- .../llama4/train_configs/llama4_17bx128e.toml | 2 +- .../llama4/train_configs/llama4_17bx16e.toml | 2 +- .../train_configs/debug_model.toml | 2 +- .../train_configs/deepseek_v3_16b.toml | 2 +- .../train_configs/deepseek_v3_671b.toml | 2 +- torchtitan/models/llama3/__init__.py | 4 +- .../models/llama3/model/state_dict_adapter.py | 21 ++++ .../llama3/train_configs/debug_model.toml | 2 +- .../llama3/train_configs/llama3_405b.toml | 2 +- .../llama3/train_configs/llama3_70b.toml | 2 +- .../llama3/train_configs/llama3_8b.toml | 2 +- torchtitan/protocols/state_dict_adapter.py | 42 +++++++ torchtitan/protocols/train_spec.py | 4 +- torchtitan/train.py | 1 + 24 files changed, 192 insertions(+), 110 deletions(-) create mode 100644 torchtitan/models/llama3/model/state_dict_adapter.py create mode 100644 torchtitan/protocols/state_dict_adapter.py diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 0ffcafb02c..ecfdd67d6b 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -24,12 +24,12 @@ interval = 500 ``` -2. SAVE ONLY MODEL WEIGHTS -By setting `last_save_model_weights_only` to `True`, the checkpoint will only contain the model weights and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size. +2. SAVE MODEL ONLY +By setting `last_save_model_only` to `True`, the checkpoint will only contain the model and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size. ``` [checkpoint] enable_checkpoint = true -last_save_model_weights_only = true +last_save_model_only = true ``` 3. CHOOSE DESIRED EXPORT PRECISION @@ -37,7 +37,7 @@ The default model states are in `float32`. You can choose to export the checkpoi ``` [checkpoint] enable_checkpoint = true -last_save_model_weights_only = true +last_save_model_only = true export_dtype = "bfloat16" ``` @@ -48,12 +48,12 @@ enable_checkpoint = true folder = "checkpoint" interval = 10 load_step = 5 -last_save_model_weights_only = true +last_save_model_only = true export_dtype = "bfloat16" ``` 5. SAVE THE FINAL CHECKPOINT\ -Once the above have been set, the final checkpoint at the end of the training step will consist of model weights only with the desired export dtype. However, if the final step has not been reached yet, full checkpoints will still be saved so that training can be resumed. +Once the above have been set, the final checkpoint at the end of the training step will consist of model only with the desired export dtype. However, if the final step has not been reached yet, full checkpoints will still be saved so that training can be resumed. 6. CONVERT SHARDED CHECKPOINTS TO A SINGLE FILE\ Finally, once you have obtained the last checkpoint, you can use the following command to convert the sharded checkpoints to a single .pt file that can be loaded into torchtune: @@ -88,5 +88,5 @@ NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoi ## How to load / save a checkpoint in HF safetensors format -For save, users need to set `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_weights_only` to save the last checkpoint in HF format (intermediate ones are always in DCP format). -For load, users need to either put the checkpoint in the `step-0` folder if using `--checkpoint.folder`, or specify `--checkpoint.initial_load_path` to load from a different folder. They also need to set `--checkpoint.initial_load_model_weights_only` to load the checkpoint in HF format. +For save, users need to set `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` to save the last checkpoint in HF format (intermediate ones are always in DCP format). +For load, users need to either put the checkpoint in the `step-0` folder if using `--checkpoint.folder`, or specify `--checkpoint.initial_load_path` to load from a different folder. They also need to set `--checkpoint.initial_load_model_only` to load the checkpoint in HF format. diff --git a/tests/integration_tests.py b/tests/integration_tests.py index f50c592857..adca9ec56e 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -123,37 +123,37 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--checkpoint.folder hf_checkpoint", - "--checkpoint.last_save_in_safetensors_format", - "--checkpoint.last_save_model_weights_only", + "--checkpoint.last_save_in_hf", + "--checkpoint.last_save_model_only", ], [ "--checkpoint.enable_checkpoint", - "--checkpoint.initial_load_path artifacts-to-be-uploaded/full_checkpoint_hf_safetensors/hf_checkpoint/step-10/", + "--checkpoint.initial_load_path artifacts-to-be-uploaded/model_only_hf_checkpoint/hf_checkpoint/step-10/", ], ], - "Checkpoint Integration Test - save load full checkpoint in HF safetensors format", - "full_checkpoint_hf_safetensors", + "Checkpoint Integration Test - save load model only checkpoint in HF definition and format", + "model_only_hf_checkpoint", ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - "--checkpoint.last_save_model_weights_only", + "--checkpoint.last_save_model_only", ], ], - "Checkpoint Integration Test - Save Model Weights Only fp32", - "last_save_model_weights_only_fp32", + "Checkpoint Integration Test - Save Model Only fp32", + "last_save_model_only_fp32", ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - "--checkpoint.last_save_model_weights_only", + "--checkpoint.last_save_model_only", "--checkpoint.export_dtype bfloat16", ], ], - "Checkpoint Integration Test - Save Model Weights Only bf16", - "last_save_model_weights_only_bf16", + "Checkpoint Integration Test - Save Model Only bf16", + "last_save_model_only_bf16", ), OverrideDefinitions( [ diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index 2f8127bfd6..e2c0e1254b 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -88,11 +88,11 @@ def __init__(self, job): folder="", interval=1, keep_latest_k=0, - last_save_model_weights_only=False, + last_save_model_only=False, export_dtype="float32", exclude_from_loading=[], initial_load_path=None, - initial_load_model_weights_only=False, + initial_load_model_only=False, ) self.fault_tolerance = SimpleNamespace(replica_id=0) @@ -119,11 +119,11 @@ def setUp(self): folder="", interval=1, keep_latest_k=2, - last_save_model_weights_only=False, + last_save_model_only=False, export_dtype="float32", exclude_from_loading=[], initial_load_path=None, - initial_load_model_weights_only=False, + initial_load_model_only=False, ) ft_ns = SimpleNamespace(replica_id=0) job_ns = SimpleNamespace(dump_folder=self.test_folder) @@ -299,7 +299,6 @@ def test_load_finds_latest_and_calls_dcp_load(self, mock_load, mock_rank): expected = os.path.join(ckpt_folder, "step-5") mock_load.assert_called_once() args, kwargs = mock_load.call_args - self.assertEqual(args[0], manager._states_to_load(model_only=False)) self.assertEqual(kwargs.get("checkpoint_id"), expected) self.assertTrue(res) manager.close() @@ -342,13 +341,13 @@ def test_interval_respects_interval(self, mock_load, mock_save, mock_rank): @mock.patch("torch.distributed.get_rank", return_value=0) @mock.patch("torchtitan.components.checkpoint.dcp.save") @mock.patch("torchtitan.components.checkpoint.dcp.load") - def test_last_save_model_weights_only_and_initial_load_model_weights_only( + def test_last_save_model_only_and_initial_load_model_only( self, mock_load, mock_save, mock_rank ): mock_save.side_effect = self.fake_save mock_load.side_effect = self.fake_load # Phase 1: save model weights only - self.job_config.checkpoint.last_save_model_weights_only = True + self.job_config.checkpoint.last_save_model_only = True manager1 = CheckpointManager( dataloader=self.data_loader, model_parts=self.model_parts, @@ -363,8 +362,8 @@ def test_last_save_model_weights_only_and_initial_load_model_weights_only( self.assertTrue(os.path.isdir(path1)) # Phase 2: initial load from step-1 cfg = self.job_config.checkpoint - cfg.last_save_model_weights_only = False - cfg.initial_load_model_weights_only = True + cfg.last_save_model_only = False + cfg.initial_load_model_only = True cfg.initial_load_path = path1 cfg.folder = "" self.job_config.job.dump_folder = self.test_folder @@ -603,8 +602,8 @@ def fake_load(state_dict: dict, checkpoint_id=None): self.assertNotIn("model", state_dict) self.assertIn("optimizer", state_dict) - self.job_config.checkpoint.last_save_model_weights_only = True - self.job_config.checkpoint.initial_load_model_weights_only = False + self.job_config.checkpoint.last_save_model_only = True + self.job_config.checkpoint.initial_load_model_only = False manager = CheckpointManager( dataloader=self.data_loader, model_parts=self.model_parts, diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index f71417de80..08c2dd1067 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -37,6 +37,7 @@ from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.protocols.state_dict_adapter import StateDictAdapter from torchtitan.tools.logging import logger from torchtitan.tools.utils import GarbageCollection @@ -54,11 +55,6 @@ class AsyncMode(str, enum.Enum): ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" -class CheckpointType(str, enum.Enum): - DCP = "DCP" - SAFETENSORS = "safetensors" - - # For now, we will manually pop the freqs_cis buffer, as we made this permanent # temporarily and we don't want to include it in the exported state_dict. # Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 @@ -179,6 +175,8 @@ class CheckpointManager: states (Dict[str, Any]): The states that need to be saved, other than the previous 4 components. job_config (JobConfig): The job config used to configure the checkpointing. + sd_adapter (Optional[type[StateDictAdapter]]): The adapter used to convert model state + dicts between native format and other formats. ft_manager (Optional[ft.Manager]): The FTManager from TorchFT. """ @@ -190,17 +188,21 @@ def __init__( lr_schedulers: LRSchedulersContainer, states: dict[str, Any], job_config: JobConfig, + sd_adapter: type[StateDictAdapter] | None = None, ft_manager: FTManager | None = None, ) -> None: ckpt_config = job_config.checkpoint self.enable_checkpoint = ckpt_config.enable_checkpoint - self.last_save_in_safetensors_format = ( - ckpt_config.last_save_in_safetensors_format - ) + self.last_save_in_hf = ckpt_config.last_save_in_hf + if self.last_save_in_hf: + assert ( + sd_adapter is not None + ), "job_config.checkpoint.last_save_in_hf is True, but sd_adapter is not provided." + self.sd_adapter = sd_adapter + self.ft_manager = ( ft_manager.manager if ft_manager and ft_manager.enabled else None ) - if self.ft_manager: optimizers.init_cache_state_dict() @@ -222,7 +224,7 @@ def load_state_dict(state_dict): self.states[k].load_state_dict(v) self.ft_manager.set_state_dict_fns(load_state_dict, state_dict) - self.ft_replica_id = job_config.fault_tolerance.replica_id + self.ft_replica_id = job_config.fault_tolerance.replica_id async_mode = ckpt_config.async_mode.lower() self.enable_staging = ( @@ -253,10 +255,8 @@ def load_state_dict(state_dict): # Checkpoint policy related fields. self.initial_load_path = ckpt_config.initial_load_path - self.initial_load_model_weights_only = ( - ckpt_config.initial_load_model_weights_only - ) - self.last_save_model_weights_only = ckpt_config.last_save_model_weights_only + self.initial_load_model_only = ckpt_config.initial_load_model_only + self.last_save_model_only = ckpt_config.last_save_model_only self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] self.exclude_from_loading = ckpt_config.exclude_from_loading self.interval = ckpt_config.interval @@ -328,7 +328,7 @@ def dcp_save( checkpoint_id: str, async_mode: AsyncMode, enable_garbage_collection: bool = False, - save_in_safetensors_format: bool = False, + to_hf: bool = False, ) -> Future | None: """Save the checkpoint with dcp. Args: @@ -336,7 +336,7 @@ def dcp_save( checkpoint_id (str): The checkpoint id to save. async_mode (AsyncMode): Whether the checkpoint is async. enable_garbage_collection (bool): Whether to enable garbage collection after save. - save_in_safetensors_format (bool): Whether to save in safetensors format. + to_hf (bool): Whether to save in HF model definition and safetensors format. Returns: Future: The future object if the checkpoint is async, otherwise None. @@ -346,7 +346,10 @@ def dcp_save( storage_writer: HuggingFaceStorageWriter | None = None checkpoint_save_id: str | None = None - if save_in_safetensors_format: + if to_hf: + assert self.sd_adapter is not None + state_dict = self.sd_adapter.to_hf(state_dict) + fqn_to_index_mapping = {} num_fqns_per_file = 30 # the use of 30 is just a heuristic for now. @@ -398,21 +401,37 @@ def dcp_load( self, state_dict: dict[str, Any], checkpoint_id: str, - checkpoint_type: CheckpointType, + from_hf: bool, ) -> None: """Load the checkpoint with dcp. Args: state_dict (dict): The state dict to load. checkpoint_id (str): The checkpoint id to load. - hf_safetensors_format (bool): Whether to use the HuggingFace safetensors format. + from_hf (bool): Whether to load from HuggingFace checkpoint with + its own model definition and safetensors format. """ - if checkpoint_type == CheckpointType.SAFETENSORS: - storage_reader = HuggingFaceStorageReader(path=checkpoint_id) - dcp.load(state_dict, storage_reader=storage_reader) + if from_hf: + assert ( + self.sd_adapter is not None + ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." + hf_state_dict = self.sd_adapter.to_hf(state_dict) + + dcp.load( + hf_state_dict, + storage_reader=HuggingFaceStorageReader(path=checkpoint_id), + ) + + state_dict = self.sd_adapter.from_hf(hf_state_dict) + self.states[MODEL].load_state_dict(state_dict) else: dcp.load(state_dict, checkpoint_id=checkpoint_id) + # TODO: Since we flatten the model states in state_dict, we need to + # manually call load_state_dict() for the model. Need to fix this. + if MODEL in self.states: + self.states[MODEL].load_state_dict(state_dict) + @torch.no_grad() def save(self, curr_step: int, last_step: bool = False) -> None: """Save the checkpoint for the current step. @@ -512,16 +531,16 @@ def load(self, step: int = -1) -> bool: checkpoint_id = self.initial_load_path if not os.path.isdir(checkpoint_id): raise ValueError( - "initial_load_full_checkpoint is specified but the path is not valid." + "checkpoint.initial_load_path is specified but the path is not valid." ) - model_only = self.initial_load_model_weights_only + model_only = self.initial_load_model_only else: return False else: if self.initial_load_path: logger.info( - "`initial_load_path` is provided but the checkpoint folder exists. " - "Checkpointer will use the checkpoints from the checkpoint folder." + "checkpoint.initial_load_path is provided but the checkpoint.folder exists. " + "Checkpointer will use the checkpoints from the checkpoint.folder." ) step = self._find_load_step() if step == -1 else step if step == -1: @@ -534,18 +553,18 @@ def load(self, step: int = -1) -> bool: f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found." ) - checkpoint_type = self._find_checkpoint_type(checkpoint_id) - if checkpoint_type == CheckpointType.SAFETENSORS: + from_hf = self._load_checkpoint_in_hf_format(checkpoint_id) + if from_hf: assert ( model_only - ), "Only model weights can be loaded when loading from safetensors checkpoint." + ), "Only model can be loaded when loading from HF's safetensors checkpoint." logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) self.dcp_load( states, checkpoint_id=checkpoint_id, - checkpoint_type=checkpoint_type, + from_hf=from_hf, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( @@ -593,7 +612,7 @@ def _find_load_step(self, folder: str = "") -> int: return -1 return max(step_counts) - def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType: + def _load_checkpoint_in_hf_format(self, checkpoint_id: str) -> bool: """Find the checkpoint type for the given id. Args: @@ -605,8 +624,8 @@ def _find_checkpoint_type(self, checkpoint_id: str) -> CheckpointType: for filename in os.listdir(checkpoint_id): if filename == "model.safetensors.index.json": - return CheckpointType.SAFETENSORS - return CheckpointType.DCP + return True + return False def _ft_folder(self) -> str: return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") @@ -636,7 +655,7 @@ def _ft_load(self) -> None: self.ft_states, checkpoint_id=checkpoint_id, # FT checkpoints are always DCP because FT checkpoint currently only save/load dataloader. - checkpoint_type=CheckpointType.DCP, + from_hf=False, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( @@ -668,7 +687,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: Returns: Dict[str, Any]: The states to load for the given step. """ - # For the first step, we will only load the model weights. + # For the first step, we will only load the model. if model_only: return self.states[MODEL].state_dict() @@ -688,35 +707,35 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: return states_to_load def _save_last_step(self, curr_step: int) -> None: - # We only consider saving weights only at the end of the training. So - # this won't affect preemption and training resume. We also only allow - # dtype conversion when we are checkpoint model weights only and the - # current dtype is not the same as the export dtype at the end of the training. + # We only consider saving model only at the end of the training. So this + # won't affect preemption and training resume. We also only allow dtype + # conversion when we are checkpointing model only and the current dtype + # is not the same as the export dtype at the end of the training. - if self.last_save_model_weights_only: + if self.last_save_model_only: states = self.states[MODEL].state_dict() if self.export_dtype != torch.float32: states = {k: v.to(self.export_dtype) for k, v in states.items()} logger.info( - f"Saving a model weights only checkpoint in {self.export_dtype} " + f"Saving a model only checkpoint in {self.export_dtype} " f"at last step, step {curr_step}." ) else: logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") states = self._flattened_model_states_sd() - if self.last_save_in_safetensors_format: + if self.last_save_in_hf: assert ( - self.last_save_model_weights_only - ), "Only model weights can be saved when saving in safetensors format." + self.last_save_model_only + ), "Only model can be saved when saving in HF safetensors format." self.dcp_save( states, checkpoint_id=self._create_checkpoint_id(curr_step), async_mode=AsyncMode.DISABLED, enable_garbage_collection=True, - save_in_safetensors_format=self.last_save_in_safetensors_format, + to_hf=self.last_save_in_hf, ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index c209d8fb3c..f7babeb704 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -399,31 +399,31 @@ class Checkpoint: "//pre_train/checkpoints/llama3/llama3_8b/step_10000". """ - initial_load_model_weights_only: bool = True + initial_load_model_only: bool = True """ - This option specifies if only the model weights should be loaded during the initial + This option specifies if only the model should be loaded during the initial checkpoint load. The option is only used when `initial_load_path` is specified. If False, the checkpoint at `initial_load_path` is treated as a standard training checkpoint, including optimizer and training states. The default setting for this option is True. Note that you will have to use - `--checkpoint.no_initial_load_model_weights_only` to override the default setting. + `--checkpoint.no_initial_load_model_only` to override the default setting. """ interval: int = 500 """Checkpointing interval in steps.""" - last_save_model_weights_only: bool = True + last_save_model_only: bool = True """ - When last_save_model_weights_only=True, only model weights will be saved at the end of training, + When last_save_model_only=True, only the model will be saved at the end of training, the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` - after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved. + after conversion. When last_save_model_only=False, the full checkpoint will be saved. A full checkpoint includes model, optimizer and train_state, which can be used to resume training. The default value is True. """ export_dtype: Literal["float16", "bfloat16", "float32"] = "float32" """ - Converts to the specified precision when training completes and last_save_model_weights_only=true. + Converts to the specified precision when training completes and last_save_model_only=true. """ create_seed_checkpoint: bool = False @@ -475,15 +475,15 @@ class Checkpoint: for many steps or checkpointing too frequently. The default value is False. """ - last_save_in_safetensors_format: bool = False + last_save_in_hf: bool = False """ - Enable the use of safetensors format for checkpointing. This will save the final checkpoints - in safetensors format instead of the default DCP format. There will be a performance - cost in using this as we need to consolidate the sharded tensors to full tensors as - a separate step. last_save_model_weights_only must be true because safetensors doesn't - support saving non tensors. On load, this argument isn't needed as we will detect - whether the loaded checkpoint is in safetensors format or not. - The default value is False. + Enable the use of Hugging Face's safetensors format for checkpointing. This will save the + final checkpoints in safetensors format instead of the default DCP format, after necessary + model state dict transformation. There will be a performance cost in using this as we need + to consolidate the sharded tensors to full tensors as a separate step. + last_save_model_only must be true because safetensors doesn't support saving + non-tensors. On load, this argument isn't needed as we will detect whether the loaded + checkpoint is in safetensors format or not. The default value is False. """ diff --git a/torchtitan/experiments/flux/tests/integration_tests.py b/torchtitan/experiments/flux/tests/integration_tests.py index 9ba7ee378f..cd2bec0976 100755 --- a/torchtitan/experiments/flux/tests/integration_tests.py +++ b/torchtitan/experiments/flux/tests/integration_tests.py @@ -58,11 +58,11 @@ def build_test_list(): [ [ "--checkpoint.enable_checkpoint", - "--checkpoint.last_save_model_weights_only", + "--checkpoint.last_save_model_only", ], ], - "Checkpoint Integration Test - Save Model Weights Only fp32", - "last_save_model_weights_only_fp32", + "Checkpoint Integration Test - Save Model Only fp32", + "last_save_model_only_fp32", ), # Parallelism tests. OverrideDefinitions( diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml index c243608a3c..aa35e176ee 100644 --- a/torchtitan/experiments/flux/train_configs/debug_model.toml +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -68,6 +68,6 @@ mode = "full" enable_checkpoint = false folder = "checkpoint" interval = 10 -last_save_model_weights_only = false +last_save_model_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml index 97bc7b5d30..ae03281780 100644 --- a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml @@ -67,6 +67,6 @@ mode = "full" enable_checkpoint = false folder = "checkpoint" interval = 1_000 -last_save_model_weights_only = false +last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml index 9025a06ff1..9cfb6421b9 100644 --- a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml @@ -67,6 +67,6 @@ mode = "full" enable_checkpoint = false folder = "checkpoint" interval = 1_000 -last_save_model_weights_only = false +last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index d72406d8c8..7c17cc9d9c 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -58,7 +58,7 @@ expert_parallel_degree = 1 enable_checkpoint = false folder = "checkpoint" interval = 10 -last_save_model_weights_only = false +last_save_model_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index 707fea92ef..bfaa57fa4e 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -51,7 +51,7 @@ context_parallel_degree = 1 enable_checkpoint = false folder = "checkpoint" interval = 500 -last_save_model_weights_only = false +last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index b4b14358c7..66d7c9dd76 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -49,7 +49,7 @@ context_parallel_degree = 1 enable_checkpoint = false folder = "checkpoint" interval = 500 -last_save_model_weights_only = false +last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index c4642dc5e3..1983b0611d 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -60,7 +60,7 @@ pipeline_parallel_schedule = "1F1B" enable_checkpoint = false folder = "checkpoint" interval = 10 -last_save_model_weights_only = false +last_save_model_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index a9316b5487..35694e1fd8 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -58,7 +58,7 @@ pipeline_parallel_schedule = "Interleaved1F1B" enable_checkpoint = false folder = "checkpoint" interval = 10 -last_save_model_weights_only = false +last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index b3722c08bd..51acd7e72a 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -58,7 +58,7 @@ pipeline_parallel_schedule = "Interleaved1F1B" enable_checkpoint = false folder = "checkpoint" interval = 500 -last_save_model_weights_only = false +last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 26895274c2..bbfebd36c4 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -3,8 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# -# Copyright (c) Meta Platforms, Inc. All Rights Reserved. from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers @@ -18,6 +16,7 @@ from .infra.pipeline import pipeline_llama from .model.args import TransformerModelArgs from .model.model import Transformer +from .model.state_dict_adapter import Llama3StateDictAdapter __all__ = [ "parallelize_llama", @@ -83,5 +82,6 @@ build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, build_validator_fn=build_validator, + state_dict_adapter=Llama3StateDictAdapter, ) ) diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py new file mode 100644 index 0000000000..2406ee3ad1 --- /dev/null +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +from torchtitan.protocols.state_dict_adapter import StateDictAdapter + + +class Llama3StateDictAdapter(StateDictAdapter): + @staticmethod + def to_hf(state_dict: dict[str, Any]) -> dict[str, Any]: + # TODO: implement this + return state_dict + + @staticmethod + def from_hf(hf_state_dict: dict[str, Any]) -> dict[str, Any]: + # TODO: implement this + return hf_state_dict diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index b9d26c7d96..b61520f1cf 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -59,7 +59,7 @@ context_parallel_degree = 1 enable_checkpoint = false folder = "checkpoint" interval = 10 -last_save_model_weights_only = false +last_save_model_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/models/llama3/train_configs/llama3_405b.toml b/torchtitan/models/llama3/train_configs/llama3_405b.toml index 8b12113c56..d34e85d213 100644 --- a/torchtitan/models/llama3/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_405b.toml @@ -49,7 +49,7 @@ context_parallel_degree = 1 enable_checkpoint = false folder = "checkpoint" interval = 500 -last_save_model_weights_only = false +last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index e65d7a1ad9..3f2a0355d6 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -48,7 +48,7 @@ context_parallel_degree = 1 enable_checkpoint = false folder = "checkpoint" interval = 500 -last_save_model_weights_only = false +last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 5530177798..ed1335fa80 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -48,7 +48,7 @@ context_parallel_degree = 1 enable_checkpoint = false folder = "checkpoint" interval = 500 -last_save_model_weights_only = false +last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py new file mode 100644 index 0000000000..bd22c8d9ba --- /dev/null +++ b/torchtitan/protocols/state_dict_adapter.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Any + + +class StateDictAdapter(ABC): + """Abstract base class for state dict transformations. + + This class defines the interface for converting between native model + state dict format and other model state dict formats. + """ + + @staticmethod + @abstractmethod + def to_hf(state_dict: dict[str, Any]) -> dict[str, Any]: + """Convert from native model state dict to HuggingFace format. + + Args: + state_dict: The native model state dict + + Returns: + The converted HuggingFace format state dict + """ + pass + + @staticmethod + @abstractmethod + def from_hf(hf_state_dict: dict[str, Any]) -> dict[str, Any]: + """Obtain native model state dict from HuggingFace format. + + Args: + hf_state_dict: The HuggingFace format state dict + + Returns: + The converted native model state dict + """ + pass diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 0e376b2f65..edf3cc4b93 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -3,8 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# -# Copyright (c) Meta Platforms, Inc. All Rights Reserved. from abc import abstractmethod from collections.abc import Callable @@ -25,6 +23,7 @@ from torchtitan.components.validate import BaseValidator from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.protocols.state_dict_adapter import StateDictAdapter @dataclass @@ -102,6 +101,7 @@ class TrainSpec: build_loss_fn: LossFunctionBuilder build_validator_fn: ValidatorBuilder | None = None build_metrics_processor_fn: MetricsProcessorBuilder | None = None + state_dict_adapter: type[StateDictAdapter] | None = None _train_specs = {} diff --git a/torchtitan/train.py b/torchtitan/train.py index 3fc9a2560a..4ccb4a45b4 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -297,6 +297,7 @@ def __init__(self, job_config: JobConfig): lr_schedulers=self.lr_schedulers, states={"train_state": self}, job_config=job_config, + sd_adapter=self.train_spec.state_dict_adapter, ft_manager=self.ft_manager, ) From 93a236ccd48a3602a31045674b618d59a9a7b9c8 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Mon, 21 Jul 2025 10:46:23 -0700 Subject: [PATCH 022/128] [DSV3] Add normalization for topk router scores (#1419) Fix issue mentioned in #1418 . Deepseek reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L592 Huggingface reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py#L468 --- torchtitan/models/deepseek_v3/model/moe.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 840cb1a57d..02a094686c 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -220,6 +220,10 @@ def forward( scores, k=self.top_k, dim=1 ) + if self.use_sigmoid: + denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 + top_scores = top_scores / denominator + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward num_tokens_per_expert = torch.histc( selected_experts_indices.view(-1), From 4e73af3e2c5f99ad3cb5a21612e615a64b0b75e7 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 21 Jul 2025 13:47:17 -0700 Subject: [PATCH 023/128] update documentation for release (#1425) as titled --- README.md | 19 +++++++++---------- docs/release.md | 26 ++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 10 deletions(-) create mode 100644 docs/release.md diff --git a/README.md b/README.md index fa9b3ba35c..58c437335d 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ To use the latest features of `torchtitan`, we recommend using the most recent P ## Latest News +- [2025/07] We released `torchtitan` [v0.1.0](https://github.com/pytorch/torchtitan/releases), and also set up nightly builds. - [2025/04] Our paper was accepted by [ICLR 2025](https://iclr.cc/virtual/2025/poster/29620). - [2025/04] [Llama 4](torchtitan/experiments/llama4/) initial support is available as an experiment. - [2025/04] Training the diffusion model [FLUX](torchtitan/experiments/flux/) with FSDP/HSDP is available as an experiment. @@ -89,31 +90,29 @@ You may want to see how the model is defined or how parallelism techniques are a ## Installation -> [Install PyTorch](https://pytorch.org/get-started/locally/) before proceeding. +One can choose to install `torchtitan` from a stable release, a nightly build, or directly run the source code. Please [install PyTorch](https://pytorch.org/get-started/locally/) before proceeding. -### Stable - -Via pip: +### Stable releases +One can install the latest [stable release]((https://github.com/pytorch/torchtitan/releases)) of `torchtitan` via `pip` or `conda`. ```sh pip install torchtitan ``` -Or via conda: ```sh conda install conda-forge::torchtitan ``` +Note that each stable release pins the nightly versions of `torch` and `torchao`. Please see [release.md](docs/release.md) for more details. -### Nightly +### Nightly builds -> This method requires the nightly build of PyTorch. +This method requires the nightly build of PyTorch. You can replace `cu126` with another version of cuda (e.g. `cu128`) or an AMD GPU (e.g. `rocm6.3`). ```sh pip install --pre torchtitan --index-url https://download.pytorch.org/whl/nightly/cu126 ``` -You can replace `cu126` with another version of cuda (e.g. `cu128`) or an AMD GPU (e.g. `rocm6.3`). ### From source -> This method requires the nightly build of PyTorch or PyTorch built from source. +This method requires the nightly build of PyTorch or the latest PyTorch built from source. ```bash git clone https://github.com/pytorch/torchtitan @@ -123,7 +122,7 @@ pip install -r requirements.txt ### Downloading a tokenizer -`torchtitan` currently supports training Llama 3.1 (8B, 70B, 405B) out of the box. To get started training these models, we need to download a tokenizer.model. Follow the instructions on the official [meta-llama](https://huggingface.co/meta-llama/Llama-3.1-8B) repository to ensure you have access to the Llama model weights. +`torchtitan` currently supports training Llama 3.1 (8B, 70B, 405B) out of the box. To get started training these models, we need to download the tokenizer. Follow the instructions on the official [meta-llama](https://huggingface.co/meta-llama/Llama-3.1-8B) repository to ensure you have access to the Llama model weights. Once you have confirmed access, you can run the following command to download the Llama 3.1 tokenizer to your local machine. diff --git a/docs/release.md b/docs/release.md new file mode 100644 index 0000000000..598734960e --- /dev/null +++ b/docs/release.md @@ -0,0 +1,26 @@ +## Stable Releases +Currently we follow a lightweight release process. +- Update the version number in `assets/version.txt` with a PR. The version numbering should follow https://semver.org/. + - E.g. for a pre-release `0.y.z` + - if major features are added, increment `y` + - if minor fixes are added, increment `z` +- Create a new release at https://github.com/pytorch/torchtitan/releases/new + - In the release notes + - include proper nightly versions for `torch` and `torchao`, which can be found in [latest CI](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu.yaml) test log "Run script in container" section. E.g. + - "Successfully installed ... `torch-2.8.0.dev20250605+cu126`" + - "Successfully installed `torchao-0.12.0.dev20250605+cu126`" + - describe the release at a high level, compared with the last release, e.g. + - "added an experiment for multimodal LLM training" + - or just "this is a regular release" + - For now, choose "Set as a pre-release". +- As we set up the GitHub workflow [release.yml](/.github/workflows/release.yml), it should trigger a [GitHub action](https://github.com/pytorch/torchtitan/actions/workflows/release.yml) to update the [torchtitan package on PyPI](https://pypi.org/project/torchtitan/), which requires approval from one the the maintainers to run. + +The general instruction on managing releases can be found [here](https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository). + + +## Nightly Builds +Nightly builds are automatically triggered by a [nightly GitHub workflow](/.github/workflows/build_whl_and_publish.yaml) and can be installed by +```bash +pip install --pre torchtitan --index-url https://download.pytorch.org/whl/nightly/cu126 +``` +You can replace `cu126` with another version of cuda (e.g. `cu128`) or an AMD GPU (e.g. `rocm6.3`). From 415f83445a9eadaa571a2f4a370995c9244181d1 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Mon, 21 Jul 2025 18:12:14 -0700 Subject: [PATCH 024/128] Add Flex debug config (#1437) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stacked PRs: * __->__#1437 --- --- --- Add flex as impl, debugging https://github.com/pytorch/torchtitan/issues/1412 ### Running debug model ``` ❯ NGPU=8 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.flavor debugmodel_flex_attn + NGPU=8 + export LOG_RANK=0 + LOG_RANK=0 + CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml + overrides= + '[' 2 -ne 0 ']' + overrides='--model.flavor debugmodel_flex_attn' + TORCHFT_LIGHTHOUSE=http://localhost:29510 + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510 + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/debug_model.toml --model.flavor debugmodel_flex_attn ***************************************** Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. ***************************************** [rank0]:[titan] 2025-07-21 15:57:07,261 - root - INFO - Starting job: DeepSeek-V3 debug training [rank0]:[titan] 2025-07-21 15:57:08,850 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-21 15:57:08,852 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-07-21 15:57:08,853 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:NCCL version 2.27.5+cuda12.9 [rank0]:[titan] 2025-07-21 15:57:12,369 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-07-21 15:57:12,371 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-07-21 15:57:12,449 - root - INFO - Building deepseek_v3 debugmodel_flex_attn with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=2048, dtype='bf16', vocab_size=2000, dim=256, inter_dim=1024, moe_inter_dim=256, n_layers=3, n_dense_layers=1, n_heads=16, norm_eps=1e-05, n_routed_experts=8, n_shared_experts=2, n_activated_experts=3, n_expert_groups=1, n_limited_groups=1, score_func='softmax', route_scale=1.0, use_grouped_mm=True, load_balance_coeff=0.001, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=True, attn_mask_type='block_causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7, eos_id=0) [rank0]:[titan] 2025-07-21 15:57:12,471 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory [rank0]:[titan] 2025-07-21 15:57:12,593 - root - INFO - Total parameter count: dense 12,479,744, sparse 3,936,256, active 14,449,920 [rank0]:[titan] 2025-07-21 15:57:12,593 - root - INFO - Model deepseek_v3 debugmodel_flex_attn size: 16,416,000 total parameters [rank0]:[titan] 2025-07-21 15:57:12,615 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-21 15:57:12,858 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-07-21 15:57:12,858 - root - INFO - CUDA memory usage for model: 0.00GiB(0.00%) [rank0]:[titan] 2025-07-21 15:57:12,859 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-21 15:57:12,859 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2). [rank0]:[titan] 2025-07-21 15:57:12,859 - root - INFO - Training starts at step 1. [rank0]:/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/backends/cuda/__init__.py:131: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:78.) [rank0]: return torch._C._get_cublas_allow_tf32() [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 361472 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 263168 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 328704 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:AUTOTUNE flex_attention(8x16x2048x192, 8x16x2048x192, 8x16x2048x128, 8x16x2048, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x2048) [rank0]:strides: [6291456, 192, 3072, 1], [6291456, 192, 3072, 1], [8388608, 256, 4096, 1], [32768, 2048, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [2048, 1] [rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32 [rank0]: triton_flex_attention_4 2.7331 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_0 5.9175 ms 46.2% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_1 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_2 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=8 [rank0]: triton_flex_attention_3 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]:SingleProcess AUTOTUNE benchmarking takes 0.1680 seconds and 2.7997 seconds precompiling for 5 choices [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 248832 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 297984 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 247808 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 247808 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 297984 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 297984 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 347136 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 347136 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:AUTOTUNE flex_attention_backward(8x16x2048x192, 8x16x2048x192, 8x16x2048x128, 8x16x2048, 8x16x2048, 8x16x2048x128, 8x16x2048x192, 8x16x2048x128, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x2048) [rank0]:strides: [6291456, 192, 3072, 1], [6291456, 192, 3072, 1], [8388608, 256, 4096, 1], [32768, 2048, 1], [32768, 2048, 1], [4194304, 262144, 128, 1], [6291456, 192, 3072, 1], [4194304, 128, 2048, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [2048, 1] [rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32 [rank0]: triton_flex_attention_backward_16 6.9023 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=8 [rank0]: triton_flex_attention_backward_18 7.2956 ms 94.6% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=8 [rank0]: triton_flex_attention_backward_20 7.9480 ms 86.8% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=8 [rank0]: triton_flex_attention_backward_14 8.2263 ms 83.9% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8 [rank0]: triton_flex_attention_backward_9 10.0881 ms 68.4% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=4 [rank0]: triton_flex_attention_backward_10 10.2952 ms 67.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_backward_11 10.6428 ms 64.9% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=4 [rank0]: triton_flex_attention_backward_12 11.6085 ms 59.5% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=4 [rank0]: triton_flex_attention_backward_26 12.3810 ms 55.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8 [rank0]: triton_flex_attention_backward_33 13.3964 ms 51.5% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=4 [rank0]:SingleProcess AUTOTUNE benchmarking takes 4.0514 seconds and 7.8970 seconds precompiling for 29 choices [rank0]:[titan] 2025-07-21 15:57:41,765 - root - INFO - step: 1 loss: 7.9132 grad_norm: 2.5572 memory: 0.00GiB(0.00%) tps: 562 tflops: 0.06 mfu: 0.01% [rank0]:[titan] 2025-07-21 15:57:41,765 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-21 15:57:42,241 - root - INFO - step: 2 loss: 6.6131 grad_norm: 3.5341 memory: 0.00GiB(0.00%) tps: 34,387 tflops: 3.52 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:42,774 - root - INFO - step: 3 loss: 4.9925 grad_norm: 2.6422 memory: 0.00GiB(0.00%) tps: 30,768 tflops: 3.15 mfu: 0.32% [rank0]:[titan] 2025-07-21 15:57:43,242 - root - INFO - step: 4 loss: 4.6671 grad_norm: 2.5457 memory: 0.00GiB(0.00%) tps: 35,051 tflops: 3.59 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:43,713 - root - INFO - step: 5 loss: 4.4394 grad_norm: 2.2714 memory: 0.00GiB(0.00%) tps: 34,824 tflops: 3.57 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:44,183 - root - INFO - step: 6 loss: 4.2412 grad_norm: 2.0883 memory: 0.00GiB(0.00%) tps: 34,852 tflops: 3.57 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:44,656 - root - INFO - step: 7 loss: 4.0939 grad_norm: 1.9913 memory: 0.00GiB(0.00%) tps: 34,692 tflops: 3.56 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:45,126 - root - INFO - step: 8 loss: 3.9942 grad_norm: 1.8520 memory: 0.00GiB(0.00%) tps: 34,879 tflops: 3.58 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:45,616 - root - INFO - step: 9 loss: 4.0337 grad_norm: 1.6383 memory: 0.00GiB(0.00%) tps: 33,464 tflops: 3.43 mfu: 0.35% [rank0]:[titan] 2025-07-21 15:57:46,090 - root - INFO - step: 10 loss: 3.9119 grad_norm: 1.6558 memory: 0.00GiB(0.00%) tps: 34,574 tflops: 3.54 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:46,090 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-07-21 15:57:48,090 - root - INFO - Training completed [rank0]:[titan] 2025-07-21 15:57:48,407 - r ``` With these changes On H100 ```Shell NGPU=8 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml ./run_train.sh + NGPU=8 + export LOG_RANK=0 + LOG_RANK=0 + CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml + overrides= + '[' 0 -ne 0 ']' + TORCHFT_LIGHTHOUSE=http://localhost:29510 + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510 + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml ***************************************** Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. ***************************************** [rank0]:[titan] 2025-07-21 14:32:24,245 - root - INFO - Starting job: DeepSeek-V3 16B model training [rank0]:[titan] 2025-07-21 14:32:25,879 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-21 14:32:25,880 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-07-21 14:32:25,882 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:NCCL version 2.27.5+cuda12.9 [rank0]:[titan] 2025-07-21 14:32:29,510 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-07-21 14:32:29,716 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:[titan] 2025-07-21 14:32:33,715 - root - INFO - Building deepseek_v3 16B with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=4096, dtype='bf16', vocab_size=128815, dim=2048, inter_dim=10944, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, n_heads=16, norm_eps=1e-05, n_routed_experts=64, n_shared_experts=2, n_activated_experts=6, n_expert_groups=1, n_limited_groups=1, score_func='softmax', route_scale=1.0, use_grouped_mm=True, load_balance_coeff=0.001, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=True, attn_mask_type='causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7, eos_id=0) [rank0]:[titan] 2025-07-21 14:32:33,892 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory [rank0]:[titan] 2025-07-21 14:32:33,946 - root - INFO - Total parameter count: dense 966,581,760, sparse 14,848,098,304, active 2,769,346,048 [rank0]:[titan] 2025-07-21 14:32:33,946 - root - INFO - Model deepseek_v3 16B size: 15,814,680,064 total parameters [rank0]:[titan] 2025-07-21 14:32:33,947 - root - INFO - Applied full activation checkpointing to the model [rank0]:[titan] 2025-07-21 14:32:34,034 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-21 14:32:34,431 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-07-21 14:32:34,431 - root - INFO - CUDA memory usage for model: 0.00GiB(0.00%) [rank0]:[titan] 2025-07-21 14:32:34,433 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-21 14:32:34,433 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 4096, total steps 1000 (warmup 200). [rank0]:[titan] 2025-07-21 14:32:34,433 - root - INFO - Training starts at step 1. [rank0]:/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/backends/cuda/__init__.py:131: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:78.) [rank0]: return torch._C._get_cublas_allow_tf32() [rank0]:Exception No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_drisspg/uu/cuuzcmwjztg3bl3ujzslc4ma26j6dinag6nocsnpg3dtmquqyno2.py, BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=8) [rank0]:Traceback (most recent call last): [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/concurrent/futures/thread.py", line 59, in run [rank0]: result = self.fn(*self.args, **self.kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 2598, in precompile_with_captured_stdout [rank0]: choice.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1875, in precompile [rank0]: self.bmreq.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/autotune_process.py", line 657, in precompile [rank0]: getattr(mod, self.kernel_name).precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile [rank0]: self._make_launchers() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers [rank0]: raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") [rank0]:RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help. [rank0]:Exception No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 327680 Hardware limit:232448 Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_drisspg/rg/crg27sodhizvda2jjyznte4u2sxyexstosyzgyfj2lfuekd2u7fx.py, BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4) [rank0]:Traceback (most recent call last): [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/concurrent/futures/thread.py", line 59, in run [rank0]: result = self.fn(*self.args, **self.kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 2598, in precompile_with_captured_stdout [rank0]: choice.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1875, in precompile [rank0]: self.bmreq.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/autotune_process.py", line 657, in precompile [rank0]: getattr(mod, self.kernel_name).precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile [rank0]: self._make_launchers() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers [rank0]: raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") [rank0]:RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 327680 Hardware limit:232448 Reducing block sizes or `num_stages` may help. [rank0]:Exception No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 360448 Hardware limit:232448 Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_drisspg/7c/c7cznn55lunqd7ln454lhdbpef3s6vnbrkieghrh4l74inbpmkgo.py, BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4) [rank0]:Traceback (most recent call last): [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/concurrent/futures/thread.py", line 59, in run [rank0]: result = self.fn(*self.args, **self.kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 2598, in precompile_with_captured_stdout [rank0]: choice.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1875, in precompile [rank0]: self.bmreq.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/autotune_process.py", line 657, in precompile [rank0]: getattr(mod, self.kernel_name).precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile [rank0]: self._make_launchers() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers [rank0]: raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") [rank0]:RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 360448 Hardware limit:232448 Reducing block sizes or `num_stages` may help. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 360448 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 327680 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:AUTOTUNE flex_attention(8x16x4096x192, 8x16x4096x192, 8x16x4096x128, 8x16x4096, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32) [rank0]:strides: [12582912, 192, 3072, 1], [12582912, 192, 3072, 1], [16777216, 256, 4096, 1], [65536, 4096, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1] [rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.int32, torch.int32, torch.int32, torch.int32 [rank0]: triton_flex_attention_4 9.7972 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_0 18.6366 ms 52.6% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_1 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_2 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=8 [rank0]: triton_flex_attention_3 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]:SingleProcess AUTOTUNE benchmarking takes 0.2523 seconds and 4.4028 seconds precompiling for 5 choices [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 247808 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 296960 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 246784 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 246784 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 296960 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 296960 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 346112 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 346112 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:AUTOTUNE flex_attention_backward(8x16x4096x192, 8x16x4096x192, 8x16x4096x128, 8x16x4096, 8x16x4096, 8x16x4096x128, 8x16x4096x192, 8x16x4096x128, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32) [rank0]:strides: [12582912, 192, 3072, 1], [12582912, 192, 3072, 1], [16777216, 256, 4096, 1], [65536, 4096, 1], [65536, 4096, 1], [8388608, 524288, 128, 1], [12582912, 192, 3072, 1], [8388608, 128, 2048, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1] [rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32 [rank0]: triton_flex_attention_backward_16 28.0839 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=8 [rank0]: triton_flex_attention_backward_18 28.5011 ms 98.5% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=8 [rank0]: triton_flex_attention_backward_20 30.6411 ms 91.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=8 [rank0]: triton_flex_attention_backward_14 31.6448 ms 88.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8 [rank0]: triton_flex_attention_backward_10 36.6753 ms 76.6% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_backward_11 39.1248 ms 71.8% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=4 [rank0]: triton_flex_attention_backward_9 39.5692 ms 71.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=4 [rank0]: triton_flex_attention_backward_12 44.0589 ms 63.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=4 [rank0]: triton_flex_attention_backward_26 44.6748 ms 62.9% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8 [rank0]: triton_flex_attention_backward_33 49.2283 ms 57.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=4 [rank0]:SingleProcess AUTOTUNE benchmarking takes 13.1756 seconds and 7.4547 seconds precompiling for 29 choices [rank0]:[titan] 2025-07-21 14:33:33,285 - root - INFO - step: 1 loss: 12.2446 grad_norm: 1.2270 memory: 0.00GiB(0.00%) tps: 552 tflops: 9.80 mfu: 0.99% [rank0]:[titan] 2025-07-21 14:33:33,286 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-21 14:36:41,972 - root - INFO - step: 10 loss: 11.4945 grad_norm: 1.5629 memory: 0.00GiB(0.00%) tps: 1,563 tflops: 27.74 mfu: 2.81% [rank0]:[titan] 2025-07-21 14:40:12,773 - root - INFO - step: 20 loss: 9.8251 grad_norm: 7.4642 memory: 0.00GiB(0.00%) tps: 1,554 tflops: 27.59 mfu: 2.79% [rank0]:[titan] 2025-07-21 14:43:42,445 - root - INFO - step: 30 loss: 8.9616 grad_norm: 2.4638 memory: 0.00GiB(0.00%) tps: 1,563 tflops: 27.74 mfu: 2.81% ... ``` --- torchtitan/models/attention.py | 8 ++++++-- torchtitan/models/deepseek_v3/__init__.py | 21 ++++++++++++++++++++ torchtitan/models/deepseek_v3/model/args.py | 1 + torchtitan/models/deepseek_v3/model/model.py | 1 + 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index e4b3b8c683..ca6545a24a 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -77,10 +77,14 @@ def mask_key(self) -> FLEX_ATTN_MASK_T: return (self.attn_mask_type, self.fixed_block_size) def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float | None = None, ) -> torch.Tensor: block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask) + return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) @staticmethod def _get_causal_mask_mod() -> _mask_mod_signature: diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index de2d26b8a3..7d7ebd8a7c 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -47,6 +47,27 @@ v_head_dim=128, mscale=0.70, ), + "debugmodel_flex_attn": DeepSeekV3ModelArgs( + vocab_size=102400, + dim=256, + inter_dim=1024, + moe_inter_dim=256, + n_layers=3, + n_dense_layers=1, + n_heads=16, + n_routed_experts=8, + n_shared_experts=2, + n_activated_experts=3, + route_scale=1.0, + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="block_causal", + ), "16B": DeepSeekV3ModelArgs( vocab_size=102400, dim=2048, diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index ea469c6725..cfa396410c 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -92,6 +92,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): beta_fast: int = 32 beta_slow: int = 1 mscale: float = 1.0 + eos_id: int = 0 def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: """ diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 34332751fd..e68d6ba838 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -330,6 +330,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): bias=False, ) self.model_args = model_args + self.eos_id = model_args.eos_id self.init_weights() def init_weights(self, buffer_device: torch.device | None = None) -> None: From ad7f6446c3e97726484aa70d2422b7ccb4dd6624 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 21 Jul 2025 21:51:40 -0400 Subject: [PATCH 025/128] [FT] Add torchft to CI (#1398) based off of @fegin previous PR: https://github.com/pytorch/torchtitan/pull/907 This sets up CI for torchft in torchtitan, it only runs when `torchtitan/components/ft.py` is changed, but we can expand it as necessary. This also makes it easier to run torchft tests / configs by just calling: `python tests/integration_tests_ft.py --test_id ...` TODO: - There is an issue with our CI where we cannot set `CUDA_VISIBLE_DEVICES` to partition the devices per replica group (I get error Cuda failure 217 'peer access is not supported between these two devices'), as a result this PR only runs with 1 replica group. In follow up PRs need to look into manually setting the cuda device when using torchft? --- .../integration_test_8gpu_torchft.yaml | 53 ++++++ tests/integration_tests_ft.py | 159 ++++++++++++++++++ 2 files changed, 212 insertions(+) create mode 100644 .github/workflows/integration_test_8gpu_torchft.yaml create mode 100644 tests/integration_tests_ft.py diff --git a/.github/workflows/integration_test_8gpu_torchft.yaml b/.github/workflows/integration_test_8gpu_torchft.yaml new file mode 100644 index 0000000000..b06201ae7d --- /dev/null +++ b/.github/workflows/integration_test_8gpu_torchft.yaml @@ -0,0 +1,53 @@ +name: TorchFT 8 GPU Integration Test + +on: + push: + branches: [ main ] + paths: + - 'torchtitan/components/ft.py' + pull_request: + paths: + - 'torchtitan/components/ft.py' + schedule: + # Runs every 6 hours + - cron: '0 */6 * * *' +concurrency: + group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -l -eo pipefail {0} + +jobs: + build-test: + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + runner: linux.g5.48xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.6" + # This image is faster to clone than the default, but it lacks CC needed by triton + # (1m25s vs 2m37s). + docker-image: torchtitan-ubuntu-20.04-clang12 + repository: pytorch/torchtitan + upload-artifact: outputs + script: | + set -eux + + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + pip config --user set global.progress_bar off + + python -m pip install torchft-nightly + python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 + + mkdir artifacts-to-be-uploaded + echo "torchft_lighthouse" + RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 > /dev/null 2>&1 & + echo "ft_integration_test" + # Getting error - Cuda failure 217 'peer access is not supported between these two devices' + python ./tests/integration_tests_ft.py artifacts-to-be-uploaded --ngpu 8 + # pkill -9 torchft_lighthouse diff --git a/tests/integration_tests_ft.py b/tests/integration_tests_ft.py new file mode 100644 index 0000000000..75005e7387 --- /dev/null +++ b/tests/integration_tests_ft.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import concurrent.futures +import logging +import os +import subprocess +from collections import defaultdict +from dataclasses import dataclass +from typing import Sequence + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib + + +@dataclass +class OverrideDefinitions: + """ + This class is used to define the override definitions for the integration tests. + """ + + override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) + test_descr: str = "default" + test_name: str = "default" + ngpu: int = 4 + model_flavor: str = "debugmodel" + + def __repr__(self): + return self.test_descr + + +def build_test_list(): + """ + key is the config file name and value is a list of OverrideDefinitions + that is used to generate variations of integration tests based on the + same root config file. + """ + integration_tests_flavors = defaultdict(list) + integration_tests_flavors["debug_model.toml"] = [ + OverrideDefinitions( + [ + ["--training.steps 10", "--checkpoint.enable_checkpoint"], + ], + "Default TorchFT integration test", + "default_torchft", + ) + ] + return integration_tests_flavors + + +def _run_cmd(cmd): + return subprocess.run([cmd], text=True, shell=True) + + +def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): + # run_test supports sequence of tests. + test_name = test_flavor.test_name + dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}" + model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}" + + # Use all 8 GPUs in a single replica + # TODO: Use two replica groups + # Right now when passing CUDA_VISIBLE_DEVICES=0,1,2,3 and 4,5,6,7 for 2 RGs I get + # Cuda failure 217 'peer access is not supported between these two devices' + all_ranks = [",".join(map(str, range(0, 8)))] + + for test_idx, override_arg in enumerate(test_flavor.override_args): + cmds = [] + + for replica_id, ranks in enumerate(all_ranks): + cmd = ( + f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + + f"CUDA_VISIBLE_DEVICES={ranks}" + + f"CONFIG_FILE={full_path} NGPU={len(ranks)} ./run_train.sh " + + "--fault_tolerance.enable " + + f"--fault_tolerance.replica_id={replica_id} --fault_tolerance.group_size={len(all_ranks)}" + ) + + cmd += " " + dump_folder_arg + cmd += " " + model_flavor_arg + if override_arg: + cmd += " " + " ".join(override_arg) + + logger.info( + "=====TorchFT Integration test, flavor : " + f"{test_flavor.test_descr}, command : {cmd}=====" + ) + cmds.append((replica_id, cmd)) + + with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor: + futures = [executor.submit(_run_cmd, cmd) for _, cmd in cmds] + results = [future.result() for future in futures] + + for i, result in enumerate(results): + logger.info(result.stdout) + + if result.returncode == 0: + continue + + raise Exception( + f"Integration test {test_idx} failed, flavor : {test_flavor.test_descr}, command : {cmds[i]}" + ) + + +def run_tests(args): + integration_tests_flavors = build_test_list() + + if args.ngpu < 8: + logger.info("Skipping TorchFT integration tests as we need 8 GPUs.") + return + + for config_file in os.listdir(args.config_dir): + if not config_file.endswith(".toml"): + continue + + full_path = os.path.join(args.config_dir, config_file) + with open(full_path, "rb") as f: + config = tomllib.load(f) + is_integration_test = config["job"].get("use_for_integration_test", False) + if not is_integration_test: + continue + + for test_flavor in integration_tests_flavors[config_file]: + if not (args.test == "all" or test_flavor.test_name == args.test): + continue + + run_test(test_flavor, full_path, args.output_dir) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("output_dir") + parser.add_argument( + "--config_dir", default="./torchtitan/models/llama3/train_configs" + ) + parser.add_argument( + "--test", + default="all", + help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", + ) + parser.add_argument("--ngpu", default=8, type=int) + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + run_tests(args) + + +if __name__ == "__main__": + main() From 16273f8da39b00db08ce152993fdcc5666aac906 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Tue, 22 Jul 2025 14:46:58 -0700 Subject: [PATCH 026/128] [deepseek] fix FlexAttention + TP (#1440) Without this PR, FlexAttention mask function will receive a mixed of plain tensors and DTensors. Given that FlexAttention + DTensor story is not clear, let's always convert to plain tensors when feeding things into FlexAttention / SDPA. --- torchtitan/models/deepseek_v3/infra/parallelize.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 1ba45f86d4..5220405950 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -188,10 +188,18 @@ def apply_non_moe_tp( input_layouts=(Shard(1), Replicate()), desired_input_layouts=(Replicate(), Replicate()), ), - # use_local_output=False make the output to be a DTensor instead of a plain Tensor + # NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor + # so that the intermedidate results k is generated as a DTensor and its gradient is + # correctly handled by the autograd engine. "attention.wkv_a": NoParallel(use_local_output=False), "attention.wkv_b": colwise_parallel(use_local_output=False), "attention.kv_norm": NoParallel(use_local_output=False), + # NOTE: use_local_output=True so that the inputs to FlexAttention are plain Tensors + "attention.sdpa": prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ), "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } From 177b295d23c850aaedbaecc0422338b76716e01b Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Tue, 22 Jul 2025 17:19:30 -0700 Subject: [PATCH 027/128] refactor so that ModelArgs does not depend on tokenizer (#1424) Previously we initialize ModelArgs, and then update it dynamically to 1. get `vocab_size` from tokenizer (used for specifying shapes of embedding / output module) 2. get `eos_id` from tokenizer (used for generating block causal attention mask) 3. update `max_seq_len` according to training job `seq_len` (used for precomputing `freqs_cis`) ------------- 1 is causing troubles as we found `vocab_size` for model checkpoints (embedding / output layer) loaded from HF may not be always the same as `tokenizer.get_vocab_size()`. In fact, as long as `vocab_size` in embedding / output layer is larger than tokenizer's `vocab_size`, the training is still OK. In addition, there have been requests to not let `ModeArgs` and model init depend on tokenizers, so this PR removes 2 and instead send `eos_id` as input to model. ---------------- For 3, there is a caveat that when torchtitan is used for continuing training from a checkpoint, users should be aware that the original model has an intrinsic limit on `max_seq_len`. E.g. for llama 3 it's 8k, for llama 4 Scout it's 1M. Currently torchtitan users could break this limit by specifying an arbitrarily large `--training.seq_len`, whether intentionally or not. This PR keeps this flexibility for two reasons: 1. when `seq_len` is less than intrinsic `max_seq_len`, we only need to generate `freqs_cis` of `seq_len` both because of resource consideration and because of CP compatibility. 2. when users intentionally want to test long context training / extend to longer context capability, they could still do that. Instead, this PR adds a warning when getting a `seq_len` larger than the original `max_seq_len` I noticed that llama "official" implementation also allows this https://github.com/meta-llama/llama-models/blob/main/models/llama4/generation.py#L72 --- scripts/estimate/estimation.py | 5 +-- scripts/generate/test_generate.py | 2 +- torchtitan/components/checkpoint.py | 6 ++-- torchtitan/experiments/llama4/__init__.py | 4 +++ torchtitan/experiments/llama4/model/args.py | 34 +++++++++++------- torchtitan/experiments/llama4/model/model.py | 12 ++++--- torchtitan/models/attention.py | 4 +-- torchtitan/models/deepseek_v3/__init__.py | 4 +-- torchtitan/models/deepseek_v3/model/args.py | 36 +++++++++++++++----- torchtitan/models/deepseek_v3/model/model.py | 10 ++++-- torchtitan/models/llama3/__init__.py | 3 +- torchtitan/models/llama3/infra/pipeline.py | 8 ++--- torchtitan/models/llama3/model/args.py | 33 +++++++++++------- torchtitan/models/llama3/model/model.py | 12 ++++--- torchtitan/protocols/train_spec.py | 4 +-- torchtitan/train.py | 27 ++++++++------- 16 files changed, 129 insertions(+), 75 deletions(-) diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index cec91fdcdd..0c8a9ccd6f 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -76,9 +76,6 @@ def estimate_memory(job_config: JobConfig): train_spec = get_train_spec(job_config.model.name) - # build tokenizer - tokenizer = train_spec.build_tokenizer_fn(job_config) - loss_parallel_enabled = ( parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel ) @@ -89,7 +86,7 @@ def estimate_memory(job_config: JobConfig): # build model (using meta init) model_args = train_spec.model_args[job_config.model.flavor] - model_args.update_from_config(job_config, tokenizer) + model_args.update_from_config(job_config) with ( FakeTensorMode() diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 07966c2763..dfdc859ec8 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -107,7 +107,7 @@ def test_generate( tokenizer = train_spec.build_tokenizer_fn(config) model_args = train_spec.model_args[config.model.flavor] - model_args.update_from_config(config, tokenizer) + model_args.update_from_config(config) init_device = "meta" if world_size > 1 else device with torch.device(init_device): diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 08c2dd1067..73489fb512 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -347,7 +347,9 @@ def dcp_save( storage_writer: HuggingFaceStorageWriter | None = None checkpoint_save_id: str | None = None if to_hf: - assert self.sd_adapter is not None + assert ( + self.sd_adapter is not None + ), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided." state_dict = self.sd_adapter.to_hf(state_dict) fqn_to_index_mapping = {} @@ -623,7 +625,7 @@ def _load_checkpoint_in_hf_format(self, checkpoint_id: str) -> bool: """ for filename in os.listdir(checkpoint_id): - if filename == "model.safetensors.index.json": + if filename.endswith(".safetensors"): return True return False diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index 798555ae43..7e3dd8f07c 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -28,6 +28,7 @@ dim=256, n_layers=6, n_heads=16, + vocab_size=2000, rope_theta=500000, ), "17bx16e": TransformerModelArgs( @@ -38,6 +39,7 @@ ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, + max_seq_len=10485760, num_experts=16, interleave_moe_layer_step=1, ), @@ -55,6 +57,7 @@ dim=256, n_layers=6, n_heads=16, + vocab_size=2000, rope_theta=500000, every_n_layers_nope=4, fixed_attn_block_size=256, @@ -69,6 +72,7 @@ ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, + max_seq_len=10485760, num_experts=16, interleave_moe_layer_step=1, every_n_layers_nope=4, diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index bccedd7be8..d4c71a1268 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -8,9 +8,8 @@ from dataclasses import dataclass from torch import nn -from torchtitan.components.tokenizer import BaseTokenizer -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger from torchtitan.tools.utils import has_cuda_capability @@ -22,13 +21,13 @@ class TransformerModelArgs(BaseModelArgs): n_layers: int = 32 n_heads: int = 32 n_kv_heads: int | None = None - vocab_size: int = -1 # defined later by tokenizer + vocab_size: int = 202048 multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: float | None = None norm_eps: float = 1e-5 rope_theta: float = 10000 - max_seq_len: int = 2048 + max_seq_len: int = 1048576 # If `True`, then each transformer block init uses its layer ID, and if # `False`, each uses the total number of transformer blocks depth_init: bool = True @@ -58,12 +57,13 @@ class TransformerModelArgs(BaseModelArgs): use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation load_balance_coeff: float | None = 1e-3 - def update_from_config( - self, job_config: JobConfig, tokenizer: BaseTokenizer - ) -> None: - self.vocab_size = tokenizer.get_vocab_size() - self.max_seq_len = job_config.training.seq_len - self.eos_id = tokenizer.eos_id + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len if self.use_grouped_mm and not has_cuda_capability(9, 0): logger.warning( @@ -72,9 +72,17 @@ def update_from_config( self.use_grouped_mm = False if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise ValueError( - "FlexAttention is not compatible with CP yet. " - "We are still working on this." + raise NotImplementedError( + "CP support for FlexAttention is still in progress." + ) + + if ( + job_config.parallelism.pipeline_parallel_degree > 1 + and self.use_flex_attn + and self.attn_mask_type == "block_causal" + ): + raise RuntimeError( + "PP + block causal FlexAttention support will be fixed soon." ) def get_nparams_and_flops( diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py index 02085d73e6..c6f410b38e 100644 --- a/torchtitan/experiments/llama4/model/model.py +++ b/torchtitan/experiments/llama4/model/model.py @@ -365,7 +365,7 @@ class Transformer(nn.Module, ModelProtocol): tok_embeddings (ParallelEmbedding): Token embeddings. layers (torch.nn.ModuleList): List of Transformer blocks. norm (RMSNorm): Layer normalization for the model output. - output (ColumnParallelLinear): Linear layer for final output. + output (Linear): Linear layer for final output. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. """ @@ -375,7 +375,6 @@ def __init__(self, model_args: TransformerModelArgs): self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers - self.eos_id = model_args.eos_id self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) @@ -441,7 +440,12 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_theta, ) - def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None): + def forward( + self, + tokens: torch.Tensor, + eos_id: int | None = None, + input_batch: torch.Tensor | None = None, + ): """ Perform a forward pass through the Transformer model. @@ -461,7 +465,7 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None) """ if self.model_args.use_flex_attn: init_attention_mask( - input_batch if input_batch is not None else tokens, eos_id=self.eos_id + input_batch if input_batch is not None else tokens, eos_id=eos_id ) # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index ca6545a24a..570d894f51 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -150,7 +150,7 @@ def blocked_mask_mod( @staticmethod @torch.no_grad() - def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None: + def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: # batch is [b, s, h, d] shape for mask_key in FlexAttention.used_attn_mask_types: attn_mask_type, fixed_block_size = mask_key @@ -239,5 +239,5 @@ def build_attention( return ScaledDotProductAttention(attn_mask_type) -def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None: +def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: FlexAttention.init_attention_mask(batch, eos_id) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 7d7ebd8a7c..af95492b82 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -29,7 +29,7 @@ deepseekv3_configs = { "debugmodel": DeepSeekV3ModelArgs( - vocab_size=102400, + vocab_size=2000, dim=256, inter_dim=1024, moe_inter_dim=256, @@ -48,7 +48,7 @@ mscale=0.70, ), "debugmodel_flex_attn": DeepSeekV3ModelArgs( - vocab_size=102400, + vocab_size=2000, dim=256, inter_dim=1024, moe_inter_dim=256, diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index cfa396410c..8d9e705275 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -12,10 +12,10 @@ from torch import nn -from torchtitan.components.tokenizer import Tokenizer from torchtitan.config_manager import JobConfig from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger +from torchtitan.tools.utils import has_cuda_capability # Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py @@ -92,14 +92,34 @@ class DeepSeekV3ModelArgs(BaseModelArgs): beta_fast: int = 32 beta_slow: int = 1 mscale: float = 1.0 - eos_id: int = 0 - def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: - """ - Update the model_config config from the given job config. - """ - self.vocab_size = tokenizer.vocab_size - self.max_seq_len = job_config.training.seq_len + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len + + if self.use_grouped_mm and not has_cuda_capability(9, 0): + logger.warning( + "Failed to use grouped mm, which is only supported on SM90 or later", + ) + self.use_grouped_mm = False + + if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + raise NotImplementedError( + "CP support for FlexAttention is still in progress." + ) + + if ( + job_config.parallelism.pipeline_parallel_degree > 1 + and self.use_flex_attn + and self.attn_mask_type == "block_causal" + ): + raise RuntimeError( + "PP + block causal FlexAttention support will be fixed soon." + ) def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: """ diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index e68d6ba838..e13eb2bf4f 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -330,7 +330,6 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): bias=False, ) self.model_args = model_args - self.eos_id = model_args.eos_id self.init_weights() def init_weights(self, buffer_device: torch.device | None = None) -> None: @@ -355,7 +354,12 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: b=cutoff_factor * final_out_std, ) - def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None): + def forward( + self, + tokens: torch.Tensor, + eos_id: int | None = None, + input_batch: torch.Tensor | None = None, + ): """ Forward pass for the Transformer model. @@ -374,7 +378,7 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None) """ if self.model_args.use_flex_attn: init_attention_mask( - input_batch if input_batch is not None else tokens, eos_id=self.eos_id + input_batch if input_batch is not None else tokens, eos_id=eos_id ) h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index bbfebd36c4..a34b4463f8 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -29,12 +29,13 @@ llama3_configs = { "debugmodel": TransformerModelArgs( - dim=256, n_layers=6, n_heads=16, rope_theta=500000 + dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000 ), "debugmodel_flex_attn": TransformerModelArgs( dim=256, n_layers=6, n_heads=16, + vocab_size=2000, rope_theta=500000, use_flex_attn=True, attn_mask_type="block_causal", diff --git a/torchtitan/models/llama3/infra/pipeline.py b/torchtitan/models/llama3/infra/pipeline.py index dfb424b5b5..2ca2e67e3e 100644 --- a/torchtitan/models/llama3/infra/pipeline.py +++ b/torchtitan/models/llama3/infra/pipeline.py @@ -37,14 +37,14 @@ def pipeline_llama( parallel_dims: ParallelDims, job_config: JobConfig, device: torch.device, - model_config: TransformerModelArgs, + model_args: TransformerModelArgs, parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: pp_mesh = parallel_dims.world_mesh["pp"] stages, model_parts = pipeline_llama_manual_split( - model, pp_mesh, parallel_dims, job_config, device, model_config + model, pp_mesh, parallel_dims, job_config, device, model_args ) # For PP with looped schedules, each item in model_parts is one stage-model-chunk. @@ -78,7 +78,7 @@ def pipeline_llama_manual_split( parallel_dims: ParallelDims, job_config: JobConfig, device: torch.device, - model_config: TransformerModelArgs, + model_args: TransformerModelArgs, ) -> tuple[list[PipelineStage], list[nn.Module]]: """ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. @@ -95,7 +95,7 @@ def pipeline_llama_manual_split( splits = parallelism_config.pipeline_parallel_split_points or generate_split_points( parallelism_config.pipeline_parallel_schedule, parallel_dims.pp, - model_config.n_layers, + model_args.n_layers, parallelism_config.pipeline_parallel_layers_per_stage, ) diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 7f7b4e5a96..f9e141d18a 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -11,9 +11,9 @@ from torch import nn -from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config_manager import JobConfig from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools.logging import logger @dataclass @@ -22,13 +22,13 @@ class TransformerModelArgs(BaseModelArgs): n_layers: int = 32 n_heads: int = 32 n_kv_heads: int | None = None - vocab_size: int = -1 # defined later by tokenizer + vocab_size: int = 128256 multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: float | None = None norm_eps: float = 1e-5 rope_theta: float = 10000 - max_seq_len: int = 2048 + max_seq_len: int = 131072 # If `True`, then each transformer block init uses its layer ID, and if # `False`, each uses the total number of transformer blocks depth_init: bool = True @@ -37,17 +37,26 @@ class TransformerModelArgs(BaseModelArgs): attn_mask_type: str = "causal" eos_id: int = 0 - def update_from_config( - self, job_config: JobConfig, tokenizer: BaseTokenizer - ) -> None: - self.vocab_size = tokenizer.get_vocab_size() - self.max_seq_len = job_config.training.seq_len - self.eos_id = tokenizer.eos_id + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise ValueError( - "FlexAttention is not compatible with CP yet. " - "We are still working on this." + raise NotImplementedError( + "CP support for FlexAttention is still in progress." + ) + + if ( + job_config.parallelism.pipeline_parallel_degree > 1 + and self.use_flex_attn + and self.attn_mask_type == "block_causal" + ): + raise RuntimeError( + "PP + block causal FlexAttention support will be fixed soon." ) def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index a52820939f..e45af90ba2 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -322,7 +322,7 @@ class Transformer(nn.Module, ModelProtocol): tok_embeddings (ParallelEmbedding): Token embeddings. layers (torch.nn.ModuleList): List of Transformer blocks. norm (RMSNorm): Layer normalization for the model output. - output (ColumnParallelLinear): Linear layer for final output. + output (Linear): Linear layer for final output. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. """ @@ -332,7 +332,6 @@ def __init__(self, model_args: TransformerModelArgs): self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers - self.eos_id = model_args.eos_id self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) @@ -398,7 +397,12 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_theta, ) - def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None): + def forward( + self, + tokens: torch.Tensor, + eos_id: int | None = None, + input_batch: torch.Tensor | None = None, + ): """ Perform a forward pass through the Transformer model. @@ -418,7 +422,7 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None) """ if self.model_args.use_flex_attn: init_attention_mask( - input_batch if input_batch is not None else tokens, eos_id=self.eos_id + input_batch if input_batch is not None else tokens, eos_id=eos_id ) # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index edf3cc4b93..0f33602fa4 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -37,9 +37,7 @@ class BaseModelArgs: _enforced: str = "This field is used to enforce all fields have defaults." @abstractmethod - def update_from_config( - self, job_config: JobConfig, tokenizer: BaseTokenizer - ) -> None: + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: pass @abstractmethod diff --git a/torchtitan/train.py b/torchtitan/train.py index 4ccb4a45b4..cce9776fd3 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -40,6 +40,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): train_spec: train_spec_module.TrainSpec # swappable training components in TrainSpec + tokenizer: train_spec_module.BaseTokenizer dataloader: train_spec_module.BaseDataLoader model_parts: list[torch.nn.Module] loss_fn: train_spec_module.LossFunction @@ -122,24 +123,21 @@ def __init__(self, job_config: JobConfig): ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name) - # build dataloader - tokenizer = ( - self.train_spec.build_tokenizer_fn(job_config) - if self.train_spec.build_tokenizer_fn is not None - else None - ) + # build tokenizer and dataloader + if self.train_spec.build_tokenizer_fn is not None: + self.tokenizer = self.train_spec.build_tokenizer_fn(job_config) self.dataloader = self.train_spec.build_dataloader_fn( dp_world_size=dp_degree, dp_rank=dp_rank, - tokenizer=tokenizer, + tokenizer=self.tokenizer, job_config=job_config, ) # build model (using meta init) model_args = self.train_spec.model_args[job_config.model.flavor] # set the model args from training job configs - model_args.update_from_config(job_config, tokenizer) + model_args.update_from_config(job_config) logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" @@ -325,7 +323,7 @@ def __init__(self, job_config: JobConfig): job_config=job_config, dp_world_size=dp_degree, dp_rank=dp_rank, - tokenizer=tokenizer, + tokenizer=self.tokenizer, parallel_dims=parallel_dims, loss_fn=self.train_spec.build_loss_fn(job_config), validation_context=self.train_context, @@ -401,11 +399,16 @@ def forward_backward_step( ) if self.pp_has_first_stage: self.pp_schedule.step( - inputs, target=targets, losses=losses, input_batch=inputs + inputs, + target=targets, + losses=losses, + input_batch=inputs, ) else: self.pp_schedule.step( - target=targets, losses=losses, input_batch=inputs + target=targets, + losses=losses, + input_batch=inputs, ) # accumulate losses across pipeline microbatches @@ -420,7 +423,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs) + pred = model_parts[0](inputs, eos_id=self.tokenizer.eos_id) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred From 171a88350eb79d40918d2ea4d95aee256a34d0a0 Mon Sep 17 00:00:00 2001 From: ebsmothers Date: Tue, 22 Jul 2025 17:44:10 -0700 Subject: [PATCH 028/128] Take job config out of checkpoint manager (#1433) This PR takes job_config out of the CheckpointManager class. Why? JobConfig is a monolith -- it has knowledge of every part of a titan training job. As a result, it is hard to actually use CheckpointManager in a standalone fashion. In practice the job config is mostly only used for its checkpoint config, plus two other usages as far as I can tell: 1) Getting the replica_id from the FTManager 2) Taking the dump_folder from the job field and joining it with the checkpoint folder For (1) we can just get this directly from FTManager without accessing the JobConfig field. For (2) we can pass `job_config.job.dump_folder` explicitly as a base folder, then join to `checkpoint_config.folder`. Personally I would try to consolidate `job.dump_folder` and `checkpoint.folder` (though I understand there are cases where only the former is needed) under Checkpoint, but not sure if this is preferable from titan's pov. --- tests/unit_tests/test_checkpoint.py | 52 +++++++++++++++++++---------- torchtitan/components/checkpoint.py | 45 ++++++++++++++----------- torchtitan/train.py | 3 +- 3 files changed, 62 insertions(+), 38 deletions(-) diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index e2c0e1254b..b39a65c261 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -175,7 +175,8 @@ def test_save_load_restores_state(self, mock_load, mock_save, mock_rank): optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -207,7 +208,8 @@ def test_save_and_purge_keeps_last_k_checkpoints( optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -247,7 +249,8 @@ def test_nonzero_rank_does_not_purge_or_save(self, mock_load, mock_save, mock_ra optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) manager.save(curr_step=1) @@ -269,7 +272,8 @@ def test_load_returns_false_when_no_checkpoint_folder(self): optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) self.assertFalse(manager.load(step=-1)) @@ -292,7 +296,8 @@ def test_load_finds_latest_and_calls_dcp_load(self, mock_load, mock_rank): optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) res = manager.load(step=-1) @@ -321,7 +326,8 @@ def test_interval_respects_interval(self, mock_load, mock_save, mock_rank): optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) manager.save(curr_step=1) @@ -354,7 +360,8 @@ def test_last_save_model_only_and_initial_load_model_only( optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) manager1.save(curr_step=1, last_step=True) @@ -373,7 +380,8 @@ def test_last_save_model_only_and_initial_load_model_only( optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) r1 = manager2.load(step=1) @@ -404,7 +412,8 @@ def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group): """ # Configure async mode job_config = DummyJobConfig(job=self.job_config.job) - job_config.checkpoint.async_mode = "async" + checkpoint_config = job_config.checkpoint + checkpoint_config.async_mode = "async" ft_manager = DummyFTManager() states = {"trainer": torch.tensor([0])} manager = CheckpointManager( @@ -413,8 +422,9 @@ def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group): optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=states, - job_config=job_config, - ft_manager=ft_manager, + checkpoint_config=checkpoint_config, + base_folder=self.job_config.job.dump_folder, + ft_manager=self.ft_manager, ) # First save schedules async @@ -445,7 +455,8 @@ def test_ft_async_save_calls_async_wait( Test that with FT enabled, AsyncMode.ASYNC via FT triggers correct waits. """ job_config = DummyJobConfig(job=self.job_config.job) - job_config.checkpoint.async_mode = "async" + checkpoint_config = job_config.checkpoint + checkpoint_config.async_mode = "async" ft_manager = mock.Mock() ft_manager.manager.return_value = mock.Mock() ft_manager.manager.participating_rank = mock.Mock(return_value=0) @@ -456,8 +467,9 @@ def test_ft_async_save_calls_async_wait( optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=job_config, - ft_manager=ft_manager, + checkpoint_config=checkpoint_config, + base_folder=self.job_config.job.dump_folder, + ft_manager=self.ft_manager, ) # Initially no future @@ -491,7 +503,8 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank): optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -516,7 +529,8 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank): optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -561,7 +575,8 @@ def __init__(self): optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -610,7 +625,8 @@ def fake_load(state_dict: dict, checkpoint_id=None): optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states=self.states, - job_config=self.job_config, + checkpoint_config=self.job_config.checkpoint, + base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 73489fb512..c87c61461a 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -36,7 +36,7 @@ from torchtitan.components.ft import FTManager from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer -from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config_manager import Checkpoint, TORCH_DTYPE_MAP from torchtitan.protocols.state_dict_adapter import StateDictAdapter from torchtitan.tools.logging import logger from torchtitan.tools.utils import GarbageCollection @@ -174,10 +174,13 @@ class CheckpointManager: lr_schedulers (LRSchedulersContainer): The lr schedulers used to optimize the model. states (Dict[str, Any]): The states that need to be saved, other than the previous 4 components. - job_config (JobConfig): The job config used to configure the checkpointing. + checkpoint_config (Checkpoint): The config used to configure the checkpointing. + base_folder (str): The base folder to save the checkpoint. Will be concatenated + with checkpoint_config.folder sd_adapter (Optional[type[StateDictAdapter]]): The adapter used to convert model state dicts between native format and other formats. ft_manager (Optional[ft.Manager]): The FTManager from TorchFT. + """ def __init__( @@ -187,13 +190,13 @@ def __init__( optimizers: OptimizersContainer, lr_schedulers: LRSchedulersContainer, states: dict[str, Any], - job_config: JobConfig, + checkpoint_config: Checkpoint, + base_folder: str, sd_adapter: type[StateDictAdapter] | None = None, ft_manager: FTManager | None = None, ) -> None: - ckpt_config = job_config.checkpoint - self.enable_checkpoint = ckpt_config.enable_checkpoint - self.last_save_in_hf = ckpt_config.last_save_in_hf + self.enable_checkpoint = checkpoint_config.enable_checkpoint + self.last_save_in_hf = checkpoint_config.last_save_in_hf if self.last_save_in_hf: assert ( sd_adapter is not None @@ -224,9 +227,9 @@ def load_state_dict(state_dict): self.states[k].load_state_dict(v) self.ft_manager.set_state_dict_fns(load_state_dict, state_dict) - self.ft_replica_id = job_config.fault_tolerance.replica_id + self.ft_replica_id = ft_manager.replica_id - async_mode = ckpt_config.async_mode.lower() + async_mode = checkpoint_config.async_mode.lower() self.enable_staging = ( self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM ) or self.ft_manager @@ -251,19 +254,21 @@ def load_state_dict(state_dict): self.cpu_offload_state_dict = None self.stager = None - self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) + self.folder = os.path.join(base_folder, checkpoint_config.folder) # Checkpoint policy related fields. - self.initial_load_path = ckpt_config.initial_load_path - self.initial_load_model_only = ckpt_config.initial_load_model_only - self.last_save_model_only = ckpt_config.last_save_model_only - self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] - self.exclude_from_loading = ckpt_config.exclude_from_loading - self.interval = ckpt_config.interval - self.enable_first_step_checkpoint = ckpt_config.enable_first_step_checkpoint + self.initial_load_path = checkpoint_config.initial_load_path + self.initial_load_model_only = checkpoint_config.initial_load_model_only + self.last_save_model_only = checkpoint_config.last_save_model_only + self.export_dtype = TORCH_DTYPE_MAP[checkpoint_config.export_dtype] + self.exclude_from_loading = checkpoint_config.exclude_from_loading + self.interval = checkpoint_config.interval + self.enable_first_step_checkpoint = ( + checkpoint_config.enable_first_step_checkpoint + ) # Async checkpoint related fields. - async_mode = ckpt_config.async_mode.lower() + async_mode = checkpoint_config.async_mode.lower() if ( async_mode == AsyncMode.ASYNC or async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM @@ -271,7 +276,7 @@ def load_state_dict(state_dict): ): self.pg = dist.new_group(backend="gloo") - self.keep_latest_k = ckpt_config.keep_latest_k + self.keep_latest_k = checkpoint_config.keep_latest_k if self.keep_latest_k > 0: if self.keep_latest_k == 1: raise ValueError( @@ -296,7 +301,9 @@ def load_state_dict(state_dict): elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: self.async_mode = AsyncMode.ASYNC_WITH_PINNED_MEM else: - raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}") + raise ValueError( + f"Unkown checkpoint async_mode {checkpoint_config.async_mode}" + ) logger.info( f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}" diff --git a/torchtitan/train.py b/torchtitan/train.py index cce9776fd3..c79b985966 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -294,7 +294,8 @@ def __init__(self, job_config: JobConfig): optimizers=self.optimizers, lr_schedulers=self.lr_schedulers, states={"train_state": self}, - job_config=job_config, + checkpoint_config=job_config.checkpoint, + base_folder=job_config.job.dump_folder, sd_adapter=self.train_spec.state_dict_adapter, ft_manager=self.ft_manager, ) From 34d815c026acc7f9984e9c39097da26de73e6caf Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Tue, 22 Jul 2025 20:59:31 -0700 Subject: [PATCH 029/128] [refactor] split JobConfig and ConfigManager into two files (#1442) This PR creates a new folder `torchtitan/config` to host `job_config.py` and `manager.py`, for the reasons below: - Both are complicated enough to worth their own files. - The convention in torchtitan to extend custom `JobConfig` is to create a file under model folder called `job_config.py` (see https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/job_config.py). This PR makes the origin `JobConfig` consistent with that convention. - (minor) Creating a more succinct `torchtitan.config` namespace is more readable than importing from `torchtitan.config_manager`. --- .github/workflows/integration_test_8gpu.yaml | 2 +- .../workflows/integration_test_8gpu_h100.yaml | 2 +- .../integration_test_8gpu_torchft.yaml | 2 +- scripts/estimate/estimation.py | 2 +- scripts/generate/test_generate.py | 2 +- tests/integration_tests_ft.py | 29 +-- tests/integration_tests_h100.py | 19 +- .../unit_tests/test_activation_checkpoint.py | 32 +-- tests/unit_tests/test_checkpoint.py | 16 +- .../unit_tests/test_dataset_checkpointing.py | 2 +- tests/unit_tests/test_job_config.py | 2 +- tests/unit_tests/test_lr_scheduler.py | 3 +- tests/unit_tests/test_model_converter.py | 2 +- tests/unit_tests/test_train_spec.py | 2 +- torchtitan/components/checkpoint.py | 9 +- torchtitan/components/ft.py | 2 +- torchtitan/components/loss.py | 2 +- torchtitan/components/lr_scheduler.py | 2 +- torchtitan/components/metrics.py | 2 +- torchtitan/components/optimizer.py | 2 +- torchtitan/components/quantization/float8.py | 2 +- torchtitan/components/quantization/mx.py | 2 +- torchtitan/components/tokenizer.py | 2 +- torchtitan/components/validate.py | 2 +- torchtitan/config/__init__.py | 18 ++ .../job_config.py} | 235 +---------------- torchtitan/config/manager.py | 236 ++++++++++++++++++ torchtitan/datasets/hf_datasets.py | 2 +- torchtitan/distributed/pipeline.py | 2 +- torchtitan/distributed/utils.py | 2 +- .../experiments/flux/dataset/flux_dataset.py | 2 +- .../experiments/flux/dataset/tokenizer.py | 2 +- .../experiments/flux/infra/parallelize.py | 2 +- torchtitan/experiments/flux/loss.py | 2 +- torchtitan/experiments/flux/sampling.py | 2 +- .../flux/tests/test_generate_image.py | 2 +- .../tests/unit_tests/test_flux_dataloader.py | 2 +- torchtitan/experiments/flux/train.py | 2 +- .../experiments/llama4/infra/parallelize.py | 8 +- torchtitan/experiments/llama4/model/args.py | 2 +- torchtitan/experiments/llama4/optimizer.py | 2 +- .../scripts/convert_hf_to_dcp_with_gpus.py | 2 +- .../scripts/convert_meta_to_dcp_with_gpus.py | 2 +- .../experiments/simple_fsdp/parallelize.py | 2 +- .../models/deepseek_v3/infra/parallelize.py | 8 +- .../models/deepseek_v3/infra/pipeline.py | 2 +- torchtitan/models/deepseek_v3/model/args.py | 2 +- torchtitan/models/llama3/infra/parallelize.py | 13 +- torchtitan/models/llama3/infra/pipeline.py | 2 +- torchtitan/models/llama3/model/args.py | 3 +- torchtitan/protocols/model_converter.py | 2 +- torchtitan/protocols/train_spec.py | 9 +- torchtitan/tools/profiling.py | 2 +- torchtitan/train.py | 13 +- 54 files changed, 374 insertions(+), 355 deletions(-) create mode 100644 torchtitan/config/__init__.py rename torchtitan/{config_manager.py => config/job_config.py} (76%) create mode 100644 torchtitan/config/manager.py diff --git a/.github/workflows/integration_test_8gpu.yaml b/.github/workflows/integration_test_8gpu.yaml index a2469a9133..ecec8190a5 100644 --- a/.github/workflows/integration_test_8gpu.yaml +++ b/.github/workflows/integration_test_8gpu.yaml @@ -46,4 +46,4 @@ jobs: USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 mkdir artifacts-to-be-uploaded - python ./tests/integration_tests.py artifacts-to-be-uploaded --ngpu 8 + python -m tests.integration_tests artifacts-to-be-uploaded --ngpu 8 diff --git a/.github/workflows/integration_test_8gpu_h100.yaml b/.github/workflows/integration_test_8gpu_h100.yaml index 813669748d..4648c661e8 100644 --- a/.github/workflows/integration_test_8gpu_h100.yaml +++ b/.github/workflows/integration_test_8gpu_h100.yaml @@ -47,4 +47,4 @@ jobs: USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 mkdir artifacts-to-be-uploaded - python ./tests/integration_tests_h100.py artifacts-to-be-uploaded --ngpu 8 + python -m tests.integration_tests_h100 artifacts-to-be-uploaded --ngpu 8 diff --git a/.github/workflows/integration_test_8gpu_torchft.yaml b/.github/workflows/integration_test_8gpu_torchft.yaml index b06201ae7d..2268170ac2 100644 --- a/.github/workflows/integration_test_8gpu_torchft.yaml +++ b/.github/workflows/integration_test_8gpu_torchft.yaml @@ -49,5 +49,5 @@ jobs: RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 > /dev/null 2>&1 & echo "ft_integration_test" # Getting error - Cuda failure 217 'peer access is not supported between these two devices' - python ./tests/integration_tests_ft.py artifacts-to-be-uploaded --ngpu 8 + python -m tests.integration_tests_ft artifacts-to-be-uploaded --ngpu 8 # pkill -9 torchft_lighthouse diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 0c8a9ccd6f..82d306e692 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -17,7 +17,7 @@ from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers -from torchtitan.config_manager import ConfigManager, JobConfig +from torchtitan.config import ConfigManager, JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.protocols.train_spec import get_train_spec diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index dfdc859ec8..9b21b3e57b 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -26,7 +26,7 @@ ) from torchtitan.components.checkpoint import excluded_parameters_for_model_only from torchtitan.components.metrics import build_device_memory_monitor -from torchtitan.config_manager import ConfigManager +from torchtitan.config import ConfigManager from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.train_spec import get_train_spec from torchtitan.tools import utils diff --git a/tests/integration_tests_ft.py b/tests/integration_tests_ft.py index 75005e7387..6430a54dd5 100644 --- a/tests/integration_tests_ft.py +++ b/tests/integration_tests_ft.py @@ -10,8 +10,8 @@ import os import subprocess from collections import defaultdict -from dataclasses import dataclass -from typing import Sequence + +from tests.integration_tests import OverrideDefinitions logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -22,22 +22,6 @@ import tomli as tomllib -@dataclass -class OverrideDefinitions: - """ - This class is used to define the override definitions for the integration tests. - """ - - override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) - test_descr: str = "default" - test_name: str = "default" - ngpu: int = 4 - model_flavor: str = "debugmodel" - - def __repr__(self): - return self.test_descr - - def build_test_list(): """ key is the config file name and value is a list of OverrideDefinitions @@ -52,6 +36,7 @@ def build_test_list(): ], "Default TorchFT integration test", "default_torchft", + ngpu=8, ) ] return integration_tests_flavors @@ -65,7 +50,6 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): # run_test supports sequence of tests. test_name = test_flavor.test_name dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}" - model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}" # Use all 8 GPUs in a single replica # TODO: Use two replica groups @@ -79,14 +63,13 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): for replica_id, ranks in enumerate(all_ranks): cmd = ( f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' - + f"CUDA_VISIBLE_DEVICES={ranks}" - + f"CONFIG_FILE={full_path} NGPU={len(ranks)} ./run_train.sh " + + f"CUDA_VISIBLE_DEVICES={ranks} " + + f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} ./run_train.sh " + "--fault_tolerance.enable " - + f"--fault_tolerance.replica_id={replica_id} --fault_tolerance.group_size={len(all_ranks)}" + + f"--fault_tolerance.replica_id={replica_id} --fault_tolerance.group_size={test_flavor.ngpu}" ) cmd += " " + dump_folder_arg - cmd += " " + model_flavor_arg if override_arg: cmd += " " + " ".join(override_arg) diff --git a/tests/integration_tests_h100.py b/tests/integration_tests_h100.py index da9539957a..f12f3c07b8 100755 --- a/tests/integration_tests_h100.py +++ b/tests/integration_tests_h100.py @@ -9,8 +9,8 @@ import os import subprocess from collections import defaultdict -from dataclasses import dataclass -from typing import Sequence + +from .integration_tests import OverrideDefinitions logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -21,21 +21,6 @@ import tomli as tomllib -@dataclass -class OverrideDefinitions: - """ - This class is used to define the override definitions for the integration tests. - """ - - override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) - test_descr: str = "default" - test_name: str = "default" - ngpu: int = 4 - - def __repr__(self): - return self.test_descr - - def build_test_list(): """ key is the config file name and value is a list of OverrideDefinitions diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index fbc585f527..a253c4fb5b 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -10,11 +10,11 @@ import torch.nn as nn from torch.utils.flop_counter import FlopCounterMode -from torchtitan.config_manager import ActivationCheckpoint as ACConfig +from torchtitan.config.job_config import ActivationCheckpoint as ACConfig from torchtitan.models.llama3.infra.parallelize import apply_ac -class TestModule(nn.Module): +class ToyModule(nn.Module): def __init__(self): super().__init__() self.layers = nn.ModuleDict({"0": TransformerBlock()}) @@ -56,12 +56,12 @@ def get_bw_flops(model_fn): return mode.get_total_flops() / (512**3 * 2) # 1. No AC - model_no_ac = TestModule() + model_no_ac = ToyModule() flops_no_ac = get_bw_flops(model_no_ac) # 2. SAC # Per-op SAC's policy is to save every other mm - model_selective_ac = TestModule() + model_selective_ac = ToyModule() ac_config_no_force = ACConfig( mode="selective", selective_ac_option="op", @@ -72,7 +72,7 @@ def get_bw_flops(model_fn): # 3. Per-op SAC with force recompute "moe.router.gate" # This leads to two mms being recomputed since they share the same shape! - model_with_force_first = TestModule() + model_with_force_first = ToyModule() ac_config_with_force_first = ACConfig( mode="selective", selective_ac_option="op", @@ -82,7 +82,7 @@ def get_bw_flops(model_fn): flops_with_force_first = get_bw_flops(model_with_force_first) # 4. Per-op SAC with force recompute "output" - model_with_force_last = TestModule() + model_with_force_last = ToyModule() ac_config_with_force_last = ACConfig( mode="selective", selective_ac_option="op", @@ -92,7 +92,7 @@ def get_bw_flops(model_fn): flops_with_force_last = get_bw_flops(model_with_force_last) # 5. Full AC - model_with_full_ac = TestModule() + model_with_full_ac = ToyModule() ac_config_full_ac = ACConfig( mode="full", ) @@ -122,12 +122,12 @@ def get_act_mem(model_fn): return act_mem # 1. No AC - model_no_ac = TestModule().cuda() + model_no_ac = ToyModule().cuda() mem_no_ac = get_act_mem(model_no_ac) # 2. SAC # Per-op SAC's policy is to save every other mm - model_selective_ac = TestModule().cuda() + model_selective_ac = ToyModule().cuda() ac_config_no_force = ACConfig( mode="selective", selective_ac_option="op", @@ -138,7 +138,7 @@ def get_act_mem(model_fn): # 3. Per-op SAC with force recompute "moe.router.gate" # This leads to two mms being recomputed since they share the same shape! - model_with_force_first = TestModule().cuda() + model_with_force_first = ToyModule().cuda() ac_config_with_force_first = ACConfig( mode="selective", selective_ac_option="op", @@ -148,7 +148,7 @@ def get_act_mem(model_fn): mem_with_force_first = get_act_mem(model_with_force_first) # 4. Per-op SAC with force recompute "output" - model_with_force_last = TestModule().cuda() + model_with_force_last = ToyModule().cuda() ac_config_with_force_last = ACConfig( mode="selective", selective_ac_option="op", @@ -158,7 +158,7 @@ def get_act_mem(model_fn): mem_with_force_last = get_act_mem(model_with_force_last) # 5. Full AC - model_with_full_ac = TestModule().cuda() + model_with_full_ac = ToyModule().cuda() ac_config_full_ac = ACConfig( mode="full", ) @@ -175,9 +175,9 @@ def get_act_mem(model_fn): # the size of the other two mms. def test_correctness(self): - model_no_ac = TestModule() + model_no_ac = ToyModule() - model_selective_ac = TestModule() + model_selective_ac = ToyModule() model_selective_ac.load_state_dict(model_no_ac.state_dict()) apply_ac( model_selective_ac, @@ -187,7 +187,7 @@ def test_correctness(self): per_op_sac_force_recompute_mm_shapes_by_fqns=[], ), ) - model_force_first = TestModule() + model_force_first = ToyModule() model_force_first.load_state_dict(model_no_ac.state_dict()) apply_ac( model_force_first, @@ -198,7 +198,7 @@ def test_correctness(self): ), ) - model_force_last = TestModule() + model_force_last = ToyModule() model_force_last.load_state_dict(model_no_ac.state_dict()) apply_ac( model_force_last, diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index b39a65c261..4d4c942c86 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -16,7 +16,7 @@ import torch.nn as nn from torch.utils.data import DataLoader from torchtitan.components.checkpoint import CheckpointManager -from torchtitan.config_manager import Checkpoint as CheckpointConfig +from torchtitan.config.job_config import Checkpoint as CheckpointConfig class FakeOptimizersContainer: @@ -176,6 +176,7 @@ def test_save_load_restores_state(self, mock_load, mock_save, mock_rank): lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -209,6 +210,7 @@ def test_save_and_purge_keeps_last_k_checkpoints( lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -250,6 +252,7 @@ def test_nonzero_rank_does_not_purge_or_save(self, mock_load, mock_save, mock_ra lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -273,6 +276,7 @@ def test_load_returns_false_when_no_checkpoint_folder(self): lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -297,6 +301,7 @@ def test_load_finds_latest_and_calls_dcp_load(self, mock_load, mock_rank): lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -327,6 +332,7 @@ def test_interval_respects_interval(self, mock_load, mock_save, mock_rank): lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -361,6 +367,7 @@ def test_last_save_model_only_and_initial_load_model_only( lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -381,6 +388,7 @@ def test_last_save_model_only_and_initial_load_model_only( lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -423,6 +431,7 @@ def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group): lr_schedulers=self.lr_schedulers, states=states, checkpoint_config=checkpoint_config, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -468,6 +477,7 @@ def test_ft_async_save_calls_async_wait( lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=checkpoint_config, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -504,6 +514,7 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank): lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -530,6 +541,7 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank): lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -576,6 +588,7 @@ def __init__(self): lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -626,6 +639,7 @@ def fake_load(state_dict: dict, checkpoint_id=None): lr_schedulers=self.lr_schedulers, states=self.states, checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, base_folder=self.job_config.job.dump_folder, ft_manager=self.ft_manager, ) diff --git a/tests/unit_tests/test_dataset_checkpointing.py b/tests/unit_tests/test_dataset_checkpointing.py index 36dcd8f866..0d4529e6e2 100644 --- a/tests/unit_tests/test_dataset_checkpointing.py +++ b/tests/unit_tests/test_dataset_checkpointing.py @@ -9,7 +9,7 @@ import torch from datasets import load_dataset from torchtitan.components.tokenizer import HuggingFaceTokenizer -from torchtitan.config_manager import ConfigManager +from torchtitan.config import ConfigManager from torchtitan.datasets.hf_datasets import build_hf_dataloader, DatasetConfig, DATASETS diff --git a/tests/unit_tests/test_job_config.py b/tests/unit_tests/test_job_config.py index 2a64b38e55..039981dbed 100644 --- a/tests/unit_tests/test_job_config.py +++ b/tests/unit_tests/test_job_config.py @@ -10,7 +10,7 @@ import pytest import tomli_w -from torchtitan.config_manager import ConfigManager, JobConfig +from torchtitan.config import ConfigManager, JobConfig class TestJobConfig(unittest.TestCase): diff --git a/tests/unit_tests/test_lr_scheduler.py b/tests/unit_tests/test_lr_scheduler.py index 1fdbb45935..3e5473f51a 100644 --- a/tests/unit_tests/test_lr_scheduler.py +++ b/tests/unit_tests/test_lr_scheduler.py @@ -12,6 +12,7 @@ from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import OptimizersContainer +from torchtitan.config import ConfigManager class TestLRScheduler(unittest.TestCase): @@ -39,8 +40,6 @@ def create_job_config( lr_min=None, ): # Create a job config with the specified parameters - from torchtitan.config_manager import ConfigManager - args = [ "--training.steps", str(training_steps), diff --git a/tests/unit_tests/test_model_converter.py b/tests/unit_tests/test_model_converter.py index 6b9d9515f4..572a269a93 100644 --- a/tests/unit_tests/test_model_converter.py +++ b/tests/unit_tests/test_model_converter.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from torchtitan.components.quantization.float8 import Float8Converter -from torchtitan.config_manager import ConfigManager +from torchtitan.config import ConfigManager from torchtitan.distributed import ParallelDims from torchtitan.protocols.model_converter import ( build_model_converters, diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index 5b01454771..411ba1439e 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -14,7 +14,7 @@ from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers, OptimizersContainer from torchtitan.components.tokenizer import build_hf_tokenizer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.models.llama3 import parallelize_llama, pipeline_llama diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index c87c61461a..5a1b40ba88 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -36,7 +36,8 @@ from torchtitan.components.ft import FTManager from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer -from torchtitan.config_manager import Checkpoint, TORCH_DTYPE_MAP +from torchtitan.config import TORCH_DTYPE_MAP +from torchtitan.config.job_config import Checkpoint as CheckpointConfig from torchtitan.protocols.state_dict_adapter import StateDictAdapter from torchtitan.tools.logging import logger from torchtitan.tools.utils import GarbageCollection @@ -190,9 +191,9 @@ def __init__( optimizers: OptimizersContainer, lr_schedulers: LRSchedulersContainer, states: dict[str, Any], - checkpoint_config: Checkpoint, - base_folder: str, - sd_adapter: type[StateDictAdapter] | None = None, + checkpoint_config: CheckpointConfig, + sd_adapter: type[StateDictAdapter] | None, + base_folder: str = "", ft_manager: FTManager | None = None, ) -> None: self.enable_checkpoint = checkpoint_config.enable_checkpoint diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft.py index ee94b238bd..70b814f3aa 100644 --- a/torchtitan/components/ft.py +++ b/torchtitan/components/ft.py @@ -13,7 +13,7 @@ import torch.distributed as dist from torch.distributed._composable.fsdp.fully_shard import FSDPModule from torch.distributed.distributed_c10d import ReduceOp -from torchtitan.config_manager import FaultTolerance as FTConfig +from torchtitan.config.job_config import FaultTolerance as FTConfig if importlib.util.find_spec("torchft") is not None: import torchft as ft diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index 6262564064..6aa1dd5699 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -9,7 +9,7 @@ import torch -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.tools.logging import logger LossFunction: TypeAlias = Callable[..., torch.Tensor] diff --git a/torchtitan/components/lr_scheduler.py b/torchtitan/components/lr_scheduler.py index da582ea7a1..bccaf2b96c 100644 --- a/torchtitan/components/lr_scheduler.py +++ b/torchtitan/components/lr_scheduler.py @@ -13,7 +13,7 @@ from torch.optim.lr_scheduler import LambdaLR, LRScheduler from torchtitan.components.optimizer import OptimizersContainer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.tools.logging import logger __all__ = [ diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 3fee856504..732d4f709f 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -14,7 +14,7 @@ from torch.utils.tensorboard import SummaryWriter from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.tools import utils from torchtitan.tools.logging import logger diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index ee87888d74..f6fd02a4d5 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -18,7 +18,7 @@ from torch.optim import Optimizer from torchtitan.components.ft import FTManager, has_torchft -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims __all__ = [ diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 91d42164a6..ca0b38e660 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn -from torchtitan.config_manager import Float8, JobConfig +from torchtitan.config.job_config import Float8, JobConfig from torchtitan.distributed import ParallelDims from torchtitan.protocols.model_converter import ( ModelConverter, diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index f2b1bdb5f0..f22ac4bd04 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -11,7 +11,7 @@ import torch.nn as nn -from torchtitan.config_manager import JobConfig, MX +from torchtitan.config.job_config import JobConfig, MX from torchtitan.distributed import ParallelDims from torchtitan.protocols.model_converter import ( ModelConverter, diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index 6ca11d6711..f6908b7772 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -12,7 +12,7 @@ from typing import Any, Optional, Union from tokenizers import AddedToken, Tokenizer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.tools.logging import logger from typing_extensions import override diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 1c4e3dbb58..7357cc8ed0 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -13,7 +13,7 @@ from torchtitan.components.loss import LossFunction from torchtitan.components.metrics import MetricsProcessor from torchtitan.components.tokenizer import BaseTokenizer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.tools import utils diff --git a/torchtitan/config/__init__.py b/torchtitan/config/__init__.py new file mode 100644 index 0000000000..9bbcac7456 --- /dev/null +++ b/torchtitan/config/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +TORCH_DTYPE_MAP = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + +from torchtitan.config.job_config import JobConfig +from torchtitan.config.manager import ConfigManager + +__all__ = ["JobConfig", "ConfigManager", "TORCH_DTYPE_MAP"] diff --git a/torchtitan/config_manager.py b/torchtitan/config/job_config.py similarity index 76% rename from torchtitan/config_manager.py rename to torchtitan/config/job_config.py index f7babeb704..b5c167e131 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config/job_config.py @@ -4,30 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import importlib -import os -import sys - -from dataclasses import asdict, dataclass, field, fields, is_dataclass, make_dataclass -from typing import Any, Literal, Type - -import torch -import tyro - -try: - import tomllib -except ModuleNotFoundError: - import tomli as tomllib - -from torchtitan.tools.logging import logger - -TORCH_DTYPE_MAP = { - "float16": torch.float16, - "float32": torch.float32, - "bfloat16": torch.bfloat16, -} - -custom_registry = tyro.constructors.ConstructorRegistry() +from dataclasses import asdict, dataclass, field +from typing import Any, Literal @dataclass @@ -758,212 +736,3 @@ class JobConfig: def to_dict(self) -> dict[str, Any]: return asdict(self) - - -class ConfigManager: - """ - Parses, merges, and validates a JobConfig from TOML and CLI sources. - - Configuration precedence: - CLI args > TOML file > JobConfig defaults - - CLI arguments use the format
. to map to TOML entries. - Example: - model.name → - - [model] - name - """ - - def __init__(self, config_cls: Type[JobConfig] = JobConfig): - self.config_cls = config_cls - self.config: JobConfig = config_cls() - self.register_tyro_rules(custom_registry) - - def parse_args(self, args: list[str] = sys.argv[1:]) -> JobConfig: - toml_values = self._maybe_load_toml(args) - config_cls = self._maybe_add_custom_args(args, toml_values) - - base_config = ( - self._dict_to_dataclass(config_cls, toml_values) - if toml_values - else config_cls() - ) - - self.config = tyro.cli( - config_cls, args=args, default=base_config, registry=custom_registry - ) - - self._validate_config() - - return self.config - - def _maybe_load_toml(self, args: list[str]) -> dict[str, Any] | None: - # 1. Check CLI - valid_keys = {"--job.config-file", "--job.config_file"} - for i, arg in enumerate(args): - if "=" in arg: - key, value = arg.split("=", 1) - if key in valid_keys: - file_path = value - break - elif i < len(args) - 1 and arg in valid_keys: - file_path = args[i + 1] - break - else: - return None - - try: - with open(file_path, "rb") as f: - return tomllib.load(f) - except (FileNotFoundError, tomllib.TOMLDecodeError) as e: - logger.exception(f"Error while loading config file: {file_path}") - raise e - - def _maybe_add_custom_args( - self, args: list[str], toml_values: dict[str, Any] | None - ) -> Type[JobConfig]: # noqa: B006 - """Find and merge custom arguments module with current JobConfig class""" - module_path = None - - # 1. Check CLI - valid_keys = { - "--experimental.custom_args_module", - "--experimental.custom-args-module", - } - for i, arg in enumerate(args): - key = arg.split("=")[0] - if key in valid_keys: - module_path = arg.split("=", 1)[1] if "=" in arg else args[i + 1] - break - - # 2. If not found in CLI, check TOML - if not module_path and toml_values: - experimental = toml_values.get("experimental", {}) - if isinstance(experimental, dict): - module_path = experimental.get("custom_args_module") - - if not module_path: - return self.config_cls - - JobConfigExtended = importlib.import_module(module_path).JobConfig - return self._merge_configs(self.config_cls, JobConfigExtended) - - @staticmethod - def _merge_configs(base, custom) -> Type: - """ - Merges a base JobConfig class with user-defined extensions. - - This method creates a new dataclass type that combines fields from both `base` and `custom`, - allowing users to extend or override JobConfig configuration structure. - - Merge behavior: - - If a field exists in both `base` and `custom`: - - If both field types are dataclasses, they are merged recursively. - - Otherwise, the field from `custom` overrides the one in `base` (type, default, etc.). - - Fields only present in `base` or `custom` are preserved as-is. - """ - result = [] - b_map = {f.name: f for f in fields(base)} - c_map = {f.name: f for f in fields(custom)} - - for name, f in b_map.items(): - if ( - name in c_map - and is_dataclass(f.type) - and is_dataclass(c_map[name].type) - ): - m_type = ConfigManager._merge_configs(f.type, c_map[name].type) - result.append((name, m_type, field(default_factory=m_type))) - - # Custom field overrides base type - elif name in c_map: - result.append((name, c_map[name].type, c_map[name])) - - # Only in Base - else: - result.append((name, f.type, f)) - - # Only in Custom - for name, f in c_map.items(): - if name not in b_map: - result.append((name, f.type, f)) - - return make_dataclass(f"Merged{base.__name__}", result, bases=(base,)) - - def _dict_to_dataclass(self, cls, data: dict[str, Any]) -> Any: - """Convert dictionary to dataclass, handling nested structures.""" - if not is_dataclass(cls): - return data - - result = {} - for f in fields(cls): - if f.name in data: - value = data[f.name] - if is_dataclass(f.type) and isinstance(value, dict): - result[f.name] = self._dict_to_dataclass(f.type, value) - else: - result[f.name] = value - return cls(**result) - - def _validate_config(self) -> None: - # TODO: temporary mitigation of BC breaking change in - # tokenizer default path, need to remove later - if not os.path.exists(self.config.model.tokenizer_path): - logger.warning( - f"Tokenizer path {self.config.model.tokenizer_path} does not exist!" - ) - old_tokenizer_path = ( - "torchtitan/datasets/tokenizer/original/tokenizer.model" - ) - if os.path.exists(old_tokenizer_path): - self.config.model.tokenizer_path = old_tokenizer_path - logger.warning( - f"Temporarily switching to previous default tokenizer path {old_tokenizer_path}. " - "Please download the new tokenizer model (python scripts/download_tokenizer.py) and update your config." - ) - else: - # Check if we are using tokenizer.model, if so then we need to alert users to redownload the tokenizer - if self.config.model.tokenizer_path.endswith("tokenizer.model"): - raise Exception( - "You are using the old tokenizer.model, please redownload the tokenizer ", - "(python scripts/download_tokenizer.py --repo_id meta-llama/Llama-3.1-8B) ", - " and update your config to the directory of the downloaded tokenizer.", - ) - - @staticmethod - def register_tyro_rules(registry: tyro.constructors.ConstructorRegistry) -> None: - @registry.primitive_rule - def list_str_rule(type_info: tyro.constructors.PrimitiveTypeInfo): - """Support for comma seperated string parsing""" - if type_info.type != list[str]: - return None - return tyro.constructors.PrimitiveConstructorSpec( - nargs=1, - metavar="A,B,C,...", - instance_from_str=lambda args: args[0].split(","), - is_instance=lambda instance: all(isinstance(i, str) for i in instance), - str_from_instance=lambda instance: [",".join(instance)], - ) - - -if __name__ == "__main__": - # ----------------------------------------------------------------------------- - # Run this module directly to debug or inspect configuration parsing. - # - # Examples: - # Show help message: - # > python -m torchtitan.config_manager --help - # - # Parse and print a config with CLI arguments: - # > python -m torchtitan.config_manager --profiling.enable_memory_snapshot - # - # ----------------------------------------------------------------------------- - - from rich import print as rprint - from rich.pretty import Pretty - - config_manager = ConfigManager() - config = config_manager.parse_args() - - rprint(Pretty(config)) diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py new file mode 100644 index 0000000000..d22b3d21fa --- /dev/null +++ b/torchtitan/config/manager.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os +import sys + +from dataclasses import field, fields, is_dataclass, make_dataclass +from typing import Any, Type + +import tyro + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib + +from torchtitan.tools.logging import logger + +from .job_config import JobConfig + + +class ConfigManager: + """ + Parses, merges, and validates a JobConfig from TOML and CLI sources. + + Configuration precedence: + CLI args > TOML file > JobConfig defaults + + CLI arguments use the format
. to map to TOML entries. + Example: + model.name → + + [model] + name + """ + + def __init__(self, config_cls: Type[JobConfig] = JobConfig): + self.config_cls = config_cls + self.config: JobConfig = config_cls() + self.register_tyro_rules(custom_registry) + + def parse_args(self, args: list[str] = sys.argv[1:]) -> JobConfig: + toml_values = self._maybe_load_toml(args) + config_cls = self._maybe_add_custom_args(args, toml_values) + + base_config = ( + self._dict_to_dataclass(config_cls, toml_values) + if toml_values + else config_cls() + ) + + self.config = tyro.cli( + config_cls, args=args, default=base_config, registry=custom_registry + ) + + self._validate_config() + + return self.config + + def _maybe_load_toml(self, args: list[str]) -> dict[str, Any] | None: + # 1. Check CLI + valid_keys = {"--job.config-file", "--job.config_file"} + for i, arg in enumerate(args): + if "=" in arg: + key, value = arg.split("=", 1) + if key in valid_keys: + file_path = value + break + elif i < len(args) - 1 and arg in valid_keys: + file_path = args[i + 1] + break + else: + return None + + try: + with open(file_path, "rb") as f: + return tomllib.load(f) + except (FileNotFoundError, tomllib.TOMLDecodeError) as e: + logger.exception(f"Error while loading config file: {file_path}") + raise e + + def _maybe_add_custom_args( + self, args: list[str], toml_values: dict[str, Any] | None + ) -> Type[JobConfig]: # noqa: B006 + """Find and merge custom arguments module with current JobConfig class""" + module_path = None + + # 1. Check CLI + valid_keys = { + "--experimental.custom_args_module", + "--experimental.custom-args-module", + } + for i, arg in enumerate(args): + key = arg.split("=")[0] + if key in valid_keys: + module_path = arg.split("=", 1)[1] if "=" in arg else args[i + 1] + break + + # 2. If not found in CLI, check TOML + if not module_path and toml_values: + experimental = toml_values.get("experimental", {}) + if isinstance(experimental, dict): + module_path = experimental.get("custom_args_module") + + if not module_path: + return self.config_cls + + JobConfigExtended = importlib.import_module(module_path).JobConfig + return self._merge_configs(self.config_cls, JobConfigExtended) + + @staticmethod + def _merge_configs(base, custom) -> Type: + """ + Merges a base JobConfig class with user-defined extensions. + + This method creates a new dataclass type that combines fields from both `base` and `custom`, + allowing users to extend or override JobConfig configuration structure. + + Merge behavior: + - If a field exists in both `base` and `custom`: + - If both field types are dataclasses, they are merged recursively. + - Otherwise, the field from `custom` overrides the one in `base` (type, default, etc.). + - Fields only present in `base` or `custom` are preserved as-is. + """ + result = [] + b_map = {f.name: f for f in fields(base)} + c_map = {f.name: f for f in fields(custom)} + + for name, f in b_map.items(): + if ( + name in c_map + and is_dataclass(f.type) + and is_dataclass(c_map[name].type) + ): + m_type = ConfigManager._merge_configs(f.type, c_map[name].type) + result.append((name, m_type, field(default_factory=m_type))) + + # Custom field overrides base type + elif name in c_map: + result.append((name, c_map[name].type, c_map[name])) + + # Only in Base + else: + result.append((name, f.type, f)) + + # Only in Custom + for name, f in c_map.items(): + if name not in b_map: + result.append((name, f.type, f)) + + return make_dataclass(f"Merged{base.__name__}", result, bases=(base,)) + + def _dict_to_dataclass(self, cls, data: dict[str, Any]) -> Any: + """Convert dictionary to dataclass, handling nested structures.""" + if not is_dataclass(cls): + return data + + result = {} + for f in fields(cls): + if f.name in data: + value = data[f.name] + if is_dataclass(f.type) and isinstance(value, dict): + result[f.name] = self._dict_to_dataclass(f.type, value) + else: + result[f.name] = value + return cls(**result) + + def _validate_config(self) -> None: + # TODO: temporary mitigation of BC breaking change in + # tokenizer default path, need to remove later + if not os.path.exists(self.config.model.tokenizer_path): + logger.warning( + f"Tokenizer path {self.config.model.tokenizer_path} does not exist!" + ) + old_tokenizer_path = ( + "torchtitan/datasets/tokenizer/original/tokenizer.model" + ) + if os.path.exists(old_tokenizer_path): + self.config.model.tokenizer_path = old_tokenizer_path + logger.warning( + f"Temporarily switching to previous default tokenizer path {old_tokenizer_path}. " + "Please download the new tokenizer model (python scripts/download_tokenizer.py) and update your config." + ) + else: + # Check if we are using tokenizer.model, if so then we need to alert users to redownload the tokenizer + if self.config.model.tokenizer_path.endswith("tokenizer.model"): + raise Exception( + "You are using the old tokenizer.model, please redownload the tokenizer ", + "(python scripts/download_tokenizer.py --repo_id meta-llama/Llama-3.1-8B) ", + " and update your config to the directory of the downloaded tokenizer.", + ) + + @staticmethod + def register_tyro_rules(registry: tyro.constructors.ConstructorRegistry) -> None: + @registry.primitive_rule + def list_str_rule(type_info: tyro.constructors.PrimitiveTypeInfo): + """Support for comma seperated string parsing""" + if type_info.type != list[str]: + return None + return tyro.constructors.PrimitiveConstructorSpec( + nargs=1, + metavar="A,B,C,...", + instance_from_str=lambda args: args[0].split(","), + is_instance=lambda instance: all(isinstance(i, str) for i in instance), + str_from_instance=lambda instance: [",".join(instance)], + ) + + +# Initialize the custom registry for tyro +custom_registry = tyro.constructors.ConstructorRegistry() + + +if __name__ == "__main__": + # ----------------------------------------------------------------------------- + # Run this module directly to debug or inspect configuration parsing. + # + # Examples: + # Show help message: + # > python -m torchtitan.config.manager --help + # + # Parse and print a config with CLI arguments: + # > python -m torchtitan.config.manager --profiling.enable_memory_snapshot + # + # ----------------------------------------------------------------------------- + + from rich import print as rprint + from rich.pretty import Pretty + + config_manager = ConfigManager() + config = config_manager.parse_args() + + rprint(Pretty(config)) diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index dbef80a6ee..0e30f8fe51 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -18,7 +18,7 @@ from torchtitan.components.dataloader import ParallelAwareDataloader from torchtitan.components.tokenizer import BaseTokenizer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.tools.logging import logger diff --git a/torchtitan/distributed/pipeline.py b/torchtitan/distributed/pipeline.py index 366021a7fc..9526a7e3b7 100644 --- a/torchtitan/distributed/pipeline.py +++ b/torchtitan/distributed/pipeline.py @@ -15,7 +15,7 @@ ) from torch.distributed.pipelining.stage import PipelineStage -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.tools.logging import logger diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index e25794a240..ecda3f9b6f 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -18,7 +18,7 @@ from torch.distributed.tensor import DTensor from torch.nn.attention import SDPBackend -from torchtitan.config_manager import TORCH_DTYPE_MAP +from torchtitan.config import TORCH_DTYPE_MAP from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.models.attention import ScaledDotProductAttention from torchtitan.tools.logging import logger diff --git a/torchtitan/experiments/flux/dataset/flux_dataset.py b/torchtitan/experiments/flux/dataset/flux_dataset.py index ea47d46abc..bd0fc715c1 100644 --- a/torchtitan/experiments/flux/dataset/flux_dataset.py +++ b/torchtitan/experiments/flux/dataset/flux_dataset.py @@ -21,7 +21,7 @@ from torchtitan.components.dataloader import ParallelAwareDataloader from torchtitan.components.tokenizer import BaseTokenizer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.experiments.flux.dataset.tokenizer import ( build_flux_tokenizer, FluxTokenizer, diff --git a/torchtitan/experiments/flux/dataset/tokenizer.py b/torchtitan/experiments/flux/dataset/tokenizer.py index 3d69b0ac57..bf90bdcb26 100644 --- a/torchtitan/experiments/flux/dataset/tokenizer.py +++ b/torchtitan/experiments/flux/dataset/tokenizer.py @@ -12,7 +12,7 @@ import torch from torchtitan.components.tokenizer import BaseTokenizer, HuggingFaceTokenizer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from transformers import CLIPTokenizer, T5Tokenizer diff --git a/torchtitan/experiments/flux/infra/parallelize.py b/torchtitan/experiments/flux/infra/parallelize.py index 69fef68c50..c2bdb98b30 100644 --- a/torchtitan/experiments/flux/infra/parallelize.py +++ b/torchtitan/experiments/flux/infra/parallelize.py @@ -14,7 +14,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy -from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger diff --git a/torchtitan/experiments/flux/loss.py b/torchtitan/experiments/flux/loss.py index e3d2f000be..9159b40b8a 100644 --- a/torchtitan/experiments/flux/loss.py +++ b/torchtitan/experiments/flux/loss.py @@ -8,7 +8,7 @@ import torch -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.tools.logging import logger LossFunction: TypeAlias = Callable[..., torch.Tensor] diff --git a/torchtitan/experiments/flux/sampling.py b/torchtitan/experiments/flux/sampling.py index 1dd733fc55..8e4e8589ef 100644 --- a/torchtitan/experiments/flux/sampling.py +++ b/torchtitan/experiments/flux/sampling.py @@ -15,7 +15,7 @@ from torch import Tensor from torchtitan.components.tokenizer import BaseTokenizer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.tools.logging import logger from .model.autoencoder import AutoEncoder diff --git a/torchtitan/experiments/flux/tests/test_generate_image.py b/torchtitan/experiments/flux/tests/test_generate_image.py index 56bfc7877f..2583b24349 100755 --- a/torchtitan/experiments/flux/tests/test_generate_image.py +++ b/torchtitan/experiments/flux/tests/test_generate_image.py @@ -9,7 +9,7 @@ import torch -from torchtitan.config_manager import ConfigManager +from torchtitan.config import ConfigManager from torchtitan.experiments.flux import flux_configs from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer diff --git a/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py b/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py index 84b2a4bb6d..093deb71e5 100644 --- a/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py +++ b/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py @@ -10,7 +10,7 @@ from datasets import load_dataset -from torchtitan.config_manager import ConfigManager +from torchtitan.config import ConfigManager from torchtitan.experiments.flux.dataset.flux_dataset import ( _cc12m_wds_data_processor, build_flux_dataloader, diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index c328d12b71..bc3db244dd 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -10,7 +10,7 @@ import torch from torch.distributed.fsdp import FSDPModule -from torchtitan.config_manager import ConfigManager, JobConfig, TORCH_DTYPE_MAP +from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import utils as dist_utils from torchtitan.tools.logging import init_logger, logger from torchtitan.train import Trainer diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 1b62011286..33ff71a985 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -18,7 +18,7 @@ RowwiseParallel, SequenceParallel, ) -from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.models.llama3.infra.parallelize import ( @@ -59,6 +59,12 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ + if ( + job_config.parallelism.context_parallel_degree > 1 + and model.model_args.use_flex_attn + ): + raise NotImplementedError("CP support for FlexAttention is still in progress.") + if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index d4c71a1268..89818812c9 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -9,7 +9,7 @@ from torch import nn -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger from torchtitan.tools.utils import has_cuda_capability diff --git a/torchtitan/experiments/llama4/optimizer.py b/torchtitan/experiments/llama4/optimizer.py index 3b20f6b1d9..4a997dd817 100644 --- a/torchtitan/experiments/llama4/optimizer.py +++ b/torchtitan/experiments/llama4/optimizer.py @@ -9,7 +9,7 @@ from torchtitan.components.ft import FTManager from torchtitan.components.optimizer import build_optimizers, OptimizersContainer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims diff --git a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py index 99b58395d3..03bb3706e3 100644 --- a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py +++ b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py @@ -17,7 +17,7 @@ from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard from torch.distributed.tensor._utils import compute_local_shape_and_global_offset from torchtitan.components.checkpoint import MODEL -from torchtitan.config_manager import ConfigManager, JobConfig +from torchtitan.config import ConfigManager, JobConfig from torchtitan.tools.logging import init_logger, logger from torchtitan.train import Trainer diff --git a/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py index bfde1d3220..7b32fce845 100644 --- a/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py +++ b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py @@ -15,7 +15,7 @@ from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard from torch.distributed.tensor._utils import compute_local_shape_and_global_offset from torchtitan.components.checkpoint import MODEL -from torchtitan.config_manager import ConfigManager, JobConfig +from torchtitan.config import ConfigManager, JobConfig from torchtitan.tools.logging import init_logger, logger from torchtitan.train import Trainer diff --git a/torchtitan/experiments/simple_fsdp/parallelize.py b/torchtitan/experiments/simple_fsdp/parallelize.py index 7a94adea39..ef02a4bf63 100644 --- a/torchtitan/experiments/simple_fsdp/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/parallelize.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn -from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_tp from torchtitan.tools.logging import logger diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 5220405950..532358b2da 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -15,7 +15,7 @@ SequenceParallel, ) -from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.experiments.llama4.infra.expert_parallel import NoParallel from torchtitan.experiments.llama4.infra.parallelize import apply_fsdp, apply_moe_ep_tp @@ -40,6 +40,12 @@ def parallelize_deepseekv3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ + if ( + job_config.parallelism.context_parallel_degree > 1 + and model.model_args.use_flex_attn + ): + raise NotImplementedError("CP support for FlexAttention is still in progress.") + if parallel_dims.tp_enabled: if job_config.parallelism.enable_async_tensor_parallel: # TODO(jianiw): This branch needs to be tested and enabled diff --git a/torchtitan/models/deepseek_v3/infra/pipeline.py b/torchtitan/models/deepseek_v3/infra/pipeline.py index 7caf3ad81f..b28ed39ee4 100644 --- a/torchtitan/models/deepseek_v3/infra/pipeline.py +++ b/torchtitan/models/deepseek_v3/infra/pipeline.py @@ -20,7 +20,7 @@ ) from torchtitan.components.loss import LossFunction -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.distributed.pipeline import build_pipeline_schedule, stage_ids_this_rank from torchtitan.protocols.train_spec import ParallelizeFunction diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 8d9e705275..cd94104cdb 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -12,7 +12,7 @@ from torch import nn -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger from torchtitan.tools.utils import has_cuda_capability diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index bbbbe71e4b..9e6e1a85d0 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -28,11 +28,8 @@ SequenceParallel, ) -from torchtitan.config_manager import ( - ActivationCheckpoint as ACConfig, - JobConfig, - TORCH_DTYPE_MAP, -) +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config.job_config import ActivationCheckpoint as ACConfig from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger @@ -60,6 +57,12 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ + if ( + job_config.parallelism.context_parallel_degree > 1 + and model.model_args.use_flex_attn + ): + raise NotImplementedError("CP support for FlexAttention is still in progress.") + if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel diff --git a/torchtitan/models/llama3/infra/pipeline.py b/torchtitan/models/llama3/infra/pipeline.py index 2ca2e67e3e..bf88f74322 100644 --- a/torchtitan/models/llama3/infra/pipeline.py +++ b/torchtitan/models/llama3/infra/pipeline.py @@ -19,7 +19,7 @@ ) from torchtitan.components.loss import LossFunction -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.distributed.pipeline import ( build_pipeline_schedule, diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index f9e141d18a..73c8e27700 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -11,7 +11,7 @@ from torch import nn -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger @@ -58,6 +58,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: raise RuntimeError( "PP + block causal FlexAttention support will be fixed soon." ) + self.max_seq_len = seq_len def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: nparams = sum(p.numel() for p in model.parameters()) diff --git a/torchtitan/protocols/model_converter.py b/torchtitan/protocols/model_converter.py index ea6ed8e12a..300c4231c3 100644 --- a/torchtitan/protocols/model_converter.py +++ b/torchtitan/protocols/model_converter.py @@ -7,7 +7,7 @@ import torch.nn as nn -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 0f33602fa4..afbf0a560a 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -14,15 +14,13 @@ from torch.distributed.pipelining.schedules import _PipelineSchedule from torchtitan.components.dataloader import BaseDataLoader -from torchtitan.components.ft import FTManager from torchtitan.components.loss import LossFunction from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.metrics import MetricsProcessor from torchtitan.components.optimizer import OptimizersContainer from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.components.validate import BaseValidator -from torchtitan.config_manager import JobConfig -from torchtitan.distributed import ParallelDims +from torchtitan.config import JobConfig from torchtitan.protocols.state_dict_adapter import StateDictAdapter @@ -74,10 +72,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: DataLoaderBuilder: TypeAlias = Callable[..., BaseDataLoader] TokenizerBuilder: TypeAlias = Callable[..., BaseTokenizer] MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] -OptimizersBuilder: TypeAlias = Callable[ - [list[nn.Module], JobConfig, ParallelDims, FTManager | None], - OptimizersContainer, -] +OptimizersBuilder: TypeAlias = Callable[..., OptimizersContainer] LRSchedulersBuilder: TypeAlias = Callable[ [OptimizersContainer, JobConfig], LRSchedulersContainer ] diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 27842bc7d4..1e9c67ea69 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -11,7 +11,7 @@ import torch -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.tools.logging import logger # the number of warmup steps before the active step in each profiling cycle diff --git a/torchtitan/train.py b/torchtitan/train.py index c79b985966..9b9f5d4115 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -22,7 +22,7 @@ build_metrics_processor, ensure_pp_loss_visible, ) -from torchtitan.config_manager import ConfigManager, JobConfig +from torchtitan.config import ConfigManager, JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils @@ -40,7 +40,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): train_spec: train_spec_module.TrainSpec # swappable training components in TrainSpec - tokenizer: train_spec_module.BaseTokenizer + tokenizer: train_spec_module.BaseTokenizer | None dataloader: train_spec_module.BaseDataLoader model_parts: list[torch.nn.Module] loss_fn: train_spec_module.LossFunction @@ -124,8 +124,11 @@ def __init__(self, job_config: JobConfig): self.train_spec = train_spec_module.get_train_spec(job_config.model.name) # build tokenizer and dataloader - if self.train_spec.build_tokenizer_fn is not None: - self.tokenizer = self.train_spec.build_tokenizer_fn(job_config) + self.tokenizer = ( + self.train_spec.build_tokenizer_fn(job_config) + if self.train_spec.build_tokenizer_fn is not None + else None + ) self.dataloader = self.train_spec.build_dataloader_fn( dp_world_size=dp_degree, @@ -295,8 +298,8 @@ def __init__(self, job_config: JobConfig): lr_schedulers=self.lr_schedulers, states={"train_state": self}, checkpoint_config=job_config.checkpoint, - base_folder=job_config.job.dump_folder, sd_adapter=self.train_spec.state_dict_adapter, + base_folder=job_config.job.dump_folder, ft_manager=self.ft_manager, ) From 2e6ab377366bbdb7314b7f6871083d9ff769da45 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Tue, 22 Jul 2025 21:01:07 -0700 Subject: [PATCH 030/128] add the forge folder (#1387) depending on https://github.com/pytorch/torchtitan/pull/1384 and https://github.com/pytorch/torchtitan/pull/1397 --- torchtitan/experiments/forge/README.md | 14 + torchtitan/experiments/forge/__init__.py | 11 + torchtitan/experiments/forge/engine.py | 233 ++++++++++++ torchtitan/experiments/forge/example_train.py | 351 ++++++++++++++++++ torchtitan/experiments/forge/job_config.py | 38 ++ torchtitan/experiments/forge/train_spec.py | 76 ++++ 6 files changed, 723 insertions(+) create mode 100644 torchtitan/experiments/forge/README.md create mode 100644 torchtitan/experiments/forge/__init__.py create mode 100644 torchtitan/experiments/forge/engine.py create mode 100644 torchtitan/experiments/forge/example_train.py create mode 100644 torchtitan/experiments/forge/job_config.py create mode 100644 torchtitan/experiments/forge/train_spec.py diff --git a/torchtitan/experiments/forge/README.md b/torchtitan/experiments/forge/README.md new file mode 100644 index 0000000000..83f0e4eb96 --- /dev/null +++ b/torchtitan/experiments/forge/README.md @@ -0,0 +1,14 @@ +## `ForgeEngine` + +The `forge` folder contains a lightweight training engine that serves as a streamlined subset of the `Trainer` class from [torchtitan/train.py](/torchtitan/train.py). This engine provides only the essential constructor method, making it highly flexible for various downstream applications. + +The [`ForgeEngine`](engine.py) takes a [`ForgeJobConfig`](job_config.py) to +- Initialize an SPMD distributed training environment +- Construct and scale models via n-D parallelisms and meta-device initialization +- Provide necessary training components and utilities + +**Primary Use Case**: The engine is designed for building trainers in post-training workflows where multiple specialized components (trainer, generator, replay buffer, parameter server, etc.) work together. + +Additionally, the folder provides a train spec registration method [`register_train_spec`](train_spec.py) that allows users to extend beyond the core set of models and training components available in torchtitan, enabling greater flexibility and customization for specific training requirements. + +The [example_train.py](./example_train.py) demonstrates how to use `ForgeEngine` for pretraining, achieving the same functionality as [torchtitan/train.py](/torchtitan/train.py) (except for quantization or fault tolerance). diff --git a/torchtitan/experiments/forge/__init__.py b/torchtitan/experiments/forge/__init__.py new file mode 100644 index 0000000000..1654959cee --- /dev/null +++ b/torchtitan/experiments/forge/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .engine import ForgeEngine +from .job_config import ForgeJobConfig +from .train_spec import ForgeTrainSpec, register_train_spec + +__all__ = ["ForgeEngine", "ForgeJobConfig", "ForgeTrainSpec", "register_train_spec"] diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py new file mode 100644 index 0000000000..0875e83d3c --- /dev/null +++ b/torchtitan/experiments/forge/engine.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Generator + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +import torchtitan.protocols.train_spec as train_spec_module +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.loss import rescale_accumulated_loss +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools import utils + +from .job_config import ForgeJobConfig +from .train_spec import ForgeTrainSpec, get_train_spec + + +class ForgeEngine(torch.distributed.checkpoint.stateful.Stateful): + # core configs + job_config: ForgeJobConfig + parallel_dims: ParallelDims + train_spec: ForgeTrainSpec + + # swappable training components in ForgeTrainSpec + model_parts: list[torch.nn.Module] + loss_fn: train_spec_module.LossFunction + optimizers: train_spec_module.OptimizersContainer + lr_schedulers: train_spec_module.LRSchedulersContainer + + # non-swappable training components + checkpointer: CheckpointManager + + # runtime utilities + device: torch.device + gc_handler: utils.GarbageCollection + gradient_accumulation_steps: int + train_context: Generator[None, None, None] + pp_has_first_stage: bool + pp_has_last_stage: bool + + # Fields in ForgeEngine which are not in original Trainer + # for dataloading + dp_degree: int + dp_rank: int + # for logging + model_args: BaseModelArgs + num_flops_per_token: float + model_param_count: int + global_batch_size: int + + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html + @record + def __init__(self, job_config: ForgeJobConfig): + torch._C._log_api_usage_once("torchtitan.train") + + self.job_config = job_config + + device_module, device_type = utils.device_module, utils.device_type + self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") + # Device has to be set before creating TorchFT manager. + device_module.set_device(self.device) + + # init distributed and build meshes + dist_utils.init_distributed(job_config) + world_size = int(os.environ["WORLD_SIZE"]) + parallelism_config = job_config.parallelism + self.parallel_dims = parallel_dims = ParallelDims( + dp_shard=parallelism_config.data_parallel_shard_degree, + dp_replicate=parallelism_config.data_parallel_replicate_degree, + cp=parallelism_config.context_parallel_degree, + tp=parallelism_config.tensor_parallel_degree, + pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, + world_size=world_size, + ) + + world_mesh = parallel_dims.world_mesh + if parallel_dims.dp_enabled: + dp_mesh = world_mesh["dp"] + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + else: + dp_degree, dp_rank = 1, 0 + self.dp_degree, self.dp_rank = dp_degree, dp_rank + + # take control of garbage collection to avoid stragglers + self.gc_handler = utils.GarbageCollection( + gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug + ) + + # Set random seed, and maybe enable deterministic mode + # (mainly for debugging, expect perf loss). + dist_utils.set_determinism( + world_mesh, + self.device, + job_config.training.seed, + job_config.training.deterministic, + ) + self.train_spec = get_train_spec(job_config.model.name) + + # build model (using meta init) + self.model_args = model_args = self.train_spec.model_args[ + job_config.model.flavor + ] + # set the model args from training job configs + model_args.update_from_config(job_config) + + with torch.device("meta"): + model = self.train_spec.model_cls(model_args) + + # calculate model size and flops per token + ( + self.model_param_count, + self.num_flops_per_token, + ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) + + # move sharded model to CPU/GPU and initialize weights via DTensor + if job_config.training.enable_cpu_offload: + init_device = "cpu" + buffer_device = device_type + else: + init_device = device_type + buffer_device = None + + self.loss_fn = self.train_spec.build_loss_fn(job_config) + + # verify batch sizes + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + global_batch_size = job_config.training.local_batch_size * dp_degree + assert global_batch_size > 0 + assert ( + global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 + ), ( + f"global batch size must be multiple of local batch size times " + f"data-parallel degree ({global_batch_size} " + f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" + ) + self.global_batch_size = global_batch_size + + # calculate gradient accumulation steps + self.gradient_accumulation_steps = global_batch_size // ( + job_config.training.local_batch_size * dp_degree + ) + assert self.gradient_accumulation_steps > 0 + self.loss_fn = rescale_accumulated_loss( + self.loss_fn, self.gradient_accumulation_steps + ) + + # apply parallelisms and initialization + if parallel_dims.pp_enabled: + if not self.train_spec.pipelining_fn: + raise RuntimeError( + f"Pipeline Parallel is enabled but {self.train_spec.name} " + f"does not support pipelining" + ) + + # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques + ( + self.pp_schedule, + self.model_parts, + self.pp_has_first_stage, + self.pp_has_last_stage, + ) = self.train_spec.pipelining_fn( + model, + parallel_dims, + job_config, + self.device, + model_args, + self.train_spec.parallelize_fn, + self.loss_fn, + ) + # when PP is enabled, `model` obj is no longer used after this point, + # model_parts is used instead + del model + + for m in self.model_parts: + m.to_empty(device=init_device) + with torch.no_grad(): + m.init_weights(buffer_device=buffer_device) + m.train() + else: + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) + + model.to_empty(device=init_device) + with torch.no_grad(): + model.init_weights(buffer_device=buffer_device) + model.train() + + self.model_parts = [model] + + # build optimizer after applying parallelisms to the model + self.optimizers = self.train_spec.build_optimizers_fn( + self.model_parts, job_config, parallel_dims + ) + self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( + self.optimizers, job_config + ) + + self.checkpointer = CheckpointManager( + dataloader=None, + model_parts=self.model_parts, + optimizers=self.optimizers, + lr_schedulers=self.lr_schedulers, + states={"train_state": self}, + checkpoint_config=job_config.checkpoint, + sd_adapter=self.train_spec.state_dict_adapter, + ) + + loss_parallel_enabled = ( + parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel + ) + self.train_context = dist_utils.get_train_context( + loss_parallel_enabled, + parallelism_config.enable_compiled_autograd, + ) + self.maybe_enable_amp = dist_utils.maybe_enable_amp( + parallel_dims, + job_config.training.mixed_precision_param, + device_type, + ) + + def close(self) -> None: + if self.checkpointer: + self.checkpointer.close() diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py new file mode 100644 index 0000000000..a0846c8ca3 --- /dev/null +++ b/torchtitan/experiments/forge/example_train.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import time +from datetime import timedelta +from typing import Any, Iterable, Optional + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +import torchtitan.protocols.train_spec as train_spec_module +from torchtitan.components.dataloader import DataloaderStopIteration +from torchtitan.components.metrics import build_metrics_processor +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.config import ConfigManager, JobConfig +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.distributed import utils as dist_utils +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.profiling import ( + maybe_enable_memory_snapshot, + maybe_enable_profiling, +) + +from .engine import ForgeEngine + + +class Trainer(ForgeEngine): + tokenizer: train_spec_module.BaseTokenizer | None + dataloader: train_spec_module.BaseDataLoader + validator: train_spec_module.BaseValidator + metrics_processor: train_spec_module.MetricsProcessor + + # additional training states + step: int + + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html + @record + def __init__(self, job_config: JobConfig): + logger.info(f"Starting job: {job_config.job.description}") + + if job_config.job.print_args: + logger.info(f"Running with args: {job_config.to_dict()}") + + if job_config.experimental.custom_import: + importlib.import_module(job_config.experimental.custom_import) + + # NOTE: Here we are passing in JobConfig as a superset of ForgeJobConfig + super().__init__(job_config) + + # build tokenizer + self.tokenizer = build_hf_tokenizer(job_config) + + # build dataloader + self.dataloader = build_hf_dataloader( + dp_world_size=self.dp_degree, + dp_rank=self.dp_rank, + tokenizer=self.tokenizer, + job_config=job_config, + ) + + model_args = self.model_args + logger.info( + f"Built {self.train_spec.name} {job_config.model.flavor} with {model_args}" + ) + + # metrics logging + self.metrics_processor = build_metrics_processor( + job_config, self.parallel_dims, model_args + ) + color = self.metrics_processor.color + + self.metrics_processor.num_flops_per_token = self.num_flops_per_token + + logger.info( + f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " + f"{color.red}size: {self.model_param_count:,} total parameters{color.reset}" + ) + + # initialize device memory monitor and get peak flops for MFU calculation + device_memory_monitor = self.metrics_processor.device_memory_monitor + gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") + device_mem_stats = device_memory_monitor.get_peak_stats() + logger.info( + f"{utils.device_type.upper()} memory usage for model: " + f"{device_mem_stats.max_reserved_gib:.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%)" + ) + + self.metrics_processor.optimizers = self.optimizers + + # Initialize trainer states that will be saved in checkpoint. + # These attributes must be initialized before checkpoint loading. + self.step = 0 + + # Build validator if validation is configured + if job_config.validation.enabled: + self.validator = build_validator( + job_config=job_config, + dp_world_size=self.dp_degree, + dp_rank=self.dp_rank, + tokenizer=self.tokenizer, + parallel_dims=self.parallel_dims, + loss_fn=self.train_spec.build_loss_fn(job_config), + validation_context=self.train_context, + maybe_enable_amp=self.maybe_enable_amp, + ) + + logger.info( + "Trainer is initialized with " + f"local batch size {job_config.training.local_batch_size}, " + f"global batch size {self.global_batch_size}, " + f"gradient accumulation steps {self.gradient_accumulation_steps}, " + f"sequence length {job_config.training.seq_len}, " + f"total steps {job_config.training.steps} " + f"(warmup {job_config.lr_scheduler.warmup_steps})." + ) + + def batch_generator( + self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]: + """Returns an iterator that processes batches from the data iterator.""" + device_type = utils.device_type + data_iterator = iter(data_iterable) + + while True: + try: + batch = next(data_iterator) + except StopIteration as ex: + # If data runs out during gradient accumulation, that + # entire step will not be executed. + raise DataloaderStopIteration() from ex + data_load_start = time.perf_counter() + input_dict, labels = batch + self.metrics_processor.ntokens_since_last_log += labels.numel() + self.metrics_processor.data_loading_times.append( + time.perf_counter() - data_load_start + ) + + # Move tensors to the appropriate device + for k, v in input_dict.items(): + if isinstance(v, torch.Tensor): + input_dict[k] = v.to(device_type) + labels = labels.to(device_type) + + yield input_dict, labels + + def forward_backward_step( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> torch.Tensor: + model_parts = self.model_parts + parallel_dims = self.parallel_dims + + # apply context parallelism if cp is enabled + # ensure CP handles the separate freqs_cis buffer for each pp stage + inputs = input_dict["input"] + optional_context_parallel_ctx = ( + dist_utils.create_context_parallel_ctx( + cp_mesh=parallel_dims.world_mesh["cp"], + cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], + cp_seq_dims=[1, 1] + [0 for _ in model_parts], + cp_no_restore_buffers={inputs, labels}, + cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + ) + if parallel_dims.cp_enabled + else None + ) + + if parallel_dims.pp_enabled: + # Pipeline Parallel forward / backward inside step() call + with self.train_context(optional_context_parallel_ctx): + targets, losses = ( + (labels, []) if self.pp_has_last_stage else (None, None) + ) + if self.pp_has_first_stage: + self.pp_schedule.step( + inputs, target=targets, losses=losses, input_batch=inputs + ) + else: + self.pp_schedule.step( + target=targets, losses=losses, input_batch=inputs + ) + + # accumulate losses across pipeline microbatches + # TODO: PP+FSDP unexpectedly puts the loss back to the CPU + loss = ( + torch.mean(torch.stack(losses)).to(self.device) + if self.pp_has_last_stage + else torch.tensor([-1.0], device=self.device) + ) + else: + # Non-PP forward / backward + with self.train_context(optional_context_parallel_ctx): + assert len(model_parts) == 1 + with self.maybe_enable_amp: + pred = model_parts[0](inputs) + loss = self.loss_fn(pred, labels) + # need to free to before bwd to avoid peaking memory + del pred + loss.backward() + + return loss + + def train_step( + self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ): + self.optimizers.zero_grad() + + # Keep these variables local to shorten the code as these are + # the major variables that are used in the training loop. + parallel_dims = self.parallel_dims + + accumulated_losses = [] + # If data runs out during gradient accumulation, that + # entire step will not be executed. + for microbatch in range(self.gradient_accumulation_steps): + input_dict, labels = next(data_iterator) + loss = self.forward_backward_step(input_dict, labels) + accumulated_losses.append(loss.detach()) + + grad_norm = dist_utils.clip_grad_norm_( + [p for m in self.model_parts for p in m.parameters()], + self.job_config.training.max_norm, + foreach=True, + pp_mesh=( + parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + ), + ep_dense_params_mesh_ndim=( + parallel_dims.dense_params_mesh_ndim + if parallel_dims.ep_enabled + else None + ), + ) + self.checkpointer.maybe_wait_for_staging() + self.optimizers.step() + self.lr_schedulers.step() + + # Reduce the data collected over gradient accumulation steps. + loss = torch.sum(torch.stack(accumulated_losses)) + + # log metrics + if not self.metrics_processor.should_log(self.step): + return + + if parallel_dims.dp_cp_enabled: + loss = loss.detach() + global_avg_loss, global_max_loss = ( + dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"]), + dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"]), + ) + else: + global_avg_loss = global_max_loss = loss.detach().item() + + self.metrics_processor.log( + self.step, + global_avg_loss, + global_max_loss, + grad_norm.item(), + ) + + @record + def train(self): + job_config = self.job_config + + self.checkpointer.load(step=job_config.checkpoint.load_step) + logger.info(f"Training starts at step {self.step + 1}.") + + with ( + maybe_enable_profiling(job_config, global_step=self.step) as torch_profiler, + maybe_enable_memory_snapshot( + job_config, global_step=self.step + ) as memory_profiler, + ): + data_iterator = self.batch_generator(self.dataloader) + while self.step < job_config.training.steps: + self.step += 1 + self.gc_handler.run(self.step) + try: + self.train_step(data_iterator) + except DataloaderStopIteration: + logger.warning("Ran out of data; last step was canceled.") + break + + # Run validation if validator is available + if ( + self.job_config.validation.enabled + and self.validator.should_validate(self.step) + ): + self.validator.validate(self.model_parts) + + self.checkpointer.save( + self.step, last_step=(self.step == job_config.training.steps) + ) + + # signal the profiler that the next profiling step has started + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step() + + # reduce timeout after first train step for faster signal + # (assuming lazy init and compilation are finished) + if self.step == 1: + dist_utils.set_pg_timeouts( + timeout=timedelta( + seconds=job_config.comm.train_timeout_seconds + ), + world_mesh=self.parallel_dims.world_mesh, + ) + + if torch.distributed.get_rank() == 0: + logger.info("Sleeping 2 seconds for other ranks to complete") + time.sleep(2) + + logger.info("Training completed") + + def state_dict(self) -> dict[str, Any]: + return {"step": self.step} + + def load_state_dict(self, state_dict: dict[str, Any]): + self.step = state_dict["step"] + + def close(self) -> None: + if self.metrics_processor: + self.metrics_processor.close() + super().close() + + +if __name__ == "__main__": + init_logger() + config_manager = ConfigManager() + config = config_manager.parse_args() + trainer: Optional[Trainer] = None + + try: + trainer = Trainer(config) + trainer.train() + except Exception: + if trainer: + trainer.close() + raise + else: + trainer.close() + torch.distributed.destroy_process_group() + logger.info("Process group destroyed.") diff --git a/torchtitan/experiments/forge/job_config.py b/torchtitan/experiments/forge/job_config.py new file mode 100644 index 0000000000..56602e3520 --- /dev/null +++ b/torchtitan/experiments/forge/job_config.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import asdict, dataclass, field +from typing import Any + +from torchtitan.config.job_config import ( + ActivationCheckpoint, + Checkpoint, + Comm, + Float8, + LRScheduler, + Model, + Optimizer, + Parallelism, + Training, +) + + +@dataclass +class ForgeJobConfig: + model: Model = field(default_factory=Model) + optimizer: Optimizer = field(default_factory=Optimizer) + lr_scheduler: LRScheduler = field(default_factory=LRScheduler) + training: Training = field(default_factory=Training) + parallelism: Parallelism = field(default_factory=Parallelism) + checkpoint: Checkpoint = field(default_factory=Checkpoint) + activation_checkpoint: ActivationCheckpoint = field( + default_factory=ActivationCheckpoint + ) + float8: Float8 = field(default_factory=Float8) + comm: Comm = field(default_factory=Comm) + + def to_dict(self) -> dict[str, Any]: + return asdict(self) diff --git a/torchtitan/experiments/forge/train_spec.py b/torchtitan/experiments/forge/train_spec.py new file mode 100644 index 0000000000..e7dc1077c2 --- /dev/null +++ b/torchtitan/experiments/forge/train_spec.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +# Import torchtitan.models to ensure all train specs are registered +import torchtitan.models # noqa: F401 +from torchtitan.protocols.state_dict_adapter import StateDictAdapter +from torchtitan.protocols.train_spec import ( + _train_specs, + BaseModelArgs, + LossFunctionBuilder, + LRSchedulersBuilder, + ModelProtocol, + OptimizersBuilder, + ParallelizeFunction, + PipeliningFunction, + TrainSpec, +) + + +@dataclass +class ForgeTrainSpec: + name: str + model_cls: type[ModelProtocol] + model_args: dict[str, BaseModelArgs] + parallelize_fn: ParallelizeFunction + pipelining_fn: PipeliningFunction | None + build_optimizers_fn: OptimizersBuilder + build_lr_schedulers_fn: LRSchedulersBuilder + build_loss_fn: LossFunctionBuilder + state_dict_adapter: type[StateDictAdapter] | None = None + + +# Copy and transform train specs from torchtitan.protocols.train_spec._train_specs +# This happens during import after all models have been registered +_forge_train_specs = {} + + +def register_train_spec(train_spec: ForgeTrainSpec) -> None: + global _forge_train_specs + if train_spec.name in _forge_train_specs: + raise ValueError(f"Model {train_spec.name} is already registered.") + + _forge_train_specs[train_spec.name] = train_spec + + +def get_train_spec(name: str) -> ForgeTrainSpec: + global _forge_train_specs + if name not in _forge_train_specs: + raise ValueError(f"Model {name} is not registered.") + return _forge_train_specs[name] + + +def _transform_train_spec(original_spec: TrainSpec): + """Transform the original train spec to ForgeTrainSpec format.""" + # Create a new TrainSpec with only the fields we need in forge + return ForgeTrainSpec( + name=original_spec.name, + model_cls=original_spec.model_cls, + model_args=original_spec.model_args, + parallelize_fn=original_spec.parallelize_fn, + pipelining_fn=original_spec.pipelining_fn, + build_optimizers_fn=original_spec.build_optimizers_fn, + build_lr_schedulers_fn=original_spec.build_lr_schedulers_fn, + build_loss_fn=original_spec.build_loss_fn, + state_dict_adapter=original_spec.state_dict_adapter, + ) + + +# Populate _forge_train_specs with transformed specs +for name, spec in _train_specs.items(): + register_train_spec(_transform_train_spec(spec)) From 2f1c814da071cc8ad165d00be6f9c1a66f8e1cce Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Tue, 22 Jul 2025 21:23:01 -0700 Subject: [PATCH 031/128] add back torch nightly install instruction (#1444) It seems user are getting confused https://github.com/pytorch/torchtitan/issues/1423 https://github.com/pytorch/torchtitan/issues/1443 --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 58c437335d..8e8ea9b194 100644 --- a/README.md +++ b/README.md @@ -107,12 +107,13 @@ Note that each stable release pins the nightly versions of `torch` and `torchao` This method requires the nightly build of PyTorch. You can replace `cu126` with another version of cuda (e.g. `cu128`) or an AMD GPU (e.g. `rocm6.3`). ```sh +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall pip install --pre torchtitan --index-url https://download.pytorch.org/whl/nightly/cu126 ``` ### From source -This method requires the nightly build of PyTorch or the latest PyTorch built from source. +This method requires the nightly build of PyTorch or the latest PyTorch built [from source](https://github.com/pytorch/pytorch?tab=readme-ov-file#from-source). ```bash git clone https://github.com/pytorch/torchtitan From d282cf2ce9ca8049b4b8423c1d7578c80426576f Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Wed, 23 Jul 2025 21:13:01 -0700 Subject: [PATCH 032/128] [refactor] moving gloabl dependence on JobConfig to fine-grained configs (#1449) In general it is anti-pattern to globally depend on a monolithic `JobConfig` everywhere. This PR takes the first step to remove this pattern. After this PR, the issue remains for parallelize & pipeline functions, metrics processor, tokenizer, dataloader, validator, etc. A note is that when refactoring tokenizer, dataloader, validator, we should still allow users to extend easily beyond. This may require the signature of builder functions to take additional args / kwargs. I'm also disabling some PP tests, which starts to fail on the latest torch nightly https://github.com/pytorch/torchtitan/actions/runs/16484633329/job/46606801605 --- README.md | 2 +- scripts/estimate/estimation.py | 6 +- scripts/generate/test_generate.py | 2 +- tests/integration_tests.py | 93 ++++++++++--------- tests/unit_tests/test_lr_scheduler.py | 24 +++-- tests/unit_tests/test_train_spec.py | 8 +- torchtitan/components/checkpoint.py | 3 +- torchtitan/components/lr_scheduler.py | 21 +++-- torchtitan/components/optimizer.py | 26 +++--- torchtitan/config/__init__.py | 42 ++++++++- torchtitan/config/job_config.py | 3 + torchtitan/distributed/utils.py | 20 ++-- torchtitan/experiments/forge/engine.py | 9 +- torchtitan/experiments/forge/example_train.py | 10 +- torchtitan/experiments/llama4/optimizer.py | 6 +- torchtitan/protocols/train_spec.py | 4 +- torchtitan/tools/profiling.py | 54 ++++++----- torchtitan/train.py | 27 +++++- 18 files changed, 225 insertions(+), 135 deletions(-) diff --git a/README.md b/README.md index 8e8ea9b194..5c5c30720a 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ You may want to see how the model is defined or how parallelism techniques are a One can choose to install `torchtitan` from a stable release, a nightly build, or directly run the source code. Please [install PyTorch](https://pytorch.org/get-started/locally/) before proceeding. ### Stable releases -One can install the latest [stable release]((https://github.com/pytorch/torchtitan/releases)) of `torchtitan` via `pip` or `conda`. +One can install the latest [stable release](https://github.com/pytorch/torchtitan/releases) of `torchtitan` via `pip` or `conda`. ```sh pip install torchtitan ``` diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 82d306e692..218e7a4c6e 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -112,8 +112,10 @@ def estimate_memory(job_config: JobConfig): model.train() # build optimizer after applying parallelisms to the model - optimizers = build_optimizers([model], job_config, parallel_dims) - lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) + optimizers = build_optimizers([model], job_config.optimizer, parallel_dims) + lr_schedulers = build_lr_schedulers( + optimizers.optimizers, job_config.lr_scheduler, job_config.training.steps + ) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 # where it issues a single all-reduce for all parameters at once for better performance diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 9b21b3e57b..ae20d11826 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -117,7 +117,7 @@ def test_generate( world_mesh = None # Init distributed env if world_size > 1: - dist_utils.init_distributed(config) + dist_utils.init_distributed(config.comm) parallel_dims = ParallelDims( dp_replicate=1, dp_shard=-1, diff --git a/tests/integration_tests.py b/tests/integration_tests.py index adca9ec56e..8f19086096 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -155,40 +155,41 @@ def build_test_list(): "Checkpoint Integration Test - Save Model Only bf16", "last_save_model_only_bf16", ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 4", - "--parallelism.pipeline_parallel_schedule InterleavedZeroBubble", - ], - ], - "PP looped zero bubble test", - "pp_looped_zero_bubble", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule ZBVZeroBubble", - ], - ], - "PP zero bubble test (v shaped)", - "pp_zbv", - ngpu=2, - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule 1F1B", - "--parallelism.data_parallel_shard_degree 1", - ], - ], - "PP 1D test 1F1B", - "pp_1f1b", - ngpu=2, - ), + # TODO: re-enable PP tests once the issue is fixed + # OverrideDefinitions( + # [ + # [ + # "--parallelism.pipeline_parallel_degree 4", + # "--parallelism.pipeline_parallel_schedule InterleavedZeroBubble", + # ], + # ], + # "PP looped zero bubble test", + # "pp_looped_zero_bubble", + # ngpu=4, + # ), + # OverrideDefinitions( + # [ + # [ + # "--parallelism.pipeline_parallel_degree 2", + # "--parallelism.pipeline_parallel_schedule ZBVZeroBubble", + # ], + # ], + # "PP zero bubble test (v shaped)", + # "pp_zbv", + # ngpu=2, + # ), + # OverrideDefinitions( + # [ + # [ + # "--parallelism.pipeline_parallel_degree 2", + # "--parallelism.pipeline_parallel_schedule 1F1B", + # "--parallelism.data_parallel_shard_degree 1", + # ], + # ], + # "PP 1D test 1F1B", + # "pp_1f1b", + # ngpu=2, + # ), OverrideDefinitions( [ [ @@ -288,18 +289,18 @@ def build_test_list(): "pp_looped_1f1b", ngpu=4, ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule PipelineScheduleMulti", - "--parallelism.pipeline_parallel_schedule_csv ./tests/assets/custom_schedule.csv", - ], - ], - "PP with custom pipeline schedule loaded from CSV file", - "pp_custom_csv", - ngpu=2, - ), + # OverrideDefinitions( + # [ + # [ + # "--parallelism.pipeline_parallel_degree 2", + # "--parallelism.pipeline_parallel_schedule PipelineScheduleMulti", + # "--parallelism.pipeline_parallel_schedule_csv ./tests/assets/custom_schedule.csv", + # ], + # ], + # "PP with custom pipeline schedule loaded from CSV file", + # "pp_custom_csv", + # ngpu=2, + # ), OverrideDefinitions( [ [ diff --git a/tests/unit_tests/test_lr_scheduler.py b/tests/unit_tests/test_lr_scheduler.py index 3e5473f51a..3d57bbd0cf 100644 --- a/tests/unit_tests/test_lr_scheduler.py +++ b/tests/unit_tests/test_lr_scheduler.py @@ -78,7 +78,9 @@ def test_linear_warmup_decay(self): ) # Build the lr scheduler - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) + lr_scheduler = build_lr_schedulers( + self.optimizer_container, config.lr_scheduler, config.training.steps + ) # Expected adjustment factors for each step expected_factors = [ @@ -118,7 +120,9 @@ def test_warmup_stable_decay(self): ) # Build the lr scheduler - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) + lr_scheduler = build_lr_schedulers( + self.optimizer_container, config.lr_scheduler, config.training.steps + ) # Expected adjustment factors for each step expected_factors = [ @@ -157,7 +161,9 @@ def test_min_lr(self): ) # Build the lr scheduler - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) + lr_scheduler = build_lr_schedulers( + self.optimizer_container, config.lr_scheduler, config.training.steps + ) # Step through all steps for _ in range(10): @@ -178,7 +184,9 @@ def test_warmup_exceeds_training(self): ) # Build the lr scheduler - should adjust warmup steps - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) + lr_scheduler = build_lr_schedulers( + self.optimizer_container, config.lr_scheduler, config.training.steps + ) # Expected adjustment factors for each step expected_factors = [ @@ -212,7 +220,9 @@ def test_warmup_stable_only(self): ) # Build the lr scheduler - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) + lr_scheduler = build_lr_schedulers( + self.optimizer_container, config.lr_scheduler, config.training.steps + ) # Expected adjustment factors for each step expected_factors = [ @@ -252,7 +262,9 @@ def test_warmup_plus_decay_exceeds_training(self): ) # Build the lr scheduler - should adjust warmup steps - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) + lr_scheduler = build_lr_schedulers( + self.optimizer_container, config.lr_scheduler, config.training.steps + ) # Expected adjustment factors for each step expected_factors = [ diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index 411ba1439e..1f0f9fb574 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -14,7 +14,7 @@ from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers, OptimizersContainer from torchtitan.components.tokenizer import build_hf_tokenizer -from torchtitan.config import JobConfig +from torchtitan.config import Optimizer as OptimizerConfig from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.models.llama3 import parallelize_llama, pipeline_llama @@ -42,7 +42,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: def fake_build_optimizers( model_parts: list[nn.Module], - job_config: JobConfig, + optimizer_config: OptimizerConfig, parallel_dims: ParallelDims, ft_manager: FTManager, ) -> OptimizersContainer: @@ -117,12 +117,12 @@ def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec: def my_build_optimizer_fn( model_parts: list[nn.Module], - job_config: JobConfig, + optimizer_config: OptimizerConfig, parallel_dims: ParallelDims, ft_manager: FTManager, ) -> OptimizersContainer: optimizers = original_build_optimizers_fn( - model_parts, job_config, parallel_dims, ft_manager + model_parts, optimizer_config, parallel_dims, ft_manager ) optimizers.register_step_post_hook( partial(my_hook, model_parts=model_parts) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 5a1b40ba88..8eede5fada 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -36,8 +36,7 @@ from torchtitan.components.ft import FTManager from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer -from torchtitan.config import TORCH_DTYPE_MAP -from torchtitan.config.job_config import Checkpoint as CheckpointConfig +from torchtitan.config import Checkpoint as CheckpointConfig, TORCH_DTYPE_MAP from torchtitan.protocols.state_dict_adapter import StateDictAdapter from torchtitan.tools.logging import logger from torchtitan.tools.utils import GarbageCollection diff --git a/torchtitan/components/lr_scheduler.py b/torchtitan/components/lr_scheduler.py index bccaf2b96c..8829431887 100644 --- a/torchtitan/components/lr_scheduler.py +++ b/torchtitan/components/lr_scheduler.py @@ -13,7 +13,7 @@ from torch.optim.lr_scheduler import LambdaLR, LRScheduler from torchtitan.components.optimizer import OptimizersContainer -from torchtitan.config import JobConfig +from torchtitan.config import LRScheduler as LRSchedulerConfig from torchtitan.tools.logging import logger __all__ = [ @@ -82,12 +82,14 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: def build_lr_schedulers( - optimizers: OptimizersContainer, job_config: JobConfig + optimizers: OptimizersContainer, + lr_scheduler_config: LRSchedulerConfig, + training_steps: int, ) -> LRSchedulersContainer: """Create a LRSchedulerContainer for the given optimizers and job config. This function creates a ``LRSchedulersContainer`` for the given optimizers. - ``job_config`` should define the correct lr scheduler parameters. + ``lr_scheduler_config`` should define the correct lr scheduler parameters. **Note** Users who want to customize the lr scheduler behavior can create their own @@ -99,9 +101,10 @@ def build_lr_schedulers( Args: optimizers (OptimizersContainer): The corresponding optimizers for the lr_schedulers. + lr_scheduler_config (LRSchedulerConfig): The lr scheduler config. + training_steps (int): The total number of training steps. """ - training_steps = job_config.training.steps - warmup_steps = int(job_config.lr_scheduler.warmup_steps) + warmup_steps = int(lr_scheduler_config.warmup_steps) if warmup_steps > training_steps: logger.warning( @@ -110,8 +113,8 @@ def build_lr_schedulers( ) warmup_steps = training_steps - if job_config.lr_scheduler.decay_ratio is not None: - decay_steps = round(training_steps * job_config.lr_scheduler.decay_ratio) + if lr_scheduler_config.decay_ratio is not None: + decay_steps = round(training_steps * lr_scheduler_config.decay_ratio) if warmup_steps + decay_steps > training_steps: logger.warning( f"Warmup ({warmup_steps}) + decay ({decay_steps}) steps exceed " @@ -123,8 +126,8 @@ def build_lr_schedulers( decay_steps = training_steps - warmup_steps # Add a vitual last step to prevent the learning rate from dropping to 0 stable_steps = training_steps + 1 - warmup_steps - decay_steps - lr_decay_type = job_config.lr_scheduler.decay_type - lr_min = job_config.lr_scheduler.lr_min + lr_decay_type = lr_scheduler_config.decay_type + lr_min = lr_scheduler_config.lr_min def linear_warmup_stable_decay( current_step: int, diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index f6fd02a4d5..2a112177e0 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -18,7 +18,7 @@ from torch.optim import Optimizer from torchtitan.components.ft import FTManager, has_torchft -from torchtitan.config import JobConfig +from torchtitan.config import Optimizer as OptimizerConfig from torchtitan.distributed import ParallelDims __all__ = [ @@ -241,14 +241,14 @@ def zero_grad(self, *args, **kwargs) -> None: def build_optimizers( model_parts: list[nn.Module], - job_config: JobConfig, + optimizer_config: OptimizerConfig, parallel_dims: ParallelDims, ft_manager: FTManager | None = None, ) -> OptimizersContainer: """Create a OptimizersContainer for the given model parts and job config. This function creates a ``OptimizersContainer`` for the given model parts. - ``job_config`` should define the correct optimizer name and parameters. + ``optimizer_config`` should define the correct optimizer name and parameters. This function currently supports creating ``OptimizersContainer`` and ``OptimizersInBackwardContainer``. @@ -260,10 +260,10 @@ def build_optimizers( Args: model_parts (List[nn.Module]): List of model parts to be optimized. - job_config (JobConfig): Job config containing the optimizer name and parameters. + optimizer_config (OptimizerConfig): Optimizer config containing the optimizer name and parameters. parallel_dims (ParallelDims): Parallel dimensions for the model. """ - optim_in_bwd = job_config.optimizer.early_step_in_backward + optim_in_bwd = optimizer_config.early_step_in_backward if optim_in_bwd: if parallel_dims.ep_enabled: raise NotImplementedError( @@ -278,14 +278,14 @@ def build_optimizers( "TorchFT is not supported with optimizers in backward." ) - name = job_config.optimizer.name - lr = job_config.optimizer.lr - beta1 = job_config.optimizer.beta1 - beta2 = job_config.optimizer.beta2 - eps = job_config.optimizer.eps - weight_decay = job_config.optimizer.weight_decay + name = optimizer_config.name + lr = optimizer_config.lr + beta1 = optimizer_config.beta1 + beta2 = optimizer_config.beta2 + eps = optimizer_config.eps + weight_decay = optimizer_config.weight_decay - optim_implementation = job_config.optimizer.implementation + optim_implementation = optimizer_config.implementation assert optim_implementation in ["fused", "foreach", "for-loop"] fused = optim_implementation == "fused" @@ -319,7 +319,7 @@ def build_optimizers( optimizer_cls, optimizer_kwargs, ft_manager.manager, - use_ft_optimizer=job_config.fault_tolerance.semi_sync_method is None, + use_ft_optimizer=ft_manager.use_async_quorum, ) return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) diff --git a/torchtitan/config/__init__.py b/torchtitan/config/__init__.py index 9bbcac7456..79e56c8887 100644 --- a/torchtitan/config/__init__.py +++ b/torchtitan/config/__init__.py @@ -12,7 +12,43 @@ "bfloat16": torch.bfloat16, } -from torchtitan.config.job_config import JobConfig -from torchtitan.config.manager import ConfigManager +from .job_config import ( + ActivationCheckpoint, + Checkpoint, + Comm, + FaultTolerance, + Float8, + Job, + JobConfig, + LRScheduler, + Metrics, + Model, + MX, + Optimizer, + Parallelism, + Profiling, + Training, + Validation, +) +from .manager import ConfigManager -__all__ = ["JobConfig", "ConfigManager", "TORCH_DTYPE_MAP"] +__all__ = [ + "JobConfig", + "ConfigManager", + "TORCH_DTYPE_MAP", + "Job", + "Model", + "MX", + "Optimizer", + "LRScheduler", + "Metrics", + "Checkpoint", + "ActivationCheckpoint", + "FaultTolerance", + "Float8", + "Parallelism", + "Comm", + "Profiling", + "Training", + "Validation", +] diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index b5c167e131..fdf38c63fe 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -563,6 +563,9 @@ class Comm: trace_buf_size: int = 20000 """Flight recorder ring buffer size, >0 means recording by default, 0 means disabled""" + save_traces_folder: str = "comm_traces" + """Flight recorder trace files location""" + @dataclass class MemoryEstimation: diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index ecda3f9b6f..3c9e20ffb5 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -18,7 +18,7 @@ from torch.distributed.tensor import DTensor from torch.nn.attention import SDPBackend -from torchtitan.config import TORCH_DTYPE_MAP +from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.models.attention import ScaledDotProductAttention from torchtitan.tools.logging import logger @@ -227,7 +227,9 @@ def maybe_enable_amp( ) -def init_distributed(job_config): +def init_distributed( + comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = "" +): def _warn_overwrite_env(env, val): if env in os.environ: logger.warning( @@ -235,13 +237,13 @@ def _warn_overwrite_env(env, val): ) os.environ[env] = val - def _get_distributed_backend(job_config): + def _get_distributed_backend(enable_cpu_backend): backend = "nccl" if device_type in torch.distributed.Backend.default_device_backend_map: backend = torch.distributed.Backend.default_device_backend_map.get( device_type ) - if job_config.training.enable_cpu_offload: + if enable_cpu_backend: backend = f"{device_type}:{backend},cpu:gloo" return backend @@ -258,17 +260,17 @@ def _get_distributed_backend(job_config): _warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP) # enable torch nccl flight recorder in the mode that would dump files if timeout is detected - _warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size)) - if job_config.comm.trace_buf_size > 0: + _warn_overwrite_env(TRACE_BUFFER_SIZE, str(comm_config.trace_buf_size)) + if comm_config.trace_buf_size > 0: # dump on timeout by default if trace buffer is enabled _warn_overwrite_env(DUMP_ON_TIMEOUT, "1") - dump_dir = f"{job_config.job.dump_folder}/comm_trace" + dump_dir = os.path.join(base_folder, comm_config.save_traces_folder) os.makedirs(dump_dir, exist_ok=True) _warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_") torch.distributed.init_process_group( - backend=_get_distributed_backend(job_config), - timeout=timedelta(seconds=job_config.comm.init_timeout_seconds), + backend=_get_distributed_backend(enable_cpu_backend), + timeout=timedelta(seconds=comm_config.init_timeout_seconds), ) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index 0875e83d3c..75c1b67c12 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -67,7 +67,10 @@ def __init__(self, job_config: ForgeJobConfig): device_module.set_device(self.device) # init distributed and build meshes - dist_utils.init_distributed(job_config) + dist_utils.init_distributed( + job_config.comm, + enable_cpu_backend=job_config.training.enable_cpu_offload, + ) world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism self.parallel_dims = parallel_dims = ParallelDims( @@ -199,10 +202,10 @@ def __init__(self, job_config: ForgeJobConfig): # build optimizer after applying parallelisms to the model self.optimizers = self.train_spec.build_optimizers_fn( - self.model_parts, job_config, parallel_dims + self.model_parts, job_config.optimizer, parallel_dims ) self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( - self.optimizers, job_config + self.optimizers, job_config.lr_scheduler, job_config.training.steps ) self.checkpointer = CheckpointManager( diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index a0846c8ca3..c54fc645c4 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -272,9 +272,15 @@ def train(self): logger.info(f"Training starts at step {self.step + 1}.") with ( - maybe_enable_profiling(job_config, global_step=self.step) as torch_profiler, + maybe_enable_profiling( + job_config.profiling, + global_step=self.step, + base_folder=job_config.job.dump_folder, + ) as torch_profiler, maybe_enable_memory_snapshot( - job_config, global_step=self.step + job_config.profiling, + global_step=self.step, + base_folder=job_config.job.dump_folder, ) as memory_profiler, ): data_iterator = self.batch_generator(self.dataloader) diff --git a/torchtitan/experiments/llama4/optimizer.py b/torchtitan/experiments/llama4/optimizer.py index 4a997dd817..0986452fae 100644 --- a/torchtitan/experiments/llama4/optimizer.py +++ b/torchtitan/experiments/llama4/optimizer.py @@ -9,7 +9,7 @@ from torchtitan.components.ft import FTManager from torchtitan.components.optimizer import build_optimizers, OptimizersContainer -from torchtitan.config import JobConfig +from torchtitan.config import Optimizer as OptimizerConfig from torchtitan.distributed import ParallelDims @@ -46,13 +46,13 @@ def _update_expert_bias( def build_llama4_optimizers( model_parts: list[nn.Module], - job_config: JobConfig, + optimizer_config: OptimizerConfig, parallel_dims: ParallelDims, ft_manager: FTManager | None = None, ) -> OptimizersContainer: optimizers = build_optimizers( model_parts=model_parts, - job_config=job_config, + optimizer_config=optimizer_config, parallel_dims=parallel_dims, ft_manager=ft_manager, ) diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index afbf0a560a..db5c923011 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -20,7 +20,7 @@ from torchtitan.components.optimizer import OptimizersContainer from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.components.validate import BaseValidator -from torchtitan.config import JobConfig +from torchtitan.config import JobConfig, LRScheduler from torchtitan.protocols.state_dict_adapter import StateDictAdapter @@ -74,7 +74,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] OptimizersBuilder: TypeAlias = Callable[..., OptimizersContainer] LRSchedulersBuilder: TypeAlias = Callable[ - [OptimizersContainer, JobConfig], LRSchedulersContainer + [OptimizersContainer, LRScheduler, int], LRSchedulersContainer ] LossFunctionBuilder: TypeAlias = Callable[..., LossFunction] ValidatorBuilder: TypeAlias = Callable[..., BaseValidator] diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 1e9c67ea69..843c13a746 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -11,7 +11,7 @@ import torch -from torchtitan.config import JobConfig +from torchtitan.config import Profiling as ProfilingConfig from torchtitan.tools.logging import logger # the number of warmup steps before the active step in each profiling cycle @@ -22,22 +22,22 @@ @contextlib.contextmanager -def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): +def maybe_enable_profiling( + profiling_config: ProfilingConfig, + *, + global_step: int = 0, + base_folder: str = "", + leaf_folder: str = "", +): # get user defined profiler settings - enable_profiling = config.profiling.enable_profiling + enable_profiling = profiling_config.enable_profiling if enable_profiling: - dump_dir = config.job.dump_folder - save_trace_dir = config.profiling.save_traces_folder - trace_dir = os.path.join(dump_dir, save_trace_dir) - profile_freq = config.profiling.profile_freq + trace_dir = os.path.join(base_folder, profiling_config.save_traces_folder) + profile_freq = profiling_config.profile_freq rank = torch.distributed.get_rank() - replica_id = None - if config.fault_tolerance.enable: - replica_id = config.fault_tolerance.replica_id - def trace_handler(prof): curr_trace_dir_name = "iteration_" + str(prof.step_num) curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) @@ -47,11 +47,9 @@ def trace_handler(prof): logger.info(f"Dumping profiler traces at step {prof.step_num}") begin = time.monotonic() - output_file = curr_trace_dir - if replica_id is not None: - output_file = os.path.join(output_file, f"replica{replica_id}") - output_file = os.path.join(output_file, f"rank{rank}_trace.json") - + output_file = os.path.join( + curr_trace_dir, leaf_folder, f"rank{rank}_trace.json" + ) prof.export_chrome_trace(output_file) logger.info( f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" @@ -89,11 +87,18 @@ def trace_handler(prof): @contextlib.contextmanager -def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0): - enable_snapshot = config.profiling.enable_memory_snapshot +def maybe_enable_memory_snapshot( + profiling_config: ProfilingConfig, + *, + global_step: int = 0, + base_folder: str = "", + leaf_folder: str = "", +): + enable_snapshot = profiling_config.enable_memory_snapshot if enable_snapshot: - snapshot_folder = config.profiling.save_memory_snapshot_folder - snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder) + snapshot_dir = os.path.join( + base_folder, profiling_config.save_memory_snapshot_folder + ) if not os.path.exists(snapshot_dir): os.makedirs(snapshot_dir, exist_ok=True) rank = torch.distributed.get_rank() @@ -123,16 +128,17 @@ def step(self, exit_ctx: bool = False): os.makedirs(curr_snapshot_dir, exist_ok=True) logger.info(f"Dumping memory snapshot at step {curr_step}") begin = time.monotonic() - with open( - f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb" - ) as output: + output_file = os.path.join( + curr_snapshot_dir, leaf_folder, f"rank{rank}_memory_snapshot.pickle" + ) + with open(output_file, "wb") as output: pickle.dump(torch.cuda.memory._snapshot(), output) logger.info( f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds" ) logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}") - profiler = MemoryProfiler(global_step, config.profiling.profile_freq) + profiler = MemoryProfiler(global_step, profiling_config.profile_freq) try: yield profiler except torch.OutOfMemoryError as e: diff --git a/torchtitan/train.py b/torchtitan/train.py index 9b9f5d4115..4123dd9234 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -85,7 +85,11 @@ def __init__(self, job_config: JobConfig): device_module.set_device(self.device) # init distributed and build meshes - dist_utils.init_distributed(job_config) + dist_utils.init_distributed( + job_config.comm, + enable_cpu_backend=job_config.training.enable_cpu_offload, + base_folder=job_config.job.dump_folder, + ) world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism self.parallel_dims = parallel_dims = ParallelDims( @@ -272,10 +276,10 @@ def __init__(self, job_config: JobConfig): # build optimizer after applying parallelisms to the model self.optimizers = self.train_spec.build_optimizers_fn( - self.model_parts, job_config, parallel_dims, self.ft_manager + self.model_parts, job_config.optimizer, parallel_dims, self.ft_manager ) self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( - self.optimizers, job_config + self.optimizers, job_config.lr_scheduler, job_config.training.steps ) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 @@ -500,10 +504,23 @@ def train(self): self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}.") + leaf_folder = ( + "" + if not self.ft_manager.enabled + else f"replica_{self.ft_manager.replica_id}" + ) with ( - maybe_enable_profiling(job_config, global_step=self.step) as torch_profiler, + maybe_enable_profiling( + job_config.profiling, + global_step=self.step, + base_folder=job_config.job.dump_folder, + leaf_folder=leaf_folder, + ) as torch_profiler, maybe_enable_memory_snapshot( - job_config, global_step=self.step + job_config.profiling, + global_step=self.step, + base_folder=job_config.job.dump_folder, + leaf_folder=leaf_folder, ) as memory_profiler, maybe_semi_sync_training( job_config.fault_tolerance, From 70592cb07dd5f12cdb2d209c2578ff8a5cb18bd3 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Thu, 24 Jul 2025 15:51:39 -0700 Subject: [PATCH 033/128] added model definition conversion for llama3 (#1441) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## This pr adds a model state dict conversion between TT and HF. It includes to and from huggingface, and importantly performs a permutation on the q and k attention matrices to address the differences in RoPE implementation between native LLaMA and HuggingFace. Thanks to @rlrs and @vwxyzjn for finding and helping to fix this issue #335, https://github.com/pytorch/torchtitan/issues/1291#issuecomment-2997077080 ### Testing I tested the correctness of the model conversion by using the two methods greedy decoding, and kl_divergence for thorough comparison. To test the from_hf script I downloaded a model from HF hub, converted it using the script, and ran forward passes using torchtitan. To test the to_hf script I obtained original llama weights and used the verified llama->dcp script. Then I used the convert_to_hf script to convert these weights to safetensors checkpoint. For kl_divergence I tested each to_hf and from_hf against the baseline hf model, and compared these to the to_hf and from_hf weights when not performing the permutation. | permuted wq and wk | kl_div (hf->tt) | kl_div (tt->hf) | | --- | --- | --- | | ✅ | -3.8356e-15 | -1.4431e-14 | | ❌ | 3.0658e-06 | 9.6463e-06 | When comparing, we can clearly see the kl div loss is many orders of magnitude higher when not permuted, showing that these probability distributions don't accurately represent the baseline hf's probability distribution. However, due to the small amount of weights that need to be permuted in this case, the loss is still not very high in the incorrect case, and can be deceiving if only using this as the evaluation metric. Therefore we also use greedy decoding with long generated sequences, calculating the loss as the exact match ratio of generated tokens. | permuted wq and wk | kl_div (hf->tt) | kl_div (tt->hf) | | --- | --- | --- | | ❌ | | | | ✅ | | | ### Usage The model conversion can be done in two ways. The first direct way is to use the new convert_from_hf.py or convert_to_hf.py script, but requires loading the entire model weights into cpu memory. The second way is to use the training config options to load/save in hf format during training. This should bring us one step closer to https://github.com/pytorch/torchtitan/issues/1210 --- docs/checkpoint.md | 102 +++++++++------ .../checkpoint_conversion/convert_from_hf.py | 62 +++++++++ .../convert_from_llama.py} | 4 +- .../checkpoint_conversion/convert_to_hf.py | 80 ++++++++++++ tests/unit_tests/test_train_spec.py | 3 +- torchtitan/components/checkpoint.py | 5 +- torchtitan/components/metrics.py | 2 +- .../experiments/deepseek_v3/model_args.py | 2 +- torchtitan/experiments/flux/model/args.py | 2 +- torchtitan/experiments/flux/model/model.py | 2 +- torchtitan/experiments/forge/engine.py | 2 +- torchtitan/experiments/forge/train_spec.py | 4 +- torchtitan/experiments/llama4/model/args.py | 2 +- torchtitan/experiments/llama4/model/model.py | 2 +- .../models/llama3/model/state_dict_adapter.py | 119 ++++++++++++++++-- torchtitan/protocols/__init__.py | 17 +++ torchtitan/protocols/model.py | 55 ++++++++ torchtitan/protocols/state_dict_adapter.py | 12 +- torchtitan/protocols/train_spec.py | 49 +------- torchtitan/train.py | 17 ++- 20 files changed, 422 insertions(+), 121 deletions(-) create mode 100644 scripts/checkpoint_conversion/convert_from_hf.py rename scripts/{convert_llama_to_dcp.py => checkpoint_conversion/convert_from_llama.py} (97%) create mode 100644 scripts/checkpoint_conversion/convert_to_hf.py create mode 100644 torchtitan/protocols/__init__.py create mode 100644 torchtitan/protocols/model.py diff --git a/docs/checkpoint.md b/docs/checkpoint.md index ecfdd67d6b..00736a7558 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -1,29 +1,17 @@ -## How to convert a Llama 3 checkpoint for use in torchtitan +# How to use checkpointing in `torchtitan` -If you want to continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager. -An example script for converting the original Llama3 checkpoints into the expected DCP format can be found in `scripts/convert_llama_to_dcp.py`. - -The script expects a path to the original checkpoint files, and a path to an output directory: -```bash -python -m scripts.convert_llama_to_dcp -``` +You may want to enable checkpointing in `torchtitan` for better fault tolerance during training, or to enable easier importing and exporting of weights between `torchtitan` and other libraries. `torchtitan` offers varying degrees of support for other checkpoint formats which are listed further below. +## A general guide to use checkpoints during training -## How to convert a torchtitan checkpoint for use in torchtune - -This guide will walk you through the steps required to convert a checkpoint from torchtitan so that it can be loaded into torchtune. - -### Steps 1. ENABLE CHECKPOINTING -In your torchtitan training config, ensure that `enable_checkpoint` is set to True. +In your `torchtitan` training config, ensure that `enable_checkpoint` is set to True. ``` [checkpoint] enable_checkpoint = true folder = "checkpoint" interval = 500 ``` - - 2. SAVE MODEL ONLY By setting `last_save_model_only` to `True`, the checkpoint will only contain the model and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size. ``` @@ -41,7 +29,17 @@ last_save_model_only = true export_dtype = "bfloat16" ``` -4. EXAMPLE CHECKPOINT CONFIGURATION +4. EXCLUDING SPECIFIC KEYS FROM CHECKPOINT LOADING +In some cases, you may want to partially load from a previous-trained checkpoint and modify certain settings, such as the number of GPUs or the current step. To achieve this, you can use the `exclude_from_loading` parameter to specify which keys should be excluded from loading. +This parameter takes a list of string that should be excluded from loading. +``` +[checkpoint] +enable_checkpoint = true +exclude_from_loading = ["data_loader", "lr_scheduler"] +``` +When used in command line, the parameter should be a comma-separated list of strings. For example: `--checkpoint.exclude_from_loading data_loader,lr_scheduler`. + +5. EXAMPLE CHECKPOINT CONFIGURATION ``` [checkpoint] enable_checkpoint = true @@ -52,41 +50,67 @@ last_save_model_only = true export_dtype = "bfloat16" ``` -5. SAVE THE FINAL CHECKPOINT\ -Once the above have been set, the final checkpoint at the end of the training step will consist of model only with the desired export dtype. However, if the final step has not been reached yet, full checkpoints will still be saved so that training can be resumed. +A more exhaustive and up-to-date list of checkpoint config options can be found in `torchtitan/config/job_config.py` -6. CONVERT SHARDED CHECKPOINTS TO A SINGLE FILE\ -Finally, once you have obtained the last checkpoint, you can use the following command to convert the sharded checkpoints to a single .pt file that can be loaded into torchtune: +## Creating a seed checkpoint +Sometimes one needs to create a seed checkpoint to initialize a model from step 0. +E.g. it is hard, if not impossible, for meta initialization on multiple devices to reproduce the initialization on a single device. +A seed checkpoint does initialization of the model on a single CPU, and can be loaded from another job on an arbitrary number of GPUs via DCP resharding. +To create a seed checkpoint, use the same model config as you use for training. +e.g. +```bash +NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 ``` -python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outputs/checkpoint/step-1000 checkpoint.pt + +## Conversion support + +### HuggingFace +`torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu. + +1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, simply enable `--checkpoint.initial_load_model_only` and set `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. + +2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt. + +```bash +python ./scripts/checkpoint_conversion/convert_from_hf.py --model_name --model_flavor +python ./scripts/checkpoint_conversion/convert_to_hf.py --model_name --model_flavor +# e.g. +python ./scripts/convert_from_hf.py ~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ ./outputs/checkpoint/step-0 --model_name llama3 --model_flavor 8B ``` -7. EXCLUDING SPECIFIC KEYS FROM CHECKPOINT LOADING -In some cases, you may want to partially load from a previous-trained checkpoint and modify certain settings, such as the number of GPUs or the current step. To achieve this, you can use the `exclude_from_loading` parameter to specify which keys should be excluded from loading. -This parameter takes a list of string that should be excluded from loading. +### Torch + +This guide will walk you through the steps required to convert a checkpoint from `torchtitan` so that it can be loaded into pt format. + +1. CHECKPOINT CONFIGURATION ``` [checkpoint] enable_checkpoint = true -exclude_from_loading = ["data_loader", "lr_scheduler"] +folder = "checkpoint" +interval = 10 +last_save_model_only = true +export_dtype = "bfloat16" ``` -When used in command line, the parameter should be a comma-separated list of strings. For example: `--checkpoint.exclude_from_loading data_loader,lr_scheduler`. - -That's it. You have now successfully converted a sharded torchtitan checkpoint for use in torchtune. +2. SAVE THE FINAL CHECKPOINT\ +Once the above have been set, the final checkpoint at the end of the training step will consist of model only with the desired export dtype. However, if the final step has not been reached yet, full checkpoints will still be saved so that training can be resumed. -## How to create a seed checkpoint -Sometimes one needs to create a seed checkpoint to initialize a model from step 0. -E.g. it is hard, if not impossible, for meta initialization on multiple devices to reproduce the initialization on a single device. -A seed checkpoint does initialization of the model on a single CPU, and can be loaded from another job on an arbitrary number of GPUs via DCP resharding. +3. CONVERT SHARDED CHECKPOINTS TO A SINGLE FILE\ +Finally, once you have obtained the last checkpoint, you can use the following command to convert the sharded checkpoints to a single .pt file. -To create a seed checkpoint, use the same model config as you use for training. -e.g. ```bash -NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 +python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outputs/checkpoint/step-1000 checkpoint.pt ``` -## How to load / save a checkpoint in HF safetensors format -For save, users need to set `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` to save the last checkpoint in HF format (intermediate ones are always in DCP format). -For load, users need to either put the checkpoint in the `step-0` folder if using `--checkpoint.folder`, or specify `--checkpoint.initial_load_path` to load from a different folder. They also need to set `--checkpoint.initial_load_model_only` to load the checkpoint in HF format. +That's it. You have now successfully converted a sharded `torchtitan` checkpoint for use with pytorch formats. + +### PyTorch Meta Llama + +An example script for converting the original Llama3 checkpoints into DCP format to be used with `torchtitan` can be found in `scripts/convert_from_llama.py`. + +The script expects a path to the original checkpoint files, and a path to an output directory: +```bash +python -m scripts.convert_from_llama +``` diff --git a/scripts/checkpoint_conversion/convert_from_hf.py b/scripts/checkpoint_conversion/convert_from_hf.py new file mode 100644 index 0000000000..42ed00bf27 --- /dev/null +++ b/scripts/checkpoint_conversion/convert_from_hf.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from pathlib import Path + +import torch +import torch.distributed.checkpoint as dcp +import torchtitan.protocols.train_spec as train_spec_module +from torch.distributed.checkpoint import HuggingFaceStorageReader +from torchtitan.components.checkpoint import ModelWrapper + + +@torch.inference_mode() +def convert_from_hf(input_dir, output_dir, model_name, model_flavor): + # initialize model to allocate memory for state dict + train_spec = train_spec_module.get_train_spec(model_name) + model_args = train_spec.model_args[model_flavor] + + with torch.device("cpu"): + model = train_spec.model_cls(model_args) + model = ModelWrapper(model) + + sd_adapter = train_spec.state_dict_adapter(model_args) + assert ( + sd_adapter is not None + ), "trying to convert checkpoint from HF to DCP safetensors format, but sd_adapter is not provided." + # get state dict in tt format with allocated memory + state_dict = model._get_state_dict() + # convert empty state dict to hf format so that hf weights can be loaded into it + hf_state_dict = sd_adapter.to_hf(state_dict) + dcp.load( + hf_state_dict, + storage_reader=HuggingFaceStorageReader(path=input_dir), + ) + # convert state dict format back hf->tt and save + state_dict = sd_adapter.from_hf(hf_state_dict) + dcp.save( + state_dict, + checkpoint_id=output_dir, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert HF checkpoint to DCP format.") + parser.add_argument( + "input_dir", type=Path, help="Input directory with HF checkpoint" + ) + parser.add_argument("output_dir", type=Path, help="Output directory for DCP.") + parser.add_argument("--model_name", type=str, nargs="?", default="llama3") + parser.add_argument("--model_flavor", type=str, nargs="?", default="8B") + args = parser.parse_args() + + convert_from_hf( + args.input_dir, + args.output_dir, + args.model_name, + args.model_flavor, + ) diff --git a/scripts/convert_llama_to_dcp.py b/scripts/checkpoint_conversion/convert_from_llama.py similarity index 97% rename from scripts/convert_llama_to_dcp.py rename to scripts/checkpoint_conversion/convert_from_llama.py index 02f371c0c8..9a6e1b1db3 100644 --- a/scripts/convert_llama_to_dcp.py +++ b/scripts/checkpoint_conversion/convert_from_llama.py @@ -14,7 +14,7 @@ @torch.inference_mode() -def convert_llama_weights(input_dir, output_dir, max_seq_len: int): +def convert_from_llama(input_dir, output_dir, max_seq_len: int): with open(input_dir / "params.json", "r") as f: params = json.load(f) n_layers = params["n_layers"] @@ -143,4 +143,4 @@ def convert_llama_weights(input_dir, output_dir, max_seq_len: int): ) args = parser.parse_args() - convert_llama_weights(args.input_dir, args.output_dir, max_seq_len=args.max_seq_len) + convert_from_llama(args.input_dir, args.output_dir, max_seq_len=args.max_seq_len) diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py new file mode 100644 index 0000000000..800b350789 --- /dev/null +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from pathlib import Path + +import torch +import torch.distributed.checkpoint as dcp +import torchtitan.protocols.train_spec as train_spec_module +from torch.distributed.checkpoint import HuggingFaceStorageWriter +from torchtitan.components.checkpoint import ModelWrapper + + +@torch.inference_mode() +def convert_to_hf(input_dir, output_dir, model_name, model_flavor): + # load model and model args so that we can get the state dict shape + train_spec = train_spec_module.get_train_spec(model_name) + model_args = train_spec.model_args[model_flavor] + + with torch.device("cpu"): + model = train_spec.model_cls(model_args) + model = ModelWrapper(model) + + sd_adapter = train_spec.state_dict_adapter(model_args) + assert ( + sd_adapter is not None + ), "trying to convert checkpoint from DCP to HF safetensors format, but sd_adapter is not provided." + + # allocate state dict memory with empty weights to load checkpoint + state_dict = model._get_state_dict() + dcp.load( + state_dict, + checkpoint_id=input_dir, + ) + + # convert state dict tt->hf + hf_state_dict = sd_adapter.to_hf(state_dict) + + fqn_to_index_mapping = {} + num_fqns_per_file = 30 + + for i, key in enumerate(hf_state_dict.keys()): + group_num = (i // num_fqns_per_file) + 1 + fqn_to_index_mapping[key] = group_num + + storage_writer = HuggingFaceStorageWriter( + path=output_dir, + save_distributed=True, + fqn_to_index_mapping=fqn_to_index_mapping, + enable_consolidation=True, + thread_count_consolidation=5, + ) + + dcp.save( + hf_state_dict, + storage_writer=storage_writer, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert DCP weights to HF format.") + parser.add_argument( + "input_dir", type=Path, help="Input directory with DCP weights." + ) + parser.add_argument( + "output_dir", type=Path, help="Output directory for HF checkpoint." + ) + parser.add_argument("--model_name", type=str, nargs="?", default="llama3") + parser.add_argument("--model_flavor", type=str, nargs="?", default="8B") + args = parser.parse_args() + + convert_to_hf( + args.input_dir, + args.output_dir, + args.model_name, + args.model_flavor, + ) diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index 1f0f9fb574..03b5efab44 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -18,11 +18,10 @@ from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.models.llama3 import parallelize_llama, pipeline_llama +from torchtitan.protocols import BaseModelArgs, ModelProtocol from torchtitan.protocols.train_spec import ( apply_to_train_specs, - BaseModelArgs, get_train_spec, - ModelProtocol, register_train_spec, TrainSpec, ) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 8eede5fada..5e718a355c 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -37,7 +37,7 @@ from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer from torchtitan.config import Checkpoint as CheckpointConfig, TORCH_DTYPE_MAP -from torchtitan.protocols.state_dict_adapter import StateDictAdapter +from torchtitan.protocols import StateDictAdapter from torchtitan.tools.logging import logger from torchtitan.tools.utils import GarbageCollection @@ -191,7 +191,7 @@ def __init__( lr_schedulers: LRSchedulersContainer, states: dict[str, Any], checkpoint_config: CheckpointConfig, - sd_adapter: type[StateDictAdapter] | None, + sd_adapter: StateDictAdapter | None, base_folder: str = "", ft_manager: FTManager | None = None, ) -> None: @@ -375,6 +375,7 @@ def dcp_save( enable_consolidation=True, thread_count_consolidation=5, ) + else: checkpoint_save_id = checkpoint_id diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 732d4f709f..dcd8782810 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -21,7 +21,7 @@ from torchtitan.tools.utils import Color, device_module, device_type if TYPE_CHECKING: - from torchtitan.protocols.train_spec import BaseModelArgs + from torchtitan.protocols import BaseModelArgs # named tuple for passing device memory stats for logging diff --git a/torchtitan/experiments/deepseek_v3/model_args.py b/torchtitan/experiments/deepseek_v3/model_args.py index b7fd7f1a72..3672c70194 100644 --- a/torchtitan/experiments/deepseek_v3/model_args.py +++ b/torchtitan/experiments/deepseek_v3/model_args.py @@ -12,7 +12,7 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config_manager import JobConfig -from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.protocols import BaseModelArgs from torchtitan.tools.logging import logger from torchtitan.tools.utils import has_cuda_capability diff --git a/torchtitan/experiments/flux/model/args.py b/torchtitan/experiments/flux/model/args.py index 3786a255a8..3ea643053d 100644 --- a/torchtitan/experiments/flux/model/args.py +++ b/torchtitan/experiments/flux/model/args.py @@ -10,7 +10,7 @@ from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams -from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.protocols import BaseModelArgs from torchtitan.tools.logging import logger diff --git a/torchtitan/experiments/flux/model/model.py b/torchtitan/experiments/flux/model/model.py index 1908831e95..b8429878d6 100644 --- a/torchtitan/experiments/flux/model/model.py +++ b/torchtitan/experiments/flux/model/model.py @@ -16,7 +16,7 @@ timestep_embedding, ) -from torchtitan.protocols.train_spec import ModelProtocol +from torchtitan.protocols import ModelProtocol from .args import FluxModelArgs diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index 75c1b67c12..392e14c94f 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -14,7 +14,7 @@ from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.loss import rescale_accumulated_loss from torchtitan.distributed import ParallelDims, utils as dist_utils -from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.protocols import BaseModelArgs from torchtitan.tools import utils from .job_config import ForgeJobConfig diff --git a/torchtitan/experiments/forge/train_spec.py b/torchtitan/experiments/forge/train_spec.py index e7dc1077c2..f3ab820535 100644 --- a/torchtitan/experiments/forge/train_spec.py +++ b/torchtitan/experiments/forge/train_spec.py @@ -8,13 +8,11 @@ # Import torchtitan.models to ensure all train specs are registered import torchtitan.models # noqa: F401 -from torchtitan.protocols.state_dict_adapter import StateDictAdapter +from torchtitan.protocols import BaseModelArgs, ModelProtocol, StateDictAdapter from torchtitan.protocols.train_spec import ( _train_specs, - BaseModelArgs, LossFunctionBuilder, LRSchedulersBuilder, - ModelProtocol, OptimizersBuilder, ParallelizeFunction, PipeliningFunction, diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index 89818812c9..741f00fd4e 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -10,7 +10,7 @@ from torch import nn from torchtitan.config import JobConfig -from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.protocols import BaseModelArgs from torchtitan.tools.logging import logger from torchtitan.tools.utils import has_cuda_capability diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py index c6f410b38e..4e276efbbc 100644 --- a/torchtitan/experiments/llama4/model/model.py +++ b/torchtitan/experiments/llama4/model/model.py @@ -10,7 +10,7 @@ from torch import nn from torchtitan.models.attention import build_attention, init_attention_mask -from torchtitan.protocols.train_spec import ModelProtocol +from torchtitan.protocols import ModelProtocol from .args import TransformerModelArgs from .moe import MoE diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 2406ee3ad1..9305c1b4d3 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -4,18 +4,123 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import re from typing import Any from torchtitan.protocols.state_dict_adapter import StateDictAdapter +from .args import TransformerModelArgs + class Llama3StateDictAdapter(StateDictAdapter): - @staticmethod - def to_hf(state_dict: dict[str, Any]) -> dict[str, Any]: - # TODO: implement this - return state_dict + def __init__(self, model_args: TransformerModelArgs): + self.model_args = model_args + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + # HuggingFace permutation function (exact copy from their conversion script) + def _permute(self, w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return ( + w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + .clone() + ) + + def _reverse_permute(self, w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return ( + w.view(n_heads_arg, 2, dim1 // n_heads_arg // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + + n_heads = self.model_args.n_heads + n_kv_heads = ( + self.model_args.n_kv_heads + if self.model_args.n_kv_heads is not None + else n_heads + ) + dim = self.model_args.dim + head_dim = dim // n_heads + hf_state_dict = {} + + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = to_hf_map[abstract_key] + # We need to permute the weights in wq and wk layer in order to account for the difference between + # the native Llama and huggingface RoPE implementation. + if abstract_key == "layers.{}.attention.wq.weight": + value = self._permute(value, n_heads) + if abstract_key == "layers.{}.attention.wk.weight": + key_value_dim = head_dim * n_kv_heads + value = self._permute(value, n_kv_heads, key_value_dim, dim) + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = to_hf_map[key] + + hf_state_dict[new_key] = value - @staticmethod - def from_hf(hf_state_dict: dict[str, Any]) -> dict[str, Any]: - # TODO: implement this return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + n_heads = self.model_args.n_heads + n_kv_heads = ( + self.model_args.n_kv_heads + if self.model_args.n_kv_heads is not None + else n_heads + ) + dim = self.model_args.dim + head_dim = dim // n_heads + state_dict = {} + + for key, value in hf_state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = self.from_hf_map[abstract_key] + + # We need to permute the weights in wq and wk layer in order to account for the difference between + # the native Llama and huggingface RoPE implementation. + if abstract_key == "model.layers.{}.self_attn.q_proj.weight": + value = self._reverse_permute(value, n_heads) + if abstract_key == "model.layers.{}.self_attn.k_proj.weight": + key_value_dim = head_dim * n_kv_heads + value = self._reverse_permute(value, n_kv_heads, key_value_dim, dim) + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = self.from_hf_map[key] + + state_dict[new_key] = value + return state_dict diff --git a/torchtitan/protocols/__init__.py b/torchtitan/protocols/__init__.py new file mode 100644 index 0000000000..2d1b283f11 --- /dev/null +++ b/torchtitan/protocols/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .model import BaseModelArgs, ModelProtocol +from .model_converter import ModelConverter, ModelConvertersContainer +from .state_dict_adapter import StateDictAdapter + +__all__ = [ + "BaseModelArgs", + "ModelProtocol", + "ModelConverter", + "ModelConvertersContainer", + "StateDictAdapter", +] diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py new file mode 100644 index 0000000000..a4f28bc895 --- /dev/null +++ b/torchtitan/protocols/model.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import abstractmethod +from dataclasses import dataclass +from typing import Protocol + +import torch +import torch.nn as nn + +from torchtitan.config import JobConfig + + +@dataclass +class BaseModelArgs: + """All ModelArgs should inherit from this class. + + The only usage of this class is type checking but allows us to extend common + arguments to all models in the future. + """ + + _enforced: str = "This field is used to enforce all fields have defaults." + + @abstractmethod + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + pass + + @abstractmethod + def get_nparams_and_flops( + self, model: nn.Module, seq_len: int + ) -> tuple[int, float]: + pass + + +class ModelProtocol(Protocol): + """Defines the interface for a model class. + + This is used to enforce that all model classes have some methods that are + required by the trainer. + """ + + def __init__(self, model_args: BaseModelArgs) -> None: + pass + + @abstractmethod + def init_weights(self, buffer_device: torch.device | None = None) -> None: + """Initialize model weights. + + Args: + buffer_device: Optional device to place buffers on during initialization. + """ + pass diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index bd22c8d9ba..f72efdea32 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -7,6 +7,8 @@ from abc import ABC, abstractmethod from typing import Any +from torchtitan.protocols import BaseModelArgs + class StateDictAdapter(ABC): """Abstract base class for state dict transformations. @@ -15,9 +17,12 @@ class StateDictAdapter(ABC): state dict format and other model state dict formats. """ - @staticmethod @abstractmethod - def to_hf(state_dict: dict[str, Any]) -> dict[str, Any]: + def __init__(self, model_args: BaseModelArgs): + pass + + @abstractmethod + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """Convert from native model state dict to HuggingFace format. Args: @@ -28,9 +33,8 @@ def to_hf(state_dict: dict[str, Any]) -> dict[str, Any]: """ pass - @staticmethod @abstractmethod - def from_hf(hf_state_dict: dict[str, Any]) -> dict[str, Any]: + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """Obtain native model state dict from HuggingFace format. Args: diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index db5c923011..8420abc555 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -4,12 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from abc import abstractmethod from collections.abc import Callable from dataclasses import dataclass -from typing import Protocol, TypeAlias +from typing import TypeAlias -import torch import torch.nn as nn from torch.distributed.pipelining.schedules import _PipelineSchedule @@ -20,49 +18,8 @@ from torchtitan.components.optimizer import OptimizersContainer from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.components.validate import BaseValidator -from torchtitan.config import JobConfig, LRScheduler -from torchtitan.protocols.state_dict_adapter import StateDictAdapter - - -@dataclass -class BaseModelArgs: - """All ModelArgs should inherit from this class. - - The only usage of this class is type checking but allows us to extend common - arguments to all models in the future. - """ - - _enforced: str = "This field is used to enforce all fields have defaults." - - @abstractmethod - def update_from_config(self, job_config: JobConfig, **kwargs) -> None: - pass - - @abstractmethod - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: - pass - - -class ModelProtocol(Protocol): - """Defines the interface for a model class. - - This is used to enforce that all model classes have some methods that are - required by the trainer. - """ - - def __init__(self, model_args: BaseModelArgs) -> None: - pass - - @abstractmethod - def init_weights(self, buffer_device: torch.device | None = None) -> None: - """Initialize model weights. - - Args: - buffer_device: Optional device to place buffers on during initialization. - """ - pass +from torchtitan.config import LRScheduler +from torchtitan.protocols import BaseModelArgs, ModelProtocol, StateDictAdapter ParallelizeFunction: TypeAlias = Callable[..., nn.Module] diff --git a/torchtitan/train.py b/torchtitan/train.py index 4123dd9234..de2ef71a34 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -302,7 +302,11 @@ def __init__(self, job_config: JobConfig): lr_schedulers=self.lr_schedulers, states={"train_state": self}, checkpoint_config=job_config.checkpoint, - sd_adapter=self.train_spec.state_dict_adapter, + sd_adapter=( + self.train_spec.state_dict_adapter(model_args) + if self.train_spec.state_dict_adapter + else None + ), base_folder=job_config.job.dump_folder, ft_manager=self.ft_manager, ) @@ -407,16 +411,11 @@ def forward_backward_step( ) if self.pp_has_first_stage: self.pp_schedule.step( - inputs, - target=targets, - losses=losses, - input_batch=inputs, + inputs, target=targets, losses=losses, input_batch=inputs ) else: self.pp_schedule.step( - target=targets, - losses=losses, - input_batch=inputs, + target=targets, losses=losses, input_batch=inputs ) # accumulate losses across pipeline microbatches @@ -431,7 +430,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, eos_id=self.tokenizer.eos_id) + pred = model_parts[0](inputs, self.tokenizer.eos_id) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred From 58afc5f6e9adc023adcfeaa93474f1beb0a3500e Mon Sep 17 00:00:00 2001 From: Sangmin Bae <50742281+raymin0223@users.noreply.github.com> Date: Thu, 24 Jul 2025 16:44:40 -0700 Subject: [PATCH 034/128] Fix incorrect mapping of ffn_norm and attention_norm in HF Llama4 conversion script (#1455) This PR aims to fix minor, wrong mapping between HF and Titan. Refer to #1454. --- .../experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py index 03bb3706e3..bad69c0f7a 100644 --- a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py +++ b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py @@ -62,7 +62,7 @@ def convert_to_titan_fqns(fqn: str) -> list[str]: return [f"layers.{layer}.moe.shared_expert.w3"] elif "feed_forward.shared_expert.up_proj.weight" in fqn: return [f"layers.{layer}.moe.shared_expert.w1"] - elif "input_layernorm.weight" in fqn: + elif "post_attention_layernorm.weight" in fqn: return [f"layers.{layer}.ffn_norm.weight"] elif "self_attn.k_proj" in fqn: return [f"layers.{layer}.attention.wk.weight"] @@ -72,7 +72,7 @@ def convert_to_titan_fqns(fqn: str) -> list[str]: return [f"layers.{layer}.attention.wq.weight"] elif "self_attn.v_proj" in fqn: return [f"layers.{layer}.attention.wv.weight"] - elif "post_attention_layernorm.weight" in fqn: + elif "input_layernorm.weight" in fqn: return [f"layers.{layer}.attention_norm.weight"] else: raise ValueError(f"Unknown fqn {fqn}") From 38a9d302b44364587e312bde73411d0ce772fdc3 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Thu, 24 Jul 2025 18:29:25 -0700 Subject: [PATCH 035/128] publish instructions on adding a new model (#1451) as titled --- README.md | 3 +- torchtitan/experiments/README.md | 4 +- torchtitan/models/README.md | 78 ++++++++++++++++++++++ torchtitan/protocols/state_dict_adapter.py | 2 +- torchtitan/protocols/train_spec.py | 4 +- 5 files changed, 86 insertions(+), 5 deletions(-) create mode 100644 torchtitan/models/README.md diff --git a/README.md b/README.md index 5c5c30720a..be0dbebc02 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ To use the latest features of `torchtitan`, we recommend using the most recent P ## Latest News +- [2025/07] We published [instructions](/torchtitan/models/README.md) on how to add a model to `torchtitan`. - [2025/07] We released `torchtitan` [v0.1.0](https://github.com/pytorch/torchtitan/releases), and also set up nightly builds. - [2025/04] Our paper was accepted by [ICLR 2025](https://iclr.cc/virtual/2025/poster/29620). - [2025/04] [Llama 4](torchtitan/experiments/llama4/) initial support is available as an experiment. @@ -37,7 +38,7 @@ To use the latest features of `torchtitan`, we recommend using the most recent P Our mission is to accelerate innovation in the field of generative AI by empowering researchers and developers to explore new modeling architectures and infrastructure techniques. -The guiding principles when building `torchtitan` +The Guiding Principles when building `torchtitan` * Designed to be easy to understand, use and extend for different training purposes. * Minimal changes to the model code when applying multi-dimensional parallelism. * Bias towards a clean, minimal codebase while providing basic reusable / swappable components. diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 53bf12ca27..2eadf7521e 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -4,8 +4,8 @@ To accelerate contributions to and innovations around `torchtitan`, we are addin We provide this `experiments/` folder to host experiments that add significant value to `torchtitan`, with the following principles. We refer to the part of `torchtitan` outside `experiments` as `core`. 1. Each subfolder in `experiments` will be an experiment, with a clear theme which can be flexible, such as - - a new model, or preferably a new model architecture, with its training infrastructure including parallelization functions; - - an enhancement or addition to the existing infrastructure of `torchtitan`. + - A new model, or preferably a new model architecture, with its training infrastructure including parallelization functions. Please see the [instructions](/torchtitan/models/README.md) on how to contribute a new model. + - An enhancement or addition to the existing infrastructure of `torchtitan`. 2. It is the contributors' responsibility to justify the value of an experiment. `torchtitan` team will review proposals on a case-by-case basis. As part of the contribution, the contributors should provide documentation that clearly showcases the motivation and innovation of an experiment, including reports on performance and loss convergence. 3. An experiment should reuse existing `torchtitan` code as much as possible, such as modules in [`components/`](../components/) (via a new [`TrainSpec`](../protocols/train_spec.py)) and [`train.py`](../train.py). For a list of extension points we provide, please refer to [docs/extension.md](../../docs/extension.md). - The extension points are subject to change. We kindly request that contributors provide feedback if they encounter issues reusing any components, rather than simply using a copy-and-paste approach. diff --git a/torchtitan/models/README.md b/torchtitan/models/README.md new file mode 100644 index 0000000000..a007c6cb94 --- /dev/null +++ b/torchtitan/models/README.md @@ -0,0 +1,78 @@ +This note outlines the process of adding a new model in the `torchtitan` repo. In most cases, new models should be added first under the `torchtitan/experiments` folder. For criteria of contributions, please see the [Contributing Guidelines](/torchtitan/experiments/README.md) therein. In general, please adhere to the [Guiding Principles](/README.md#overview) of `torchtitan`. + +For offline explorations, we recommend the same steps, unless otherwise noted. + +## Adding the model + +Please refer to the [Llama 3 folder](.llama3) as an example. + +The folder should be organized as follows +- `model` folder: a self-contained folder of model definition and args + - `args.py` + - Inherit [`BaseModelArgs`](/torchtitan/protocols/model.py) and implement the interfaces. + - `get_nparams_and_flops()` will be used to understand model size and compute throughput. + - `update_from_config()` updates the model args from training configs. To extend training configs, see the bullet point below on `job_config.py`. + - `model.py` + - NOTE: Please adhere to the guiding principles and write single-device model code. + - NOTE: We prioritize readability over flexibility. The preferred style is to not share modules among different models, except for the most common and complicated ones. + - Inherit [`ModelProtocol`](/torchtitan/protocols/model.py) and implement the interfaces. + - `__init__()` consumes a `ModelArgs` input to build the model + - `init_weights()` is used to properly initialize the parameters and buffers in the model. Please define it in a recursive way so that every submodule has its own `init_weights()`. + - Add additional files to reduce the complexity of `model.py` if it grows too large or complex, e.g. moe.py to host the `MoE`, `Router`, and `GroupedExperts` modules. + - `state_dict_adapter.py` + - Inherit [`StateDictAdapter`](/torchtitan/protocols/state_dict_adapter.py) to implement state dict mappings between `torchtitan` model definition and other model definitions (e.g. from HuggingFace so that we can save / load model checkpoints in HF formats). + - There are multiple ways such adapters could be used + - Checkpoint conversion scripts in `scripts/checkpoint_conversion/` will use them to adapt state dicts containing non-sharded `torch.Tensor` on CPU. + - During training, [`CheckpointManager`](/torchtitan/components/checkpoint.py) will use them to adapt state dicts containing (potentially sharded) `DTensor` on GPUs to save / load checkpoints in HF format. + - In post-training, `to_hf()` helps convert a torchtitan model to HF model, which can be used for inference by other frameworks. + - This is optional for offline exploration. +- `infra` folder: containing the functions used to parallelize the model using PyTorch native techniques + - `parallelize.py` + - apply training techniques in the following order + - TP (and EP if the model has MoE architecture) + - activation checkpointing + - `torch.compile` + - FSDP / HSDP + - NOTE: currently CP support for language models is enabled via a context manager in `torchtitan/train.py`. Ideally no extra work is needed to enable CP. + - `pipeline.py` (optional if model size is small) + - apply PP + - Include other util files if necessary. +- `__init__.py` + - A dictionary of the actual model configurations, of the type `[str: ModelArgs]`. + - Call `register_train_spec` to specify a [`TrainSpec`](/torchtitan/protocols/train_spec.py), consisting a tuple of + - model name, model class, model args + - parallelizing function, pipelining function + - builder functions for optimizer, lr scheduler, data loader, tokenizer, and loss function + - More often than not, existing components can be reused. + - Adding new datasets requires the `torchtitan` team’s review and legal approval. + - Try to have minimal dependency on external libraries, if any. + - state dict adapter + - Read [more](/docs/extension.md#trainspec) on `TrainSpec`. +- `README.md` + - Include [instructions](/README.md#downloading-a-tokenizer) to download tokenizers / encoders. + - Include instructions to download model checkpoints for continued pretraining or post training. + - Update the current status of development, including the supported features and coming features. + - This is optional for offline exploration. +- `job_config.py` (if necessary) + - Sometimes a new model needs to access additional configs, to be consumed by various training components. Read the [guidance](/docs/extension.md#train-script) on extending `JobConfig`. +- `train.py` (only if absolutely necessary) + - Sometimes `torchtitan/train.py` may not be enough to run the model. There is a [tradeoff](/docs/extension.md#train-script) between extending the existing one vs. having a new one. + - Even if a new one needs to be added, it should reuse `torchtitan/train.py` as much as possible. See `torchtitan/experiments/flux/train.py` as an example. +- `train_configs` folder + - There should be one `.toml` file for each model variant (e.g. Llama 3.1 8B / 70B / 405B) as well as a `debug_model.toml`. + - They should be verified with real training jobs, in terms of optimized throughput and loss converging. + +## Testing and Benchmarking +- Numerics testing + - One way of doing this E2E is to load the same model checkpoint into the `torchtitan` model and the HF model, and compare the model output given the same input. This assumes + - HF implementation is correct. + - The correctness of a `torchtitan` model and the corresponding state dict adapter together indicates the correctness of both. +- Loss converging + - If there is a verified baseline, compare the loss curves with the baseline. + - For comparisons within `torchtitan`, see the [guidelines](/docs/converging.md). +- Performance benchmarking + - Please refer to the [benchmarks](/benchmarks/) folder. +- CI tests + - Including unit tests and integration tests, see [examples](/tests/). + - If the model folder is under the experiments folder, put the tests under the model folder. Otherwise, put the tests under the `/tests` folder. + - Add necessary GitHub [workflows](/.github/workflows/). diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index f72efdea32..9bcbfc0463 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from typing import Any -from torchtitan.protocols import BaseModelArgs +from .model import BaseModelArgs class StateDictAdapter(ABC): diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 8420abc555..8a782f8b42 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -19,7 +19,9 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.components.validate import BaseValidator from torchtitan.config import LRScheduler -from torchtitan.protocols import BaseModelArgs, ModelProtocol, StateDictAdapter + +from .model import BaseModelArgs, ModelProtocol +from .state_dict_adapter import StateDictAdapter ParallelizeFunction: TypeAlias = Callable[..., nn.Module] From f3e2a75e151872b3dc346d3bfa37d121e413242d Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 24 Jul 2025 18:31:02 -0700 Subject: [PATCH 036/128] make mxfp8 dim1 cast kernel configurable (#1427) Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (https://github.com/pytorch/ao/pull/2513). The integration work for composability with torch.compile + FSDP is complete as well: https://github.com/pytorch/ao/pull/2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it. --- torchtitan/components/quantization/mx.py | 31 +++++++++++++++--------- torchtitan/config/job_config.py | 2 +- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index f22ac4bd04..f2c6820a70 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -40,10 +40,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): "torchao is not installed. Please install it to use MXFP8 linear layers." ) torchao_version = version("torchao") - mxfp8_min_version = "0.11.0" - if torchao_version < mxfp8_min_version: + + # Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git... + is_nightly_build = torchao_version.startswith("0.13.0") + if not is_nightly_build: raise ImportError( - f"torchao version {torchao_version} is too old, please install torchao {mxfp8_min_version} or later and try again" + f"torchao version {torchao_version} is too old, please install torchao nightly build and try again" ) # Can be removed if we enable the emulated versions @@ -51,19 +53,26 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): 10, 0 ), "MXFP8 is only supported on SM100 or architectures" - self.enabled = True - mx_job_config: MX = job_config.mx - self.filter_fqns = mx_job_config.filter_fqns + # TP not yet supported with torch.compile + assert not ( + job_config.training.compile + and job_config.parallelism.tensor_parallel_degree > 1 + ), "TP not yet supported with torch.compile for mxfp8" # Configure MXFP8 - from torchao.prototype.mx_formats.config import MXLinearConfig + from torchao.prototype.mx_formats.config import ( + MXFP8Dim1CastKernelChoice, + MXLinearConfig, + ) + mx_job_config: MX = job_config.mx config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name]) - config.use_fp8_dim1_cast_triton_kernel = ( - mx_job_config.use_fp8_dim1_cast_triton_kernel - ) + config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[ + mx_job_config.mxfp8_dim1_cast_kernel_choice.upper() + ] + self.filter_fqns = mx_job_config.filter_fqns self.config = config - + self.enabled = True logger.info(f"Float8 training active with recipe {mx_job_config.recipe_name}") def convert(self, model: nn.Module): diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index fdf38c63fe..d673999810 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -534,7 +534,7 @@ class Float8: @dataclass class MX: - use_fp8_dim1_cast_triton_kernel: bool = True + mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton" """Temp work around for inductor performance gap""" recipe_name: Literal["mxfp8"] = "mxfp8" From 8a7b4aa6dabbac207d56f100da34fc739f3ed138 Mon Sep 17 00:00:00 2001 From: Danning XIE <24580222+DNXie@users.noreply.github.com> Date: Sun, 27 Jul 2025 17:20:59 -0700 Subject: [PATCH 037/128] Fix a none pointer exception in checkpoint.py (#1465) This PR fixes a potential `NoneType` attribute error in the checkpoint purging code. Previously, when the regex did not find a match (i.e., there exists a file of other format), the code would attempt to access `match.group(1)` and raise an none pointer exception: ``` [rank0]: Traceback (most recent call last): [rank0]: File "~/.conda/envs/forge/lib/python3.10/runpy.py", line 196, in _run_module_as_main [rank0]: return _run_code(code, main_globals, None, [rank0]: File "~/.conda/envs/forge/lib/python3.10/runpy.py", line 86, in _run_code [rank0]: exec(code, run_globals) [rank0]: File "~/forge/apps/sft/main.py", line 271, in [rank0]: sys.exit(recipe_main()) [rank0]: File "~/forge/forge/config/parse.py", line 174, in wrapper [rank0]: sys.exit(recipe_main(conf)) [rank0]: File "~/forge/apps/sft/main.py", line 266, in recipe_main [rank0]: recipe.train() [rank0]: File "~/forge/apps/sft/main.py", line 233, in train [rank0]: self.checkpointer.save( [rank0]: File "~/.conda/envs/forge/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: File "~/torchtitan/torchtitan/components/checkpoint.py", line 507, in save [rank0]: self._purge_stale_checkpoints() [rank0]: File "~/torchtitan/torchtitan/components/checkpoint.py", line 795, in _purge_stale_checkpoints [rank0]: discovered_checkpoints.append((int(match.group(1)), path)) [rank0]: AttributeError: 'NoneType' object has no attribute 'group' ``` --- torchtitan/components/checkpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 5e718a355c..68143a35a4 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -787,8 +787,9 @@ def _purge_stale_checkpoints(self): discovered_checkpoints = [] for filename in os.listdir(self.folder): match = re.search(r"step-(\d+)", filename) - path = os.path.join(self.folder, filename) - discovered_checkpoints.append((int(match.group(1)), path)) + if match: + path = os.path.join(self.folder, filename) + discovered_checkpoints.append((int(match.group(1)), path)) discovered_checkpoints.sort() to_delete = discovered_checkpoints[: -1 * self.keep_latest_k] From 1fefaee086adc018ab84566552647e84067c165a Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 28 Jul 2025 09:26:30 -0400 Subject: [PATCH 038/128] remove float8 force_recompute_fp8_weight_in_bwd flag (#1452) Summary: This flag has been deprecated in https://github.com/pytorch/ao/pull/2356, deleting it from torchtitan to prepare for future deletion from torchao. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: --- docs/float8.md | 3 +-- tests/integration_tests.py | 1 - tests/integration_tests_h100.py | 3 --- torchtitan/components/quantization/float8.py | 6 ------ torchtitan/config/job_config.py | 8 -------- 5 files changed, 1 insertion(+), 20 deletions(-) diff --git a/docs/float8.md b/docs/float8.md index a3d806c928..5d90e0617e 100644 --- a/docs/float8.md +++ b/docs/float8.md @@ -11,12 +11,11 @@ USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git For float8 with tensorwise scaling, launch training job with the following command (or alternatively set configs in toml files) ``` -CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --training.compile +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --training.compile ``` * `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. * `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth. * `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. -* `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward. * `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using. * **Auto-filter**: add `"auto_filter_small_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers are not large enough to benefit from float8 training, since the GEMM has to be big enough that the speedup from using FP8 tensorcores is greater than the overhead of creating dynamically quantized inputs. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs, where (K,N) represents the linear layer weight shape. For best performance, you should still manually filter out layers that are too small to benefit from float8 training. * `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 8f19086096..19617ae558 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -517,7 +517,6 @@ def build_test_list(): "--model.converters float8", "--float8.enable_fsdp_float8_all_gather", "--float8.precompute_float8_dynamic_scale_for_fsdp", - "--float8.force_recompute_fp8_weight_in_bwd", "--float8.emulate", ], ], diff --git a/tests/integration_tests_h100.py b/tests/integration_tests_h100.py index f12f3c07b8..29c11476b5 100755 --- a/tests/integration_tests_h100.py +++ b/tests/integration_tests_h100.py @@ -46,7 +46,6 @@ def build_test_list(): "--model.converters float8", "--float8.enable_fsdp_float8_all_gather", "--float8.precompute_float8_dynamic_scale_for_fsdp", - "--float8.force_recompute_fp8_weight_in_bwd", ], ], "Float8 test", @@ -63,7 +62,6 @@ def build_test_list(): "--model.converters float8", "--float8.enable_fsdp_float8_all_gather", "--float8.precompute_float8_dynamic_scale_for_fsdp", - "--float8.force_recompute_fp8_weight_in_bwd", ] ], "FSDP+async TP+PP+torch.compile+Float8", @@ -80,7 +78,6 @@ def build_test_list(): "--model.converters float8", "--float8.enable_fsdp_float8_all_gather", "--float8.precompute_float8_dynamic_scale_for_fsdp", - "--float8.force_recompute_fp8_weight_in_bwd", ] ], "HSDP+CP+torch.compile+Float8", diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index ca0b38e660..863ea266fc 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -72,11 +72,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): "with `float8_config.recipe_name` is not supported" ) - assert not float8_config.force_recompute_fp8_weight_in_bwd, ( - "using `float8_config.force_recompute_fp8_weight_in_bwd` together " - "with `float8_config.recipe_name` is not supported" - ) - self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name) self.precompute_scale = False logger.info( @@ -97,7 +92,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ) self.config = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, - force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd, emulate=float8_config.emulate, ) # for precompute_float8_dynamic_scale_for_fsdp diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index d673999810..4d204793a8 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -499,14 +499,6 @@ class Float8: precompute_float8_dynamic_scale_for_fsdp: bool = False """Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling""" - force_recompute_fp8_weight_in_bwd: bool = False - """ - Whether to force the recomputation of FP8 weights during backward pass. - When using FSDP with tensorwise scaling, it is recommended to enable - `force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights - for backward computation. - """ - recipe_name: Literal["tensorwise", "rowwise", "rowwise_with_gw_hp"] | None = None """If specified, creates float8 config from recipe name""" From a44dff1a41f6c0d8e504919ce4b1b50d05102f01 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 28 Jul 2025 16:55:16 -0700 Subject: [PATCH 039/128] [checkpoint] let user specify `intial_load_path` and `initial_load_in_hf` when using HF checkpoints (#1466) This PR - removes `_load_checkpoint_in_hf_format` and always let user enable `initial_load_in_hf`, which makes it symmetric to `last_save_in_hf` and less ambiguous - require `intial_load_path` to be used, instead of letting user figure out they can download HF checkpoint to `checkpoint.folder / step-0 / ` which is not intuitive. --- docs/checkpoint.md | 4 +-- tests/integration_tests.py | 4 ++- torchtitan/components/checkpoint.py | 45 ++++++++++------------------- torchtitan/config/job_config.py | 44 +++++++++++++++++----------- 4 files changed, 48 insertions(+), 49 deletions(-) diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 00736a7558..3986e3dade 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -68,7 +68,7 @@ NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoi ### HuggingFace `torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu. -1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, simply enable `--checkpoint.initial_load_model_only` and set `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. +1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_model_only` and `--checkpoint.initial_load_in_hf`, and set `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. 2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt. @@ -76,7 +76,7 @@ NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoi python ./scripts/checkpoint_conversion/convert_from_hf.py --model_name --model_flavor python ./scripts/checkpoint_conversion/convert_to_hf.py --model_name --model_flavor # e.g. -python ./scripts/convert_from_hf.py ~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ ./outputs/checkpoint/step-0 --model_name llama3 --model_flavor 8B +python ./scripts/convert_from_hf.py ~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ ./initial_load_path/ --model_name llama3 --model_flavor 8B ``` ### Torch diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 19617ae558..7e1cbb7630 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -123,12 +123,14 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--checkpoint.folder hf_checkpoint", - "--checkpoint.last_save_in_hf", "--checkpoint.last_save_model_only", + "--checkpoint.last_save_in_hf", ], [ "--checkpoint.enable_checkpoint", "--checkpoint.initial_load_path artifacts-to-be-uploaded/model_only_hf_checkpoint/hf_checkpoint/step-10/", + "--checkpoint.initial_load_model_only", + "--checkpoint.initial_load_in_hf", ], ], "Checkpoint Integration Test - save load model only checkpoint in HF definition and format", diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 68143a35a4..5b649f5a80 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -196,12 +196,6 @@ def __init__( ft_manager: FTManager | None = None, ) -> None: self.enable_checkpoint = checkpoint_config.enable_checkpoint - self.last_save_in_hf = checkpoint_config.last_save_in_hf - if self.last_save_in_hf: - assert ( - sd_adapter is not None - ), "job_config.checkpoint.last_save_in_hf is True, but sd_adapter is not provided." - self.sd_adapter = sd_adapter self.ft_manager = ( ft_manager.manager if ft_manager and ft_manager.enabled else None @@ -257,9 +251,16 @@ def load_state_dict(state_dict): self.folder = os.path.join(base_folder, checkpoint_config.folder) # Checkpoint policy related fields. - self.initial_load_path = checkpoint_config.initial_load_path self.initial_load_model_only = checkpoint_config.initial_load_model_only + self.initial_load_in_hf = checkpoint_config.initial_load_in_hf + self.initial_load_path = checkpoint_config.initial_load_path self.last_save_model_only = checkpoint_config.last_save_model_only + self.last_save_in_hf = checkpoint_config.last_save_in_hf + if self.last_save_in_hf: + assert ( + sd_adapter is not None + ), "job_config.checkpoint.last_save_in_hf is True, but sd_adapter is not provided." + self.sd_adapter = sd_adapter self.export_dtype = TORCH_DTYPE_MAP[checkpoint_config.export_dtype] self.exclude_from_loading = checkpoint_config.exclude_from_loading self.interval = checkpoint_config.interval @@ -536,6 +537,7 @@ def load(self, step: int = -1) -> bool: return False model_only = False + from_hf = False if not os.path.exists(self.folder): if self.initial_load_path: checkpoint_id = self.initial_load_path @@ -544,13 +546,18 @@ def load(self, step: int = -1) -> bool: "checkpoint.initial_load_path is specified but the path is not valid." ) model_only = self.initial_load_model_only + from_hf = self.initial_load_in_hf + if from_hf: + assert ( + model_only + ), "Only model can be loaded when loading from HF's safetensors checkpoint." else: return False else: if self.initial_load_path: - logger.info( + logger.warning( "checkpoint.initial_load_path is provided but the checkpoint.folder exists. " - "Checkpointer will use the checkpoints from the checkpoint.folder." + f"Checkpointer will use the checkpoints from the checkpoint.folder {self.folder}." ) step = self._find_load_step() if step == -1 else step if step == -1: @@ -563,11 +570,6 @@ def load(self, step: int = -1) -> bool: f"--checkpoint.load_step={step} but checkpoint {checkpoint_id} is not found." ) - from_hf = self._load_checkpoint_in_hf_format(checkpoint_id) - if from_hf: - assert ( - model_only - ), "Only model can be loaded when loading from HF's safetensors checkpoint." logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) @@ -622,21 +624,6 @@ def _find_load_step(self, folder: str = "") -> int: return -1 return max(step_counts) - def _load_checkpoint_in_hf_format(self, checkpoint_id: str) -> bool: - """Find the checkpoint type for the given id. - - Args: - checkpoint_id (str): The folder to find the checkpoint type for. - - Returns: - CheckpointType: The checkpoint type for the given folder. - """ - - for filename in os.listdir(checkpoint_id): - if filename.endswith(".safetensors"): - return True - return False - def _ft_folder(self) -> str: return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 4d204793a8..f61eee1495 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -361,6 +361,9 @@ class Checkpoint: When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}. """ + interval: int = 500 + """Checkpointing interval in steps.""" + initial_load_path: str | None = None """ This option specifies the path to the initial checkpoint to load, which is @@ -382,13 +385,20 @@ class Checkpoint: This option specifies if only the model should be loaded during the initial checkpoint load. The option is only used when `initial_load_path` is specified. If False, the checkpoint at `initial_load_path` is treated as a standard training - checkpoint, including optimizer and training states. + checkpoint, including optimizer, lr scheduler, training states, etc. The default setting for this option is True. Note that you will have to use `--checkpoint.no_initial_load_model_only` to override the default setting. """ - interval: int = 500 - """Checkpointing interval in steps.""" + initial_load_in_hf: bool = False + """ + Enable the use of HuggingFace's safetensors format for checkpointing. The option + is only used when `initial_load_path` is specified. This will load checkpoints + in HF's model definition and safetensors format instead of the default torchtitan + model definition and DCP format, after necessary model state dict transformation. + `initial_load_model_only` must be true because safetensors doesn't support saving + non-tensors. The default value is False. + """ last_save_model_only: bool = True """ @@ -399,16 +409,20 @@ class Checkpoint: The default value is True. """ - export_dtype: Literal["float16", "bfloat16", "float32"] = "float32" + last_save_in_hf: bool = False """ - Converts to the specified precision when training completes and last_save_model_only=true. + Enable the use of Hugging Face's safetensors format for checkpointing. This will save the + final checkpoints in safetensors format instead of the default DCP format, after necessary + model state dict transformation. There will be a performance cost in using this as we need + to consolidate the sharded tensors to full tensors as a separate step. + last_save_model_only must be true because safetensors doesn't support saving + non-tensors. On load, this argument isn't needed as we will detect whether the loaded + checkpoint is in safetensors format or not. The default value is False. """ - create_seed_checkpoint: bool = False + export_dtype: Literal["float16", "bfloat16", "float32"] = "float32" """ - Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint. - Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1. - Could be implemented as a separate script, but this way shares more code. + Converts to the specified precision when training completes and last_save_model_only=true. """ async_mode: Literal["disabled", "async", "async_with_pinned_mem"] = "disabled" @@ -453,15 +467,11 @@ class Checkpoint: for many steps or checkpointing too frequently. The default value is False. """ - last_save_in_hf: bool = False + create_seed_checkpoint: bool = False """ - Enable the use of Hugging Face's safetensors format for checkpointing. This will save the - final checkpoints in safetensors format instead of the default DCP format, after necessary - model state dict transformation. There will be a performance cost in using this as we need - to consolidate the sharded tensors to full tensors as a separate step. - last_save_model_only must be true because safetensors doesn't support saving - non-tensors. On load, this argument isn't needed as we will detect whether the loaded - checkpoint is in safetensors format or not. The default value is False. + Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint. + Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1. + Could be implemented as a separate script, but this way shares more code. """ From f26179e74b274972fef224fc2023c088f6231ed1 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 28 Jul 2025 21:51:48 -0400 Subject: [PATCH 040/128] Re-enable pipeline parallel tests (#1477) These should be fixed now that https://github.com/pytorch/pytorch/pull/159084 has landed --- tests/integration_tests.py | 93 +++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 47 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 7e1cbb7630..c3ebf64afc 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -157,41 +157,40 @@ def build_test_list(): "Checkpoint Integration Test - Save Model Only bf16", "last_save_model_only_bf16", ), - # TODO: re-enable PP tests once the issue is fixed - # OverrideDefinitions( - # [ - # [ - # "--parallelism.pipeline_parallel_degree 4", - # "--parallelism.pipeline_parallel_schedule InterleavedZeroBubble", - # ], - # ], - # "PP looped zero bubble test", - # "pp_looped_zero_bubble", - # ngpu=4, - # ), - # OverrideDefinitions( - # [ - # [ - # "--parallelism.pipeline_parallel_degree 2", - # "--parallelism.pipeline_parallel_schedule ZBVZeroBubble", - # ], - # ], - # "PP zero bubble test (v shaped)", - # "pp_zbv", - # ngpu=2, - # ), - # OverrideDefinitions( - # [ - # [ - # "--parallelism.pipeline_parallel_degree 2", - # "--parallelism.pipeline_parallel_schedule 1F1B", - # "--parallelism.data_parallel_shard_degree 1", - # ], - # ], - # "PP 1D test 1F1B", - # "pp_1f1b", - # ngpu=2, - # ), + OverrideDefinitions( + [ + [ + "--parallelism.pipeline_parallel_degree 4", + "--parallelism.pipeline_parallel_schedule InterleavedZeroBubble", + ], + ], + "PP looped zero bubble test", + "pp_looped_zero_bubble", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule ZBVZeroBubble", + ], + ], + "PP zero bubble test (v shaped)", + "pp_zbv", + ngpu=2, + ), + OverrideDefinitions( + [ + [ + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule 1F1B", + "--parallelism.data_parallel_shard_degree 1", + ], + ], + "PP 1D test 1F1B", + "pp_1f1b", + ngpu=2, + ), OverrideDefinitions( [ [ @@ -291,18 +290,18 @@ def build_test_list(): "pp_looped_1f1b", ngpu=4, ), - # OverrideDefinitions( - # [ - # [ - # "--parallelism.pipeline_parallel_degree 2", - # "--parallelism.pipeline_parallel_schedule PipelineScheduleMulti", - # "--parallelism.pipeline_parallel_schedule_csv ./tests/assets/custom_schedule.csv", - # ], - # ], - # "PP with custom pipeline schedule loaded from CSV file", - # "pp_custom_csv", - # ngpu=2, - # ), + OverrideDefinitions( + [ + [ + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule PipelineScheduleMulti", + "--parallelism.pipeline_parallel_schedule_csv ./tests/assets/custom_schedule.csv", + ], + ], + "PP with custom pipeline schedule loaded from CSV file", + "pp_custom_csv", + ngpu=2, + ), OverrideDefinitions( [ [ From 83e6941055fd36434ba50c1a4c7c39cb02327429 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 28 Jul 2025 19:40:21 -0700 Subject: [PATCH 041/128] Improve reshard_after_forward logic (#1094) according to discussions in https://github.com/pytorch/torchtitan/issues/1091 The CI failure is because `FSDPMemTracker` is not compatible of `fully_shard` on a list of modules. @sanketpurandare will help address this soon. Let's land it after the feature is available. --------- Co-authored-by: Jiani Wang --- torchtitan/models/llama3/infra/parallelize.py | 39 ++++++++++++------- torchtitan/tools/utils.py | 12 +++--- torchtitan/train.py | 6 +-- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 9e6e1a85d0..6d9bf60c11 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -385,30 +385,41 @@ def apply_fsdp( if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() - for layer_id, transformer_block in model.layers.items(): - if reshard_after_forward_policy == "always": + match reshard_after_forward_policy: + case "always": reshard_after_forward = True - elif reshard_after_forward_policy == "never": + case "never": reshard_after_forward = False - elif reshard_after_forward_policy == "default": - if pp_enabled: - # For PP, do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = False - else: - # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.layers) - 1 - else: + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: raise ValueError( f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." ) + + if model.tok_embeddings is not None: + fully_shard( + model.tok_embeddings, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + for layer_id, transformer_block in model.layers.items(): fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) - fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + # As an optimization, do not reshard_after_forward the last layers by default + # since FSDP would prefetch them immediately after the forward pass + if model.norm is not None and model.output is not None: + fully_shard( + [model.norm, model.output], + **fsdp_config, + reshard_after_forward=reshard_after_forward_policy == "always", + ) + fully_shard(model, **fsdp_config) def apply_ddp( diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index f31c6a735e..1ef19c123f 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -39,7 +39,7 @@ def __init__(self, gc_freq: int = 1000, debug: bool = False): self.gc_freq = gc_freq self.debug = debug gc.disable() - self.collect("Initial GC collection.") + self.collect("Initial GC collection") if debug: from torch.utils.viz._cycles import warn_tensor_cycles @@ -49,18 +49,18 @@ def __init__(self, gc_freq: int = 1000, debug: bool = False): def run(self, step_count: int): if self.debug: self.collect( - "Force GC to perform collection to obtain debug information.", + "Force GC to perform collection to obtain debug information", generation=2, ) gc.collect() elif step_count > 1 and step_count % self.gc_freq == 0: - self.collect("Peforming periodical GC collection.") + self.collect("Peforming periodical GC collection") @staticmethod def collect(reason: str, generation: int = 1): begin = time.monotonic() gc.collect(generation) - logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin) + logger.info("[GC] %s %.2f seconds", reason, time.monotonic() - begin) # hardcoded BF16 type peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC @@ -165,12 +165,12 @@ def check_if_feature_in_pytorch( if "git" in torch.__version__: # pytorch is built from source # notify users to check if the pull request is included in their pytorch logger.warning( - "detected that the pytorch is built from source. Please make sure the PR " + "Detected that the pytorch is built from source. Please make sure the PR " f"({pull_request_link}) is included in pytorch for correct {feature_name}." ) elif min_nightly_version is not None and torch.__version__ < min_nightly_version: logger.warning( - f"detected that the pytorch version {torch.__version__} is older than " + f"Detected that the pytorch version {torch.__version__} is older than " f"{min_nightly_version}. Please upgrade a newer version to include the " f"change in ({pull_request_link}) for correct {feature_name}." ) diff --git a/torchtitan/train.py b/torchtitan/train.py index de2ef71a34..5f04bdd104 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -350,7 +350,7 @@ def __init__(self, job_config: JobConfig): f"gradient accumulation steps {self.gradient_accumulation_steps}, " f"sequence length {job_config.training.seq_len}, " f"total steps {job_config.training.steps} " - f"(warmup {job_config.lr_scheduler.warmup_steps})." + f"(warmup {job_config.lr_scheduler.warmup_steps})" ) def batch_generator( @@ -501,7 +501,7 @@ def train(self): job_config = self.job_config self.checkpointer.load(step=job_config.checkpoint.load_step) - logger.info(f"Training starts at step {self.step + 1}.") + logger.info(f"Training starts at step {self.step + 1}") leaf_folder = ( "" @@ -611,4 +611,4 @@ def close(self) -> None: else: trainer.close() torch.distributed.destroy_process_group() - logger.info("Process group destroyed.") + logger.info("Process group destroyed") From 942661ce4ea5b466869a2379d2ad89c40df3dc99 Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen <33333409+runame@users.noreply.github.com> Date: Mon, 28 Jul 2025 20:04:26 -0700 Subject: [PATCH 042/128] Log total number of tokens seen (#1474) I added logging of the total number of tokens that have been used for training up to the current `log` call. This is useful for plotting metrics like the loss against the number of tokens instead of steps. Since `metrics_processor` is part of the `Trainer`'s state, this change should be compatible with checkpointing and resuming the count of `ntokens_seen`. --- torchtitan/train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 5f04bdd104..58fc69ac2b 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -63,6 +63,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # additional training states step: int + ntokens_seen: int # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @record @@ -294,6 +295,7 @@ def __init__(self, job_config: JobConfig): # Initialize trainer states that will be saved in checkpoint. # These attributes must be initialized before checkpoint loading. self.step = 0 + self.ntokens_seen = 0 self.checkpointer = CheckpointManager( dataloader=self.dataloader, @@ -369,7 +371,9 @@ def batch_generator( raise DataloaderStopIteration() from ex data_load_start = time.perf_counter() input_dict, labels = batch - self.metrics_processor.ntokens_since_last_log += labels.numel() + ntokens_batch = labels.numel() + self.ntokens_seen += ntokens_batch + self.metrics_processor.ntokens_since_last_log += ntokens_batch self.metrics_processor.data_loading_times.append( time.perf_counter() - data_load_start ) @@ -494,6 +498,7 @@ def train_step( global_avg_loss, global_max_loss, grad_norm.item(), + extra_metrics={"ntokens_seen": self.ntokens_seen}, ) @record @@ -572,10 +577,11 @@ def train(self): logger.info("Training completed") def state_dict(self) -> dict[str, Any]: - return {"step": self.step} + return {"step": self.step, "ntokens_seen": self.ntokens_seen} def load_state_dict(self, state_dict: dict[str, Any]): self.step = state_dict["step"] + self.ntokens_seen = state_dict["ntokens_seen"] def close(self) -> None: if self.checkpointer: From 5bab356c29dfababd8f16ab7d8e3d50cba6326e5 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Mon, 28 Jul 2025 21:05:26 -0700 Subject: [PATCH 043/128] Temporarily Disable Memory Tracking Test for FSDP2 (#1480) As a followup of #1094 , it breaks memory tracking test for FSDP2 --- tests/integration_tests.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index c3ebf64afc..e8ffad1bf7 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -461,16 +461,16 @@ def build_test_list(): "cpu_offload+opt_in_bwd+TP+DP+CP", ngpu=8, ), - OverrideDefinitions( - [ - [ - "--memory_estimation.enabled", - ] - ], - "FSDP2 Memory Tracking and Estimation", - "fsdp2_memory_estimation", - ngpu=2, - ), + # OverrideDefinitions( + # [ + # [ + # "--memory_estimation.enabled", + # ] + # ], + # "FSDP2 Memory Tracking and Estimation", + # "fsdp2_memory_estimation", + # ngpu=2, + # ), OverrideDefinitions( [ [ From 8dd5a7e9b1ed52f195142a21391c1b61c0c3295c Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 29 Jul 2025 12:31:33 -0400 Subject: [PATCH 044/128] Fix tokenizer error message (#1476) We don't use `tokenizer.model` anymore --- torchtitan/components/tokenizer.py | 2 +- torchtitan/config/job_config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index f6908b7772..b0b7146945 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -161,7 +161,7 @@ def _load_tokenizer_from_path(self, tokenizer_path: str) -> Tokenizer: raise FileNotFoundError( f"No supported tokenizer files found in '{tokenizer_path}'. " f"Available files: {available_files}. " - "Looking for: tokenizer.json, tokenizer.model, vocab.txt+merges.txt, or vocab.json+merges.txt" + "Looking for: tokenizer.json, vocab.txt+merges.txt, or vocab.json+merges.txt" ) def _get_token_from_config(self, config: dict[str, Any], key: str) -> Optional[str]: diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index f61eee1495..eaf73bdff8 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -78,7 +78,7 @@ class Model: flavor: str = "debugmodel" """Which model config to train""" - tokenizer_path: str = "./torchtitan/datasets/tokenizer/tokenizer.model" + tokenizer_path: str = "./tests/assets/tokenizer" """Tokenizer path""" converters: list[str] = field(default_factory=list) From 3aa09b9417021fa061c0b1f14e77f7dc0274ca3c Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 29 Jul 2025 11:03:40 -0700 Subject: [PATCH 045/128] log cuda driver version for debugging (#1479) This PR adds: - CUDA driver version logging to the GPU CI jobs. I'm currently looking into a "CUDA driver: invalid argument" issue (#1475) in which it would be helpful to know the driver version. - TORCH_SHOW_CPP_STACKTRACES=1 for the H100 job where async TP tests run, so if the symmetric memory initialization fails we can see exactly where, rather than the inscrutable python level error message. --- .github/workflows/integration_test_8gpu.yaml | 4 ++++ .github/workflows/integration_test_8gpu_flux.yaml | 4 ++++ .github/workflows/integration_test_8gpu_h100.yaml | 8 +++++++- .github/workflows/integration_test_8gpu_simple_fsdp.yaml | 4 ++++ .github/workflows/integration_test_8gpu_torchft.yaml | 4 ++++ 5 files changed, 23 insertions(+), 1 deletion(-) diff --git a/.github/workflows/integration_test_8gpu.yaml b/.github/workflows/integration_test_8gpu.yaml index ecec8190a5..8436854e7f 100644 --- a/.github/workflows/integration_test_8gpu.yaml +++ b/.github/workflows/integration_test_8gpu.yaml @@ -39,6 +39,10 @@ jobs: CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" + # Log CUDA driver version for debugging. + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1) + echo "CUDA driver version: ${DRIVER_VERSION}" + pip config --user set global.progress_bar off python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 diff --git a/.github/workflows/integration_test_8gpu_flux.yaml b/.github/workflows/integration_test_8gpu_flux.yaml index 5ee0cb3534..ffa0dc86be 100644 --- a/.github/workflows/integration_test_8gpu_flux.yaml +++ b/.github/workflows/integration_test_8gpu_flux.yaml @@ -41,6 +41,10 @@ jobs: pip config --user set global.progress_bar off + # Log CUDA driver version for debugging. + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1) + echo "CUDA driver version: ${DRIVER_VERSION}" + python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 mkdir artifacts-to-be-uploaded diff --git a/.github/workflows/integration_test_8gpu_h100.yaml b/.github/workflows/integration_test_8gpu_h100.yaml index 4648c661e8..a170f338d8 100644 --- a/.github/workflows/integration_test_8gpu_h100.yaml +++ b/.github/workflows/integration_test_8gpu_h100.yaml @@ -40,6 +40,10 @@ jobs: CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" + # Log CUDA driver version for debugging. + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1) + echo "CUDA driver version: ${DRIVER_VERSION}" + pip config --user set global.progress_bar off python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 @@ -47,4 +51,6 @@ jobs: USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 mkdir artifacts-to-be-uploaded - python -m tests.integration_tests_h100 artifacts-to-be-uploaded --ngpu 8 + + # Enable CPP stacktraces for debugging symmetric memory initialization errors. + TORCH_SHOW_CPP_STACKTRACES=1 python -m tests.integration_tests_h100 artifacts-to-be-uploaded --ngpu 8 diff --git a/.github/workflows/integration_test_8gpu_simple_fsdp.yaml b/.github/workflows/integration_test_8gpu_simple_fsdp.yaml index 44b29df555..6bca7dac9a 100644 --- a/.github/workflows/integration_test_8gpu_simple_fsdp.yaml +++ b/.github/workflows/integration_test_8gpu_simple_fsdp.yaml @@ -38,6 +38,10 @@ jobs: CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" + # Log CUDA driver version for debugging. + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1) + echo "CUDA driver version: ${DRIVER_VERSION}" + pip config --user set global.progress_bar off python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 diff --git a/.github/workflows/integration_test_8gpu_torchft.yaml b/.github/workflows/integration_test_8gpu_torchft.yaml index 2268170ac2..d249aeaaa6 100644 --- a/.github/workflows/integration_test_8gpu_torchft.yaml +++ b/.github/workflows/integration_test_8gpu_torchft.yaml @@ -38,6 +38,10 @@ jobs: CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" + # Log CUDA driver version for debugging. + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1) + echo "CUDA driver version: ${DRIVER_VERSION}" + pip config --user set global.progress_bar off python -m pip install torchft-nightly From 327a99cc2371964a1160ff4aec15e37806993139 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Tue, 29 Jul 2025 15:50:20 -0500 Subject: [PATCH 046/128] Fixes the sd adapter in forge experiments (#1484) Copied from github.com/pytorch/torchtitan/pull/1441, tested manually via forge --------- Co-authored-by: Allen Wang --- torchtitan/experiments/forge/engine.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index 392e14c94f..398a1c5d5e 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -215,7 +215,11 @@ def __init__(self, job_config: ForgeJobConfig): lr_schedulers=self.lr_schedulers, states={"train_state": self}, checkpoint_config=job_config.checkpoint, - sd_adapter=self.train_spec.state_dict_adapter, + sd_adapter=( + self.train_spec.state_dict_adapter(model_args) + if self.train_spec.state_dict_adapter + else None + ), ) loss_parallel_enabled = ( From 881f0ca465d26ab87ccfc5c89572b8a21c4e9707 Mon Sep 17 00:00:00 2001 From: Shuaipeng Li Date: Thu, 31 Jul 2025 01:10:23 +0800 Subject: [PATCH 047/128] Change `lr_min` to `min_lr_factor ` (#1471) [fixed] https://github.com/pytorch/torchtitan/issues/1457 --------- Co-authored-by: shuaipengli --- tests/unit_tests/test_lr_scheduler.py | 20 +++++++++++-------- torchtitan/components/lr_scheduler.py | 10 +++++----- torchtitan/config/job_config.py | 6 +++--- .../train_configs/deepseek_v2.toml | 2 +- .../llama4/train_configs/debug_model.toml | 2 +- .../llama4/train_configs/llama4_17bx128e.toml | 2 +- .../llama4/train_configs/llama4_17bx16e.toml | 2 +- .../train_configs/debug_model.toml | 2 +- .../train_configs/deepseek_v3_16b.toml | 4 ++-- .../train_configs/deepseek_v3_671b.toml | 4 ++-- .../llama3/train_configs/debug_model.toml | 2 +- 11 files changed, 30 insertions(+), 26 deletions(-) diff --git a/tests/unit_tests/test_lr_scheduler.py b/tests/unit_tests/test_lr_scheduler.py index 3d57bbd0cf..dfa51751dc 100644 --- a/tests/unit_tests/test_lr_scheduler.py +++ b/tests/unit_tests/test_lr_scheduler.py @@ -37,7 +37,7 @@ def create_job_config( warmup_steps=None, decay_ratio=None, decay_type=None, - lr_min=None, + min_lr_factor=None, ): # Create a job config with the specified parameters args = [ @@ -58,7 +58,11 @@ def create_job_config( args += ( ["--lr_scheduler.decay_type", decay_type] if decay_type is not None else [] ) - args += ["--lr_scheduler.lr_min", str(lr_min)] if lr_min is not None else [] + args += ( + ["--lr_scheduler.min_lr_factor", str(min_lr_factor)] + if min_lr_factor is not None + else [] + ) config_manager = ConfigManager() # Create base config with parameters passed directly @@ -74,7 +78,7 @@ def test_linear_warmup_decay(self): warmup_steps=2, decay_ratio=None, # Use default decay: start decay immediately decay_type=None, - lr_min=None, + min_lr_factor=None, ) # Build the lr scheduler @@ -116,7 +120,7 @@ def test_warmup_stable_decay(self): warmup_steps=2, decay_ratio=0.5, # 50% of steps for decay decay_type="linear", - lr_min=0.0, + min_lr_factor=0.0, ) # Build the lr scheduler @@ -157,7 +161,7 @@ def test_min_lr(self): warmup_steps=2, decay_ratio=None, decay_type="linear", - lr_min=0.2, # 20% of base LR as minimum + min_lr_factor=0.2, # 20% of base LR as minimum ) # Build the lr scheduler @@ -180,7 +184,7 @@ def test_warmup_exceeds_training(self): warmup_steps=10, # More than training steps decay_ratio=None, decay_type="linear", - lr_min=0.0, + min_lr_factor=0.0, ) # Build the lr scheduler - should adjust warmup steps @@ -216,7 +220,7 @@ def test_warmup_stable_only(self): warmup_steps=2, decay_ratio=0.0, # 0% of steps for decay (no decay) decay_type="linear", - lr_min=0.0, + min_lr_factor=0.0, ) # Build the lr scheduler @@ -258,7 +262,7 @@ def test_warmup_plus_decay_exceeds_training(self): warmup_steps=5, decay_ratio=0.8, # 80% of steps for decay (8 steps) decay_type="linear", - lr_min=0.0, + min_lr_factor=0.0, ) # Build the lr scheduler - should adjust warmup steps diff --git a/torchtitan/components/lr_scheduler.py b/torchtitan/components/lr_scheduler.py index 8829431887..9bdccf7981 100644 --- a/torchtitan/components/lr_scheduler.py +++ b/torchtitan/components/lr_scheduler.py @@ -127,7 +127,7 @@ def build_lr_schedulers( # Add a vitual last step to prevent the learning rate from dropping to 0 stable_steps = training_steps + 1 - warmup_steps - decay_steps lr_decay_type = lr_scheduler_config.decay_type - lr_min = lr_scheduler_config.lr_min + min_lr_factor = lr_scheduler_config.min_lr_factor def linear_warmup_stable_decay( current_step: int, @@ -135,7 +135,7 @@ def linear_warmup_stable_decay( stable_steps: int, decay_steps: int, lr_decay_type: str, - lr_min: float, + min_lr_factor: float, ): """ Computes linear warmup followed by stable learning rate for a while, @@ -150,7 +150,7 @@ def linear_warmup_stable_decay( 2. `sqrt`: decays as 1 minus the square root of the decay progress. 3. `cosine`: follows a cosine curve, decaying according to the values of the half-period of the cosine function. - If `lr_min` is specified, the decay range is scaled from 1 to `lr_min` + If `min_lr_factor` is specified, the decay range is scaled from 1 to `min_lr_factor` to ensure the learning rate does not drop below this minimum value. """ warmup_stable_steps = warmup_steps + stable_steps @@ -176,7 +176,7 @@ def linear_warmup_stable_decay( curr_adjustment = 1 - math.sqrt(progress) elif lr_decay_type == "cosine": curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) - curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment + curr_adjustment = min_lr_factor + (1 - min_lr_factor) * curr_adjustment return curr_adjustment lr_lambda = functools.partial( @@ -185,6 +185,6 @@ def linear_warmup_stable_decay( stable_steps=stable_steps, decay_steps=decay_steps, lr_decay_type=lr_decay_type, - lr_min=lr_min, + min_lr_factor=min_lr_factor, ) return LRSchedulersContainer(optimizers, lr_lambda) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index eaf73bdff8..5e4df35eff 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -155,11 +155,11 @@ class LRScheduler: - 'cosine': smoothly decays learning rate following a cosine curve """ - lr_min: float = 0.0 + min_lr_factor: float = 0.0 """ Min lr ratio for lr scheduler. - If provided, the range of decay factor is scaled from 1 to `lr_min` - to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`. + If provided, the range of decay factor is scaled from 1 to `min_lr_factor` + to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.min_lr_factor`. """ diff --git a/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml b/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml index 6b8390178d..cb0bfa72e9 100644 --- a/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml +++ b/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml @@ -35,7 +35,7 @@ implementation = "foreach" warmup_steps = 100 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps decay_type = "linear" -lr_min = 0.1 +min_lr_factor = 0.1 [training] local_batch_size = 2 # 8 diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index 7c17cc9d9c..fb672cc4c7 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -34,7 +34,7 @@ eps = 1e-15 warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps decay_type = "linear" -lr_min = 0.1 +min_lr_factor = 0.1 [training] local_batch_size = 8 diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index bfaa57fa4e..f316cd8380 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -27,7 +27,7 @@ eps = 1e-15 [lr_scheduler] warmup_steps = 600 -lr_min = 0.1 +min_lr_factor = 0.1 [training] local_batch_size = 1 diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index 66d7c9dd76..725bbe903d 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -27,7 +27,7 @@ eps = 1e-15 [lr_scheduler] warmup_steps = 600 -lr_min = 0.1 +min_lr_factor = 0.1 [training] local_batch_size = 8 diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 1983b0611d..64d080126b 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -36,7 +36,7 @@ eps = 1e-8 warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps decay_type = "linear" -lr_min = 0.0 +min_lr_factor = 0.0 [training] local_batch_size = 8 diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 35694e1fd8..290d4aa6d0 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -33,8 +33,8 @@ eps = 1e-8 [lr_scheduler] warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps -decay_type = "linear" -lr_min = 2.2e-5 +decay_type = "cosine" +min_lr_factor = 0.1 [training] local_batch_size = 8 diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 51acd7e72a..51e7ddbb50 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -33,8 +33,8 @@ eps = 1e-8 [lr_scheduler] warmup_steps = 2_000 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps -decay_type = "linear" -lr_min = 2.2e-5 +decay_type = "cosine" +min_lr_factor = 0.1 [training] local_batch_size = 4 diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index b61520f1cf..00a688dcf5 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -36,7 +36,7 @@ eps = 1e-8 warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps decay_type = "linear" -lr_min = 0.0 +min_lr_factor = 0.0 [training] local_batch_size = 8 From f1c8c2cd07edbd8c2d256e64016830d3670ea9b3 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 30 Jul 2025 16:13:28 -0700 Subject: [PATCH 048/128] guard against nvidia-smi command exit code 1 (#1496) Somehow, despite the logs showing echo-ing the driver version successfully, it exits with code 1. This PR attempts guard against the case where the command writes to stdout successfully but exits with code 1 for some reason. https://github.com/pytorch/torchtitan/blob/881f0ca465d26ab87ccfc5c89572b8a21c4e9707/.github/workflows/integration_test_8gpu.yaml#L44 --- .github/workflows/integration_test_8gpu.yaml | 2 +- .github/workflows/integration_test_8gpu_flux.yaml | 2 +- .github/workflows/integration_test_8gpu_h100.yaml | 2 +- .github/workflows/integration_test_8gpu_simple_fsdp.yaml | 2 +- .github/workflows/integration_test_8gpu_torchft.yaml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/integration_test_8gpu.yaml b/.github/workflows/integration_test_8gpu.yaml index 8436854e7f..63d6c85ba8 100644 --- a/.github/workflows/integration_test_8gpu.yaml +++ b/.github/workflows/integration_test_8gpu.yaml @@ -40,7 +40,7 @@ jobs: conda activate "${CONDA_ENV}" # Log CUDA driver version for debugging. - DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1) + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1 || true) echo "CUDA driver version: ${DRIVER_VERSION}" pip config --user set global.progress_bar off diff --git a/.github/workflows/integration_test_8gpu_flux.yaml b/.github/workflows/integration_test_8gpu_flux.yaml index ffa0dc86be..36e31fe690 100644 --- a/.github/workflows/integration_test_8gpu_flux.yaml +++ b/.github/workflows/integration_test_8gpu_flux.yaml @@ -42,7 +42,7 @@ jobs: pip config --user set global.progress_bar off # Log CUDA driver version for debugging. - DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1) + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1 || true) echo "CUDA driver version: ${DRIVER_VERSION}" python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 diff --git a/.github/workflows/integration_test_8gpu_h100.yaml b/.github/workflows/integration_test_8gpu_h100.yaml index a170f338d8..4b005d4793 100644 --- a/.github/workflows/integration_test_8gpu_h100.yaml +++ b/.github/workflows/integration_test_8gpu_h100.yaml @@ -41,7 +41,7 @@ jobs: conda activate "${CONDA_ENV}" # Log CUDA driver version for debugging. - DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1) + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1 || true) echo "CUDA driver version: ${DRIVER_VERSION}" pip config --user set global.progress_bar off diff --git a/.github/workflows/integration_test_8gpu_simple_fsdp.yaml b/.github/workflows/integration_test_8gpu_simple_fsdp.yaml index 6bca7dac9a..6b6405f497 100644 --- a/.github/workflows/integration_test_8gpu_simple_fsdp.yaml +++ b/.github/workflows/integration_test_8gpu_simple_fsdp.yaml @@ -39,7 +39,7 @@ jobs: conda activate "${CONDA_ENV}" # Log CUDA driver version for debugging. - DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1) + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1 || true) echo "CUDA driver version: ${DRIVER_VERSION}" pip config --user set global.progress_bar off diff --git a/.github/workflows/integration_test_8gpu_torchft.yaml b/.github/workflows/integration_test_8gpu_torchft.yaml index d249aeaaa6..4ad284afff 100644 --- a/.github/workflows/integration_test_8gpu_torchft.yaml +++ b/.github/workflows/integration_test_8gpu_torchft.yaml @@ -39,7 +39,7 @@ jobs: conda activate "${CONDA_ENV}" # Log CUDA driver version for debugging. - DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1) + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1 || true) echo "CUDA driver version: ${DRIVER_VERSION}" pip config --user set global.progress_bar off From 3c84ce095f76fdc3dd55dea9984c6e9e2bfc80ad Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Wed, 30 Jul 2025 19:34:37 -0400 Subject: [PATCH 049/128] Refactor PP splitting (#1416) This refactors the PP splitting logic to consolidate around settings FQNs for each model chunk. For example: ``` [ ['tok_embeddings', 'layers.0'], # stage0 ['layers.1', 'layers.2'], # stage1 ['layers.3', 'layers.4'], # stage2 ... # so on... ] ``` This is better because it can generally be applied to all models, and the code can be re-used for cases that don't explicitly require pipelined execution (for example, streaming diloco needs to communicate model chunks) Changes: - Refactor deepseekv3 and llama to share the same pipeline util functions - Add module_names_per_model_chunk config, deprecate pipeline_parallel_split_points TODO (follow up PRs): - `pipeline_module_split` will be upstreamed to PyTorch as a `torch.distributed.pipelining` utility since it contains no model specific code. - Additional changes are needed to get this to work for torchft streaming diloco including updating the training loop to not execute if the pipeline schedule isn't set and making sure the pipelining_fn return the correct model chunks. cc @tushar00jain --- tests/unit_tests/test_job_config.py | 73 ++-- torchtitan/config/job_config.py | 25 +- torchtitan/distributed/pipeline.py | 347 +++++++++++++----- torchtitan/models/deepseek_v3/__init__.py | 4 +- .../models/deepseek_v3/infra/pipeline.py | 310 ---------------- torchtitan/models/llama3/infra/pipeline.py | 191 +++++----- 6 files changed, 400 insertions(+), 550 deletions(-) delete mode 100644 torchtitan/models/deepseek_v3/infra/pipeline.py diff --git a/tests/unit_tests/test_job_config.py b/tests/unit_tests/test_job_config.py index 039981dbed..9325ea1861 100644 --- a/tests/unit_tests/test_job_config.py +++ b/tests/unit_tests/test_job_config.py @@ -52,40 +52,34 @@ def test_job_config_file_cmd_overrides(self): ) assert config.job.dump_folder == "/tmp/test_tt/" - def test_parse_pp_split_points(self): - toml_splits = ["layers.2", "layers.4", "layers.6"] - cmdline_splits = ["layers.1", "layers.3", "layers.5"] - # no split points specified - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - "./torchtitan/models/llama3/train_configs/debug_model.toml", - ] - ) - assert config.parallelism.pipeline_parallel_split_points == [] + def test_parse_module_fqns_per_model_part(self): + toml_chunks = [ + ["tok_embeddings", "layers.0"], + ["layers.1", "layers.2"], + ["layers.3", "norm", "output"], + ] + cmdline_chunks = [ + ["tok_embeddings", "layers.0", "layers.1"], + ["layers.2", "layers.3", "norm", "output"], + ] - # toml has no split points, but cmdline splits are specified + # no module names specified config_manager = ConfigManager() config = config_manager.parse_args( [ "--job.config_file", "./torchtitan/models/llama3/train_configs/debug_model.toml", - "--parallelism.pipeline_parallel_split_points", - ",".join(cmdline_splits), ] ) - assert ( - config.parallelism.pipeline_parallel_split_points == cmdline_splits - ), config.parallelism.pipeline_parallel_split_points + assert config.parallelism.module_fqns_per_model_part is None - # toml has split points, cmdline does not + # toml has module names, cmdline does not with tempfile.NamedTemporaryFile() as fp: with open(fp.name, "wb") as f: tomli_w.dump( { "parallelism": { - "pipeline_parallel_split_points": toml_splits, + "module_fqns_per_model_part": toml_chunks, } }, f, @@ -93,32 +87,43 @@ def test_parse_pp_split_points(self): config_manager = ConfigManager() config = config_manager.parse_args(["--job.config_file", fp.name]) assert ( - config.parallelism.pipeline_parallel_split_points == toml_splits - ), config.parallelism.pipeline_parallel_split_points + config.parallelism.module_fqns_per_model_part == toml_chunks + ), config.parallelism.module_fqns_per_model_part - # toml has split points, cmdline overrides them + # test that the field accepts list of lists structure with tempfile.NamedTemporaryFile() as fp: with open(fp.name, "wb") as f: tomli_w.dump( { "parallelism": { - "pipeline_parallel_split_points": toml_splits, + "module_fqns_per_model_part": cmdline_chunks, } }, f, ) config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - fp.name, - "--parallelism.pipeline_parallel_split_points", - ",".join(cmdline_splits), - ] - ) + config = config_manager.parse_args(["--job.config_file", fp.name]) + assert ( + config.parallelism.module_fqns_per_model_part == cmdline_chunks + ), config.parallelism.module_fqns_per_model_part + + # test empty chunks are handled correctly + empty_chunks = [[], ["tok_embeddings"], []] + with tempfile.NamedTemporaryFile() as fp: + with open(fp.name, "wb") as f: + tomli_w.dump( + { + "parallelism": { + "module_fqns_per_model_part": empty_chunks, + } + }, + f, + ) + config_manager = ConfigManager() + config = config_manager.parse_args(["--job.config_file", fp.name]) assert ( - config.parallelism.pipeline_parallel_split_points == cmdline_splits - ), config.parallelism.pipeline_parallel_split_points + config.parallelism.module_fqns_per_model_part == empty_chunks + ), config.parallelism.module_fqns_per_model_part def test_parse_exclude_from_loading(self): toml_splits = ["optimizer", "dataloader"] diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 5e4df35eff..5255de3da4 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -290,6 +290,7 @@ class Parallelism: pipeline_parallel_split_points: list[str] = field(default_factory=list) """ + DEPRECATED: Use module_fqns_per_model_part instead. Specify comma-separated names of modules to use as the beginning of a split point. e.g. "layers.0,layers.2" will cause the model to be split into 3 stages, the first containing all the layers up to layers.0, @@ -299,9 +300,31 @@ class Parallelism: but currently the split points must be specified manually. """ + module_fqns_per_model_part: list[list[str]] | None = None + """ + Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk. + Each inner list represents one model chunk and contains the module names that belong to that chunk. + e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']] + will create 3 chunks: the first containing tok_embeddings and layers.0, + the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4. + This provides more explicit control over which modules belong to each chunk compared to split points. + """ + + pipeline_parallel_first_stage_less_layers: int = 1 + """ + The number of layers to reduce in the first stage of pipeline parallelism. This is because + the first stage has the extra overhead of the embedding layer, which is not present in the other stages. + """ + + pipeline_parallel_last_stage_less_layers: int = 1 + """ + The number of layers to reduce in the last stage of pipeline parallelism. This is because + the last stage has the extra overhead of the output layer, which is not present in the other stages. + """ + pipeline_parallel_layers_per_stage: int | None = None """ - The number of layers per (virtual) pipeline stage. If specified, the split points will be + The number of layers per (virtual) pipeline stage. If specified, the module_fqns_per_model_part will be calculated from the number of layers and pipeline_parallel_degree. If not specified, the layers per stage will be inferred from the model, schedule, and pipeline_parallel_degree. """ diff --git a/torchtitan/distributed/pipeline.py b/torchtitan/distributed/pipeline.py index 9526a7e3b7..96cf2ed790 100644 --- a/torchtitan/distributed/pipeline.py +++ b/torchtitan/distributed/pipeline.py @@ -3,122 +3,34 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import os from typing import Callable +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage + from torch.distributed.pipelining.schedules import ( _PipelineSchedule, _PipelineScheduleRuntime, get_schedule_class, PipelineScheduleMulti, PipelineScheduleSingle, + ScheduleZBVZeroBubble, ) -from torch.distributed.pipelining.stage import PipelineStage from torchtitan.config import JobConfig from torchtitan.tools.logging import logger -__all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"] - - -# TODO: It's unclear if this API is general enough to be used by other models. -# If not, we should move it to a Transformer-specific directory. -def generate_split_points( - schedule_str: str, - pp_degree: int, - num_layers: int, - num_layers_per_stage: int | None, - input_weight: int = 1, - output_weight: int = 1, -) -> list[str]: - """ - Generate a list of split points based on the input configs. In this function, - the number of effective layers considered is the summation of num_layers, - input_weight, and output_weight. - - If num_layers_per_virtual_stage is given, we require rigid fit of the - effective layers (regular layers + weighted input + weighted output) - onto pipeline stages and ranks, with several assertions. It is the users' - responsibility to figure out the input weight, output weight, and the - number of regular layers, so that they can be arranged neatly. - - If num_layers_per_virtual_stage is None, we by default set each pipeline rank - to have 1 stage if schedule_str is a single-stage schedule, or 2 virtual stages - if it is a multi-stage schedule, and try to distribute all effective layers - evenly onto the PP stages. If there are extra layers, we disperse them in - the starting stages. - - Args: - schedule_str (str): The string of the schedule name. - pp_degree (int): The pipeline parallel dimension. - num_layers (int): The number of layers in the model. - input_weight (int): The number of layers to consider the input modules in layer calculation. - output_weight (int): The number of layers to consider the output modules in layer calculation. - num_layers_per_stage (int): The number of layers per (virtual) pipeline stage. - - Returns: - list[str]: A list of split point FQNs. - """ - - schedule_class = get_schedule_class(schedule_str) - is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) - - num_effective_layers = num_layers + input_weight + output_weight - - if num_layers_per_stage is not None: - # If num_layers_per_stage is provided, we require a rigid fit of the effective layers - assert num_effective_layers % pp_degree == 0 - num_layers_per_pipeline_rank = num_effective_layers // pp_degree - - assert num_layers_per_pipeline_rank % num_layers_per_stage == 0 - num_stages_per_rank = num_layers_per_pipeline_rank // num_layers_per_stage - - num_total_virtual_stages = num_stages_per_rank * pp_degree - num_extra_layers = 0 - - if is_single_stage_schedule: - assert ( - num_stages_per_rank == 1 - ), f"Number of stages per rank ({num_stages_per_rank}) must be 1 for single-stage schedules." - else: - assert ( - num_stages_per_rank >= 2 - ), f"Number of stages per rank ({num_stages_per_rank}) must be >= 2 for multi-stage schedules." - else: - # In a multi-stage schedule, if num_layers_per_stage is not - # provided, by default each pipeline rank has 2 virtual stages. - num_stages_per_rank = 1 if is_single_stage_schedule else 2 - num_total_virtual_stages = pp_degree * num_stages_per_rank - - if num_total_virtual_stages > num_effective_layers: - raise ValueError( - "The number of total stages cannot be greater than the number of effective layers." - ) - - num_layers_per_stage = num_effective_layers // num_total_virtual_stages - num_extra_layers = num_effective_layers % num_total_virtual_stages - - assert num_layers_per_stage >= max(input_weight, output_weight) - - splits = [] - current_layer = 0 - for i in range(num_total_virtual_stages - 1): - if i == 0: - current_layer += num_layers_per_stage - input_weight - else: - current_layer += num_layers_per_stage - # extra layers will be dispersed to the first stages - if num_extra_layers > 0: - current_layer += 1 - num_extra_layers -= 1 - splits.append("layers." + str(current_layer)) - - logger.info( - "No 'pipeline_parallel_split_points' provided. Here is the auto-generated split, " - f"which may be sub-optimal: {splits}." - ) - return splits +__all__ = [ + "build_pipeline_schedule", + "stage_ids_this_rank", + "generate_llm_fqn_per_model_part", + "pipeline_module_split", +] def build_pipeline_schedule( @@ -154,7 +66,7 @@ def build_pipeline_schedule( # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training if batch_size % microbatch_size != 0: raise ValueError( - f"Batch size {job_config.training.local_batch_size} must be divisible by number of microbatches {n_microbatches}. " + f"Batch size {job_config.training.local_batch_size} must be divisible by microbatch_size {microbatch_size}. " "Update the config arguments for either batch_size or pipeline_parallel_microbatch_size." ) n_microbatches = batch_size // microbatch_size @@ -209,3 +121,234 @@ def stage_ids_this_rank( zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1)) ) return stage_v_pairs[pp_rank] + + +def generate_llm_fqn_per_model_part( + num_stages: int, + num_layers: int, + input_weight: int = 1, + output_weight: int = 1, +) -> list[list[str]]: + """ + Programmatically generates module names model part, focused on LLMs models. + + Args: + num_stages: Number of pipeline stages + num_layers: Total number of transformer layers in the model + input_weight: Weight for input modules (tok_embeddings) in layer calculation + output_weight: Weight for output modules (norm + output) in layer calculation + + Returns: + List of lists containing module names for each model part + + Example: + generate_llm_fqn_per_model_part(2, 3, input_weight=2, output_weight=2) + treats embeddings as 2 layers and norm+output as 2 layers for distribution + """ + if num_stages < 1: + raise ValueError("Number of stages must be at least 1") + + if num_stages == 1: + # Single stage gets everything + layer_names = [f"layers.{i}" for i in range(num_layers)] + return [["tok_embeddings"] + layer_names + ["norm", "output"]] + + # Calculate effective layers including weights + num_effective_layers = num_layers + input_weight + output_weight + + if num_stages > num_effective_layers: + raise ValueError( + f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" + ) + + # Calculate layers per stage (distribute evenly) + layers_per_stage = num_effective_layers // num_stages + extra_layers = num_effective_layers % num_stages + + # Feasibility check: Ensure at least 1 layer in each PP stage + if layers_per_stage == 0: + raise ValueError( + f"Configuration would result in empty stages. " + f"With {num_stages} stages and {num_effective_layers} effective layers " + f"(num_layers={num_layers} + input_weight={input_weight} + output_weight={output_weight}), " + f"each stage would get {layers_per_stage} layers on average. " + f"Reduce num_stages or increase num_layers/weights." + ) + + # Balance check: Ensure weights don't exceed minimum layers per stage + if input_weight > layers_per_stage: + raise ValueError( + f"input_weight ({input_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + if output_weight > layers_per_stage: + raise ValueError( + f"output_weight ({output_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + + module_names_per_stage = [] + current_layer = 0 + + for stage_idx in range(num_stages): + stage_modules = [] + + # Calculate effective layers for this stage + effective_layers_for_stage = layers_per_stage + if stage_idx < extra_layers: + effective_layers_for_stage += 1 + + # First stage: handle input modules with weighting + if stage_idx == 0: + stage_modules.append("tok_embeddings") + # Account for input weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - input_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Last stage: handle output modules with weighting + elif stage_idx == num_stages - 1: + # Account for output weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - output_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Add output modules + stage_modules.extend(["norm", "output"]) + + # Middle stages: only transformer layers + else: + for _ in range(effective_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + module_names_per_stage.append(stage_modules) + + return module_names_per_stage + + +def pipeline_module_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + pp_schedule: str, + device: torch.device, + module_names_per_stage: list[list[str]], +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API creates pipeline stages based on specified module names for each stage. + + Some model restrictions include: + - forward() method should tolerate deleted layers + - weight initialization methods should tolerate deleted layers + - Does not support nested moduledict and modulelist structures + + Args: + whole_model: The complete model to be split + pp_mesh: Pipeline parallel device mesh + pp_schedule: Name of pipeline parallelism schedule + device: Device + module_names_per_stage: List of lists, where each inner list contains the module names + that should be included in that stage. Module names should be + dot-separated paths. Examples: + - "tok_embeddings" for token embeddings + - "layers.0", "layers.1" for specific transformer layers + - "norm" for the final normalization layer + - "output" for the output projection layer + + Returns: + Tuple of (stages, models) where stages are PipelineStage objects and models are the + corresponding model chunks + + Example usage: + module_names_per_stage = [ + ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer + ["layers.1", "layers.2"], # Stage 1: middle layers + ["norm", "output"] # Stage 2: final norm + output + ] + """ + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + + def _build_stage_from_modules( + stage_idx: int, module_names: list[str], num_stages: int + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}") + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with None + setattr(model, module_name, None) + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(module_names_per_stage) + stages = [] + models = [] + + schedule_class = get_schedule_class(pp_schedule) + style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" + + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): + module_names = module_names_per_stage[stage_idx] + stage, model_chunk = _build_stage_from_modules( + stage_idx, + module_names, + num_stages, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx} " + f"with modules {module_names}" + ) + stages.append(stage) + models.append(model_chunk) + + return stages, models diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index af95492b82..7585b4e03e 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -11,11 +11,11 @@ from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers +from torchtitan.models.llama3.infra.pipeline import pipeline_llama from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from .infra.parallelize import parallelize_deepseekv3 -from .infra.pipeline import pipeline_deepseekv3 from .model.args import DeepSeekV3ModelArgs from .model.model import DeepSeekV3Model @@ -138,7 +138,7 @@ model_cls=DeepSeekV3Model, model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, - pipelining_fn=pipeline_deepseekv3, + pipelining_fn=pipeline_llama, build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, diff --git a/torchtitan/models/deepseek_v3/infra/pipeline.py b/torchtitan/models/deepseek_v3/infra/pipeline.py deleted file mode 100644 index b28ed39ee4..0000000000 --- a/torchtitan/models/deepseek_v3/infra/pipeline.py +++ /dev/null @@ -1,310 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# This file applies the PT-D pipeline parallelism to the Llama model. - -import copy - -import torch -import torch.nn as nn -from torch.distributed import DeviceMesh -from torch.distributed.pipelining import PipelineStage -from torch.distributed.pipelining.schedules import ( - _PipelineSchedule, - get_schedule_class, - PipelineScheduleSingle, - ScheduleZBVZeroBubble, -) - -from torchtitan.components.loss import LossFunction -from torchtitan.config import JobConfig -from torchtitan.distributed import ParallelDims -from torchtitan.distributed.pipeline import build_pipeline_schedule, stage_ids_this_rank -from torchtitan.protocols.train_spec import ParallelizeFunction -from torchtitan.tools.logging import logger - -from ..model.args import DeepSeekV3ModelArgs - - -def generate_module_names_per_stage( - num_stages: int, - num_layers: int, - input_weight: int = 1, - output_weight: int = 1, -) -> list[list[str]]: - """ - Programmatically generates module names per stage for pipeline parallelism with weighting. - - Args: - num_stages: Number of pipeline stages - num_layers: Total number of transformer layers in the model - input_weight: Weight for input modules (tok_embeddings) in layer calculation - output_weight: Weight for output modules (norm + output) in layer calculation - - Returns: - List of lists containing module names for each stage - - Example: - generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2) - treats embeddings as 2 layers and norm+output as 2 layers for distribution - """ - if num_stages < 1: - raise ValueError("Number of stages must be at least 1") - - if num_stages == 1: - # Single stage gets everything - layer_names = [f"layers.{i}" for i in range(num_layers)] - return [["tok_embeddings"] + layer_names + ["norm", "output"]] - - # Calculate effective layers including weights - num_effective_layers = num_layers + input_weight + output_weight - - if num_stages > num_effective_layers: - raise ValueError( - f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" - ) - - # Calculate layers per stage (distribute evenly) - layers_per_stage = num_effective_layers // num_stages - extra_layers = num_effective_layers % num_stages - - # Ensure each stage gets at least the weight of input/output modules - if layers_per_stage < max(input_weight, output_weight): - raise ValueError( - f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})" - ) - - module_names_per_stage = [] - current_layer = 0 - - for stage_idx in range(num_stages): - stage_modules = [] - - # Calculate effective layers for this stage - effective_layers_for_stage = layers_per_stage - if stage_idx < extra_layers: - effective_layers_for_stage += 1 - - # First stage: handle input modules with weighting - if stage_idx == 0: - stage_modules.append("tok_embeddings") - # Account for input weight in layer distribution - remaining_layers_for_stage = effective_layers_for_stage - input_weight - - # Add transformer layers - for _ in range(remaining_layers_for_stage): - if current_layer < num_layers: - stage_modules.append(f"layers.{current_layer}") - current_layer += 1 - - # Last stage: handle output modules with weighting - elif stage_idx == num_stages - 1: - # Account for output weight in layer distribution - remaining_layers_for_stage = effective_layers_for_stage - output_weight - - # Add transformer layers - for _ in range(remaining_layers_for_stage): - if current_layer < num_layers: - stage_modules.append(f"layers.{current_layer}") - current_layer += 1 - - # Add output modules - stage_modules.extend(["norm", "output"]) - - # Middle stages: only transformer layers - else: - for _ in range(effective_layers_for_stage): - if current_layer < num_layers: - stage_modules.append(f"layers.{current_layer}") - current_layer += 1 - - module_names_per_stage.append(stage_modules) - - return module_names_per_stage - - -def pipeline_deepseekv3( - model: nn.Module, - parallel_dims: ParallelDims, - job_config: JobConfig, - device: torch.device, - model_config: DeepSeekV3ModelArgs, - parallelize_fn: ParallelizeFunction, - loss_fn: LossFunction, -) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: - pp_mesh = parallel_dims.world_mesh["pp"] - - # Determine the number of virtual stages based on schedule type - schedule_class = get_schedule_class( - job_config.parallelism.pipeline_parallel_schedule - ) - is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) - - # For multi-stage schedules, default is 2 virtual stages per rank - # For single-stage schedules, default is 1 virtual stage per rank - stages_per_rank = 1 if is_single_stage_schedule else 2 - num_virtual_stages = parallel_dims.pp * stages_per_rank - - # Generate module names per stage programmatically with weighting - num_layers = model_config.n_layers - - # You can adjust these weights based on the computational cost of embeddings and output layers - # Higher weights mean these modules are treated as "heavier" in the distribution - input_weight = 1 # Weight for tok_embeddings - output_weight = 1 # Weight for norm + output layers - - module_names_per_stage = generate_module_names_per_stage( - num_virtual_stages, num_layers, input_weight, output_weight - ) - for i, stage_ms in enumerate(module_names_per_stage): - logger.info(f"Stage {i}: {stage_ms}") - - stages, model_parts = pipeline_module_split( - model, - pp_mesh, - job_config.parallelism.pipeline_parallel_schedule, - device, - module_names_per_stage, - ) - - # For PP with looped schedules, each item in model_parts is one stage-model-chunk. - # We need to iterate through model_parts to apply SPMD parallelisms, compilation, - # optimizer, and checkpointing - for i, m in enumerate(model_parts): - # apply SPMD-style PT-D techniques - m = parallelize_fn(m, parallel_dims, job_config) - model_parts[i] = m - # NOTE: this is to update the model in the stage - # in case the model is modified e.g. by torch.compile - stages[i].submod = m - - pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) - - # This is used in the train loop to determine whether to pass in the input_ids and labels - has_first_stage = False - has_last_stage = False - for stage in stages: - if stage.is_first: - has_first_stage = True - if stage.is_last: - has_last_stage = True - - return pp_schedule, model_parts, has_first_stage, has_last_stage - - -def pipeline_module_split( - whole_model: nn.Module, - pp_mesh: DeviceMesh, - pp_schedule: str, - device: torch.device, - module_names_per_stage: list[list[str]], -) -> tuple[list[PipelineStage], list[nn.Module]]: - """ - This API creates pipeline stages based on specified module names for each stage. - - Args: - whole_model: The complete model to be split - pp_mesh: Pipeline parallel device mesh - pp_schedule: Name of pipeline parallelism schedule - device: Device type - module_names_per_stage: List of lists, where each inner list contains the module names - that should be included in that stage. Module names should be - dot-separated paths. Examples: - - "tok_embeddings" for token embeddings - - "layers.0", "layers.1" for specific transformer layers - - "norm" for the final normalization layer - - "output" for the output projection layer - - Returns: - Tuple of (stages, models) where stages are PipelineStage objects and models are the - corresponding model chunks - - Example usage: - module_names_per_stage = [ - ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer - ["layers.1", "layers.2"], # Stage 1: middle layers - ["norm", "output"] # Stage 2: final norm + output - ] - """ - pp_rank = pp_mesh.get_local_rank() - pp_size = pp_mesh.size() - - def _build_stage_from_modules( - stage_idx: int, module_names: list[str], num_stages: int - ) -> tuple[PipelineStage, nn.Module]: - model = copy.deepcopy(whole_model) - - # Create a set of modules to keep for faster lookup - modules_to_keep = set(module_names) - print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}") - for module_name, module_value in model.named_children(): - # Handle layer-like structures (e.g., "layers.0", "layers.1") - if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): - layers_to_keep = { - name.split(".", 1)[1] - for name in modules_to_keep - if name.startswith(f"{module_name}.") - } - if layers_to_keep: - # Keep only specified layers - if isinstance(module_value, nn.ModuleDict): - for layer_name in list(module_value.keys()): - if layer_name not in layers_to_keep: - del module_value[layer_name] - elif isinstance(module_value, nn.ModuleList): - indices_to_keep = { - int(idx) for idx in layers_to_keep if idx.isdigit() - } - new_layers = nn.ModuleList( - [ - layer - for i, layer in enumerate(module_value) - if i in indices_to_keep - ] - ) - setattr(model, module_name, new_layers) - else: - # No layers from this structure needed, set to empty structure - if isinstance(module_value, nn.ModuleDict): - setattr(model, module_name, nn.ModuleDict()) - elif isinstance(module_value, nn.ModuleList): - setattr(model, module_name, nn.ModuleList()) - # Handle simple module attributes (e.g., "linear", "norm") - elif module_name not in modules_to_keep: - # Replace with None - setattr(model, module_name, None) - - stage = PipelineStage( - model, - stage_idx, - num_stages, - device, - group=pp_mesh.get_group("pp"), - ) - return stage, model - - num_stages = len(module_names_per_stage) - stages = [] - models = [] - - schedule_class = get_schedule_class(pp_schedule) - style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" - - for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): - module_names = module_names_per_stage[stage_idx] - stage, model_chunk = _build_stage_from_modules( - stage_idx, - module_names, - num_stages, - ) - logger.info( - f"PP rank {pp_rank} is building stage_idx {stage_idx} " - f"with modules {module_names}" - ) - stages.append(stage) - models.append(model_chunk) - - return stages, models diff --git a/torchtitan/models/llama3/infra/pipeline.py b/torchtitan/models/llama3/infra/pipeline.py index bf88f74322..db3d6465e6 100644 --- a/torchtitan/models/llama3/infra/pipeline.py +++ b/torchtitan/models/llama3/infra/pipeline.py @@ -6,16 +6,14 @@ # This file applies the PT-D pipeline parallelism to the Llama model. -import copy +import math import torch import torch.nn as nn -from torch.distributed import DeviceMesh -from torch.distributed.pipelining import PipelineStage from torch.distributed.pipelining.schedules import ( _PipelineSchedule, get_schedule_class, - ScheduleZBVZeroBubble, + PipelineScheduleSingle, ) from torchtitan.components.loss import LossFunction @@ -23,13 +21,12 @@ from torchtitan.distributed import ParallelDims from torchtitan.distributed.pipeline import ( build_pipeline_schedule, - generate_split_points, - stage_ids_this_rank, + generate_llm_fqn_per_model_part, + pipeline_module_split, ) -from torchtitan.protocols.train_spec import ParallelizeFunction -from torchtitan.tools.logging import logger -from ..model.args import TransformerModelArgs +from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction +from torchtitan.tools.logging import logger def pipeline_llama( @@ -37,14 +34,95 @@ def pipeline_llama( parallel_dims: ParallelDims, job_config: JobConfig, device: torch.device, - model_args: TransformerModelArgs, + model_args: BaseModelArgs, parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + if job_config.parallelism.pipeline_parallel_split_points != []: + raise ValueError( + "pipeline_parallel_split_points is deprecated. Please use module_fqns_per_model_part instead." + "You can generate module_fqns_per_model_part programmatically with generate_llm_fqn_per_model_part" + ) + pp_mesh = parallel_dims.world_mesh["pp"] - stages, model_parts = pipeline_llama_manual_split( - model, pp_mesh, parallel_dims, job_config, device, model_args + # Determine the number of virtual stages based on schedule type + schedule_class = get_schedule_class( + job_config.parallelism.pipeline_parallel_schedule + ) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + layers_per_stage = job_config.parallelism.pipeline_parallel_layers_per_stage + if hasattr(model_args, "n_layers"): + num_layers = model_args.n_layers + else: + raise ValueError("Model does not have n_layers attribute.") + + # You can adjust these weights based on the computational cost of embeddings and output layers + # Higher weights mean these modules are treated as "heavier" in the distribution + input_weight = job_config.parallelism.pipeline_parallel_first_stage_less_layers + output_weight = job_config.parallelism.pipeline_parallel_last_stage_less_layers + + # Calculate number of virtual stages + if layers_per_stage is not None: + + # Calculate number of virtual stages needed (using ceiling division) + # This allows for unequal distribution where stages can differ by at most 1 layer + num_virtual_stages = math.ceil( + (num_layers + input_weight + output_weight) / layers_per_stage + ) + + # Validation: check stages per rank based on schedule type + model_config_info = f"Model has {num_layers} layers with pipeline_parallel_layers_per_stage={layers_per_stage}" + stage_distribution_info = ( + f"resulting in {num_virtual_stages=} across {parallel_dims.pp} PP ranks" + ) + + if num_virtual_stages % parallel_dims.pp != 0: + raise ValueError( + f"Number of virtual stages ({num_virtual_stages}) must be divisible by " + f"pipeline parallel size ({parallel_dims.pp}). " + f"{model_config_info}. " + f"Please adjust pipeline_parallel_layers_per_stage to a value that results in a number of stages " + f"divisible by {parallel_dims.pp}." + ) + + stages_per_rank = num_virtual_stages // parallel_dims.pp + + if is_single_stage_schedule and stages_per_rank != 1: + raise ValueError( + f"Single stage schedule requires exactly 1 stage per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please increase pipeline_parallel_layers_per_stage to {num_layers // parallel_dims.pp} or higher " + f"to achieve 1 stage per rank." + ) + + if not is_single_stage_schedule and stages_per_rank < 2: + raise ValueError( + f"Multi-stage schedule requires at least 2 stages per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please decrease pipeline_parallel_layers_per_stage to achieve at least 2 stages per rank." + ) + else: + # Fallback to default behavior when layers_per_stage is not provided + # For multi-stage schedules, default is 2 virtual stages per rank + # For single-stage schedules, default is 1 virtual stage per rank + stages_per_rank = 1 if is_single_stage_schedule else 2 + num_virtual_stages = parallel_dims.pp * stages_per_rank + + module_names_per_stage = job_config.parallelism.module_fqns_per_model_part + if module_names_per_stage is None: + module_names_per_stage = generate_llm_fqn_per_model_part( + num_virtual_stages, num_layers, input_weight, output_weight + ) + for i, stage_ms in enumerate(module_names_per_stage): + logger.debug(f"Stage {i}: {stage_ms}") + + stages, model_parts = pipeline_module_split( + model, + pp_mesh, + job_config.parallelism.pipeline_parallel_schedule, + device, + module_names_per_stage, ) # For PP with looped schedules, each item in model_parts is one stage-model-chunk. @@ -70,92 +148,3 @@ def pipeline_llama( has_last_stage = True return pp_schedule, model_parts, has_first_stage, has_last_stage - - -def pipeline_llama_manual_split( - whole_model: nn.Module, - pp_mesh: DeviceMesh, - parallel_dims: ParallelDims, - job_config: JobConfig, - device: torch.device, - model_args: TransformerModelArgs, -) -> tuple[list[PipelineStage], list[nn.Module]]: - """ - This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. - - It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. - - The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD - parallelism. - """ - pp_rank = pp_mesh.get_local_rank() - pp_size = pp_mesh.size() - parallelism_config = job_config.parallelism - - splits = parallelism_config.pipeline_parallel_split_points or generate_split_points( - parallelism_config.pipeline_parallel_schedule, - parallel_dims.pp, - model_args.n_layers, - parallelism_config.pipeline_parallel_layers_per_stage, - ) - - def _build_stage( - stage_idx: int, - start_layer: str | None, - stop_layer: str | None, - is_first: bool = False, - is_last: bool = False, - ) -> tuple[PipelineStage, nn.Module]: - model = copy.deepcopy(whole_model) - if not is_first: - model.tok_embeddings = None - - drop_layers = start_layer is not None - for name in list(model.layers.keys()): - # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) - if f"layers.{name}" == start_layer: - drop_layers = False - if f"layers.{name}" == stop_layer: - drop_layers = True - if drop_layers: - del model.layers[name] - - if not is_last: - model.norm = None - model.output = None - - stage = PipelineStage( - model, - stage_idx, - num_stages, - device, - group=pp_mesh.get_group("pp"), - ) - return stage, model - - num_stages = len(splits) + 1 - stage_idx = pp_rank - - stages = [] - models = [] - - schedule_class = get_schedule_class(parallelism_config.pipeline_parallel_schedule) - style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" - - for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): - start_layer = splits[stage_idx - 1] if stage_idx > 0 else None - stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None - stage, model_chunk = _build_stage( - stage_idx, - start_layer, - stop_layer, - is_first=stage_idx == 0, - is_last=stage_idx == num_stages - 1, - ) - logger.info( - f"PP rank {pp_rank} is building stage_idx {stage_idx}" - f" with start_layer {start_layer}, stop_layer {stop_layer}" - ) - stages.append(stage) - models.append(model_chunk) - return stages, models From be49c02feeaa516958bc6cd123b5c151eba9550c Mon Sep 17 00:00:00 2001 From: Less Wright Date: Wed, 30 Jul 2025 17:19:45 -0700 Subject: [PATCH 050/128] [deepseek] integrate 16B tokenizer to match 16B official model (#1497) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on the discussion in this PR (https://github.com/pytorch/torchtitan/pull/1495), the conclusion was to ensure that 16B uses the proper tokenizer to avoid the cudaAssertError in the current config which comes from mismatch between embeddings and tokenizer vocab. Thus, this PR; 1 - adds additional line to the readme for enabling users to pull the 16B-chat tokenizer, 2- updates the 16_toml config to point to the 16B tokenizer under /assets/tokenizer/deepseek-moe-16b-chat With that, the vocab size of 102400 already in the toml now works flawlessly. **Testing:** run download tokenizer run 20 iters with 16B without issue. Screenshot 2025-07-30 at 12 46
38 PM --- torchtitan/models/deepseek_v3/README.md | 7 ++++++- .../models/deepseek_v3/train_configs/deepseek_v3_16b.toml | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index 367e4e9413..5e6c97e28d 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -7,10 +7,15 @@ DeepSeek-V3 is a Mixture-of-Experts (MoE) transformer model with Multi-head Late ### Download Tokenizer ```bash -# DeepSeek tokenizer (automatically downloads tokenizer.json and tokenizer_config.json) +# DeepSeek 671B tokenizer (automatically downloads tokenizer.json and tokenizer_config.json) python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3 ``` +```bash +# For 16B model support: +python scripts/download_tokenizer.py --repo_id deepseek-ai/deepseek-moe-16b-chat +``` + ## Training ### Debug Training diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 290d4aa6d0..4f646c8d0f 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -22,7 +22,7 @@ enable_wandb = false [model] name = "deepseek_v3" flavor = "16B" -tokenizer_path = "./assets/tokenizer/DeepSeek-V3" +tokenizer_path = "./assets/tokenizer/deepseek-moe-16b-chat" # converters = ["float8"] [optimizer] From 82b593e6cdaa31b783cdbe36f80a7e7c41e1252f Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Wed, 30 Jul 2025 22:24:07 -0700 Subject: [PATCH 051/128] remove dead code (#1501) Summary: remove some stale code that determines parameters to pass to outer optimizer --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1501). * #1446 * #1502 * __->__ #1501 --- torchtitan/components/ft.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft.py index 70b814f3aa..76f2da3ae5 100644 --- a/torchtitan/components/ft.py +++ b/torchtitan/components/ft.py @@ -123,8 +123,6 @@ def maybe_semi_sync_training( ), "FTManager must be enabled to use semi-sync training." if semi_sync_method.lower() == "diloco": # Create the outer optimizer based on the inner optimizer parameters. - params = [group["params"] for group in optimizer.param_groups] - params = [param for sublist in params for param in sublist] outer_optimizers = [] for model in model_parts: params = [p for p in model.parameters() if p.requires_grad] From 5961c759071c27bc1c5eb52d99e84c03b3952bdb Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Wed, 30 Jul 2025 22:24:20 -0700 Subject: [PATCH 052/128] fix creating leaf folder (#1502) Summary: the leaf folder wasn't being created so and no profiles were being written, so create it if it doesn't exist --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1502). * #1446 * __->__ #1502 * #1501 --- torchtitan/tools/profiling.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 843c13a746..0e851d335a 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -40,16 +40,14 @@ def maybe_enable_profiling( def trace_handler(prof): curr_trace_dir_name = "iteration_" + str(prof.step_num) - curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) + curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name, leaf_folder) if not os.path.exists(curr_trace_dir): os.makedirs(curr_trace_dir, exist_ok=True) logger.info(f"Dumping profiler traces at step {prof.step_num}") begin = time.monotonic() - output_file = os.path.join( - curr_trace_dir, leaf_folder, f"rank{rank}_trace.json" - ) + output_file = os.path.join(curr_trace_dir, f"rank{rank}_trace.json") prof.export_chrome_trace(output_file) logger.info( f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" @@ -123,13 +121,13 @@ def step(self, exit_ctx: bool = False): # dump as iteration_0_exit if OOM at iter 1 curr_step = self.step_num - 1 dir_name = f"iteration_{curr_step}_exit" - curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) + curr_snapshot_dir = os.path.join(snapshot_dir, dir_name, leaf_folder) if not os.path.exists(curr_snapshot_dir): os.makedirs(curr_snapshot_dir, exist_ok=True) logger.info(f"Dumping memory snapshot at step {curr_step}") begin = time.monotonic() output_file = os.path.join( - curr_snapshot_dir, leaf_folder, f"rank{rank}_memory_snapshot.pickle" + curr_snapshot_dir, f"rank{rank}_memory_snapshot.pickle" ) with open(output_file, "wb") as output: pickle.dump(torch.cuda.memory._snapshot(), output) From 1080c8fb6b2c726bd1f497e41b42b589966f2f5c Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Wed, 30 Jul 2025 22:53:08 -0700 Subject: [PATCH 053/128] validation support for pipeline parallelism [WIP] (#1490) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With recent api change to pipeline schedule https://github.com/pytorch/pytorch/pull/157795, we can now schedule forward pass and calculate loss, allowing us to use validation and pp together. To test correctness we train from a seed checkpoint with training.seed and training.determinism set with varying degrees of parallelism and different pipeline schedules to compare if loss remains the same: | Parallelism | Loss | | --- | --- | | FSDP=2 | Screenshot 2025-07-29 at 5
12 49 PM | | FSDP=2, TP=2, PP=2, PP_schedule="1F1B" | Screenshot 2025-07-29 at 5 17 18 PM | | FSDP=2, PP=4, PP_schedule="1F1B" | Screenshot 2025-07-29 at 5 15 53 PM | | FSDP=2, PP=4, PP_schedule="Interleaved1F1B" |Screenshot 2025-07-29 at 5 39 39 PM | | FSDP=2, PP=4, PP_schedule="GPipe" | Screenshot 2025-07-29 at 5 49 36 PM | FSDP=2, PP=4, PP_schedule="LoopedBFS" | Screenshot 2025-07-29 at 5 54 55 PM | FSDP=2, PP=4, PP_schedule="InterleavedZeroBubble" | Screenshot 2025-07-30 at 2 30 53 PM --- tests/integration_tests.py | 7 ++-- torchtitan/components/validate.py | 53 +++++++++++++++++++++++++++---- torchtitan/train.py | 18 ++++++++--- 3 files changed, 65 insertions(+), 13 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index e8ffad1bf7..f7512836c6 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -544,13 +544,14 @@ def build_test_list(): [ "--validation.enabled", "--validation.dataset c4_test", - "--parallelism.data_parallel_replicate_degree=2", "--parallelism.tensor_parallel_degree=2", "--parallelism.context_parallel_degree=2", + "--parallelism.pipeline_parallel_degree=2", + "--parallelism.pipeline_parallel_schedule Interleaved1F1B", ], ], - "Validation test with fsdp, tp, cp", - "validation_fsdp_tp_cp", + "Validation test with tp, cp, pp", + "validation_tp_cp_pp", ngpu=8, ), ] diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 7357cc8ed0..7f8b848c7f 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from torch.distributed.fsdp import FSDPModule +from torch.distributed.pipelining.schedules import _PipelineSchedule from torchtitan.components.dataloader import BaseDataLoader from torchtitan.components.loss import LossFunction from torchtitan.components.metrics import MetricsProcessor @@ -54,6 +55,9 @@ def __init__( validation_context: Generator[None, None, None], maybe_enable_amp: Generator[None, None, None], metrics_processor: MetricsProcessor, + pp_schedule: _PipelineSchedule | None = None, + pp_has_first_stage: bool | None = None, + pp_has_last_stage: bool | None = None, ): self.job_config = job_config self.parallel_dims = parallel_dims @@ -67,6 +71,9 @@ def __init__( self.validation_context = validation_context self.maybe_enable_amp = maybe_enable_amp self.metrics_processor = metrics_processor + self.pp_schedule = pp_schedule + self.pp_has_first_stage = pp_has_first_stage + self.pp_has_last_stage = pp_has_last_stage @torch.no_grad() def validate( @@ -75,7 +82,6 @@ def validate( step: int, ) -> dict[str, float]: # Set model to eval mode - # TODO: currently does not support pipeline parallelism model = model_parts[0] model.eval() @@ -110,11 +116,40 @@ def validate( else None ) - with self.validation_context(optional_context_parallel_ctx): - assert len(model_parts) == 1 - with self.maybe_enable_amp: - predictions = model(inputs) - loss = self.loss_fn(predictions, labels) + if parallel_dims.pp_enabled: + assert self.pp_schedule is not None + assert self.pp_has_first_stage is not None + assert self.pp_has_last_stage is not None + # Pipeline Parallel forward inside eval() call + with self.validation_context(optional_context_parallel_ctx): + targets, losses = ( + (labels, []) if self.pp_has_last_stage else (None, None) + ) + if self.pp_has_first_stage: + self.pp_schedule.eval( + inputs, + target=targets, + losses=losses, + input_batch=inputs, + ) + else: + self.pp_schedule.eval( + target=targets, losses=losses, input_batch=inputs + ) + + # accumulate losses across pipeline microbatches + # TODO: PP+FSDP unexpectedly puts the loss back to the CPU + loss = ( + torch.mean(torch.stack(losses)).to(device_type) + if self.pp_has_last_stage + else torch.tensor([-1.0], device=device_type) + ) + else: + with self.validation_context(optional_context_parallel_ctx): + assert len(model_parts) == 1 + with self.maybe_enable_amp: + predictions = model(inputs) + loss = self.loss_fn(predictions, labels) accumulated_losses.append(loss.detach()) @@ -152,6 +187,9 @@ def build_validator( validation_context: Generator[None, None, None], maybe_enable_amp: Generator[None, None, None], metrics_processor: MetricsProcessor | None = None, + pp_schedule: _PipelineSchedule | None = None, + pp_has_first_stage: bool | None = None, + pp_has_last_stage: bool | None = None, ) -> BaseValidator: """Build a simple validator focused on correctness.""" return Validator( @@ -164,4 +202,7 @@ def build_validator( validation_context=validation_context, maybe_enable_amp=maybe_enable_amp, metrics_processor=metrics_processor, + pp_schedule=pp_schedule, + pp_has_first_stage=pp_has_first_stage, + pp_has_last_stage=pp_has_last_stage, ) diff --git a/torchtitan/train.py b/torchtitan/train.py index 58fc69ac2b..64665ca651 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -329,9 +329,16 @@ def __init__(self, job_config: JobConfig): # Build validator if validation is configured if job_config.validation.enabled: assert self.train_spec.build_validator_fn is not None - assert ( - not parallel_dims.pp_enabled - ), "pp is enabled but validation doesn't support pipeline parallelism yet" + + pp_schedule, pp_has_first_stage, pp_has_last_stage = ( + ( + self.pp_schedule, + self.pp_has_first_stage, + self.pp_has_last_stage, + ) + if parallel_dims.pp_enabled + else (None, None, None) + ) self.validator = self.train_spec.build_validator_fn( job_config=job_config, @@ -343,6 +350,9 @@ def __init__(self, job_config: JobConfig): validation_context=self.train_context, maybe_enable_amp=self.maybe_enable_amp, metrics_processor=self.metrics_processor, + pp_schedule=pp_schedule, + pp_has_first_stage=pp_has_first_stage, + pp_has_last_stage=pp_has_last_stage, ) logger.info( @@ -434,7 +444,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, self.tokenizer.eos_id) + pred = model_parts[0](inputs, eos_id=self.tokenizer.eos_id) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred From ad9849c3cd6123e58ccffdd7811e9f51aba8be96 Mon Sep 17 00:00:00 2001 From: speed <69028964+speed1313@users.noreply.github.com> Date: Thu, 31 Jul 2025 14:53:31 +0900 Subject: [PATCH 054/128] Fix data_load_start position (#1481) # Fix incorrect data loading time measurement This PR fixes the timing of data_loading_times measurement in batch_generator. Previously, the timer started after calling next(data_iterator), which excluded the actual data fetching time from the measurement. Now, the timer starts before the next() call to correctly capture the full DataLoader latency. --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 64665ca651..d7f770c3a3 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -373,13 +373,13 @@ def batch_generator( data_iterator = iter(data_iterable) while True: + data_load_start = time.perf_counter() try: batch = next(data_iterator) except StopIteration as ex: # If data runs out during gradient accumulation, that # entire step will not be executed. raise DataloaderStopIteration() from ex - data_load_start = time.perf_counter() input_dict, labels = batch ntokens_batch = labels.numel() self.ntokens_seen += ntokens_batch From b1dc33067f2e79af9bcd4f888b475b19733d7795 Mon Sep 17 00:00:00 2001 From: Ido Hakimi <5303103+idoh@users.noreply.github.com> Date: Thu, 31 Jul 2025 07:54:14 +0200 Subject: [PATCH 055/128] Refactor script to use 'overwrites' variable for command-line arguments in training scripts (#1473) The goal of this PR is to add support for command line arguments to the bash training scripts. The `run_train.sh` had support for `overrides`, however, the `multinode_trainer.slurm` script did not. This `overrides` flag add supports for commands like: `sbatch multinode_trainer.slurm --job.description="TEST_RUN"` However, there is a problem with the current `overrides` implementation, when passing arguments with space such as `"TEST RUN"` instead of `"TEST_RUN"` then the variable `job.description` would only get `TEST` as input and the training script throws an error for unrecognizing the argument `RUN` which is passed in a different line. To address this I simplify the code and directly pass the additional overrides through `$@`. This solves the issue for commands such as: `sbatch multinode_trainer.slurm --job.description="TEST RUN"` --- multinode_trainer.slurm | 4 ++-- run_train.sh | 9 ++------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/multinode_trainer.slurm b/multinode_trainer.slurm index 5330186b09..3b123fb62f 100644 --- a/multinode_trainer.slurm +++ b/multinode_trainer.slurm @@ -34,7 +34,7 @@ export LOGLEVEL=INFO export FI_PROVIDER="efa" # Ensure that P2P is available # export NCCL_P2P_DISABLE=1 -export NCCL_IB_DISABLE=1 +# export NCCL_IB_DISABLE=1 # debugging flags (optional) export NCCL_DEBUG=WARN @@ -59,5 +59,5 @@ CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.t dcgmi profile --pause # adjust sbatch --ntasks and sbatch --nodes above and --nnodes below # to your specific node count, and update target launch file. -srun torchrun --nnodes 4 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./torchtitan/train.py --job.config_file ${CONFIG_FILE} +srun torchrun --nnodes 4 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./torchtitan/train.py --job.config_file ${CONFIG_FILE} "$@" dcgmi profile --resume diff --git a/run_train.sh b/run_train.sh index fbed394ebb..01dddd0abd 100755 --- a/run_train.sh +++ b/run_train.sh @@ -7,22 +7,17 @@ set -ex -# use envs as local overrides for convenience +# use envs as local overwrites for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} -overrides="" -if [ $# -ne 0 ]; then - overrides="$*" -fi - TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ --m torchtitan.train --job.config_file ${CONFIG_FILE} $overrides +-m torchtitan.train --job.config_file ${CONFIG_FILE} "$@" From cf30b2902718790cbe91900414c3201b6d7680b0 Mon Sep 17 00:00:00 2001 From: Ido Hakimi <5303103+idoh@users.noreply.github.com> Date: Thu, 31 Jul 2025 07:55:11 +0200 Subject: [PATCH 056/128] Add logging for learning rates in MetricsProcessor (#1413) This PR adds learning rate logging. There was a previous attempt to implement this in an [earlier PR](https://github.com/pytorch/torchtitan/pull/937), but that one was ultimately **closed**. This version ensures that LR logging works properly, I verified it using the WSD scheduler that was recently added in [another PR](https://github.com/pytorch/torchtitan/pull/938). image One design consideration here is that torchtitan supports multiple optimizers and learning rate schedules, each potentially having its own LR. However, in practice, I believe that 99.9999% of use cases will use a single LR. Given that, the logging works as follows: - If there is only one learning rate, it gets logged directly under the main charts as `lr`. - If there are multiple learning rates, they are logged under a separate section, each with its corresponding label. Alternatively, we could have ignored the multi-LR case and always logged a single LR, but I prefer this approach since it handles both scenarios robustly with minimal extra code. Happy to adjust if others have a strong preference for simplicity over robustness. --- torchtitan/train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index d7f770c3a3..369c409a81 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -456,6 +456,8 @@ def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): self.optimizers.zero_grad() + # Save the current step learning rate for logging + lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] # Keep these variables local to shorten the code as these are # the major variables that are used in the training loop. @@ -503,12 +505,16 @@ def train_step( else: global_avg_loss = global_max_loss = loss.detach().item() + extra_metrics = { + "n_tokens_seen": self.ntokens_seen, + "lr": lr, + } self.metrics_processor.log( self.step, global_avg_loss, global_max_loss, grad_norm.item(), - extra_metrics={"ntokens_seen": self.ntokens_seen}, + extra_metrics=extra_metrics, ) @record From d655e1673d5f15d31a9514092f281e744107d821 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 31 Jul 2025 19:16:49 -0700 Subject: [PATCH 057/128] Make token group alignment size configurable (#1503) ## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: https://github.com/pytorch/ao/pull/2642 - Did manual test run with llama4 debug model using bf16 --- torchtitan/components/quantization/float8.py | 7 ++++ torchtitan/components/quantization/mx.py | 10 ++++++ torchtitan/config/job_config.py | 9 ++++- .../llama4/infra/expert_parallel.py | 35 ++++++++++++++++--- 4 files changed, 56 insertions(+), 5 deletions(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 863ea266fc..58699b92ee 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -10,6 +10,9 @@ from torchtitan.config.job_config import Float8, JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.experiments.llama4.infra.expert_parallel import ( + set_token_group_alignment_size_m, +) from torchtitan.protocols.model_converter import ( ModelConverter, register_model_converter, @@ -66,6 +69,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): job_config.parallelism.context_parallel_degree == 1 ), "Float8 MoE training prototype does not yet support context parallelism" + # For fp8 grouped GEMM, token group sizes must be multiples of 16 + # (16 byte alignment / 1 byte per elem = 16 elements) + set_token_group_alignment_size_m(16) + if float8_config.recipe_name is not None: assert not float8_config.enable_fsdp_float8_all_gather, ( "using `float8_config.enable_fsdp_float8_all_gather` together " diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index f2c6820a70..ce4d89ffe6 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -59,6 +59,16 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): and job_config.parallelism.tensor_parallel_degree > 1 ), "TP not yet supported with torch.compile for mxfp8" + # For MoE training with mxfp8, token group sizes must be multiples of 32 + if job_config.mx.moe_fqns_prototype: + from torchtitan.experiments.llama4.infra.expert_parallel import ( + set_token_group_alignment_size, + ) + + mxfp8_block_size = 32 + set_token_group_alignment_size(mxfp8_block_size) + logger.info(f"Setting token group alignment size to {mxfp8_block_size}") + # Configure MXFP8 from torchao.prototype.mx_formats.config import ( MXFP8Dim1CastKernelChoice, diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 5255de3da4..39e81f7a99 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -567,12 +567,19 @@ class MX: filter_fqns: list[str] = field(default_factory=lambda: ["output"]) """ - Comma-separated list of fully qualified names of modules to skip applying mxfloat8 training to. + Comma-separated list of fully qualified names of modules to skip applying mxfp8 training to. nn.Linear modules with any dim size not divisible by 16 are also always skipped due to hardware requirements. By default we always skip the output layer. Example: --mx.filter_fqns "attention.wq,attention.wk,attention.wv,output" """ + moe_fqns_prototype: list[str] | str = field(default_factory=list) + """ + Comma-separated list of fully qualified names of MoE modules to apply mxfp8 training to. + This is a prototype feature that requires the torchao nightly build. + Example: --mx.moe_fqns_prototype="experts" + """ + @dataclass class Comm: diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index 0e8aef8ee2..f40dbae2bc 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -6,7 +6,7 @@ from functools import partial -from typing import Callable +from typing import Callable, Literal import torch import torch.distributed as dist @@ -24,6 +24,33 @@ from torch.distributed.tensor.placement_types import Placement +TOKEN_GROUP_ALIGN_SIZE_M = 8 +ValidTokenGroupAlignmentSize = Literal[8, 16, 32] + + +def set_token_group_alignment_size_m( + alignment_size: ValidTokenGroupAlignmentSize, +) -> None: + """ + Set the token group alignment size for token groups in MoE. This is implemented by + padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M. + + Valid values are: 8, 16, or 32. + Different values are needed for different cases: + + * For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements). + * For fp8, 16 byte alignment / 1 byte per elem = 16 elements. + * For mxfp8, we need 32 (or block_size) because scaling block size is (1 x 32), + so when doing per-token-group quantization on each logically distinct subtensor, + we need to ensure the contracting dim is divisible by block_size. + In the backward pass, grad_weight = (grad_output_t @ input).t() has gemm dims + of (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M, + so we need 32 element alignment. + """ + global TOKEN_GROUP_ALIGN_SIZE_M + TOKEN_GROUP_ALIGN_SIZE_M = alignment_size + + # implementation of Tensor Parallel for the GroupedExperts in MoE class TensorParallel(ParallelStyle): def _partition_fn(self, name, module, device_mesh): @@ -251,6 +278,7 @@ def wrapper( x: torch.Tensor, num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: + global TOKEN_GROUP_ALIGN_SIZE_M if isinstance(w1, DTensor): w1 = w1.to_local() w2 = w2.to_local() @@ -264,7 +292,6 @@ def wrapper( experts_per_ep_rank = w1.shape[0] num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank - ALIGN_SIZE_M = 16 with torch.no_grad(): ( permuted_indices, @@ -274,8 +301,8 @@ def wrapper( num_tokens_per_expert, experts_per_ep_rank, num_ep_ranks, - x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, - ALIGN_SIZE_M, + x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M, + TOKEN_GROUP_ALIGN_SIZE_M, ) x = torch.vstack((x, x.new_zeros((x.shape[-1])))) From b109f7d713febb12231390b2c051986007e6b48a Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Fri, 1 Aug 2025 13:35:42 -0700 Subject: [PATCH 058/128] [DSV3] Add output.contiguous() in model to match llama3 (#1504) (#1513) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #1504 **Summary** ## ~~Change tokenizer size~~ This is resolved by downloading the right tokenizer Before the change: ``` File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/xilunwu/pytorch/torch/nn/modules/normalization.py", line 414, in forward return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/xilunwu/pytorch/torch/nn/functional.py", line 2924, in rms_norm return torch.rms_norm(input, normalized_shape, weight, eps) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torch.AcceleratorError: CUDA error: device-side assert triggered Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. ``` Adding CUDA_LAUNCH_BLOCKING=1 to launch command shows the real error is in embedding. After fixing the tokenizer size the training works fine. ## Add `.contiguous()` to output after calling transpose() Command: `NGPU=8 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.context-parallel-degree 2` Error: ``` [rank0]:[rank0]: File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/xilunwu/oss/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 237, in forward [rank0]:[rank0]: output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. ``` The model code didn't match with llama3. After adding `.contiguous()` it runs correctly. ``` NGPU=8 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.context-parallel-degree 2 + NGPU=8 + export LOG_RANK=0 + LOG_RANK=0 + CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml + overrides= + '[' 2 -ne 0 ']' + overrides='--parallelism.context-parallel-degree 2' + TORCHFT_LIGHTHOUSE=http://localhost:29510 + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510 + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml --parallelism.context-parallel-degree 2 W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] ***************************************** W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] ***************************************** [rank0]:[titan] 2025-07-31 11:31:25,671 - root - INFO - Starting job: DeepSeek-V3 16B model training [rank0]:[titan] 2025-07-31 11:31:27,890 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-31 11:31:27,891 - root - INFO - Building 2-D device mesh with ['dp_shard', 'cp'], [4, 2] [rank0]:[titan] 2025-07-31 11:31:27,897 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-07-31 11:31:32,956 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-07-31 11:31:33,170 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:[titan] 2025-07-31 11:31:38,681 - root - INFO - Building deepseek_v3 16B with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=4096, dtype='bf16', vocab_size=129280, dim=2048, inter_dim=10944, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, n_heads=16, norm_eps=1e-05, n_routed_experts=64, n_shared_experts=2, n_activated_experts=6, n_expert_groups=1, n_limited_groups=1, score_func='softmax', route_scale=1.0, use_grouped_mm=True, load_balance_coeff=0.001, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=False, attn_mask_type='causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7) [rank0]:[titan] 2025-07-31 11:31:38,855 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory [rank0]:[titan] 2025-07-31 11:31:38,929 - root - INFO - Total parameter count: dense 968,486,400, sparse 14,848,098,304, active 2,771,250,688 [rank0]:[titan] 2025-07-31 11:31:38,929 - root - INFO - Model deepseek_v3 16B size: 15,816,584,704 total parameters [rank0]:[titan] 2025-07-31 11:31:38,930 - root - INFO - Applied full activation checkpointing to the model [rank0]:[titan] 2025-07-31 11:31:39,021 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-31 11:31:39,021 - root - INFO - Applied Context Parallel to the model [rank0]:[titan] 2025-07-31 11:31:39,398 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-07-31 11:31:39,399 - root - INFO - CUDA memory usage for model: 8.84GiB(9.30%) [rank0]:[titan] 2025-07-31 11:31:39,400 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-31 11:31:39,400 - root - INFO - Trainer is initialized with local batch size 8, global batch size 32, gradient accumulation steps 1, sequence length 4096, total steps 1000 (warmup 200) [rank0]:[titan] 2025-07-31 11:31:39,400 - root - INFO - Training starts at step 1 [rank0]:[titan] 2025-07-31 11:31:49,242 - root - INFO - step: 1 loss: 12.2584 grad_norm: 1.2466 memory: 53.49GiB(56.30%) tps: 1,589 tflops: 28.21 mfu: 2.85% [rank0]:[titan] 2025-07-31 11:31:49,242 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-31 11:32:13,707 - root - INFO - step: 10 loss: 11.5358 grad_norm: 1.4495 memory: 71.08GiB(74.82%) tps: 6,027 tflops: 107.02 mfu: 10.82% [rank0]:[titan] 2025-07-31 11:32:40,848 - root - INFO - step: 20 loss: 10.0093 grad_norm: 7.7745 memory: 71.08GiB(74.82%) tps: 6,037 tflops: 107.20 mfu: 10.84% ``` --- torchtitan/models/deepseek_v3/model/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index e13eb2bf4f..1d92c12545 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -233,7 +233,9 @@ def forward( output = self.sdpa(q, k, v, scale=self.softmax_scale) # Reshape and project output - output = output.transpose(1, 2) # (bsz, seqlen, n_heads, v_head_dim) + output = output.transpose( + 1, 2 + ).contiguous() # (bsz, seqlen, n_heads, v_head_dim) output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) return self.wo(output) # (bsz, seqlen, dim) From 43fa980929ffc77594b2375f85cf57ddd0b052e0 Mon Sep 17 00:00:00 2001 From: Ruisi Zhang Date: Fri, 1 Aug 2025 14:28:40 -0700 Subject: [PATCH 059/128] fix small deepseekv3 typo (#1514) as titled, found there is a small typo in importing deepseekv3 model functions. --- torchtitan/models/deepseek_v3/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 7585b4e03e..8243a0a84a 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -21,8 +21,8 @@ __all__ = [ "parallelize_deepseekv3", - "DeepseekV3ModelArgs", - "DeepseekV3Model", + "DeepSeekV3ModelArgs", + "DeepSeekV3Model", "deepseekv3_configs", ] From 48d8dcde1875c9a8473ab59aa78476c7d30756b7 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 1 Aug 2025 18:55:42 -0400 Subject: [PATCH 060/128] make mx recipe name more generic (#1512) Summary: Instead of maintaining a mapping in torchtitan with valid mx recipe names, just pass the string recipe directly to torchao. This way torchao can iterate on recipes without any changes to torchtitan to use those recipes. Note that appropriate error messages will be thrown from torchao if user specifies an invalid config name, so there is no need to duplicate them in torchtitan. Test Plan: ```bash with-proxy CONFIG_FILE="torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.print_after_conversion --training.compile --training.steps 50 --model.converters mx --mx.recipe_name "mxfp8_cublas_rceil" ``` Reviewers: Subscribers: Tasks: Tags: --- torchtitan/components/quantization/mx.py | 5 +---- torchtitan/config/job_config.py | 7 +++++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index ce4d89ffe6..276208c9a8 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -22,9 +22,6 @@ from .utils import module_filter_fn -# Maps titan recipe names to torchao mx recipe names -NAME_MAP = {"mxfp8": "mxfp8_cublas"} - class MXConverter(ModelConverter): """Converts the linear layers of `model` to `MXLinear`.""" @@ -76,7 +73,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ) mx_job_config: MX = job_config.mx - config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name]) + config = MXLinearConfig.from_recipe_name(mx_job_config.recipe_name) config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[ mx_job_config.mxfp8_dim1_cast_kernel_choice.upper() ] diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 39e81f7a99..fa92545e44 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -562,8 +562,11 @@ class MX: mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton" """Temp work around for inductor performance gap""" - recipe_name: Literal["mxfp8"] = "mxfp8" - """If specified, creates float8 config from recipe name""" + recipe_name: str = "mxfp8_cublas" + """ + If specified, creates MX config from recipe name. See + https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats for more information. + """ filter_fqns: list[str] = field(default_factory=lambda: ["output"]) """ From a0fdaa31b3a4c36505daf7d40b4f44ea95de9f5a Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen <33333409+runame@users.noreply.github.com> Date: Fri, 1 Aug 2025 16:05:19 -0700 Subject: [PATCH 061/128] All-reduce `ntokens_seen` before logging (#1509) Currently, `ntokens_seen` is only locally logged. I think it is almost always desirable to only track the global quantity (the only use case I can see for per-device tracking is for debugging?). Therefore, I propose to all-reduce `ntokens_seen` before logging. --- torchtitan/distributed/utils.py | 10 ++++++++++ torchtitan/train.py | 12 ++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 3c9e20ffb5..aa25149d66 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -62,6 +62,16 @@ def dist_max( ) +def dist_sum( + x: torch.Tensor, + mesh: DeviceMesh, + extra_pg: dist.ProcessGroup | None = None, +) -> float: + return _dist_reduce( + x, reduceOp=c10d.ReduceOp.SUM.name, mesh=mesh, extra_pg=extra_pg + ) + + def dist_mean( x: torch.Tensor, mesh: DeviceMesh, diff --git a/torchtitan/train.py b/torchtitan/train.py index 369c409a81..0955bbb2cb 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -498,15 +498,23 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() ft_pg = self.ft_manager.loss_sync_pg - global_avg_loss, global_max_loss = ( + global_avg_loss, global_max_loss, global_ntokens_seen = ( dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_sum( + torch.tensor( + self.ntokens_seen, dtype=torch.int64, device=self.device + ), + parallel_dims.world_mesh["dp_cp"], + ft_pg, + ), ) else: global_avg_loss = global_max_loss = loss.detach().item() + global_ntokens_seen = self.ntokens_seen extra_metrics = { - "n_tokens_seen": self.ntokens_seen, + "n_tokens_seen": global_ntokens_seen, "lr": lr, } self.metrics_processor.log( From 2429e0b7653aa7e28d319b7c1028c3f233e3968d Mon Sep 17 00:00:00 2001 From: Runa Eschenhagen <33333409+runame@users.noreply.github.com> Date: Fri, 1 Aug 2025 16:25:14 -0700 Subject: [PATCH 062/128] Compute validation metrics at first step (#1508) Currently, the first time validation metrics are computed is when `step == job_config.validation.freq`. I think it is preferable to always compute them for the first step as well. --- torchtitan/components/validate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 7f8b848c7f..12694b25b6 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -28,7 +28,7 @@ def validate(self, model_parts: list[nn.Module]) -> dict[str, float]: raise NotImplementedError("validate method not implemented") def should_validate(self, step: int) -> bool: - return step % self.job_config.validation.freq == 0 + return step == 1 or step % self.job_config.validation.freq == 0 class Validator(BaseValidator): From 004162a4335208bbc36a8bb84073f7ad717214b0 Mon Sep 17 00:00:00 2001 From: Shoufa Chen Date: Fri, 1 Aug 2025 21:06:55 -0400 Subject: [PATCH 063/128] minor fix (#1494) --- .../experiments/flux/infra/parallelize.py | 2 +- torchtitan/experiments/flux/sampling.py | 2 +- torchtitan/experiments/flux/train.py | 25 ++++++++++--------- torchtitan/tools/utils.py | 4 +-- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/torchtitan/experiments/flux/infra/parallelize.py b/torchtitan/experiments/flux/infra/parallelize.py index c2bdb98b30..263655a28f 100644 --- a/torchtitan/experiments/flux/infra/parallelize.py +++ b/torchtitan/experiments/flux/infra/parallelize.py @@ -143,7 +143,7 @@ def parallelize_encoders( fully_shard(t5_model.hf_module, **fsdp_config) if parallel_dims.dp_replicate_enabled: - logger.info("Applied FSDP to the T5 encoder model") + logger.info("Applied HSDP to the T5 encoder model") else: logger.info("Applied FSDP to the T5 encoder model") diff --git a/torchtitan/experiments/flux/sampling.py b/torchtitan/experiments/flux/sampling.py index 8e4e8589ef..2ea0caf8c9 100644 --- a/torchtitan/experiments/flux/sampling.py +++ b/torchtitan/experiments/flux/sampling.py @@ -172,7 +172,7 @@ def denoise( _, latent_channels, latent_height, latent_width = latents.shape # create denoising schedule - timesteps = get_schedule(denoising_steps, latent_channels, shift=True) + timesteps = get_schedule(denoising_steps, latent_height * latent_width, shift=True) # create positional encodings POSITION_DIM = 3 diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index bc3db244dd..add15ad540 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -126,19 +126,20 @@ def forward_backward_step( # Patchify: Convert latent into a sequence of patches latents = pack_latents(latents) - latent_noise_pred = model( - img=latents, - img_ids=latent_pos_enc, - txt=t5_encodings, - txt_ids=text_pos_enc, - y=clip_encodings, - timesteps=timesteps, - ) + with self.maybe_enable_amp: + latent_noise_pred = model( + img=latents, + img_ids=latent_pos_enc, + txt=t5_encodings, + txt_ids=text_pos_enc, + y=clip_encodings, + timesteps=timesteps, + ) - # Convert sequence of patches to latent shape - pred = unpack_latents(latent_noise_pred, latent_height, latent_width) - target = noise - labels - loss = self.loss_fn(pred, target) + # Convert sequence of patches to latent shape + pred = unpack_latents(latent_noise_pred, latent_height, latent_width) + target = noise - labels + loss = self.loss_fn(pred, target) # pred.shape=(bs, seq_len, vocab_size) # need to free to before bwd to avoid peaking memory del (pred, noise, target) diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index 1ef19c123f..45bbd4ab83 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -166,11 +166,11 @@ def check_if_feature_in_pytorch( # notify users to check if the pull request is included in their pytorch logger.warning( "Detected that the pytorch is built from source. Please make sure the PR " - f"({pull_request_link}) is included in pytorch for correct {feature_name}." + f"({pull_request}) is included in pytorch for correct {feature_name}." ) elif min_nightly_version is not None and torch.__version__ < min_nightly_version: logger.warning( f"Detected that the pytorch version {torch.__version__} is older than " f"{min_nightly_version}. Please upgrade a newer version to include the " - f"change in ({pull_request_link}) for correct {feature_name}." + f"change in ({pull_request}) for correct {feature_name}." ) From ed288bc9f28700b992cb7e50465648cc21aced28 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sun, 3 Aug 2025 09:13:44 -0700 Subject: [PATCH 064/128] [llama4] store expert weights such that we can transpose before grouped mm to have col-major memory layout (#1517) # Summary Rather than store experts weights pre-transposed (E, in_dim, out_dim), we should store the expert weights non-transposed (E, out_dim, in_dim) then transpose before grouped gemm for (1) compatible dims for gemm, and (2) column-major memory layout required for right operand in grouped gemm. Doing this simple transpose (metadata change only) is must more efficient than doing this [inefficient memory layout transformation before every GEMM in fp8](https://github.com/pytorch/ao/blob/6e941c87c4d9fb9a74e6f979dd522605c696ca42/torchao/prototype/moe_training/scaled_grouped_mm.py#L96). # Eager Performance Llama4 debug model with FSDP=8, using config: ```python "debugmodel": TransformerModelArgs( dim=5120, n_layers=4, n_heads=40, n_kv_heads=8, ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, max_seq_len=10485760, num_experts=16, interleave_moe_layer_step=1, ), ``` ### bfloat16 With change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2147.0 Max Memory Usage: 92.67 GiB ``` Without change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 1711.0 Max Memory Usage: 92.67 GiB ``` ### fp8 rowwise With change: ``` (torchtitan) [danvm@devgpu007.eag6 ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2675.0 Max Memory Usage: 90.35 GiB ``` Without change: ``` (torchtitan) [danvm@devgpu007.eag6 ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2360.0 Max Memory Usage: 90.35 GiB ``` --- .../llama4/infra/expert_parallel.py | 22 +++++++++---- torchtitan/experiments/llama4/model/moe.py | 32 ++++++++++++------- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index f40dbae2bc..9a9dad66ae 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -54,16 +54,21 @@ def set_token_group_alignment_size_m( # implementation of Tensor Parallel for the GroupedExperts in MoE class TensorParallel(ParallelStyle): def _partition_fn(self, name, module, device_mesh): + # w1 shape = (experts, out_dim, in_dim) module.register_parameter( - "w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)])) + "w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(1)])) ) # Column-wise sharding + + # w2 shape = (experts, in_dim, out_dim) module.register_parameter( "w2", - nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])), + nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(2)])), ) # Row-wise sharding + + # w3 shape = (experts, out_dim, in_dim) module.register_parameter( "w3", - nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])), + nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(1)])), ) # Column-wise sharding def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: @@ -223,17 +228,22 @@ def _token_dispatch(self, mod, inputs, device_mesh): return super()._token_dispatch(mod, inputs, self.ep_mesh) def _partition_fn_2d(self, name, mod, ep_tp_mesh): + # w1 shape = (experts, out_dim, in_dim) mod.register_parameter( "w1", - nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(2)])), + nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(1)])), ) # Column-wise sharding + + # w2 shape = (experts, in_dim, out_dim) mod.register_parameter( "w2", - nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(1)])), + nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(2)])), ) # Row-wise sharding + + # w3 shape = (experts, out_dim, in_dim) mod.register_parameter( "w3", - nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(2)])), + nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(1)])), ) # Column-wise sharding def _token_combine(self, mod, routed_output, device_mesh): diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index 71ac1360c3..73a5d0a205 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -23,9 +23,9 @@ def __init__( ): super().__init__() self.num_experts = num_experts - self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) - self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) - self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.w1 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) + self.w2 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) self.use_grouped_mm = use_grouped_mm def forward( @@ -69,9 +69,9 @@ def _run_experts_for_loop( ) out_experts_splits = [] for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx])) - h = h * torch.matmul(x_expert, w3[expert_idx]) - h = torch.matmul(h, w2[expert_idx]) + h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) + h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) + h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) # h shape (tokens_per_expert(varying), dim) out_experts_splits.append(h) out = torch.cat(out_experts_splits, dim=0) @@ -80,10 +80,10 @@ def _run_experts_for_loop( out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) else: # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, w1)) - h = h * torch.bmm(x, w3) + h = F.silu(torch.bmm(x, w1.transpose(-2, -1))) + h = h * torch.bmm(x, w3.transpose(-2, -1)) # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, w2) + out = torch.bmm(h, w2.transpose(-2, -1)) return out @@ -105,9 +105,17 @@ def _run_experts_grouped_mm( # fall back to regular bmm between 3D tensors assert x.dim() == 3 - h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) - h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) - out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) + h = F.silu( + torch._grouped_mm( + x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets + ) + ) + h = h * torch._grouped_mm( + x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets + ) + out = torch._grouped_mm( + h, w2.bfloat16().transpose(-2, -1), offs=offsets + ).type_as(x) return out From 28440294122836ecc2bc35d15314d7b9bd33bad6 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 4 Aug 2025 16:04:00 -0700 Subject: [PATCH 065/128] [llama4] add apply_compile for moe, where fullgraph=False for moe layers (#1519) We should add an `apply_compile` function for llama4 that uses fullgraph=False for MoE layers and fullgraph=True for dense layers. I keep manually applying this hack during development to test compile composability, but IMO we should have this merged and update to use fullgraph=True everywhere once that is supported. cc @xmfan @tianyu-l any thoughts? --- .../experiments/llama4/infra/parallelize.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 33ff71a985..b1e60f9962 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -21,11 +21,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims -from torchtitan.models.llama3.infra.parallelize import ( - apply_ac, - apply_compile, - apply_ddp, -) +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp from torchtitan.tools.logging import logger from .expert_parallel import ( @@ -385,3 +381,19 @@ def apply_moe_ep_tp( device_mesh=experts_mesh, parallelize_plan=experts_plan, ) + + +def apply_compile(model: nn.Module): + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + for layer_id, transformer_block in model.layers.named_children(): + # TODO: remove when torch.compile supports fullgraph=True for llama4 moe + fullgraph = True + if transformer_block.moe_enabled: + fullgraph = False + transformer_block = torch.compile(transformer_block, fullgraph=fullgraph) + model.layers.register_module(layer_id, transformer_block) + + logger.info("Compiling each TransformerBlock with torch.compile") From 92bea07501005b644dff3dd166f573f4991eaaa7 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Mon, 4 Aug 2025 22:38:10 -0700 Subject: [PATCH 066/128] [deepseek] update to 16b base tokenizer (#1499) This PR updates to use base rather than chat (they are the same but name is different) and makes it clear we are not loading the model weights for 16b. Testing: download via script run 20 iters with 16b_base tokenizer. --- torchtitan/models/deepseek_v3/README.md | 4 ++-- .../models/deepseek_v3/train_configs/deepseek_v3_16b.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index 5e6c97e28d..54aa8f8d28 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -12,8 +12,8 @@ python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3 ``` ```bash -# For 16B model support: -python scripts/download_tokenizer.py --repo_id deepseek-ai/deepseek-moe-16b-chat +# DeepSeek 16B tokenizer: +python scripts/download_tokenizer.py --repo_id deepseek-ai/deepseek-moe-16b-base ``` ## Training diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 4f646c8d0f..1cedc590d2 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -22,7 +22,7 @@ enable_wandb = false [model] name = "deepseek_v3" flavor = "16B" -tokenizer_path = "./assets/tokenizer/deepseek-moe-16b-chat" +tokenizer_path = "./assets/tokenizer/deepseek-moe-16b-base" # converters = ["float8"] [optimizer] From 90cfba48c6fe9aa1d266a2282dac23c2cbd5c5ef Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Tue, 5 Aug 2025 11:30:21 -0700 Subject: [PATCH 067/128] Add description for 16B model tokenizer for deepseek-v3 model (#1530) As titled, quick followup for #1499 --- torchtitan/models/deepseek_v3/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index 54aa8f8d28..6698852b47 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -16,6 +16,9 @@ python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3 python scripts/download_tokenizer.py --repo_id deepseek-ai/deepseek-moe-16b-base ``` +> **Note:** We are reusing the tokenizer from deepseek-ai/deepseek-moe-16b-base to help users test and run the 16B model. This is not the official tokenizer for the DeepSeek-V3-16B model. The DeepSeek-V3 model has a different architecture from the deepseek-moe models (different attention implementation, MoE router implementation, etc.), making it not feasible to load deepseek-moe-16b model weights into DeepSeek-V3-16B. + + ## Training ### Debug Training From a204e3188e9257cc3ce6be08c37f7845d7428e30 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Tue, 5 Aug 2025 15:00:48 -0700 Subject: [PATCH 068/128] Flux Validation (#1518) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # This pr implements the validator class for flux following the method discussed in Stable Diffusion 3 paper. The paper shows that creating 8 equidistant timesteps and calculating the average loss on them will result in a highly correlated loss to external validation methods such as CLIP or FID score. This pr's implementation rather than creating 8 stratified timesteps per sample, only applies one of these equidistant timesteps to each sample in a round-robin fashion. Aggregated over many samples in a validation set, this should give a similar validation score as the full timestep method, but will process more validation samples quickly. ### Implementations - Integrates the image generation evaluation in the validation step, users can - Refactors and combines eval job_config with validation - Adds an `all_timesteps` option to the job_config to choose whether to use round robin timesteps or full timesteps per sample - Creates validator class and validation dataloader for flux, validator dataloader handles generating timesteps for round-robin method of validation ### Enabling all timesteps Developers can enable the full timestamp method of validation by setting `all_timesteps = True` in the flux validation job config. Enabling all_timesteps may require tweaking some hyperparams `validation.local_batch_size, validation.steps` to prevent spiking memory and optimizing throughput. By using a ratio of around 1/4 for `validation.local_batch_size` to `training.local_batch_size` will not spike the memory higher than training when `fsdp = 8`. Below we can see the difference between round robin and all timesteps. In the comparison the total number of validation samples processed is the same, but in `all_timesteps=True` configuration we have to lower the batch size to prevent memory spiking. All timesteps also achieves a higher throughput (tps) but still processes total samples of validation set more slowly. | Round Robin (batch_size=32, steps=1, fsdp=8) | All Timesteps (batch_size=8, steps=4, fsdp=8) | | ---- | --- | | Screenshot 2025-08-01 at 3 46
42 PM | Screenshot 2025-08-01 at 3 30
10 PM | --- torchtitan/components/validate.py | 2 +- torchtitan/experiments/flux/__init__.py | 2 + .../experiments/flux/dataset/flux_dataset.py | 125 +++++++- torchtitan/experiments/flux/job_config.py | 9 +- torchtitan/experiments/flux/sampling.py | 8 +- .../flux/tests/integration_tests.py | 3 + .../flux/tests/test_generate_image.py | 10 +- .../tests/unit_tests/test_flux_dataloader.py | 4 +- torchtitan/experiments/flux/train.py | 72 +---- .../flux/train_configs/debug_model.toml | 20 +- .../flux/train_configs/flux_dev_model.toml | 20 +- .../train_configs/flux_schnell_model.toml | 18 +- torchtitan/experiments/flux/validate.py | 276 ++++++++++++++++++ .../multimodal/tokenizer/tiktoken.py | 2 +- 14 files changed, 475 insertions(+), 96 deletions(-) create mode 100644 torchtitan/experiments/flux/validate.py diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 12694b25b6..1bdb854e80 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -80,7 +80,7 @@ def validate( self, model_parts: list[nn.Module], step: int, - ) -> dict[str, float]: + ) -> None: # Set model to eval mode model = model_parts[0] model.eval() diff --git a/torchtitan/experiments/flux/__init__.py b/torchtitan/experiments/flux/__init__.py index 12613a7930..b7c55c4ee4 100644 --- a/torchtitan/experiments/flux/__init__.py +++ b/torchtitan/experiments/flux/__init__.py @@ -17,6 +17,7 @@ from .model.args import FluxModelArgs from .model.autoencoder import AutoEncoderParams from .model.model import FluxModel +from .validate import build_flux_validator __all__ = [ "FluxModelArgs", @@ -117,5 +118,6 @@ build_dataloader_fn=build_flux_dataloader, build_tokenizer_fn=None, build_loss_fn=build_mse_loss, + build_validator_fn=build_flux_validator, ) ) diff --git a/torchtitan/experiments/flux/dataset/flux_dataset.py b/torchtitan/experiments/flux/dataset/flux_dataset.py index bd0fc715c1..df266496c2 100644 --- a/torchtitan/experiments/flux/dataset/flux_dataset.py +++ b/torchtitan/experiments/flux/dataset/flux_dataset.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import itertools import math from dataclasses import dataclass from typing import Any, Callable, Optional @@ -103,6 +104,38 @@ def _cc12m_wds_data_processor( "image": img, "clip_tokens": clip_tokens, # type: List[int] "t5_tokens": t5_tokens, # type: List[int] + "prompt": sample["txt"], # type: str + } + + +def _coco_data_processor( + sample: dict[str, Any], + t5_tokenizer: FluxTokenizer, + clip_tokenizer: FluxTokenizer, + output_size: int = 256, +) -> dict[str, Any]: + """ + Preprocess COCO dataset sample image and text for Flux model. + + Args: + sample: A sample from dataset + t5_encoder: T5 encoder + clip_encoder: CLIP encoder + output_size: The output image size + + """ + img = _process_cc12m_image(sample["image"], output_size=output_size) + prompt = sample["caption"] + if isinstance(prompt, list): + prompt = prompt[0] + t5_tokens = t5_tokenizer.encode(prompt) + clip_tokens = clip_tokenizer.encode(prompt) + + return { + "image": img, + "clip_tokens": clip_tokens, # type: List[int] + "t5_tokens": t5_tokens, # type: List[int] + "prompt": prompt, # type: str } @@ -126,6 +159,11 @@ class TextToImageDatasetConfig: ), data_processor=_cc12m_wds_data_processor, ), + "coco-validation": TextToImageDatasetConfig( + path="howard-hou/COCO-Text", + loader=lambda path: load_dataset(path, split="validation", streaming=True), + data_processor=_coco_data_processor, + ), } @@ -242,8 +280,9 @@ def __iter__(self): # skip low quality image or image with color channel = 1 if sample_dict["image"] is None: + sample = sample.get("__key__", "unknown") logger.warning( - f"Low quality image {sample['__key__']} is skipped in Flux Dataloader." + f"Low quality image {sample} is skipped in Flux Dataloader." ) continue @@ -308,3 +347,87 @@ def build_flux_dataloader( dp_world_size=dp_world_size, batch_size=batch_size, ) + + +class FluxValidationDataset(FluxDataset): + """ + Adds logic to generate timesteps for flux validation method described in SD3 paper + + Args: + generate_timesteps (bool): Generate stratified timesteps in round-robin style for validation + """ + + def __init__( + self, + dataset_name: str, + dataset_path: Optional[str], + t5_tokenizer: BaseTokenizer, + clip_tokenizer: BaseTokenizer, + job_config: Optional[JobConfig] = None, + dp_rank: int = 0, + dp_world_size: int = 1, + generate_timesteps: bool = True, + ) -> None: + # Call parent constructor correctly + super().__init__( + dataset_name=dataset_name, + dataset_path=dataset_path, + t5_tokenizer=t5_tokenizer, + clip_tokenizer=clip_tokenizer, + job_config=job_config, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=False, + ) + + # Initialize timestep generation for validation + self.generate_timesteps = generate_timesteps + if self.generate_timesteps: + # Generate stratified timesteps as described in SD3 paper + val_timesteps = [1 / 8 * (i + 0.5) for i in range(8)] + self.timestep_cycle = itertools.cycle(val_timesteps) + + def __iter__(self): + # Get parent iterator and add timesteps to each sample + parent_iterator = super().__iter__() + + for sample_dict, labels in parent_iterator: + # Add timestep to the sample dict if timestep generation is enabled + if self.generate_timesteps: + sample_dict["timestep"] = next(self.timestep_cycle) + + yield sample_dict, labels + + +def build_flux_validation_dataloader( + dp_world_size: int, + dp_rank: int, + job_config: JobConfig, + # This parameter is not used, keep it for compatibility + tokenizer: BaseTokenizer | None, + generate_timestamps: bool = True, +) -> ParallelAwareDataloader: + """Build a data loader for HuggingFace datasets.""" + dataset_name = job_config.validation.dataset + dataset_path = job_config.validation.dataset_path + batch_size = job_config.validation.local_batch_size + + t5_tokenizer, clip_tokenizer = build_flux_tokenizer(job_config) + + ds = FluxValidationDataset( + dataset_name=dataset_name, + dataset_path=dataset_path, + t5_tokenizer=t5_tokenizer, + clip_tokenizer=clip_tokenizer, + job_config=job_config, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + generate_timesteps=generate_timestamps, + ) + + return ParallelAwareDataloader( + dataset=ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + ) diff --git a/torchtitan/experiments/flux/job_config.py b/torchtitan/experiments/flux/job_config.py index be84c8160b..8c7589d045 100644 --- a/torchtitan/experiments/flux/job_config.py +++ b/torchtitan/experiments/flux/job_config.py @@ -36,7 +36,7 @@ class Encoder: @dataclass -class Eval: +class Validation: enable_classifier_free_guidance: bool = False """Whether to use classifier-free guidance during sampling""" classifier_free_guidance_scale: float = 5.0 @@ -45,8 +45,13 @@ class Eval: """How many denoising steps to sample when generating an image""" eval_freq: int = 100 """Frequency of evaluation/sampling during training""" + save_img_count: int = 1 + """ How many images to generate and save during validation, starting from + the beginning of validation set, -1 means generate on all samples""" save_img_folder: str = "img" """Directory to save image generated/sampled from the model""" + all_timesteps: bool = False + """Whether to generate all stratified timesteps per sample or use round robin""" @dataclass @@ -57,4 +62,4 @@ class JobConfig: training: Training = field(default_factory=Training) encoder: Encoder = field(default_factory=Encoder) - eval: Eval = field(default_factory=Eval) + validation: Validation = field(default_factory=Validation) diff --git a/torchtitan/experiments/flux/sampling.py b/torchtitan/experiments/flux/sampling.py index 2ea0caf8c9..445e8c85fd 100644 --- a/torchtitan/experiments/flux/sampling.py +++ b/torchtitan/experiments/flux/sampling.py @@ -93,7 +93,9 @@ def generate_image( img_height = 16 * (job_config.training.img_size // 16) img_width = 16 * (job_config.training.img_size // 16) - enable_classifier_free_guidance = job_config.eval.enable_classifier_free_guidance + enable_classifier_free_guidance = ( + job_config.validation.enable_classifier_free_guidance + ) # Tokenize the prompt. Unsqueeze to add a batch dimension. clip_tokens = clip_tokenizer.encode(prompt).unsqueeze(0) @@ -132,7 +134,7 @@ def generate_image( model=model, img_width=img_width, img_height=img_height, - denoising_steps=job_config.eval.denoising_steps, + denoising_steps=job_config.validation.denoising_steps, clip_encodings=batch["clip_encodings"], t5_encodings=batch["t5_encodings"], enable_classifier_free_guidance=enable_classifier_free_guidance, @@ -142,7 +144,7 @@ def generate_image( empty_clip_encodings=( empty_batch["clip_encodings"] if enable_classifier_free_guidance else None ), - classifier_free_guidance_scale=job_config.eval.classifier_free_guidance_scale, + classifier_free_guidance_scale=job_config.validation.classifier_free_guidance_scale, ) img = autoencoder.decode(img) diff --git a/torchtitan/experiments/flux/tests/integration_tests.py b/torchtitan/experiments/flux/tests/integration_tests.py index cd2bec0976..ae4e688266 100755 --- a/torchtitan/experiments/flux/tests/integration_tests.py +++ b/torchtitan/experiments/flux/tests/integration_tests.py @@ -64,6 +64,9 @@ def build_test_list(): "Checkpoint Integration Test - Save Model Only fp32", "last_save_model_only_fp32", ), + OverrideDefinitions( + [["--validation.enabled"]], "Flux Validation Test", "validation" + ), # Parallelism tests. OverrideDefinitions( [ diff --git a/torchtitan/experiments/flux/tests/test_generate_image.py b/torchtitan/experiments/flux/tests/test_generate_image.py index 2583b24349..0f06568954 100755 --- a/torchtitan/experiments/flux/tests/test_generate_image.py +++ b/torchtitan/experiments/flux/tests/test_generate_image.py @@ -57,12 +57,12 @@ def test_generate_image(self): "--training.img_size", str(img_width), # eval params - "--eval.denoising_steps", + "--validation.denoising_steps", str(num_steps), - "--eval.enable_classifier_free_guidance", - "--eval.classifier_free_guidance_scale", + "--validation.enable_classifier_free_guidance", + "--validation.classifier_free_guidance_scale", str(classifier_free_guidance_scale), - "--eval.save_img_folder", + "--validation.save_img_folder", "img", ] ) @@ -120,7 +120,7 @@ def test_generate_image(self): save_image( name=f"img_unit_test_{config.training.seed}.jpg", output_dir=os.path.join( - config.job.dump_folder, config.eval.save_img_folder + config.job.dump_folder, config.validation.save_img_folder ), x=image, add_sampling_metadata=True, diff --git a/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py b/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py index 093deb71e5..3d7fb9902a 100644 --- a/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py +++ b/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py @@ -79,7 +79,9 @@ def test_load_dataset(self): for i in range(0, num_steps): input_data, labels = next(it) - assert len(input_data) == 2 # (clip_encodings, t5_encodings) + assert ( + len(input_data) == 3 + ) # (clip_encodings, t5_encodings, prompt) assert labels.shape == (batch_size, 3, 256, 256) assert input_data["clip_tokens"].shape == ( batch_size, diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index add15ad540..7af97cff35 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -5,21 +5,18 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Iterable, Optional +from typing import Optional import torch -from torch.distributed.fsdp import FSDPModule from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import utils as dist_utils from torchtitan.tools.logging import init_logger, logger from torchtitan.train import Trainer -from .dataset.tokenizer import build_flux_tokenizer from .infra.parallelize import parallelize_encoders from .model.autoencoder import load_ae from .model.hf_embedder import FluxEmbedder -from .sampling import generate_image, save_image from .utils import ( create_position_encoding_for_latents, pack_latents, @@ -81,6 +78,15 @@ def __init__(self, job_config: JobConfig): job_config=job_config, ) + if job_config.validation.enabled: + self.validator.flux_init( + device=self.device, + _dtype=self._dtype, + autoencoder=self.autoencoder, + t5_encoder=self.t5_encoder, + clip_encoder=self.clip_encoder, + ) + def forward_backward_step( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor ) -> torch.Tensor: @@ -147,64 +153,6 @@ def forward_backward_step( return loss - def eval_step(self, prompt: str = "A photo of a cat"): - """ - Evaluate the Flux model. - 1) generate and save images every few steps. Currently, we run the eval and on the same - prompts across all DP ranks. We will change this behavior to run on validation set prompts. - Due to random noise generation, results could be different across DP ranks cause we assign - different random seeds to each DP rank. - 2) [TODO] Calculate loss with fixed t value on validation set. - """ - - t5_tokenizer, clip_tokenizer = build_flux_tokenizer(self.job_config) - - image = generate_image( - device=self.device, - dtype=self._dtype, - job_config=self.job_config, - model=self.model_parts[0], - prompt=prompt, # TODO(jianiw): change this to a prompt from validation set - autoencoder=self.autoencoder, - t5_tokenizer=t5_tokenizer, - clip_tokenizer=clip_tokenizer, - t5_encoder=self.t5_encoder, - clip_encoder=self.clip_encoder, - ) - - save_image( - name=f"image_rank{str(torch.distributed.get_rank())}_{self.step}.png", - output_dir=os.path.join( - self.job_config.job.dump_folder, self.job_config.eval.save_img_folder - ), - x=image, - add_sampling_metadata=True, - prompt=prompt, - ) - - # Reshard after run forward pass in eval_step. - # This is to ensure the model weights are sharded the same way for checkpoint saving. - for module in self.model_parts[0].modules(): - if isinstance(module, FSDPModule): - module.reshard() - - def train_step( - self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] - ): - super().train_step(data_iterator) - - # Evaluate the model during training - if ( - self.step % self.job_config.eval.eval_freq == 0 - or self.step == self.job_config.training.steps - ): - model = self.model_parts[0] - model.eval() - # We need to set reshard_after_forward before last forward pass. - # So the model wieghts are sharded the same way for checkpoint saving. - self.eval_step() - model.train() - if __name__ == "__main__": init_logger() diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml index aa35e176ee..aad2580218 100644 --- a/torchtitan/experiments/flux/train_configs/debug_model.toml +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -47,13 +47,6 @@ clip_encoder = "openai/clip-vit-large-patch14" max_t5_encoding_len = 256 autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image -[eval] -enable_classifier_free_guidance = true -classifier_free_guidance_scale = 5.0 -denoising_steps = 4 -save_img_folder = "img" -eval_freq = 5 - [parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 @@ -71,3 +64,16 @@ interval = 10 last_save_model_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[validation] +enabled = false +dataset = "coco-validation" +freq = 5 +local_batch_size = 8 +steps = 1 +enable_classifier_free_guidance = true +classifier_free_guidance_scale = 5.0 +denoising_steps = 4 +save_img_count = 1 +save_img_folder = "img" +all_timesteps = false diff --git a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml index ae03281780..5fbdcb6fca 100644 --- a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml @@ -46,13 +46,6 @@ clip_encoder = "openai/clip-vit-large-patch14" max_t5_encoding_len = 512 autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image -[eval] -enable_classifier_free_guidance = true -classifier_free_guidance_scale = 5.0 -denoising_steps = 50 -save_img_folder = "img" -eval_freq = 1000 - [parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 @@ -70,3 +63,16 @@ interval = 1_000 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[validation] +enabled = false +dataset = "coco-validation" +local_batch_size = 32 +steps = 1 +freq = 1000 +enable_classifier_free_guidance = true +classifier_free_guidance_scale = 5.0 +denoising_steps = 50 +save_img_count = 50 +save_img_folder = "img" +all_timesteps = false diff --git a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml index 9cfb6421b9..d479710e62 100644 --- a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml @@ -46,12 +46,6 @@ clip_encoder = "openai/clip-vit-large-patch14" max_t5_encoding_len = 256 autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image -[eval] -enable_classifier_free_guidance = true -classifier_free_guidance_scale = 5.0 -denoising_steps = 50 -save_img_folder = "img" -eval_freq = 1000 [parallelism] data_parallel_replicate_degree = 1 @@ -70,3 +64,15 @@ interval = 1_000 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[validation] +enabled = false +dataset = "coco-validation" +local_batch_size=64 +freq = 1000 +enable_classifier_free_guidance = true +classifier_free_guidance_scale = 5.0 +denoising_steps = 50 +save_img_count = 50 +save_img_folder = "img" +all_timesteps = false diff --git a/torchtitan/experiments/flux/validate.py b/torchtitan/experiments/flux/validate.py new file mode 100644 index 0000000000..059faf5b65 --- /dev/null +++ b/torchtitan/experiments/flux/validate.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Generator + +import torch +import torch.nn as nn +from torch.distributed.fsdp import FSDPModule +from torch.distributed.pipelining.schedules import _PipelineSchedule + +from torchtitan.components.dataloader import BaseDataLoader +from torchtitan.components.loss import LossFunction +from torchtitan.components.metrics import MetricsProcessor +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.components.validate import Validator +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.experiments.flux.dataset.flux_dataset import ( + build_flux_validation_dataloader, +) + +from torchtitan.experiments.flux.dataset.tokenizer import build_flux_tokenizer +from torchtitan.experiments.flux.model.autoencoder import AutoEncoder +from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder +from torchtitan.experiments.flux.sampling import generate_image, save_image +from torchtitan.experiments.flux.utils import ( + create_position_encoding_for_latents, + pack_latents, + preprocess_data, + unpack_latents, +) + + +class FluxValidator(Validator): + """ + Simple validator focused on correctness and integration. + + Args: + job_config: Job configuration + validation_dataloader: The validation dataloader + loss_fn: Loss function to use for validation + model: The model to validate (single model, no parallelism) + """ + + validation_dataloader: BaseDataLoader + + def __init__( + self, + job_config: JobConfig, + dp_world_size: int, + dp_rank: int, + tokenizer: BaseTokenizer, + parallel_dims: ParallelDims, + loss_fn: LossFunction, + validation_context: Generator[None, None, None], + maybe_enable_amp: Generator[None, None, None], + metrics_processor: MetricsProcessor | None = None, + pp_schedule: _PipelineSchedule | None = None, + pp_has_first_stage: bool | None = None, + pp_has_last_stage: bool | None = None, + ): + self.job_config = job_config + self.parallel_dims = parallel_dims + self.loss_fn = loss_fn + self.all_timesteps = self.job_config.validation.all_timesteps + self.validation_dataloader = build_flux_validation_dataloader( + job_config=job_config, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + tokenizer=tokenizer, + generate_timestamps=not self.all_timesteps, + ) + self.validation_context = validation_context + self.maybe_enable_amp = maybe_enable_amp + self.metrics_processor = metrics_processor + self.t5_tokenizer, self.clip_tokenizer = build_flux_tokenizer(self.job_config) + + def flux_init( + self, + device: torch.device, + _dtype: torch.dtype, + autoencoder: AutoEncoder, + t5_encoder: FluxEmbedder, + clip_encoder: FluxEmbedder, + ): + self.device = device + self._dtype = _dtype + self.autoencoder = autoencoder + self.t5_encoder = t5_encoder + self.clip_encoder = clip_encoder + + @torch.no_grad() + def validate( + self, + model_parts: list[nn.Module], + step: int, + ) -> None: + # Set model to eval mode + # TODO: currently does not support pipeline parallelism + model = model_parts[0] + model.eval() + + save_img_count = self.job_config.validation.save_img_count + + parallel_dims = self.parallel_dims + + accumulated_losses = [] + device_type = dist_utils.device_type + num_steps = 0 + + for input_dict, labels in self.validation_dataloader: + if ( + self.job_config.validation.steps != -1 + and num_steps >= self.job_config.validation.steps + ): + break + + prompt = input_dict.pop("prompt") + if not isinstance(prompt, list): + prompt = [prompt] + for p in prompt: + if save_img_count != -1 and save_img_count <= 0: + break + image = generate_image( + device=self.device, + dtype=self._dtype, + job_config=self.job_config, + model=model, + prompt=p, + autoencoder=self.autoencoder, + t5_tokenizer=self.t5_tokenizer, + clip_tokenizer=self.clip_tokenizer, + t5_encoder=self.t5_encoder, + clip_encoder=self.clip_encoder, + ) + + save_image( + name=f"image_rank{str(torch.distributed.get_rank())}_{step}.png", + output_dir=os.path.join( + self.job_config.job.dump_folder, + self.job_config.validation.save_img_folder, + ), + x=image, + add_sampling_metadata=True, + prompt=p, + ) + save_img_count -= 1 + + # generate t5 and clip embeddings + input_dict["image"] = labels + input_dict = preprocess_data( + device=self.device, + dtype=self._dtype, + autoencoder=self.autoencoder, + clip_encoder=self.clip_encoder, + t5_encoder=self.t5_encoder, + batch=input_dict, + ) + labels = input_dict["img_encodings"].to(device_type) + clip_encodings = input_dict["clip_encodings"] + t5_encodings = input_dict["t5_encodings"] + + bsz = labels.shape[0] + + # If using all_timesteps we generate all 8 timesteps and expand our batch inputs here + if self.all_timesteps: + stratified_timesteps = torch.tensor( + [1 / 8 * (i + 0.5) for i in range(8)], + dtype=torch.float32, + device=self.device, + ).repeat(bsz) + clip_encodings = clip_encodings.repeat_interleave(8, dim=0) + t5_encodings = t5_encodings.repeat_interleave(8, dim=0) + labels = labels.repeat_interleave(8, dim=0) + else: + stratified_timesteps = input_dict.pop("timestep") + + # Note the tps may be inaccurate due to the generating image step not being counted + self.metrics_processor.ntokens_since_last_log += labels.numel() + + # Apply timesteps here and update our bsz to efficiently compute all timesteps and samples in a single forward pass + with torch.no_grad(), torch.device(self.device): + noise = torch.randn_like(labels) + timesteps = stratified_timesteps.to(labels) + sigmas = timesteps.view(-1, 1, 1, 1) + latents = (1 - sigmas) * labels + sigmas * noise + + bsz, _, latent_height, latent_width = latents.shape + + POSITION_DIM = 3 # constant for Flux flow model + with torch.no_grad(), torch.device(self.device): + # Create positional encodings + latent_pos_enc = create_position_encoding_for_latents( + bsz, latent_height, latent_width, POSITION_DIM + ) + text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM) + + # Patchify: Convert latent into a sequence of patches + latents = pack_latents(latents) + + with self.maybe_enable_amp: + latent_noise_pred = model( + img=latents, + img_ids=latent_pos_enc, + txt=t5_encodings, + txt_ids=text_pos_enc, + y=clip_encodings, + timesteps=timesteps, + ) + + # Convert sequence of patches to latent shape + pred = unpack_latents(latent_noise_pred, latent_height, latent_width) + target = noise - labels + loss = self.loss_fn(pred, target) + + del pred, noise, target, latent_noise_pred, latents + + accumulated_losses.append(loss.detach()) + + num_steps += 1 + + # Compute average loss + loss = torch.sum(torch.stack(accumulated_losses)) + loss /= num_steps + if parallel_dims.dp_cp_enabled: + global_avg_loss = dist_utils.dist_mean( + loss, parallel_dims.world_mesh["dp_cp"] + ) + else: + global_avg_loss = loss.item() + + self.metrics_processor.log_validation(loss=global_avg_loss, step=step) + + # Reshard after run forward pass + # This is to ensure the model weights are sharded the same way for checkpoint saving. + for module in model.modules(): + if isinstance(module, FSDPModule): + module.reshard() + + # Set model back to train mode + model.train() + + +def build_flux_validator( + job_config: JobConfig, + dp_world_size: int, + dp_rank: int, + tokenizer: BaseTokenizer, + parallel_dims: ParallelDims, + loss_fn: LossFunction, + validation_context: Generator[None, None, None], + maybe_enable_amp: Generator[None, None, None], + metrics_processor: MetricsProcessor | None = None, + pp_schedule: _PipelineSchedule | None = None, + pp_has_first_stage: bool | None = None, + pp_has_last_stage: bool | None = None, +) -> FluxValidator: + """Build a simple validator focused on correctness.""" + return FluxValidator( + job_config=job_config, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + tokenizer=tokenizer, + parallel_dims=parallel_dims, + loss_fn=loss_fn, + validation_context=validation_context, + maybe_enable_amp=maybe_enable_amp, + metrics_processor=metrics_processor, + pp_schedule=pp_schedule, + pp_has_first_stage=pp_has_first_stage, + pp_has_last_stage=pp_has_last_stage, + ) diff --git a/torchtitan/experiments/multimodal/tokenizer/tiktoken.py b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py index b6de11e522..239cf3d339 100644 --- a/torchtitan/experiments/multimodal/tokenizer/tiktoken.py +++ b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py @@ -32,7 +32,7 @@ from tiktoken.load import load_tiktoken_bpe from torchtitan.components.tokenizer import BaseTokenizer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.tools.logging import logger IMAGE_TOKEN_ID = 128256 From 3065a2aebbf2dc0e635deb35c175f28eedd94690 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Tue, 5 Aug 2025 18:08:44 -0700 Subject: [PATCH 069/128] model fragments for diloco (#1446) Summary: - add a configuration option for users to provide how they want to partition the model - if this is provided, the model needs to implement `FaultTolerantTrainingSpec` that defines the framentation function to split the model based on the configuration - determine the model fragments in training script to pass to ft manager Test Plan: Running llama3 8b parameters with 2 fragments, 1 step delay, each fragment gets synced every 20 steps image --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1446). * #1516 * __->__ #1446 --- torchtitan/components/ft/__init__.py | 18 +++ torchtitan/components/ft/diloco/__init__.py | 13 ++ torchtitan/components/ft/diloco/protocol.py | 19 +++ torchtitan/components/ft/diloco/utils.py | 130 ++++++++++++++++++ .../components/{ft.py => ft/manager.py} | 16 ++- torchtitan/config/job_config.py | 16 +++ torchtitan/models/llama3_ft/__init__.py | 49 +++++++ torchtitan/protocols/train_spec.py | 4 +- torchtitan/train.py | 14 +- 9 files changed, 272 insertions(+), 7 deletions(-) create mode 100644 torchtitan/components/ft/__init__.py create mode 100644 torchtitan/components/ft/diloco/__init__.py create mode 100644 torchtitan/components/ft/diloco/protocol.py create mode 100644 torchtitan/components/ft/diloco/utils.py rename torchtitan/components/{ft.py => ft/manager.py} (93%) create mode 100644 torchtitan/models/llama3_ft/__init__.py diff --git a/torchtitan/components/ft/__init__.py b/torchtitan/components/ft/__init__.py new file mode 100644 index 0000000000..308025d39d --- /dev/null +++ b/torchtitan/components/ft/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.ft.manager import ( + FTManager, + has_torchft, + maybe_semi_sync_training, +) + + +__all__ = [ + "FTManager", + "has_torchft", + "maybe_semi_sync_training", +] diff --git a/torchtitan/components/ft/diloco/__init__.py b/torchtitan/components/ft/diloco/__init__.py new file mode 100644 index 0000000000..d99772a274 --- /dev/null +++ b/torchtitan/components/ft/diloco/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.ft.diloco.protocol import FaultTolerantTrainSpec +from torchtitan.components.ft.diloco.utils import fragment_llm + +__all__ = [ + "FaultTolerantTrainSpec", + "fragment_llm", +] diff --git a/torchtitan/components/ft/diloco/protocol.py b/torchtitan/components/ft/diloco/protocol.py new file mode 100644 index 0000000000..15c218ffe2 --- /dev/null +++ b/torchtitan/components/ft/diloco/protocol.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Callable, TypeAlias + +import torch.nn as nn +from torchtitan.protocols.train_spec import TrainSpec + + +FragmentFunction: TypeAlias = Callable[..., list[nn.Module]] + + +@dataclass +class FaultTolerantTrainSpec(TrainSpec): + fragment_fn: FragmentFunction | None = None diff --git a/torchtitan/components/ft/diloco/utils.py b/torchtitan/components/ft/diloco/utils.py new file mode 100644 index 0000000000..f83759cff6 --- /dev/null +++ b/torchtitan/components/ft/diloco/utils.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +from torchtitan.config.job_config import FaultTolerance as FTConfig +from torchtitan.distributed.pipeline import generate_llm_fqn_per_model_part + + +def module_split( + model: nn.Module, + module_fqns_per_model_fragment: list[list[str]], +) -> list[nn.Module]: + """ + This API creates fragments based on specified module names for each fragment. + This method updates the model in place. + + Args: + model: The complete model to be split + module_fqns_per_model_fragment: List of lists, where each inner list contains the module names + that should be included in that fragment. Module names should be + dot-separated paths. Examples: + - "tok_embeddings" for token embeddings + - "layers.0", "layers.1" for specific transformer layers + - "norm" for the final normalization layer + - "output" for the output projection layer + + Returns: + List of model fragments + + Example usage: + module_fqns_per_model_fragment = [ + ["tok_embeddings", "layers.0"], # fragment 0: embeddings + first layer + ["layers.1", "layers.2"], # fragment 1: middle layers + ["norm", "output"] # fragment 2: final norm + output + ] + """ + + def _build_fragment_from_modules( + fragment_idx: int, module_names: list[str] + ) -> nn.Module: + fragment_model = nn.Module() + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + print(f"fragment {fragment_idx}: Modules to keep: {modules_to_keep}") + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + + if not layers_to_keep: + continue + + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name in layers_to_keep: + setattr( + fragment_model, + f"{module_name}.{layer_name}", + module_value[layer_name], + ) + else: + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(fragment_model, module_name, new_layers) + + continue + + # Handle simple module attributes (e.g., "linear", "norm") + if module_name not in modules_to_keep: + continue + + setattr(fragment_model, module_name, module_value) + + return fragment_model + + num_fragments = len(module_fqns_per_model_fragment) + model_fragments = [] + + for fragment_idx in range(num_fragments): + module_names = module_fqns_per_model_fragment[fragment_idx] + model_fragment = _build_fragment_from_modules( + fragment_idx, + module_names, + ) + print(f"building fragment_idx {fragment_idx} " f"with modules {module_names}") + model_fragments.append(model_fragment) + + return model_fragments + + +def fragment_llm( + model: nn.Module, + ft_config: FTConfig, + n_layers: int, +) -> list[nn.Module]: + assert ft_config.num_fragments > 0 + + module_fqns_per_model_fragment = ft_config.module_fqns_per_model_fragment + + input_weight = 1 # Weight for tok_embeddings + output_weight = 1 # Weight for norm + output layers + + if module_fqns_per_model_fragment == []: + if ft_config.num_fragments == 1: + return [model] + + module_fqns_per_model_fragment = generate_llm_fqn_per_model_part( + ft_config.num_fragments, n_layers, input_weight, output_weight + ) + + model_fragments = module_split(model, module_fqns_per_model_fragment) + print(f"Created {len(model_fragments)} model fragments") + + return model_fragments diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft/manager.py similarity index 93% rename from torchtitan/components/ft.py rename to torchtitan/components/ft/manager.py index 76f2da3ae5..1a33222c1e 100644 --- a/torchtitan/components/ft.py +++ b/torchtitan/components/ft/manager.py @@ -7,10 +7,12 @@ import importlib from contextlib import nullcontext from datetime import timedelta -from typing import ContextManager, Optional, TYPE_CHECKING, Union +from typing import Callable, ContextManager, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist + +import torch.nn as nn from torch.distributed._composable.fsdp.fully_shard import FSDPModule from torch.distributed.distributed_c10d import ReduceOp from torchtitan.config.job_config import FaultTolerance as FTConfig @@ -108,8 +110,10 @@ def loss_sync_pg( def maybe_semi_sync_training( ft_config: FTConfig, ft_manager: FTManager, - model_parts: list[torch.nn.Module], + model: torch.nn.Module, + n_layers: int, optimizer: torch.optim.Optimizer, + fragment_fn: Optional[Callable[..., list[nn.Module]]] = None, ) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]: """ If TorchFT is enabled and the config is set, use semi_sync_method @@ -122,6 +126,11 @@ def maybe_semi_sync_training( ft_manager._manager is not None ), "FTManager must be enabled to use semi-sync training." if semi_sync_method.lower() == "diloco": + if fragment_fn: + model_parts = fragment_fn(model, ft_config, n_layers) + else: + model_parts = [model] + # Create the outer optimizer based on the inner optimizer parameters. outer_optimizers = [] for model in model_parts: @@ -142,10 +151,9 @@ def maybe_semi_sync_training( fragment_update_alpha=ft_config.fragment_update_alpha, ) elif semi_sync_method.lower() == "local_sgd": - assert len(model_parts) == 1 return local_sgd.LocalSGD( manager=ft_manager._manager, - model=model_parts[0], + model=model, optimizer=optimizer, sync_every=ft_config.sync_steps, ) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index fa92545e44..1e13484291 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -696,6 +696,22 @@ class FaultTolerance: This is only used when "semi_sync_method" is set. """ + module_fqns_per_model_fragment: list[list[str]] = field(default_factory=list) + """ + Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model fragment. + Each inner list represents one model fragment and contains the module names that belong to that fragment. + e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']] + will create 3 chunks: the first containing tok_embeddings and layers.0, + the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4. + """ + + num_fragments: int = 1 + """ + Number of fragments to split the model into. This is only used when "semi_sync_method" is "diloco". + This is used to automatically split the model into fragments provided that the model + implements FaultTolerantTrainSpec + """ + @dataclass class Experimental: diff --git a/torchtitan/models/llama3_ft/__init__.py b/torchtitan/models/llama3_ft/__init__.py new file mode 100644 index 0000000000..1dc277051b --- /dev/null +++ b/torchtitan/models/llama3_ft/__init__.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.ft.diloco import FaultTolerantTrainSpec, fragment_llm +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.protocols.train_spec import register_train_spec +from ..llama3 import ( + llama3_configs, + Llama3StateDictAdapter, + parallelize_llama, + pipeline_llama, + Transformer, + TransformerModelArgs, +) + +__all__ = [ + "parallelize_llama", + "pipeline_llama", + "TransformerModelArgs", + "Transformer", + "llama3_configs", +] + + +register_train_spec( + FaultTolerantTrainSpec( + name="llama3_ft", + model_cls=Transformer, + model_args=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + fragment_fn=fragment_llm, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + state_dict_adapter=Llama3StateDictAdapter, + ) +) diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 8a782f8b42..fc1ed1b279 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -6,7 +6,7 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import TypeAlias +from typing import Mapping, TypeAlias import torch.nn as nn from torch.distributed.pipelining.schedules import _PipelineSchedule @@ -43,7 +43,7 @@ class TrainSpec: name: str model_cls: type[ModelProtocol] - model_args: dict[str, BaseModelArgs] + model_args: Mapping[str, BaseModelArgs] parallelize_fn: ParallelizeFunction pipelining_fn: PipeliningFunction | None build_optimizers_fn: OptimizersBuilder diff --git a/torchtitan/train.py b/torchtitan/train.py index 0955bbb2cb..04ad969046 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -48,6 +48,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): lr_schedulers: train_spec_module.LRSchedulersContainer validator: train_spec_module.BaseValidator metrics_processor: train_spec_module.MetricsProcessor + model_args: train_spec_module.BaseModelArgs # non-swappable training components checkpointer: CheckpointManager @@ -146,6 +147,7 @@ def __init__(self, job_config: JobConfig): model_args = self.train_spec.model_args[job_config.model.flavor] # set the model args from training job configs model_args.update_from_config(job_config) + self.model_args = model_args logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" @@ -553,8 +555,18 @@ def train(self): maybe_semi_sync_training( job_config.fault_tolerance, ft_manager=self.ft_manager, - model_parts=self.model_parts, + model=self.model_parts[0], + n_layers=( + self.model_args.n_layers + if hasattr(self.model_args, "n_layers") + else 0 + ), optimizer=self.optimizers, + fragment_fn=( + self.train_spec.fragment_fn + if hasattr(self.train_spec, "fragment_fn") + else None + ), ), ): data_iterator = self.batch_generator(self.dataloader) From cc558277c486e22e60928fb0fa45c2365fbdac7c Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Tue, 5 Aug 2025 18:12:00 -0700 Subject: [PATCH 070/128] checkpoint.md (#1533) fix typo in chekcpoint.md --- docs/checkpoint.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 3986e3dade..b662c52fda 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -60,7 +60,7 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l To create a seed checkpoint, use the same model config as you use for training. e.g. ```bash -NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 +NGPU=1 CONFIG_FILE= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 ``` ## Conversion support From f2830b65db87a3b32ba422b9b9ae85277970554b Mon Sep 17 00:00:00 2001 From: Ali Date: Tue, 5 Aug 2025 18:18:33 -0700 Subject: [PATCH 071/128] Fix config manager directories (#1532) After changing the setup for the JobConfig and ConfigManager, some files had the old structure for paths and directories. Those directories affected the import libs paths. This PR fixes those paths and directories. Mainly these changes are related to this line: `from torchtitan.config_manager import ...` which should be `from torchtitan.config import ....`. --------- Co-authored-by: Ali Sol --- CONTRIBUTING.md | 2 +- docs/debugging.md | 6 +++--- docs/extension.md | 2 +- torchtitan/experiments/deepseek_v3/model_args.py | 2 +- torchtitan/experiments/deepseek_v3/train_ds_real.py | 2 +- torchtitan/experiments/multimodal/check_padding_mm.py | 2 +- torchtitan/experiments/multimodal/mm_dataset.py | 6 +++--- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d317f0bfe3..8de2b9df9d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -51,7 +51,7 @@ Note: To accelerate contributions to and innovations around `torchtitan`, we are - After the model change, it should still load the original checkpoint correctly. - Document the reasons for the code change, similar to [composability.md](docs/composability.md). - Keep code modularized, especially for [train.py](train.py), so that it remains easy to copy-paste into a minimal code example. If necessary: - - Introduce new config options/category in [config_manager.py](torchtitan/config_manager.py). + - Introduce new config options/category in [job_config.py](torchtitan/config/job_config.py). - Create separate functions/files. ### Proof of Value diff --git a/docs/debugging.md b/docs/debugging.md index 61795c70ed..28bad0e3d1 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -37,19 +37,19 @@ You can override it at runtime via CLI with: To inspect how configuration values are interpreted—including those from `.toml` files and CLI overrides—run the config manager directly: ```bash -python -m torchtitan.config_manager [your cli args...] +python -m torchtitan.config.manager [your cli args...] ``` For example, ```bash -python -m torchtitan.config_manager --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml --profiling.enable_memory_snapshot +python -m torchtitan.config.manager --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml --profiling.enable_memory_snapshot ``` To list all available CLI flags and usage: ```bash -python -m torchtitan.config_manager --help +python -m torchtitan.config.manager --help ``` This will print a structured configuration to `stdout`, allowing you to verify that overrides are being applied correctly. diff --git a/docs/extension.md b/docs/extension.md index 25d90447c2..f529b05bb7 100644 --- a/docs/extension.md +++ b/docs/extension.md @@ -36,7 +36,7 @@ This is an ongoing effort, and the level of grouping is subject to change. ### Extending `JobConfig` -[`JobConfig`](../torchtitan/config_manager.py) supports custom extension through the `--experimental.custom_args_module` flag. +[`JobConfig`](../torchtitan/config/job_config.py) supports custom extension through the `--experimental.custom_args_module` flag. This lets you define a custom module that extends `JobConfig` with additional fields. When specified, your custom `JobConfig` is merged with the default: diff --git a/torchtitan/experiments/deepseek_v3/model_args.py b/torchtitan/experiments/deepseek_v3/model_args.py index 3672c70194..2e0bc12ff4 100644 --- a/torchtitan/experiments/deepseek_v3/model_args.py +++ b/torchtitan/experiments/deepseek_v3/model_args.py @@ -10,7 +10,7 @@ from torch import nn from torchtitan.components.tokenizer import BaseTokenizer -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.protocols import BaseModelArgs from torchtitan.tools.logging import logger diff --git a/torchtitan/experiments/deepseek_v3/train_ds_real.py b/torchtitan/experiments/deepseek_v3/train_ds_real.py index be4a92da53..c983a9c570 100644 --- a/torchtitan/experiments/deepseek_v3/train_ds_real.py +++ b/torchtitan/experiments/deepseek_v3/train_ds_real.py @@ -24,7 +24,7 @@ from torchtitan.components.metrics import build_metrics_processor from torchtitan.components.optimizer import build_optimizers -from torchtitan.config_manager import ConfigManager, JobConfig +from torchtitan.config import ConfigManager, JobConfig from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.distributed import ParallelDims diff --git a/torchtitan/experiments/multimodal/check_padding_mm.py b/torchtitan/experiments/multimodal/check_padding_mm.py index 7534dcdccf..0635c7a030 100644 --- a/torchtitan/experiments/multimodal/check_padding_mm.py +++ b/torchtitan/experiments/multimodal/check_padding_mm.py @@ -8,7 +8,7 @@ from mm_dataset import build_mm_dataloader from tokenizer.tiktoken import build_tiktoken_tokenizer -from torchtitan.config_manager import ConfigManager +from torchtitan.config import ConfigManager from torchtitan.tools.logging import init_logger diff --git a/torchtitan/experiments/multimodal/mm_dataset.py b/torchtitan/experiments/multimodal/mm_dataset.py index 5daf1d0eae..da69d6973a 100644 --- a/torchtitan/experiments/multimodal/mm_dataset.py +++ b/torchtitan/experiments/multimodal/mm_dataset.py @@ -16,12 +16,12 @@ from tokenizer.tiktoken import BaseTokenizer, IGNORE_INDEX from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import IterableDataset -from transform import CLIPTransform -from utils import load_image from torchtitan.components.dataloader import ParallelAwareDataloader -from torchtitan.config_manager import JobConfig +from torchtitan.config import JobConfig from torchtitan.tools.logging import logger +from transform import CLIPTransform +from utils import load_image def _load_obelics_dataset(dataset_path: str): From a9aa5069649157c42ab2945247fa3eeff4ab2f38 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Tue, 5 Aug 2025 22:47:50 -0700 Subject: [PATCH 072/128] unify moe implementation for llama4 and deepseek_v3 (#1534) Given the complexity of MoE and EP modules This PR 1. creates `torchtitan/models/moe.py` as the central moe implementation (this is similar to why we have `torchtitan/models/attention.py`) 2. creates `torchtitan/distributed/expert_parallel.py` as the central EP implementation 3. rename `torchtitan/distributed/pipeline.py` -> `torchtitan/distributed/pipeline_parallel.py` to be consistent with EP 4. apply temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467 before the memory leak issue with AC + PT-D all_to_all_single_autograd is fixed (cc @soulitzer) --- torchtitan/components/optimizer.py | 53 +++ torchtitan/components/quantization/float8.py | 4 +- torchtitan/components/quantization/mx.py | 7 +- .../infra => distributed}/expert_parallel.py | 36 +- torchtitan/distributed/parallel_dims.py | 5 - .../{pipeline.py => pipeline_parallel.py} | 0 torchtitan/distributed/utils.py | 20 +- torchtitan/experiments/forge/example_train.py | 6 +- torchtitan/experiments/llama4/__init__.py | 13 +- .../experiments/llama4/infra/parallelize.py | 8 +- torchtitan/experiments/llama4/model/args.py | 21 +- torchtitan/experiments/llama4/model/model.py | 25 +- torchtitan/experiments/llama4/optimizer.py | 66 --- .../llama4/train_configs/debug_model.toml | 2 +- torchtitan/models/deepseek_v3/README.md | 7 +- torchtitan/models/deepseek_v3/__init__.py | 68 ++-- .../models/deepseek_v3/infra/parallelize.py | 2 +- torchtitan/models/deepseek_v3/model/args.py | 21 +- torchtitan/models/deepseek_v3/model/model.py | 48 ++- torchtitan/models/deepseek_v3/model/moe.py | 375 ------------------ .../train_configs/debug_model.toml | 2 +- torchtitan/models/llama3/infra/pipeline.py | 2 +- .../llama4/model => models}/moe.py | 102 +++-- torchtitan/train.py | 6 +- 24 files changed, 315 insertions(+), 584 deletions(-) rename torchtitan/{experiments/llama4/infra => distributed}/expert_parallel.py (90%) rename torchtitan/distributed/{pipeline.py => pipeline_parallel.py} (100%) delete mode 100644 torchtitan/experiments/llama4/optimizer.py delete mode 100644 torchtitan/models/deepseek_v3/model/moe.py rename torchtitan/{experiments/llama4/model => models}/moe.py (80%) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 2a112177e0..ce71ac7f0c 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -24,6 +24,7 @@ __all__ = [ "OptimizersContainer", "build_optimizers", + "build_optimizers_with_moe_load_balancing", ] @@ -323,3 +324,55 @@ def build_optimizers( ) return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + + +def build_optimizers_with_moe_load_balancing( + model_parts: list[nn.Module], + optimizer_config: OptimizerConfig, + parallel_dims: ParallelDims, + ft_manager: FTManager | None = None, +) -> OptimizersContainer: + optimizers = build_optimizers( + model_parts=model_parts, + optimizer_config=optimizer_config, + parallel_dims=parallel_dims, + ft_manager=ft_manager, + ) + + # for MoE auxiliary-loss-free load balancing + def _update_expert_bias( + model_parts: list[nn.Module], + parallel_dims: ParallelDims, + ): + dp_cp_mesh = ( + parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + ) + # TODO: Currently this sync is blocking (thus exposed) and happens on the + # default compute stream. Need to assess if this is OK performance-wise. + for model_part in model_parts: + for transformer_block in model_part.layers.values(): + if transformer_block.moe_enabled: + moe = transformer_block.moe + if moe.load_balance_coeff is None: + return + + if dp_cp_mesh is not None: + torch.distributed.all_reduce( + moe.tokens_per_expert, group=dp_cp_mesh.get_group() + ) + + with torch.no_grad(): + expert_bias_delta = moe.load_balance_coeff * torch.sign( + moe.tokens_per_expert.mean() - moe.tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + moe.expert_bias.add_(expert_bias_delta) + moe.tokens_per_expert.zero_() + + optimizers.register_step_pre_hook( + lambda *args, **kwargs: _update_expert_bias( + model_parts, parallel_dims=parallel_dims + ) + ) + + return optimizers diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 58699b92ee..3629258154 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -10,9 +10,7 @@ from torchtitan.config.job_config import Float8, JobConfig from torchtitan.distributed import ParallelDims -from torchtitan.experiments.llama4.infra.expert_parallel import ( - set_token_group_alignment_size_m, -) +from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m from torchtitan.protocols.model_converter import ( ModelConverter, register_model_converter, diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index 276208c9a8..15c74b7fd7 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -13,6 +13,7 @@ from torchtitan.config.job_config import JobConfig, MX from torchtitan.distributed import ParallelDims +from torchtitan.distributed.expert_parallel import set_token_group_alignment_size_m from torchtitan.protocols.model_converter import ( ModelConverter, register_model_converter, @@ -58,12 +59,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): # For MoE training with mxfp8, token group sizes must be multiples of 32 if job_config.mx.moe_fqns_prototype: - from torchtitan.experiments.llama4.infra.expert_parallel import ( - set_token_group_alignment_size, - ) - mxfp8_block_size = 32 - set_token_group_alignment_size(mxfp8_block_size) + set_token_group_alignment_size_m(mxfp8_block_size) logger.info(f"Setting token group alignment size to {mxfp8_block_size}") # Configure MXFP8 diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/distributed/expert_parallel.py similarity index 90% rename from torchtitan/experiments/llama4/infra/expert_parallel.py rename to torchtitan/distributed/expert_parallel.py index 9a9dad66ae..bc5d43f9f2 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -11,7 +11,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed._functional_collectives import all_to_all_single_autograd from torch.distributed.tensor import ( DeviceMesh, distribute_module, @@ -24,6 +23,41 @@ from torch.distributed.tensor.placement_types import Placement +# from torch.distributed._functional_collectives import all_to_all_single_autograd +# TODO: there is memory leak issue with AC + all_to_all_single_autograd +# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467 +class _A2A(torch.autograd.Function): + @staticmethod + def forward(ctx, x, out_splits, in_splits, group): + if isinstance(out_splits, torch.Tensor): + out_splits = out_splits.tolist() + if isinstance(in_splits, torch.Tensor): + in_splits = in_splits.tolist() + T_out = int(sum(out_splits)) + + y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits + dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group) + + ctx.in_splits = in_splits + ctx.out_splits = out_splits + ctx.group = group + return y + + @staticmethod + def backward(ctx, grad_y): + # grad wrt input has length sum(in_splits) + T_in = int(sum(ctx.in_splits)) + grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:])) + dist.all_to_all_single( + grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group + ) + return grad_x, None, None, None + + +def all_to_all_single_autograd(x, out_splits, in_splits, group): + return _A2A.apply(x, out_splits, in_splits, group) + + TOKEN_GROUP_ALIGN_SIZE_M = 8 ValidTokenGroupAlignmentSize = Literal[8, 16, 32] diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 01e14cc0b0..3108049a6f 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -232,8 +232,3 @@ def seq_len_divisor(self): # when load balancing is enabled (by default). # https://github.com/pytorch/pytorch/blob/4f62dcc/torch/distributed/tensor/experimental/_attention.py#L1246 return self.tp * (self.cp * 2) - - @cached_property - def dense_params_mesh_ndim(self): - # Note: In dp2ep EP, EP params mesh ndim is 1 more due to the 'ep' mesh - return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled diff --git a/torchtitan/distributed/pipeline.py b/torchtitan/distributed/pipeline_parallel.py similarity index 100% rename from torchtitan/distributed/pipeline.py rename to torchtitan/distributed/pipeline_parallel.py diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index aa25149d66..7d4dc935c3 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -16,11 +16,9 @@ from torch import distributed as dist from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor -from torch.nn.attention import SDPBackend from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP from torchtitan.distributed.parallel_dims import ParallelDims -from torchtitan.models.attention import ScaledDotProductAttention from torchtitan.tools.logging import logger from torchtitan.tools.utils import device_module, device_type @@ -202,6 +200,10 @@ def context(cp_context: Generator[None, None, None] | None = None): ) if cp_context is not None: + from torch.nn.attention import SDPBackend + + from torchtitan.models.attention import ScaledDotProductAttention + if SDPBackend.MATH in ScaledDotProductAttention.backends: ScaledDotProductAttention.backends.remove(SDPBackend.MATH) assert ( @@ -319,7 +321,7 @@ def clip_grad_norm_( error_if_nonfinite: bool = False, foreach: bool | None = None, pp_mesh: DeviceMesh | None = None, - ep_dense_params_mesh_ndim: int | None = None, + ep_enabled: bool = False, ) -> torch.Tensor: """ Clip the gradient norm of an iterable of parameters. @@ -349,7 +351,7 @@ def clip_grad_norm_( Total norm of the parameter gradients (viewed as a single vector). """ - if ep_dense_params_mesh_ndim is not None: + if ep_enabled: return _clip_grad_norm_with_ep( parameters, max_norm, @@ -357,7 +359,6 @@ def clip_grad_norm_( error_if_nonfinite, foreach, pp_mesh, - ep_dense_params_mesh_ndim, ) if isinstance(parameters, torch.Tensor): @@ -401,7 +402,6 @@ def _clip_grad_norm_with_ep( error_if_nonfinite: bool, foreach: bool | None, pp_mesh: DeviceMesh | None, - dense_params_mesh_ndim: int, ) -> torch.Tensor: ep_params = [] non_ep_params = [] @@ -412,12 +412,12 @@ def _clip_grad_norm_with_ep( if p.grad is None: continue assert isinstance(p, DTensor) and isinstance(p.grad, DTensor) - if p.device_mesh.ndim == dense_params_mesh_ndim: - non_ep_params.append(p) - non_ep_grads.append(p.grad) - else: + if "ep" in p.device_mesh.mesh_dim_names: ep_params.append(p) ep_grads.append(p.grad) + else: + non_ep_params.append(p) + non_ep_grads.append(p.grad) ep_grads_total_norm = torch.nn.utils.get_total_norm( ep_grads, norm_type, error_if_nonfinite, foreach ).full_tensor() diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index c54fc645c4..0bebd197d1 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -231,11 +231,7 @@ def train_step( pp_mesh=( parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None ), - ep_dense_params_mesh_ndim=( - parallel_dims.dense_params_mesh_ndim - if parallel_dims.ep_enabled - else None - ), + ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() self.optimizers.step() diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index 7e3dd8f07c..0ffe139dae 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -6,15 +6,16 @@ from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.models.llama3 import pipeline_llama +from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from .infra.parallelize import parallelize_llama from .model.args import TransformerModelArgs from .model.model import Transformer -from .optimizer import build_llama4_optimizers __all__ = [ "TransformerModelArgs", @@ -40,7 +41,7 @@ multiple_of=2048, rope_theta=500000, max_seq_len=10485760, - num_experts=16, + moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, ), "17bx128e": TransformerModelArgs( @@ -51,7 +52,7 @@ ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, - num_experts=128, + moe_args=MoEArgs(num_experts=128), ), "debugmodel_irope": TransformerModelArgs( dim=256, @@ -73,7 +74,7 @@ multiple_of=2048, rope_theta=500000, max_seq_len=10485760, - num_experts=16, + moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, every_n_layers_nope=4, use_flex_attn=True, @@ -87,7 +88,7 @@ ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, - num_experts=128, + moe_args=MoEArgs(num_experts=128), every_n_layers_nope=4, use_flex_attn=True, attn_mask_type="block_causal", @@ -102,7 +103,7 @@ model_args=llama4_configs, parallelize_fn=parallelize_llama, pipelining_fn=pipeline_llama, - build_optimizers_fn=build_llama4_optimizers, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_hf_tokenizer, diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index b1e60f9962..4a7a860680 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -21,16 +21,16 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims -from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp -from torchtitan.tools.logging import logger - -from .expert_parallel import ( +from torchtitan.distributed.expert_parallel import ( ExpertParallel, ExpertTensorParallel, NoParallel, TensorParallel, ) +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp +from torchtitan.tools.logging import logger + def parallelize_llama( model: nn.Module, diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index 741f00fd4e..dda130548d 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -5,11 +5,13 @@ # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass +from dataclasses import dataclass, field from torch import nn from torchtitan.config import JobConfig + +from torchtitan.models.moe import MoEArgs from torchtitan.protocols import BaseModelArgs from torchtitan.tools.logging import logger from torchtitan.tools.utils import has_cuda_capability @@ -34,7 +36,6 @@ class TransformerModelArgs(BaseModelArgs): use_flex_attn: bool = False attn_mask_type: str = "causal" - eos_id: int = 0 # iRoPE settings # When ``every_n_layers_nope`` is specified, NoPE (no positional embedding) is # used every n layers. Other layers uses RoPE (rotary positional embedding) and @@ -45,17 +46,11 @@ class TransformerModelArgs(BaseModelArgs): every_n_layers_nope: int | None = None fixed_attn_block_size: int = 8192 - # MoE args - moe_enabled: bool = True - num_experts: int = 8 - use_shared_expert: bool = True + # MoE + moe_args: MoEArgs = field(default_factory=MoEArgs) auto_scale_hidden_dim: bool = True # frequency of using MoE layer instead of feedforward layer in a transformer block interleave_moe_layer_step: int = 2 - # token-choice - top_k: int = 1 - use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation - load_balance_coeff: float | None = 1e-3 def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len @@ -65,11 +60,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len - if self.use_grouped_mm and not has_cuda_capability(9, 0): + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): logger.warning( "Failed to use grouped mm, which is only supported on SM90 or later", ) - self.use_grouped_mm = False + self.moe_args.use_grouped_mm = False if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: raise NotImplementedError( @@ -112,7 +107,7 @@ def get_nparams_and_flops( nparams_sparse_active = ( nparams_moe_router + nparams_shared_expert - + nparams_experts * self.top_k // self.num_experts + + nparams_experts * self.moe_args.top_k // self.moe_args.num_experts ) logger.info( diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py index 4e276efbbc..eb46a22b00 100644 --- a/torchtitan/experiments/llama4/model/model.py +++ b/torchtitan/experiments/llama4/model/model.py @@ -10,10 +10,10 @@ from torch import nn from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.moe import MoE from torchtitan.protocols import ModelProtocol from .args import TransformerModelArgs -from .moe import MoE def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: @@ -296,12 +296,25 @@ def __init__( self.attention = Attention(model_args, attn_use_rope, fixed_attn_block_size) # use MoE layer for every interleave_moe_layer_step FFN layers - self.moe_enabled = ( - model_args.moe_enabled - and (layer_id + 1) % model_args.interleave_moe_layer_step == 0 - ) + moe_args = model_args.moe_args + self.moe_enabled = (layer_id + 1) % model_args.interleave_moe_layer_step == 0 if self.moe_enabled: - self.moe = MoE(model_args) + dim = model_args.dim + hidden_dim = 4 * model_args.dim + ffn_dim_multiplier = model_args.ffn_dim_multiplier + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + + hidden_dim_denom = 1 + if model_args.auto_scale_hidden_dim: + hidden_dim_denom = moe_args.top_k + moe_args.num_shared_experts + + if model_args.auto_scale_hidden_dim: + hidden_dim = int(hidden_dim / hidden_dim_denom) + hidden_dim += -hidden_dim % model_args.multiple_of + + self.moe = MoE(moe_args, dim=dim, hidden_dim=hidden_dim) else: self.feed_forward = FeedForward( dim=model_args.dim, diff --git a/torchtitan/experiments/llama4/optimizer.py b/torchtitan/experiments/llama4/optimizer.py deleted file mode 100644 index 0986452fae..0000000000 --- a/torchtitan/experiments/llama4/optimizer.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn - -from torchtitan.components.ft import FTManager -from torchtitan.components.optimizer import build_optimizers, OptimizersContainer -from torchtitan.config import Optimizer as OptimizerConfig -from torchtitan.distributed import ParallelDims - - -# for MoE auxiliary-loss-free load balancing -def _update_expert_bias( - model_parts: list[nn.Module], - parallel_dims: ParallelDims, -): - dp_cp_mesh = ( - parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None - ) - # TODO: Currently this sync is blocking (thus exposed) and happens on the - # default compute stream. Need to assess if this is OK performance-wise. - for model_part in model_parts: - for transformer_block in model_part.layers.values(): - if transformer_block.moe_enabled: - moe = transformer_block.moe - if moe.load_balance_coeff is None: - return - - if dp_cp_mesh is not None: - torch.distributed.all_reduce( - moe.tokens_per_expert, group=dp_cp_mesh.get_group() - ) - - with torch.no_grad(): - expert_bias_delta = moe.load_balance_coeff * torch.sign( - moe.tokens_per_expert.mean() - moe.tokens_per_expert - ) - expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() - moe.expert_bias.add_(expert_bias_delta) - moe.tokens_per_expert.zero_() - - -def build_llama4_optimizers( - model_parts: list[nn.Module], - optimizer_config: OptimizerConfig, - parallel_dims: ParallelDims, - ft_manager: FTManager | None = None, -) -> OptimizersContainer: - optimizers = build_optimizers( - model_parts=model_parts, - optimizer_config=optimizer_config, - parallel_dims=parallel_dims, - ft_manager=ft_manager, - ) - - optimizers.register_step_pre_hook( - lambda *args, **kwargs: _update_expert_bias( - model_parts, parallel_dims=parallel_dims - ) - ) - - return optimizers diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index fb672cc4c7..a7f068c073 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -63,7 +63,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = "none" # ["none", "selective", "full"] +mode = "selective" # ["none", "selective", "full"] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8] diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index 6698852b47..38742cc716 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -1,4 +1,4 @@ -# DeepSeek-V3 in TorchTitan +# DeepSeek-V3 in `torchtitan` DeepSeek-V3 is a Mixture-of-Experts (MoE) transformer model with Multi-head Latent Attention (MLA) architecture. @@ -50,11 +50,8 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml ## To be added -- Modeling - - Merge DeepSeek-V3 and Llama4 MoE common components - - Attention Layer: need to pass softmax_scale to sdpa() to support scaling - Parallelism - - Context Parallel support for DeepSeek-V3 + - Context Parallel support for DeepSeek V3 - torch.compile - Quantization - Testing diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 8243a0a84a..a39b35dfa2 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -8,10 +8,11 @@ from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader -from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers from torchtitan.models.llama3.infra.pipeline import pipeline_llama +from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import register_train_spec, TrainSpec @@ -36,10 +37,14 @@ n_layers=3, n_dense_layers=1, n_heads=16, - n_routed_experts=8, - n_shared_experts=2, - n_activated_experts=3, - route_scale=1.0, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=True, + score_before_experts=False, + ), q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, @@ -55,10 +60,14 @@ n_layers=3, n_dense_layers=1, n_heads=16, - n_routed_experts=8, - n_shared_experts=2, - n_activated_experts=3, - route_scale=1.0, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=True, + score_before_experts=False, + ), q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, @@ -76,10 +85,14 @@ n_layers=27, n_dense_layers=1, n_heads=16, - n_routed_experts=64, - n_shared_experts=2, - n_activated_experts=6, - route_scale=1.0, + moe_args=MoEArgs( + num_experts=64, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=True, + score_before_experts=False, + ), q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, @@ -95,12 +108,17 @@ n_layers=60, n_dense_layers=1, n_heads=128, - n_routed_experts=160, - n_shared_experts=2, - n_activated_experts=6, + moe_args=MoEArgs( + num_experts=160, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=True, + route_scale=16.0, + score_before_experts=False, + ), n_expert_groups=8, n_limited_groups=3, - route_scale=16.0, q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, @@ -115,13 +133,17 @@ n_layers=61, n_dense_layers=3, n_heads=128, - n_routed_experts=256, - n_shared_experts=1, - n_activated_experts=8, + moe_args=MoEArgs( + num_experts=256, + num_shared_experts=1, + top_k=8, + score_func="sigmoid", + route_norm=True, + route_scale=2.5, + score_before_experts=False, + ), n_expert_groups=8, n_limited_groups=4, - route_scale=2.5, - score_func="sigmoid", q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, @@ -139,7 +161,7 @@ model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, pipelining_fn=pipeline_llama, - build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights + build_optimizers_fn=build_optimizers_with_moe_load_balancing, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_hf_tokenizer, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 532358b2da..8e289f01fb 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -17,7 +17,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims -from torchtitan.experiments.llama4.infra.expert_parallel import NoParallel +from torchtitan.distributed.expert_parallel import NoParallel from torchtitan.experiments.llama4.infra.parallelize import apply_fsdp, apply_moe_ep_tp from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp from torchtitan.tools.logging import logger diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index cd94104cdb..025a550b9b 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -7,12 +7,13 @@ # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Literal from torch import nn from torchtitan.config import JobConfig +from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger from torchtitan.tools.utils import has_cuda_capability @@ -67,16 +68,13 @@ class DeepSeekV3ModelArgs(BaseModelArgs): n_dense_layers: int = 1 n_heads: int = 16 norm_eps: float = 1e-5 # eps used for RMSNorm + # MoE - n_routed_experts: int = 64 - n_shared_experts: int = 2 - n_activated_experts: int = 6 + moe_args: MoEArgs = field(default_factory=MoEArgs) + # TODO: node-limited routing is not supported yet n_expert_groups: int = 1 n_limited_groups: int = 1 - score_func: Literal["softmax", "sigmoid"] = "softmax" - route_scale: float = 1.0 - use_grouped_mm: bool = True - load_balance_coeff: float = 1e-3 + # Multi-Head Latent Attention (MLA) q_lora_rank: int = 0 kv_lora_rank: int = 512 @@ -85,6 +83,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): v_head_dim: int = 128 use_flex_attn: bool = False attn_mask_type: str = "causal" + # yarn original_seq_len: int = 4096 rope_theta: float = 10000.0 @@ -101,11 +100,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len - if self.use_grouped_mm and not has_cuda_capability(9, 0): + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): logger.warning( "Failed to use grouped mm, which is only supported on SM90 or later", ) - self.use_grouped_mm = False + self.moe_args.use_grouped_mm = False if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: raise NotImplementedError( @@ -149,7 +148,7 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in nparams_sparse_active = ( nparams_moe_router + nparams_shared_expert - + nparams_experts * self.n_activated_experts // self.n_routed_experts + + nparams_experts * self.moe_args.top_k // self.moe_args.num_experts ) logger.info( diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 1d92c12545..cfdc794ca9 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -8,12 +8,50 @@ from typing import Tuple import torch +import torch.nn.functional as F from torch import nn + from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.moe import MoE from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs -from .moe import FeedForward, MoE + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float = 0.02): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) # Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 @@ -269,10 +307,14 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.attention = Attention(model_args) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.moe_enabled = layer_id >= model_args.n_dense_layers + self.moe_enabled = layer_id >= model_args.n_dense_layers if self.moe_enabled: - self.moe = MoE(model_args) + self.moe = MoE( + model_args.moe_args, + dim=model_args.dim, + hidden_dim=model_args.moe_inter_dim, + ) else: self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim) diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py deleted file mode 100644 index 02a094686c..0000000000 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ /dev/null @@ -1,375 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn.functional as F -from torch import nn -from torchtitan.experiments.llama4.infra.expert_parallel import expert_parallel - -from .args import DeepSeekV3ModelArgs - - -class FeedForward(nn.Module): - """ - FeedForward module - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. - - Attributes: - w1 (Linear): Linear transformation for the first layer. - w2 (Linear): Linear transformation for the second layer. - w3 (Linear): Linear transformation for the third layer. - - """ - - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - def init_weights(self, init_std: float = 0.02): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) - for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) - - -class GroupedExperts(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - num_experts: int, - use_grouped_mm: bool, - ): - super().__init__() - self.num_experts = num_experts - self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) - self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) - self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) - self.use_grouped_mm = use_grouped_mm - - def forward( - self, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if self.use_grouped_mm: - return GroupedExperts._run_experts_grouped_mm( - self.w1, self.w2, self.w3, x, num_tokens_per_expert - ) - else: - return GroupedExperts._run_experts_for_loop( - self.w1, self.w2, self.w3, x, num_tokens_per_expert - ) - - # TODO: keeping this for-loop implementation for comparison - # and readability, may remove later - @expert_parallel - @staticmethod - def _run_experts_for_loop( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() - - # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) - - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx])) - h = h * torch.matmul(x_expert, w3[expert_idx]) - h = torch.matmul(h, w2[expert_idx]) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - - # side-effect code due to the usage of generate_permute_indices - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, w1)) - h = h * torch.bmm(x, w3) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, w2) - - return out - - @expert_parallel - @staticmethod - def _run_experts_grouped_mm( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 - else: - offsets = None - # fall back to regular bmm between 3D tensors - assert x.dim() == 3 - - h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) - h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) - out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) - - return out - - def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) - nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) - - -class TokenChoiceTopKRouter(nn.Module): - """This class implements token-choice routing. In token-choice top-K routing, each token is - routed to top K experts based on the router scores. - - Args: - gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). - num_experts (int): Number of experts in each moe layer. - top_k (int): Number of experts each token will be routed to in token-choice routing. - use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. - """ - - def __init__( - self, - dim: int, - num_experts: int, - top_k: int, - use_sigmoid: bool = False, - route_sclaing_factor: float = 1.0, - ): - super().__init__() - - self.dim = dim - self.num_experts = num_experts - self.top_k = top_k - self.use_sigmoid = use_sigmoid - self.route_sclaing_factor = route_sclaing_factor - self.gate = nn.Linear(self.dim, self.num_experts, bias=False) - - def forward( - self, x: torch.Tensor, expert_bias: torch.Tensor | None = None - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - TODO: We haven't implement the group-based routing (node limit routing), - and currently EP is not supporting node limit routing yet. - - Args: - x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. - - Returns: - routed_input (torch.Tensor): - Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. - token_indices (torch.Tensor): - Token indices for routed_input with shape ``(bs*slen*top_k,)``. - num_tokens_per_expert (torch.Tensor): - Number of tokens assigned to each expert with shape ``(num_experts,)``. - """ - # scores shape (bs*slen, num_experts) - scores = self.gate(x) - - # By default, sigmoid or softmax is performed in float32 to avoid loss explosion - if self.use_sigmoid: - scores = torch.sigmoid(scores.to(torch.float32)) - else: - scores = F.softmax(scores.to(torch.float32), dim=1) - - # top scores shape (bs*slen, top_k) - # NOTE: The expert_bias is only used for routing. The gating value - # top_scores is still derived from the original scores. - if expert_bias is not None: - _, selected_experts_indices = torch.topk( - scores + expert_bias, k=self.top_k, dim=1 - ) - top_scores = scores.gather(dim=1, index=selected_experts_indices) - else: - top_scores, selected_experts_indices = torch.topk( - scores, k=self.top_k, dim=1 - ) - - if self.use_sigmoid: - denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 - top_scores = top_scores / denominator - - # group tokens together by expert indices from 0 to num_experts and pass that to experts forward - num_tokens_per_expert = torch.histc( - selected_experts_indices.view(-1), - bins=self.num_experts, - min=0, - max=self.num_experts, - ) - - # Reorder the token indices to match the order of the experts - # token_indices_experts_sorted shape (bs*slen*top_k,) - token_indices_experts_sorted = torch.argsort( - selected_experts_indices.view(-1), stable=True - ) - - # reorder the scores to match the order of the token indices - top_scores = top_scores.view(-1)[token_indices_experts_sorted] - token_indices_experts_sorted = token_indices_experts_sorted // self.top_k - - top_scores = ( - top_scores * self.route_sclaing_factor - ) # must multiply the scaling factor - return top_scores, token_indices_experts_sorted, num_tokens_per_expert - - def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) - - -class MoE(nn.Module): - def __init__(self, model_args: DeepSeekV3ModelArgs): - - super().__init__() - dim = model_args.dim - - num_experts = model_args.n_routed_experts - hidden_dim = model_args.moe_inter_dim - top_k = model_args.n_activated_experts - route_scaling_factor = model_args.route_scale - - self.experts = GroupedExperts( - dim=dim, - hidden_dim=hidden_dim, - num_experts=num_experts, - use_grouped_mm=model_args.use_grouped_mm, - ) - self.router = TokenChoiceTopKRouter( - dim=dim, - num_experts=num_experts, - top_k=top_k, - use_sigmoid=model_args.score_func == "sigmoid", - route_sclaing_factor=route_scaling_factor, - ) - self.shared_expert = ( - # Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py#L517 - GroupedExperts( - dim=dim, - hidden_dim=hidden_dim * model_args.n_shared_experts, - num_experts=1, # Here needs to be 1 to make it equivalent to the MLP - use_grouped_mm=model_args.use_grouped_mm, - ) - if model_args.n_shared_experts > 0 - else None - ) - - # auxiliary-loss-free load balancing - self.load_balance_coeff = model_args.load_balance_coeff - if self.load_balance_coeff is not None: - assert self.load_balance_coeff > 0.0 - self.register_buffer( - "expert_bias", - torch.zeros(num_experts, dtype=torch.float32), - ) - self.register_buffer( - "tokens_per_expert", - torch.zeros(num_experts, dtype=torch.float32), - ) - else: - self.expert_bias = None - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. - - Returns: - out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. - """ - bs, slen, dim = x.shape - - # top_scores and selected_indices shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - ( - top_scores, - token_indices, - num_tokens_per_expert, - ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) - - # tokens_per_expert will be used to update the expert bias for load balancing. - # Prevent extra local tokens accumulation on evaluation or activation recomputation. - if self.load_balance_coeff is not None and torch.is_grad_enabled(): - with torch.no_grad(): - self.tokens_per_expert.add_(num_tokens_per_expert) - # shape (bs*slen*top_k, dim) - token_indices = token_indices.reshape(-1, 1).expand(-1, dim) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather( - x.view(-1, dim), - dim=0, - index=token_indices, - ) - - # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_tokens_per_expert) - - routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( - x.dtype - ) - - # shared expert - if self.shared_expert is not None: - out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( - bs * slen, dim - ) - else: - out = torch.zeros_like(x.reshape(bs * slen, dim)) - - # Accumulate multiple expert results becase each token can be routed to multiple experts - out = out.scatter_add(dim=0, index=token_indices, src=routed_output) - out = out.reshape(bs, slen, dim) - return out - - def init_weights( - self, - init_std: float, - buffer_device: torch.device, - ): - self.experts.init_weights(init_std) - self.router.init_weights(init_std) - if self.shared_expert is not None: - self.shared_expert.init_weights(init_std) - - if self.load_balance_coeff is not None: - with torch.device(buffer_device): - self.expert_bias = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) - self.tokens_per_expert = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 64d080126b..093f89a18b 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -65,7 +65,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = "none" # ["none", "selective", "full"] +mode = "selective" # ["none", "selective", "full"] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8] diff --git a/torchtitan/models/llama3/infra/pipeline.py b/torchtitan/models/llama3/infra/pipeline.py index db3d6465e6..8741b2eef4 100644 --- a/torchtitan/models/llama3/infra/pipeline.py +++ b/torchtitan/models/llama3/infra/pipeline.py @@ -19,7 +19,7 @@ from torchtitan.components.loss import LossFunction from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims -from torchtitan.distributed.pipeline import ( +from torchtitan.distributed.pipeline_parallel import ( build_pipeline_schedule, generate_llm_fqn_per_model_part, pipeline_module_split, diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/models/moe.py similarity index 80% rename from torchtitan/experiments/llama4/model/moe.py rename to torchtitan/models/moe.py index 73a5d0a205..b8d777306c 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/models/moe.py @@ -4,13 +4,31 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Literal + import torch import torch.nn.functional as F from torch import nn -from ..infra.expert_parallel import expert_parallel +from torchtitan.distributed.expert_parallel import expert_parallel + + +@dataclass +class MoEArgs: + num_experts: int = 8 + num_shared_experts: int = 1 -from .args import TransformerModelArgs + # router + score_func: Literal["softmax", "sigmoid"] = "sigmoid" + route_norm: bool = False + route_scale: float = 1.0 + score_before_experts: bool = True + + # token-choice + top_k: int = 1 + use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation + load_balance_coeff: float | None = 1e-3 class GroupedExperts(nn.Module): @@ -142,13 +160,17 @@ def __init__( dim: int, num_experts: int, top_k: int, - use_sigmoid: bool = False, + score_func: Literal["softmax", "sigmoid"], + route_norm: bool, + route_scale: float, ): super().__init__() self.gate = nn.Linear(dim, num_experts, bias=False) self.num_experts = num_experts self.top_k = top_k - self.use_sigmoid = use_sigmoid + self.score_func = score_func + self.route_norm = route_norm + self.route_scale = route_scale def forward( self, x: torch.Tensor, expert_bias: torch.Tensor | None = None @@ -169,10 +191,12 @@ def forward( scores = self.gate(x) # By default, sigmoid or softmax is performed in float32 to avoid loss explosion - if self.use_sigmoid: + if self.score_func == "sigmoid": scores = torch.sigmoid(scores.to(torch.float32)) - else: + elif self.score_func == "softmax": scores = F.softmax(scores.to(torch.float32), dim=1) + else: + raise NotImplementedError(f"Unknown score function {self.score_function}") # top scores shape (bs*slen, top_k) # NOTE: The expert_bias is only used for routing. The gating value @@ -187,6 +211,11 @@ def forward( scores, k=self.top_k, dim=1 ) + if self.score_func == "sigmoid" and self.route_norm: + denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 + top_scores = top_scores / denominator + top_scores = top_scores * self.route_scale + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward num_tokens_per_expert = torch.histc( selected_experts_indices.view(-1), @@ -194,10 +223,13 @@ def forward( min=0, max=self.num_experts, ) + + # Reorder the token indices to match the order of the experts # token_indices_experts_sorted shape (bs*slen*top_k,) token_indices_experts_sorted = torch.argsort( selected_experts_indices.view(-1), stable=True ) + top_scores = top_scores.view(-1)[token_indices_experts_sorted] token_indices_experts_sorted = token_indices_experts_sorted // self.top_k @@ -208,50 +240,43 @@ def init_weights(self, init_std: float): class MoE(nn.Module): - def __init__(self, model_args: TransformerModelArgs): + def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): super().__init__() - dim = model_args.dim - hidden_dim = 4 * model_args.dim - ffn_dim_multiplier = model_args.ffn_dim_multiplier - hidden_dim = int(2 * hidden_dim / 3) - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - - num_experts = model_args.num_experts - - hidden_dim_denom = 1 - if model_args.auto_scale_hidden_dim: - hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert) - - if model_args.auto_scale_hidden_dim: - hidden_dim = int(hidden_dim / hidden_dim_denom) - hidden_dim += -hidden_dim % model_args.multiple_of + num_experts = moe_args.num_experts self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, num_experts=num_experts, - use_grouped_mm=model_args.use_grouped_mm, + use_grouped_mm=moe_args.use_grouped_mm, ) self.router = TokenChoiceTopKRouter( - dim=dim, num_experts=num_experts, top_k=model_args.top_k + dim=dim, + num_experts=num_experts, + top_k=moe_args.top_k, + score_func=moe_args.score_func, + route_norm=moe_args.route_norm, + route_scale=moe_args.route_scale, ) self.shared_expert = ( GroupedExperts( dim=dim, - hidden_dim=hidden_dim, + # TODO: if it doesn't use GroupedExperts.num_experts + # we can just use normal FeedForward + hidden_dim=hidden_dim * moe_args.num_shared_experts, num_experts=1, - use_grouped_mm=model_args.use_grouped_mm, + use_grouped_mm=moe_args.use_grouped_mm, ) - if model_args.use_shared_expert + if moe_args.num_shared_experts > 0 else None ) + self.score_before_experts = moe_args.score_before_experts # define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664) # NOTE: tokens_per_expert is accumulated in the model forward pass. # expert_bias is updated outside the model in an optimzer step pre hook # to work with gradient accumulation. - self.load_balance_coeff = model_args.load_balance_coeff + self.load_balance_coeff = moe_args.load_balance_coeff if self.load_balance_coeff is not None: assert self.load_balance_coeff > 0.0 self.register_buffer( @@ -284,8 +309,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) # tokens_per_expert will be used to update the expert bias for load balancing. - # Prevent extra local tokens accumulation on evaluation or activation recomputation. - if self.load_balance_coeff is not None and torch.is_grad_enabled(): + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- + # first in the forward pass, and then in the backward pass. However, this has no + # effect on the expert bias update thanks to the torch.sign() operator. + if self.load_balance_coeff is not None: with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) @@ -298,13 +325,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dim=0, index=token_indices, ) - routed_input = (routed_input.to(torch.float32) * top_scores.reshape(-1, 1)).to( - x.dtype - ) + + if self.score_before_experts: + routed_input = ( + routed_input.to(torch.float32) * top_scores.reshape(-1, 1) + ).to(x.dtype) # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) + if not self.score_before_experts: + routed_output = ( + routed_output.to(torch.float32) * top_scores.reshape(-1, 1) + ).to(x.dtype) + # shared expert if self.shared_expert is not None: out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( diff --git a/torchtitan/train.py b/torchtitan/train.py index 04ad969046..807dea8bc5 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -480,11 +480,7 @@ def train_step( pp_mesh=( parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None ), - ep_dense_params_mesh_ndim=( - parallel_dims.dense_params_mesh_ndim - if parallel_dims.ep_enabled - else None - ), + ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() self.optimizers.step() From be211c8234e6ab210fab5f730890b65f256ceb34 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Tue, 5 Aug 2025 23:04:59 -0700 Subject: [PATCH 073/128] separate out diloco configs (#1516) --- torchtitan/components/ft/config/__init__.py | 13 +++ torchtitan/components/ft/config/job_config.py | 80 +++++++++++++++++++ torchtitan/components/ft/diloco/utils.py | 2 +- torchtitan/components/ft/manager.py | 2 +- torchtitan/config/job_config.py | 59 -------------- torchtitan/models/__init__.py | 1 + 6 files changed, 96 insertions(+), 61 deletions(-) create mode 100644 torchtitan/components/ft/config/__init__.py create mode 100644 torchtitan/components/ft/config/job_config.py diff --git a/torchtitan/components/ft/config/__init__.py b/torchtitan/components/ft/config/__init__.py new file mode 100644 index 0000000000..4936e3303e --- /dev/null +++ b/torchtitan/components/ft/config/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.ft.config.job_config import FaultTolerance, JobConfig + + +__all__ = [ + "FaultTolerance", + "JobConfig", +] diff --git a/torchtitan/components/ft/config/job_config.py b/torchtitan/components/ft/config/job_config.py new file mode 100644 index 0000000000..c5bc309f72 --- /dev/null +++ b/torchtitan/components/ft/config/job_config.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from torchtitan.config.job_config import FaultTolerance as BaseFaultTolerance + + +@dataclass +class FaultTolerance(BaseFaultTolerance): + """ + Extends fault tolerance to also support Streaming DiLoCo + """ + + sync_steps: int = 5 + """ + Number of steps to wait before performing synchronization. This is only used when "semi_sync_method" + is set. + """ + + should_quantize: bool = False + """ + Whether to quantize the gradients before allreduce. + + Disabled by default since the quantization does utilize the GPU + and uses more collectives. Enabling this requires knowing about + the tradeoffs between GPU utilization and communication. + + + This is only used when "semi_sync_method" is set. + """ + + fragment_sync_delay: int = 0 + """ + Controls the number of inner steps to wait before blocking on a + model fragment's synchronization. This is the "tao" parameter in + the Streaming DiLoCo paper. + + By default, each model fragment will be synced at the same step + at which the allreduce is issued. Enabling delay can improve + communication and computation overlap, but at the cost of compromising + model quality + + This is only used when "semi_sync_method" is set. + """ + + fragment_update_alpha: float = 0.0 + """ + Determines how to mix the local and global optimized parameters + + By default, we just use the global parameters. This ensures all + DDP replicas have the same parameters after syncrhonizing on + the fragment. Tuning this can also affect the model quality. + + This is only used when "semi_sync_method" is set. + """ + + module_fqns_per_model_fragment: list[list[str]] = field(default_factory=list) + """ + Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model fragment. + Each inner list represents one model fragment and contains the module names that belong to that fragment. + e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']] + will create 3 chunks: the first containing tok_embeddings and layers.0, + the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4. + """ + + num_fragments: int = 1 + """ + Number of fragments to split the model into. This is only used when "semi_sync_method" is "diloco". + This is used to automatically split the model into fragments provided that the model + implements FaultTolerantTrainSpec + """ + + +@dataclass +class JobConfig: + fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) diff --git a/torchtitan/components/ft/diloco/utils.py b/torchtitan/components/ft/diloco/utils.py index f83759cff6..f7eaf26593 100644 --- a/torchtitan/components/ft/diloco/utils.py +++ b/torchtitan/components/ft/diloco/utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch.nn as nn -from torchtitan.config.job_config import FaultTolerance as FTConfig +from torchtitan.components.ft.config import FaultTolerance as FTConfig from torchtitan.distributed.pipeline import generate_llm_fqn_per_model_part diff --git a/torchtitan/components/ft/manager.py b/torchtitan/components/ft/manager.py index 1a33222c1e..6431fbc38c 100644 --- a/torchtitan/components/ft/manager.py +++ b/torchtitan/components/ft/manager.py @@ -15,7 +15,7 @@ import torch.nn as nn from torch.distributed._composable.fsdp.fully_shard import FSDPModule from torch.distributed.distributed_c10d import ReduceOp -from torchtitan.config.job_config import FaultTolerance as FTConfig +from torchtitan.components.ft.config import FaultTolerance as FTConfig if importlib.util.find_spec("torchft") is not None: import torchft as ft diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 1e13484291..d6c87585d3 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -653,65 +653,6 @@ class FaultTolerance: (https://github.com/pytorch/torchft/blob/360c5c534bdeac959507e9d238ba9f3902d3fda9/torchft/local_sgd.py#L41) """ - sync_steps: int = 5 - """ - Number of steps to wait before performing synchronization. This is only used when "semi_sync_method" - is set. - """ - - should_quantize: bool = False - """ - Whether to quantize the gradients before allreduce. - - Disabled by default since the quantization does utilize the GPU - and uses more collectives. Enabling this requires knowing about - the tradeoffs between GPU utilization and communication. - - - This is only used when "semi_sync_method" is set. - """ - - fragment_sync_delay: int = 0 - """ - Controls the number of inner steps to wait before blocking on a - model fragment's synchronization. This is the "tao" parameter in - the Streaming DiLoCo paper. - - By default, each model fragment will be synced at the same step - at which the allreduce is issued. Enabling delay can improve - communication and computation overlap, but at the cost of compromising - model quality - - This is only used when "semi_sync_method" is set. - """ - - fragment_update_alpha: float = 0.0 - """ - Determines how to mix the local and global optimized parameters - - By default, we just use the global parameters. This ensures all - DDP replicas have the same parameters after syncrhonizing on - the fragment. Tuning this can also affect the model quality. - - This is only used when "semi_sync_method" is set. - """ - - module_fqns_per_model_fragment: list[list[str]] = field(default_factory=list) - """ - Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model fragment. - Each inner list represents one model fragment and contains the module names that belong to that fragment. - e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']] - will create 3 chunks: the first containing tok_embeddings and layers.0, - the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4. - """ - - num_fragments: int = 1 - """ - Number of fragments to split the model into. This is only used when "semi_sync_method" is "diloco". - This is used to automatically split the model into fragments provided that the model - implements FaultTolerantTrainSpec - """ - @dataclass class Experimental: diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index 378f886658..bd6d6ee90c 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -9,3 +9,4 @@ # will be called. import torchtitan.models.deepseek_v3 # noqa: F401 import torchtitan.models.llama3 # noqa: F401 +import torchtitan.models.llama3_ft # noqa: F401 From 36ec54731bd0bc154d2a0b06f30ec922a3d144e1 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Wed, 6 Aug 2025 00:25:35 -0700 Subject: [PATCH 074/128] fix module import (#1537) --- torchtitan/components/ft/diloco/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/components/ft/diloco/utils.py b/torchtitan/components/ft/diloco/utils.py index f7eaf26593..4a65ed78e4 100644 --- a/torchtitan/components/ft/diloco/utils.py +++ b/torchtitan/components/ft/diloco/utils.py @@ -6,7 +6,7 @@ import torch.nn as nn from torchtitan.components.ft.config import FaultTolerance as FTConfig -from torchtitan.distributed.pipeline import generate_llm_fqn_per_model_part +from torchtitan.distributed.pipeline_parallel import generate_llm_fqn_per_model_part def module_split( From a1fdd7e43694bbfeff5d6ad8ac738c067bb90d41 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Wed, 6 Aug 2025 11:26:23 -0700 Subject: [PATCH 075/128] use logger in ft (#1539) Summary: - wasn't seeing print statements getting printed - the statements show up using the logger - also added some logging to validate the model is being split for diloco --- torchtitan/components/ft/diloco/utils.py | 8 ++++++-- torchtitan/components/ft/manager.py | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/ft/diloco/utils.py b/torchtitan/components/ft/diloco/utils.py index 4a65ed78e4..766ecab64b 100644 --- a/torchtitan/components/ft/diloco/utils.py +++ b/torchtitan/components/ft/diloco/utils.py @@ -7,6 +7,7 @@ import torch.nn as nn from torchtitan.components.ft.config import FaultTolerance as FTConfig from torchtitan.distributed.pipeline_parallel import generate_llm_fqn_per_model_part +from torchtitan.tools.logging import logger def module_split( @@ -98,7 +99,9 @@ def _build_fragment_from_modules( fragment_idx, module_names, ) - print(f"building fragment_idx {fragment_idx} " f"with modules {module_names}") + logger.info( + f"building fragment_idx {fragment_idx} " f"with modules {module_names}" + ) model_fragments.append(model_fragment) return model_fragments @@ -118,6 +121,7 @@ def fragment_llm( if module_fqns_per_model_fragment == []: if ft_config.num_fragments == 1: + logger.info("Created 1 model fragments") return [model] module_fqns_per_model_fragment = generate_llm_fqn_per_model_part( @@ -125,6 +129,6 @@ def fragment_llm( ) model_fragments = module_split(model, module_fqns_per_model_fragment) - print(f"Created {len(model_fragments)} model fragments") + logger.info(f"Created {len(model_fragments)} model fragments") return model_fragments diff --git a/torchtitan/components/ft/manager.py b/torchtitan/components/ft/manager.py index 6431fbc38c..38ec5173bd 100644 --- a/torchtitan/components/ft/manager.py +++ b/torchtitan/components/ft/manager.py @@ -16,6 +16,7 @@ from torch.distributed._composable.fsdp.fully_shard import FSDPModule from torch.distributed.distributed_c10d import ReduceOp from torchtitan.components.ft.config import FaultTolerance as FTConfig +from torchtitan.tools.logging import logger if importlib.util.find_spec("torchft") is not None: import torchft as ft @@ -125,6 +126,9 @@ def maybe_semi_sync_training( assert ( ft_manager._manager is not None ), "FTManager must be enabled to use semi-sync training." + logger.info( + f"using fragment function to split model: {fragment_fn is not None}" + ) if semi_sync_method.lower() == "diloco": if fragment_fn: model_parts = fragment_fn(model, ft_config, n_layers) From 23e4dfca5ca52587dbaf18f67bee8d73e875df5b Mon Sep 17 00:00:00 2001 From: Garrett Goon <44747910+garrett361@users.noreply.github.com> Date: Fri, 8 Aug 2025 19:35:55 -0400 Subject: [PATCH 076/128] fix: ep clipping with no ep grads (#1541) The current EP grad clipping logic assumes that when using EP all of the norms returned by `torch.nn.utils.get_total_norm` are `DTensor`s. This assumption can be violated and the subsequent `full_tensor` call can correspondingly fail in the edge case where the [ep_grad list](https://github.com/pytorch/torchtitan/blob/a1fdd7e43694bbfeff5d6ad8ac738c067bb90d41/torchtitan/distributed/utils.py?plain=1#L408) is empty, in which case `get_total_norm` returns `tensor(0.)`, a non-`DTensor`. https://github.com/pytorch/torchtitan/blob/a1fdd7e43694bbfeff5d6ad8ac738c067bb90d41/torchtitan/distributed/utils.py?plain=1#L421-L423 ``` File "/app/torchtitan/torchtitan/distributed/utils.py", line 423, in _clip_grad_norm_with_ep ).full_tensor() ^^^^^^^^^^^ AttributeError: 'Tensor' object has no attribute 'full_tensor' ``` This edge case can occur in PP+EP setups when model uses some fully dense and some MoE layers (like DSv3), in which case some pp ranks may not be assigned any MoE layers. I suppose it is possible that `non_ep_grads` could also be empty, but I can only imagine this happening in extreme cases, so I did not change the `non_ep_grads` code. CC @tianyu-l --- torchtitan/distributed/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 7d4dc935c3..13cd700eb2 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -420,7 +420,12 @@ def _clip_grad_norm_with_ep( non_ep_grads.append(p.grad) ep_grads_total_norm = torch.nn.utils.get_total_norm( ep_grads, norm_type, error_if_nonfinite, foreach - ).full_tensor() + ) + # ep_grads may be an empty list, in which case get_total_norm returns tensor(0.), a non-DTensor + # This can occur in PP + EP setups where certain PP ranks only own non-EP layers, for instance. + if isinstance(ep_grads_total_norm, DTensor): + ep_grads_total_norm = ep_grads_total_norm.full_tensor() + non_ep_grads_total_norm = torch.nn.utils.get_total_norm( non_ep_grads, norm_type, error_if_nonfinite, foreach ).full_tensor() From 2c8b5947991239913d67e2f7d22a255c3e2a9694 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Fri, 8 Aug 2025 16:49:22 -0700 Subject: [PATCH 077/128] Reorder validate and checkpoint in train (#1542) If validation and checkpoint occur on the same training step, do checkpointing first --- torchtitan/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 807dea8bc5..f55530a083 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -575,6 +575,10 @@ def train(self): logger.warning("Ran out of data; last step was canceled.") break + self.checkpointer.save( + self.step, last_step=(self.step == job_config.training.steps) + ) + # Run validation if validator is available if ( self.job_config.validation.enabled @@ -582,10 +586,6 @@ def train(self): ): self.validator.validate(self.model_parts, self.step) - self.checkpointer.save( - self.step, last_step=(self.step == job_config.training.steps) - ) - # signal the profiler that the next profiling step has started if torch_profiler: torch_profiler.step() From 59e57a4bb4ff96fb949d61af683f28aa89914361 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 11 Aug 2025 16:54:19 -0700 Subject: [PATCH 078/128] fix EP fsdp gradient divide factor (#1551) issue pointed out in https://github.com/pytorch/torchtitan/pull/1534#issuecomment-3157429435 https://github.com/pytorch/pytorch/issues/160285 solution given by @rakkit in https://github.com/pytorch/torchtitan/pull/1534#issuecomment-3157429435 --- torchtitan/distributed/parallel_dims.py | 12 +++++++++--- torchtitan/experiments/llama4/infra/parallelize.py | 8 ++++++++ torchtitan/models/deepseek_v3/infra/parallelize.py | 1 + 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 3108049a6f..bbb3874b57 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from functools import cached_property from torch.distributed.device_mesh import DeviceMesh, init_device_mesh @@ -219,11 +218,18 @@ def pp_enabled(self): def ep_enabled(self): return self.ep > 1 - @cached_property + @property + def fsdp_gradient_divide_factor(self) -> int: + # This is needed for FSDP-sharded experts when Expert Parallel is enabled. + # Although the FSDP sharding of experts is done on a mesh of a different size than + # other parameters, the gradient division factor should be consistent with data. + return self.dp_replicate * self.dp_shard * self.cp + + @property def non_data_parallel_size(self): return self.cp * self.tp * self.pp - @cached_property + @property def seq_len_divisor(self): # Sequence Parallel requires that seq_len be divisible by TP degree. # https://github.com/pytorch/torchtitan/pull/640#discussion_r1849481001 diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 4a7a860680..2db1c64b58 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -139,6 +139,7 @@ def parallelize_llama( if dp_mod_ep_mesh_dim_names else None ), + gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) if parallel_dims.dp_replicate_enabled: @@ -270,6 +271,7 @@ def apply_fsdp( cpu_offload: bool = False, reshard_after_forward_policy: str = "default", dp_mod_ep_mesh: DeviceMesh | None = None, + gradient_divide_factor: int | None = None, ): """ Apply data parallelism (via FSDP2) to the model. @@ -322,6 +324,12 @@ def apply_fsdp( **fsdp_mod_ep_config, reshard_after_forward=reshard_after_forward, ) + # NOTE: # Although the FSDP sharding of experts is done on a mesh of + # a different size than other parameters, the gradient division + # factor should be consistent with data. + transformer_block.moe.experts.set_gradient_divide_factor( + gradient_divide_factor, + ) fully_shard( transformer_block, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 8e289f01fb..8c9af6618c 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -122,6 +122,7 @@ def parallelize_deepseekv3( if dp_mod_ep_mesh_dim_names else None ), + gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) if parallel_dims.dp_replicate_enabled: From fd5a87fa668505899401fb0a1188f1543b0530dd Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Mon, 11 Aug 2025 22:35:43 -0700 Subject: [PATCH 079/128] Better Support for Huggingface Asset Integration (#1526) ## Better Support for Huggingface Asset Integration This pr adds richer support for Huggingface assets through the `download_hf_assets.py` script and support for them through the `model.hf_assets_path` specification in job_config. `model.hf_assets_path` will contain the tokenizer, safetensor weights, safetensor.index mapping, and hf config file corresponding to a specific model or huggingface repo. These different assets can be downloaded using the new `download_hf_assets.py` script. The `hf_assets_path` will be used whenever a tokenizer is needed, or in the initialization of `state_dict_adapter` to create a canonical`fqn_to_index_mapping` which may be smartly constructed for speedy loading by huggingface. ### Fixes This also fixes a bug between `save_in_hf` with `pp > 1` where saving would fail when gathering the ranks' sharded safetensor checkpoints into full safetensor checkpoints. This was due to some pipeline stages not knowing the correct, full fqn mapping due to us naively generating the fqn mappings locally. ### Breaking Changes This change deprecates `model.tokenizer_path`. It is replaced by `model.hf_assets_path` and adds naive legacy support by assigning `model.hf_assets_path` the tokenizer path if `model.tokenizer_path` is still specified. However, users will not be able to decouple `tokenizer_path` and `hf_assets_path`. --- .gitignore | 6 +- README.md | 2 +- docs/checkpoint.md | 4 +- .../checkpoint_conversion/convert_from_hf.py | 2 +- .../checkpoint_conversion/convert_to_hf.py | 20 +- scripts/download_hf_assets.py | 258 ++++++++++++++++++ scripts/download_tokenizer.py | 169 ------------ tests/unit_tests/test_tokenizer.py | 7 +- torchtitan/components/checkpoint.py | 38 ++- torchtitan/components/tokenizer.py | 2 +- torchtitan/config/job_config.py | 10 +- torchtitan/config/manager.py | 20 +- .../experiments/flux/dataset/tokenizer.py | 2 +- .../flux/tests/integration_tests.py | 4 +- torchtitan/experiments/llama4/README.md | 2 +- .../multimodal/check_padding_mm.py | 2 +- .../multimodal/tokenizer/tiktoken.py | 2 +- torchtitan/models/deepseek_v3/README.md | 4 +- .../models/llama3/model/state_dict_adapter.py | 28 +- .../llama3/train_configs/debug_model.toml | 2 +- .../llama3/train_configs/llama3_8b.toml | 2 +- torchtitan/protocols/state_dict_adapter.py | 5 +- torchtitan/train.py | 4 +- 23 files changed, 370 insertions(+), 225 deletions(-) create mode 100644 scripts/download_hf_assets.py delete mode 100644 scripts/download_tokenizer.py diff --git a/.gitignore b/.gitignore index 81fb607959..6df39a9ead 100644 --- a/.gitignore +++ b/.gitignore @@ -15,10 +15,8 @@ wandb torchtitan/datasets/**/*.model -# tokenizer models -assets/**/*.model -assets/**/*.json -assets/**/*.txt +# hf assets +assets/hf/* torchtitan/experiments/flux/assets/* # temp files diff --git a/README.md b/README.md index be0dbebc02..30e3ffae34 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,7 @@ Once you have confirmed access, you can run the following command to download th # Get your HF token from https://huggingface.co/settings/tokens # Llama 3.1 tokenizer -python scripts/download_tokenizer.py --repo_id meta-llama/Llama-3.1-8B --hf_token=... +python scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets tokenizer --hf_token=... ``` ### Start a training run diff --git a/docs/checkpoint.md b/docs/checkpoint.md index b662c52fda..da6598ca8d 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -68,13 +68,13 @@ NGPU=1 CONFIG_FILE= ./run_train.sh --checkpoint.enable_che ### HuggingFace `torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu. -1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_model_only` and `--checkpoint.initial_load_in_hf`, and set `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. +1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set. 2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt. ```bash python ./scripts/checkpoint_conversion/convert_from_hf.py --model_name --model_flavor -python ./scripts/checkpoint_conversion/convert_to_hf.py --model_name --model_flavor +python ./scripts/checkpoint_conversion/convert_to_hf.py --hf_assets_path ./assets/hf/Llama3.1-8B --model_name --model_flavor # e.g. python ./scripts/convert_from_hf.py ~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ ./initial_load_path/ --model_name llama3 --model_flavor 8B ``` diff --git a/scripts/checkpoint_conversion/convert_from_hf.py b/scripts/checkpoint_conversion/convert_from_hf.py index 42ed00bf27..f71af08363 100644 --- a/scripts/checkpoint_conversion/convert_from_hf.py +++ b/scripts/checkpoint_conversion/convert_from_hf.py @@ -24,7 +24,7 @@ def convert_from_hf(input_dir, output_dir, model_name, model_flavor): model = train_spec.model_cls(model_args) model = ModelWrapper(model) - sd_adapter = train_spec.state_dict_adapter(model_args) + sd_adapter = train_spec.state_dict_adapter(model_args, None) assert ( sd_adapter is not None ), "trying to convert checkpoint from HF to DCP safetensors format, but sd_adapter is not provided." diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py index 800b350789..39c46a16d2 100644 --- a/scripts/checkpoint_conversion/convert_to_hf.py +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -15,7 +15,7 @@ @torch.inference_mode() -def convert_to_hf(input_dir, output_dir, model_name, model_flavor): +def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_path): # load model and model args so that we can get the state dict shape train_spec = train_spec_module.get_train_spec(model_name) model_args = train_spec.model_args[model_flavor] @@ -24,7 +24,7 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor): model = train_spec.model_cls(model_args) model = ModelWrapper(model) - sd_adapter = train_spec.state_dict_adapter(model_args) + sd_adapter = train_spec.state_dict_adapter(model_args, hf_assets_path) assert ( sd_adapter is not None ), "trying to convert checkpoint from DCP to HF safetensors format, but sd_adapter is not provided." @@ -39,17 +39,10 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor): # convert state dict tt->hf hf_state_dict = sd_adapter.to_hf(state_dict) - fqn_to_index_mapping = {} - num_fqns_per_file = 30 - - for i, key in enumerate(hf_state_dict.keys()): - group_num = (i // num_fqns_per_file) + 1 - fqn_to_index_mapping[key] = group_num - storage_writer = HuggingFaceStorageWriter( path=output_dir, save_distributed=True, - fqn_to_index_mapping=fqn_to_index_mapping, + fqn_to_index_mapping=sd_adapter.fqn_to_index_mapping, enable_consolidation=True, thread_count_consolidation=5, ) @@ -68,6 +61,12 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor): parser.add_argument( "output_dir", type=Path, help="Output directory for HF checkpoint." ) + parser.add_argument( + "--hf_assets_path", + type=Path, + help="Path to HF assets directory. This is used to get the model.safetensors.index.json mapping", + default="./assets/hf/Llama3.1-8B", + ) parser.add_argument("--model_name", type=str, nargs="?", default="llama3") parser.add_argument("--model_flavor", type=str, nargs="?", default="8B") args = parser.parse_args() @@ -77,4 +76,5 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor): args.output_dir, args.model_name, args.model_flavor, + args.hf_assets_path, ) diff --git a/scripts/download_hf_assets.py b/scripts/download_hf_assets.py new file mode 100644 index 0000000000..017cc0a405 --- /dev/null +++ b/scripts/download_hf_assets.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from fnmatch import fnmatch +from typing import Optional + +from requests.exceptions import HTTPError +from tqdm import tqdm + + +def download_hf_assets( + repo_id: str, + local_dir: str, + asset_types: str | list[str], + download_all: bool = False, + hf_token: Optional[str] = None, + additional_patterns: Optional[list] = None, +) -> None: + """ + Download relevant files from HuggingFace Hub repository. + + This function recursively searches through the HuggingFace Hub repository + and downloads all related files + + Asset types: + - tokenizer: + - tokenizer.json - Modern HuggingFace tokenizers (complete definition) + - tokenizer_config.json - Tokenizer configuration and metadata + - tokenizer.model - SentencePiece model files (Llama, T5, etc.) + - vocab.txt - Plain text vocabulary files + - vocab.json - JSON vocabulary files + - merges.txt - BPE merge rules (GPT-2, RoBERTa style) + - special_tokens_map.json - Special token mappings + - safetensors + - *.safetensors - Modern Huggingface model weights format for fast loading + - model.safetensors.index.json - Contains mapping from hf fqn to file name + - index + - model.safetensors.index.json - Contains mapping from hf fqn to file name + - config + - config.json - Defines the model architecture + - generation_config.json - Defines the model architecture params needed for generation + + Args: + repo_id (str): HuggingFace repository ID (e.g., meta-llama/Llama-3.1-8B") + local_dir (str): Local directory to save tokenizer files. A subdirectory + named after the model will be created automatically. + asset_types (list[str]): List of the asset types to download + hf_token (Optional[str]): HuggingFace API token for accessing private repositories. + Required for gated models like Llama. + additional_patterns (Optional[list]): Additional file patterns to search for and download + from the HuggingFace Hub repository. + download_all (bool): If True, download all files from the repository + """ + import os + + from huggingface_hub import hf_hub_download, list_repo_files + + # Extract model name from repo_id (part after "/") + if "/" not in repo_id: + raise ValueError( + f"Invalid repo_id format: '{repo_id}'. Expected format: 'organization/model-name'" + ) + model_name = repo_id.split("/")[-1].strip() + model_dir = os.path.join(local_dir, model_name) + + ASSET_PATTERNS = { + "tokenizer": [ + "tokenizer.json", + "tokenizer_config.json", + "tokenizer.model", + "vocab.txt", + "vocab.json", + "merges.txt", + "special_tokens_map.json", + ], + "safetensors": ["*.safetensors", "model.safetensors.index.json"], + "index": ["model.safetensors.index.json"], + "config": ["config.json", "generation_config.json"], + } + + if isinstance(asset_types, str): + asset_types = [asset_types] + + if download_all: + print("Downloading all files from repository...") + files_found = list_repo_files(repo_id=repo_id, token=hf_token) + else: + total_patterns = [] + for asset_type in asset_types: + if asset_type in ASSET_PATTERNS: + total_patterns.extend(ASSET_PATTERNS[asset_type]) + else: + raise ValueError( + "Unknown asset type {}. Available uses: --asset {} \n".format( + asset_type, " ".join(ASSET_PATTERNS.keys()) + ), + "Or specify exact patterns to download. Example: --additional_patterns '*.safetensors' README.md '*.json' \n", + "Or use --all to download all files", + ) + + # Add additional patterns if provided + if additional_patterns: + total_patterns.extend(additional_patterns) + asset_types.append("additional_patterns") + ASSET_PATTERNS["additional_patterns"] = additional_patterns + + def should_download(patterns: list[str], filename: str) -> bool: + """Check if a file matches a pattern to be downloaded.""" + basename = os.path.basename(filename) + for pattern in patterns: + pattern_lower = pattern.lower() + + # Exact name match + if basename == pattern_lower: + return True + # Do wildcard match if wildcards are in pattern + if "*" in pattern_lower or "?" in pattern_lower: + if fnmatch(basename, pattern_lower): + return True + return False + + try: + # Get list of available files in the repo + print(f"Scanning repository {repo_id} for files...") + available_files = list_repo_files(repo_id=repo_id, token=hf_token) + + # Filter for requested asset files + files_found = [ + f for f in available_files if should_download(total_patterns, f) + ] + + # Check each asset type individually to see if files were not found + for asset_type in asset_types: + if asset_type in ASSET_PATTERNS: + asset_patterns = ASSET_PATTERNS[asset_type] + matches_found = False + for f in available_files: + if should_download(asset_patterns, f): + matches_found = True + break + + if not matches_found: + print( + f"Warning: No matching files found for asset_type '{asset_type}' in {repo_id}" + ) + + if not files_found: + print(f"Warning: No matching files found in {repo_id}") + print(f"Available files: {available_files[:10]}...") + return + + except HTTPError as e: + if e.response and e.response.status_code == 401: + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) + raise e + + print(f"Found {len(files_found)} files:") + for f in files_found: + print(f" - {f}") + + downloaded_files = [] + missed_files = [] + + # Download files with progress bar + with tqdm(total=len(files_found), desc="Downloading files", unit="file") as pbar: + for filename in files_found: + try: + pbar.set_description(f"Downloading {os.path.basename(filename)}") + + hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=model_dir, + token=hf_token, + ) + downloaded_files.append(filename) + pbar.update(1) + + except HTTPError as e: + if e.response and e.response.status_code == 404: + print(f"File {filename} not found, skipping...") + missed_files.append(filename) + pbar.update(1) + continue + else: + raise e + + if downloaded_files: + print( + f"\nSuccessfully downloaded {len(downloaded_files)} files to: {model_dir}" + ) + if missed_files: + print(f"Warning: Some files could not be downloaded: \n{missed_files}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Download files from HuggingFace Hub. " + "Automatically detects and downloads files that match the specified file-types to download. " + ) + parser.add_argument( + "--repo_id", + type=str, + required=True, + help="Repository ID to download from (e.g., 'meta-llama/Llama-3.1-8B', 'deepseek-ai/DeepSeek-V3')", + ) + parser.add_argument( + "--hf_token", + type=str, + default=None, + help="HuggingFace API token (required for private repos)", + ) + parser.add_argument( + "--local_dir", + type=str, + default="assets/hf/", + help="Local directory to save hf asset files (default: assets/hf/)", + ) + parser.add_argument( + "--assets", + type=str, + nargs="+", + default=[], + help="Asset types to download: tokenizer, safetensors, index, config", + ) + parser.add_argument( + "--additional_patterns", + type=str, + nargs="+", + default=[], + help="Additional file patterns to search for and download from the HuggingFace Hub repository", + ) + + parser.add_argument( + "--all", action="store_true", default=False, help="Download all files in repo" + ) + + args = parser.parse_args() + if not args.all and not args.assets and not args.additional_patterns: + parser.error( + "At least one of --all, --assets or --additional_patterns must be specified." + ) + + download_hf_assets( + args.repo_id, + args.local_dir, + args.assets, + args.all, + args.hf_token, + args.additional_patterns, + ) diff --git a/scripts/download_tokenizer.py b/scripts/download_tokenizer.py deleted file mode 100644 index 3996ac29a9..0000000000 --- a/scripts/download_tokenizer.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -from requests.exceptions import HTTPError - - -def download_hf_tokenizer_files( - repo_id: str, - local_dir: str, - hf_token: Optional[str] = None, - additional_patterns: Optional[list] = None, -) -> None: - """ - Download relevant tokenizer files from HuggingFace Hub repository. - - This function recursively searches through the HuggingFace Hub repository - and downloads all tokenizer-related files to enable tokenizer - loading with the build_hf_tokenizer() function. - - Files downloaded: - - tokenizer.json - Modern HuggingFace tokenizers (complete definition) - - tokenizer_config.json - Tokenizer configuration and metadata - - tokenizer.model - SentencePiece model files (Llama, T5, etc.) - - vocab.txt - Plain text vocabulary files - - vocab.json - JSON vocabulary files - - merges.txt - BPE merge rules (GPT-2, RoBERTa style) - - special_tokens_map.json - Special token mappings - - Args: - repo_id (str): HuggingFace repository ID (e.g., meta-llama/Llama-3.1-8B") - local_dir (str): Local directory to save tokenizer files. A subdirectory - named after the model will be created automatically. - hf_token (Optional[str]): HuggingFace API token for accessing private repositories. - Required for gated models like Llama. - additional_patterns (Optional[list]): Additional file patterns to search for and download - from the HuggingFace Hub repository. - """ - import os - - from huggingface_hub import hf_hub_download, list_repo_files - - # Extract model name from repo_id (part after "/") - if "/" not in repo_id: - raise ValueError( - f"Invalid repo_id format: '{repo_id}'. Expected format: 'organization/model-name'" - ) - model_name = repo_id.split("/")[-1].strip() - model_dir = os.path.join(local_dir, model_name) - - # Tokenizer file patterns to match (case-insensitive) - tokenizer_patterns = [ - "tokenizer.json", - "tokenizer_config.json", - "tokenizer.model", - "vocab.txt", - "vocab.json", - "merges.txt", - "special_tokens_map.json", - ] - - # Add additional files if provided - if additional_patterns: - tokenizer_patterns.extend(additional_patterns) - - def is_tokenizer_file(filename: str) -> bool: - """Check if a file is a tokenizer-related file.""" - filename_lower = filename.lower() - basename = os.path.basename(filename_lower) - - # Check exact matches - if basename in [pattern.lower() for pattern in tokenizer_patterns]: - return True - - return False - - try: - # Get list of available files in the repo - print(f"Scanning repository {repo_id} for tokenizer files...") - available_files = list_repo_files(repo_id=repo_id, token=hf_token) - - # Filter for tokenizer files - tokenizer_files_found = [f for f in available_files if is_tokenizer_file(f)] - - if not tokenizer_files_found: - print(f"Warning: No tokenizer files found in {repo_id}") - print(f"Available files: {available_files[:10]}...") - return - - print(f"Found {len(tokenizer_files_found)} tokenizer files:") - for f in tokenizer_files_found: - print(f" - {f}") - - downloaded_files = [] - for filename in tokenizer_files_found: - try: - hf_hub_download( - repo_id=repo_id, - filename=filename, - local_dir=model_dir, - token=hf_token, - ) - file_path = os.path.join(model_dir, filename) - print(f"Successfully downloaded {filename} to {file_path}") - downloaded_files.append(filename) - except HTTPError as e: - if e.response and e.response.status_code == 404: - print(f"File {filename} not found, skipping...") - continue - else: - raise e - - if downloaded_files: - print( - f"\nSuccessfully downloaded {len(downloaded_files)} tokenizer files to: {model_dir}" - ) - else: - print(f"Warning: No tokenizer files could be downloaded from {repo_id}") - - except HTTPError as e: - if e.response and e.response.status_code == 401: - print( - "You need to pass a valid `--hf_token=...` to download private checkpoints." - ) - raise e - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser( - description="Download tokenizer files from HuggingFace Hub. " - "Automatically detects and downloads common tokenizer files (tokenizer.json, " - "tokenizer_config.json, tokenizer.model, ...) that work with Tokenizer." - ) - parser.add_argument( - "--repo_id", - type=str, - required=True, - help="Repository ID to download from (e.g., 'meta-llama/Llama-3.1-8B', 'deepseek-ai/DeepSeek-V3')", - ) - parser.add_argument( - "--hf_token", - type=str, - default=None, - help="HuggingFace API token (required for private repos)", - ) - parser.add_argument( - "--local_dir", - type=str, - default="assets/tokenizer/", - help="Local directory to save tokenizer files (default: assets/tokenizer/)", - ) - parser.add_argument( - "--additional_patterns", - type=str, - nargs="*", - default=None, - help="Additional file patterns to search for and download from the HuggingFace Hub repository", - ) - - args = parser.parse_args() - download_hf_tokenizer_files( - args.repo_id, args.local_dir, args.hf_token, args.additional_patterns - ) diff --git a/tests/unit_tests/test_tokenizer.py b/tests/unit_tests/test_tokenizer.py index a7a3a7e623..a0306fde36 100644 --- a/tests/unit_tests/test_tokenizer.py +++ b/tests/unit_tests/test_tokenizer.py @@ -11,7 +11,7 @@ from requests.exceptions import HTTPError -from scripts.download_tokenizer import download_hf_tokenizer_files +from scripts.download_hf_assets import download_hf_assets from tokenizers import Tokenizer from torch.testing._internal.common_utils import ( @@ -23,7 +23,7 @@ class TestTokenizerIntegration(unittest.TestCase): - """Test integration between download_tokenizer and load_tokenizer functions.""" + """Test integration between download_hf_assets and load_tokenizer functions.""" def setUp(self): """Create a temporary directory for test files.""" @@ -262,9 +262,10 @@ def test_download_and_build_tokenizer(self, test_repo_id): """ # Step 1: Download tokenizer files try: - download_hf_tokenizer_files( + download_hf_assets( repo_id=test_repo_id, local_dir=self.temp_dir, + asset_types="tokenizer", ) except HTTPError as e: if test_repo_id == "meta-llama/Llama-3.1-8B": diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 5b649f5a80..e55f792dff 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -360,14 +360,7 @@ def dcp_save( ), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided." state_dict = self.sd_adapter.to_hf(state_dict) - fqn_to_index_mapping = {} - num_fqns_per_file = 30 - # the use of 30 is just a heuristic for now. - # Once these fqns map to HF ones, we can use the fqn mapping - # from the model.safetensors.index.json file - for i, key in enumerate(state_dict.keys()): - group_num = (i // num_fqns_per_file) + 1 - fqn_to_index_mapping[key] = group_num + fqn_to_index_mapping = self.sd_adapter.fqn_to_index_mapping storage_writer = HuggingFaceStorageWriter( path=checkpoint_id, @@ -539,18 +532,32 @@ def load(self, step: int = -1) -> bool: model_only = False from_hf = False if not os.path.exists(self.folder): + model_only = self.initial_load_model_only + from_hf = self.initial_load_in_hf + if from_hf: + assert ( + model_only + ), "Only model can be loaded when loading from HF's safetensors checkpoint." if self.initial_load_path: checkpoint_id = self.initial_load_path if not os.path.isdir(checkpoint_id): raise ValueError( "checkpoint.initial_load_path is specified but the path is not valid." ) - model_only = self.initial_load_model_only - from_hf = self.initial_load_in_hf if from_hf: - assert ( - model_only - ), "Only model can be loaded when loading from HF's safetensors checkpoint." + logger.info( + f"loading from HF safetensors from --checkpoint.initial_load_path: {self.initial_load_path}" + ) + elif from_hf: + checkpoint_id = self.sd_adapter.hf_assets_path + if not os.path.isdir(checkpoint_id): + raise ValueError( + "model.hf_assets_path is being used to load HF weights but the path is not valid. \ + Either make sure hf_assets_path is correct or provide a valid checkpoint.initial_load_path" + ) + logger.info( + f"loading HF safetensors from --model.hf_assets_path: {self.sd_adapter.hf_assets_path}" + ) else: return False else: @@ -559,6 +566,11 @@ def load(self, step: int = -1) -> bool: "checkpoint.initial_load_path is provided but the checkpoint.folder exists. " f"Checkpointer will use the checkpoints from the checkpoint.folder {self.folder}." ) + if self.initial_load_in_hf: + logger.warning( + "checkpoint.initial_load_in_hf is True but the checkpoint.folder exists. " + "Checkpointer will not load from HF safetensors" + ) step = self._find_load_step() if step == -1 else step if step == -1: return False diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index b0b7146945..24db9b3484 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -417,5 +417,5 @@ def build_hf_tokenizer( Returns: tokenizer (HuggingFaceTokenizer): Loaded tokenizer instance with intelligent BOS/EOS handling """ - tokenizer = HuggingFaceTokenizer(job_config.model.tokenizer_path) + tokenizer = HuggingFaceTokenizer(job_config.model.hf_assets_path) return tokenizer diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index d6c87585d3..f407fe6e78 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -78,7 +78,15 @@ class Model: flavor: str = "debugmodel" """Which model config to train""" - tokenizer_path: str = "./tests/assets/tokenizer" + hf_assets_path: str = "./tests/assets/tokenizer" + """ + Path to HF assets folder. This folder contains local copies of Hugging Face assets, + including model weights in .safetensors format, the model.safetensor.index.json file + (fqn to file mapping), the config.json file, generation_config.json, and tokenizer files. + """ + + tokenizer_path: str | None = None + """DEPRECATED: Use hf_assets_path instead.""" """Tokenizer path""" converters: list[str] = field(default_factory=list) diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py index d22b3d21fa..4ec6eb4ac9 100644 --- a/torchtitan/config/manager.py +++ b/torchtitan/config/manager.py @@ -170,27 +170,33 @@ def _dict_to_dataclass(self, cls, data: dict[str, Any]) -> Any: return cls(**result) def _validate_config(self) -> None: - # TODO: temporary mitigation of BC breaking change in + # TODO: temporary mitigation of BC breaking change in hf_assets_path # tokenizer default path, need to remove later - if not os.path.exists(self.config.model.tokenizer_path): + if self.config.model.tokenizer_path is not None: logger.warning( - f"Tokenizer path {self.config.model.tokenizer_path} does not exist!" + "tokenizer_path is deprecated, use model.hf_assets_path instead. " + "Setting hf_assets_path to tokenizer_path temporarily." + ) + self.config.model.hf_assets_path = self.config.model.tokenizer_path + if not os.path.exists(self.config.model.hf_assets_path): + logger.warning( + f"HF assets path {self.config.model.hf_assets_path} does not exist!" ) old_tokenizer_path = ( "torchtitan/datasets/tokenizer/original/tokenizer.model" ) if os.path.exists(old_tokenizer_path): - self.config.model.tokenizer_path = old_tokenizer_path + self.config.model.hf_assets_path = old_tokenizer_path logger.warning( f"Temporarily switching to previous default tokenizer path {old_tokenizer_path}. " - "Please download the new tokenizer model (python scripts/download_tokenizer.py) and update your config." + "Please download the new tokenizer files (python scripts/download_hf_assets.py) and update your config." ) else: # Check if we are using tokenizer.model, if so then we need to alert users to redownload the tokenizer - if self.config.model.tokenizer_path.endswith("tokenizer.model"): + if self.config.model.hf_assets_path.endswith("tokenizer.model"): raise Exception( "You are using the old tokenizer.model, please redownload the tokenizer ", - "(python scripts/download_tokenizer.py --repo_id meta-llama/Llama-3.1-8B) ", + "(python scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets tokenizer) ", " and update your config to the directory of the downloaded tokenizer.", ) diff --git a/torchtitan/experiments/flux/dataset/tokenizer.py b/torchtitan/experiments/flux/dataset/tokenizer.py index bf90bdcb26..41a02aa907 100644 --- a/torchtitan/experiments/flux/dataset/tokenizer.py +++ b/torchtitan/experiments/flux/dataset/tokenizer.py @@ -124,7 +124,7 @@ def build_flux_tokenizer(job_config: JobConfig) -> tuple[BaseTokenizer, BaseToke # NOTE: This tokenizer is used for offline CI and testing only, borrowed from llama3 tokenizer if job_config.training.test_mode: tokenizer_class = FluxTestTokenizer - t5_tokenizer_path = clip_tokenzier_path = job_config.model.tokenizer_path + t5_tokenizer_path = clip_tokenzier_path = job_config.model.hf_assets_path else: tokenizer_class = FluxTokenizer diff --git a/torchtitan/experiments/flux/tests/integration_tests.py b/torchtitan/experiments/flux/tests/integration_tests.py index ae4e688266..aa23add5cf 100755 --- a/torchtitan/experiments/flux/tests/integration_tests.py +++ b/torchtitan/experiments/flux/tests/integration_tests.py @@ -109,7 +109,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): t5_encoder_version_arg = ( "--encoder.t5_encoder torchtitan/experiments/flux/tests/assets/t5-v1_1-xxl/" ) - tokenzier_path_arg = "--model.tokenizer_path tests/assets/tokenizer" + hf_assets_path_arg = "--model.hf_assets_path tests/assets/tokenizer" all_ranks = ",".join(map(str, range(test_flavor.ngpu))) @@ -121,7 +121,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): cmd += " " + random_init_encoder_arg cmd += " " + clip_encoder_version_arg cmd += " " + t5_encoder_version_arg - cmd += " " + tokenzier_path_arg + cmd += " " + hf_assets_path_arg if override_arg: cmd += " " + " ".join(override_arg) logger.info( diff --git a/torchtitan/experiments/llama4/README.md b/torchtitan/experiments/llama4/README.md index 23b75b8598..964bc3741f 100644 --- a/torchtitan/experiments/llama4/README.md +++ b/torchtitan/experiments/llama4/README.md @@ -12,7 +12,7 @@ https://github.com/pytorch/torchtitan/issues/1118 #### Download Llama 4 tokenizer ```bash # Llama 4 tokenizer.model -python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E --hf_token=... +python scripts/download_hf_assets.py --assets tokenizer --repo_id meta-llama/Llama-4-Scout-17B-16E --hf_token=... ``` #### To be added diff --git a/torchtitan/experiments/multimodal/check_padding_mm.py b/torchtitan/experiments/multimodal/check_padding_mm.py index 0635c7a030..18f6a4ac8d 100644 --- a/torchtitan/experiments/multimodal/check_padding_mm.py +++ b/torchtitan/experiments/multimodal/check_padding_mm.py @@ -39,7 +39,7 @@ def main( str(batch_size), "--training.seq_len", str(seq_len), - "--model.tokenizer_path", + "--model.hf_assets_path", tokenizer_path, ] ) diff --git a/torchtitan/experiments/multimodal/tokenizer/tiktoken.py b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py index 239cf3d339..b3a7a71ea0 100644 --- a/torchtitan/experiments/multimodal/tokenizer/tiktoken.py +++ b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py @@ -232,4 +232,4 @@ def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]: def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer: - return TikTokenizer(job_config.model.tokenizer_path) + return TikTokenizer(job_config.model.hf_assets_path) diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index 38742cc716..5a36c9198c 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -8,12 +8,12 @@ DeepSeek-V3 is a Mixture-of-Experts (MoE) transformer model with Multi-head Late ```bash # DeepSeek 671B tokenizer (automatically downloads tokenizer.json and tokenizer_config.json) -python scripts/download_tokenizer.py --repo_id deepseek-ai/DeepSeek-V3 +python scripts/download_hf_assets.py --repo_id deepseek-ai/DeepSeek-V3 --assets tokenizer ``` ```bash # DeepSeek 16B tokenizer: -python scripts/download_tokenizer.py --repo_id deepseek-ai/deepseek-moe-16b-base +python scripts/download_hf_assets.py --repo_id deepseek-ai/deepseek-moe-16b-base --assets tokenizer ``` > **Note:** We are reusing the tokenizer from deepseek-ai/deepseek-moe-16b-base to help users test and run the 16B model. This is not the official tokenizer for the DeepSeek-V3-16B model. The DeepSeek-V3 model has a different architecture from the deepseek-moe models (different attention implementation, MoE router implementation, etc.), making it not feasible to load deepseek-moe-16b model weights into DeepSeek-V3-16B. diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 9305c1b4d3..91259d0bce 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -4,17 +4,23 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import json +import logging +import os import re from typing import Any +logger = logging.getLogger() + from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import TransformerModelArgs class Llama3StateDictAdapter(StateDictAdapter): - def __init__(self, model_args: TransformerModelArgs): + def __init__(self, model_args: TransformerModelArgs, hf_assets_path: str | None): self.model_args = model_args + self.hf_assets_path = hf_assets_path self.from_hf_map = { "model.embed_tokens.weight": "tok_embeddings.weight", "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", @@ -31,6 +37,26 @@ def __init__(self, model_args: TransformerModelArgs): "lm_head.weight": "output.weight", } + if hf_assets_path: + mapping_path = os.path.join(hf_assets_path, "model.safetensors.index.json") + try: + with open(mapping_path, "r") as f: + hf_safetensors_indx = json.load(f) + except FileNotFoundError: + logger.warning( + "model.safetensors.index.json not found at hf_assets_path: {mapping_path}. \ + Defaulting to saving a single safetensors file if checkpoint is saved in HF format.", + ) + hf_safetensors_indx = None + + if hf_safetensors_indx: + self.fqn_to_index_mapping = {} + for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): + indx = re.search(r"\d+", raw_indx).group(0) + self.fqn_to_index_mapping[hf_key] = indx + else: + self.fqn_to_index_mapping = None + # HuggingFace permutation function (exact copy from their conversion script) def _permute(self, w, n_heads_arg, dim1=None, dim2=None): if dim1 is None: diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 00a688dcf5..0607268a75 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -24,7 +24,7 @@ enable_wandb = false name = "llama3" flavor = "debugmodel" # test folder with tokenizer.json, for debug purpose only -tokenizer_path = "./tests/assets/tokenizer" +hf_assets_path = "./tests/assets/tokenizer" # converters = ["float8"] [optimizer] diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index ed1335fa80..f3c2931a55 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" -tokenizer_path = "./assets/tokenizer/Llama-3.1-8B" +hf_assets_path = "./assets/hf/Llama-3.1-8B" # converters = ["float8"] [optimizer] diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 9bcbfc0463..1975a9ed08 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -15,10 +15,13 @@ class StateDictAdapter(ABC): This class defines the interface for converting between native model state dict format and other model state dict formats. + Args: + model_args: for initializing the model's memory space + hf_assets_path: path to HF assets folder containing tokenizer, model weights, etc. """ @abstractmethod - def __init__(self, model_args: BaseModelArgs): + def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None): pass @abstractmethod diff --git a/torchtitan/train.py b/torchtitan/train.py index f55530a083..70b7d2ebde 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -307,7 +307,9 @@ def __init__(self, job_config: JobConfig): states={"train_state": self}, checkpoint_config=job_config.checkpoint, sd_adapter=( - self.train_spec.state_dict_adapter(model_args) + self.train_spec.state_dict_adapter( + model_args, job_config.model.hf_assets_path + ) if self.train_spec.state_dict_adapter else None ), From d14f1e3bcb4570be461b4bb70e0be522b4bc9a1c Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Mon, 11 Aug 2025 22:37:42 -0700 Subject: [PATCH 080/128] Flux Batched Inference (#1548) This PR reuses the trainer's parallelization to perform batched inference on Flux. It follows up from and addresses comments on @CarlosGomes98 https://github.com/pytorch/torchtitan/pull/1227. This irons out some optimizations for better code clarity. In this implementation saves images locally from each rank but maintains the global information such as prompt and idx of each image. It also removes unnecessary padding and handles prompt sets that are not divisible by the batch size. --- .../experiments/flux/dataset/tokenizer.py | 28 ++++-- .../experiments/flux/inference/infer.py | 85 +++++++++++++++++++ .../experiments/flux/inference/prompts.txt | 28 ++++++ .../experiments/flux/inference/run_infer.sh | 22 +++++ torchtitan/experiments/flux/job_config.py | 15 ++++ torchtitan/experiments/flux/run_train.sh | 8 +- torchtitan/experiments/flux/sampling.py | 42 ++++++--- .../flux/train_configs/debug_model.toml | 6 ++ 8 files changed, 211 insertions(+), 23 deletions(-) create mode 100644 torchtitan/experiments/flux/inference/infer.py create mode 100644 torchtitan/experiments/flux/inference/prompts.txt create mode 100755 torchtitan/experiments/flux/inference/run_infer.sh diff --git a/torchtitan/experiments/flux/dataset/tokenizer.py b/torchtitan/experiments/flux/dataset/tokenizer.py index 41a02aa907..1e1e2498a1 100644 --- a/torchtitan/experiments/flux/dataset/tokenizer.py +++ b/torchtitan/experiments/flux/dataset/tokenizer.py @@ -45,13 +45,31 @@ def _pad_and_chunk_tokens( def get_vocab_size(self) -> int: return self.tiktokenizer.vocab_size - def encode(self, text: str) -> torch.Tensor: + def encode(self, text: str | list[str]) -> torch.Tensor: """ Use TikTokenizer to encode the text into tokens, and then pad and chunk the tokens to max_length. """ - tokens = self.tiktokenizer.encode(text, add_bos=True, add_eos=True) - tokens = self._pad_and_chunk_tokens(tokens, self._max_length, self.pad_id) - return torch.tensor(tokens) + if isinstance(text, list): + if len(text) == 1: + # for single item in list encode and add batch dimension + tokens = self.tiktokenizer.encode(text[0], add_bos=True, add_eos=True) + tokens = self._pad_and_chunk_tokens( + tokens, self._max_length, self.pad_id + ) + return torch.tensor(tokens).unsqueeze(0) + else: + all_tokens = [] + for t in text: + tokens = self.tiktokenizer.encode(t, add_bos=True, add_eos=True) + tokens = self._pad_and_chunk_tokens( + tokens, self._max_length, self.pad_id + ) + all_tokens.append(torch.tensor(tokens)) + return torch.stack(all_tokens) + else: + tokens = self.tiktokenizer.encode(text, add_bos=True, add_eos=True) + tokens = self._pad_and_chunk_tokens(tokens, self._max_length, self.pad_id) + return torch.tensor(tokens) def decode(self, t: List[int]) -> str: """ @@ -90,7 +108,7 @@ def get_vocab_size(self) -> int: def encode( self, - s: str, + s: str | list[str], ) -> torch.Tensor: """ Encode the prompt text into tokens. diff --git a/torchtitan/experiments/flux/inference/infer.py b/torchtitan/experiments/flux/inference/infer.py new file mode 100644 index 0000000000..f31fb15073 --- /dev/null +++ b/torchtitan/experiments/flux/inference/infer.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +from torchtitan.config import ConfigManager, JobConfig +from torchtitan.experiments.flux.dataset.tokenizer import build_flux_tokenizer +from torchtitan.experiments.flux.sampling import generate_image, save_image +from torchtitan.experiments.flux.train import FluxTrainer +from torchtitan.tools.logging import init_logger, logger + + +@torch.no_grad() +@record +def inference(config: JobConfig): + # Reuse trainer to perform forward passes + trainer = FluxTrainer(config) + + # Distributed processing setup: Each GPU/process handles a subset of prompts + world_size = int(os.environ["WORLD_SIZE"]) + global_rank = int(os.environ["RANK"]) + original_prompts = open(config.inference.prompts_path).readlines() + total_prompts = len(original_prompts) + + # Distribute prompts across processes using round-robin assignment + prompts = original_prompts[global_rank::world_size] + + trainer.checkpointer.load(step=config.checkpoint.load_step) + t5_tokenizer, clip_tokenizer = build_flux_tokenizer(config) + + if global_rank == 0: + logger.info("Starting inference...") + + if prompts: + # Generate images for this process's assigned prompts + bs = config.inference.local_batch_size + + output_dir = os.path.join( + config.job.dump_folder, + config.inference.save_img_folder, + ) + + # Create mapping from local indices to global prompt indices + global_ids = list(range(global_rank, total_prompts, world_size)) + + for i in range(0, len(prompts), bs): + images = generate_image( + device=trainer.device, + dtype=trainer._dtype, + job_config=trainer.job_config, + model=trainer.model_parts[0], + prompt=prompts[i : i + bs], + autoencoder=trainer.autoencoder, + t5_tokenizer=t5_tokenizer, + clip_tokenizer=clip_tokenizer, + t5_encoder=trainer.t5_encoder, + clip_encoder=trainer.clip_encoder, + ) + for j in range(images.shape[0]): + # Extract single image while preserving batch dimension [1, C, H, W] + img = images[j : j + 1] + global_id = global_ids[i + j] + + save_image( + name=f"image_prompt{global_id}_rank{str(torch.distributed.get_rank())}.png", + output_dir=output_dir, + x=img, + add_sampling_metadata=True, + prompt=prompts[i + j], + ) + + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + init_logger() + config_manager = ConfigManager() + config = config_manager.parse_args() + inference(config) diff --git a/torchtitan/experiments/flux/inference/prompts.txt b/torchtitan/experiments/flux/inference/prompts.txt new file mode 100644 index 0000000000..23a76c5a34 --- /dev/null +++ b/torchtitan/experiments/flux/inference/prompts.txt @@ -0,0 +1,28 @@ +A serene mountain landscape at sunset with a crystal clear lake reflecting the golden sky +A futuristic cityscape with flying cars and neon lights illuminating the night sky +A cozy cafe interior with steam rising from coffee cups and warm lighting +A magical forest with glowing mushrooms and fireflies dancing between ancient trees +A peaceful beach scene with turquoise waves and palm trees swaying in the breeze +A steampunk-inspired mechanical dragon soaring through clouds +A mystical library with floating books and magical artifacts +A Japanese garden in spring with cherry blossoms falling gently +A space station orbiting a colorful nebula +A medieval castle on a hilltop during a dramatic thunderstorm +A underwater scene with bioluminescent creatures and coral reefs +A desert oasis with a majestic palace and palm trees +A cyberpunk street market with holographic signs and diverse crowds +A cozy winter cabin surrounded by snow-covered pine trees +A fantasy tavern filled with unique characters and magical atmosphere +A tropical rainforest with exotic birds and waterfalls +A steampunk airship navigating through storm clouds +A peaceful zen garden with a traditional Japanese tea house +A magical potion shop with bubbling cauldrons and mysterious ingredients +A futuristic space colony on Mars with domed habitats +A mystical temple hidden in the clouds +A vintage train station with steam locomotives and period architecture +A magical bakery with floating pastries and enchanted ingredients +A peaceful countryside scene with rolling hills and a rustic farmhouse +A underwater city with advanced technology and marine life +A fantasy marketplace with magical creatures and exotic goods +A peaceful meditation garden with lotus flowers and koi ponds +A steampunk laboratory with intricate machinery and glowing elements diff --git a/torchtitan/experiments/flux/inference/run_infer.sh b/torchtitan/experiments/flux/inference/run_infer.sh new file mode 100755 index 0000000000..126540b7ed --- /dev/null +++ b/torchtitan/experiments/flux/inference/run_infer.sh @@ -0,0 +1,22 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_train.sh +NGPU=${NGPU:-"8"} +export LOG_RANK=${LOG_RANK:-0} +CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/flux/train_configs/debug_model.toml"} + +PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +-m torchtitan.experiments.flux.inference.infer --job.config_file ${CONFIG_FILE} \ +--checkpoint.enable_checkpoint \ +--checkpoint.exclude_from_loading=lr_scheduler,dataloader,optimizer "$@" diff --git a/torchtitan/experiments/flux/job_config.py b/torchtitan/experiments/flux/job_config.py index 8c7589d045..0b139ed42f 100644 --- a/torchtitan/experiments/flux/job_config.py +++ b/torchtitan/experiments/flux/job_config.py @@ -54,6 +54,20 @@ class Validation: """Whether to generate all stratified timesteps per sample or use round robin""" +@dataclass +class Inference: + """Inference configuration""" + + save_img_folder: str = "inference_results" + """Path to save the inference results""" + prompts_path: str = "./torchtitan/experiments/flux/inference/prompts.txt" + """Path to file with newline separated prompts to generate images for""" + local_batch_size: int = 2 + """Batch size for inference""" + img_size: int = 256 + """Image size for inference""" + + @dataclass class JobConfig: """ @@ -63,3 +77,4 @@ class JobConfig: training: Training = field(default_factory=Training) encoder: Encoder = field(default_factory=Encoder) validation: Validation = field(default_factory=Validation) + inference: Inference = field(default_factory=Inference) diff --git a/torchtitan/experiments/flux/run_train.sh b/torchtitan/experiments/flux/run_train.sh index 6902cc339a..231d66fc35 100755 --- a/torchtitan/experiments/flux/run_train.sh +++ b/torchtitan/experiments/flux/run_train.sh @@ -14,13 +14,7 @@ NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/flux/train_configs/debug_model.toml"} -overrides="" -if [ $# -ne 0 ]; then - overrides="$*" -fi - - PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ --m torchtitan.experiments.flux.train --job.config_file ${CONFIG_FILE} $overrides +-m torchtitan.experiments.flux.train --job.config_file ${CONFIG_FILE} "$@" diff --git a/torchtitan/experiments/flux/sampling.py b/torchtitan/experiments/flux/sampling.py index 445e8c85fd..4a5a1157f8 100644 --- a/torchtitan/experiments/flux/sampling.py +++ b/torchtitan/experiments/flux/sampling.py @@ -76,7 +76,7 @@ def generate_image( dtype: torch.dtype, job_config: JobConfig, model: FluxModel, - prompt: str, + prompt: str | list[str], autoencoder: AutoEncoder, t5_tokenizer: BaseTokenizer, clip_tokenizer: BaseTokenizer, @@ -89,6 +89,9 @@ def generate_image( Since we will always use the local random seed on this rank, we don't need to pass in the seed again. """ + if isinstance(prompt, str): + prompt = [prompt] + # allow for packing and conversion to latent space. Use the same resolution as training time. img_height = 16 * (job_config.training.img_size // 16) img_width = 16 * (job_config.training.img_size // 16) @@ -98,8 +101,11 @@ def generate_image( ) # Tokenize the prompt. Unsqueeze to add a batch dimension. - clip_tokens = clip_tokenizer.encode(prompt).unsqueeze(0) - t5_tokens = t5_tokenizer.encode(prompt).unsqueeze(0) + clip_tokens = clip_tokenizer.encode(prompt) + t5_tokens = t5_tokenizer.encode(prompt) + if len(prompt) == 1: + clip_tokens = clip_tokens.unsqueeze(0) + t5_tokens = t5_tokens.unsqueeze(0) batch = preprocess_data( device=device, @@ -114,11 +120,16 @@ def generate_image( ) if enable_classifier_free_guidance: - empty_clip_tokens = clip_tokenizer.encode("").unsqueeze(0) - empty_t5_tokens = t5_tokenizer.encode("").unsqueeze(0) + num_images = len(prompt) + + empty_clip_tokens = clip_tokenizer.encode("") + empty_t5_tokens = t5_tokenizer.encode("") + empty_clip_tokens = empty_clip_tokens.repeat(num_images, 1) + empty_t5_tokens = empty_t5_tokens.repeat(num_images, 1) + empty_batch = preprocess_data( device=device, - dtype=torch.bfloat16, + dtype=dtype, autoencoder=None, clip_encoder=clip_encoder, t5_encoder=t5_encoder, @@ -176,6 +187,13 @@ def denoise( # create denoising schedule timesteps = get_schedule(denoising_steps, latent_height * latent_width, shift=True) + if enable_classifier_free_guidance: + # Double batch size for CFG: [unconditional, conditional] + latents = torch.cat([latents, latents], dim=0) + t5_encodings = torch.cat([empty_t5_encodings, t5_encodings], dim=0) + clip_encodings = torch.cat([empty_clip_encodings, clip_encodings], dim=0) + bsz *= 2 + # create positional encodings POSITION_DIM = 3 latent_pos_enc = create_position_encoding_for_latents( @@ -183,11 +201,6 @@ def denoise( ).to(latents) text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents) - if enable_classifier_free_guidance: - latents = torch.cat([latents, latents], dim=0) - t5_encodings = torch.cat([empty_t5_encodings, t5_encodings], dim=0) - clip_encodings = torch.cat([empty_clip_encodings, clip_encodings], dim=0) - # convert img-like latents into sequences of patches latents = pack_latents(latents) @@ -206,8 +219,15 @@ def denoise( pred_u, pred_c = pred.chunk(2) pred = pred_u + classifier_free_guidance_scale * (pred_c - pred_u) + # repeat along batch dimension to update both unconditional and conditional latents + pred = pred.repeat(2, 1, 1) + latents = latents + (t_prev - t_curr) * pred + # take the conditional latents for the final result + if enable_classifier_free_guidance: + latents = latents.chunk(2)[1] + # convert sequences of patches into img-like latents latents = unpack_latents(latents, latent_height, latent_width) diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml index aad2580218..e565b23bd9 100644 --- a/torchtitan/experiments/flux/train_configs/debug_model.toml +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -71,9 +71,15 @@ dataset = "coco-validation" freq = 5 local_batch_size = 8 steps = 1 +# args for sampling images enable_classifier_free_guidance = true classifier_free_guidance_scale = 5.0 denoising_steps = 4 save_img_count = 1 save_img_folder = "img" all_timesteps = false + +[inference] +save_img_folder = "inference_results" +prompts_path = "./torchtitan/experiments/flux/inference/prompts.txt" +local_batch_size = 2 From 9c42b9b474b7a6c18c8400dcf938d1e06ecfa1df Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 12 Aug 2025 10:56:12 -0700 Subject: [PATCH 081/128] [a2av] Add autograd support for token dispatch op (#1491) Added class `TokenDispatcher` which dispatches tokens to different experts, with backward support. Usage: ``` dispatcher = TokenDispatcher(group_name, align, max_inp_len, max_out_len, inp.shape[1:], world_size, ne, dtype) # inp, out, in_splits, out_splits_offsets must be symmetric tensors output = dispatcher(inp, out, in_splits, out_splits_offsets) ``` Supports: ``` torch.compile(dispatcher) ``` --- .../experiments/kernels/moe/dispatch.py | 312 ++++++++++++++++++ 1 file changed, 312 insertions(+) create mode 100644 torchtitan/experiments/kernels/moe/dispatch.py diff --git a/torchtitan/experiments/kernels/moe/dispatch.py b/torchtitan/experiments/kernels/moe/dispatch.py new file mode 100644 index 0000000000..7775e00848 --- /dev/null +++ b/torchtitan/experiments/kernels/moe/dispatch.py @@ -0,0 +1,312 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem + + +# Adding out-of-tree ops to the `symm_mem` library +lib = torch.library.Library("symm_mem", "FRAGMENT") # noqa: TOR901 + +""" +all_to_all_vdev_2d_offset_copy: +Copy data from `input` to `symm_in_buf` and call `all_to_all_vdev_2d_offset` to shuffle data +""" +lib.define( + "all_to_all_vdev_2d_offset_copy(" + "Tensor input, Tensor symm_in_buf, Tensor(a!) out, " + "Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()", + tags=[torch._C.Tag.needs_exact_strides], +) + + +@torch.library.impl(lib, "all_to_all_vdev_2d_offset_copy", "CUDA") +def _all_to_all_vdev_2d_offset_copy_cuda( + input: torch.Tensor, + symm_in_buf: torch.Tensor, + out: torch.Tensor, + in_splits_offsets: torch.Tensor, + out_splits_offsets: torch.Tensor, + group_name: str, +) -> None: + if symm_in_buf.shape[0] < input.shape[0]: + raise RuntimeError( + f"symm_in_buf with dim-0 length {symm_in_buf.shape[0]} cannot fit input with dim-0 length {input.shape[0]}" + ) + if symm_in_buf.shape[1:] != input.shape[1:]: + raise RuntimeError( + f"symm_in_buf non-0 dims do not match that of input: {symm_in_buf.shape[1:]} vs {input.shape[1:]}" + ) + if symm_in_buf.dtype != input.dtype: + raise RuntimeError( + f"symm_in_buf dtype {symm_in_buf.dtype} does not match input dtype {input.dtype}" + ) + + symm_in_buf.narrow(0, 0, input.shape[0]).copy_(input) + torch.ops.symm_mem.all_to_all_vdev_2d_offset( + symm_in_buf, + out, + in_splits_offsets, + out_splits_offsets, + group_name, + ) + + +class AllToAllVDev2d(torch.autograd.Function): + """ + Autograd function for `all_to_all_vdev_2d` + """ + + @staticmethod + def forward( # type: ignore[no-untyped-def] + ctx, + input: torch.Tensor, + out: torch.Tensor, + in_splits: torch.Tensor, + out_splits_offsets: torch.Tensor, + group_name: str, + major_align: int, + # Buffers needed for backward pass + grad_out_buf: torch.Tensor, + grad_in_buf: torch.Tensor, + grad_in_splits_offsets: torch.Tensor, + ) -> torch.Tensor: + """ + Functionality is the same as `all_to_all_vdev_2d` but with functionalization. + """ + # Shuffle input to output + torch.ops.symm_mem.all_to_all_vdev_2d( + input, out, in_splits, out_splits_offsets, group_name, major_align + ) + + # Output splits in forward is the input splits in backward + ctx.save_for_backward( + out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets + ) + ctx.group_name = group_name + return out + + @staticmethod + def backward( # type: ignore[no-untyped-def] + ctx, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None, None, None]: + """ + Backward pass of `all_to_all_vdev_2d` is `all_to_all_vdev_2d_offset`. + + Args: + `grad_output`: gradients of output passed back from the downstream. + + Returns: + `grad_input`: gradients of input. + """ + # Splits info + # Splits/offsets of grad_out is the same as out splits/offsets in forward + ( + grad_out_splits_offsets, + grad_out_buf, + grad_in_buf, + grad_in_splits_offsets, + ) = ctx.saved_tensors + + # Shuffle gradients back to the input + torch.ops.symm_mem.all_to_all_vdev_2d_offset_copy( + grad_output, + grad_out_buf, + grad_in_buf, + grad_out_splits_offsets, + grad_in_splits_offsets, + group_name=ctx.group_name, + ) + return grad_in_buf, None, None, None, None, None, None, None, None + + +class TokenDispatcher(torch.nn.Module): + """ + Dispatch tokens to different experts, with backward pass to shuffle gradients back to the input. + Args: + `group_name`: name of the group to use for communication. + `align`: alignment of the token offsets for each receiving expert. If + using Grouped Gemm next, this should be the same as Grouped Gemm's + alignment. + `in_len`: length of the input. + `out_len`: length of the output. + `token_shape`: shape of the tokens. + `num_ranks`: number of ranks in the group. + `num_local_experts`: number of local experts. + `dtype`: data type of the input/output. + `device`: device to use for communication. + """ + + def __init__( + self, + group_name: str, + align: int, + in_len, + out_len, + token_shape, + num_ranks, + num_local_experts, + dtype, + device: torch.device, + ) -> None: + super().__init__() + self.group_name = group_name + self.align = align + self.grad_out_buf = symm_mem.empty( + out_len, *token_shape, dtype=dtype, device=device + ) + self.grad_in_buf = symm_mem.empty( + in_len, *token_shape, dtype=dtype, device=device + ) + self.nsplits = num_ranks * num_local_experts + self.grad_in_splits_offsets = symm_mem.empty( + (2, self.nsplits), dtype=torch.int64, device=device + ) + + def forward( + self, + inp: torch.Tensor, + out: torch.Tensor, + in_splits: torch.Tensor, + out_splits_offsets: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + `inp`: input tensor. + `out`: buffer for output tensor. + `in_splits`: splits of the input tensor. + `out_splits_offsets`: splits and offsets of the output tensor. + See `all_to_all_vdev_2d` for more details. + Note: + All tensor arguments must be symmetrically allocated, i.e. + >>> inp = symm_mem.empty(max_inp_len, dtype=dtype, device=device) + >>> out = symm_mem.empty(max_out_len, dtype=dtype, device=device) + >>> in_splits = symm_mem.empty( + ... nsplits, dtype=torch.int64, device=device) + >>> out_splits_offsets = symm_mem.empty( + ... (2, nsplits), dtype=torch.int64, device=device) + """ + + if in_splits.numel() != self.nsplits: + raise ValueError(f"Expected {self.nsplits} splits, got {in_splits.numel()}") + if out_splits_offsets.shape != (2, self.nsplits): + raise ValueError( + f"Expected shape (2, {self.nsplits}), got {out_splits_offsets.shape}" + ) + + return AllToAllVDev2d.apply( + inp, + out, + in_splits, + out_splits_offsets, + self.group_name, + self.align, + self.grad_out_buf, + self.grad_in_buf, + self.grad_in_splits_offsets, + ) + + +def test_token_dispatch() -> None: + # Init + dist.init_process_group() + rank = dist.get_rank() + world_size = dist.get_world_size() + device_count = torch.cuda.device_count() + device = torch.device("cuda", rank % device_count) + + # NVSHMEM backend specific + torch.cuda.set_device(device) + torch.empty(1, device=device) + # Set NVSHMEM as SymmMem backend + symm_mem.set_backend("NVSHMEM") + + # Mimics Group GEMM alignment + align = 8 + torch.manual_seed(42 + rank) + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + dtype = torch.float + # Number of experts per rank + ne = 8 + nsplits = ne * world_size + + # Number of elements for an expert is random between [0, k) + k = 10 + inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=device) + + # Max number of input elements (must be a constant across ranks for symmetric memory allocation) + max_inp_len = k * nsplits + # Max number of output elements (must be a constant across ranks for symmetric memory allocation) + overflow_factor = world_size # worst case: one rank receives all data + max_out_len = max_inp_len * overflow_factor + + hid = 4096 + inp = symm_mem.empty(max_inp_len, hid, dtype=dtype, device=device) + out = symm_mem.empty(max_out_len, hid, dtype=dtype, device=device) + in_splits = symm_mem.empty(nsplits, dtype=torch.int64, device=device).copy_( + inp_splits + ) + # 2 rows: output splits, output offsets + out_splits_offsets = symm_mem.empty((2, nsplits), dtype=torch.int64, device=device) + + dispatcher = TokenDispatcher( + group_name, + align, + max_inp_len, + max_out_len, + inp.shape[1:], + world_size, + ne, + dtype, + device, + ) + + compiled_dispatcher = torch.compile( + dispatcher, + fullgraph=True, + ) + + # Perform a Dot product with output, so that gradients passed back from + # different ranks are different + weight = torch.empty(max_out_len, dtype=dtype, device=device).fill_(rank + 1) + + # Run a few iterations + iters = 2 + for i in range(iters): + # Test if gradients would be passed back from inp to tokens + tokens = torch.randn( + max_inp_len, hid, dtype=dtype, device=device + ).requires_grad_(True) + tokens.grad = None + inp.copy_(tokens) + output = compiled_dispatcher(inp, out, in_splits, out_splits_offsets) + p = torch.matmul(weight, output) + p.sum().backward() + + # Check gradients + start = 0 + for i, split in enumerate(in_splits.tolist()): + grad_chunk = tokens.grad[start : start + split] + dst_rank = i // ne + torch.testing.assert_close( + grad_chunk, + torch.empty(split, hid, device=device).fill_(dst_rank + 1), + ) + start += split + + dist.destroy_process_group() + print(f"Rank {rank} passed") + + +if __name__ == "__main__": + # To run this test, use the following command: + # torchrun --nproc-per-node 4 --standalone dispatch.py + test_token_dispatch() From cf4de26517294c44c273be34cb32d5a18748c8ae Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 12 Aug 2025 10:57:13 -0700 Subject: [PATCH 082/128] [a2av] Add autograd support for token combine op (#1511) Added class `TokenCombiner` which combines tokens from different experts, with backward support. Usage: ``` combiner = TokenCombiner(group_name, align, max_inp_len, max_out_len, inp.shape[1:], world_size, ne, dtype) # inp, out, in_splits_offsets, out_splits_offsets must be symmetric tensors output = combiner(inp, out, in_splits_offsets, out_splits_offsets) ``` Supports: ``` torch.compile(combiner) ``` --- torchtitan/experiments/kernels/moe/combine.py | 337 ++++++++++++++++++ 1 file changed, 337 insertions(+) create mode 100644 torchtitan/experiments/kernels/moe/combine.py diff --git a/torchtitan/experiments/kernels/moe/combine.py b/torchtitan/experiments/kernels/moe/combine.py new file mode 100644 index 0000000000..34d38e9f1b --- /dev/null +++ b/torchtitan/experiments/kernels/moe/combine.py @@ -0,0 +1,337 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem + + +# Adding out-of-tree ops to the `symm_mem` library +lib = torch.library.Library("symm_mem", "FRAGMENT") # noqa: TOR901 + +""" +all_to_all_vdev_2d_copy: +Copy data from `input` to `symm_in_buf` and call `all_to_all_vdev_2d` to shuffle data +""" +lib.define( + "all_to_all_vdev_2d_copy(" + "Tensor input, Tensor symm_in_buf, Tensor(a!) out, " + "Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()", + tags=[torch._C.Tag.needs_exact_strides], +) + + +@torch.library.impl(lib, "all_to_all_vdev_2d_copy", "CUDA") +def _all_to_all_vdev_2d_copy_cuda( + input: torch.Tensor, + symm_in_buf: torch.Tensor, + out: torch.Tensor, + in_splits: torch.Tensor, + out_splits_offsets: torch.Tensor, + group_name: str, + major_align: Optional[int] = None, +) -> None: + if symm_in_buf.shape[0] < input.shape[0]: + raise RuntimeError( + f"symm_in_buf with dim-0 length {symm_in_buf.shape[0]} cannot fit input with dim-0 length {input.shape[0]}" + ) + if symm_in_buf.shape[1:] != input.shape[1:]: + raise RuntimeError( + f"symm_in_buf non-0 dims do not match that of input: {symm_in_buf.shape[1:]} vs {input.shape[1:]}" + ) + if symm_in_buf.dtype != input.dtype: + raise RuntimeError( + f"symm_in_buf dtype {symm_in_buf.dtype} does not match input dtype {input.dtype}" + ) + + symm_in_buf.narrow(0, 0, input.shape[0]).copy_(input) + torch.ops.symm_mem.all_to_all_vdev_2d( + symm_in_buf, + out, + in_splits, + out_splits_offsets, + group_name, + major_align, + ) + + +class AllToAllVDev2dOffset(torch.autograd.Function): + """ + Autograd function for `all_to_all_vdev_2d_offset` + """ + + @staticmethod + def forward( # type: ignore[no-untyped-def] + ctx, + input: torch.Tensor, + out: torch.Tensor, + in_splits_offsets: torch.Tensor, + out_splits_offsets: torch.Tensor, + group_name: str, + # Buffers needed for backward pass + major_align: int, + grad_out_buf: torch.Tensor, + grad_in_buf: torch.Tensor, + grad_in_splits_offsets: torch.Tensor, + ) -> torch.Tensor: + """ + Functionality is the same as `all_to_all_vdev_2d_offset` but with functionalization. + """ + # Shuffle input to output + torch.ops.symm_mem.all_to_all_vdev_2d_offset( + input, out, in_splits_offsets, out_splits_offsets, group_name + ) + + # Output splits in forward is the input splits in backward + ctx.save_for_backward( + out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets + ) + ctx.group_name = group_name + ctx.major_align = major_align + return out + + @staticmethod + def backward( # type: ignore[no-untyped-def] + ctx, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None, None, None]: + """ + Backward pass of `all_to_all_vdev_2d_offset` is `all_to_all_vdev_2d`. + + Args: + `grad_output`: gradients of output passed back from the downstream. + + Returns: + `grad_input`: gradients of input. + """ + # Splits info + # Splits/offsets of grad_out is the same as out splits/offsets in forward + ( + grad_out_splits_offsets, + grad_out_buf, + grad_in_buf, + grad_in_splits_offsets, + ) = ctx.saved_tensors + grad_out_splits = grad_out_splits_offsets[0] + + # Shuffle gradients back to the input + # TODO: create an op that takes both in_splits_offsets and + # out_splits_offsets, instead of taking alignment + torch.ops.symm_mem.all_to_all_vdev_2d_copy( + grad_output, + grad_out_buf, + grad_in_buf, + grad_out_splits, + grad_in_splits_offsets, + ctx.group_name, + ctx.major_align, + ) + + return grad_in_buf, None, None, None, None, None, None, None, None + + +class TokenCombiner(torch.nn.Module): + """ + Combine tokens from different experts, with backward pass to shuffle gradients back to the input. + Args: + `group_name`: name of the group to use for communication. + `align`: alignment of the token offsets from each expert. If using + Grouped Gemm next, this should be the same as Grouped Gemm's alignment. + `in_len`: length of the input. + `out_len`: length of the output. + `token_shape`: shape of the tokens. + `num_ranks`: number of ranks in the group. + `num_local_experts`: number of local experts. + `dtype`: data type of the input/output. + """ + + def __init__( + self, + group_name: str, + align: int, + in_len, + out_len, + token_shape, + num_ranks, + num_local_experts, + dtype, + device: torch.device, + ) -> None: + super().__init__() + self.group_name = group_name + self.align = align + self.grad_out_buf = symm_mem.empty( + out_len, *token_shape, dtype=dtype, device=device + ) + self.grad_in_buf = symm_mem.empty( + in_len, *token_shape, dtype=dtype, device=device + ) + self.nsplits = num_ranks * num_local_experts + self.grad_in_splits_offsets = symm_mem.empty( + (2, self.nsplits), dtype=torch.int64, device=device + ) + + def forward( + self, + inp: torch.Tensor, + out: torch.Tensor, + in_splits_offsets: torch.Tensor, + out_splits_offsets: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + `inp`: input tensor. + `out`: buffer for output tensor. + `in_splits_offsets`: splits and offsets of the input tensor. + `out_splits_offsets`: splits and offsets of the output tensor. + See `all_to_all_vdev_2d_offset` for more details. + Note: + All tensor arguments must be symmetrically allocated, i.e. + >>> inp = symm_mem.empty(max_inp_len, dtype=dtype, device=device) + >>> out = symm_mem.empty(max_out_len, dtype=dtype, device=device) + >>> in_splits_offsets = symm_mem.empty( + ... (2, nsplits), dtype=torch.int64, device=device) + >>> out_splits_offsets = symm_mem.empty( + ... (2, nsplits), dtype=torch.int64, device=device) + """ + + if in_splits_offsets.shape != (2, self.nsplits): + raise ValueError( + f"Expected shape (2, {self.nsplits}), got {in_splits_offsets.shape}" + ) + if out_splits_offsets.shape != (2, self.nsplits): + raise ValueError( + f"Expected shape (2, {self.nsplits}), got {out_splits_offsets.shape}" + ) + + return AllToAllVDev2dOffset.apply( + inp, + out, + in_splits_offsets, + out_splits_offsets, + self.group_name, + self.align, + self.grad_out_buf, + self.grad_in_buf, + self.grad_in_splits_offsets, + ) + + +def test_token_combine() -> None: + # Init + dist.init_process_group() + rank = dist.get_rank() + world_size = dist.get_world_size() + device_count = torch.cuda.device_count() + device = torch.device("cuda", rank % device_count) + + # NVSHMEM backend specific + torch.cuda.set_device(device) + torch.empty(1, device=device) + # Set NVSHMEM as SymmMem backend + symm_mem.set_backend("NVSHMEM") + + # Mimics Group GEMM alignment + align = 8 + torch.manual_seed(42 + rank) + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + dtype = torch.float + # Number of experts per rank + ne = 8 + nsplits = ne * world_size + + # Number of elements for an expert is random between [0, k) + k = 10 + inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=device) + + # Max number of input elements (must be a constant across ranks for symmetric memory allocation) + max_inp_len = k * nsplits + # Max number of output elements (must be a constant across ranks for symmetric memory allocation) + overflow_factor = world_size # worst case: one rank receives all data + max_out_len = max_inp_len * overflow_factor + + # Use a dispatch to prepare the input for combine (this is just a + # preparation, not the test itself) + # Buffers for dispatch + hid = 4096 + inp = symm_mem.empty(max_inp_len, hid, dtype=dtype, device=device) + out = symm_mem.empty(max_out_len, hid, dtype=dtype, device=device) + # 2 rows: input splits, input offsets + in_splits_offsets = symm_mem.empty((2, nsplits), dtype=torch.int64, device=device) + # 2 rows: output splits, output offsets + out_splits_offsets = symm_mem.empty((2, nsplits), dtype=torch.int64, device=device) + + # Dispatch the tokens first so that we have a nice input for combine + in_splits_offsets[0].copy_(inp_splits) + torch.ops.symm_mem.all_to_all_vdev_2d( + inp, + out, + in_splits_offsets[0], + out_splits_offsets, + group_name, + major_align=align, + ) + + combiner = TokenCombiner( + group_name, + align, + max_out_len, + max_inp_len, + out.shape[1:], + world_size, + ne, + dtype, + device, + ) + + compiled_combiner = torch.compile( + combiner, + fullgraph=True, + ) + + # Perform a Dot product with output, so that gradients passed back from + # different ranks are different + weight = torch.empty(max_inp_len, dtype=dtype, device=device).fill_(rank + 1) + + # Now we start to test the autograd function + + # Requires grad for input of combine + out.requires_grad_(True) + + combine_out = compiled_combiner( + out, + inp, + out_splits_offsets, + in_splits_offsets, + ) + p = torch.matmul(weight, combine_out) + p.sum().backward() + + # Check gradients + # We also need to skip the padding in the input data + out_splits = out_splits_offsets[0].tolist() + out_offsets = out_splits_offsets[1].tolist() + for i, (split, offset) in enumerate(zip(out_splits, out_offsets)): + grad_chunk = out.grad[offset : offset + split] + dst_rank = i % world_size + torch.testing.assert_close( + grad_chunk, + torch.empty(split, hid, device=device).fill_(dst_rank + 1), + ) + + dist.destroy_process_group() + print(f"Rank {rank} passed") + + +if __name__ == "__main__": + # To run this test, use the following command: + # torchrun --nproc-per-node 4 --standalone combine.py + test_token_combine() From a6972ae36ee78f14361db1824fa90b4f05be39db Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Tue, 12 Aug 2025 16:22:07 -0700 Subject: [PATCH 083/128] Add state_dict converter for DeepSeekv3 in torchtitan (#1538) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Support loading a DeepSeek HF weights to Deepseek-V3 model: 1. Support split / concat weight for GroupedExperts 2. Support _dequantization during loading HF checkpoints Numerical verification: (using offline conversion script) ``` python convert_from_hf.py /data/users/jianiw/dsv3-weights outputs/checkpoint-dsv3-cpu --model_name deepseek_v3 --model_flavor 671B > cpu_convert.txt 2>&1 ``` Screenshot 2025-08-11 at 4 31 50 PM Screenshot 2025-08-11 at 4 32 23 PM --- .../experiments/multimodal/mm_dataset.py | 4 +- torchtitan/models/deepseek_v3/README.md | 13 + torchtitan/models/deepseek_v3/__init__.py | 2 + .../models/deepseek_v3/model/quantization.py | 73 ++++++ .../deepseek_v3/model/state_dict_adapter.py | 231 ++++++++++++++++++ torchtitan/models/moe.py | 2 + 6 files changed, 323 insertions(+), 2 deletions(-) create mode 100644 torchtitan/models/deepseek_v3/model/quantization.py create mode 100644 torchtitan/models/deepseek_v3/model/state_dict_adapter.py diff --git a/torchtitan/experiments/multimodal/mm_dataset.py b/torchtitan/experiments/multimodal/mm_dataset.py index da69d6973a..29a42aeeb0 100644 --- a/torchtitan/experiments/multimodal/mm_dataset.py +++ b/torchtitan/experiments/multimodal/mm_dataset.py @@ -16,12 +16,12 @@ from tokenizer.tiktoken import BaseTokenizer, IGNORE_INDEX from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import IterableDataset +from transform import CLIPTransform +from utils import load_image from torchtitan.components.dataloader import ParallelAwareDataloader from torchtitan.config import JobConfig from torchtitan.tools.logging import logger -from transform import CLIPTransform -from utils import load_image def _load_obelics_dataset(dataset_path: str): diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index 5a36c9198c..085403d47b 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -49,6 +49,19 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml - Pipeline Parallel (PP) +## HuggingFace -> DCP Checkpoint Conversion + +We implemented StateDictAdapter to preform HuggingFace safetensor to DCP format conversion. Currently, we only support conversion from HF checkpoints to DCP checkpoints offline (using CPU plain tensor). + +Run the offine conversion script: +```bash +python scripts/checkpoint_conversion/convert_from_hf.py --model_name deepseek_v3 --model_flavor 671B +``` + +Some limitations: +1. It can't be used to convert HF checkpoint on the fly using GPU DTensor, because of sharding and quantized blocks may not be aligned well and causing silent numerfical incorrectness. +2. It can't be used for weight sync to generate a state dict of bf16 because fake quantization to fp8 is applied. + ## To be added - Parallelism - Context Parallel support for DeepSeek V3 diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index a39b35dfa2..2e0f42a736 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -19,6 +19,7 @@ from .infra.parallelize import parallelize_deepseekv3 from .model.args import DeepSeekV3ModelArgs from .model.model import DeepSeekV3Model +from .model.state_dict_adapter import DeepSeekV3StateDictAdapter __all__ = [ "parallelize_deepseekv3", @@ -166,5 +167,6 @@ build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=DeepSeekV3StateDictAdapter, ) ) diff --git a/torchtitan/models/deepseek_v3/model/quantization.py b/torchtitan/models/deepseek_v3/model/quantization.py new file mode 100644 index 0000000000..a8ac6003a2 --- /dev/null +++ b/torchtitan/models/deepseek_v3/model/quantization.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchtitan.tools.logging import logger + +# Fixed block size of 128x128 as specified in the algorithm +BLOCK_SIZE = 128 + + +def calculate_scale_shape( + weight: torch.Tensor, BLOCK_SIZE: int = BLOCK_SIZE +) -> torch.Size: + # Calculate the scale tensor shape + orig_shape = weight.shape + + # Calculate number of blocks needed + block_rows = (orig_shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE + block_cols = (orig_shape[1] + BLOCK_SIZE - 1) // BLOCK_SIZE + + # Verify scale_inv shape matches expected block dimensions + expected_scale_shape = torch.Size((block_rows, block_cols)) + + return expected_scale_shape + + +def dequantize_from_fp8( + weight: torch.Tensor, + scale_inv: torch.Tensor, + dtype=torch.bfloat16, + BLOCK_SIZE: int = BLOCK_SIZE, +) -> torch.Tensor: + # Convert to float32 for computation + float_weight = weight.to(torch.float32) + # Get original dimensions + orig_shape = weight.shape + + # Verify scale_inv shape matches expected block dimensions + expected_scale_shape = calculate_scale_shape(weight, BLOCK_SIZE) + block_rows, block_cols = expected_scale_shape + if scale_inv.shape != expected_scale_shape: + logger.warning( + f"scale_inv shape {scale_inv.shape} doesn't match expected shape {expected_scale_shape}" + ) + + # NOTE: When processing large models on-the-fly, misalignment between block boundaries + # and DTensor local shape partitioning can lead to silent numerical inaccuracies. + dequantized = float_weight.detach().clone().to(dtype=dtype) + + # Apply scaling factors to each block + for i in range(block_rows): + row_start = i * BLOCK_SIZE + row_end = min(row_start + BLOCK_SIZE, orig_shape[0]) + + for j in range(block_cols): + col_start = j * BLOCK_SIZE + col_end = min(col_start + BLOCK_SIZE, orig_shape[1]) + + # Get the block + block = float_weight[row_start:row_end, col_start:col_end] + + scale = scale_inv[i, j] + block = block * scale + + # Explicitly convert block to dtype + block_converted = block.to(dtype=torch.float32) + # Store the dequantized block + dequantized[row_start:row_end, col_start:col_end] = block_converted + + return dequantized diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py new file mode 100644 index 0000000000..890ae00f36 --- /dev/null +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re +from typing import Any + +import torch + +from torchtitan.protocols.state_dict_adapter import StateDictAdapter + +from .args import DeepSeekV3ModelArgs +from .quantization import calculate_scale_shape, dequantize_from_fp8 + + +class DeepSeekV3StateDictAdapter(StateDictAdapter): + """ + StateDictAdapter for DeepSeekV3 model. + """ + + def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None): + self.model_args = model_args + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + # Attention Module + "model.layers.{}.self_attn.q_a_proj.weight": "layers.{}.attention.wq_a.weight", + "model.layers.{}.self_attn.q_a_layernorm.weight": "layers.{}.attention.q_norm.weight", + "model.layers.{}.self_attn.q_b_proj.weight": "layers.{}.attention.wq_b.weight", + "model.layers.{}.self_attn.kv_a_proj_with_mqa.weight": "layers.{}.attention.wkv_a.weight", + "model.layers.{}.self_attn.kv_a_layernorm.weight": "layers.{}.attention.kv_norm.weight", + "model.layers.{}.self_attn.kv_b_proj.weight": "layers.{}.attention.wkv_b.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + # MLP Module + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + # Transfomer Layer + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + # MoE Module + "model.layers.{}.mlp.experts.{}.gate_proj.weight": "layers.{}.moe.experts.w1", + "model.layers.{}.mlp.experts.{}.up_proj.weight": "layers.{}.moe.experts.w3", + "model.layers.{}.mlp.experts.{}.down_proj.weight": "layers.{}.moe.experts.w2", + "model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight", + "model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.moe.shared_expert.w1", + "model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.moe.shared_expert.w3", + "model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.moe.shared_expert.w2", + "model.layers.{}.mlp.gate.e_score_correction_bias": "layers.{}.moe.expert_bias", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + def _split_experts_weights( + self, weight: torch.Tensor, n_experts: int + ) -> list[torch.Tensor]: + """ + Split the weights of the experts into a list of tensors. + """ + split_weight = torch.split(weight, weight.shape[0] // n_experts, dim=0) + return split_weight + + def _concatenate_expert_weights( + self, expert_weights_by_layer: dict[str, Any], n_experts: int + ) -> torch.Tensor: + """ + Concatenate the weights of seprate experts into GroupedExpert weights. + """ + for layer, abstract_keys in list(expert_weights_by_layer.items()): + for abstract_key, experts in list(abstract_keys.items()): + # If we have all the experts for this abstract_key, concatenate them + if len(experts) == n_experts: + sorted_expert_ids = sorted(experts.keys()) + sorted_experts = [experts[i] for i in sorted_expert_ids] + stacked_tensor = torch.stack(sorted_experts, dim=0) + + # Remove these experts from the tracking dict to free memory + del expert_weights_by_layer[layer][abstract_key] + if not expert_weights_by_layer[layer]: + del expert_weights_by_layer[layer] + + return stacked_tensor + + return None + + def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: + """ + Dequantize the weights from float8 to float32. + """ + + scale_inv_keys = [] + for key, weight in state_dict.items(): + if key.endswith(".weight") and key + "_scale_inv" in state_dict: + scale_inv = state_dict[key + "_scale_inv"] + dequantized_weight = dequantize_from_fp8( + weight, scale_inv, dtype=torch.float32 + ) + # update the weight and remove the scale_inv tensor + state_dict[key] = dequantized_weight + scale_inv_keys.append(key + "_scale_inv") + + for key in scale_inv_keys: + state_dict.pop(key) + + return state_dict + + def _add_quantization_scale_inv_tensors( + self, state_dict: dict[str, Any] + ) -> dict[str, Any]: + """ + Add quantization scale tensors the state_dict. + """ + non_quantized_keys = [ + "input_layernorm.weight", + "post_attention_layernorm.weight", + "norm.weight", + "lm_head.weight", + "embed_tokens.weight", + "mlp.gate.weight", + ] + + weight_scale_inv_state_dict = {} + for key, value in state_dict.items(): + if key.endswith(".weight") and not any( + non_quantized_key in key for non_quantized_key in non_quantized_keys + ): + expected_scale_shape = calculate_scale_shape(value) + # add weight_scale_inv to the state_dict + weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( + expected_scale_shape, dtype=torch.float32 + ) + + state_dict.update(weight_scale_inv_state_dict) + return state_dict + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + """ + 1. Convert between the HF shape and the torchtitan shape. + 2. Split the GroupedExperts' weight into seprate expert's wegiht. + """ + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + + hf_state_dict = {} + + for key, value in state_dict.items(): + if "moe.experts" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_abstract_key = to_hf_map[abstract_key] + + # Split expert weights into seperate expert weights + split_values = self._split_experts_weights( + value, self.model_args.moe_args.num_experts + ) + + for expert_num in range(0, self.model_args.moe_args.num_experts): + new_key = new_abstract_key.format(layer_num, expert_num) + hf_state_dict[new_key] = split_values[expert_num].squeeze() + + elif "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = to_hf_map[abstract_key] + new_key = new_key.format(layer_num) + + # torchtitan shape: (1, s[1], s[2]) -> HF shape: (s[1], s[2]) + if "shared_expert" in key: + value = value.squeeze(0) + + hf_state_dict[new_key] = value + + else: + new_key = to_hf_map[key] + hf_state_dict[new_key] = value + + hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( + hf_state_dict + ) + return hf_state_dict_with_scale_inv + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + """ + 1. When loading from HF checkpoint, dequantize the weights from float8 to float32. + 2. Convert between the HF shape and the torchtitan shape. + 3. Concate seprate expert's wegiht into GroupedExperts' weight. + """ + # dequantize the tensor in state_dict and remove the scale_inv tensor + hf_state_dict = self._dequantize(hf_state_dict) + state_dict = {} + + expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} + + for key, value in hf_state_dict.items(): + if "mlp.experts" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=2) + layer_num, expert_num = re.findall(r"\d+", key) + new_key = self.from_hf_map[abstract_key] + new_key = new_key.format(layer_num) + + # Store the expert's weight in expert_weights_by_layer for concating later. + if layer_num not in expert_weights_by_layer: + expert_weights_by_layer[layer_num] = {} + if abstract_key not in expert_weights_by_layer[layer_num]: + expert_weights_by_layer[layer_num][abstract_key] = {} + expert_weights_by_layer[layer_num][abstract_key][expert_num] = value + + # try to concat the expert's weight into GroupedExperts' weight. + stacked_value = self._concatenate_expert_weights( + expert_weights_by_layer, self.model_args.moe_args.num_experts + ) + if stacked_value is not None: + state_dict[new_key] = stacked_value + + elif "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = self.from_hf_map[abstract_key] + new_key = new_key.format(layer_num) + + # HF shape: (s[1], s[2]) -> torchtitan shape: (1, s[1], s[2]) + if "shared_experts" in key: + value = value.unsqueeze(0) + + state_dict[new_key] = value + + else: + new_key = self.from_hf_map[key] + state_dict[new_key] = value + + return state_dict diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index b8d777306c..195429b147 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -282,10 +282,12 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): self.register_buffer( "expert_bias", torch.zeros(num_experts, dtype=torch.float32), + persistent=True, ) self.register_buffer( "tokens_per_expert", torch.zeros(num_experts, dtype=torch.float32), + persistent=False, ) else: self.expert_bias = None From 8bd8c930efdcc999b49a8d13332148ffe2bd090a Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Tue, 12 Aug 2025 16:35:31 -0700 Subject: [PATCH 084/128] Move fqn mapping logic to StateDictAdapter (#1557) This moves the logic that parses `model.safetensors.index.json` and generates the `fqn_to_index_mapping` to `StateDictAdapter` since this logic should be shared by all classes that inherit from `StateDictAdapter`. --- .../checkpoint_conversion/convert_to_hf.py | 2 +- torchtitan/components/checkpoint.py | 6 ++-- torchtitan/experiments/forge/train_spec.py | 4 +-- torchtitan/models/README.md | 2 +- .../models/llama3/model/state_dict_adapter.py | 24 ++------------ torchtitan/protocols/__init__.py | 3 +- torchtitan/protocols/state_dict_adapter.py | 33 ++++++++++++++++++- torchtitan/protocols/train_spec.py | 4 +-- 8 files changed, 45 insertions(+), 33 deletions(-) diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py index 39c46a16d2..db69a34b0e 100644 --- a/scripts/checkpoint_conversion/convert_to_hf.py +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -65,7 +65,7 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat "--hf_assets_path", type=Path, help="Path to HF assets directory. This is used to get the model.safetensors.index.json mapping", - default="./assets/hf/Llama3.1-8B", + default="./assets/hf/Llama-3.1-8B", ) parser.add_argument("--model_name", type=str, nargs="?", default="llama3") parser.add_argument("--model_flavor", type=str, nargs="?", default="8B") diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index e55f792dff..478062e8e1 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -37,7 +37,7 @@ from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer from torchtitan.config import Checkpoint as CheckpointConfig, TORCH_DTYPE_MAP -from torchtitan.protocols import StateDictAdapter +from torchtitan.protocols import BaseStateDictAdapter from torchtitan.tools.logging import logger from torchtitan.tools.utils import GarbageCollection @@ -177,7 +177,7 @@ class CheckpointManager: checkpoint_config (Checkpoint): The config used to configure the checkpointing. base_folder (str): The base folder to save the checkpoint. Will be concatenated with checkpoint_config.folder - sd_adapter (Optional[type[StateDictAdapter]]): The adapter used to convert model state + sd_adapter (Optional[type[BaseStateDictAdapter]]): The adapter used to convert model state dicts between native format and other formats. ft_manager (Optional[ft.Manager]): The FTManager from TorchFT. @@ -191,7 +191,7 @@ def __init__( lr_schedulers: LRSchedulersContainer, states: dict[str, Any], checkpoint_config: CheckpointConfig, - sd_adapter: StateDictAdapter | None, + sd_adapter: BaseStateDictAdapter | None, base_folder: str = "", ft_manager: FTManager | None = None, ) -> None: diff --git a/torchtitan/experiments/forge/train_spec.py b/torchtitan/experiments/forge/train_spec.py index f3ab820535..463f608b2b 100644 --- a/torchtitan/experiments/forge/train_spec.py +++ b/torchtitan/experiments/forge/train_spec.py @@ -8,7 +8,7 @@ # Import torchtitan.models to ensure all train specs are registered import torchtitan.models # noqa: F401 -from torchtitan.protocols import BaseModelArgs, ModelProtocol, StateDictAdapter +from torchtitan.protocols import BaseModelArgs, BaseStateDictAdapter, ModelProtocol from torchtitan.protocols.train_spec import ( _train_specs, LossFunctionBuilder, @@ -30,7 +30,7 @@ class ForgeTrainSpec: build_optimizers_fn: OptimizersBuilder build_lr_schedulers_fn: LRSchedulersBuilder build_loss_fn: LossFunctionBuilder - state_dict_adapter: type[StateDictAdapter] | None = None + state_dict_adapter: type[BaseStateDictAdapter] | None = None # Copy and transform train specs from torchtitan.protocols.train_spec._train_specs diff --git a/torchtitan/models/README.md b/torchtitan/models/README.md index a007c6cb94..d76ac4fc24 100644 --- a/torchtitan/models/README.md +++ b/torchtitan/models/README.md @@ -20,7 +20,7 @@ The folder should be organized as follows - `init_weights()` is used to properly initialize the parameters and buffers in the model. Please define it in a recursive way so that every submodule has its own `init_weights()`. - Add additional files to reduce the complexity of `model.py` if it grows too large or complex, e.g. moe.py to host the `MoE`, `Router`, and `GroupedExperts` modules. - `state_dict_adapter.py` - - Inherit [`StateDictAdapter`](/torchtitan/protocols/state_dict_adapter.py) to implement state dict mappings between `torchtitan` model definition and other model definitions (e.g. from HuggingFace so that we can save / load model checkpoints in HF formats). + - Inherit [`BaseStateDictAdapter`](/torchtitan/protocols/state_dict_adapter.py) to implement state dict mappings between `torchtitan` model definition and other model definitions (e.g. from HuggingFace so that we can save / load model checkpoints in HF formats). - There are multiple ways such adapters could be used - Checkpoint conversion scripts in `scripts/checkpoint_conversion/` will use them to adapt state dicts containing non-sharded `torch.Tensor` on CPU. - During training, [`CheckpointManager`](/torchtitan/components/checkpoint.py) will use them to adapt state dicts containing (potentially sharded) `DTensor` on GPUs to save / load checkpoints in HF format. diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 91259d0bce..cae0b4c174 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -4,9 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import json import logging -import os import re from typing import Any @@ -19,6 +17,8 @@ class Llama3StateDictAdapter(StateDictAdapter): def __init__(self, model_args: TransformerModelArgs, hf_assets_path: str | None): + super().__init__(model_args, hf_assets_path) + self.model_args = model_args self.hf_assets_path = hf_assets_path self.from_hf_map = { @@ -37,26 +37,6 @@ def __init__(self, model_args: TransformerModelArgs, hf_assets_path: str | None) "lm_head.weight": "output.weight", } - if hf_assets_path: - mapping_path = os.path.join(hf_assets_path, "model.safetensors.index.json") - try: - with open(mapping_path, "r") as f: - hf_safetensors_indx = json.load(f) - except FileNotFoundError: - logger.warning( - "model.safetensors.index.json not found at hf_assets_path: {mapping_path}. \ - Defaulting to saving a single safetensors file if checkpoint is saved in HF format.", - ) - hf_safetensors_indx = None - - if hf_safetensors_indx: - self.fqn_to_index_mapping = {} - for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): - indx = re.search(r"\d+", raw_indx).group(0) - self.fqn_to_index_mapping[hf_key] = indx - else: - self.fqn_to_index_mapping = None - # HuggingFace permutation function (exact copy from their conversion script) def _permute(self, w, n_heads_arg, dim1=None, dim2=None): if dim1 is None: diff --git a/torchtitan/protocols/__init__.py b/torchtitan/protocols/__init__.py index 2d1b283f11..de1dc6f6da 100644 --- a/torchtitan/protocols/__init__.py +++ b/torchtitan/protocols/__init__.py @@ -6,7 +6,7 @@ from .model import BaseModelArgs, ModelProtocol from .model_converter import ModelConverter, ModelConvertersContainer -from .state_dict_adapter import StateDictAdapter +from .state_dict_adapter import BaseStateDictAdapter, StateDictAdapter __all__ = [ "BaseModelArgs", @@ -14,4 +14,5 @@ "ModelConverter", "ModelConvertersContainer", "StateDictAdapter", + "BaseStateDictAdapter", ] diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 1975a9ed08..ce03d732d6 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -4,13 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import json +import logging +import os +import re from abc import ABC, abstractmethod from typing import Any +logger = logging.getLogger() + from .model import BaseModelArgs -class StateDictAdapter(ABC): +class BaseStateDictAdapter(ABC): """Abstract base class for state dict transformations. This class defines the interface for converting between native model @@ -47,3 +53,28 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: The converted native model state dict """ pass + + +class StateDictAdapter(BaseStateDictAdapter): + """State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping""" + + def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None): + if hf_assets_path: + mapping_path = os.path.join(hf_assets_path, "model.safetensors.index.json") + try: + with open(mapping_path, "r") as f: + hf_safetensors_indx = json.load(f) + except FileNotFoundError: + logger.warning( + "model.safetensors.index.json not found at hf_assets_path: {mapping_path}. \ + Defaulting to saving a single safetensors file if checkpoint is saved in HF format.", + ) + hf_safetensors_indx = None + + if hf_safetensors_indx: + self.fqn_to_index_mapping = {} + for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): + indx = re.search(r"\d+", raw_indx).group(0) + self.fqn_to_index_mapping[hf_key] = indx + else: + self.fqn_to_index_mapping = None diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index fc1ed1b279..06fa3a1bc6 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -21,7 +21,7 @@ from torchtitan.config import LRScheduler from .model import BaseModelArgs, ModelProtocol -from .state_dict_adapter import StateDictAdapter +from .state_dict_adapter import BaseStateDictAdapter ParallelizeFunction: TypeAlias = Callable[..., nn.Module] @@ -53,7 +53,7 @@ class TrainSpec: build_loss_fn: LossFunctionBuilder build_validator_fn: ValidatorBuilder | None = None build_metrics_processor_fn: MetricsProcessorBuilder | None = None - state_dict_adapter: type[StateDictAdapter] | None = None + state_dict_adapter: type[BaseStateDictAdapter] | None = None _train_specs = {} From 21416c4c4841a40eb7d3e7696dd2b573c25695e4 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Tue, 12 Aug 2025 19:05:15 -0700 Subject: [PATCH 085/128] Update .gitignore (#1560) Added back `assets/tokenizer/*` to `.gitignore` for people still using old `tokenizer_path` --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 6df39a9ead..34e53e84c9 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ torchtitan/datasets/**/*.model # hf assets assets/hf/* +assets/tokenizer/* torchtitan/experiments/flux/assets/* # temp files From 0c51d924cf10a4ee081fcd4478c1a7d94ebc1ca8 Mon Sep 17 00:00:00 2001 From: ebsmothers Date: Wed, 13 Aug 2025 09:33:56 -0700 Subject: [PATCH 086/128] fix state dict adapter in forge engine (#1563) This was broken by #1526, updating to match the changes in train.py. Tested via a run in forge --- torchtitan/experiments/forge/engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index 398a1c5d5e..45bfa9a605 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -216,7 +216,9 @@ def __init__(self, job_config: ForgeJobConfig): states={"train_state": self}, checkpoint_config=job_config.checkpoint, sd_adapter=( - self.train_spec.state_dict_adapter(model_args) + self.train_spec.state_dict_adapter( + model_args, job_config.model.hf_assets_path + ) if self.train_spec.state_dict_adapter else None ), From 48b6520b01b1b91a76e4a100e535582222c9d424 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Wed, 13 Aug 2025 10:50:57 -0700 Subject: [PATCH 087/128] unit test for download_hf_assets script (#1556) --- tests/unit_tests/test_download_hf_assets.py | 258 ++++++++++++++++++++ 1 file changed, 258 insertions(+) create mode 100644 tests/unit_tests/test_download_hf_assets.py diff --git a/tests/unit_tests/test_download_hf_assets.py b/tests/unit_tests/test_download_hf_assets.py new file mode 100644 index 0000000000..f4e3a44298 --- /dev/null +++ b/tests/unit_tests/test_download_hf_assets.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest +from unittest.mock import Mock, patch + +from scripts.download_hf_assets import download_hf_assets + + +class TestDownloadHfAssets(unittest.TestCase): + """Tests for the download_hf_assets script + + We mock `huggingface_hub.list_repo_files` and `huggingface_hub.hf_hub_download` to simulate the meta-llama/Llama-3.1-8B repo + """ + + # Complete file list from the actual meta-llama/Llama-3.1-8B repository + COMPLETE_REPO_FILES = [ + "config.json", + "generation_config.json", + "model.safetensors.index.json", + "model-00001-of-00004.safetensors", + "model-00002-of-00004.safetensors", + "model-00003-of-00004.safetensors", + "model-00004-of-00004.safetensors", + "original/consolidated.00.pth", + "original/params.json", + "original/tokenizer.model", + "special_tokens_map.json", + "tokenizer.json", + "tokenizer_config.json", + "LICENSE", + "README.md", + "USE_POLICY.md", + ] + + # Expected files for each asset type + EXPECTED_FILES = { + "tokenizer": [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "original/tokenizer.model", + ], + "config": ["config.json", "generation_config.json"], + "safetensors": [ + "model-00001-of-00004.safetensors", + "model-00002-of-00004.safetensors", + "model-00003-of-00004.safetensors", + "model-00004-of-00004.safetensors", + "model.safetensors.index.json", + ], + "index": ["model.safetensors.index.json"], + } + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.repo_id = "meta-llama/Llama-3.1-8B" + + def tearDown(self): + import shutil + + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def _setup_mocks(self, mock_download, mock_list_files, repo_files=None): + """Helper to setup mock configurations""" + mock_list_files.return_value = repo_files or self.COMPLETE_REPO_FILES + mock_download.return_value = None + + def _get_downloaded_files(self, mock_download): + """Helper to extract downloaded filenames from mock calls""" + return [call[1]["filename"] for call in mock_download.call_args_list] + + def _assert_files_downloaded(self, mock_download, expected_files): + """Helper to assert expected files were downloaded""" + self.assertEqual(mock_download.call_count, len(expected_files)) + downloaded_files = self._get_downloaded_files(mock_download) + for expected_file in expected_files: + self.assertIn(expected_file, downloaded_files) + + def _call_download_hf_assets(self, **kwargs): + """Helper to call download_hf_assets with common defaults""" + defaults = { + "repo_id": self.repo_id, + "local_dir": self.temp_dir, + } + defaults.update(kwargs) + return download_hf_assets(**defaults) + + @patch("huggingface_hub.list_repo_files") + @patch("huggingface_hub.hf_hub_download") + def test_download_single_asset_types(self, mock_download, mock_list_files): + """Test downloading individual asset types""" + self._setup_mocks(mock_download, mock_list_files) + + # Test each asset type individually + for asset_type, expected_files in self.EXPECTED_FILES.items(): + with self.subTest(asset_type=asset_type): + mock_download.reset_mock() + self._call_download_hf_assets(asset_types=[asset_type]) + self._assert_files_downloaded(mock_download, expected_files) + + @patch("huggingface_hub.list_repo_files") + @patch("huggingface_hub.hf_hub_download") + def test_download_multiple_asset_types(self, mock_download, mock_list_files): + """Test downloading multiple asset types together""" + self._setup_mocks(mock_download, mock_list_files) + + # Get all expected files (removing duplicates) + all_expected_files = set() + for files in self.EXPECTED_FILES.values(): + all_expected_files.update(files) + + self._call_download_hf_assets(asset_types=list(self.EXPECTED_FILES.keys())) + self._assert_files_downloaded(mock_download, all_expected_files) + + @patch("huggingface_hub.list_repo_files") + @patch("huggingface_hub.hf_hub_download") + def test_download_all_files(self, mock_download, mock_list_files): + """Test downloading all files with --all option""" + self._setup_mocks(mock_download, mock_list_files) + + self._call_download_hf_assets(asset_types=[], download_all=True) + self._assert_files_downloaded(mock_download, self.COMPLETE_REPO_FILES) + + @patch("huggingface_hub.list_repo_files") + @patch("huggingface_hub.hf_hub_download") + def test_additional_patterns(self, mock_download, mock_list_files): + """Test downloading with additional file patterns""" + test_files = ["tokenizer.json", "custom_file.txt", "README.md"] + self._setup_mocks(mock_download, mock_list_files, repo_files=test_files) + + self._call_download_hf_assets( + asset_types=["tokenizer"], additional_patterns=["*.txt"] + ) + + # Only tokenizer.json and custom_file.txt should be downloaded + expected_files = ["tokenizer.json", "custom_file.txt"] + self._assert_files_downloaded(mock_download, expected_files) + + @patch("huggingface_hub.hf_hub_download") + def test_list_files(self, mock_download): + """Tests that list files returns correct list of files by using real huggingface_hub.list_files""" + """This test uses larger deepseek-ai/DeepSeek-V3 repo for more thorough test""" + + # Setup mock download + mock_download.return_value = None + + # Test downloading safetensors asset type + self._call_download_hf_assets( + repo_id="deepseek-ai/DeepSeek-V3", + asset_types=["safetensors"], + ) + + # Verify all 163 safetensors files plus index file are downloaded + expected_safetensors_files = [ + f"model-{i:05d}-of-000163.safetensors" for i in range(1, 164) + ] + expected_files = expected_safetensors_files + [ + "model.safetensors.index.json", + ] + + self._assert_files_downloaded(mock_download, expected_files) + + @patch("huggingface_hub.list_repo_files") + @patch("huggingface_hub.hf_hub_download") + def test_nested_directory_handling(self, mock_download, mock_list_files): + """Tests that files in nested directory files are detected and downloaded correctly""" + test_files = [ + "tokenizer.json", + "original/tokenizer.model", + "original/consolidated.00.pth", # Should NOT be downloaded (no .pth pattern) + "config.json", + ] + self._setup_mocks(mock_download, mock_list_files, repo_files=test_files) + + self._call_download_hf_assets(asset_types=["tokenizer", "config"]) + + # Verify nested tokenizer file is downloaded but .pth file is not + expected_files = ["tokenizer.json", "original/tokenizer.model", "config.json"] + self._assert_files_downloaded(mock_download, expected_files) + + # Verify .pth file was NOT downloaded + downloaded_files = self._get_downloaded_files(mock_download) + self.assertNotIn("original/consolidated.00.pth", downloaded_files) + + @patch("huggingface_hub.list_repo_files") + def test_missing_files_warning(self, mock_list_files): + """Test warning when requested files are not found""" + mock_list_files.return_value = ["config.json", "README.md"] + + with patch("builtins.print") as mock_print: + self._call_download_hf_assets(asset_types=["tokenizer"]) + + # Check that warning was printed + warning_calls = [ + call + for call in mock_print.call_args_list + if "Warning: No matching files found for asset_type 'tokenizer'" + in str(call) + ] + self.assertTrue(len(warning_calls) > 0) + + @patch("huggingface_hub.list_repo_files") + @patch("huggingface_hub.hf_hub_download") + def test_download_failure_handling(self, mock_download, mock_list_files): + """Test handling of download failures""" + from requests.exceptions import HTTPError + + self._setup_mocks( + mock_download, + mock_list_files, + repo_files=["tokenizer.json", "missing_file.json"], + ) + + # Mock 404 error for missing file + def download_side_effect(*args, **kwargs): + if kwargs["filename"] == "missing_file.json": + response = Mock() + response.status_code = 404 + raise HTTPError(response=response) + return None + + mock_download.side_effect = download_side_effect + + with patch("builtins.print") as mock_print: + self._call_download_hf_assets( + asset_types=["tokenizer"], additional_patterns=["missing_file.json"] + ) + + # Check that 404 error was handled gracefully + error_calls = [ + call + for call in mock_print.call_args_list + if "File missing_file.json not found, skipping..." in str(call) + ] + self.assertTrue(len(error_calls) > 0) + + def test_invalid_repo_id_format(self): + """Test error handling for invalid repo_id format""" + with self.assertRaises(ValueError) as context: + self._call_download_hf_assets( + repo_id="invalid-repo-id", asset_types=["tokenizer"] + ) + self.assertIn("Invalid repo_id format", str(context.exception)) + + def test_unknown_asset_type(self): + """Test error handling for unknown asset type""" + with self.assertRaises(ValueError) as context: + self._call_download_hf_assets(asset_types=["unknown_type"]) + self.assertIn("Unknown asset type unknown_type", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() From aeb3a4bafc7c6794183939059fe5d56c33c81b0f Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:07:30 -0700 Subject: [PATCH 088/128] [EP] add support for ETP=1 (#1555) This is a followup of original EP support https://github.com/pytorch/torchtitan/pull/1324 ### PR summary [TBA] description + figure ### numerics verification setup - optimizer Adam - steps 100, warmup_steps 20 - seed 42 comparison set - FSDP 2 - FSDP 2, CP 2, TP 2, EP 8, ETP 1 - FSDP 2 (EP 2), PP 2, TP 2 (ETP 2) image --- scripts/estimate/estimation.py | 1 + scripts/generate/test_generate.py | 1 + tests/unit_tests/test_model_converter.py | 1 + torchtitan/config/job_config.py | 24 +++- torchtitan/distributed/expert_parallel.py | 49 ++++++++ torchtitan/distributed/parallel_dims.py | 35 ++++-- torchtitan/experiments/forge/engine.py | 1 + .../experiments/llama4/infra/parallelize.py | 20 ++- .../llama4/train_configs/debug_model.toml | 1 + .../llama4/train_configs/llama4_17bx128e.toml | 2 + .../llama4/train_configs/llama4_17bx16e.toml | 2 + .../models/deepseek_v3/infra/parallelize.py | 5 +- .../train_configs/debug_model.toml | 5 +- .../train_configs/deepseek_v3_16b.toml | 5 +- .../train_configs/deepseek_v3_671b.toml | 5 +- torchtitan/models/moe.py | 117 +++++++++++++++--- torchtitan/train.py | 1 + 17 files changed, 230 insertions(+), 45 deletions(-) diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 218e7a4c6e..510cc394f7 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -52,6 +52,7 @@ def estimate_memory(job_config: JobConfig): tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, ep=parallelism_config.expert_parallel_degree, + etp=parallelism_config.expert_tensor_parallel_degree, world_size=world_size, ) # ParallelDims.build_mesh has to happen outside of the FakeTensorMode diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index ae20d11826..60cd3d04c1 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -125,6 +125,7 @@ def test_generate( tp=world_size, pp=1, ep=1, + etp=1, world_size=world_size, ) world_mesh = parallel_dims.world_mesh diff --git a/tests/unit_tests/test_model_converter.py b/tests/unit_tests/test_model_converter.py index 572a269a93..bfcb25189d 100644 --- a/tests/unit_tests/test_model_converter.py +++ b/tests/unit_tests/test_model_converter.py @@ -22,6 +22,7 @@ def build_parallel_dims(job_config, world_size): tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, ep=parallelism_config.expert_parallel_degree, + etp=parallelism_config.expert_tensor_parallel_degree, world_size=world_size, ) return parallel_dims diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index f407fe6e78..9a78451fc6 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -374,9 +374,27 @@ class Parallelism: expert_parallel_degree: int = 1 """ - Expert parallelism degree. 1 means disabled. - Currently, only "dp2ep" is supported, with the following constraints: - context_parallel_degree <= expert_parallel_degree <= data_parallel_shard_degree * context_parallel_degree + Expert parallelism degree. 1 means disabled. No effect for non-MoE models. + Currently, it is supported with the following constraints: + - when etp = tp: + - cp <= ep <= dp_shard * cp + - ep % cp == 0 + - dp_shard * cp % ep == 0 + - when etp = 1: + - cp * tp <= ep <= dp_shard * cp * tp + - ep % (cp * tp) == 0 + - dp_shard * cp * tp % ep == 0 + Note that this is still an experimental feature. Some contrains will be + relaxed soon when we have more flexible DeviceMesh support. + """ + + expert_tensor_parallel_degree: int = 1 + """ + Expert tensor parallelism degree. 1 means disabled. No effect for non-MoE models, or when ep = 1. + With this option, the tensor parallel degree on routed experts can be different from that on other params. + Currently, we only support either + - [partial dp -> ep] etp = tp + - [partial dp + all tp -> ep] etp = 1 Note that this is still an experimental feature. """ diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index bc5d43f9f2..915a5ac107 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -363,3 +363,52 @@ def wrapper( return out return wrapper + + +# This class is to support Sequence Parallel for ETP=1 +# when EP borrows from all TP and part of DP +class ReordererSequenceParallel(ParallelStyle): + def _prepare_inputput_fn(self, mod, inputs, device_mesh): + top_scores, selected_experts_indices = inputs + + top_scores = DTensor.from_local(top_scores, device_mesh, (Replicate(),)) + selected_experts_indices = DTensor.from_local( + selected_experts_indices, device_mesh, (Replicate(),) + ) + + # TODO: If needed, we can pad tokens in case bs*slen is not divisible by TP degree + # if top_scores.shape[0] % device_mesh.size() != 0: + # num_tokens = top_scores.shape[0] + # tp_size = device_mesh.size() + # n_pad = (num_tokens // tp_size + 1) * tp_size - num_tokens + # selected_experts_indices = F.pad(selected_experts_indices, [0, 0, 0, n_pad]) + # top_scores = F.pad(top_scores, [0, 0, 0, n_pad]) + assert top_scores.shape[0] % device_mesh.size() == 0 + + # split on the bs*slen dimension + top_scores = top_scores.redistribute(device_mesh, (Shard(0),)).to_local() + selected_experts_indices = selected_experts_indices.redistribute( + device_mesh, (Shard(0),) + ).to_local() + + return top_scores, selected_experts_indices + + def _prepare_output_fn(self, mod, outputs, device_mesh): + top_scores, token_indices_experts_sorted, num_tokens_per_expert = outputs + + # NOTE: As we shard routed tokens along bs*slen dim across the TP ranks, + # the MoE gather and scatter still require global token indices. + num_tokens = top_scores.shape[0] + local_rank = device_mesh.get_local_rank() + token_indices_experts_sorted += num_tokens // device_mesh.size() * local_rank + + return top_scores, token_indices_experts_sorted, num_tokens_per_expert + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=None, + input_fn=self._prepare_inputput_fn, + output_fn=self._prepare_output_fn, + ) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index bbb3874b57..dbb443c6b2 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -23,6 +23,7 @@ class ParallelDims: tp: int pp: int ep: int + etp: int world_size: int _world_mesh: DeviceMesh = None @@ -31,18 +32,19 @@ def __post_init__(self): self._validate() def _validate(self): - dp_replicate, dp_shard, cp, tp, pp, ep = ( + dp_replicate, dp_shard, cp, tp, pp, ep, etp = ( self.dp_replicate, self.dp_shard, self.cp, self.tp, self.pp, self.ep, + self.etp, ) - for d in (dp_replicate, cp, tp, pp, ep): + for d in (dp_replicate, cp, tp, pp, ep, etp): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" - assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." + assert dp_shard == -1 or dp_shard >= 1, "dp_shard must -1 or >=1." if dp_shard < 0: self.dp_shard = dp_shard = self.world_size // (dp_replicate * cp * tp * pp) assert dp_shard >= 1 @@ -53,8 +55,13 @@ def _validate(self): ) if ep > 1: - # EP would borrow all cp and some dp_shard degree - assert ep % cp == 0 and (dp_shard * cp) % ep == 0 + assert etp == tp or etp == 1, "Currently we only support ETP=TP or ETP=1" + if etp == tp: + # EP would borrow all cp and some dp_shard degree + assert ep % cp == 0 and (dp_shard * cp) % ep == 0 + elif etp == 1: + # EP would borrow all cp and tp and some dp_shard degree + assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0 def build_mesh(self) -> DeviceMesh: # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel @@ -68,9 +75,15 @@ def build_mesh(self) -> DeviceMesh: def _build_mesh_with_ep(self) -> DeviceMesh: # With ep, dp_shard and ep are derived submeshes: # dp_shard = dp_shard_mod_ep * dp_shard_in_ep - # ep = dp_shard_in_ep * cp - dp_shard_mod_ep = self.dp_shard * self.cp // self.ep - dp_shard_in_ep = self.ep // self.cp + if self.etp == self.tp: + # ep = dp_shard_in_ep * cp + dp_shard_mod_ep = self.dp_shard * self.cp // self.ep + dp_shard_in_ep = self.ep // self.cp + else: + assert self.etp == 1 + # ep = dp_shard_in_ep * cp * tp + dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep + dp_shard_in_ep = self.ep // (self.cp * self.tp) dims = [] names = [] @@ -121,6 +134,8 @@ def _build_mesh_with_ep(self) -> DeviceMesh: dp_shard_cp_mesh_dim_names.append("cp") dp_cp_mesh_dim_names.append("cp") ep_mesh_dim_names.append("cp") + if self.etp == 1 and self.tp_enabled: + ep_mesh_dim_names.append("tp") mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") @@ -218,6 +233,10 @@ def pp_enabled(self): def ep_enabled(self): return self.ep > 1 + @property + def etp_enabled(self): + return self.etp > 1 + @property def fsdp_gradient_divide_factor(self) -> int: # This is needed for FSDP-sharded experts when Expert Parallel is enabled. diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index 45bfa9a605..e930131f56 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -80,6 +80,7 @@ def __init__(self, job_config: ForgeJobConfig): tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, ep=parallelism_config.expert_parallel_degree, + etp=parallelism_config.expert_tensor_parallel_degree, world_size=world_size, ) diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 2db1c64b58..bc6f828980 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -25,6 +25,7 @@ ExpertParallel, ExpertTensorParallel, NoParallel, + ReordererSequenceParallel, TensorParallel, ) @@ -87,7 +88,6 @@ def parallelize_llama( enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) - # TODO: shall we support tensorwise float8 comms for MoE TP if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, @@ -95,9 +95,12 @@ def parallelize_llama( ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, ep_tp_mesh=( world_mesh["ep", "tp"] - if parallel_dims.tp_enabled and parallel_dims.ep_enabled + if parallel_dims.tp_enabled + and parallel_dims.ep_enabled + and parallel_dims.etp_enabled else None ), + etp_enabled=parallel_dims.etp_enabled, ) if job_config.activation_checkpoint.mode != "none": @@ -344,6 +347,7 @@ def apply_moe_ep_tp( tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, ep_tp_mesh: DeviceMesh | None, + etp_enabled: bool, ): for transformer_block in model.layers.values(): if not transformer_block.moe_enabled: @@ -365,13 +369,17 @@ def apply_moe_ep_tp( # input Replicate, output Partial "moe.shared_expert": TensorParallel(), } + if not etp_enabled: + # If TP is borrowed for EP, then split the tokens across TP ranks so that + # the reorderer, the all-to-all comms, and routed experts computation + # are effectively running Sequence Parallel (split along the folded bs*slen dim) + moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) parallelize_module( module=transformer_block, device_mesh=tp_mesh, parallelize_plan=moe_layer_plan, ) - # if ep_mesh is not None: experts_mesh, experts_plan = None, None if ep_mesh is None: experts_mesh = tp_mesh @@ -381,9 +389,13 @@ def apply_moe_ep_tp( experts_mesh = ep_mesh # input / output sharding on the batch / tokens dim experts_plan = ExpertParallel() - else: + elif etp_enabled: experts_mesh = ep_tp_mesh experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + else: + experts_mesh = ep_mesh + experts_plan = ExpertParallel() + parallelize_module( module=transformer_block.moe.experts, device_mesh=experts_mesh, diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index a7f068c073..0d1ed83628 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -53,6 +53,7 @@ enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1 expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 [checkpoint] enable_checkpoint = false diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index f316cd8380..00416eb91c 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -46,6 +46,8 @@ pipeline_parallel_degree = 4 # pipeline_parallel_schedule = "interleaved1f1b" # pipeline_parallel_microbatches = 2 context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 8 [checkpoint] enable_checkpoint = false diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index 725bbe903d..6a2b660cdf 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -44,6 +44,8 @@ tensor_parallel_degree = 8 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 8 [checkpoint] enable_checkpoint = false diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 8c9af6618c..7085cc1d04 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -82,9 +82,12 @@ def parallelize_deepseekv3( ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, ep_tp_mesh=( world_mesh["ep", "tp"] - if parallel_dims.tp_enabled and parallel_dims.ep_enabled + if parallel_dims.tp_enabled + and parallel_dims.ep_enabled + and parallel_dims.etp_enabled else None ), + etp_enabled=parallel_dims.etp_enabled, ) if job_config.activation_checkpoint.mode != "none": diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 093f89a18b..065803ff01 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -1,5 +1,3 @@ -# torchtitan Config.toml - [job] dump_folder = "./outputs" description = "DeepSeek-V3 debug training" @@ -52,9 +50,10 @@ data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 enable_async_tensor_parallel = false -expert_parallel_degree = 1 pipeline_parallel_degree = 1 pipeline_parallel_schedule = "1F1B" +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 [checkpoint] enable_checkpoint = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 1cedc590d2..42e2cc6bc7 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -1,5 +1,3 @@ -# torchtitan Config.toml - [job] dump_folder = "./outputs" description = "DeepSeek-V3 16B model training" @@ -50,9 +48,10 @@ data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 enable_async_tensor_parallel = false -expert_parallel_degree = 1 pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 [checkpoint] enable_checkpoint = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 51e7ddbb50..fc1b512e28 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -1,5 +1,3 @@ -# torchtitan Config.toml - [job] dump_folder = "./outputs" description = "DeepSeek-V3 671B model training" @@ -50,9 +48,10 @@ data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 8 enable_async_tensor_parallel = false -expert_parallel_degree = 1 pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 [checkpoint] enable_checkpoint = false diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 195429b147..1d2f9e5731 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -148,11 +148,12 @@ class TokenChoiceTopKRouter(nn.Module): routed to top K experts based on the router scores. Args: - gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). dim (int): Dimension of input tokens. num_experts (int): Number of experts in each moe layer. top_k (int): Number of experts each token will be routed to in token-choice routing. - use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. + score_func (Literal["softmax", "sigmoid"]): Whether to use sigmoid or softmax for router scores. + route_norm (bool): Whether to normalize the routing scores when using sigmoid. + route_scale (float): Scaling factor applied to the routing scores. """ def __init__( @@ -178,14 +179,17 @@ def forward( """ Args: x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + expert_bias (torch.Tensor | None, optional): Optional bias tensor for experts with shape ``(num_experts,)``. + Used for load balancing. Defaults to None. Returns: - routed_input (torch.Tensor): - Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. - token_indices (torch.Tensor): - Token indices for routed_input with shape ``(bs*slen*top_k,)``. - num_tokens_per_expert (torch.Tensor): - Number of tokens assigned to each expert with shape ``(num_experts,)``. + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - top_scores (torch.Tensor): + Routing scores for selected experts with shape ``(bs*slen, top_k)``. + - selected_experts_indices (torch.Tensor): + Expert indices selected for each token with shape ``(bs*slen, top_k)``. + - num_tokens_per_expert (torch.Tensor): + Number of tokens assigned to each expert with shape ``(num_experts,)``. """ # scores shape (bs*slen, num_experts) scores = self.gate(x) @@ -224,19 +228,71 @@ def forward( max=self.num_experts, ) + return top_scores, selected_experts_indices, num_tokens_per_expert + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + + +# NOTE: the reason we make this a stateless module is to support +# expert_tensor_parallel_degree=1 with consistent TP/EP APIs. +class TokenReorderer(nn.Module): + """ + This module reorders token indices to match the order of experts, enabling + efficient parallel processing of tokens by experts. + + Args: + num_experts (int): Number of experts in the MoE layer. + top_k (int): Number of experts each token will be routed to. + """ + + def __init__(self, num_experts: int, top_k: int): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + + def forward( + self, + top_scores: torch.Tensor, + selected_experts_indices: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Reorders token indices to match the order of experts for MoE routing. + + Args: + top_scores (torch.Tensor): Routing scores for selected experts, + shape (batch_size*seq_len, top_k) + selected_experts_indices (torch.Tensor): Expert indices selected for each token, + shape (batch_size*seq_len, top_k) + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - top_scores_experts_sorted: Scores reordered to match expert ordering + - token_indices_experts_sorted: Token indices reordered to match expert ordering + - num_tokens_per_expert: Number of tokens assigned to each expert + """ + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + # Reorder the token indices to match the order of the experts # token_indices_experts_sorted shape (bs*slen*top_k,) token_indices_experts_sorted = torch.argsort( selected_experts_indices.view(-1), stable=True ) - top_scores = top_scores.view(-1)[token_indices_experts_sorted] + top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] token_indices_experts_sorted = token_indices_experts_sorted // self.top_k - return top_scores, token_indices_experts_sorted, num_tokens_per_expert - - def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + return ( + top_scores_experts_sorted, + token_indices_experts_sorted, + num_tokens_per_expert, + ) class MoE(nn.Module): @@ -258,6 +314,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): route_norm=moe_args.route_norm, route_scale=moe_args.route_scale, ) + self.reorderer = TokenReorderer(num_experts=num_experts, top_k=moe_args.top_k) self.shared_expert = ( GroupedExperts( dim=dim, @@ -302,11 +359,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ bs, slen, dim = x.shape - # top_scores and selected_indices shape (bs*slen*top_k,) + # top_scores and selected_experts_indices shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) ( top_scores, - token_indices, + selected_experts_indices, num_tokens_per_expert, ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) @@ -318,19 +375,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) + # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + # NOTE: the reason we need to compute num_tokens_per_expert again is: + # 1st computation in router is to update self.tokens_per_expert + # which would be the same across all TP ranks. + # 2nd computation in reorderer is for the actual routing and experts computation + # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. + # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. + ( + top_scores_experts_sorted, + token_indices_experts_sorted, + num_tokens_per_expert, + ) = self.reorderer(top_scores, selected_experts_indices) + # shape (bs*slen*top_k, dim) - token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + token_indices_experts_sorted = token_indices_experts_sorted.reshape( + -1, 1 + ).expand(-1, dim) # shape (bs*slen*top_k, dim) routed_input = torch.gather( x.view(-1, dim), dim=0, - index=token_indices, + index=token_indices_experts_sorted, ) if self.score_before_experts: routed_input = ( - routed_input.to(torch.float32) * top_scores.reshape(-1, 1) + routed_input.to(torch.float32) + * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) # shape (bs*slen*top_k, dim) @@ -338,7 +412,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.score_before_experts: routed_output = ( - routed_output.to(torch.float32) * top_scores.reshape(-1, 1) + routed_output.to(torch.float32) + * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) # shared expert @@ -349,7 +424,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: out = torch.zeros_like(x.reshape(bs * slen, dim)) - out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.scatter_add( + dim=0, index=token_indices_experts_sorted, src=routed_output + ) out = out.reshape(bs, slen, dim) return out diff --git a/torchtitan/train.py b/torchtitan/train.py index 70b7d2ebde..e38446a398 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -101,6 +101,7 @@ def __init__(self, job_config: JobConfig): tp=parallelism_config.tensor_parallel_degree, pp=parallelism_config.pipeline_parallel_degree, ep=parallelism_config.expert_parallel_degree, + etp=parallelism_config.expert_tensor_parallel_degree, world_size=world_size, ) From 6377dce8b7fcc745473d0a47d4b852bb4d81a4a9 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 13 Aug 2025 14:33:53 -0700 Subject: [PATCH 089/128] llama4: Avoid staticmethod nested graph break for MoE compile (#1565) This nested graph break is particularly bad, it is falling back the scaled grouped mm ops to eager Test plan: `NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" tlp ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.expert_parallel_d egree=2 --training.compile` --- torchtitan/models/moe.py | 154 +++++++++++++++++++-------------------- 1 file changed, 75 insertions(+), 79 deletions(-) diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 1d2f9e5731..bd8116ea15 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -31,6 +31,79 @@ class MoEArgs: load_balance_coeff: float | None = 1e-3 +# TODO: keeping this for-loop implementation for comparison +# and readability, may remove later +@expert_parallel +def _run_experts_for_loop( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, +) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) + h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) + h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, w1.transpose(-2, -1))) + h = h * torch.bmm(x, w3.transpose(-2, -1)) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, w2.transpose(-2, -1)) + + return out + + +@expert_parallel +def _run_experts_grouped_mm( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, +) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + h = F.silu( + torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets) + ) + h = h * torch._grouped_mm( + x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets + ) + out = torch._grouped_mm(h, w2.bfloat16().transpose(-2, -1), offs=offsets).type_as(x) + + return out + + class GroupedExperts(nn.Module): def __init__( self, @@ -52,91 +125,14 @@ def forward( num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: if self.use_grouped_mm: - return GroupedExperts._run_experts_grouped_mm( + return _run_experts_grouped_mm( self.w1, self.w2, self.w3, x, num_tokens_per_expert ) else: - return GroupedExperts._run_experts_for_loop( + return _run_experts_for_loop( self.w1, self.w2, self.w3, x, num_tokens_per_expert ) - # TODO: keeping this for-loop implementation for comparison - # and readability, may remove later - @expert_parallel - @staticmethod - def _run_experts_for_loop( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() - - # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) - - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) - h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) - h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - - # side-effect code due to the usage of generate_permute_indices - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, w1.transpose(-2, -1))) - h = h * torch.bmm(x, w3.transpose(-2, -1)) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, w2.transpose(-2, -1)) - - return out - - @expert_parallel - @staticmethod - def _run_experts_grouped_mm( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 - else: - offsets = None - # fall back to regular bmm between 3D tensors - assert x.dim() == 3 - - h = F.silu( - torch._grouped_mm( - x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets - ) - ) - h = h * torch._grouped_mm( - x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets - ) - out = torch._grouped_mm( - h, w2.bfloat16().transpose(-2, -1), offs=offsets - ).type_as(x) - - return out - def init_weights(self, init_std: float): nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) From 7354848dfb6dd2d67727a4702130f75c5985ed94 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Wed, 13 Aug 2025 17:43:39 -0700 Subject: [PATCH 090/128] [MoE/EP] apply dim-1 FSDP sharding for routed experts and rewrite shared experts with FFN (#1561) **apply dim-1 FSDP sharding for routed experts when `dp_mod_ep * ep > num_experts`** This is because our routed experts are defined of shape `(num_experts, ..., ...)`. EP already shards on dim-0. FSDP's default dim-0 sharding + EP sharding will be inefficient when `dp_mod_ep * ep > num_experts`. Tested: with 8 experts FSDP2 EP4, we see default dim-0 sharding > [rank0]:w1 DTensor(local_tensor=tensor(..., device='meta', size=(1, 512, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0))) [rank0]:w2 DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 512)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0))) [rank0]:w3 DTensor(local_tensor=tensor(..., device='meta', size=(1, 512, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0))) with 4 experts, FSDP2 EP4, we see dim-1 sharding > [rank0]:w1 DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0))) [rank0]:w2 DTensor(local_tensor=tensor(..., device='meta', size=(1, 128, 512)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0))) [rank0]:w3 DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0))) also tested integration works fine with: FSDP 2, CP 2 (EP 2), TP 2 (ETP 2) **rewrite shared experts with FFN** This is because - Same reason above, but using FFN is a simpler solution, especially considering shared experts are sharded together with TransformerBlock, so no need to complicate its `sharding_placement_fn`. - It turns out for multiple shared experts, we can just stack on the `hidden_dim` dimension, and TP will just work out fine. - It also simplifies the GroupedExperts module as it no longer needs to work with shared experts. **other changes** - rename `shared_expert` to `shared_experts` - merge two `tolist()` d2h for `input_splits` and `output_splits` in `token_dispatch` into one - state dict / checkpoint conversion changes (@wwwjn please help verify) --- torchtitan/distributed/expert_parallel.py | 73 +++++----- .../experiments/llama4/infra/parallelize.py | 83 ++++++++--- torchtitan/experiments/llama4/model/args.py | 10 +- .../scripts/convert_hf_to_dcp_with_gpus.py | 10 +- torchtitan/models/deepseek_v3/model/args.py | 10 +- torchtitan/models/deepseek_v3/model/model.py | 39 +---- .../deepseek_v3/model/state_dict_adapter.py | 16 +-- torchtitan/models/moe.py | 136 +++++++++--------- 8 files changed, 189 insertions(+), 188 deletions(-) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 915a5ac107..eef4bda714 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -29,12 +29,7 @@ class _A2A(torch.autograd.Function): @staticmethod def forward(ctx, x, out_splits, in_splits, group): - if isinstance(out_splits, torch.Tensor): - out_splits = out_splits.tolist() - if isinstance(in_splits, torch.Tensor): - in_splits = in_splits.tolist() T_out = int(sum(out_splits)) - y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group) @@ -176,6 +171,7 @@ def __init__(self): def _token_dispatch(self, mod, inputs, device_mesh): # annotate module input placements/sharding with input_layouts routed_input, num_tokens_per_expert = inputs + ep_size = device_mesh.shape[0] # generate the input splits and output splits for all-to-all with torch.no_grad(): @@ -187,15 +183,20 @@ def _token_dispatch(self, mod, inputs, device_mesh): num_tokens_per_expert, group=device_mesh.get_group(), ) - # NOTE: this would incur a device-to-host sync - self.input_splits = ( - num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist() + input_splits = ( + num_tokens_per_expert.view(ep_size, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=True) ) - self.output_splits = ( - num_tokens_per_expert_group.view(device_mesh.shape[0], -1) + output_splits = ( + num_tokens_per_expert_group.view(ep_size, -1) .sum(dim=1) - .tolist() + .to(torch.device("cpu"), non_blocking=True) ) + # NOTE: this would incur a device-to-host sync + torch.cuda.current_stream().synchronize() + self.input_splits = input_splits.tolist() + self.output_splits = output_splits.tolist() # perform all-to-all routed_input = all_to_all_single_autograd( @@ -320,7 +321,7 @@ def wrapper( w2: torch.Tensor, w3: torch.Tensor, x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, + num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: global TOKEN_GROUP_ALIGN_SIZE_M if isinstance(w1, DTensor): @@ -328,37 +329,33 @@ def wrapper( w2 = w2.to_local() w3 = w3.to_local() - if num_tokens_per_expert is not None: - from torchtitan.experiments.kernels.moe.indices import ( - generate_permute_indices, + from torchtitan.experiments.kernels.moe.indices import generate_permute_indices + + experts_per_ep_rank = w1.shape[0] + num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + + with torch.no_grad(): + ( + permuted_indices, + num_tokens_per_expert, + _, # offsets, + ) = generate_permute_indices( + num_tokens_per_expert, + experts_per_ep_rank, + num_ep_ranks, + x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M, + TOKEN_GROUP_ALIGN_SIZE_M, ) - experts_per_ep_rank = w1.shape[0] - num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank - - with torch.no_grad(): - ( - permuted_indices, - num_tokens_per_expert, - _, # offsets, - ) = generate_permute_indices( - num_tokens_per_expert, - experts_per_ep_rank, - num_ep_ranks, - x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M, - TOKEN_GROUP_ALIGN_SIZE_M, - ) - - x = torch.vstack((x, x.new_zeros((x.shape[-1])))) - input_shape = x.shape - x = x[permuted_indices, :] + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] out = func(w1, w2, w3, x, num_tokens_per_expert) - if num_tokens_per_expert is not None: - out_unpermuted = out.new_empty(input_shape) - out_unpermuted[permuted_indices, :] = out - out = out_unpermuted[:-1] + out_unpermuted = out.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = out + out = out_unpermuted[:-1] return out diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index bc6f828980..6d75b4986a 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -137,9 +137,10 @@ def parallelize_llama( pp_enabled=parallel_dims.pp_enabled, cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ep_degree=parallel_dims.ep, dp_mod_ep_mesh=( world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if dp_mod_ep_mesh_dim_names + if parallel_dims.ep_enabled else None ), gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, @@ -273,6 +274,7 @@ def apply_fsdp( pp_enabled: bool, cpu_offload: bool = False, reshard_after_forward_policy: str = "default", + ep_degree: int = 1, dp_mod_ep_mesh: DeviceMesh | None = None, gradient_divide_factor: int | None = None, ): @@ -298,35 +300,57 @@ def apply_fsdp( if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() - for layer_id, transformer_block in model.layers.items(): - if reshard_after_forward_policy == "always": + match reshard_after_forward_policy: + case "always": reshard_after_forward = True - elif reshard_after_forward_policy == "never": + case "never": reshard_after_forward = False - elif reshard_after_forward_policy == "default": - if pp_enabled: - # For PP, do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = False - else: - # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.layers) - 1 - else: + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: raise ValueError( f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." ) - # NOTE: in an MoE layer, the router and the shared experts - # are sharded together with the TransformerBlock - if transformer_block.moe_enabled and dp_mod_ep_mesh: + if model.tok_embeddings is not None: + fully_shard( + model.tok_embeddings, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + + for layer_id, transformer_block in model.layers.items(): + # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping + # - the router and the shared experts are sharded together with the TransformerBlock + # - the routed experts are sharded with the remaining dp_mod_ep_mesh + if transformer_block.moe_enabled and ep_degree > 1: fsdp_mod_ep_config = fsdp_config.copy() fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + + # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). + # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding + # causes inefficiency, so we choose to do FSDP sharding on dim-1. + # Even when EP is not used, we may still want to shard the experts + # on non-0 dim. For now it may not be worth the complexity to support + # shard_placement_fn on the outer TransformerBlock-level FSDP. + _experts_shard_placement_fn = None + assert dp_mod_ep_mesh is not None + assert hasattr(transformer_block, "moe") + if ( + dp_mod_ep_mesh.size() * ep_degree + > transformer_block.moe.experts.num_experts + ): + _experts_shard_placement_fn = lambda param: Shard(1) + fully_shard( transformer_block.moe.experts, **fsdp_mod_ep_config, reshard_after_forward=reshard_after_forward, + shard_placement_fn=_experts_shard_placement_fn, ) + # NOTE: # Although the FSDP sharding of experts is done on a mesh of # a different size than other parameters, the gradient division # factor should be consistent with data. @@ -339,7 +363,17 @@ def apply_fsdp( **fsdp_config, reshard_after_forward=reshard_after_forward, ) - fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + # As an optimization, do not reshard_after_forward the last layers by default + # since FSDP would prefetch them immediately after the forward pass + if model.norm is not None and model.output is not None: + fully_shard( + [model.norm, model.output], + **fsdp_config, + reshard_after_forward=reshard_after_forward_policy == "always", + ) + + fully_shard(model, **fsdp_config) def apply_moe_ep_tp( @@ -366,14 +400,23 @@ def apply_moe_ep_tp( ), # replicate computation for the router "moe.router.gate": NoParallel(), - # input Replicate, output Partial - "moe.shared_expert": TensorParallel(), } if not etp_enabled: # If TP is borrowed for EP, then split the tokens across TP ranks so that # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) + if transformer_block.moe.shared_experts is not None: + # input Replicate, output Partial + moe_layer_plan.update( + { + "moe.shared_experts.w1": ColwiseParallel(), + "moe.shared_experts.w2": RowwiseParallel( + output_layouts=Partial() + ), + "moe.shared_experts.w3": ColwiseParallel(), + } + ) parallelize_module( module=transformer_block, device_mesh=tp_mesh, diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index dda130548d..949f4cf052 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -85,7 +85,7 @@ def get_nparams_and_flops( ) -> tuple[int, float]: nparams_embedding = 0 nparams_moe_router = 0 - nparams_shared_expert = 0 + nparams_shared_experts = 0 nparams_experts = 0 nparams_dense = 0 @@ -93,8 +93,8 @@ def get_nparams_and_flops( if "embedding" in name: nparams_embedding += p.numel() nparams_dense += p.numel() - elif "moe.shared_expert" in name: - nparams_shared_expert += p.numel() + elif "moe.shared_experts" in name: + nparams_shared_experts += p.numel() elif "moe.router" in name: nparams_moe_router += p.numel() elif "moe.experts" in name: @@ -102,11 +102,11 @@ def get_nparams_and_flops( else: nparams_dense += p.numel() - nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts nparams = nparams_dense + nparams_sparse nparams_sparse_active = ( nparams_moe_router - + nparams_shared_expert + + nparams_shared_experts + nparams_experts * self.moe_args.top_k // self.moe_args.num_experts ) diff --git a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py index bad69c0f7a..5cac0bba3e 100644 --- a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py +++ b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py @@ -57,11 +57,11 @@ def convert_to_titan_fqns(fqn: str) -> list[str]: elif "feed_forward.router.weight" in fqn: return [f"layers.{layer}.moe.router.gate.weight"] elif "feed_forward.shared_expert.down_proj.weight" in fqn: - return [f"layers.{layer}.moe.shared_expert.w2"] + return [f"layers.{layer}.moe.shared_experts.w2.weight"] elif "feed_forward.shared_expert.gate_proj.weight" in fqn: - return [f"layers.{layer}.moe.shared_expert.w3"] + return [f"layers.{layer}.moe.shared_experts.w3.weight"] elif "feed_forward.shared_expert.up_proj.weight" in fqn: - return [f"layers.{layer}.moe.shared_expert.w1"] + return [f"layers.{layer}.moe.shared_experts.w1.weight"] elif "post_attention_layernorm.weight" in fqn: return [f"layers.{layer}.ffn_norm.weight"] elif "self_attn.k_proj" in fqn: @@ -86,7 +86,7 @@ def convert_to_hf_shape(fqn: str, titan_fqns: list[str], dtensor: DTensor) -> li elif "shared_expert" in fqn: s = dtensor.shape # TODO: this is not right but I have to do this to load the checkpoint. - return torch.Size((s[2], s[1])) + return torch.Size((s[1], s[0])) return dtensor.shape @@ -96,7 +96,7 @@ def convert_to_titan_tensors(fqn: str, full_tensor: torch.Tensor) -> torch.Tenso elif "shared_expert" in fqn: # TODO: this is not right but I have to do this to load the checkpoint. full_tensor = full_tensor.transpose(1, 0) - full_tensors = [full_tensor.unsqueeze(0)] + full_tensors = [full_tensor] else: full_tensors = [full_tensor] return full_tensors diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 025a550b9b..044420d37a 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -126,7 +126,7 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in """ nparams_embedding = 0 nparams_moe_router = 0 - nparams_shared_expert = 0 + nparams_shared_experts = 0 nparams_experts = 0 nparams_dense = 0 @@ -134,8 +134,8 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in if "embedding" in name: nparams_embedding += p.numel() nparams_dense += p.numel() - elif "moe.shared_expert" in name: - nparams_shared_expert += p.numel() + elif "moe.shared_experts" in name: + nparams_shared_experts += p.numel() elif "moe.router" in name: nparams_moe_router += p.numel() elif "moe.experts" in name: @@ -143,11 +143,11 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in else: nparams_dense += p.numel() - nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts nparams = nparams_dense + nparams_sparse nparams_sparse_active = ( nparams_moe_router - + nparams_shared_expert + + nparams_shared_experts + nparams_experts * self.moe_args.top_k // self.moe_args.num_experts ) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index cfdc794ca9..dd31fc3181 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -8,52 +8,15 @@ from typing import Tuple import torch -import torch.nn.functional as F from torch import nn from torchtitan.models.attention import build_attention, init_attention_mask -from torchtitan.models.moe import MoE +from torchtitan.models.moe import FeedForward, MoE from torchtitan.protocols.train_spec import ModelProtocol from .args import DeepSeekV3ModelArgs -class FeedForward(nn.Module): - """ - FeedForward module - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. - - Attributes: - w1 (Linear): Linear transformation for the first layer. - w2 (Linear): Linear transformation for the second layer. - w3 (Linear): Linear transformation for the third layer. - - """ - - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - def init_weights(self, init_std: float = 0.02): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) - for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) - - # Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 def precompute_freqs_cis(args: DeepSeekV3ModelArgs) -> torch.Tensor: """ diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 890ae00f36..5a676b5a07 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -44,9 +44,9 @@ def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None): "model.layers.{}.mlp.experts.{}.up_proj.weight": "layers.{}.moe.experts.w3", "model.layers.{}.mlp.experts.{}.down_proj.weight": "layers.{}.moe.experts.w2", "model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight", - "model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.moe.shared_expert.w1", - "model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.moe.shared_expert.w3", - "model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.moe.shared_expert.w2", + "model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.moe.shared_experts.w1.weight", + "model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.moe.shared_experts.w3.weight", + "model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.moe.shared_experts.w2.weight", "model.layers.{}.mlp.gate.e_score_correction_bias": "layers.{}.moe.expert_bias", "model.norm.weight": "norm.weight", "lm_head.weight": "output.weight", @@ -163,11 +163,6 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: layer_num = re.search(r"\d+", key).group(0) new_key = to_hf_map[abstract_key] new_key = new_key.format(layer_num) - - # torchtitan shape: (1, s[1], s[2]) -> HF shape: (s[1], s[2]) - if "shared_expert" in key: - value = value.squeeze(0) - hf_state_dict[new_key] = value else: @@ -217,11 +212,6 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: layer_num = re.search(r"\d+", key).group(0) new_key = self.from_hf_map[abstract_key] new_key = new_key.format(layer_num) - - # HF shape: (s[1], s[2]) -> torchtitan shape: (1, s[1], s[2]) - if "shared_experts" in key: - value = value.unsqueeze(0) - state_dict[new_key] = value else: diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index bd8116ea15..40bd6c2cca 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -31,6 +31,38 @@ class MoEArgs: load_balance_coeff: float | None = 1e-3 +# can be used as dense FFN layer or shared experts in MoE layers +class FeedForward(nn.Module): + """ + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float = 0.02): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + # TODO: keeping this for-loop implementation for comparison # and readability, may remove later @expert_parallel @@ -39,39 +71,32 @@ def _run_experts_for_loop( w2: torch.Tensor, w3: torch.Tensor, x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, + num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: - if num_tokens_per_expert is not None: - # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() - - # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) - - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) - h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) - h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - - # side-effect code due to the usage of generate_permute_indices - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, w1.transpose(-2, -1))) - h = h * torch.bmm(x, w3.transpose(-2, -1)) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, w2.transpose(-2, -1)) + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) + h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) + h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) return out @@ -82,16 +107,11 @@ def _run_experts_grouped_mm( w2: torch.Tensor, w3: torch.Tensor, x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, + num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: - if num_tokens_per_expert is not None: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 - else: - offsets = None - # fall back to regular bmm between 3D tensors - assert x.dim() == 3 + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 h = F.silu( torch._grouped_mm(x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets) @@ -122,7 +142,7 @@ def __init__( def forward( self, x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, + num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: if self.use_grouped_mm: return _run_experts_grouped_mm( @@ -311,15 +331,8 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): route_scale=moe_args.route_scale, ) self.reorderer = TokenReorderer(num_experts=num_experts, top_k=moe_args.top_k) - self.shared_expert = ( - GroupedExperts( - dim=dim, - # TODO: if it doesn't use GroupedExperts.num_experts - # we can just use normal FeedForward - hidden_dim=hidden_dim * moe_args.num_shared_experts, - num_experts=1, - use_grouped_mm=moe_args.use_grouped_mm, - ) + self.shared_experts = ( + FeedForward(dim=dim, hidden_dim=hidden_dim * moe_args.num_shared_experts) if moe_args.num_shared_experts > 0 else None ) @@ -354,6 +367,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ bs, slen, dim = x.shape + x = x.view(-1, dim) # top_scores and selected_experts_indices shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) @@ -361,7 +375,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: top_scores, selected_experts_indices, num_tokens_per_expert, - ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) + ) = self.router(x, self.expert_bias) # tokens_per_expert will be used to update the expert bias for load balancing. # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- @@ -391,11 +405,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ).expand(-1, dim) # shape (bs*slen*top_k, dim) - routed_input = torch.gather( - x.view(-1, dim), - dim=0, - index=token_indices_experts_sorted, - ) + routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) if self.score_before_experts: routed_input = ( @@ -413,12 +423,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ).to(x.dtype) # shared expert - if self.shared_expert is not None: - out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( - bs * slen, dim - ) + if self.shared_experts is not None: + out = self.shared_experts(x) else: - out = torch.zeros_like(x.reshape(bs * slen, dim)) + out = torch.zeros_like(x) out = out.scatter_add( dim=0, index=token_indices_experts_sorted, src=routed_output @@ -433,8 +441,8 @@ def init_weights( ): self.experts.init_weights(init_std) self.router.init_weights(init_std) - if self.shared_expert is not None: - self.shared_expert.init_weights(init_std) + if self.shared_experts is not None: + self.shared_experts.init_weights(init_std) if self.load_balance_coeff is not None: with torch.device(buffer_device): From 6fc499f6f5b32151a799188be2208cfb09faed30 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Thu, 14 Aug 2025 18:02:21 -0700 Subject: [PATCH 091/128] quick fix dsv3 fsdp (#1575) as titled --- torchtitan/models/deepseek_v3/infra/parallelize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 7085cc1d04..8271d49dcd 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -120,9 +120,10 @@ def parallelize_deepseekv3( pp_enabled=parallel_dims.pp_enabled, cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ep_degree=parallel_dims.ep, dp_mod_ep_mesh=( world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if dp_mod_ep_mesh_dim_names + if parallel_dims.ep_enabled else None ), gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, From e629fe58ae16bbfc2ebec21f26c87ce465057147 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 15 Aug 2025 08:14:15 -0700 Subject: [PATCH 092/128] Use PYTORCH_ALLOC_CONF as PYTORCH_CUDA_ALLOC_CONF is deprecated (#1577) --- run_train.sh | 2 +- torchtitan/experiments/deepseek_v3/run_training.sh | 2 +- torchtitan/experiments/flux/inference/run_infer.sh | 2 +- torchtitan/experiments/flux/run_train.sh | 2 +- .../experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh | 2 +- .../experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/run_train.sh b/run_train.sh index 01dddd0abd..0f9d1829b7 100755 --- a/run_train.sh +++ b/run_train.sh @@ -16,7 +16,7 @@ CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} -PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +PYTORCH_ALLOC_CONF="expandable_segments:True" \ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ diff --git a/torchtitan/experiments/deepseek_v3/run_training.sh b/torchtitan/experiments/deepseek_v3/run_training.sh index b2eb8009a8..e9d183d80f 100644 --- a/torchtitan/experiments/deepseek_v3/run_training.sh +++ b/torchtitan/experiments/deepseek_v3/run_training.sh @@ -23,7 +23,7 @@ fi TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} -PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +PYTORCH_ALLOC_CONF="expandable_segments:True" \ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ diff --git a/torchtitan/experiments/flux/inference/run_infer.sh b/torchtitan/experiments/flux/inference/run_infer.sh index 126540b7ed..b5419af2fb 100755 --- a/torchtitan/experiments/flux/inference/run_infer.sh +++ b/torchtitan/experiments/flux/inference/run_infer.sh @@ -14,7 +14,7 @@ NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/flux/train_configs/debug_model.toml"} -PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +PYTORCH_ALLOC_CONF="expandable_segments:True" \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -m torchtitan.experiments.flux.inference.infer --job.config_file ${CONFIG_FILE} \ diff --git a/torchtitan/experiments/flux/run_train.sh b/torchtitan/experiments/flux/run_train.sh index 231d66fc35..6fbd781102 100755 --- a/torchtitan/experiments/flux/run_train.sh +++ b/torchtitan/experiments/flux/run_train.sh @@ -14,7 +14,7 @@ NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/flux/train_configs/debug_model.toml"} -PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +PYTORCH_ALLOC_CONF="expandable_segments:True" \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -m torchtitan.experiments.flux.train --job.config_file ${CONFIG_FILE} "$@" diff --git a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh index 6530b8ce99..a0dcdca0eb 100644 --- a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh +++ b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh @@ -20,7 +20,7 @@ if [ $# -ne 0 ]; then overrides="$*" fi -PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +PYTORCH_ALLOC_CONF="expandable_segments:True" \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ convert_hf_to_dcp_with_gpus.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh index f3fd310934..2da5ac2b3f 100644 --- a/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh +++ b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh @@ -19,7 +19,7 @@ if [ $# -ne 0 ]; then overrides="$*" fi -PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +PYTORCH_ALLOC_CONF="expandable_segments:True" \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ convert_meta_to_dcp_with_gpus_meta.py --job.config_file ${CONFIG_FILE} $overrides From 803906b8f2e4b3015d63a589296b36e6302fb75b Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Fri, 15 Aug 2025 11:41:26 -0400 Subject: [PATCH 093/128] Add DualPipeV (#1571) DualPipeV was added to pt-core (https://github.com/pytorch/pytorch/pull/159591) so just adding code to support it in titan To use, in .toml file set: ``` pipeline_parallel_schedule = "DualPipeV" ``` Ideally we don't have this if-statement check, so as a future BE task I can look into removing it --- torchtitan/distributed/pipeline_parallel.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index 96cf2ed790..c74c99daaf 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -18,6 +18,7 @@ get_schedule_class, PipelineScheduleMulti, PipelineScheduleSingle, + ScheduleDualPipeV, ScheduleZBVZeroBubble, ) @@ -335,7 +336,9 @@ def _build_stage_from_modules( models = [] schedule_class = get_schedule_class(pp_schedule) - style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" + style = ( + "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" + ) for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): module_names = module_names_per_stage[stage_idx] From 297a72a14a977c58ec26e74b281df917afda7084 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 15 Aug 2025 10:56:07 -0700 Subject: [PATCH 094/128] Ignore tokenizer_path if it is an empty string (#1579) This allows us to use `--model.tokenizer_path=` to invalidate `tokenizer_path`. --- torchtitan/config/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py index 4ec6eb4ac9..ce0fe35c0f 100644 --- a/torchtitan/config/manager.py +++ b/torchtitan/config/manager.py @@ -172,7 +172,7 @@ def _dict_to_dataclass(self, cls, data: dict[str, Any]) -> Any: def _validate_config(self) -> None: # TODO: temporary mitigation of BC breaking change in hf_assets_path # tokenizer default path, need to remove later - if self.config.model.tokenizer_path is not None: + if self.config.model.tokenizer_path: logger.warning( "tokenizer_path is deprecated, use model.hf_assets_path instead. " "Setting hf_assets_path to tokenizer_path temporarily." From a59abeac98ba47e46e88504fb7e40bbf69a861f8 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Fri, 15 Aug 2025 11:47:30 -0700 Subject: [PATCH 095/128] added better guidance for if deprecated tokenizer path fails (#1568) Adds a check to see if the old tokenizer path is being used when tokenizer path fails. This way it can provide guidance to people to update to the supported `hf_assets_path` and `download_hf_assets.py` script --- torchtitan/components/tokenizer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index 24db9b3484..022fcbc266 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -82,7 +82,16 @@ def _load_config(self, config_path: str) -> Optional[dict]: def _load_tokenizer_from_path(self, tokenizer_path: str) -> Tokenizer: """Load tokenizer from various file formats.""" if not os.path.exists(tokenizer_path): - raise FileNotFoundError(f"Tokenizer path '{tokenizer_path}' does not exist") + if "assets/tokenizer" in tokenizer_path: + raise FileNotFoundError( + "Detected ./assets/tokenizer path which was deprecated in https://github.com/pytorch/torchtitan/pull/1540.\n" + "Remove --model.tokenizer_path and download to --model.hf_assets_path using ./scripts/download_hf_assets.py\n" + "See example: https://github.com/pytorch/torchtitan/tree/main/torchtitan/models/deepseek_v3#download-tokenizer" + ) + else: + raise FileNotFoundError( + f"Tokenizer path '{tokenizer_path}' does not exist" + ) # Define paths for different tokenizer file types tokenizer_json_path = os.path.join(tokenizer_path, "tokenizer.json") From 72b16b13abc88ba08f3e1796e5caee09abd94554 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Sat, 16 Aug 2025 22:18:33 -0700 Subject: [PATCH 096/128] Added doc for Val/Eval and lm_eval integration (#1573) This pr adds documentation for how to get started with - In training validation in TorchTitan - Third party evaluation with `lm_eval` --- docs/evaluation.md | 38 +++++++++++++++++++ .../llama3/train_configs/llama3_405b.toml | 6 +++ .../llama3/train_configs/llama3_70b.toml | 6 +++ .../llama3/train_configs/llama3_8b.toml | 6 +++ 4 files changed, 56 insertions(+) create mode 100644 docs/evaluation.md diff --git a/docs/evaluation.md b/docs/evaluation.md new file mode 100644 index 0000000000..69de104aaa --- /dev/null +++ b/docs/evaluation.md @@ -0,0 +1,38 @@ +# Validation and Evaluation + +`torchtitan` provides direct and indirect support for validation to support user's training goals. Direct support is provided by the `Validator` class which interacts directly with the training loop, and indirect support is provided through [HuggingFace checkpoint conversion](https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md#huggingface) for users who want to do evaluation using external tools such as ELeutherAI's `lm_eval`. + +## Validation +For users who want to perform validation directly during the training loop, we provide the `Validator` class which can be conveniently overloaded through the `TrainSpec` or configured in `JobConfig`. The validator class has access to and reuses many of the trainer's functions such as its parallelization, including pipelining. + +Below is an example validation config: + +```toml +[validation] +enabled = true +dataset = "c4_validation" +freq = 500 +steps = -1 # consumes the entire validation set +``` + +## Third-Party Evaluation +With `./scripts/checkpoint_conversion/convert_to_hf.py`, `torchtitan` offers support for converting checkpoints from DCP to safetensors format. Using this script, users can perform efficient evaluation separate from their training using external libraries that support HuggingFace e.g. `lm_eval` with `vllm` backend. + +### Example usage of `lm_eval` with `vllm`: +To use this specific setup make sure to include a HuggingFace `config.json` file which is not provided by conversion script or `last_save_in_hf` option. The HF config file can be downloaded by running `python ./scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets config`. + +Note that pip installing `lm-eval` may result in breaking `torchtitan` dev environment so we recommend creating a separate env. +```bash +pip install "lm-eval[vllm]" +lm_eval --model vllm \ + --model_args pretrained=./outputs/checkpoint/step-1000,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8, \ + --tasks mmlu \ + --batch_size auto +``` +| Groups |Version|Filter|n-shot|Metric| |Value | |Stderr| +|------------------|------:|------|------|------|---|-----:|---|-----:| +|mmlu | 2|none | |acc |↑ |0.6209|± |0.0038| +| - humanities | 2|none | |acc |↑ |0.5481|± |0.0066| +| - other | 2|none | |acc |↑ |0.7045|± |0.0078| +| - social sciences| 2|none | |acc |↑ |0.7351|± |0.0078| +| - stem | 2|none | |acc |↑ |0.5357|± |0.0085| diff --git a/torchtitan/models/llama3/train_configs/llama3_405b.toml b/torchtitan/models/llama3/train_configs/llama3_405b.toml index d34e85d213..63d91f41aa 100644 --- a/torchtitan/models/llama3/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_405b.toml @@ -60,3 +60,9 @@ mode = "full" # ["none", "selective", "full"] enable_fsdp_float8_all_gather = true precompute_float8_dynamic_scale_for_fsdp = true filter_fqns = ["output"] + +[validation] +enabled = false +dataset = "c4_validation" +freq = 500 +steps = -1 diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index 3f2a0355d6..8d3289de85 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -59,3 +59,9 @@ mode = "full" enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output"] + +[validation] +enabled = false +dataset = "c4_validation" +freq = 500 +steps = -1 diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index f3c2931a55..038f9b33f6 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -60,3 +60,9 @@ selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac ba enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output"] + +[validation] +enabled = false +dataset = "c4_validation" +freq = 100 +steps = -1 From 9233d831882677fca37c45766be55808d43918da Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Sun, 17 Aug 2025 17:50:59 -0700 Subject: [PATCH 097/128] [EP] bug fixes (#1586) fixes bug introduced in https://github.com/pytorch/torchtitan/pull/1555 --- torchtitan/distributed/expert_parallel.py | 12 +++++++++--- torchtitan/experiments/llama4/infra/parallelize.py | 2 +- .../llama4/train_configs/debug_model.toml | 1 + .../deepseek_v3/train_configs/debug_model.toml | 1 + 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index eef4bda714..c5b9d7bfd9 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -365,6 +365,10 @@ def wrapper( # This class is to support Sequence Parallel for ETP=1 # when EP borrows from all TP and part of DP class ReordererSequenceParallel(ParallelStyle): + def __init__(self): + super().__init__() + self.num_tokens = None + def _prepare_inputput_fn(self, mod, inputs, device_mesh): top_scores, selected_experts_indices = inputs @@ -372,6 +376,7 @@ def _prepare_inputput_fn(self, mod, inputs, device_mesh): selected_experts_indices = DTensor.from_local( selected_experts_indices, device_mesh, (Replicate(),) ) + self.num_tokens = top_scores.shape[0] # TODO: If needed, we can pad tokens in case bs*slen is not divisible by TP degree # if top_scores.shape[0] % device_mesh.size() != 0: @@ -380,7 +385,7 @@ def _prepare_inputput_fn(self, mod, inputs, device_mesh): # n_pad = (num_tokens // tp_size + 1) * tp_size - num_tokens # selected_experts_indices = F.pad(selected_experts_indices, [0, 0, 0, n_pad]) # top_scores = F.pad(top_scores, [0, 0, 0, n_pad]) - assert top_scores.shape[0] % device_mesh.size() == 0 + assert self.num_tokens % device_mesh.size() == 0 # split on the bs*slen dimension top_scores = top_scores.redistribute(device_mesh, (Shard(0),)).to_local() @@ -395,9 +400,10 @@ def _prepare_output_fn(self, mod, outputs, device_mesh): # NOTE: As we shard routed tokens along bs*slen dim across the TP ranks, # the MoE gather and scatter still require global token indices. - num_tokens = top_scores.shape[0] local_rank = device_mesh.get_local_rank() - token_indices_experts_sorted += num_tokens // device_mesh.size() * local_rank + token_indices_experts_sorted += ( + self.num_tokens // device_mesh.size() * local_rank + ) return top_scores, token_indices_experts_sorted, num_tokens_per_expert diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 6d75b4986a..6fc343d282 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -401,7 +401,7 @@ def apply_moe_ep_tp( # replicate computation for the router "moe.router.gate": NoParallel(), } - if not etp_enabled: + if ep_mesh is not None and not etp_enabled: # If TP is borrowed for EP, then split the tokens across TP ranks so that # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index 0d1ed83628..0179b5f9a1 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -51,6 +51,7 @@ fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" context_parallel_degree = 1 expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 065803ff01..dd94556f27 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -52,6 +52,7 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 From 0d1b80d3b724cc804a15ee50138824c8ad0bea0d Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 18 Aug 2025 10:26:20 -0700 Subject: [PATCH 098/128] [EP] remove token split overhead from DTensor in TokenReorderer pre hook (#1587) Due to the d2h sync in EP, training sometimes is CPU bounded. So we need to be more careful about DTensor overhead. See screenshots below for profiler traces. Numerics are verified to be the same. **before** image **after** image --- torchtitan/distributed/expert_parallel.py | 25 ++++++++++++----------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index c5b9d7bfd9..384d9e33fe 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -371,27 +371,28 @@ def __init__(self): def _prepare_inputput_fn(self, mod, inputs, device_mesh): top_scores, selected_experts_indices = inputs - - top_scores = DTensor.from_local(top_scores, device_mesh, (Replicate(),)) - selected_experts_indices = DTensor.from_local( - selected_experts_indices, device_mesh, (Replicate(),) - ) self.num_tokens = top_scores.shape[0] - # TODO: If needed, we can pad tokens in case bs*slen is not divisible by TP degree + # NOTE: If needed, we can pad tokens in case bs*slen is not divisible by TP degree # if top_scores.shape[0] % device_mesh.size() != 0: # num_tokens = top_scores.shape[0] # tp_size = device_mesh.size() # n_pad = (num_tokens // tp_size + 1) * tp_size - num_tokens # selected_experts_indices = F.pad(selected_experts_indices, [0, 0, 0, n_pad]) # top_scores = F.pad(top_scores, [0, 0, 0, n_pad]) - assert self.num_tokens % device_mesh.size() == 0 - # split on the bs*slen dimension - top_scores = top_scores.redistribute(device_mesh, (Shard(0),)).to_local() - selected_experts_indices = selected_experts_indices.redistribute( - device_mesh, (Shard(0),) - ).to_local() + def _split_along_first_dim(x: torch.Tensor) -> torch.Tensor: + assert x.is_contiguous() + assert self.num_tokens % device_mesh.size() == 0 + local_num_tokens = self.num_tokens // device_mesh.size() + local_rank = device_mesh.get_local_rank() + offset = local_rank * local_num_tokens + output = x[offset : offset + local_num_tokens] + + return output + + top_scores = _split_along_first_dim(top_scores) + selected_experts_indices = _split_along_first_dim(selected_experts_indices) return top_scores, selected_experts_indices From f9e8897a6f09fbe932e4b1a4df9fc23ca08c2b5e Mon Sep 17 00:00:00 2001 From: Hossein Kaviani Date: Mon, 18 Aug 2025 11:31:14 -0700 Subject: [PATCH 099/128] Adding Qwen3 model to the experiments folder (#1429) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In this PR, I added Qwen3 0.6 B dense model for torchtitan under experiments. Parity test has been done and Torch Titan native results match HF implementation. Profiler diagnostic has been attached. Computation/communication latency breakdown displays good performance. More explanation can be found in the README file. Thanks Rohan Pandey (@khoomeik) for helping out with the Rope implementation.Screenshot 2025-08-13 at 8 31 53 AM --------- Co-authored-by: Hossein Kavianihamedani Co-authored-by: Jiani Wang --- torchtitan/experiments/__init__.py | 1 + torchtitan/experiments/qwen3/README.md | 27 + torchtitan/experiments/qwen3/__init__.py | 121 +++++ .../experiments/qwen3/infra/parallelize.py | 219 ++++++++ torchtitan/experiments/qwen3/model/args.py | 65 +++ torchtitan/experiments/qwen3/model/model.py | 466 ++++++++++++++++++ .../qwen3/train_configs/qwen3_0.6b.toml | 60 +++ 7 files changed, 959 insertions(+) create mode 100644 torchtitan/experiments/qwen3/README.md create mode 100644 torchtitan/experiments/qwen3/__init__.py create mode 100644 torchtitan/experiments/qwen3/infra/parallelize.py create mode 100644 torchtitan/experiments/qwen3/model/args.py create mode 100644 torchtitan/experiments/qwen3/model/model.py create mode 100644 torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 4c54bdc13e..9d81f6b885 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -5,4 +5,5 @@ # LICENSE file in the root directory of this source tree. import torchtitan.experiments.llama4 # noqa: F401 +import torchtitan.experiments.qwen3 import torchtitan.experiments.simple_fsdp # noqa: F401 diff --git a/torchtitan/experiments/qwen3/README.md b/torchtitan/experiments/qwen3/README.md new file mode 100644 index 0000000000..dce71ed11c --- /dev/null +++ b/torchtitan/experiments/qwen3/README.md @@ -0,0 +1,27 @@ +**The Qwen3 model is still under development.** + + +#### Available features +QWEN3 0.6B Dense model is available for: + +- FSDP/HSDP, TP, DDP, AC, compile support + +Other model sizes are added to the args, but toml file configs need to be added and tested. Further testing is needed to check the coistency of the parallelism implementations. + +#### Download Qwen3 tokenizer + +```python scripts/download_tokenizer.py --repo_id Qwen/Qwen3-0.6B``` + + +#### Parity with HF + +Model parity test has been done and results suggest parity with HF implementation. Further investigation is needed to check the sanity of the Rope function. + +#### To be added +- Modeling + - Variants of Dense models up to 32B + - MoE alternatives + - Weight tying +- Testing + - The model should be tested against established performance benchmarks + - CI integration diff --git a/torchtitan/experiments/qwen3/__init__.py b/torchtitan/experiments/qwen3/__init__.py new file mode 100644 index 0000000000..a5c31f3d88 --- /dev/null +++ b/torchtitan/experiments/qwen3/__init__.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .infra.parallelize import parallelize_qwen3 +from .model.args import Qwen3ModelArgs +from .model.model import Transformer + +__all__ = [ + "parallelize_qwen3", + "Qwen3ModelArgs", + "Transformer", + "qwen3_configs", +] + + +# Adding different variants of the model + +qwen3_configs = { + "0.6B": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=4096, + head_dim=128, + dim=1024, + n_layers=28, + n_heads=16, + n_kv_heads=8, + qk_norm=True, + hidden_dim=3072, + rope_theta=1000000, + ), + "1.7B": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=4096, + head_dim=128, + dim=2048, + n_layers=28, + n_heads=16, + n_kv_heads=8, + qk_norm=True, + hidden_dim=6144, + rope_theta=1000000, + ), + "4B": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=4096, + head_dim=128, + dim=2560, + n_layers=36, + n_heads=32, + n_kv_heads=8, + qk_norm=True, + hidden_dim=9728, + rope_theta=1000000, + ), + "8B": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=4096, + head_dim=128, + dim=4096, + n_layers=36, + n_heads=32, + n_kv_heads=8, + qk_norm=True, + hidden_dim=12288, + rope_theta=1000000, + ), + "14B": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=4096, + head_dim=128, + dim=5120, + n_layers=40, + n_heads=40, + n_kv_heads=8, + qk_norm=True, + hidden_dim=17408, + rope_theta=1000000, + ), + "32B": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=4096, + head_dim=128, + dim=5120, + n_layers=64, + n_heads=64, + n_kv_heads=8, + qk_norm=True, + hidden_dim=25600, + rope_theta=1000000, + ), +} + + +register_train_spec( + TrainSpec( + name="qwen3", + model_cls=Transformer, + model_args=qwen3_configs, # Change from dict to Mapping + parallelize_fn=parallelize_qwen3, + pipelining_fn=None, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + ) +) diff --git a/torchtitan/experiments/qwen3/infra/parallelize.py b/torchtitan/experiments/qwen3/infra/parallelize.py new file mode 100644 index 0000000000..4d4572d287 --- /dev/null +++ b/torchtitan/experiments/qwen3/infra/parallelize.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + +import torch +import torch.nn as nn + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.expert_parallel import NoParallel +from torchtitan.models.llama3.infra.parallelize import ( + apply_ac, + apply_compile, + apply_ddp, + apply_fsdp, +) +from torchtitan.tools.logging import logger + + +def parallelize_qwen3( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + + world_mesh = parallel_dims.world_mesh + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + if parallel_dims.tp_enabled: + if ( + job_config.parallelism.enable_async_tensor_parallel + and not job_config.training.compile + ): + raise RuntimeError("Async TP requires --training.compile") + + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + apply_tp( + model, + world_mesh["tp"], + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if job_config.training.compile: + apply_compile(model) + + if parallel_dims.fsdp_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + ) + + return model + + +def apply_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears with tensorwise scaling. + if enable_float8_tensorwise_tp: + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), Replicate()), + desired_input_layouts=(Replicate(), Replicate()), + ), + "attention.wq": colwise_parallel(use_local_output=False), + "attention.wk": colwise_parallel(use_local_output=False), + "attention.wv": colwise_parallel(use_local_output=False), + "attention.q_norm": NoParallel(use_local_output=False), + "attention.k_norm": NoParallel(use_local_output=False), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": colwise_parallel(), + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) diff --git a/torchtitan/experiments/qwen3/model/args.py b/torchtitan/experiments/qwen3/model/args.py new file mode 100644 index 0000000000..a27b7bae14 --- /dev/null +++ b/torchtitan/experiments/qwen3/model/args.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +from dataclasses import dataclass + +from torch import nn + +from torchtitan.config import JobConfig +from torchtitan.protocols.train_spec import BaseModelArgs + +from torchtitan.tools.logging import logger + + +@dataclass +class Qwen3ModelArgs(BaseModelArgs): + + dim: int = 1024 + n_layers: int = 28 + n_heads: int = 16 + n_kv_heads: int = 8 + vocab_size: int = 151936 + head_dim: int = 128 + hidden_dim: int = 3072 + + norm_eps: float = 1e-6 + rope_theta: float = 1000000 + qk_norm: bool = True + max_seq_len: int = 4096 + depth_init: bool = True + + use_flex_attn: bool = False + attn_mask_type: str = "causal" + eos_id: int = 151645 + + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + nparams = sum(p.numel() for p in model.parameters()) + nparams_embedding = sum( + sum(p.numel() for p in m.parameters()) + for m in model.children() + if isinstance(m, nn.Embedding) + ) + + l, h, q, t = ( + self.n_layers, + self.n_heads, + self.dim // self.n_heads, + seq_len, + ) + num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t + + return nparams, num_flops_per_token diff --git a/torchtitan/experiments/qwen3/model/model.py b/torchtitan/experiments/qwen3/model/model.py new file mode 100644 index 0000000000..6697e39202 --- /dev/null +++ b/torchtitan/experiments/qwen3/model/model.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +import torch +import torch.nn.functional as F +from torch import nn + +from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import Qwen3ModelArgs + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + Note: + This function adds .transpose(-2,-1) to match HF implementation. This method assumes that last + dimension is [real_0, real_1, ..., real_{N-1}, imag_0, imag_1, ..., imag_{N-1}] while Rope in Llama3 + has [real_0, imag_0, real_1, imag_1, ..., real_{N-1}, imag_{N-1}]. This is the main difference + between Llama3 and Qwen3 Rope which is under investigation. + """ + xk_complex = torch.view_as_complex( + xk.view(*xk.shape[:-1], 2, xk.shape[-1] // 2) + .transpose(-2, -1) + .contiguous() + .float() + ) + xq_complex = torch.view_as_complex( + xq.view(*xq.shape[:-1], 2, xq.shape[-1] // 2) + .transpose(-2, -1) + .contiguous() + .float() + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex) + + xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(3) + + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: Qwen3ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.head_dim + + # RMSNorm added here to the here to include the q-k norm + # This is one of the main differences between Llama3 and Qwen3 + if model_args.qk_norm: + self.q_norm = nn.RMSNorm( + self.head_dim, eps=model_args.norm_eps, elementwise_affine=True + ) + self.k_norm = nn.RMSNorm( + self.head_dim, eps=model_args.norm_eps, elementwise_affine=True + ) + else: + self.q_norm = None + self.k_norm = None + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + if self.q_norm is not None: + self.q_norm.reset_parameters() + if self.k_norm is not None: + self.k_norm.reset_parameters() + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + # Adding the q_norm and k_norm here + # Last layer of adding q-k norm + if self.q_norm: + xq = self.q_norm(xq) + if self.k_norm: + xk = self.k_norm(xk) + + # repeat k/v heads if n_kv_heads < n_heads + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + output = self.sdpa(xq, xk, xv) + + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + + # Hidden dimension is directly added from the model argsS + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: Qwen3ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, hidden_dim=model_args.hidden_dim + ) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module, ModelProtocol): + """ + Transformer Module + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + model_args (TransformerModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: Qwen3ModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id + self.head_dim = model_args.head_dim + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + self.init_weights() + + def init_weights( + self, + buffer_device: torch.device | None = None, + ): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = self._precompute_freqs_cis() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.head_dim, + # Need to compute until at least the max token limit for generation + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def forward( + self, + tokens: torch.Tensor, + input_batch: torch.Tensor | None = None, + eos_id: int | None = None, + ): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. + If pipeline parallelism is enabled, this will be the input token indices + for the ranks on the first pipeline stage. This will be the activation of the + previous pipeline stage if the current rank is not on the first stage. + input_batch (torch.Tensor): The input batch read from the dataloader. + This will always be the input batch regardless of the pipeline stage. + This field is required for non-first PP stages to perform document + masking attention (to analyze the boundary of the document). + eos_id (int | None): End-of-sequence token ID. If not provided, uses self.eos_id. + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + if self.model_args.use_flex_attn: + init_attention_mask( + input_batch if input_batch is not None else tokens, + eos_id=eos_id, + ) + + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + return output diff --git a/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml b/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml new file mode 100644 index 0000000000..f9cf0cd2b4 --- /dev/null +++ b/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml @@ -0,0 +1,60 @@ +[job] +dump_folder = "./outputs" +description = "Qwen 3 0.6B training" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +name = "qwen3" +flavor = "0.6B" +tokenizer_path = "./assets/tokenizer/Qwen3-0.6B" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 3e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 1 # lr scheduler warm up + +[training] +local_batch_size = 4 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +dataset = "c4" + + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +last_save_model_weights_only = false +export_dtype = "float16" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] From e4847c8054ad5f878add8a4ef990d311f1a08602 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Mon, 18 Aug 2025 19:15:46 -0700 Subject: [PATCH 100/128] added example for bidirectional checkpoint testing (#1540) This pr adds - an example script for bidirectional testing of checkpoint conversion scripts - a `checkpoint_conversion.md` to describe our methodology. --- scripts/checkpoint_conversion/README.md | 24 +++ .../numerical_tests_example.py | 160 ++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 scripts/checkpoint_conversion/README.md create mode 100644 scripts/checkpoint_conversion/numerical_tests_example.py diff --git a/scripts/checkpoint_conversion/README.md b/scripts/checkpoint_conversion/README.md new file mode 100644 index 0000000000..78d580bcda --- /dev/null +++ b/scripts/checkpoint_conversion/README.md @@ -0,0 +1,24 @@ +# Testing Checkpoint Conversion for Correctness + +When converting checkpoints between file types or model definitions, we need to ensure that the converted checkpoints are correct, i.e. their model definition remains the same, which includes that the converted checkpoint's weights will give the same outputs when loaded in the new intended program context. + +This guide provides a general framework on how to test your conversion script for correctness. The example that we will use here is bidirectional conversion between HuggingFace and `torchtitan`. + +## Methods + +### Sanity Check (Greedy Decode) +A quick way to sanity check if your conversion is correct is to perform greedy decoding inference on both the initial and converted checkpoints and confirm that they are the same. This method doesn't guarantee correctness but will very likely result in a fast **true negative** if the model definitions are not the same. For Llama3, greedy decoding can be achieved using the `generation/test_generate.py` script. Other models may not have an inference script, but the methodology holds the same. + +Note that your model definition needs to match your conversion script. For example, if converting from `torchtitan` to HuggingFace, be sure to include the correct `config.json` file that matches the `torchtitan` model architecture. Providing an incorrect `config.json` when loading the model with HuggingFace `transformers` will result in incorrect generations despite a correct weight conversion. + +### Comprehensive Check (KL Divergence) +In our `./scripts/checkpoint_conversion/numerical_test_example.py` this will be performing forward on DCP checkpoints loaded in `torchtitan` and safetensors checkpoints loaded in HuggingFace `AutoModelForCausalLM`. This script tests the HuggingFace -> `torchtitan` direction, as loading a HuggingFace checkpoint requires both +- converting the instantiated `torchtitan` state dict `to_hf` so that safetensors weights can be loaded into it, and +- converting the HF version of state dict back to torchtitan using `from_hf`. + +To convert Llama 3 between HuggingFace and `torchtitan` we had to perform a permutation on several of the attention matrices to account for difference between HuggingFace and native Llama RoPE implementations. To demonstrate how a KL divergence test can reveal subtle inaccuracies such as this, we additionally compare the KL divergence between the original and converted model with and without the permutation. The results are as follows: +``` +$ python ./scripts/checkpoint_conversion/example.py +Average loss of test from_hf is -1.45365707318601e-13 +Average loss of test from_hf_no_perm is 5.368335223465692e-06 +``` diff --git a/scripts/checkpoint_conversion/numerical_tests_example.py b/scripts/checkpoint_conversion/numerical_tests_example.py new file mode 100644 index 0000000000..66eff8054e --- /dev/null +++ b/scripts/checkpoint_conversion/numerical_tests_example.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +import torch.distributed.checkpoint as dcp +import torch.nn.functional as F +from torchtitan.components.checkpoint import ModelWrapper +from torchtitan.config import ConfigManager +from torchtitan.protocols.train_spec import get_train_spec +from torchtitan.tools.logging import logger +from transformers import AutoModelForCausalLM + +device_type = "cuda" if torch.cuda.is_available() else "cpu" + + +def loss_fn(logits1, logits2): + # Convert logits to probabilities + probs1 = F.log_softmax(logits1, dim=-1) + probs2 = F.softmax(logits2, dim=-1) + + # Calculate KL Divergence + kl_loss = F.kl_div(probs1, probs2, "mean") + return kl_loss + + +@torch.no_grad +def forward_hf(model_name, model_path: Optional[str], input_ids): + # Load the tokenizer and model + model_path = model_path if model_path else model_name + model = AutoModelForCausalLM.from_pretrained(model_path) + + device = torch.device(device_type) + model.to(device) + + # List to store outputs + outputs_list = [] + + for inputs in input_ids: + inputs = inputs.to(device) + outputs = model.generate( + inputs=inputs, + max_length=prompt_len + 1, + do_sample=False, + output_logits=True, + return_dict_in_generate=True, + ) + + outputs = torch.stack(outputs.logits) + outputs_list.append(outputs) + + del model + torch.cuda.empty_cache() + + return outputs_list + + +@torch.no_grad +def forward_tt(config_path, checkpoint_path, test_set): + + config_manager = ConfigManager() + config = config_manager.parse_args([f"--job.config_file={config_path}"]) + + train_spec = get_train_spec(config.model.name) + + model_args = train_spec.model_args[config.model.flavor] + model_args.update_from_config(config) + + model = train_spec.model_cls(model_args) + + # materalize model + device = torch.device(device_type) + model.to_empty(device=device) + model.init_weights(buffer_device=device) + model.eval() + + modelWrapper = ModelWrapper(model) + state_dict = modelWrapper._get_state_dict() + + # Checkpoint Loading + logger.info(f"Loading checkpoint at: {checkpoint_path}") + dcp.load(state_dict, checkpoint_id=checkpoint_path) + + output_list = [] + for prompt in test_set: + input_ids = prompt.to(device_type) + # ensure batch dimension (T,) --> (B, T) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + # obtains the logits of only the last token in the predictions + predictions = model(input_ids)[:, -1, :].unsqueeze(1) + output_list.append(predictions) + + del model + torch.cuda.empty_cache() + + return output_list + + +if __name__ == "__main__": + # hf params + hf_model_name = "meta-llama/Meta-Llama-3-8B" + + # tt params + config_path = "torchtitan/models/llama3/train_configs/llama3_8b.toml" + checkpoint_path = "outputs/test_checkpoint/step-0-fromhf" # dcp checkpoint from convert_from_hf.py + # dcp checkpoint from convert_from_hf.py without using sd_adapter's permute + checkpoint_path_no_perm = "outputs/test_checkpoint/step-0-fromhfnoperm" + + # test params + prompt_len = 8 + test_size = 100 + + config_manager = ConfigManager() + config = config_manager.parse_args([f"--job.config_file={config_path}"]) + train_spec = get_train_spec(config.model.name) + tokenizer = train_spec.build_tokenizer_fn(config) + + # Build test set of randomly generated token ids + test_set = [ + torch.randint( + 0, + tokenizer.get_vocab_size(), + ( + 1, # batch size + prompt_len, + ), + ) + for _ in range(test_size) + ] + + # baseline logits + baseline_hf_outputs = forward_hf(hf_model_name, None, test_set) + + # testing from hf conversion + from_hf_outputs = forward_tt(config_path, checkpoint_path, test_set) + from_hf_outputs_no_perm = forward_tt(config_path, checkpoint_path_no_perm, test_set) + + # Define the set of outputs to test loss for + test_configs = { + "from_hf": [baseline_hf_outputs, from_hf_outputs], + "from_hf_no_perm": [baseline_hf_outputs, from_hf_outputs_no_perm], + } + avg_losses = {} + + for test_name, (baseline_outputs, conversion_outputs) in test_configs.items(): + total_loss = 0 + for baseline, outputs in zip(baseline_outputs, conversion_outputs): + total_loss += loss_fn(baseline, outputs) + avg_loss = total_loss / len(test_set) + avg_losses[test_name] = avg_loss.item() + + for test_name, avg_loss in avg_losses.items(): + print(f"Average loss for test {test_name} is {avg_loss}") From a54725cfab21c82f5189cfdfff93c2c9347ac025 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 18 Aug 2025 21:10:26 -0700 Subject: [PATCH 101/128] MoE explicit prefetching in FSDP (#1594) As titled. It seems with this PR, prefetching works as expected, both in forward and in backward. This is developed on top of the example https://github.com/pytorch/torchtitan/pull/1581 from @weifengpy I verified in profiler trace, but it's a bit hard to show in a picture. image --- .../experiments/llama4/infra/parallelize.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 6fc343d282..b576b91ab3 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -375,6 +375,57 @@ def apply_fsdp( fully_shard(model, **fsdp_config) + # NOTE: set up explicit prefetching when EP is enabled, as D2H syncs + # in EP could interfere with implicit prefetching in FSDP + if ep_degree == 1: + return + + # forward + transformer_blocks = list(model.layers.values()) + next_transformer_blocks = transformer_blocks[1:] + [None] + + if model.tok_embeddings is not None and model.layers is not None: + model.tok_embeddings.set_modules_to_forward_prefetch([transformer_blocks[0]]) + + for transformer_block, next_transformer_block in zip( + transformer_blocks, next_transformer_blocks + ): + if next_transformer_block is not None: + if next_transformer_block.moe_enabled: + transformer_block.set_modules_to_forward_prefetch( + [next_transformer_block, next_transformer_block.moe.experts] + ) + else: + transformer_block.set_modules_to_forward_prefetch( + [next_transformer_block] + ) + elif model.norm is not None and model.output is not None: + transformer_block.set_modules_to_forward_prefetch( + [model.norm, model.output] + ) + + # backward + reversed_transformer_blocks = list(reversed(model.layers.values())) + prev_transformer_blocks = reversed_transformer_blocks[1:] + [None] + + if model.norm is not None and model.output is not None and model.layers is not None: + model.output.set_modules_to_backward_prefetch([reversed_transformer_blocks[0]]) + + for transformer_block, prev_transformer_block in zip( + reversed_transformer_blocks, prev_transformer_blocks + ): + if prev_transformer_block is not None: + if prev_transformer_block.moe_enabled: + transformer_block.set_modules_to_backward_prefetch( + [prev_transformer_block, prev_transformer_block.moe.experts] + ) + else: + transformer_block.set_modules_to_backward_prefetch( + [prev_transformer_block] + ) + elif model.tok_embeddings is not None: + transformer_block.set_modules_to_backward_prefetch([model.tok_embeddings]) + def apply_moe_ep_tp( model: nn.Module, From 9e2468960969772b968972aad2d008575a5fb9f1 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Mon, 18 Aug 2025 21:27:49 -0700 Subject: [PATCH 102/128] [DeepSeek] add torch.compile + async TP (#1588) verified that torch.compile works. However, I didn't see async TP in trace. cc @danielvegamyhre @fegin Could you help take a look? --- .../experiments/llama4/infra/parallelize.py | 5 ++--- torchtitan/models/deepseek_v3/README.md | 7 ++++--- .../models/deepseek_v3/infra/parallelize.py | 21 ++++++++++++------- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index b576b91ab3..35a72167d0 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -108,10 +108,9 @@ def parallelize_llama( # turn on per-TransformerBlock compile after AC wrapping and before FSDP if job_config.training.compile: - apply_compile(model) - # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE torch._dynamo.config.capture_scalar_outputs = True + apply_compile(model) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: @@ -503,7 +502,7 @@ def apply_compile(model: nn.Module): repeated structure. Alternatively one can compile the whole model (after applying DP). """ for layer_id, transformer_block in model.layers.named_children(): - # TODO: remove when torch.compile supports fullgraph=True for llama4 moe + # TODO: remove when torch.compile supports fullgraph=True for MoE fullgraph = True if transformer_block.moe_enabled: fullgraph = False diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index 085403d47b..15860c2361 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -47,6 +47,7 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml - Tensor Parallel (TP) - Expert Parallel (EP) - Pipeline Parallel (PP) +- torch.compile ## HuggingFace -> DCP Checkpoint Conversion @@ -65,8 +66,8 @@ Some limitations: ## To be added - Parallelism - Context Parallel support for DeepSeek V3 -- torch.compile - Quantization - Testing - - perfomance and loss converging tests - - CI integration + - loss converging tests (verified) + - perfomance (WIP) + - CI integration (WIP) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 8271d49dcd..d82270edd0 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -47,12 +47,11 @@ def parallelize_deepseekv3( raise NotImplementedError("CP support for FlexAttention is still in progress.") if parallel_dims.tp_enabled: - if job_config.parallelism.enable_async_tensor_parallel: - # TODO(jianiw): This branch needs to be tested and enabled - raise NotImplementedError( - "Currently, async TP is not tested for deepseekv3. \ - torch.compile is not supported yet, which is required for async TP." - ) + if ( + job_config.parallelism.enable_async_tensor_parallel + and not job_config.training.compile + ): + raise RuntimeError("Async TP requires --training.compile") enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -94,7 +93,9 @@ def parallelize_deepseekv3( apply_ac(model, job_config.activation_checkpoint) if job_config.training.compile: - raise NotImplementedError("torch.compile is not supported yet for deepseekv3") + # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE + torch._dynamo.config.capture_scalar_outputs = True + apply_compile(model) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: @@ -251,6 +252,12 @@ def apply_non_moe_tp( parallelize_plan=layer_plan, ) + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + logger.info( f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" "Tensor Parallelism to the model" From 7f1fa48157cbd8fd9d573dfea857b94589ffbeba Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Mon, 18 Aug 2025 22:08:18 -0700 Subject: [PATCH 103/128] [Qwen3] Switch to verified RoPE implementation + Add weight tying support (#1590) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Context 1. Current Qwen3 model RoPE used some trick to make numerical parity with HF. This trick is from un-official source and hard to reasoning mathematically. Switch to [torchtune based implementation](https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_positional_embeddings.py#L14), which was directly contributed from Qwen team. Thanks @ebsmothers for point us to this implementation! - For RoPE embedding, I change it to the same way as complex representation based RoPE in llama3: We initialize and precompute the RoPE embedding cos/sin value only once, and pass it into Attention module during forward. In this way, TP can be applied seamlessly. - In contrast, torchtune passed the RoPE class into initialize function for each layers' attention module. 2. Add weight tying support for Qwen3, verified with FSDP + TP ## Numerical verification for RoPE Run end-to-end forward pass of Qwen3 model, the output and Screenshot 2025-08-18 at 2 48 48 PM ## Weight tying Verification: 1. With vs. without weight tying on torchtitan model: (FSDP=4, loss are exactly the same) Screenshot 2025-08-18 at 6 19 13 PM 2. torchtitan with weight tying vs. HF Screenshot 2025-08-18 at 9 37 50 PM 3. Weight tying memory address / id check: (in train.py) - passed ``` assert id(model.tok_embeddings.weight) == id(model.output.weight), "id check 2" assertEqual(model.tok_embeddings.weight, model.output.weight) # model.forward() assert id(model.tok_embeddings.weight.grad) == id(model.output.weight.grad), "id check 2" assertEqual(model.tok_embeddings.weight.grad, model.output.weight.grad) ``` --- torchtitan/experiments/qwen3/__init__.py | 6 +- .../experiments/qwen3/infra/parallelize.py | 4 + torchtitan/experiments/qwen3/model/args.py | 2 + torchtitan/experiments/qwen3/model/model.py | 160 ++++++++---------- .../qwen3/train_configs/qwen3_0.6b.toml | 7 +- 5 files changed, 83 insertions(+), 96 deletions(-) diff --git a/torchtitan/experiments/qwen3/__init__.py b/torchtitan/experiments/qwen3/__init__.py index a5c31f3d88..d22053ff65 100644 --- a/torchtitan/experiments/qwen3/__init__.py +++ b/torchtitan/experiments/qwen3/__init__.py @@ -16,12 +16,12 @@ from .infra.parallelize import parallelize_qwen3 from .model.args import Qwen3ModelArgs -from .model.model import Transformer +from .model.model import Qwen3Model __all__ = [ "parallelize_qwen3", "Qwen3ModelArgs", - "Transformer", + "Qwen3Model", "qwen3_configs", ] @@ -107,7 +107,7 @@ register_train_spec( TrainSpec( name="qwen3", - model_cls=Transformer, + model_cls=Qwen3Model, model_args=qwen3_configs, # Change from dict to Mapping parallelize_fn=parallelize_qwen3, pipelining_fn=None, diff --git a/torchtitan/experiments/qwen3/infra/parallelize.py b/torchtitan/experiments/qwen3/infra/parallelize.py index 4d4572d287..8648326770 100644 --- a/torchtitan/experiments/qwen3/infra/parallelize.py +++ b/torchtitan/experiments/qwen3/infra/parallelize.py @@ -120,6 +120,10 @@ def parallelize_qwen3( enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) + # Enable weight tying after applying parallelisms + if model.model_args.enable_weight_tying: + model.output.weight = model.tok_embeddings.weight + return model diff --git a/torchtitan/experiments/qwen3/model/args.py b/torchtitan/experiments/qwen3/model/args.py index a27b7bae14..45e11d0a5a 100644 --- a/torchtitan/experiments/qwen3/model/args.py +++ b/torchtitan/experiments/qwen3/model/args.py @@ -38,6 +38,8 @@ class Qwen3ModelArgs(BaseModelArgs): attn_mask_type: str = "causal" eos_id: int = 151645 + enable_weight_tying: bool = False + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len if seq_len > self.max_seq_len: diff --git a/torchtitan/experiments/qwen3/model/model.py b/torchtitan/experiments/qwen3/model/model.py index 6697e39202..07d05b734b 100644 --- a/torchtitan/experiments/qwen3/model/model.py +++ b/torchtitan/experiments/qwen3/model/model.py @@ -16,42 +16,45 @@ from .args import Qwen3ModelArgs - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' - and the end index 'end'. The 'theta' parameter scales the frequencies. - The returned tensor contains complex values in complex64 data type. - - Args: - dim (int): Dimension of the frequency tensor. - end (int): End index for precomputing frequencies. - theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0. - - Returns: - torch.Tensor: Precomputed frequency tensor with complex exponentials. +# Adapted from https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_positional_embeddings.py +def precompute_rope_cache( + dim: int, max_seq_len: int, base: float = 1_000_000.0 +) -> torch.Tensor: + freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + t = torch.arange(max_seq_len, dtype=freqs.dtype, device=freqs.device) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.outer(t, freqs).float() + + # We cache the cos and sin embeddings instead of the IDs. This helps + # ensure we have correct behavior when training with bf16 + # Size: [max_seq_len, (dim * 2)] + freqs = torch.cat([idx_theta, idx_theta], dim=-1) + rope_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + return rope_cache + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) - freqs = torch.outer(t, freqs).float() - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """ - Reshape frequency tensor for broadcasting it with another tensor. + Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor. This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2), and the first seqlen elements will be sliced, but dim must match x. Args: - freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. Returns: @@ -59,56 +62,31 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten """ ndim = x.ndim assert ndim > 1 - seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + _, seqlen, _, head_dim = x.shape + rope_cache = rope_cache[0:seqlen] + # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, + xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary embeddings to input tensors using the given frequency tensor. + # input tensor x has shape [bsz, seq_len, num_heads, head_dim] + head_dim = xq.shape[-1] - This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided - frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are - returned as real tensors. + # reshape for broadcast + rope_cache = reshape_for_broadcast(rope_cache, xq) - Args: - xq (torch.Tensor): Query tensor to apply rotary embeddings. - xk (torch.Tensor): Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. - - Returns: - tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - Note: - This function adds .transpose(-2,-1) to match HF implementation. This method assumes that last - dimension is [real_0, real_1, ..., real_{N-1}, imag_0, imag_1, ..., imag_{N-1}] while Rope in Llama3 - has [real_0, imag_0, real_1, imag_1, ..., real_{N-1}, imag_{N-1}]. This is the main difference - between Llama3 and Qwen3 Rope which is under investigation. - """ - xk_complex = torch.view_as_complex( - xk.view(*xk.shape[:-1], 2, xk.shape[-1] // 2) - .transpose(-2, -1) - .contiguous() - .float() - ) - xq_complex = torch.view_as_complex( - xq.view(*xq.shape[:-1], 2, xq.shape[-1] // 2) - .transpose(-2, -1) - .contiguous() - .float() - ) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex) - - xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(3) + # [bsz, seq_len, 1, head_dim] + cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) + sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device) + # xq: [bsz, seq_len, num_heads, head_dim] + # xk: [bsz, seq_len, num_kv_heads, head_dim] + xq_out = (xq * cos) + (rotate_half(xq) * sin) + xk_out = (xk * cos) + (rotate_half(xk) * sin) return xq_out.type_as(xq), xk_out.type_as(xk) @@ -189,14 +167,13 @@ def init_weights(self, init_std: float): def forward( self, x: torch.Tensor, - freqs_cis: torch.Tensor, + rope_cache: torch.Tensor, ): """ Forward pass of the attention module. Args: x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed frequency tensor. Returns: torch.Tensor: Output tensor after attention. @@ -220,9 +197,10 @@ def forward( if self.k_norm: xk = self.k_norm(xk) - # repeat k/v heads if n_kv_heads < n_heads - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + # Apply rotary embedding + xq, xk = apply_rotary_emb(xq, xk, rope_cache) + # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -318,7 +296,7 @@ def __init__(self, layer_id: int, model_args: Qwen3ModelArgs): def forward( self, x: torch.Tensor, - freqs_cis: torch.Tensor, + rope_cache: torch.Tensor, ): """ Perform a forward pass through the TransformerBlock. @@ -331,7 +309,7 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis) + h = x + self.attention(self.attention_norm(x), rope_cache) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -342,9 +320,9 @@ def init_weights(self): self.feed_forward.init_weights(self.weight_init_std) -class Transformer(nn.Module, ModelProtocol): +class Qwen3Model(nn.Module, ModelProtocol): """ - Transformer Module + Qwen3Model Module Args: model_args (TransformerModelArgs): Model configuration arguments. @@ -370,13 +348,18 @@ def __init__(self, model_args: Qwen3ModelArgs): self.head_dim = model_args.head_dim self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) - self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + + self.register_buffer( + "rope_cache", self._precompute_rope_cache(), persistent=False + ) self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.n_layers): self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + self.init_weights() def init_weights( @@ -394,9 +377,9 @@ def init_weights( ``init_weights``. We only call it in the constructor of this ``Transformer`` root module to avoid reinitializing tensors. """ - buffer_device = buffer_device or self.freqs_cis.device + buffer_device = buffer_device or self.rope_cache.device with torch.device(buffer_device): - self.freqs_cis = self._precompute_freqs_cis() + self.rope_cache = self._precompute_rope_cache() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): @@ -406,6 +389,8 @@ def init_weights( self.norm.reset_parameters() final_out_std = self.model_args.dim**-0.5 cutoff_factor = 3 + + # If weight tying is enabled, we don't need to initialize the output layer if self.output is not None: nn.init.trunc_normal_( self.output.weight, @@ -415,12 +400,9 @@ def init_weights( b=cutoff_factor * final_out_std, ) - def _precompute_freqs_cis(self) -> torch.Tensor: - return precompute_freqs_cis( - self.head_dim, - # Need to compute until at least the max token limit for generation - # TODO: explain in docs/composability.md why we removed the 2x - # relaxing in our CP enablement PR + def _precompute_rope_cache(self) -> torch.Tensor: + return precompute_rope_cache( + self.model_args.head_dim, self.model_args.max_seq_len, self.model_args.rope_theta, ) @@ -459,7 +441,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, self.rope_cache) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml b/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml index f9cf0cd2b4..5c73423af0 100644 --- a/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml +++ b/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml @@ -8,14 +8,14 @@ save_traces_folder = "profile_trace" profile_freq = 100 [metrics] -log_freq = 10 +log_freq = 1 enable_tensorboard = false save_tb_folder = "tb" [model] name = "qwen3" flavor = "0.6B" -tokenizer_path = "./assets/tokenizer/Qwen3-0.6B" +hf_assets_path = "./assets/hf/Qwen3-0.6B" # converters = ["float8"] [optimizer] @@ -24,7 +24,7 @@ lr = 3e-4 eps = 1e-8 [lr_scheduler] -warmup_steps = 1 # lr scheduler warm up +warmup_steps = 2 # lr scheduler warm up, 20% total steps [training] local_batch_size = 4 @@ -34,7 +34,6 @@ steps = 10 compile = false dataset = "c4" - [parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 From 9f47cebfc92a464d75e7c1711663bd988bdb3cf4 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Tue, 19 Aug 2025 13:01:15 -0700 Subject: [PATCH 104/128] [dsv3] Remove dtype to avoid confusion (#1599) Remove unused dtype field, we haven't supported FP8 training in torchtitan dsv3 now. --- torchtitan/models/deepseek_v3/__init__.py | 1 - torchtitan/models/deepseek_v3/model/args.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 2e0f42a736..f81db35341 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -150,7 +150,6 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - dtype="fp8", ), } diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 044420d37a..c25d0fbe61 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -8,7 +8,6 @@ from dataclasses import dataclass, field -from typing import Literal from torch import nn @@ -28,7 +27,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs): Attributes: max_batch_size (int): Maximum batch size. max_seq_len (int): Maximum sequence length. - dtype (Literal["bf16", "fp8"]): Data type for computations. vocab_size (int): Vocabulary size. dim (int): Model dimension. inter_dim (int): Intermediate dimension for MLP layers. @@ -59,7 +57,6 @@ class DeepSeekV3ModelArgs(BaseModelArgs): max_batch_size: int = 8 max_seq_len: int = 4096 * 4 - dtype: Literal["bf16", "fp8"] = "bf16" vocab_size: int = 102400 dim: int = 2048 inter_dim: int = 10944 From b5b7ffbd6c981d3311fd64e6be98c705f9411844 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Tue, 19 Aug 2025 13:20:35 -0700 Subject: [PATCH 105/128] [HF] Deprecate `tokenizer_path` in Toml Files (#1592) This PR deprecates the `model.tokenizer_path` in .toml files and replaces them with `model.hf_assets_path`. See https://github.com/pytorch/torchtitan/pull/1526 for more details. Reasoning: `tokenizer_path` is still supported in .toml files by naively overriding `hf_assets_path` when it is specified. This is meant to allow backwards compatibility, but it's not meant to be a well-maintained option in the future. `tokenizer_path` is used for: - loading a tokenizer `hf_assets_path` can be used for: - one stop shop for accessing hf repo's files - loading a tokenizer (or multiple tokenizers) - loading safetensor checkpoints - loading other hf assets within the same repo (encoders, autoencoders, etc.) The reason we change `tokenizer_path -> hf_assets_path` is to be more consistent with this new functionality. and we additionally change the path `assets/tokenizer -> assets/hf/` to reflect the new `download_hf_assets.py` script. Breaking Changes: - You may have to download or move tokenizer to the new default HF assets path `./assets/hf/` - You may have to download "duplicate" tokenizers for different versions of the same model - e.g. Llama-3.1-8B and Llama-3.1-405B will each require a tokenizer in their respective HF assets path --- .../experiments/deepseek_v3/train_configs/deepseek_v2.toml | 2 +- torchtitan/experiments/llama4/train_configs/debug_model.toml | 2 +- .../experiments/llama4/train_configs/llama4_17bx128e.toml | 2 +- torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml | 2 +- torchtitan/experiments/qwen3/README.md | 2 +- torchtitan/models/deepseek_v3/train_configs/debug_model.toml | 2 +- .../models/deepseek_v3/train_configs/deepseek_v3_16b.toml | 2 +- .../models/deepseek_v3/train_configs/deepseek_v3_671b.toml | 2 +- torchtitan/models/llama3/train_configs/llama3_405b.toml | 2 +- torchtitan/models/llama3/train_configs/llama3_70b.toml | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml b/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml index cb0bfa72e9..eae923714e 100644 --- a/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml +++ b/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml @@ -22,7 +22,7 @@ enable_wandb = false name = "deepseek_v2" flavor = "deepseek-ai/DeepSeek-V2-Lite" # test tokenizer.model, for debug purpose only -tokenizer_path = "./tests/assets/tokenizer" +hf_assets_path = "./tests/assets/tokenizer" # converters = ["float8"] [optimizer] diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index 0179b5f9a1..f445b2ad7a 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -22,7 +22,7 @@ enable_wandb = false name = "llama4" flavor = "debugmodel" # test tokenizer.model, for debug purpose only -tokenizer_path = "./tests/assets/tokenizer" +hf_assets_path = "./tests/assets/tokenizer" # converters = ["float8"] [optimizer] diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index 00416eb91c..cb69e63e24 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -17,7 +17,7 @@ save_tb_folder = "tb" [model] name = "llama4" flavor = "17bx128e" -tokenizer_path = "./assets/tokenizer/Llama-4-Scout-17B-16E" +hf_assets_path = "./assets/hf/Llama-4-Scout-17B-128E" # converters = ["float8"] [optimizer] diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index 6a2b660cdf..4e7416fd24 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -17,7 +17,7 @@ save_tb_folder = "tb" [model] name = "llama4" flavor = "17bx16e" -tokenizer_path = "./assets/tokenizer/Llama-4-Scout-17B-16E" +hf_assets_path = "./assets/hf/Llama-4-Scout-17B-16E" # converters = ["float8"] [optimizer] diff --git a/torchtitan/experiments/qwen3/README.md b/torchtitan/experiments/qwen3/README.md index dce71ed11c..77b23d55ce 100644 --- a/torchtitan/experiments/qwen3/README.md +++ b/torchtitan/experiments/qwen3/README.md @@ -10,7 +10,7 @@ Other model sizes are added to the args, but toml file configs need to be added #### Download Qwen3 tokenizer -```python scripts/download_tokenizer.py --repo_id Qwen/Qwen3-0.6B``` +```python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --asset tokenizer``` #### Parity with HF diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index dd94556f27..bb564bd38a 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -22,7 +22,7 @@ enable_wandb = false name = "deepseek_v3" flavor = "debugmodel" # test tokenizer, for debug purpose only -tokenizer_path = "./tests/assets/tokenizer" +hf_assets_path = "./tests/assets/tokenizer" # converters = ["float8"] [optimizer] diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 42e2cc6bc7..15ce11bd07 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -20,7 +20,7 @@ enable_wandb = false [model] name = "deepseek_v3" flavor = "16B" -tokenizer_path = "./assets/tokenizer/deepseek-moe-16b-base" +hf_assets_path = "./assets/hf/deepseek-moe-16b-base" # converters = ["float8"] [optimizer] diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index fc1b512e28..614719dd24 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -20,7 +20,7 @@ enable_wandb = false [model] name = "deepseek_v3" flavor = "671B" -tokenizer_path = "./assets/tokenizer/DeepSeek-V3" +hf_assets_path = "./assets/hf/DeepSeek-V3" # converters = ["float8"] [optimizer] diff --git a/torchtitan/models/llama3/train_configs/llama3_405b.toml b/torchtitan/models/llama3/train_configs/llama3_405b.toml index 63d91f41aa..471ed981bc 100644 --- a/torchtitan/models/llama3/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_405b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "405B" -tokenizer_path = "./assets/tokenizer/Llama-3.1-8B" +hf_assets_path = "./assets/hf/Llama-3.1-405B" converters = ["float8"] [optimizer] diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index 8d3289de85..8a3f2018e8 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "70B" -tokenizer_path = "./assets/tokenizer/Llama-3.1-8B" +hf_assets_path = "./assets/hf/Llama-3.1-70B" # converters = ["float8"] [optimizer] From 084d307c41e013e897cb946633c2b726f6945dad Mon Sep 17 00:00:00 2001 From: lckr <15931380+lckr@users.noreply.github.com> Date: Tue, 19 Aug 2025 22:35:31 +0200 Subject: [PATCH 106/128] [doc] update DeepSeekV3ModelArgs doc string (#1598) In this PR, I'm updated the outdated doc string for DeepSeekV3ModelArgs --- torchtitan/models/deepseek_v3/model/args.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index c25d0fbe61..48e8246fca 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -34,20 +34,17 @@ class DeepSeekV3ModelArgs(BaseModelArgs): n_layers (int): Number of transformer layers. n_dense_layers (int): Number of dense layers in the model. n_heads (int): Number of attention heads. - n_routed_experts (int): Number of routed experts for MoE layers. - n_shared_experts (int): Number of shared experts for MoE layers. - n_activated_experts (int): Number of activated experts in MoE layers. + norm_eps (float): Epsilon value used for RMSNorm. + moe_args (MoEArgs): MoE configuration. n_expert_groups (int): Number of expert groups. n_limited_groups (int): Number of limited groups for MoE routing. - score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. - route_scale (float): Scaling factor for routing scores. - use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers. - load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers. q_lora_rank (int): LoRA rank for query projections. kv_lora_rank (int): LoRA rank for key-value projections. qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. v_head_dim (int): Dimension for value projections. + use_flex_attn (bool): Whether to use FlexAttention. + attn_mask_type (str): Type of attention mask. original_seq_len (int): Original sequence length. rope_theta (float): Base for rotary positional encoding. rope_factor (float): Scaling factor for extended sequence lengths. From 9874e84d1844626a5a601718f5e68266d7f71442 Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Tue, 19 Aug 2025 14:11:18 -0700 Subject: [PATCH 107/128] Change freq_cis from persistent buffer to non-persistent buffer (#1600) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Context As PP didn't need persistent buffer, and `torch.compile` works with non-persistent buffer now, change freq_cis from persistent buffer to non-persistent buffer . In this way, checkpointer doesn't need to explicitly exclude freq_cis when loading. ## Test 1. llama3 model with torch.compile ✅ 2. llama4 model with torch.compile ✅ 3. deepseek-v3 model with torch.compile ✅ --- scripts/generate/test_generate.py | 3 --- tests/unit_tests/test_checkpoint.py | 4 ++-- torchtitan/components/checkpoint.py | 9 --------- torchtitan/experiments/llama4/model/model.py | 11 +++-------- torchtitan/models/deepseek_v3/infra/parallelize.py | 7 ++++++- torchtitan/models/deepseek_v3/model/model.py | 2 +- torchtitan/models/llama3/model/model.py | 11 +++-------- 7 files changed, 15 insertions(+), 32 deletions(-) diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 60cd3d04c1..21322ba232 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -24,7 +24,6 @@ parallelize_module, RowwiseParallel, ) -from torchtitan.components.checkpoint import excluded_parameters_for_model_only from torchtitan.components.metrics import build_device_memory_monitor from torchtitan.config import ConfigManager from torchtitan.distributed import ParallelDims, utils as dist_utils @@ -143,8 +142,6 @@ def test_generate( model.eval() state_dict = model.state_dict() - for k in excluded_parameters_for_model_only: - state_dict.pop(k, None) # Checkpoint Loading begin = time.monotonic() diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index 4d4c942c86..a0fb5d3bab 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -562,7 +562,7 @@ def test_enable_first_step_checkpoint(self, mock_save, mock_rank): @mock.patch("torch.distributed.get_rank", return_value=0) @mock.patch("torchtitan.components.checkpoint.dcp.save") - def test_excluded_parameters_not_saved(self, mock_save, mock_rank): + def test_non_persist_buffer_not_saved(self, mock_save, mock_rank): """Test that freqs_cis is not saved""" # Create a fake model with freqs_cis and other parameters @@ -572,7 +572,7 @@ def __init__(self): self.weight = nn.Parameter(torch.randn(2, 2)) self.bias = nn.Parameter(torch.randn(2)) # Register freqs_cis as a buffer (common pattern in transformer models) - self.register_buffer("freqs_cis", torch.randn(10, 5)) + self.register_buffer("freqs_cis", torch.randn(10, 5), persistent=False) self.other_param = nn.Parameter(torch.randn(3, 3)) fake_model = FakeModelWithFreqsCis() diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 478062e8e1..2df8f9cd4b 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -55,12 +55,6 @@ class AsyncMode(str, enum.Enum): ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" -# For now, we will manually pop the freqs_cis buffer, as we made this permanent -# temporarily and we don't want to include it in the exported state_dict. -# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 -excluded_parameters_for_model_only = {"freqs_cis"} - - class ModelWrapper(Stateful): def __init__(self, model: nn.Module | list[nn.Module]) -> None: self.model = [model] if isinstance(model, nn.Module) else model @@ -70,9 +64,6 @@ def _get_state_dict(self) -> dict[str, Any]: state_dict = { k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() } - # Exclude parameters that should not be saved - for excluded_key in excluded_parameters_for_model_only: - state_dict.pop(excluded_key, None) return state_dict def state_dict(self) -> dict[str, Any]: diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py index eb46a22b00..84e5613de0 100644 --- a/torchtitan/experiments/llama4/model/model.py +++ b/torchtitan/experiments/llama4/model/model.py @@ -391,14 +391,9 @@ def __init__(self, model_args: TransformerModelArgs): self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) - # TODO persistent should be set to false, since this buffer can be recomputed. - # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, - # compile or pipeline-tracer will not correctly handle non-persistent buffers, - # so we need to fix that. (2) if we initialize pipeline-parallel models from - # a seed checkpoint rather than calling init_weights, we need freqs_cis to be - # initialized by the checkpoint, or we need to add a separate initializer for - # just the non-persistent buffers that is called after loading checkpoints. - self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + self.register_buffer( + "freqs_cis", self._precompute_freqs_cis(), persistent=False + ) self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.n_layers): diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index d82270edd0..1a64d34a1c 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import Replicate, Shard @@ -18,7 +19,11 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.distributed.expert_parallel import NoParallel -from torchtitan.experiments.llama4.infra.parallelize import apply_fsdp, apply_moe_ep_tp +from torchtitan.experiments.llama4.infra.parallelize import ( + apply_compile, + apply_fsdp, + apply_moe_ep_tp, +) from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp from torchtitan.tools.logging import logger diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index dd31fc3181..5249a26d5d 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -322,7 +322,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): self.max_seq_len = model_args.max_seq_len self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) self.register_buffer( - "freqs_cis", precompute_freqs_cis(model_args), persistent=True + "freqs_cis", precompute_freqs_cis(model_args), persistent=False ) self.layers = torch.nn.ModuleDict() diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index e45af90ba2..039fe0da77 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -335,14 +335,9 @@ def __init__(self, model_args: TransformerModelArgs): self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) - # TODO persistent should be set to false, since this buffer can be recomputed. - # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, - # compile or pipeline-tracer will not correctly handle non-persistent buffers, - # so we need to fix that. (2) if we initialize pipeline-parallel models from - # a seed checkpoint rather than calling init_weights, we need freqs_cis to be - # initialized by the checkpoint, or we need to add a separate initializer for - # just the non-persistent buffers that is called after loading checkpoints. - self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + self.register_buffer( + "freqs_cis", self._precompute_freqs_cis(), persistent=False + ) self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.n_layers): From c0b2e5a321efd7e3597138a14db8d6c1ab4c38a4 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Wed, 20 Aug 2025 10:48:00 -0700 Subject: [PATCH 108/128] [HF] Model Definition Conversion Support for FLUX (#1582) This PR adds the `FluxStateDictAdapter`, allowing us to convert checkpoints to and from HF. Additional changes: - Modifies `download_hf_assets` script to support downloading diffusion-type safetensor files - Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf` so that conversion script can be reused - e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py ./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux --model_flavor flux-dev` Tests: Performing KL divergence test on the forward pass of converted weights loaded in `torchtitan` and HF weights loaded with HF `FluxTransformer2DModel`, we get: ``` Average loss for test from_hf is 7.233546986222528e-13 ``` Addiitonally, we can now run inference with HF weights to verify changes made in https://github.com/pytorch/torchtitan/pull/1548 ### Batched Inference on TorchTitan: | | prompt0 | prompt1 | prompt2 | | --- | --- | --- | --- | | no CFG | prompt0_nocfg | prompt1_nocfg | prompt2_nocfg | | CFG | prompt0_cfg | prompt1_cfg | prompt2_cfg | --- .../checkpoint_conversion/convert_from_hf.py | 2 + .../checkpoint_conversion/convert_to_hf.py | 2 + scripts/download_hf_assets.py | 4 +- torchtitan/experiments/flux/__init__.py | 2 + .../flux/model/state_dict_adapter.py | 288 ++++++++++++++++++ torchtitan/experiments/flux/validate.py | 7 + torchtitan/protocols/state_dict_adapter.py | 4 +- 7 files changed, 305 insertions(+), 4 deletions(-) create mode 100644 torchtitan/experiments/flux/model/state_dict_adapter.py diff --git a/scripts/checkpoint_conversion/convert_from_hf.py b/scripts/checkpoint_conversion/convert_from_hf.py index f71af08363..fae7eec17b 100644 --- a/scripts/checkpoint_conversion/convert_from_hf.py +++ b/scripts/checkpoint_conversion/convert_from_hf.py @@ -16,6 +16,8 @@ @torch.inference_mode() def convert_from_hf(input_dir, output_dir, model_name, model_flavor): + if model_name == "flux": + import torchtitan.experiments.flux # noqa: F401 # initialize model to allocate memory for state dict train_spec = train_spec_module.get_train_spec(model_name) model_args = train_spec.model_args[model_flavor] diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py index db69a34b0e..f0ea17cc63 100644 --- a/scripts/checkpoint_conversion/convert_to_hf.py +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -16,6 +16,8 @@ @torch.inference_mode() def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_path): + if model_name == "flux": + import torchtitan.experiments.flux # noqa: F401 # load model and model args so that we can get the state dict shape train_spec = train_spec_module.get_train_spec(model_name) model_args = train_spec.model_args[model_flavor] diff --git a/scripts/download_hf_assets.py b/scripts/download_hf_assets.py index 017cc0a405..e1092b2d70 100644 --- a/scripts/download_hf_assets.py +++ b/scripts/download_hf_assets.py @@ -76,8 +76,8 @@ def download_hf_assets( "merges.txt", "special_tokens_map.json", ], - "safetensors": ["*.safetensors", "model.safetensors.index.json"], - "index": ["model.safetensors.index.json"], + "safetensors": ["*.safetensors", "*model.safetensors.index.json"], + "index": ["*model.safetensors.index.json"], "config": ["config.json", "generation_config.json"], } diff --git a/torchtitan/experiments/flux/__init__.py b/torchtitan/experiments/flux/__init__.py index b7c55c4ee4..693022fa56 100644 --- a/torchtitan/experiments/flux/__init__.py +++ b/torchtitan/experiments/flux/__init__.py @@ -17,6 +17,7 @@ from .model.args import FluxModelArgs from .model.autoencoder import AutoEncoderParams from .model.model import FluxModel +from .model.state_dict_adapter import FluxStateDictAdapter from .validate import build_flux_validator __all__ = [ @@ -119,5 +120,6 @@ build_tokenizer_fn=None, build_loss_fn=build_mse_loss, build_validator_fn=build_flux_validator, + state_dict_adapter=FluxStateDictAdapter, ) ) diff --git a/torchtitan/experiments/flux/model/state_dict_adapter.py b/torchtitan/experiments/flux/model/state_dict_adapter.py new file mode 100644 index 0000000000..de92026e4c --- /dev/null +++ b/torchtitan/experiments/flux/model/state_dict_adapter.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import os +import re + +from collections import defaultdict +from typing import Any + +import torch + +from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter + +from .args import FluxModelArgs + +logger = logging.getLogger() + + +class FluxStateDictAdapter(BaseStateDictAdapter): + """ + State dict adapter for Flux model to convert between HuggingFace safetensors format + and torchtitan DCP format. + + This state dict adapter handles only the state dict of transformer from Flux HF model repo. + """ + + def __init__(self, model_args: FluxModelArgs, hf_assets_path: str | None): + + # Build fqn to index mapping if hf_assets_path + if hf_assets_path: + # If directory is multimodal ensure that hf_assets_path is to the folder containing transformer's safetensors + if os.path.exists(os.path.join(hf_assets_path, "model_index.json")): + hf_assets_path = os.path.join(hf_assets_path, "transformers") + + # Check if safetensors index file exists + index_files = [ + "model.safetensors.index.json", + "diffusion_pytorch_model.safetensors.index.json", + ] + + hf_safetensors_indx = None + for index_file in index_files: + mapping_path = os.path.join(hf_assets_path, index_file) + if os.path.exists(mapping_path): + with open(mapping_path, "r") as f: + hf_safetensors_indx = json.load(f) + break + if hf_safetensors_indx is None: + logger.warning( + f"no safetensors index file found at hf_assets_path: {hf_assets_path}. \ + Defaulting to saving a single safetensors file if checkpoint is saved in HF format.", + ) + + if hf_safetensors_indx: + self.fqn_to_index_mapping = {} + for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): + indx = re.search(r"\d+", raw_indx).group(0) + self.fqn_to_index_mapping[hf_key] = indx + else: + self.fqn_to_index_mapping = None + + self.model_args = model_args + self.hf_assets_path = hf_assets_path + + # mapping containing direct 1 to 1 mappings from HF to torchtitan + self.from_hf_map_direct = { + "x_embedder.bias": "img_in.bias", + "x_embedder.weight": "img_in.weight", + "context_embedder.bias": "txt_in.bias", + "context_embedder.weight": "txt_in.weight", + "norm_out.linear.bias": "final_layer.adaLN_modulation.1.bias", + "norm_out.linear.weight": "final_layer.adaLN_modulation.1.weight", + "proj_out.bias": "final_layer.linear.bias", + "proj_out.weight": "final_layer.linear.weight", + "time_text_embed.text_embedder.linear_1.bias": "vector_in.in_layer.bias", + "time_text_embed.text_embedder.linear_1.weight": "vector_in.in_layer.weight", + "time_text_embed.timestep_embedder.linear_1.bias": "time_in.in_layer.bias", + "time_text_embed.timestep_embedder.linear_1.weight": "time_in.in_layer.weight", + "time_text_embed.text_embedder.linear_2.bias": "vector_in.out_layer.bias", + "time_text_embed.text_embedder.linear_2.weight": "vector_in.out_layer.weight", + "time_text_embed.timestep_embedder.linear_2.bias": "time_in.out_layer.bias", + "time_text_embed.timestep_embedder.linear_2.weight": "time_in.out_layer.weight", + "single_transformer_blocks.{}.attn.norm_k.weight": "single_blocks.{}.norm.key_norm.weight", + "single_transformer_blocks.{}.attn.norm_q.weight": "single_blocks.{}.norm.query_norm.weight", + "single_transformer_blocks.{}.norm.linear.bias": "single_blocks.{}.modulation.lin.bias", + "single_transformer_blocks.{}.norm.linear.weight": "single_blocks.{}.modulation.lin.weight", + "single_transformer_blocks.{}.proj_out.bias": "single_blocks.{}.linear2.bias", + "single_transformer_blocks.{}.proj_out.weight": "single_blocks.{}.linear2.weight", + "transformer_blocks.{}.attn.norm_added_k.weight": "double_blocks.{}.txt_attn.norm.key_norm.weight", + "transformer_blocks.{}.attn.norm_added_q.weight": "double_blocks.{}.txt_attn.norm.query_norm.weight", + "transformer_blocks.{}.attn.norm_k.weight": "double_blocks.{}.img_attn.norm.key_norm.weight", + "transformer_blocks.{}.attn.norm_q.weight": "double_blocks.{}.img_attn.norm.query_norm.weight", + "transformer_blocks.{}.attn.to_add_out.bias": "double_blocks.{}.txt_attn.proj.bias", + "transformer_blocks.{}.attn.to_add_out.weight": "double_blocks.{}.txt_attn.proj.weight", + "transformer_blocks.{}.attn.to_out.0.bias": "double_blocks.{}.img_attn.proj.bias", + "transformer_blocks.{}.attn.to_out.0.weight": "double_blocks.{}.img_attn.proj.weight", + "transformer_blocks.{}.ff.net.0.proj.bias": "double_blocks.{}.img_mlp.0.bias", + "transformer_blocks.{}.ff.net.0.proj.weight": "double_blocks.{}.img_mlp.0.weight", + "transformer_blocks.{}.ff.net.2.bias": "double_blocks.{}.img_mlp.2.bias", + "transformer_blocks.{}.ff.net.2.weight": "double_blocks.{}.img_mlp.2.weight", + "transformer_blocks.{}.ff_context.net.0.proj.bias": "double_blocks.{}.txt_mlp.0.bias", + "transformer_blocks.{}.ff_context.net.0.proj.weight": "double_blocks.{}.txt_mlp.0.weight", + "transformer_blocks.{}.ff_context.net.2.bias": "double_blocks.{}.txt_mlp.2.bias", + "transformer_blocks.{}.ff_context.net.2.weight": "double_blocks.{}.txt_mlp.2.weight", + "transformer_blocks.{}.norm1.linear.bias": "double_blocks.{}.img_mod.lin.bias", + "transformer_blocks.{}.norm1.linear.weight": "double_blocks.{}.img_mod.lin.weight", + "transformer_blocks.{}.norm1_context.linear.bias": "double_blocks.{}.txt_mod.lin.bias", + "transformer_blocks.{}.norm1_context.linear.weight": "double_blocks.{}.txt_mod.lin.weight", + } + + # combination plan to keep track of the order of layers to be combined + self.combination_plan = { + "single_blocks.{}.linear1.bias": [ + "single_transformer_blocks.{}.attn.to_q.bias", + "single_transformer_blocks.{}.attn.to_k.bias", + "single_transformer_blocks.{}.attn.to_v.bias", + "single_transformer_blocks.{}.proj_mlp.bias", + ], + "single_blocks.{}.linear1.weight": [ + "single_transformer_blocks.{}.attn.to_q.weight", + "single_transformer_blocks.{}.attn.to_k.weight", + "single_transformer_blocks.{}.attn.to_v.weight", + "single_transformer_blocks.{}.proj_mlp.weight", + ], + "double_blocks.{}.txt_attn.qkv.bias": [ + "transformer_blocks.{}.attn.add_q_proj.bias", + "transformer_blocks.{}.attn.add_k_proj.bias", + "transformer_blocks.{}.attn.add_v_proj.bias", + ], + "double_blocks.{}.txt_attn.qkv.weight": [ + "transformer_blocks.{}.attn.add_q_proj.weight", + "transformer_blocks.{}.attn.add_k_proj.weight", + "transformer_blocks.{}.attn.add_v_proj.weight", + ], + "double_blocks.{}.img_attn.qkv.bias": [ + "transformer_blocks.{}.attn.to_q.bias", + "transformer_blocks.{}.attn.to_k.bias", + "transformer_blocks.{}.attn.to_v.bias", + ], + "double_blocks.{}.img_attn.qkv.weight": [ + "transformer_blocks.{}.attn.to_q.weight", + "transformer_blocks.{}.attn.to_k.weight", + "transformer_blocks.{}.attn.to_v.weight", + ], + } + + # reverse of combination plan: maps fqns to the fqn they are combined into + self.reverse_combination_plan = { + value: key + for key, value_list in self.combination_plan.items() + for value in value_list + } + + # original flux implementation and HF swap shift and scale + # https://github.com/huggingface/diffusers/blob/main/scripts/convert_flux_to_diffusers.py#L63-L68 + def _swap_scale_shift(self, weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + """Convert TorchTitan DCP state dict to HuggingFace safetensors format.""" + + to_hf_map_direct = { + v: k for k, v in self.from_hf_map_direct.items() if v is not None + } + hf_state_dict = {} + + for key, value in state_dict.items(): + # Extract layer_num and abstract key if necessary + if "blocks" in key: + layer_num = re.search(r"\d+", key).group(0) + key = re.sub(r"(\d+)", "{}", key, count=1) + else: + layer_num = None + + if key in to_hf_map_direct: + # handle direct mapping + new_key = to_hf_map_direct[key] + + # perform swap to be compatible with HF + if key in [ + "final_layer.adaLN_modulation.1.weight", + "final_layer.adaLN_modulation.1.bias", + ]: + value = self._swap_scale_shift(value) + + if new_key is None: + continue + if layer_num: + new_key = new_key.format(layer_num) + + hf_state_dict[new_key] = value + + elif key in self.combination_plan: + # handle splitting layers + if key in [ + "single_blocks.{}.linear1.bias", + "single_blocks.{}.linear1.weight", + ]: + mlp_hidden_dim = int( + self.model_args.hidden_size * self.model_args.mlp_ratio + ) + split_plan = [ + self.model_args.hidden_size, + self.model_args.hidden_size, + self.model_args.hidden_size, + mlp_hidden_dim, + ] + # split into q, k, v, mlp + split_vals = torch.split( + value, + split_plan, + dim=0, + ) + else: + # split into q, k, v + split_vals = torch.split(value, self.model_args.hidden_size, dim=0) + + new_keys = ( + abstract_key.format(layer_num) + for abstract_key in self.combination_plan[key] + ) + + for new_key, value in zip(new_keys, split_vals): + hf_state_dict[new_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + """Convert HuggingFace safetensors state dict to TorchTitan DCP format.""" + state_dict = {} + + # Keeps track of HF fqn values to combine into one TT fqn later + # {tt_fqn : {hf_fqn1 : value}, {hf_fqn2 : value}, ...} + to_combine = defaultdict(dict) + + for key, value in hf_state_dict.items(): + # extract layer_num and abstract key if necessary + if "blocks" in key: + layer_num = re.search(r"\d+", key).group(0) + key = re.sub(r"(\d+)", "{}", key, count=1) + else: + layer_num = None + + if key in self.from_hf_map_direct: + new_key = self.from_hf_map_direct[key] + + # perform swap to be compatible with HF + if key in [ + "norm_out.linear.weight", + "norm_out.linear.bias", + ]: + value = self._swap_scale_shift(value) + if new_key is None: + continue + if layer_num: + new_key = new_key.format(layer_num) + + state_dict[new_key] = value + elif key in self.reverse_combination_plan: + # collect the layers that need to be combined + tt_abstract_key = self.reverse_combination_plan[key] + if tt_abstract_key is None: + continue + to_combine[tt_abstract_key.format(layer_num)][ + key.format(layer_num) + ] = value + + # combine collected values + for tt_fqn, hf_fqn_map in to_combine.items(): + layer_num = re.search(r"\d+", tt_fqn).group(0) + tt_abstract_key = re.sub(r"(\d+)", "{}", tt_fqn, count=1) + combine_values = [] + # use combination_plan to ensure correct order before concatenation + for hf_abstract_key in self.combination_plan[tt_abstract_key]: + hf_key = hf_abstract_key.format(layer_num) + combine_values.append(hf_fqn_map[hf_key]) + + value = torch.cat(combine_values, dim=0) + state_dict[tt_fqn] = value + + return state_dict diff --git a/torchtitan/experiments/flux/validate.py b/torchtitan/experiments/flux/validate.py index 059faf5b65..89dc4f8942 100644 --- a/torchtitan/experiments/flux/validate.py +++ b/torchtitan/experiments/flux/validate.py @@ -104,6 +104,10 @@ def validate( model = model_parts[0] model.eval() + # Disable cfg dropout during validation + training_cfg_prob = self.job_config.training.classifier_free_guidance_prob + self.job_config.training.classifier_free_guidance_prob = 0.0 + save_img_count = self.job_config.validation.save_img_count parallel_dims = self.parallel_dims @@ -244,6 +248,9 @@ def validate( # Set model back to train mode model.train() + # re-enable cfg dropout for training + self.job_config.training.classifier_free_guidance_prob = training_cfg_prob + def build_flux_validator( job_config: JobConfig, diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index ce03d732d6..106a7937ef 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -66,8 +66,8 @@ def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None): hf_safetensors_indx = json.load(f) except FileNotFoundError: logger.warning( - "model.safetensors.index.json not found at hf_assets_path: {mapping_path}. \ - Defaulting to saving a single safetensors file if checkpoint is saved in HF format.", + f"model.safetensors.index.json not found at hf_assets_path: {mapping_path}. \ + Defaulting to saving a single safetensors file if checkpoint is saved in HF format." ) hf_safetensors_indx = None From 46a32e7883e8ab22b91007b5f5c77d9326e35480 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Wed, 20 Aug 2025 10:48:11 -0700 Subject: [PATCH 109/128] Deprecate Llama Conversion Script (#1603) This PR removes the `convert_from_llama.py` script since it is superseded by `convert_from_hf.py` instead. --- docs/checkpoint.md | 9 -- .../convert_from_llama.py | 146 ------------------ 2 files changed, 155 deletions(-) delete mode 100644 scripts/checkpoint_conversion/convert_from_llama.py diff --git a/docs/checkpoint.md b/docs/checkpoint.md index da6598ca8d..45915d8ad8 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -105,12 +105,3 @@ python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outp That's it. You have now successfully converted a sharded `torchtitan` checkpoint for use with pytorch formats. - -### PyTorch Meta Llama - -An example script for converting the original Llama3 checkpoints into DCP format to be used with `torchtitan` can be found in `scripts/convert_from_llama.py`. - -The script expects a path to the original checkpoint files, and a path to an output directory: -```bash -python -m scripts.convert_from_llama -``` diff --git a/scripts/checkpoint_conversion/convert_from_llama.py b/scripts/checkpoint_conversion/convert_from_llama.py deleted file mode 100644 index 9a6e1b1db3..0000000000 --- a/scripts/checkpoint_conversion/convert_from_llama.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import json -from pathlib import Path - -import torch -import torch.distributed.checkpoint as DCP -from torchtitan.tools.logging import init_logger, logger - - -@torch.inference_mode() -def convert_from_llama(input_dir, output_dir, max_seq_len: int): - with open(input_dir / "params.json", "r") as f: - params = json.load(f) - n_layers = params["n_layers"] - n_heads = params["n_heads"] - dim = params["dim"] - dims_per_head = dim // n_heads - - checkpoint_list = sorted([file for file in input_dir.rglob("*.pth")]) - logger.info( - f"Loading original Llama weights from {[ckpt.name for ckpt in checkpoint_list]}" - ) - shards = [ - torch.load(ckpt, map_location="cpu", weights_only=True, mmap=True) - for ckpt in checkpoint_list - ] - - if len(shards) == 1: - state_dict = shards[0] - else: # sharded - state_dict = {} - n_heads_per_shard = n_heads // len(shards) - num_key_value_heads = params["n_kv_heads"] - n_kv_heads_per_shard = num_key_value_heads // len(shards) - key_value_dim = dims_per_head * num_key_value_heads - for layer in range(n_layers): - state_dict[f"layers.{layer}.attention_norm.weight"] = shards[0][ - f"layers.{layer}.attention_norm.weight" - ] - for i in range(len(shards)): - del shards[i][f"layers.{layer}.attention_norm.weight"] - state_dict[f"layers.{layer}.ffn_norm.weight"] = shards[0][ - f"layers.{layer}.ffn_norm.weight" - ] - for i in range(len(shards)): - del shards[i][f"layers.{layer}.ffn_norm.weight"] - - for wn, nh in [ - ("wq", n_heads_per_shard), - ("wk", n_kv_heads_per_shard), - ("wv", n_kv_heads_per_shard), - ]: - state_dict[f"layers.{layer}.attention.{wn}.weight"] = torch.cat( - [ - shards[i][f"layers.{layer}.attention.{wn}.weight"].view( - nh, dims_per_head, dim - ) - for i in range(len(shards)) - ], - dim=0, - ).reshape(nh * len(shards) * dims_per_head, dim) - for i in range(len(shards)): - del shards[i][f"layers.{layer}.attention.{wn}.weight"] - - state_dict[f"layers.{layer}.attention.wo.weight"] = torch.cat( - [ - shards[i][f"layers.{layer}.attention.wo.weight"] - for i in range(len(shards)) - ], - dim=1, - ) - for i in range(len(shards)): - del shards[i][f"layers.{layer}.attention.wo.weight"] - - state_dict[f"layers.{layer}.feed_forward.w1.weight"] = torch.cat( - [ - shards[i][f"layers.{layer}.feed_forward.w1.weight"] - for i in range(len(shards)) - ], - dim=0, - ) - for i in range(len(shards)): - del shards[i][f"layers.{layer}.feed_forward.w1.weight"] - - state_dict[f"layers.{layer}.feed_forward.w2.weight"] = torch.cat( - [ - shards[i][f"layers.{layer}.feed_forward.w2.weight"] - for i in range(len(shards)) - ], - dim=1, - ) - for i in range(len(shards)): - del shards[i][f"layers.{layer}.feed_forward.w2.weight"] - - state_dict[f"layers.{layer}.feed_forward.w3.weight"] = torch.cat( - [ - shards[i][f"layers.{layer}.feed_forward.w3.weight"] - for i in range(len(shards)) - ], - dim=0, - ) - for i in range(len(shards)): - del shards[i][f"layers.{layer}.feed_forward.w3.weight"] - - state_dict["norm.weight"] = shards[0]["norm.weight"] - for i in range(len(shards)): - del shards[i]["norm.weight"] - state_dict["tok_embeddings.weight"] = torch.cat( - [shards[i]["tok_embeddings.weight"] for i in range(len(shards))], dim=0 - ) - for i in range(len(shards)): - del shards[i]["tok_embeddings.weight"] - state_dict["output.weight"] = torch.cat( - [shards[i]["output.weight"] for i in range(len(shards))], dim=0 - ) - for i in range(len(shards)): - del shards[i]["output.weight"] - - logger.info(f"Writing to DCP at '{output_dir}'") - output_dir.mkdir(parents=True, exist_ok=True) - storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8) - DCP.save(state_dict, storage_writer=storage_writer) - - -if __name__ == "__main__": - init_logger() - parser = argparse.ArgumentParser(description="Convert Llama weights to DCP format.") - parser.add_argument( - "input_dir", type=Path, help="Input directory with original Llama weights." - ) - parser.add_argument("output_dir", type=Path, help="Output directory for DCP.") - parser.add_argument( - "--max_seq_len", - type=int, - default=131072, - help="The maximum sequence length of the model.", - ) - args = parser.parse_args() - - convert_from_llama(args.input_dir, args.output_dir, max_seq_len=args.max_seq_len) From 08b8b244f3a63dfb606444dbe18bddaedd04ce1e Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:00:21 -0700 Subject: [PATCH 110/128] [refactor] support compile model and loss separately (#1608) Creating a new field in `JobConfig`, with the default being ``` [compile] enable=false components = ["model", "loss"] ``` This way we get to compile loss separately to get memory reduction, even when the model is not ready to be compiled. This PR also applies loss compilation to DeepSeek 16B and 671B. --- docs/float8.md | 6 +++--- scripts/estimate/estimation.py | 4 ++-- scripts/estimate/run_memory_estimation.sh | 2 +- tests/integration_tests.py | 12 ++++++------ tests/integration_tests_h100.py | 6 +++--- torchtitan/components/loss.py | 2 +- torchtitan/components/quantization/float8.py | 6 +++++- torchtitan/components/quantization/mx.py | 7 +++++-- torchtitan/config/job_config.py | 17 +++++++++++++---- torchtitan/experiments/flux/loss.py | 2 +- .../flux/train_configs/debug_model.toml | 1 - .../flux/train_configs/flux_dev_model.toml | 1 - .../flux/train_configs/flux_schnell_model.toml | 1 - torchtitan/experiments/forge/job_config.py | 2 ++ .../experiments/llama4/infra/parallelize.py | 9 ++++++--- .../llama4/train_configs/debug_model.toml | 5 ++++- .../llama4/train_configs/llama4_17bx128e.toml | 5 ++++- .../llama4/train_configs/llama4_17bx16e.toml | 5 ++++- .../qwen3/train_configs/qwen3_0.6b.toml | 5 ++++- torchtitan/experiments/simple_fsdp/README.md | 2 +- .../experiments/simple_fsdp/parallelize.py | 9 ++++++--- .../simple_fsdp/tests/integration_tests.py | 2 +- .../models/deepseek_v3/infra/parallelize.py | 7 +++++-- .../deepseek_v3/train_configs/debug_model.toml | 7 +++++-- .../train_configs/deepseek_v3_16b.toml | 10 +++++++--- .../train_configs/deepseek_v3_671b.toml | 9 +++++++-- torchtitan/models/llama3/infra/parallelize.py | 11 +++++++---- .../llama3/train_configs/debug_model.toml | 5 ++++- .../llama3/train_configs/llama3_405b.toml | 5 ++++- .../models/llama3/train_configs/llama3_70b.toml | 5 ++++- .../models/llama3/train_configs/llama3_8b.toml | 5 ++++- 31 files changed, 119 insertions(+), 56 deletions(-) diff --git a/docs/float8.md b/docs/float8.md index 5d90e0617e..1a7277ff3d 100644 --- a/docs/float8.md +++ b/docs/float8.md @@ -11,14 +11,14 @@ USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git For float8 with tensorwise scaling, launch training job with the following command (or alternatively set configs in toml files) ``` -CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --training.compile +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --compile.enable ``` * `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. * `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth. * `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. * `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using. * **Auto-filter**: add `"auto_filter_small_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers are not large enough to benefit from float8 training, since the GEMM has to be big enough that the speedup from using FP8 tensorcores is greater than the overhead of creating dynamically quantized inputs. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs, where (K,N) represents the linear layer weight shape. For best performance, you should still manually filter out layers that are too small to benefit from float8 training. -* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels +* `--compile.enable` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files) ``` @@ -26,7 +26,7 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_trai ``` * `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. * `--float8.recipe_name="rowwise"`: use the rowwise scaling recipe for higher accuracy compared to tensorwise scaling -* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels +* `--compile.enable` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels For parallelisms, for float8 with tensorwise scaling we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`). For float8 with rowwise scaling, all distributed communication is done in high precision. diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 510cc394f7..8103ae0b57 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -33,9 +33,9 @@ def estimate_memory(job_config: JobConfig): # Get the world size world_size = int(os.environ["WORLD_SIZE"]) - if job_config.training.compile or job_config.parallelism.enable_compiled_autograd: + if job_config.compile.enable or job_config.parallelism.enable_compiled_autograd: logger.info("Compile mode is not supported yet. Switching to eager mode.") - job_config.training.compile = False + job_config.compile.enable = False job_config.parallelism.enable_compiled_autograd = False # init fake pg diff --git a/scripts/estimate/run_memory_estimation.sh b/scripts/estimate/run_memory_estimation.sh index e8f9ecc88f..9d766a07a2 100755 --- a/scripts/estimate/run_memory_estimation.sh +++ b/scripts/estimate/run_memory_estimation.sh @@ -23,4 +23,4 @@ fi # Export WORLD_SIZE and LOCAL_RANK export WORLD_SIZE=$((NGPU * NNODES)) export LOCAL_RANK=0 -python -m scripts.estimate.estimation --job.config_file ${CONFIG_FILE} --memory_estimation.enabled $overrides +python -m scripts.estimate.estimation --job.config_file ${CONFIG_FILE} --memory_estimation.enable $overrides diff --git a/tests/integration_tests.py b/tests/integration_tests.py index f7512836c6..73ded45482 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -57,7 +57,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", ], ], "1D compile", @@ -66,7 +66,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", "--activation_checkpoint.mode selective", "--activation_checkpoint.selective_ac_option op", ], @@ -86,7 +86,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", "--parallelism.tensor_parallel_degree 2", ], ], @@ -97,7 +97,7 @@ def build_test_list(): # OverrideDefinitions( # [ # [ - # "--training.compile", + # "--compile.enable", # "--parallelism.tensor_parallel_degree 2", # "--parallelism.enable_async_tensor_parallel", # ], @@ -267,7 +267,7 @@ def build_test_list(): "--parallelism.pipeline_parallel_degree 2", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", - "--training.compile", + "--compile.enable", ], ], "PP+DP+TP 3D test with torch.compile", @@ -464,7 +464,7 @@ def build_test_list(): # OverrideDefinitions( # [ # [ - # "--memory_estimation.enabled", + # "--memory_estimation.enable", # ] # ], # "FSDP2 Memory Tracking and Estimation", diff --git a/tests/integration_tests_h100.py b/tests/integration_tests_h100.py index 29c11476b5..b45d6f3159 100755 --- a/tests/integration_tests_h100.py +++ b/tests/integration_tests_h100.py @@ -32,7 +32,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", "--parallelism.tensor_parallel_degree 2", "--parallelism.enable_async_tensor_parallel", ], @@ -54,7 +54,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", "--parallelism.data_parallel_shard_degree=2", "--parallelism.tensor_parallel_degree=2", "--parallelism.pipeline_parallel_degree=2", @@ -71,7 +71,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--training.compile", + "--compile.enable", "--parallelism.data_parallel_shard_degree=2", "--parallelism.data_parallel_replicate_degree=2", "--parallelism.context_parallel_degree=2", diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index 6aa1dd5699..84ae786834 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -24,7 +24,7 @@ def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor def build_cross_entropy_loss(job_config: JobConfig): loss_fn = cross_entropy_loss - if job_config.training.compile: + if job_config.compile.enable and "loss" in job_config.compile.components: logger.info("Compiling the loss function with torch.compile") loss_fn = torch.compile(loss_fn) return loss_fn diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 3629258154..22134c65f0 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -28,8 +28,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = False float8_config: Float8 = job_config.float8 + compile_config = job_config.compile + model_compile_enabled = ( + compile_config.enable and "model" in compile_config.components + ) if has_cuda_capability(8, 9) or ( - float8_config.emulate and not job_config.training.compile + float8_config.emulate and not model_compile_enabled ): pass else: diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index 15c74b7fd7..84216dadbd 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -52,9 +52,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ), "MXFP8 is only supported on SM100 or architectures" # TP not yet supported with torch.compile + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) assert not ( - job_config.training.compile - and job_config.parallelism.tensor_parallel_degree > 1 + model_compile_enabled and job_config.parallelism.tensor_parallel_degree > 1 ), "TP not yet supported with torch.compile for mxfp8" # For MoE training with mxfp8, token group sizes must be multiples of 32 diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 9a78451fc6..a688cdadae 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -218,9 +218,6 @@ class Training: This feature only takes effect when data_parallel_shard_degree > 1 """ - compile: bool = False - """Whether to compile the model""" - gc_freq: int = 50 """Python garbage control scheduling interval, in steps""" @@ -550,6 +547,17 @@ class ActivationCheckpoint: """ +@dataclass +class Compile: + enable: bool = False + """Whether to apply torch.compile""" + + components: list[Literal["model", "loss"]] = field( + default_factory=lambda: ["model", "loss"] + ) + """Which components to compile""" + + @dataclass class Float8: enable_fsdp_float8_all_gather: bool = False @@ -630,7 +638,7 @@ class Comm: @dataclass class MemoryEstimation: - enabled: bool = False + enable: bool = False """Whether to estimate memory usage for FSDP""" disable_fake_mode: bool = False @@ -747,6 +755,7 @@ class JobConfig: activation_checkpoint: ActivationCheckpoint = field( default_factory=ActivationCheckpoint ) + compile: Compile = field(default_factory=Compile) float8: Float8 = field(default_factory=Float8) mx: MX = field(default_factory=MX) comm: Comm = field(default_factory=Comm) diff --git a/torchtitan/experiments/flux/loss.py b/torchtitan/experiments/flux/loss.py index 9159b40b8a..6bf93f9d00 100644 --- a/torchtitan/experiments/flux/loss.py +++ b/torchtitan/experiments/flux/loss.py @@ -21,7 +21,7 @@ def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: def build_mse_loss(job_config: JobConfig): loss_fn = mse_loss - if job_config.training.compile: + if job_config.compile.enable and "loss" in job_config.compile.components: logger.info("Compiling the loss function with torch.compile") loss_fn = torch.compile(loss_fn) return loss_fn diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml index e565b23bd9..9be99b0424 100644 --- a/torchtitan/experiments/flux/train_configs/debug_model.toml +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -36,7 +36,6 @@ decay_ratio = 0.0 # no decay, stay stable during training local_batch_size = 4 max_norm = 2.0 # grad norm clipping steps = 10 -compile = false dataset = "cc12m-test" classifier_free_guidance_prob = 0.447 img_size = 256 diff --git a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml index 5fbdcb6fca..083ad7977a 100644 --- a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml @@ -35,7 +35,6 @@ decay_ratio = 0.0 # no decay local_batch_size = 32 max_norm = 1.0 # grad norm clipping steps = 30_000 -compile = false dataset = "cc12m-wds" classifier_free_guidance_prob = 0.447 img_size = 256 diff --git a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml index d479710e62..0a9cce71c7 100644 --- a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml @@ -35,7 +35,6 @@ decay_ratio = 0.0 # no decay local_batch_size = 64 max_norm = 1.0 # grad norm clipping steps = 30_000 -compile = false dataset = "cc12m-wds" classifier_free_guidance_prob = 0.447 img_size = 256 diff --git a/torchtitan/experiments/forge/job_config.py b/torchtitan/experiments/forge/job_config.py index 56602e3520..f65488b012 100644 --- a/torchtitan/experiments/forge/job_config.py +++ b/torchtitan/experiments/forge/job_config.py @@ -11,6 +11,7 @@ ActivationCheckpoint, Checkpoint, Comm, + Compile, Float8, LRScheduler, Model, @@ -31,6 +32,7 @@ class ForgeJobConfig: activation_checkpoint: ActivationCheckpoint = field( default_factory=ActivationCheckpoint ) + compile: Compile = field(default_factory=Compile) float8: Float8 = field(default_factory=Float8) comm: Comm = field(default_factory=Comm) diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 35a72167d0..a716c78907 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -62,12 +62,15 @@ def parallelize_llama( ): raise NotImplementedError("CP support for FlexAttention is still in progress.") + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel - and not job_config.training.compile + and not model_compile_enabled ): - raise RuntimeError("Async TP requires --training.compile") + raise RuntimeError("Async TP requires torch.compile") enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -107,7 +110,7 @@ def parallelize_llama( apply_ac(model, job_config.activation_checkpoint) # turn on per-TransformerBlock compile after AC wrapping and before FSDP - if job_config.training.compile: + if model_compile_enabled: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE torch._dynamo.config.capture_scalar_outputs = True apply_compile(model) diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index f445b2ad7a..0bdb16ecb9 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -41,7 +41,6 @@ local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 -compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -68,6 +67,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] mode = "selective" # ["none", "selective", "full"] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index cb69e63e24..c40437b377 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -34,7 +34,6 @@ local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 -compile = false dataset = "c4" [parallelism] @@ -60,6 +59,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = "full" # ["none", "selective", "full"] +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index 4e7416fd24..ab718cf6f9 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -34,7 +34,6 @@ local_batch_size = 8 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 -compile = false dataset = "c4" [parallelism] @@ -58,6 +57,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = "full" # ["none", "selective", "full"] +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml b/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml index 5c73423af0..38dc259496 100644 --- a/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml +++ b/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml @@ -31,7 +31,6 @@ local_batch_size = 4 seq_len = 4096 max_norm = 1.0 # grad norm clipping steps = 10 -compile = false dataset = "c4" [parallelism] @@ -53,6 +52,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] mode = "selective" # ["none", "selective", "full"] selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md index 82eeec877e..43edc6d80a 100644 --- a/torchtitan/experiments/simple_fsdp/README.md +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -13,7 +13,7 @@ This folder includes an experimental frontend implementation for [SimpleFSDP: Si ### Enable SimpleFSDP Training ```bash -CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --compile.enable ``` ### Composability Support diff --git a/torchtitan/experiments/simple_fsdp/parallelize.py b/torchtitan/experiments/simple_fsdp/parallelize.py index ef02a4bf63..4d909e4fe4 100644 --- a/torchtitan/experiments/simple_fsdp/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/parallelize.py @@ -37,12 +37,15 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel - and not job_config.training.compile + and not model_compile_enabled ): - raise RuntimeError("Async TP requires --training.compile") + raise RuntimeError("Async TP requires torch.compile") enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -99,7 +102,7 @@ def parallelize_llama( ) logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode) - if job_config.training.compile: + if model_compile_enabled: torch._inductor.config.reorder_for_peak_memory = False model = torch.compile(model, fullgraph=True) diff --git a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py index d0579adcd1..45e99aa6c9 100755 --- a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py +++ b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py @@ -220,7 +220,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): for idx, override_arg in enumerate(test_flavor.override_args): cmd = ( f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_train.sh " - f"--model.name llama3_simple_fsdp --training.compile " + f"--model.name llama3_simple_fsdp --compile.enable " ) # dump compile trace for debugging purpose cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 1a64d34a1c..1aedd73ad3 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -51,10 +51,13 @@ def parallelize_deepseekv3( ): raise NotImplementedError("CP support for FlexAttention is still in progress.") + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel - and not job_config.training.compile + and not model_compile_enabled ): raise RuntimeError("Async TP requires --training.compile") @@ -97,7 +100,7 @@ def parallelize_deepseekv3( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) - if job_config.training.compile: + if model_compile_enabled: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE torch._dynamo.config.capture_scalar_outputs = True apply_compile(model) diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index bb564bd38a..79a15bd2e2 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -41,7 +41,6 @@ local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 -compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -66,7 +65,11 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] -selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] [float8] enable_fsdp_float8_all_gather = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 15ce11bd07..3ef6e67fc0 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -39,7 +39,6 @@ local_batch_size = 8 seq_len = 4096 max_norm = 1.0 # grad norm clipping steps = 1000 -compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -50,7 +49,7 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 1 +expert_parallel_degree = 8 expert_tensor_parallel_degree = 1 [checkpoint] @@ -62,7 +61,12 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" [activation_checkpoint] -mode = "full" # ["none", "selective", "full"] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=true +components = ["loss"] # ["model", "loss"] [float8] enable_fsdp_float8_all_gather = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 614719dd24..23dc315d05 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -46,7 +46,7 @@ dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 8 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" @@ -62,7 +62,12 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" [activation_checkpoint] -mode = "full" # ["none", "selective", "full"] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=true +components = ["loss"] # ["model", "loss"] [float8] enable_fsdp_float8_all_gather = false diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 6d9bf60c11..2e2e81302a 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -63,12 +63,15 @@ def parallelize_llama( ): raise NotImplementedError("CP support for FlexAttention is still in progress.") + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel - and not job_config.training.compile + and not model_compile_enabled ): - raise RuntimeError("Async TP requires --training.compile") + raise RuntimeError("Async TP requires torch.compile") enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -93,7 +96,7 @@ def parallelize_llama( apply_ac(model, job_config.activation_checkpoint) # turn on per-TransformerBlock compile after AC wrapping and before FSDP - if job_config.training.compile: + if model_compile_enabled: apply_compile(model) if parallel_dims.fsdp_enabled: @@ -129,7 +132,7 @@ def parallelize_llama( apply_ddp( model, world_mesh, - enable_compile=job_config.training.compile, + enable_compile=model_compile_enabled, enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 0607268a75..d446027f48 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -43,7 +43,6 @@ local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 -compile = false dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -67,6 +66,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] mode = "selective" # ["none", "selective", "full"] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/models/llama3/train_configs/llama3_405b.toml b/torchtitan/models/llama3/train_configs/llama3_405b.toml index 471ed981bc..5895f7f255 100644 --- a/torchtitan/models/llama3/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_405b.toml @@ -34,7 +34,6 @@ local_batch_size = 2 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 -compile = true dataset = "c4" [parallelism] @@ -56,6 +55,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = "full" # ["none", "selective", "full"] +[compile] +enable=true +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = true precompute_float8_dynamic_scale_for_fsdp = true diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index 8a3f2018e8..9a2eddd093 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -34,7 +34,6 @@ local_batch_size = 8 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 1000 -compile = false dataset = "c4" [parallelism] @@ -55,6 +54,10 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = "full" +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 038f9b33f6..d9a9c331f7 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -34,7 +34,6 @@ local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 1000 -compile = false dataset = "c4" [parallelism] @@ -52,6 +51,10 @@ last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] +[compile] +enable=false +components = ["model", "loss"] + [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy From 82d6c3b0382dcde617ddbb7195e98e47e121302a Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Thu, 21 Aug 2025 10:12:12 -0700 Subject: [PATCH 111/128] [DSV3] Upgrade to DeepSeek-V3.1 (#1609) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tested Loading weights from https://huggingface.co/deepseek-ai/DeepSeek-V3.1-Base Screenshot 2025-08-20 at 10 28
20 PM --- torchtitan/models/deepseek_v3/README.md | 2 +- .../models/deepseek_v3/train_configs/deepseek_v3_671b.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index 15860c2361..fcccd65fc5 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -8,7 +8,7 @@ DeepSeek-V3 is a Mixture-of-Experts (MoE) transformer model with Multi-head Late ```bash # DeepSeek 671B tokenizer (automatically downloads tokenizer.json and tokenizer_config.json) -python scripts/download_hf_assets.py --repo_id deepseek-ai/DeepSeek-V3 --assets tokenizer +python scripts/download_hf_assets.py --repo_id deepseek-ai/DeepSeek-V3.1-Base --assets tokenizer ``` ```bash diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 23dc315d05..1a748d56f1 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -20,7 +20,7 @@ enable_wandb = false [model] name = "deepseek_v3" flavor = "671B" -hf_assets_path = "./assets/hf/DeepSeek-V3" +hf_assets_path = "./assets/hf/DeepSeek-V3.1-Base" # converters = ["float8"] [optimizer] From fd230800d2385e782f4443e24f1e45d83a6e6b0d Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Thu, 21 Aug 2025 10:12:28 -0700 Subject: [PATCH 112/128] Fix Typo (#1611) We want to include this PR in our next release ASAP. Created another branch and revert CODE_OF_CONDUCT.md from @BioGeek 's #1583 . Much appreciated for @BioGeek's contribution! --------- Co-authored-by: Jeroen Van Goey --- docs/composability.md | 2 +- docs/debugging.md | 2 +- tests/unit_tests/test_activation_checkpoint.py | 2 +- tests/unit_tests/test_lr_scheduler.py | 2 +- torchtitan/components/checkpoint.py | 10 +++++----- torchtitan/components/ft/config/job_config.py | 2 +- torchtitan/components/ft/manager.py | 2 +- torchtitan/components/lr_scheduler.py | 2 +- torchtitan/config/job_config.py | 4 ++-- torchtitan/config/manager.py | 2 +- torchtitan/distributed/utils.py | 4 ++-- torchtitan/experiments/deepseek_v3/group_gemms.py | 2 +- torchtitan/experiments/deepseek_v3/model.py | 8 ++++---- torchtitan/experiments/flux/README.md | 2 +- torchtitan/experiments/flux/dataset/flux_dataset.py | 4 ++-- torchtitan/experiments/flux/job_config.py | 4 ++-- torchtitan/experiments/flux/model/layers.py | 4 ++-- torchtitan/experiments/flux/model/model.py | 2 +- torchtitan/experiments/flux/sampling.py | 2 +- torchtitan/experiments/flux/train.py | 2 +- torchtitan/experiments/flux/utils.py | 2 +- .../triton_contiguous_group_gemm/cg_forward.py | 2 +- torchtitan/experiments/llama4/README.md | 2 +- torchtitan/experiments/multimodal/model.py | 6 +++--- torchtitan/experiments/simple_fsdp/simple_fsdp.py | 2 +- torchtitan/models/README.md | 2 +- torchtitan/models/attention.py | 2 +- torchtitan/models/deepseek_v3/README.md | 4 ++-- .../models/deepseek_v3/model/state_dict_adapter.py | 12 ++++++------ torchtitan/models/moe.py | 2 +- torchtitan/protocols/model_converter.py | 2 +- torchtitan/tools/utils.py | 2 +- 32 files changed, 52 insertions(+), 52 deletions(-) diff --git a/docs/composability.md b/docs/composability.md index 8063fa89b1..1efe7bb9d4 100644 --- a/docs/composability.md +++ b/docs/composability.md @@ -14,7 +14,7 @@ Example ([PR #322](https://github.com/pytorch/torchtitan/pull/322)): We decided to actually reuse the top-level model object on every PP stage, just delete the layers we don't want, and make sure that the top-level forward would do the right thing. This means we don't have to make a separate runtime pp_forward that glues together child modules per stage. The first change was using a moduledict instead of modulelist to store layers. This preserves layer Fully Qualified Names (FQNs) even when deleting some layers - e.g. layers.1 stays layers.1 even if you remove layers.0, which isn't true for a list- this matters for checkpoint save/load. Preserving FQNs is a requirement for using Distributed Checkpointing (DCP) since it uses FQNs as globally unique IDs for sharding metadata. The second change was making the input and output layers optional- if the layer exists, we run it, otherwise we feed the input through to bypass it. With these two changes, we can just (meta)-initialize the whole model, delete the unused parts per stage, then materialize the remaining part on GPU before loading a checkpoint. ## Using a seed checkpoint for init -Initializing the pipeline-parallel model is challenging becuase we assume the model could be so large as to not fit on local GPU (or possibly, even on CPU), and we also want to use the (bitwise) same initialization as we use for 1D or 2D parallel models, to ease debugging or comparisons between runs. It's not that easy to rewrite the original model's `init_weights` function to be tolerant of initializing only some layers, and also serializing initialization operations globally for consistent RNG order. +Initializing the pipeline-parallel model is challenging because we assume the model could be so large as to not fit on local GPU (or possibly, even on CPU), and we also want to use the (bitwise) same initialization as we use for 1D or 2D parallel models, to ease debugging or comparisons between runs. It's not that easy to rewrite the original model's `init_weights` function to be tolerant of initializing only some layers, and also serializing initialization operations globally for consistent RNG order. For now, we sidestep all these problems with a simple but brutal solution: Initialize the whole model on some CPU instance, save a checkpoint file, and then lean on Distributed Checkpointing's "load" functionality to initialize the FQNs that are present on a given PP stage after stage creation. For future work, we consider adding a more elaborate initialization scheme to `torch.pipelining`. diff --git a/docs/debugging.md b/docs/debugging.md index 28bad0e3d1..f8479d8e0a 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -116,4 +116,4 @@ Here's a typical comparison setup (maintaining an overall DP degree of 4): To reproduce loss curves across above runs, you'll need to create a seed checkpoint, and then load the same seed checkpoint for all runs to ensure consistent model initialization on each rank. You might also need to set the `deterministic` mode to ensure consistent training behavior. -We also provided an example of verifying the numerical consistency across parallism plans configs on Llama 3 in https://github.com/pytorch/torchtitan/blob/main/docs/converging.md. +We also provided an example of verifying the numerical consistency across parallelism plans configs on Llama 3 in https://github.com/pytorch/torchtitan/blob/main/docs/converging.md. diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index a253c4fb5b..a4dbc21a5f 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -171,7 +171,7 @@ def get_act_mem(model_fn): self.assertEqual(mem_with_force_last, 1.0) self.assertEqual(mem_full_ac, 0.0) # Note: SAC > no-AC here because it unnecessarily saves "output" - # even that is not needed for recomputaion and output is double + # even that is not needed for recomputation and output is double # the size of the other two mms. def test_correctness(self): diff --git a/tests/unit_tests/test_lr_scheduler.py b/tests/unit_tests/test_lr_scheduler.py index dfa51751dc..00c817a46a 100644 --- a/tests/unit_tests/test_lr_scheduler.py +++ b/tests/unit_tests/test_lr_scheduler.py @@ -256,7 +256,7 @@ def test_warmup_stable_only(self): def test_warmup_plus_decay_exceeds_training(self): """Test when warmup + decay steps exceed training steps.""" # Create a job config where warmup + decay steps > training steps - # Expected behaviro: warmup steps = 5, decay steps = 5 + # Expected behavior: warmup steps = 5, decay steps = 5 config = self.create_job_config( training_steps=10, warmup_steps=5, diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 2df8f9cd4b..bf6f0b3ae2 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -138,7 +138,7 @@ class CheckpointManager: We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object into one state dict before saving/loading. We rely on the individual - state_dicts to not collide, which is gauranteed for the model by correct pipeline + state_dicts to not collide, which is guaranteed for the model by correct pipeline splitting and for the optimizer by the flattening support described in (1). 3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers @@ -146,12 +146,12 @@ class CheckpointManager: Note: TorchFT checkpointing flow - There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent + There are two types of checkpoints: when TorchFT is enabled: 1) the full persistent checkpoint, 2) the per-replica checkpoint. - The full perisistent checkpoint is saved by the replica with + The full persistent checkpoint is saved by the replica with ``ft_manager.participating_rank() == 0``. It contains everything including the model, - optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent + optimizer, lr_scheduler, dataloader, and train_state. Right now the full persistent checkpoint is loaded by all replicas. However, we can optimize it to only load if there are no other alive replicas. @@ -294,7 +294,7 @@ def load_state_dict(state_dict): self.async_mode = AsyncMode.ASYNC_WITH_PINNED_MEM else: raise ValueError( - f"Unkown checkpoint async_mode {checkpoint_config.async_mode}" + f"Unknown checkpoint async_mode {checkpoint_config.async_mode}" ) logger.info( diff --git a/torchtitan/components/ft/config/job_config.py b/torchtitan/components/ft/config/job_config.py index c5bc309f72..df33c049a4 100644 --- a/torchtitan/components/ft/config/job_config.py +++ b/torchtitan/components/ft/config/job_config.py @@ -52,7 +52,7 @@ class FaultTolerance(BaseFaultTolerance): Determines how to mix the local and global optimized parameters By default, we just use the global parameters. This ensures all - DDP replicas have the same parameters after syncrhonizing on + DDP replicas have the same parameters after synchronizing on the fragment. Tuning this can also affect the model quality. This is only used when "semi_sync_method" is set. diff --git a/torchtitan/components/ft/manager.py b/torchtitan/components/ft/manager.py index 38ec5173bd..5d64d34b09 100644 --- a/torchtitan/components/ft/manager.py +++ b/torchtitan/components/ft/manager.py @@ -49,7 +49,7 @@ def __init__( elif ft_config.process_group == "nccl": pg = ft.ProcessGroupNCCL(timeout=process_group_timeout) else: - raise ValueError(f"Unsuported process group: {ft_config.process_group}") + raise ValueError(f"Unsupported process group: {ft_config.process_group}") # If the training method is specific, then the quorum should be synchronous self.use_async_quorum = ft_config.semi_sync_method is None diff --git a/torchtitan/components/lr_scheduler.py b/torchtitan/components/lr_scheduler.py index 9bdccf7981..6384feb641 100644 --- a/torchtitan/components/lr_scheduler.py +++ b/torchtitan/components/lr_scheduler.py @@ -124,7 +124,7 @@ def build_lr_schedulers( decay_steps = training_steps - warmup_steps else: decay_steps = training_steps - warmup_steps - # Add a vitual last step to prevent the learning rate from dropping to 0 + # Add a virtual last step to prevent the learning rate from dropping to 0 stable_steps = training_steps + 1 - warmup_steps - decay_steps lr_decay_type = lr_scheduler_config.decay_type min_lr_factor = lr_scheduler_config.min_lr_factor diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index a688cdadae..a43b3ce060 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -35,7 +35,7 @@ class Profiling: """Trace files location""" profile_freq: int = 10 - """How often to collect profile traces, in interations""" + """How often to collect profile traces, in iterations""" enable_memory_snapshot: bool = False """Whether to dump memory snapshot""" @@ -381,7 +381,7 @@ class Parallelism: - cp * tp <= ep <= dp_shard * cp * tp - ep % (cp * tp) == 0 - dp_shard * cp * tp % ep == 0 - Note that this is still an experimental feature. Some contrains will be + Note that this is still an experimental feature. Some constraints will be relaxed soon when we have more flexible DeviceMesh support. """ diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py index ce0fe35c0f..f2dbe1d357 100644 --- a/torchtitan/config/manager.py +++ b/torchtitan/config/manager.py @@ -204,7 +204,7 @@ def _validate_config(self) -> None: def register_tyro_rules(registry: tyro.constructors.ConstructorRegistry) -> None: @registry.primitive_rule def list_str_rule(type_info: tyro.constructors.PrimitiveTypeInfo): - """Support for comma seperated string parsing""" + """Support for comma separate string parsing""" if type_info.type != list[str]: return None return tyro.constructors.PrimitiveConstructorSpec( diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 13cd700eb2..5ee31c39c6 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -122,10 +122,10 @@ def set_determinism( torch.distributed.broadcast(seed_tensor, src=0) seed = seed_tensor.to("cpu").view(torch.uint64).item() - # Set distinct seed for each rank in mesh dimensions, with dimension name provdied by `distinct_seed_mesh_dim` + # Set distinct seed for each rank in mesh dimensions, with dimension name provided by `distinct_seed_mesh_dim` # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh, # and choose a unique seed for each rank on the PP mesh. - # TODO(jianiw): We could further extend this to support mutiple distinct dimensions instead of just one. + # TODO(jianiw): We could further extend this to support multiple distinct dimensions instead of just one. if ( c10d.get_world_size() > 1 and distinct_seed_mesh_dim in world_mesh.mesh_dim_names diff --git a/torchtitan/experiments/deepseek_v3/group_gemms.py b/torchtitan/experiments/deepseek_v3/group_gemms.py index 0a52ec50f4..3e886fc1f1 100644 --- a/torchtitan/experiments/deepseek_v3/group_gemms.py +++ b/torchtitan/experiments/deepseek_v3/group_gemms.py @@ -403,7 +403,7 @@ def arrange_expert_weights(self, all_weights, submod_name, module): fp8, scales = dsgemm_utils.prepare_fp8_weight(combined_weights) # prescale weights - # TODO - this creates 2 sets of weights, need to resolve this for traiing aspect. + # TODO - this creates 2 sets of weights, need to resolve this for training aspect. module.register_parameter( f"{submod_name}_fp8", nn.Parameter( diff --git a/torchtitan/experiments/deepseek_v3/model.py b/torchtitan/experiments/deepseek_v3/model.py index 615f20ba8b..5ee68524c6 100644 --- a/torchtitan/experiments/deepseek_v3/model.py +++ b/torchtitan/experiments/deepseek_v3/model.py @@ -382,7 +382,7 @@ def __init__(self, config): if self.topk_method == "noaux_tc": self.e_score_correction_bias = nn.Parameter( # Changed from torch.empty to torch.rand to avoid non-even - # distribution for runs without actual weigths + # distribution for runs without actual weights torch.rand((self.n_routed_experts)) ) self.reset_parameters() @@ -519,7 +519,7 @@ def __init__(self, config): assert ( MoE.group_mm in MoE.group_gemm_strategies - ), f"selected group gemm {self.group_mm} is not avaiable!" + ), f"selected group gemm {self.group_mm} is not available!" # keep active gg ready self.group_gemm_instance = MoE.group_gemm_strategies[MoE.group_mm] self._buffer_initialized = False @@ -695,7 +695,7 @@ def moe_forward(self, x, topk_ids, topk_weight): # TODO: don't use `received` gathered_tokens = token_gather_buf[:received] else: # "torch_all_to_all" - # Prepare input ans output splits + # Prepare input and output splits with torch.no_grad(): output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum( dim=1 @@ -1349,7 +1349,7 @@ def prepare_inputs_for_generation( # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if ( attention_mask is not None diff --git a/torchtitan/experiments/flux/README.md b/torchtitan/experiments/flux/README.md index 880a74b9a8..f47face41c 100644 --- a/torchtitan/experiments/flux/README.md +++ b/torchtitan/experiments/flux/README.md @@ -53,7 +53,7 @@ python -m torchtitan.experiments.flux.tests.integration_tests - Parallelism: The model supports FSDP, HSDP for training on multiple GPUs. - Activation checkpointing: The model uses activation checkpointing to reduce memory usage during training. - Distributed checkpointing and loading. - - Notes on the current checkpointing implementation: To keep the model wieghts are sharded the same way as checkpointing, we need to shard the model weights before saving the checkpoint. This is done by checking each module at the end of envaluation, and sharding the weights of the module if it is a FSDPModule. + - Notes on the current checkpointing implementation: To keep the model weights are sharded the same way as checkpointing, we need to shard the model weights before saving the checkpoint. This is done by checking each module at the end of evaluation, and sharding the weights of the module if it is a FSDPModule. - CI for FLUX model. Supported periodically running integration tests on 8 GPUs, and unittests. diff --git a/torchtitan/experiments/flux/dataset/flux_dataset.py b/torchtitan/experiments/flux/dataset/flux_dataset.py index df266496c2..02fd73afec 100644 --- a/torchtitan/experiments/flux/dataset/flux_dataset.py +++ b/torchtitan/experiments/flux/dataset/flux_dataset.py @@ -246,8 +246,8 @@ def __iter__(self): # TODO: Add support for robust data loading and error handling. # Currently, we assume the dataset is well-formed and does not contain corrupted samples. # If a corrupted sample is encountered, the program will crash and throw an exception. - # You can NOT try to catch the exception and continue, becuase the iterator within dataset - # is not broken after raising an exception, so calling next() will thorw StopIteration and might cause re-loop. + # You can NOT try to catch the exception and continue, because the iterator within dataset + # is not broken after raising an exception, so calling next() will throw StopIteration and might cause re-loop. try: sample = next(dataset_iterator) except StopIteration: diff --git a/torchtitan/experiments/flux/job_config.py b/torchtitan/experiments/flux/job_config.py index 0b139ed42f..60422de2ee 100644 --- a/torchtitan/experiments/flux/job_config.py +++ b/torchtitan/experiments/flux/job_config.py @@ -16,7 +16,7 @@ class Training: img_size: int = 256 """Image width to sample""" test_mode: bool = False - """Whether to use intergration test mode, which will randomly initialize the encoder and use a dummy tokenizer""" + """Whether to use integration test mode, which will randomly initialize the encoder and use a dummy tokenizer""" @dataclass @@ -71,7 +71,7 @@ class Inference: @dataclass class JobConfig: """ - Extend the tyro parser with custom config classe for Flux model. + Extend the tyro parser with custom config classes for Flux model. """ training: Training = field(default_factory=Training) diff --git a/torchtitan/experiments/flux/model/layers.py b/torchtitan/experiments/flux/model/layers.py index a8d2d3af46..3aff5df5e1 100644 --- a/torchtitan/experiments/flux/model/layers.py +++ b/torchtitan/experiments/flux/model/layers.py @@ -232,13 +232,13 @@ def forward( attn = attention(q, k, v, pe=pe) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - # calculate the img bloks + # calculate the img blocks img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod2.gate * self.img_mlp( (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift ) - # calculate the txt bloks + # calculate the txt blocks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp( (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift diff --git a/torchtitan/experiments/flux/model/model.py b/torchtitan/experiments/flux/model/model.py index b8429878d6..33a3615307 100644 --- a/torchtitan/experiments/flux/model/model.py +++ b/torchtitan/experiments/flux/model/model.py @@ -25,7 +25,7 @@ class FluxModel(nn.Module, ModelProtocol): """ Transformer model for flow matching on sequences. - Agrs: + Args: model_args: FluxModelArgs. Attributes: diff --git a/torchtitan/experiments/flux/sampling.py b/torchtitan/experiments/flux/sampling.py index 4a5a1157f8..8bc3464dcd 100644 --- a/torchtitan/experiments/flux/sampling.py +++ b/torchtitan/experiments/flux/sampling.py @@ -85,7 +85,7 @@ def generate_image( ) -> torch.Tensor: """ Sampling and save a single images from noise using a given prompt. - For randomized noise generation, the random seend should already be set at the begining of training. + For randomized noise generation, the random seend should already be set at the beginning of training. Since we will always use the local random seed on this rank, we don't need to pass in the seed again. """ diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index 7af97cff35..e364b30e53 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -104,7 +104,7 @@ def forward_backward_step( # Keep these variables local to shorten the code as these are # the major variables that are used in the training loop. - # explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not + # explicitly convert flux model to be Bfloat16 no matter FSDP is applied or not model = self.model_parts[0] # image in latent space transformed by self.auto_encoder diff --git a/torchtitan/experiments/flux/utils.py b/torchtitan/experiments/flux/utils.py index d49fd77bc6..8b4cb00d31 100644 --- a/torchtitan/experiments/flux/utils.py +++ b/torchtitan/experiments/flux/utils.py @@ -65,7 +65,7 @@ def generate_noise_latent( dtype: torch.dtype, seed: int | None = None, ) -> Tensor: - """Generate noise latents for the Flux flow model. The random seed will be set at the begining of training. + """Generate noise latents for the Flux flow model. The random seed will be set at the beginning of training. Args: bsz (int): batch_size. diff --git a/torchtitan/experiments/kernels/triton_contiguous_group_gemm/cg_forward.py b/torchtitan/experiments/kernels/triton_contiguous_group_gemm/cg_forward.py index 68553e08ae..7a5878e738 100644 --- a/torchtitan/experiments/kernels/triton_contiguous_group_gemm/cg_forward.py +++ b/torchtitan/experiments/kernels/triton_contiguous_group_gemm/cg_forward.py @@ -29,7 +29,7 @@ # ============ Triton kernel for contiguous grouped GEMM ============ -# L2 Caching optmization +# L2 Caching optimization @triton.jit diff --git a/torchtitan/experiments/llama4/README.md b/torchtitan/experiments/llama4/README.md index 964bc3741f..635c8c71b9 100644 --- a/torchtitan/experiments/llama4/README.md +++ b/torchtitan/experiments/llama4/README.md @@ -26,5 +26,5 @@ python scripts/download_hf_assets.py --assets tokenizer --repo_id meta-llama/Lla - Quantization - efficient float8 Grouped MM kernels (from torchao) - Testing - - perfomance and loss converging tests + - performance and loss converging tests - CI integration diff --git a/torchtitan/experiments/multimodal/model.py b/torchtitan/experiments/multimodal/model.py index d1783a5bd7..83d43c4b95 100644 --- a/torchtitan/experiments/multimodal/model.py +++ b/torchtitan/experiments/multimodal/model.py @@ -839,7 +839,7 @@ def forward( Processes images and returns the tokens and hidden states. Multiple images per sample: we add a dimension num_imgs to the input. This is useful when a single - sample constains multiple images, for example: + sample contains multiple images, for example: - sample 1: " what animal is this?" - sample 2: "I like more than " @@ -999,7 +999,7 @@ def forward( class FeedForwardForDecoder(nn.Module): """ FeedForward module for the decoder. It's different from the one in the encoder. - This is the component which is orignally used in llama3. + This is the component which is originally used in llama3. """ def __init__( @@ -1301,7 +1301,7 @@ class FusionLayer(nn.Module): """ Deep Fusion model architectures combine pretrained encoder models with pretrained language models by infusing the encoder outputs into the middle layers of the LLM. - This allows the language model to interpret the enocder outputs as text and + This allows the language model to interpret the encoder outputs as text and "understand" any modality for which you can train an decoder. To enable the language model to adapt to the encoder outputs, the FusionLayer fuses a new learnable layer to an existing decoder (language model) layer. This additional layer can take the encoder embeddings and diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index d90f3a67e5..8f7a2f4da8 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -227,7 +227,7 @@ def replicate_compute(self, x): ) # re-wrap 1D all-gathered DTensor on dp_mesh to 1D DTensor on tp_mesh - # TODO: DTensor should support this mesh collasping operation + # TODO: DTensor should support this mesh collapsing operation replicated_local_tensor = replicated_dtensor.to_local( grad_placements=self.grad_placements ) diff --git a/torchtitan/models/README.md b/torchtitan/models/README.md index d76ac4fc24..9c8b960609 100644 --- a/torchtitan/models/README.md +++ b/torchtitan/models/README.md @@ -4,7 +4,7 @@ For offline explorations, we recommend the same steps, unless otherwise noted. ## Adding the model -Please refer to the [Llama 3 folder](.llama3) as an example. +Please refer to the [Llama 3 folder](llama3) as an example. The folder should be organized as follows - `model` folder: a self-contained folder of model definition and args diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 570d894f51..277d64be1b 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -118,7 +118,7 @@ def _fixed_block_mask_mod( mask_mod: _mask_mod_signature, fixed_block_size: int ) -> _mask_mod_signature: """ - Given an arbirary mask_mod, divide the input sequence to blocks + Given an arbitrary mask_mod, divide the input sequence to blocks and only allow attention within the same block. Args: diff --git a/torchtitan/models/deepseek_v3/README.md b/torchtitan/models/deepseek_v3/README.md index fcccd65fc5..3cfb5e7d3c 100644 --- a/torchtitan/models/deepseek_v3/README.md +++ b/torchtitan/models/deepseek_v3/README.md @@ -52,9 +52,9 @@ CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml ## HuggingFace -> DCP Checkpoint Conversion -We implemented StateDictAdapter to preform HuggingFace safetensor to DCP format conversion. Currently, we only support conversion from HF checkpoints to DCP checkpoints offline (using CPU plain tensor). +We implemented StateDictAdapter to perform HuggingFace safetensor to DCP format conversion. Currently, we only support conversion from HF checkpoints to DCP checkpoints offline (using CPU plain tensor). -Run the offine conversion script: +Run the offline conversion script: ```bash python scripts/checkpoint_conversion/convert_from_hf.py --model_name deepseek_v3 --model_flavor 671B ``` diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 5a676b5a07..0bdf456ef4 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -36,7 +36,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None): "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", - # Transfomer Layer + # Transformer Layer "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", # MoE Module @@ -65,7 +65,7 @@ def _concatenate_expert_weights( self, expert_weights_by_layer: dict[str, Any], n_experts: int ) -> torch.Tensor: """ - Concatenate the weights of seprate experts into GroupedExpert weights. + Concatenate the weights of separate experts into GroupedExpert weights. """ for layer, abstract_keys in list(expert_weights_by_layer.items()): for abstract_key, experts in list(abstract_keys.items()): @@ -137,7 +137,7 @@ def _add_quantization_scale_inv_tensors( def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. Convert between the HF shape and the torchtitan shape. - 2. Split the GroupedExperts' weight into seprate expert's wegiht. + 2. Split the GroupedExperts' weight into separate expert's wegiht. """ to_hf_map = {v: k for k, v in self.from_hf_map.items()} @@ -149,7 +149,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: layer_num = re.search(r"\d+", key).group(0) new_abstract_key = to_hf_map[abstract_key] - # Split expert weights into seperate expert weights + # Split expert weights into separate expert weights split_values = self._split_experts_weights( value, self.model_args.moe_args.num_experts ) @@ -178,7 +178,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. When loading from HF checkpoint, dequantize the weights from float8 to float32. 2. Convert between the HF shape and the torchtitan shape. - 3. Concate seprate expert's wegiht into GroupedExperts' weight. + 3. Concate separate expert's wegiht into GroupedExperts' weight. """ # dequantize the tensor in state_dict and remove the scale_inv tensor hf_state_dict = self._dequantize(hf_state_dict) @@ -193,7 +193,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: new_key = self.from_hf_map[abstract_key] new_key = new_key.format(layer_num) - # Store the expert's weight in expert_weights_by_layer for concating later. + # Store the expert's weight in expert_weights_by_layer for concatenating later. if layer_num not in expert_weights_by_layer: expert_weights_by_layer[layer_num] = {} if abstract_key not in expert_weights_by_layer[layer_num]: diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 40bd6c2cca..da7fc1bfa0 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -340,7 +340,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): # define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664) # NOTE: tokens_per_expert is accumulated in the model forward pass. - # expert_bias is updated outside the model in an optimzer step pre hook + # expert_bias is updated outside the model in an optimizer step pre hook # to work with gradient accumulation. self.load_balance_coeff = moe_args.load_balance_coeff if self.load_balance_coeff is not None: diff --git a/torchtitan/protocols/model_converter.py b/torchtitan/protocols/model_converter.py index 300c4231c3..dbfc3a99c3 100644 --- a/torchtitan/protocols/model_converter.py +++ b/torchtitan/protocols/model_converter.py @@ -25,7 +25,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ... def convert(self, model: nn.Module): - """Inplace convertion of the model.""" + """Inplace conversion of the model.""" ... def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index 45bbd4ab83..070cc2938f 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -54,7 +54,7 @@ def run(self, step_count: int): ) gc.collect() elif step_count > 1 and step_count % self.gc_freq == 0: - self.collect("Peforming periodical GC collection") + self.collect("Performing periodical GC collection") @staticmethod def collect(reason: str, generation: int = 1): From 2bfcdd8e149e49b9e958fd58a9fbed261754a1af Mon Sep 17 00:00:00 2001 From: rakkit <401872089@qq.com> Date: Fri, 22 Aug 2025 04:05:00 +0300 Subject: [PATCH 113/128] improve MoE bias update logic in optimizer (#1593) We put all experts' usage into a buffer such that we only need one reduce rather than #number-of-layers times Additionally, handle cases where tokens per expert are counted twice during full recompute. Co-authored-by: wang55 --- torchtitan/components/optimizer.py | 63 ++++++++++++++++++++++-------- torchtitan/models/moe.py | 27 +++++++------ 2 files changed, 61 insertions(+), 29 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index ce71ac7f0c..d3e9628103 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointImpl from torch.distributed.checkpoint.state_dict import ( get_optimizer_state_dict, set_optimizer_state_dict, @@ -340,6 +341,9 @@ def build_optimizers_with_moe_load_balancing( ) # for MoE auxiliary-loss-free load balancing + def _is_recomputation_enabled(module): + return getattr(module, "checkpoint_impl", None) is CheckpointImpl.NO_REENTRANT + def _update_expert_bias( model_parts: list[nn.Module], parallel_dims: ParallelDims, @@ -349,25 +353,52 @@ def _update_expert_bias( ) # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. + tokens_per_expert_list = [] for model_part in model_parts: for transformer_block in model_part.layers.values(): - if transformer_block.moe_enabled: + if not transformer_block.moe_enabled: + continue + if transformer_block.moe.load_balance_coeff is None: + return + tokens_per_expert = transformer_block.moe.tokens_per_expert + if _is_recomputation_enabled(transformer_block): + # TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice. + # This does not affect to expert choice, but affects the experts usage metrics. + # We divide by 2 to correct for this double-counting due to recomputation + # TODO: new API to help determine if AC is enabled https://github.com/pytorch/pytorch/pull/160888 + tokens_per_expert = tokens_per_expert // 2 + tokens_per_expert_list.append(tokens_per_expert) + + tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) + + if dp_cp_mesh is not None: + # Perform single all-reduce to get global statistics across all processes + pg = dp_cp_mesh.get_group() + torch.distributed.all_reduce( + tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM + ) + + moe_layer_idx = 0 + with torch.no_grad(): + for model_part in model_parts: + for transformer_block in model_part.layers.values(): + if not transformer_block.moe_enabled: + continue moe = transformer_block.moe - if moe.load_balance_coeff is None: - return - - if dp_cp_mesh is not None: - torch.distributed.all_reduce( - moe.tokens_per_expert, group=dp_cp_mesh.get_group() - ) - - with torch.no_grad(): - expert_bias_delta = moe.load_balance_coeff * torch.sign( - moe.tokens_per_expert.mean() - moe.tokens_per_expert - ) - expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() - moe.expert_bias.add_(expert_bias_delta) - moe.tokens_per_expert.zero_() + + tokens_per_expert = tokens_per_expert_by_layer[ + moe_layer_idx + ].float() + moe_layer_idx += 1 + + # update the expert bias + # this is not exactly the same as https://arxiv.org/pdf/2408.15664 proposed + expert_bias_delta = moe.load_balance_coeff * torch.sign( + tokens_per_expert.mean() - tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + moe.expert_bias.add_(expert_bias_delta) + moe.tokens_per_expert.zero_() optimizers.register_step_pre_hook( lambda *args, **kwargs: _update_expert_bias( diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index da7fc1bfa0..3a9dd1b28a 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -350,13 +350,14 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): torch.zeros(num_experts, dtype=torch.float32), persistent=True, ) - self.register_buffer( - "tokens_per_expert", - torch.zeros(num_experts, dtype=torch.float32), - persistent=False, - ) else: self.expert_bias = None + # tokens_per_expert will be used to track expert usage and to update the expert bias for load balancing + self.register_buffer( + "tokens_per_expert", + torch.zeros(num_experts, dtype=torch.float32), + persistent=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -378,12 +379,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) = self.router(x, self.expert_bias) # tokens_per_expert will be used to update the expert bias for load balancing. + # and also to count the expert usage # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- # first in the forward pass, and then in the backward pass. However, this has no # effect on the expert bias update thanks to the torch.sign() operator. - if self.load_balance_coeff is not None: - with torch.no_grad(): - self.tokens_per_expert.add_(num_tokens_per_expert) + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) @@ -444,11 +445,11 @@ def init_weights( if self.shared_experts is not None: self.shared_experts.init_weights(init_std) - if self.load_balance_coeff is not None: - with torch.device(buffer_device): + with torch.device(buffer_device): + self.tokens_per_expert = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + if self.load_balance_coeff is not None: self.expert_bias = torch.zeros( self.experts.num_experts, dtype=torch.float32 ) - self.tokens_per_expert = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) From 255a6ab67bc14f14c94a2a9e9394f183691f8726 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Fri, 22 Aug 2025 15:28:12 +0800 Subject: [PATCH 114/128] fix qwen3 compile config in parallelize.py (#1623) fix compile config in parallelize.py, ref https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py --- .../experiments/qwen3/infra/parallelize.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/torchtitan/experiments/qwen3/infra/parallelize.py b/torchtitan/experiments/qwen3/infra/parallelize.py index 8648326770..8f7cb06ef7 100644 --- a/torchtitan/experiments/qwen3/infra/parallelize.py +++ b/torchtitan/experiments/qwen3/infra/parallelize.py @@ -45,12 +45,22 @@ def parallelize_qwen3( Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ + + if ( + job_config.parallelism.context_parallel_degree > 1 + and model.model_args.use_flex_attn + ): + raise NotImplementedError("CP support for FlexAttention is still in progress.") + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel - and not job_config.training.compile + and not model_compile_enabled ): - raise RuntimeError("Async TP requires --training.compile") + raise RuntimeError("Async TP requires torch.compile") enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -75,7 +85,7 @@ def parallelize_qwen3( apply_ac(model, job_config.activation_checkpoint) # turn on per-TransformerBlock compile after AC wrapping and before FSDP - if job_config.training.compile: + if model_compile_enabled: apply_compile(model) if parallel_dims.fsdp_enabled: @@ -95,11 +105,6 @@ def parallelize_qwen3( reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ) - if parallel_dims.dp_replicate_enabled: - logger.info("Applied HSDP to the model") - else: - logger.info("Applied FSDP to the model") - if parallel_dims.dp_replicate_enabled: logger.info("Applied HSDP to the model") else: @@ -116,7 +121,7 @@ def parallelize_qwen3( apply_ddp( model, world_mesh, - enable_compile=job_config.training.compile, + enable_compile=model_compile_enabled, enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) From 7d744b2318b279f9ae3afbec5a5f1f53dd740964 Mon Sep 17 00:00:00 2001 From: Garrett Goon <44747910+garrett361@users.noreply.github.com> Date: Fri, 22 Aug 2025 03:30:01 -0400 Subject: [PATCH 115/128] add model_parts ref to MetricsProcessor (#1578) Adds a `ModelProtocol.get_extra_metrics` method for more flexible custom metric reporting, as discussed in #1576 Probably this should be an abstract method, but I was wary of making this a breaking change for users who inherit this commit. The current signature is `get_extra_metrics(self, parallel_dims: ParallelDims) -> None | dict`. I also considered adding some subset of `JobConfig`, `TrainSpec`, and `pp_has_{first,last}_stage`; not sure what else might be useful. Tested via running the debugmodel with print statements. CC @rakkit @wwwjn --- torchtitan/components/metrics.py | 2 ++ torchtitan/train.py | 1 + 2 files changed, 3 insertions(+) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index dcd8782810..720e2b9d6d 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -319,6 +319,7 @@ class MetricsProcessor: num_flops_per_token: int optimizers: OptimizersContainer | None lr_schedulers: LRSchedulersContainer | None + model_parts: list[torch.nn.Module] | None def __init__( self, @@ -349,6 +350,7 @@ def __init__( self.num_flops_per_token = -1 self.optimizers = None self.lr_schedulers = None + self.model_parts = None def should_log(self, step: int) -> bool: return step == 1 or step % self.job_config.metrics.log_freq == 0 diff --git a/torchtitan/train.py b/torchtitan/train.py index e38446a398..1954c356db 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -294,6 +294,7 @@ def __init__(self, job_config: JobConfig): ) ) self.metrics_processor.optimizers = self.optimizers + self.metrics_processor.model_parts = self.model_parts # Initialize trainer states that will be saved in checkpoint. # These attributes must be initialized before checkpoint loading. From 8a749c61877471e2ca6f9d3aa4966d45462c2c15 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 22 Aug 2025 08:35:05 -0700 Subject: [PATCH 116/128] Move the call to init_attention_mask to trainer (#1616) One perspective on the attention mask is that it should be coupled with the dataloader rather than the modeling component. Therefore, this PR moves the creation of the attention mask to the trainer, removing it from the model itself. This PR also fixes https://github.com/pytorch/torchtitan/issues/1612 ``` -> % LOG_RANK=6 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --training.steps=100 --parallelism.pipeline_parallel_degree=4 + NGPU=8 + export LOG_RANK=6 + LOG_RANK=6 + CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml + TORCHFT_LIGHTHOUSE=http://localhost:29510 + PYTORCH_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510 + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 6 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml --training.steps=100 --parallelism.pipeline_parallel_degree=4 W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] ***************************************** W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] ***************************************** [rank6]:[titan] 2025-08-21 10:14:50,681 - root - INFO - Starting job: DeepSeek-V3 16B model training [rank6]:[titan] 2025-08-21 10:14:53,248 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank6]:[titan] 2025-08-21 10:14:53,250 - root - INFO - Building 2-D device mesh with ['pp', 'dp_shard'], [4, 2] [rank6]:[titan] 2025-08-21 10:14:53,265 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank6]:[titan] 2025-08-21 10:14:56,937 - root - INFO - Loading tokenizer from tokenizer.json [rank6]:[titan] 2025-08-21 10:14:57,076 - root - INFO - Preparing c4 dataset from allenai/c4 [rank6]:[titan] 2025-08-21 10:15:00,743 - root - INFO - Building deepseek_v3 16B with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=4096, vocab_size=102400, dim=2048, inter_dim=10944, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, n_heads=16, norm_eps=1e-05, moe_args=MoEArgs(num_experts=64, num_shared_experts=2, score_func='softmax', route_norm=True, route_scale=1.0, score_before_experts=False, top_k=6, use_grouped_mm=True, load_balance_coeff=0.001), n_expert_groups=1, n_limited_groups=1, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=True, attn_mask_type='block_causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7) [rank6]:[titan] 2025-08-21 10:15:00,966 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory [rank6]:[titan] 2025-08-21 10:15:01,008 - root - INFO - Total parameter count: dense 858,385,920, sparse 14,848,098,304, active 2,661,150,208 [rank6]:[titan] 2025-08-21 10:15:01,008 - root - INFO - Model deepseek_v3 16B size: 15,706,484,224 total parameters [rank6]:Stage 3: Modules to keep: {'layers.14', 'layers.13', 'layers.11', 'layers.12'} [rank6]:Stage 7: Modules to keep: {'output', 'norm', 'layers.26', 'layers.25'} [rank6]:[titan] 2025-08-21 10:15:01,029 - root - INFO - PP rank 3 is building stage_idx 3 with modules ['layers.11', 'layers.12', 'layers.13', 'layers.14'] [rank6]:[titan] 2025-08-21 10:15:01,048 - root - INFO - PP rank 3 is building stage_idx 7 with modules ['layers.25', 'layers.26', 'norm', 'output'] [rank6]:[titan] 2025-08-21 10:15:01,048 - root - INFO - Applied full activation checkpointing to the model [rank6]:[titan] 2025-08-21 10:15:01,072 - root - INFO - Applied FSDP to the model [rank6]:[titan] 2025-08-21 10:15:01,072 - root - INFO - Applied full activation checkpointing to the model [rank6]:[titan] 2025-08-21 10:15:01,080 - root - INFO - Applied FSDP to the model [rank6]:[titan] 2025-08-21 10:15:01,080 - root - INFO - Using pipeline schedule Interleaved1F1B with 8 microbatches and 8 stages. [rank6]:[titan] 2025-08-21 10:15:01,488 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank6]:[titan] 2025-08-21 10:15:01,488 - root - INFO - CUDA memory usage for model: 6.94GiB(7.31%) [rank6]:[titan] 2025-08-21 10:15:01,489 - root - WARNING - Warmup steps (200) exceed total training steps (100). Adjusting warmup steps to 100. [rank6]:[titan] 2025-08-21 10:15:01,489 - root - WARNING - Warmup (100) + decay (80) steps exceed total training steps (100). Adjusting decay steps to 0. [rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Mixed precision training is handled by fully_shard [rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Trainer is initialized with local batch size 8, global batch size 16, gradient accumulation steps 1, sequence length 4096, total steps 100 (warmup 200) [rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Training starts at step 1 [rank6]:[rank6]:[W821 10:15:10.781655306 ProcessGroupNCCL.cpp:3993] Warning: An unbatched P2P op (send/recv) was called on this ProcessGroup with size 4. In lazy initialization mode, this will result in a new 2-rank NCCL communicator to be created. (function operator()) [rank6]:NCCL version 2.27.5+cuda12.6 [rank6]:[rank6]:[W821 10:15:16.977607954 ProcessGroupNCCL.cpp:3993] Warning: An unbatched P2P op (send/recv) was called on this ProcessGroup with size 4. In lazy initialization mode, this will result in a new 2-rank NCCL communicator to be created. (function operator()) [rank6]:/data/users/chienchin/mywork/pytorch/torch/__init__.py:1539: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /data/users/chienchin/mywork/pytorch/aten/src/ATen/Context.cpp:80.) [rank6]: return _C._get_float32_matmul_precision() [rank6]:[titan] 2025-08-21 10:15:28,674 - root - INFO - step: 1 loss: 12.0194 grad_norm: 1.8958 memory: 53.94GiB(56.78%) tps: 296 tflops: 5.16 mfu: 0.52% [rank6]:[titan] 2025-08-21 10:15:28,674 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 no^[^[[rank6]:[titan] 2025-08-21 10:15:43,154 - root - INFO - step: 10 loss: 10.3629 grad_norm: 3.0762 memory: 67.11GiB(70.64%) tps: 5,092 tflops: 88.73 mfu: 8.97% [rank6]:[titan] 2025-08-21 10:15:59,017 - root - INFO - step: 20 loss: 8.9238 grad_norm: 2.5020 memory: 67.11GiB(70.64%) tps: 5,165 tflops: 90.00 mfu: 9.10% [rank6]:[titan] 2025-08-21 10:16:15,051 - root - INFO - step: 30 loss: 7.8167 grad_norm: 1.7460 memory: 67.11GiB(70.64%) tps: 5,109 tflops: 89.04 mfu: 9.00% [rank6]:[titan] 2025-08-21 10:16:31,989 - root - INFO - step: 40 loss: 7.1761 grad_norm: 1.1432 memory: 67.11GiB(70.64%) tps: 4,837 tflops: 84.29 mfu: 8.52% [rank6]:[titan] 2025-08-21 10:16:48,455 - root - INFO - step: 50 loss: 6.7850 grad_norm: 1.4950 memory: 67.11GiB(70.64%) tps: 4,975 tflops: 86.70 mfu: 8.77% [rank6]:[titan] 2025-08-21 10:17:04,602 - root - INFO - step: 60 loss: 6.8310 grad_norm: 1.2972 memory: 67.11GiB(70.64%) tps: 5,074 tflops: 88.42 mfu: 8.94% [rank6]:[titan] 2025-08-21 10:17:22,231 - root - INFO - step: 70 loss: 6.6627 grad_norm: 1.1630 memory: 67.11GiB(70.64%) tps: 4,647 tflops: 80.98 mfu: 8.19% [rank6]:[titan] 2025-08-21 10:17:41,358 - root - INFO - step: 80 loss: 6.3542 grad_norm: 0.8215 memory: 67.11GiB(70.64%) tps: 4,283 tflops: 74.64 mfu: 7.55% [rank6]:[titan] 2025-08-21 10:17:58,336 - root - INFO - step: 90 loss: 6.4442 grad_norm: 1.2542 memory: 67.11GiB(70.64%) tps: 4,825 tflops: 84.09 mfu: 8.50% [rank6]:[titan] 2025-08-21 10:18:12,542 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds [rank6]:[titan] 2025-08-21 10:18:14,566 - root - INFO - step: 100 loss: 6.7519 grad_norm: 1.3966 memory: 67.11GiB(70.64%) tps: 5,048 tflops: 87.97 mfu: 8.89% [rank6]:[titan] 2025-08-21 10:18:14,566 - root - INFO - Training completed [rank6]:[titan] 2025-08-21 10:18:17,159 - root - INFO - Process group destroyed ``` --- torchtitan/distributed/utils.py | 4 +--- torchtitan/experiments/forge/example_train.py | 7 +++++++ torchtitan/experiments/llama4/model/args.py | 9 --------- torchtitan/experiments/llama4/model/model.py | 8 +------- torchtitan/experiments/qwen3/model/model.py | 10 +--------- torchtitan/models/attention.py | 17 ++++++++++++++++- torchtitan/models/deepseek_v3/model/args.py | 9 --------- torchtitan/models/deepseek_v3/model/model.py | 7 +------ torchtitan/models/llama3/model/args.py | 8 -------- torchtitan/models/llama3/model/model.py | 8 +------- torchtitan/train.py | 12 ++++++++++-- 11 files changed, 38 insertions(+), 61 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 5ee31c39c6..74d310dfc1 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -206,9 +206,7 @@ def context(cp_context: Generator[None, None, None] | None = None): if SDPBackend.MATH in ScaledDotProductAttention.backends: ScaledDotProductAttention.backends.remove(SDPBackend.MATH) - assert ( - ScaledDotProductAttention.backends - ), "No valid SDPA backends with CP." + stack.enter_context(cp_context) yield diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 0bebd197d1..92209aeb76 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -160,6 +160,13 @@ def forward_backward_step( # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage inputs = input_dict["input"] + # Create the FlexAttention mask according to the input + if getattr(self.model_args, "use_flex_attn", False): + cp_mesh = ( + parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None + ) + init_attention_mask(inputs, self.tokenizer.eos_id, cp_mesh) + optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"], diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index 949f4cf052..272936a153 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -71,15 +71,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: "CP support for FlexAttention is still in progress." ) - if ( - job_config.parallelism.pipeline_parallel_degree > 1 - and self.use_flex_attn - and self.attn_mask_type == "block_causal" - ): - raise RuntimeError( - "PP + block causal FlexAttention support will be fixed soon." - ) - def get_nparams_and_flops( self, model: nn.Module, seq_len: int ) -> tuple[int, float]: diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py index 84e5613de0..f5557b6d25 100644 --- a/torchtitan/experiments/llama4/model/model.py +++ b/torchtitan/experiments/llama4/model/model.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from torch import nn -from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.attention import build_attention from torchtitan.models.moe import MoE from torchtitan.protocols import ModelProtocol @@ -451,7 +451,6 @@ def _precompute_freqs_cis(self) -> torch.Tensor: def forward( self, tokens: torch.Tensor, - eos_id: int | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -471,11 +470,6 @@ def forward( torch.Tensor: Output logits after applying the Transformer model. """ - if self.model_args.use_flex_attn: - init_attention_mask( - input_batch if input_batch is not None else tokens, eos_id=eos_id - ) - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens diff --git a/torchtitan/experiments/qwen3/model/model.py b/torchtitan/experiments/qwen3/model/model.py index 07d05b734b..e5792cdbb3 100644 --- a/torchtitan/experiments/qwen3/model/model.py +++ b/torchtitan/experiments/qwen3/model/model.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from torch import nn -from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.attention import build_attention from torchtitan.protocols.train_spec import ModelProtocol from .args import Qwen3ModelArgs @@ -411,7 +411,6 @@ def forward( self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None, - eos_id: int | None = None, ): """ Perform a forward pass through the Transformer model. @@ -425,18 +424,11 @@ def forward( This will always be the input batch regardless of the pipeline stage. This field is required for non-first PP stages to perform document masking attention (to analyze the boundary of the document). - eos_id (int | None): End-of-sequence token ID. If not provided, uses self.eos_id. Returns: torch.Tensor: Output logits after applying the Transformer model. """ - if self.model_args.use_flex_attn: - init_attention_mask( - input_batch if input_batch is not None else tokens, - eos_id=eos_id, - ) - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 277d64be1b..f66361a6d2 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -6,10 +6,12 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. +import functools from typing import Callable, ClassVar import torch import torch.nn.functional as F +from torch.distributed.tensor.experimental._attention import create_cp_block_mask from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( _mask_mod_signature, @@ -239,5 +241,18 @@ def build_attention( return ScaledDotProductAttention(attn_mask_type) -def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: +def init_attention_mask( + batch: torch.Tensor, + eos_id: int | None, + cp_mesh: torch.distributed.device_mesh.DeviceMesh | None = None, +) -> None: + + # This is not functional yet because we currently gate the use of Flex + CP + # while we continue debugging accuracy issues. However, we want to evaluate + # the user experience with CP enabled. + if cp_mesh is not None: + FlexAttention.compiled_create_block_mask = functools.partial( + create_cp_block_mask, device_mesh=cp_mesh + ) + FlexAttention.init_attention_mask(batch, eos_id) diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 48e8246fca..d6afedfa34 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -105,15 +105,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: "CP support for FlexAttention is still in progress." ) - if ( - job_config.parallelism.pipeline_parallel_degree > 1 - and self.use_flex_attn - and self.attn_mask_type == "block_causal" - ): - raise RuntimeError( - "PP + block causal FlexAttention support will be fixed soon." - ) - def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: """ Adopted from llama4 implementation. diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 5249a26d5d..e2c4bbeda9 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -10,7 +10,7 @@ import torch from torch import nn -from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.attention import build_attention from torchtitan.models.moe import FeedForward, MoE from torchtitan.protocols.train_spec import ModelProtocol @@ -364,7 +364,6 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: def forward( self, tokens: torch.Tensor, - eos_id: int | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -383,10 +382,6 @@ def forward( Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ - if self.model_args.use_flex_attn: - init_attention_mask( - input_batch if input_batch is not None else tokens, eos_id=eos_id - ) h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 73c8e27700..e2f698f8b1 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -50,14 +50,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: "CP support for FlexAttention is still in progress." ) - if ( - job_config.parallelism.pipeline_parallel_degree > 1 - and self.use_flex_attn - and self.attn_mask_type == "block_causal" - ): - raise RuntimeError( - "PP + block causal FlexAttention support will be fixed soon." - ) self.max_seq_len = seq_len def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 039fe0da77..f2284920aa 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from torch import nn -from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.attention import build_attention from torchtitan.protocols.train_spec import ModelProtocol from .args import TransformerModelArgs @@ -395,7 +395,6 @@ def _precompute_freqs_cis(self) -> torch.Tensor: def forward( self, tokens: torch.Tensor, - eos_id: int | None = None, input_batch: torch.Tensor | None = None, ): """ @@ -415,11 +414,6 @@ def forward( torch.Tensor: Output logits after applying the Transformer model. """ - if self.model_args.use_flex_attn: - init_attention_mask( - input_batch if input_batch is not None else tokens, eos_id=eos_id - ) - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens diff --git a/torchtitan/train.py b/torchtitan/train.py index 1954c356db..ca86e50b48 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -24,6 +24,7 @@ ) from torchtitan.config import ConfigManager, JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.models.attention import init_attention_mask from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger @@ -408,9 +409,16 @@ def forward_backward_step( model_parts = self.model_parts parallel_dims = self.parallel_dims + inputs = input_dict["input"] + # Create the FlexAttention mask according to the input + if getattr(self.model_args, "use_flex_attn", False): + cp_mesh = ( + parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None + ) + init_attention_mask(inputs, self.tokenizer.eos_id, cp_mesh) + # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage - inputs = input_dict["input"] optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"], @@ -450,7 +458,7 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, eos_id=self.tokenizer.eos_id) + pred = model_parts[0](inputs) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred From f738a03bf45e1c063aaffa6f7e4df2cfcb47de42 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 22 Aug 2025 08:37:15 -0700 Subject: [PATCH 117/128] Switch DeepSeekV3 to Use FlexAttention by Default (#1610) Currently, the only available backend for SDPA for DeepSeekV3 is efficient attention kernel. For FlashAttentionV2 (what current SDPA supports), the V embedding dimension must be the same as Q and K. For cuDNN attention, it is complaining the head dimension is too large. The reason for defaulting the attention to SDPA in TorchTitan is that FlexCP is not yet ready. However, the combination of SDPA + CP + DeepSeekV3 is also not functional. This PR updates all DeepSeekV3 configurations to use FlexAttention, which significantly improves the overall performance. **Document masking also contributes to MFU improvement, but the majority is from FlexAttention itself**. ``` CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --training.steps=100 --parallelism.expert_parallel_degree=8 ``` SDPA: ``` [rank0]:[titan] 2025-08-20 18:28:42,047 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 4096, total steps 100 (warmup 200) [rank0]:[titan] 2025-08-20 18:28:42,047 - root - INFO - Training starts at step 1 [rank0]:[titan] 2025-08-20 18:29:04,053 - root - INFO - step: 1 loss: 12.0401 grad_norm: 1.7464 memory: 63.55GiB(66.89%) tps: 1,416 tflops: 24.67 mfu: 2.49% [rank0]:[titan] 2025-08-20 18:29:04,053 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-08-20 18:29:46,138 - root - INFO - step: 10 loss: 10.3087 grad_norm: 3.1896 memory: 78.14GiB(82.25%) tps: 7,008 tflops: 122.12 mfu: 12.35% [rank0]:[titan] 2025-08-20 18:30:33,628 - root - INFO - step: 20 loss: 8.7601 grad_norm: 2.5195 memory: 78.14GiB(82.25%) tps: 6,900 tflops: 120.24 mfu: 12.16% [rank0]:[titan] 2025-08-20 18:31:22,497 - root - INFO - step: 30 loss: 7.7450 grad_norm: 1.9296 memory: 78.14GiB(82.25%) tps: 6,705 tflops: 116.85 mfu: 11.82% [rank0]:[titan] 2025-08-20 18:32:19,709 - root - INFO - step: 40 loss: 6.9795 grad_norm: 0.6893 memory: 78.14GiB(82.25%) tps: 5,728 tflops: 99.81 mfu: 10.09% [rank0]:[titan] 2025-08-20 18:33:34,343 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds [rank0]:[titan] 2025-08-20 18:33:43,863 - root - INFO - step: 50 loss: 6.8381 grad_norm: 1.1848 memory: 78.14GiB(82.25%) tps: 3,894 tflops: 67.86 mfu: 6.86% [rank0]:[titan] 2025-08-20 18:34:37,289 - root - INFO - step: 60 loss: 6.5727 grad_norm: 0.9871 memory: 78.14GiB(82.25%) tps: 6,133 tflops: 106.88 mfu: 10.81% [rank0]:[titan] 2025-08-20 18:35:27,959 - root - INFO - step: 70 loss: 6.5041 grad_norm: 1.5895 memory: 78.14GiB(82.25%) tps: 6,467 tflops: 112.70 mfu: 11.40% [rank0]:[titan] 2025-08-20 18:36:16,732 - root - INFO - step: 80 loss: 6.3179 grad_norm: 0.9556 memory: 78.14GiB(82.25%) tps: 6,719 tflops: 117.08 mfu: 11.84% [rank0]:[titan] 2025-08-20 18:37:05,604 - root - INFO - step: 90 loss: 6.2124 grad_norm: 0.8286 memory: 78.14GiB(82.25%) tps: 6,705 tflops: 116.85 mfu: 11.81% [rank0]:[titan] 2025-08-20 18:37:49,285 - root - INFO - [GC] Peforming periodical GC collection 0.04 seconds [rank0]:[titan] 2025-08-20 18:37:54,361 - root - INFO - step: 100 loss: 6.2596 grad_norm: 1.5143 memory: 78.14GiB(82.25%) tps: 6,721 tflops: 117.12 mfu: 11.84% [rank0]:[titan] 2025-08-20 18:37:54,361 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-08-20 18:37:56,364 - root - INFO - Training completed [rank0]:[titan] 2025-08-20 18:37:57,535 - root - INFO - Process group destroyed ``` FlexAttention (now) ``` [rank0]:/data/users/chienchin/mywork/pytorch/torch/__init__.py:1539: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /data/users/chienchin/mywork/pytorch/aten/src/ATen/Context.cpp:80.) [rank0]: return _C._get_float32_matmul_precision() [rank0]:[titan] 2025-08-20 22:16:59,699 - root - INFO - step: 1 loss: 11.9984 grad_norm: 1.7288 memory: 63.55GiB(66.89%) tps: 727 tflops: 12.67 mfu: 1.28% [rank0]:[titan] 2025-08-20 22:16:59,699 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-08-20 22:17:32,228 - root - INFO - step: 10 loss: 10.3101 grad_norm: 2.9111 memory: 78.14GiB(82.25%) tps: 9,066 tflops: 157.99 mfu: 15.97% [rank0]:[titan] 2025-08-20 22:18:08,957 - root - INFO - step: 20 loss: 8.7431 grad_norm: 2.5391 memory: 78.14GiB(82.25%) tps: 8,922 tflops: 155.47 mfu: 15.72% [rank0]:[titan] 2025-08-20 22:18:46,981 - root - INFO - step: 30 loss: 7.7133 grad_norm: 1.7743 memory: 78.14GiB(82.25%) tps: 8,618 tflops: 150.18 mfu: 15.19% [rank0]:[titan] 2025-08-20 22:19:26,672 - root - INFO - step: 40 loss: 6.9643 grad_norm: 0.7227 memory: 78.14GiB(82.25%) tps: 8,256 tflops: 143.88 mfu: 14.55% [rank0]:[titan] 2025-08-20 22:20:01,975 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds [rank0]:[titan] 2025-08-20 22:20:06,015 - root - INFO - step: 50 loss: 6.8046 grad_norm: 1.0556 memory: 78.14GiB(82.25%) tps: 8,329 tflops: 145.15 mfu: 14.68% [rank0]:[titan] 2025-08-20 22:20:45,784 - root - INFO - step: 60 loss: 6.5364 grad_norm: 1.7141 memory: 78.14GiB(82.25%) tps: 8,240 tflops: 143.59 mfu: 14.52% [rank0]:[titan] 2025-08-20 22:21:25,078 - root - INFO - step: 70 loss: 6.4709 grad_norm: 1.2385 memory: 78.14GiB(82.25%) tps: 8,340 tflops: 145.33 mfu: 14.69% [rank0]:[titan] 2025-08-20 22:22:03,088 - root - INFO - step: 80 loss: 6.2786 grad_norm: 2.2534 memory: 78.14GiB(82.25%) tps: 8,621 tflops: 150.24 mfu: 15.19% [rank0]:[titan] 2025-08-20 22:22:41,254 - root - INFO - step: 90 loss: 6.1441 grad_norm: 0.6878 memory: 78.14GiB(82.25%) tps: 8,586 tflops: 149.62 mfu: 15.13% [rank0]:[titan] 2025-08-20 22:23:15,059 - root - INFO - [GC] Peforming periodical GC collection 0.05 seconds [rank0]:[titan] 2025-08-20 22:23:19,063 - root - INFO - step: 100 loss: 6.1348 grad_norm: 1.2875 memory: 78.14GiB(82.25%) tps: 8,667 tflops: 151.04 mfu: 15.27% [rank0]:[titan] 2025-08-20 22:23:19,064 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-08-20 22:23:21,065 - root - INFO - Training completed [rank0]:[titan] 2025-08-20 22:23:22,436 - root - INFO - Process group destroyed ``` --- torchtitan/models/deepseek_v3/__init__.py | 6 ++++++ torchtitan/models/llama3/infra/parallelize.py | 1 + 2 files changed, 7 insertions(+) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index f81db35341..1c3d2b19d2 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -100,6 +100,8 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, + use_flex_attn=True, + attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( vocab_size=102400, @@ -125,6 +127,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, + use_flex_attn=True, + attn_mask_type="block_causal", ), "671B": DeepSeekV3ModelArgs( vocab_size=129280, @@ -150,6 +154,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, + use_flex_attn=True, + attn_mask_type="block_causal", ), } diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 2e2e81302a..ecd08990e7 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -243,6 +243,7 @@ def apply_tp( # the result of max, since the absolute maximum is # used to compute the scaling factor for quantization. torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, } From cab22e73578d4a32c6d4c6248e1756a08ab77282 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Fri, 22 Aug 2025 13:19:53 -0700 Subject: [PATCH 118/128] Centralize Async TP Enablement with maybe_enable_async_tp API (#1619) This PR addresses duplicated code related to enabling async TP across different parts of the codebase. It introduces a new API, `maybe_enable_async_tp()`, which centralizes the enablement logic and is reused consistently in all models. Note that while this PR fixes one async TP bug in TorchTitan, it does not fully resolve https://github.com/pytorch/torchtitan/issues/1613, as there appear to be additional bugs in PyTorch's async TP implementation. --- torchtitan/distributed/tensor_parallel.py | 27 +++++++++++++++++++ .../experiments/llama4/infra/parallelize.py | 18 +++---------- .../models/deepseek_v3/infra/parallelize.py | 23 +++------------- torchtitan/models/llama3/infra/parallelize.py | 18 +++---------- 4 files changed, 37 insertions(+), 49 deletions(-) create mode 100644 torchtitan/distributed/tensor_parallel.py diff --git a/torchtitan/distributed/tensor_parallel.py b/torchtitan/distributed/tensor_parallel.py new file mode 100644 index 0000000000..a2749f4c11 --- /dev/null +++ b/torchtitan/distributed/tensor_parallel.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.config import JobConfig +from torchtitan.tools.logging import logger + + +def maybe_enable_async_tp(job_config: JobConfig, tp_mesh: DeviceMesh): + if not job_config.parallelism.enable_async_tensor_parallel: + return + + if not (job_config.compile.enable and "model" in job_config.compile.components): + raise RuntimeError("Async TP requires --training.compile") + + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info("Async TP is enabled") diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index a716c78907..e511686575 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -28,6 +28,7 @@ ReordererSequenceParallel, TensorParallel, ) +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp from torchtitan.tools.logging import logger @@ -66,12 +67,6 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) if parallel_dims.tp_enabled: - if ( - job_config.parallelism.enable_async_tensor_parallel - and not model_compile_enabled - ): - raise RuntimeError("Async TP requires torch.compile") - enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( "rowwise", @@ -88,8 +83,8 @@ def parallelize_llama( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, - enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( @@ -177,7 +172,6 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, - enable_async_tp: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -256,14 +250,8 @@ def apply_non_moe_tp( parallelize_plan=layer_plan, ) - if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) - logger.info( - f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" "Tensor Parallelism to the model" ) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 1aedd73ad3..8423c2a8e6 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -19,6 +19,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.distributed.expert_parallel import NoParallel +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.experiments.llama4.infra.parallelize import ( apply_compile, apply_fsdp, @@ -51,16 +52,7 @@ def parallelize_deepseekv3( ): raise NotImplementedError("CP support for FlexAttention is still in progress.") - model_compile_enabled = ( - job_config.compile.enable and "model" in job_config.compile.components - ) if parallel_dims.tp_enabled: - if ( - job_config.parallelism.enable_async_tensor_parallel - and not model_compile_enabled - ): - raise RuntimeError("Async TP requires --training.compile") - enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( "rowwise", @@ -79,8 +71,8 @@ def parallelize_deepseekv3( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, - enable_async_tp=False, ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( @@ -100,7 +92,7 @@ def parallelize_deepseekv3( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) - if model_compile_enabled: + if job_config.compile.enable and "model" in job_config.compile.components: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE torch._dynamo.config.capture_scalar_outputs = True apply_compile(model) @@ -167,7 +159,6 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, - enable_async_tp: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -260,13 +251,7 @@ def apply_non_moe_tp( parallelize_plan=layer_plan, ) - if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) - logger.info( - f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" "Tensor Parallelism to the model" ) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index ecd08990e7..4ed3363606 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -31,6 +31,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.config.job_config import ActivationCheckpoint as ACConfig from torchtitan.distributed import ParallelDims +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.tools.logging import logger @@ -67,12 +68,6 @@ def parallelize_llama( job_config.compile.enable and "model" in job_config.compile.components ) if parallel_dims.tp_enabled: - if ( - job_config.parallelism.enable_async_tensor_parallel - and not model_compile_enabled - ): - raise RuntimeError("Async TP requires torch.compile") - enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( "rowwise", @@ -89,8 +84,8 @@ def parallelize_llama( world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, - enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) @@ -144,7 +139,6 @@ def apply_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, - enable_async_tp: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -221,14 +215,8 @@ def apply_tp( parallelize_plan=layer_plan, ) - if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) - logger.info( - f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" "Tensor Parallelism to the model" ) From cd337db6303870a4182a730d1c4d6941f47bb2b8 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Fri, 22 Aug 2025 13:20:20 -0700 Subject: [PATCH 119/128] [Cleanup] Miscellaneous Refactors (#1607) This PR makes several miscellaneous refactors to clean up `torchtitan` before release. Changes: - Sets each of `model_parts` to eval mode in `Validator` class to support PP (Bug fix) - Refactor `checkpoint.enable_checkpoint -> checkpoint.enable` (Refactor) - Refacotr `validation.enabled -> validation.enable` (Refactor) --- docs/checkpoint.md | 16 ++++++------ docs/debugging.md | 2 +- docs/evaluation.md | 2 +- tests/integration_tests.py | 26 +++++++++---------- tests/integration_tests_ft.py | 2 +- tests/unit_tests/test_checkpoint.py | 4 +-- torchtitan/components/checkpoint.py | 12 ++++----- torchtitan/components/validate.py | 16 ++++-------- torchtitan/config/job_config.py | 6 ++--- .../train_configs/deepseek_v2.toml | 2 +- .../experiments/flux/inference/run_infer.sh | 2 +- .../flux/tests/integration_tests.py | 8 +++--- torchtitan/experiments/flux/train.py | 4 +-- .../flux/train_configs/debug_model.toml | 4 +-- .../flux/train_configs/flux_dev_model.toml | 4 +-- .../train_configs/flux_schnell_model.toml | 4 +-- torchtitan/experiments/flux/validate.py | 7 ----- torchtitan/experiments/forge/example_train.py | 7 ++--- .../experiments/llama4/scripts/REAME.md | 4 +-- .../llama4/train_configs/debug_model.toml | 2 +- .../llama4/train_configs/llama4_17bx128e.toml | 2 +- .../llama4/train_configs/llama4_17bx16e.toml | 2 +- .../qwen3/train_configs/qwen3_0.6b.toml | 2 +- .../simple_fsdp/tests/integration_tests.py | 14 +++++----- .../train_configs/debug_model.toml | 2 +- .../train_configs/deepseek_v3_16b.toml | 2 +- .../train_configs/deepseek_v3_671b.toml | 2 +- .../llama3/train_configs/debug_model.toml | 4 +-- .../llama3/train_configs/llama3_405b.toml | 4 +-- .../llama3/train_configs/llama3_70b.toml | 4 +-- .../llama3/train_configs/llama3_8b.toml | 4 +-- torchtitan/protocols/state_dict_adapter.py | 2 +- torchtitan/train.py | 6 ++--- 33 files changed, 84 insertions(+), 100 deletions(-) diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 45915d8ad8..6e3112309b 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -5,10 +5,10 @@ You may want to enable checkpointing in `torchtitan` for better fault tolerance ## A general guide to use checkpoints during training 1. ENABLE CHECKPOINTING -In your `torchtitan` training config, ensure that `enable_checkpoint` is set to True. +In your `torchtitan` training config, ensure that under `[checkpoint]`, `enable` is set to True. ``` [checkpoint] -enable_checkpoint = true +enable = true folder = "checkpoint" interval = 500 ``` @@ -16,7 +16,7 @@ interval = 500 By setting `last_save_model_only` to `True`, the checkpoint will only contain the model and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size. ``` [checkpoint] -enable_checkpoint = true +enable = true last_save_model_only = true ``` @@ -24,7 +24,7 @@ last_save_model_only = true The default model states are in `float32`. You can choose to export the checkpoint in a lower precision format such as `bfloat16`. ``` [checkpoint] -enable_checkpoint = true +enable = true last_save_model_only = true export_dtype = "bfloat16" ``` @@ -34,7 +34,7 @@ In some cases, you may want to partially load from a previous-trained checkpoint This parameter takes a list of string that should be excluded from loading. ``` [checkpoint] -enable_checkpoint = true +enable = true exclude_from_loading = ["data_loader", "lr_scheduler"] ``` When used in command line, the parameter should be a comma-separated list of strings. For example: `--checkpoint.exclude_from_loading data_loader,lr_scheduler`. @@ -42,7 +42,7 @@ When used in command line, the parameter should be a comma-separated list of str 5. EXAMPLE CHECKPOINT CONFIGURATION ``` [checkpoint] -enable_checkpoint = true +enable = true folder = "checkpoint" interval = 10 load_step = 5 @@ -60,7 +60,7 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l To create a seed checkpoint, use the same model config as you use for training. e.g. ```bash -NGPU=1 CONFIG_FILE= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 +NGPU=1 CONFIG_FILE= ./run_train.sh --checkpoint.enable --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 ``` ## Conversion support @@ -86,7 +86,7 @@ This guide will walk you through the steps required to convert a checkpoint from 1. CHECKPOINT CONFIGURATION ``` [checkpoint] -enable_checkpoint = true +enable = true folder = "checkpoint" interval = 10 last_save_model_only = true diff --git a/docs/debugging.md b/docs/debugging.md index f8479d8e0a..bc683bd9b3 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -100,7 +100,7 @@ For multiple experimental runs with different parallelism configs, we need to us ```bash -NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 +NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 ``` **Note**: Using a seed checkpoint will only make sure a model has same initial weights when configs change, but the training process may not be the same even after setting the seed and the `deterministic` mode, e.g. due to tensor shape change, data precision change, usage of randomness in model code, etc. diff --git a/docs/evaluation.md b/docs/evaluation.md index 69de104aaa..64306cdb34 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -9,7 +9,7 @@ Below is an example validation config: ```toml [validation] -enabled = true +enable = true dataset = "c4_validation" freq = 500 steps = -1 # consumes the entire validation set diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 73ded45482..9142e8b5e5 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -108,10 +108,10 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", ], [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--training.steps 20", ], ], @@ -121,13 +121,13 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--checkpoint.folder hf_checkpoint", "--checkpoint.last_save_model_only", "--checkpoint.last_save_in_hf", ], [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--checkpoint.initial_load_path artifacts-to-be-uploaded/model_only_hf_checkpoint/hf_checkpoint/step-10/", "--checkpoint.initial_load_model_only", "--checkpoint.initial_load_in_hf", @@ -139,7 +139,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--checkpoint.last_save_model_only", ], ], @@ -149,7 +149,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--checkpoint.last_save_model_only", "--checkpoint.export_dtype bfloat16", ], @@ -244,14 +244,14 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--parallelism.pipeline_parallel_degree 2", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", ], [ "--training.steps 20", - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--parallelism.pipeline_parallel_degree 2", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", @@ -443,7 +443,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--parallelism.tensor_parallel_degree=2", "--parallelism.context_parallel_degree=2", "--training.enable_cpu_offload", @@ -474,7 +474,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", ], [ # placeholder for the generation script's generate step @@ -497,13 +497,13 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--training.steps 10", ], # Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be # excluded during loading to avoid errors caused by mismatched dp_degree. [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer", "--parallelism.tensor_parallel_degree 2", "--training.steps 20", @@ -542,7 +542,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--validation.enabled", + "--validation.enable", "--validation.dataset c4_test", "--parallelism.tensor_parallel_degree=2", "--parallelism.context_parallel_degree=2", diff --git a/tests/integration_tests_ft.py b/tests/integration_tests_ft.py index 6430a54dd5..c0c64e4e74 100644 --- a/tests/integration_tests_ft.py +++ b/tests/integration_tests_ft.py @@ -32,7 +32,7 @@ def build_test_list(): integration_tests_flavors["debug_model.toml"] = [ OverrideDefinitions( [ - ["--training.steps 10", "--checkpoint.enable_checkpoint"], + ["--training.steps 10", "--checkpoint.enable"], ], "Default TorchFT integration test", "default_torchft", diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index a0fb5d3bab..bd4b892719 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -83,7 +83,7 @@ class DummyJobConfig: def __init__(self, job): self.job = job self.checkpoint = CheckpointConfig( - enable_checkpoint=True, + enable=True, async_mode="disabled", folder="", interval=1, @@ -114,7 +114,7 @@ def setUp(self): self.ft_manager = DummyFTManager() ckpt_cfg = CheckpointConfig( - enable_checkpoint=True, + enable=True, async_mode="DISABLED", folder="", interval=1, diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index bf6f0b3ae2..fcec601850 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -186,7 +186,7 @@ def __init__( base_folder: str = "", ft_manager: FTManager | None = None, ) -> None: - self.enable_checkpoint = checkpoint_config.enable_checkpoint + self.enable = checkpoint_config.enable self.ft_manager = ( ft_manager.manager if ft_manager and ft_manager.enabled else None @@ -216,10 +216,10 @@ def load_state_dict(state_dict): async_mode = checkpoint_config.async_mode.lower() self.enable_staging = ( - self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM + self.enable and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM ) or self.ft_manager - if not self.enable_checkpoint and self.ft_manager is None: + if not self.enable and self.ft_manager is None: return self.states = states @@ -305,7 +305,7 @@ def __del__(self): self.close() def close(self): - if hasattr(self, "enable_checkpoint") and self.enable_checkpoint: + if hasattr(self, "enable") and self.enable: if hasattr(self, "mp") and self.mp and self.mp.is_alive(): self.mp_queue_send.put(Terminate()) self.mp.join() @@ -517,7 +517,7 @@ def load(self, step: int = -1) -> bool: if self.ft_manager: self._ft_load() - if not self.enable_checkpoint: + if not self.enable: return False model_only = False @@ -739,7 +739,7 @@ def _save_last_step(self, curr_step: int) -> None: ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: - if not self.enable_checkpoint: + if not self.enable: return False if curr_step == 1 and self.enable_first_step_checkpoint: diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 1bdb854e80..a88b41a508 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn -from torch.distributed.fsdp import FSDPModule from torch.distributed.pipelining.schedules import _PipelineSchedule from torchtitan.components.dataloader import BaseDataLoader from torchtitan.components.loss import LossFunction @@ -82,8 +81,8 @@ def validate( step: int, ) -> None: # Set model to eval mode - model = model_parts[0] - model.eval() + for model in model_parts: + model.eval() parallel_dims = self.parallel_dims @@ -148,7 +147,7 @@ def validate( with self.validation_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - predictions = model(inputs) + predictions = model_parts[0](inputs) loss = self.loss_fn(predictions, labels) accumulated_losses.append(loss.detach()) @@ -167,14 +166,9 @@ def validate( self.metrics_processor.log_validation(loss=global_avg_loss, step=step) - # Reshard after run forward pass - # This is to ensure the model weights are sharded the same way for checkpoint saving. - for module in model.modules(): - if isinstance(module, FSDPModule): - module.reshard() - # Set model back to train mode - model.train() + for model in model_parts: + model.train() def build_validator( diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index a43b3ce060..a2247aa210 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -398,13 +398,13 @@ class Parallelism: @dataclass class Checkpoint: - enable_checkpoint: bool = False + enable: bool = False """Whether to enable checkpoint""" folder: str = "checkpoint" """ The folder to store the checkpoints. - When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}. + When enable is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}. """ interval: int = 500 @@ -710,7 +710,7 @@ class Experimental: @dataclass class Validation: - enabled: bool = False + enable: bool = False """Enable validation to default run validation after each training loop""" dataset: str = "c4_validation" diff --git a/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml b/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml index eae923714e..1b20031ca6 100644 --- a/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml +++ b/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml @@ -56,7 +56,7 @@ context_parallel_degree = 1 # expert_parallel_degree = 2 set in custom_args [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 10 model_weights_only = false diff --git a/torchtitan/experiments/flux/inference/run_infer.sh b/torchtitan/experiments/flux/inference/run_infer.sh index b5419af2fb..dc7fc8ea90 100755 --- a/torchtitan/experiments/flux/inference/run_infer.sh +++ b/torchtitan/experiments/flux/inference/run_infer.sh @@ -18,5 +18,5 @@ PYTORCH_ALLOC_CONF="expandable_segments:True" \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -m torchtitan.experiments.flux.inference.infer --job.config_file ${CONFIG_FILE} \ ---checkpoint.enable_checkpoint \ +--checkpoint.enable \ --checkpoint.exclude_from_loading=lr_scheduler,dataloader,optimizer "$@" diff --git a/torchtitan/experiments/flux/tests/integration_tests.py b/torchtitan/experiments/flux/tests/integration_tests.py index aa23add5cf..9cc4e1eee7 100755 --- a/torchtitan/experiments/flux/tests/integration_tests.py +++ b/torchtitan/experiments/flux/tests/integration_tests.py @@ -44,10 +44,10 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", ], [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--training.steps 20", ], ], @@ -57,7 +57,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--checkpoint.last_save_model_only", ], ], @@ -65,7 +65,7 @@ def build_test_list(): "last_save_model_only_fp32", ), OverrideDefinitions( - [["--validation.enabled"]], "Flux Validation Test", "validation" + [["--validation.enable"]], "Flux Validation Test", "validation" ), # Parallelism tests. OverrideDefinitions( diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index e364b30e53..624792e83e 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -78,7 +78,7 @@ def __init__(self, job_config: JobConfig): job_config=job_config, ) - if job_config.validation.enabled: + if job_config.validation.enable: self.validator.flux_init( device=self.device, _dtype=self._dtype, @@ -167,7 +167,7 @@ def forward_backward_step( int(os.environ["WORLD_SIZE"]) == 1 ), "Must create seed checkpoint using a single device, to disable sharding." assert ( - config.checkpoint.enable_checkpoint + config.checkpoint.enable ), "Must enable checkpointing when creating a seed checkpoint." trainer.checkpointer.save(curr_step=0, last_step=True) logger.info("Created seed checkpoint") diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml index 9be99b0424..d22815e59e 100644 --- a/torchtitan/experiments/flux/train_configs/debug_model.toml +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -57,7 +57,7 @@ custom_args_module = "torchtitan.experiments.flux.job_config" mode = "full" [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 10 last_save_model_only = false @@ -65,7 +65,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [validation] -enabled = false +enable = false dataset = "coco-validation" freq = 5 local_batch_size = 8 diff --git a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml index 083ad7977a..389b1aa9a7 100644 --- a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml @@ -56,7 +56,7 @@ custom_args_module = "torchtitan.experiments.flux.job_config" mode = "full" [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 1_000 last_save_model_only = true @@ -64,7 +64,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [validation] -enabled = false +enable = false dataset = "coco-validation" local_batch_size = 32 steps = 1 diff --git a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml index 0a9cce71c7..9e1cbb85fa 100644 --- a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml @@ -57,7 +57,7 @@ custom_args_module = "torchtitan.experiments.flux.job_config" mode = "full" [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 1_000 last_save_model_only = true @@ -65,7 +65,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [validation] -enabled = false +enable = false dataset = "coco-validation" local_batch_size=64 freq = 1000 diff --git a/torchtitan/experiments/flux/validate.py b/torchtitan/experiments/flux/validate.py index 89dc4f8942..28f09156b1 100644 --- a/torchtitan/experiments/flux/validate.py +++ b/torchtitan/experiments/flux/validate.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn -from torch.distributed.fsdp import FSDPModule from torch.distributed.pipelining.schedules import _PipelineSchedule from torchtitan.components.dataloader import BaseDataLoader @@ -239,12 +238,6 @@ def validate( self.metrics_processor.log_validation(loss=global_avg_loss, step=step) - # Reshard after run forward pass - # This is to ensure the model weights are sharded the same way for checkpoint saving. - for module in model.modules(): - if isinstance(module, FSDPModule): - module.reshard() - # Set model back to train mode model.train() diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 92209aeb76..a43f27c926 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -100,7 +100,7 @@ def __init__(self, job_config: JobConfig): self.step = 0 # Build validator if validation is configured - if job_config.validation.enabled: + if job_config.validation.enable: self.validator = build_validator( job_config=job_config, dp_world_size=self.dp_degree, @@ -297,10 +297,7 @@ def train(self): break # Run validation if validator is available - if ( - self.job_config.validation.enabled - and self.validator.should_validate(self.step) - ): + if self.job_config.enable and self.validator.should_validate(self.step): self.validator.validate(self.model_parts) self.checkpointer.save( diff --git a/torchtitan/experiments/llama4/scripts/REAME.md b/torchtitan/experiments/llama4/scripts/REAME.md index c4cd6c3241..97285a2d9f 100644 --- a/torchtitan/experiments/llama4/scripts/REAME.md +++ b/torchtitan/experiments/llama4/scripts/REAME.md @@ -7,11 +7,11 @@ This folder contains the scripts for converting officially released Llama 4 chec From Meta format: ```bash -CONFIG_FILE=../train_configs/llama4_16.toml ./convert_meta_to_dcp.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8 +CONFIG_FILE=../train_configs/llama4_16.toml ./convert_meta_to_dcp.sh --checkpoint.enable --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8 ``` From HuggingFace format: ```bash -CONFIG_FILE=../train_configs/llama4_16.toml ./convert_hf_to_dcp_with_gpus.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8 +CONFIG_FILE=../train_configs/llama4_16.toml ./convert_hf_to_dcp_with_gpus.sh --checkpoint.enable --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8 ``` diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index 0bdb16ecb9..58484d34d3 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -56,7 +56,7 @@ expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 10 last_save_model_only = false diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index c40437b377..6de020fad0 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -49,7 +49,7 @@ expert_parallel_degree = 1 expert_tensor_parallel_degree = 8 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 500 last_save_model_only = true diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index ab718cf6f9..bc9e9bc4f5 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -47,7 +47,7 @@ expert_parallel_degree = 1 expert_tensor_parallel_degree = 8 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 500 last_save_model_only = true diff --git a/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml b/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml index 38dc259496..220708d144 100644 --- a/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml +++ b/torchtitan/experiments/qwen3/train_configs/qwen3_0.6b.toml @@ -41,7 +41,7 @@ tensor_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 500 last_save_model_weights_only = false diff --git a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py index 45e99aa6c9..f4efbbc599 100755 --- a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py +++ b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py @@ -78,10 +78,10 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", ], [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--training.steps 20", ], ], @@ -91,14 +91,14 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--parallelism.pipeline_parallel_degree 2", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", ], [ "--training.steps 20", - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--parallelism.pipeline_parallel_degree 2", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", @@ -180,20 +180,20 @@ def build_test_list(): OverrideDefinitions( [ [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--training.steps 10", ], # Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be # excluded during loading to avoid errors caused by mismatched dp_degree. [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer", "--parallelism.tensor_parallel_degree 2", "--training.steps 20", ], # load at [tp:4]. [ - "--checkpoint.enable_checkpoint", + "--checkpoint.enable", "--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer", "--parallelism.tensor_parallel_degree 4", "--training.steps 30", diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 79a15bd2e2..dc9f37f443 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -56,7 +56,7 @@ expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 10 last_save_model_only = false diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 3ef6e67fc0..6612cb5cf8 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -53,7 +53,7 @@ expert_parallel_degree = 8 expert_tensor_parallel_degree = 1 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 10 last_save_model_only = true diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 1a748d56f1..ad238839a1 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -54,7 +54,7 @@ expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 500 last_save_model_only = true diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index d446027f48..ecabf6e5db 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -55,7 +55,7 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 10 last_save_model_only = false @@ -76,7 +76,7 @@ precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output"] [validation] -enabled = false +enable = false dataset = "c4_validation" freq = 5 steps = 10 diff --git a/torchtitan/models/llama3/train_configs/llama3_405b.toml b/torchtitan/models/llama3/train_configs/llama3_405b.toml index 5895f7f255..824918eae2 100644 --- a/torchtitan/models/llama3/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_405b.toml @@ -45,7 +45,7 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 500 last_save_model_only = true @@ -65,7 +65,7 @@ precompute_float8_dynamic_scale_for_fsdp = true filter_fqns = ["output"] [validation] -enabled = false +enable = false dataset = "c4_validation" freq = 500 steps = -1 diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index 9a2eddd093..20756ba421 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -44,7 +44,7 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 500 last_save_model_only = true @@ -64,7 +64,7 @@ precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output"] [validation] -enabled = false +enable = false dataset = "c4_validation" freq = 500 steps = -1 diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index d9a9c331f7..577d489c33 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -44,7 +44,7 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 500 last_save_model_only = true @@ -65,7 +65,7 @@ precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output"] [validation] -enabled = false +enable = false dataset = "c4_validation" freq = 100 steps = -1 diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 106a7937ef..916368cb94 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -67,7 +67,7 @@ def __init__(self, model_args: BaseModelArgs, hf_assets_path: str | None): except FileNotFoundError: logger.warning( f"model.safetensors.index.json not found at hf_assets_path: {mapping_path}. \ - Defaulting to saving a single safetensors file if checkpoint is saved in HF format." + Defaulting to saving a single safetensors file if checkpoint is saved in HF format" ) hf_safetensors_indx = None diff --git a/torchtitan/train.py b/torchtitan/train.py index ca86e50b48..758a5a6995 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -334,7 +334,7 @@ def __init__(self, job_config: JobConfig): ) # Build validator if validation is configured - if job_config.validation.enabled: + if job_config.validation.enable: assert self.train_spec.build_validator_fn is not None pp_schedule, pp_has_first_stage, pp_has_last_stage = ( @@ -593,7 +593,7 @@ def train(self): # Run validation if validator is available if ( - self.job_config.validation.enabled + self.job_config.validation.enable and self.validator.should_validate(self.step) ): self.validator.validate(self.model_parts, self.step) @@ -648,7 +648,7 @@ def close(self) -> None: int(os.environ["WORLD_SIZE"]) == 1 ), "Must create seed checkpoint using a single device, to disable sharding." assert ( - config.checkpoint.enable_checkpoint + config.checkpoint.enable ), "Must enable checkpointing when creating a seed checkpoint." trainer.checkpointer.save(curr_step=0, last_step=True) logger.info("Created seed checkpoint") From 2025abb5fbfed71212a46a8e2c34cb23abbc9fc9 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Sun, 24 Aug 2025 21:53:48 -0700 Subject: [PATCH 120/128] async tp minor fix (#1629) follow up of https://github.com/pytorch/torchtitan/pull/1619 to fix remaining errors. also fixing a TODO --- .../experiments/llama4/infra/parallelize.py | 8 ++++---- .../experiments/simple_fsdp/parallelize.py | 15 +++------------ .../experiments/simple_fsdp/simple_fsdp.py | 16 +++++++--------- .../models/deepseek_v3/infra/parallelize.py | 7 +++++-- torchtitan/models/llama3/infra/parallelize.py | 6 +++--- 5 files changed, 22 insertions(+), 30 deletions(-) diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index e511686575..3e4dd43f70 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -63,9 +63,6 @@ def parallelize_llama( ): raise NotImplementedError("CP support for FlexAttention is still in progress.") - model_compile_enabled = ( - job_config.compile.enable and "model" in job_config.compile.components - ) if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -104,6 +101,9 @@ def parallelize_llama( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE @@ -160,7 +160,7 @@ def parallelize_llama( apply_ddp( model, dp_mesh, - enable_compile=job_config.training.compile, + enable_compile=model_compile_enabled, enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) diff --git a/torchtitan/experiments/simple_fsdp/parallelize.py b/torchtitan/experiments/simple_fsdp/parallelize.py index 4d909e4fe4..5feffdabb3 100644 --- a/torchtitan/experiments/simple_fsdp/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/parallelize.py @@ -9,6 +9,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_tp from torchtitan.tools.logging import logger @@ -37,16 +38,7 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - model_compile_enabled = ( - job_config.compile.enable and "model" in job_config.compile.components - ) if parallel_dims.tp_enabled: - if ( - job_config.parallelism.enable_async_tensor_parallel - and not model_compile_enabled - ): - raise RuntimeError("Async TP requires torch.compile") - enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( "rowwise", @@ -64,8 +56,8 @@ def parallelize_llama( tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, - enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) + maybe_enable_async_tp(job_config, tp_mesh) if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) @@ -98,11 +90,10 @@ def parallelize_llama( mode=dp_mode, ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, - tp_mesh=tp_mesh if parallel_dims.tp_enabled else None, ) logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode) - if model_compile_enabled: + if job_config.compile.enable and "model" in job_config.compile.components: torch._inductor.config.reorder_for_peak_memory = False model = torch.compile(model, fullgraph=True) diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 8f7a2f4da8..38074a2844 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -185,7 +185,12 @@ def _custom_policy(ctx, func, *args, **kwargs): class ReplicateComputation(torch.nn.Module): def __init__( - self, device_mesh, param_sharding, mode, regional_ac, mp_policy, tp_mesh + self, + device_mesh, + param_sharding, + mode, + regional_ac, + mp_policy, ): super().__init__() self.device_mesh = device_mesh @@ -197,7 +202,6 @@ def __init__( mp_policy = mp_policy or MixedPrecisionPolicy() self.param_dtype = mp_policy.param_dtype self.reduce_dtype = mp_policy.reduce_dtype - self.tp_mesh = tp_mesh def replicate_compute(self, x): # data parallel runtime replicate parameters and do local compute @@ -207,10 +211,7 @@ def replicate_compute(self, x): # support for FSDP/DDP/HSDP + TP (assuming TP shards the inner-most dim) if x._spec.mesh.mesh_dim_names[-1] == "tp": tp_placement = x._spec.placements[-1] - # TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"] - # after DeviceMesh supports slicing a non-root mesh - # dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"] - dp_mesh, tp_mesh = self.device_mesh, self.tp_mesh + dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"] # re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather sharded_local_tensor = x.to_local() @@ -270,7 +271,6 @@ def data_parallel( mode="replicate", ac_mode: str = "none", mp_policy: Optional[MixedPrecisionPolicy] = None, - tp_mesh: Optional[DeviceMesh] = None, ): if mode == "replicate": param_sharding = (Replicate(),) @@ -314,7 +314,6 @@ def data_parallel( # mode, # regional_ac, # mp_policy=mp_policy, - # tp_mesh=tp_mesh, # ), # unsafe=True, # ) @@ -328,7 +327,6 @@ def data_parallel( mode, regional_ac, mp_policy=mp_policy, - tp_mesh=tp_mesh, ), ) return model diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 8423c2a8e6..c77250d0f3 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -92,7 +92,10 @@ def parallelize_deepseekv3( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) - if job_config.compile.enable and "model" in job_config.compile.components: + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + if model_compile_enabled: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE torch._dynamo.config.capture_scalar_outputs = True apply_compile(model) @@ -147,7 +150,7 @@ def parallelize_deepseekv3( apply_ddp( model, dp_mesh, - enable_compile=job_config.training.compile, + enable_compile=model_compile_enabled, enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 4ed3363606..7d0b5de92b 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -64,9 +64,6 @@ def parallelize_llama( ): raise NotImplementedError("CP support for FlexAttention is still in progress.") - model_compile_enabled = ( - job_config.compile.enable and "model" in job_config.compile.components - ) if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( @@ -90,6 +87,9 @@ def parallelize_llama( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: apply_compile(model) From 91979081a6e98bd8f6e2f27cf69dd0102c1f8bde Mon Sep 17 00:00:00 2001 From: Lan Li Date: Tue, 26 Aug 2025 02:28:00 +0800 Subject: [PATCH 121/128] fix(dataloader): Prevent RuntimeError from DataloaderStopIteration (#1627) The DataloaderStopIteration exception inherited from StopIteration. According to PEP 479, raising a StopIteration subclass from a generator causes a RuntimeError in Python 3.7+. This change modifies the base class to `Exception` to ensure it can be caught correctly by user code without triggering this behavior. Fixes ISSUE #1626 --- torchtitan/components/dataloader.py | 10 ++++++++-- torchtitan/experiments/forge/example_train.py | 6 +++--- torchtitan/train.py | 6 +++--- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/torchtitan/components/dataloader.py b/torchtitan/components/dataloader.py index 2b555a91bb..071af84d54 100644 --- a/torchtitan/components/dataloader.py +++ b/torchtitan/components/dataloader.py @@ -16,8 +16,14 @@ from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.tools.logging import logger - -class DataloaderStopIteration(StopIteration): +# NOTE: This class deliberately inherits from `Exception` and not `StopIteration`. +# According to PEP 479, raising a `StopIteration` or its subclass from within a +# generator will wrap it in a `RuntimeError`. Since this exception is designed +# to be raised from a generator-based dataloader and caught by the training loop, +# inheriting from `StopIteration` would make it uncatchable and would crash the +# program. +# See: https://peps.python.org/pep-0479/ +class DataloaderExhaustedError(Exception): """An exception that indicates dataloader exhaustion.""" pass diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index a43f27c926..0c728a1f72 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -13,7 +13,7 @@ from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module -from torchtitan.components.dataloader import DataloaderStopIteration +from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.metrics import build_metrics_processor from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.components.validate import build_validator @@ -135,7 +135,7 @@ def batch_generator( except StopIteration as ex: # If data runs out during gradient accumulation, that # entire step will not be executed. - raise DataloaderStopIteration() from ex + raise DataloaderExhaustedError() from ex data_load_start = time.perf_counter() input_dict, labels = batch self.metrics_processor.ntokens_since_last_log += labels.numel() @@ -292,7 +292,7 @@ def train(self): self.gc_handler.run(self.step) try: self.train_step(data_iterator) - except DataloaderStopIteration: + except DataloaderExhaustedError: logger.warning("Ran out of data; last step was canceled.") break diff --git a/torchtitan/train.py b/torchtitan/train.py index 758a5a6995..9b69fd6798 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -15,7 +15,7 @@ import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager -from torchtitan.components.dataloader import DataloaderStopIteration +from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training from torchtitan.components.loss import rescale_accumulated_loss from torchtitan.components.metrics import ( @@ -386,7 +386,7 @@ def batch_generator( except StopIteration as ex: # If data runs out during gradient accumulation, that # entire step will not be executed. - raise DataloaderStopIteration() from ex + raise DataloaderExhaustedError() from ex input_dict, labels = batch ntokens_batch = labels.numel() self.ntokens_seen += ntokens_batch @@ -583,7 +583,7 @@ def train(self): self.gc_handler.run(self.step) try: self.train_step(data_iterator) - except DataloaderStopIteration: + except DataloaderExhaustedError: logger.warning("Ran out of data; last step was canceled.") break From 030879fa59be48e3edafe997a00877671c461e6c Mon Sep 17 00:00:00 2001 From: Jiani Wang <40016222+wwwjn@users.noreply.github.com> Date: Mon, 25 Aug 2025 15:40:51 -0700 Subject: [PATCH 122/128] [Qwen3] Fix weight tying for Qwen3 according to Huggingface configs (#1633) As titled. Only enable weight tying for smaller model --- torchtitan/experiments/qwen3/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchtitan/experiments/qwen3/__init__.py b/torchtitan/experiments/qwen3/__init__.py index d22053ff65..9ea4582aaf 100644 --- a/torchtitan/experiments/qwen3/__init__.py +++ b/torchtitan/experiments/qwen3/__init__.py @@ -40,6 +40,7 @@ qk_norm=True, hidden_dim=3072, rope_theta=1000000, + enable_weight_tying=True, ), "1.7B": Qwen3ModelArgs( vocab_size=151936, @@ -52,6 +53,7 @@ qk_norm=True, hidden_dim=6144, rope_theta=1000000, + enable_weight_tying=True, ), "4B": Qwen3ModelArgs( vocab_size=151936, @@ -64,6 +66,7 @@ qk_norm=True, hidden_dim=9728, rope_theta=1000000, + enable_weight_tying=True, ), "8B": Qwen3ModelArgs( vocab_size=151936, From ad06609c4e3ac119112b097b6b317f31c5c760f5 Mon Sep 17 00:00:00 2001 From: wesleytruong Date: Mon, 25 Aug 2025 16:54:25 -0700 Subject: [PATCH 123/128] Fix variable name in NotImplementedError message (#1637) self.score_function should be self.score_func --- torchtitan/models/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 3a9dd1b28a..8be14ecbf0 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -216,7 +216,7 @@ def forward( elif self.score_func == "softmax": scores = F.softmax(scores.to(torch.float32), dim=1) else: - raise NotImplementedError(f"Unknown score function {self.score_function}") + raise NotImplementedError(f"Unknown score function {self.score_func}") # top scores shape (bs*slen, top_k) # NOTE: The expert_bias is only used for routing. The gating value From 4191def494fced2be9564014bc6f7a4780a0e1fc Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 26 Aug 2025 11:08:39 -1000 Subject: [PATCH 124/128] Update torchft.md (#1596) Add some basic documentation on how to use Titan with TorchFt for DiLoCo. LMK if anything needs clarification @vishal9-team --- docs/torchft.md | 50 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/docs/torchft.md b/docs/torchft.md index 02dec2409d..68733ec1ec 100644 --- a/docs/torchft.md +++ b/docs/torchft.md @@ -10,6 +10,8 @@ TorchFT is designed to provide fault tolerance when training with replicated wei Before using TorchFT with TorchTitan, you need to install TorchFT by following the instructions in the [TorchFT README](https://github.com/pytorch/torchft/blob/main/README.md) to install TorchFT. +Alternatively, you can install TorchFT with `pip install torchft-nightly`. + ## Configuring TorchTitan for Using TorchFT When using TorchFT with TorchTitan, you need to launch multiple replica groups, each of which is a separate TorchTitan instance. Each replica group is responsible for maintaining a copy of the model weights. In case of a failure, the other replica groups can continue training without lossing weight information. @@ -21,20 +23,23 @@ For example, if you want to run HSDP on a single machine with eight GPUs, where Let's consider an example where we want to run HSDP on a single machine with eight GPUs, where weights are sharded within four GPUs with two replica groups (2, 4 device mesh). Without using TorchFT, you can launch such a training process by specifying `--parallelism.data_parallel_replica_degree=2 --parallelism.data_parallel_shard_degree=4`. However, in the event of a trainer failure (emulating a real-world machine failure), the entire training process would need to stop and recover from the last checkpoint. This can lead to significant downtime and wasted resources. With TorchFT, we can tolerate one replica group failure, ensuring that the training process continues uninterrupted. To achieve this, we can launch two TorchTitan instances, each managing four GPUs and communicating with each other through TorchFT. This setup allows for seamless fault tolerance and minimizes the impact of individual trainer failures. - -### Launching TorchFT with TorchTitan +### Launching TorchFT with TorchTitan (Example 1) To launch TorchFT with TorchTitan, you need to execute the following three commands in different shell sessions: 1. Launch TorchFT lighthouse: + ```bash RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 ``` + 2. Launch the first TorchTitan instance: + ```bash NGPU=4 CUDA_VISIBLE_DEVICES=0,1,2,3 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --fault_tolerance.enable --fault_tolerance.replica_id=0 --fault_tolerance.group_size=2 --parallelism.data_parallel_shard_degree=4 ``` 3. Launch the second TorchTitan instance: + ```bash NGPU=4 CUDA_VISIBLE_DEVICES=4,5,6,7 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --fault_tolerance.enable --fault_tolerance.replica_id=1 --fault_tolerance.group_size=2 --parallelism.data_parallel_shard_degree=4 ``` @@ -48,3 +53,44 @@ NGPU=4 CUDA_VISIBLE_DEVICES=4,5,6,7 CONFIG_FILE="./torchtitan/models/llama3/trai * Note that the alive replica group with the smallest replica ID will perform checkpointing saving. In a real-world scenario, `torchft_lighthouse` would likely be on a different machine. The `TORCHFT_LIGHTHOUSE` environment variable is used to tell TorchFT how to communicate with `torchft_lighthouse`. The default value is `http://localhost:29510`. + +### Using semi-synchronous training (Example 2) + +TorchFT provides algorithms that do not require per-step synchronization and +the replica groups can sychronize weights every N steps. + +**Note on Batch Sizes**: For DiLoCo, there's an important distinction in batch size terminology: + +The `--training.global_batch_size` parameter refers to global batch size that will be split across all replica groups. + +- **Global batch size**: The total batch size across all DiLoCo islands/replica groups +- **Inner global batch size**: The batch size within each individual DiLoCo island. This is determined by dividing global batch size by number of replica groups. + +#### Replica Group 0 +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0 --fault_tolerance.semi_sync_method="diloco" --experimental.custom_args_module=torchtitan.components.ft.config +``` + +#### Replica Group 1 +```bash +CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1 --fault_tolerance.semi_sync_method="diloco" --experimental.custom_args_module=torchtitan.components.ft.config +``` + +## Fault Tolerance Configuration Options + +For complete configuration options, run `NGPU=1 ./run_train.sh --help`. + +[Optional] Only for semi-synchronous training: + +- `--fault_tolerance.sync_steps`: The number of training steps before synchronization. +- `--fault_tolerance.semi_sync_method`: Synchronization method (e.g., "local_sgd", "diloco") + +For more semi-synchronouse configuration options, see [ft/config/job_config.py](config/job_config.py). + +## Environment Variables + +- `TORCHFT_LIGHTHOUSE`: URL of the lighthouse service +- `TORCHFT_MANAGER_PORT`: Port for the TorchFT manager +- `REPLICA_GROUP_ID`: Identifier for the replica group +- `RUST_LOGS`: Logging level for Rust components +- `RUST_BACKTRACE`: Enable backtrace for debugging From e65ef30dec74e8e592191a1487e92256d27ecb2b Mon Sep 17 00:00:00 2001 From: Hossein Kaviani Date: Tue, 26 Aug 2025 14:13:28 -0700 Subject: [PATCH 125/128] Adding StateDictAdapter (#1601) In this PR, I'm adding the StateDictAdapter for Qwen3 to enable loading HF checkpoints. We can use this script to adapt the checkpoint from HF to the format that we can load into the torchtitan model and vice versa. This can enable us to do a parity test with the HF implementation and make sure that our results are aligned with the HF implementation. --------- Co-authored-by: Hossein Kavianihamedani --- torchtitan/experiments/qwen3/README.md | 9 +- torchtitan/experiments/qwen3/__init__.py | 3 +- .../qwen3/model/state_dict_adapter.py | 86 +++++++++++++++++++ 3 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 torchtitan/experiments/qwen3/model/state_dict_adapter.py diff --git a/torchtitan/experiments/qwen3/README.md b/torchtitan/experiments/qwen3/README.md index 77b23d55ce..d6e7591811 100644 --- a/torchtitan/experiments/qwen3/README.md +++ b/torchtitan/experiments/qwen3/README.md @@ -6,22 +6,21 @@ QWEN3 0.6B Dense model is available for: - FSDP/HSDP, TP, DDP, AC, compile support -Other model sizes are added to the args, but toml file configs need to be added and tested. Further testing is needed to check the coistency of the parallelism implementations. +Other model sizes are added to the args, but toml file configs need to be added and tested. #### Download Qwen3 tokenizer -```python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --asset tokenizer``` - +```python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --assets tokenizer``` #### Parity with HF -Model parity test has been done and results suggest parity with HF implementation. Further investigation is needed to check the sanity of the Rope function. +Model parity test has been done and results suggest parity with HF implementation. #### To be added - Modeling - Variants of Dense models up to 32B - MoE alternatives - - Weight tying + - Testing - The model should be tested against established performance benchmarks - CI integration diff --git a/torchtitan/experiments/qwen3/__init__.py b/torchtitan/experiments/qwen3/__init__.py index 9ea4582aaf..b5aa870d4e 100644 --- a/torchtitan/experiments/qwen3/__init__.py +++ b/torchtitan/experiments/qwen3/__init__.py @@ -17,6 +17,7 @@ from .infra.parallelize import parallelize_qwen3 from .model.args import Qwen3ModelArgs from .model.model import Qwen3Model +from .model.state_dict_adapter import Qwen3StateDictAdapter __all__ = [ "parallelize_qwen3", @@ -25,7 +26,6 @@ "qwen3_configs", ] - # Adding different variants of the model qwen3_configs = { @@ -120,5 +120,6 @@ build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, build_validator_fn=build_validator, + state_dict_adapter=Qwen3StateDictAdapter, ) ) diff --git a/torchtitan/experiments/qwen3/model/state_dict_adapter.py b/torchtitan/experiments/qwen3/model/state_dict_adapter.py new file mode 100644 index 0000000000..760cc662be --- /dev/null +++ b/torchtitan/experiments/qwen3/model/state_dict_adapter.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script is adapted from torchtitan/models/llama3/model/state_dict_adapter.py. + +We can use this script to adapt the checkpoint from HF to the format that we can load into the torchtitan model and vice versa. +This can enable us to do a parity test with the HF implementation and make sure that our results are +aligned with the HF implementation. + +""" +import re +from typing import Any + +from torchtitan.protocols.state_dict_adapter import StateDictAdapter + +from .args import Qwen3ModelArgs + + +class Qwen3StateDictAdapter(StateDictAdapter): + def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None): + super().__init__(model_args, hf_assets_path) + + self.model_args = model_args + self.hf_assets_path = hf_assets_path + + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + hf_state_dict = {} + + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = to_hf_map[abstract_key] + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = to_hf_map[key] + + hf_state_dict[new_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + + state_dict = {} + + for key, value in hf_state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = self.from_hf_map[abstract_key] + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = self.from_hf_map[key] + + state_dict[new_key] = value + return state_dict From 17ef753ef90889451782e67dda7fe1dc3b7f6a5b Mon Sep 17 00:00:00 2001 From: Anna Cai <31640097+anana10c@users.noreply.github.com> Date: Wed, 27 Aug 2025 12:25:22 -0700 Subject: [PATCH 126/128] add wandb team entity and run name options (#1643) Summary: Allow user to configure WandB entity (aka team) and run name through environment variables. Differential Revision: D80499210 --- torchtitan/components/metrics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 720e2b9d6d..e42753c82c 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -144,7 +144,9 @@ def __init__(self, log_dir: str, job_config: JobConfig, tag: str | None = None): os.makedirs(log_dir, exist_ok=True) self.wandb.init( + entity=os.getenv("WANDB_TEAM", None), project=os.getenv("WANDB_PROJECT", "torchtitan"), + name=os.getenv("WANDB_RUN_NAME", None), dir=log_dir, config=job_config.to_dict(), ) From a481c267b12a504da515d46911f2120896781602 Mon Sep 17 00:00:00 2001 From: Danning XIE <24580222+DNXie@users.noreply.github.com> Date: Wed, 27 Aug 2025 15:34:42 -0700 Subject: [PATCH 127/128] Solving the validation hanging issue (#1634) As the [discussion](https://github.com/pytorch/torchtitan/issues/1618#issuecomment-3215472411), I added: - warning message when the validation `steps`=-1 in both comment and logger - change the default `steps` to reasonable values with the common setup (world size = 8). - Add infinite loop support for validator to avoid hang when `steps` is large enough to exhaust the dataset. - Add the same fix for flux. ## Test - 8 GPUs with `steps=-1`: hang around step 1270 - 8 GPUs with `steps=1200`: good - 8 GPUs with `steps=1500`: `infinite` automatically set to true. Exhaust the dataset and re-iterate, but won't hang - Flux: `steps=-1` doesn't hang - Flux: `steps=60` doesn't hang; re-loop the dataset. Full thread: https://github.com/pytorch/torchtitan/issues/1618 cc @ebsmothers @tianyu-l --- torchtitan/components/validate.py | 7 +++++++ torchtitan/config/job_config.py | 5 ++++- torchtitan/datasets/hf_datasets.py | 3 ++- torchtitan/experiments/flux/dataset/flux_dataset.py | 5 ++++- torchtitan/experiments/flux/train_configs/debug_model.toml | 2 +- .../experiments/flux/train_configs/flux_dev_model.toml | 2 +- .../experiments/flux/train_configs/flux_schnell_model.toml | 1 + torchtitan/experiments/flux/validate.py | 7 +++++++ torchtitan/models/llama3/train_configs/llama3_405b.toml | 2 +- torchtitan/models/llama3/train_configs/llama3_70b.toml | 2 +- torchtitan/models/llama3/train_configs/llama3_8b.toml | 4 ++-- 11 files changed, 31 insertions(+), 9 deletions(-) diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index a88b41a508..d704447703 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -17,6 +17,7 @@ from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.tools import utils +from torchtitan.tools.logging import logger class BaseValidator: @@ -66,6 +67,7 @@ def __init__( dp_world_size=dp_world_size, dp_rank=dp_rank, tokenizer=tokenizer, + infinite=self.job_config.validation.steps != -1, ) self.validation_context = validation_context self.maybe_enable_amp = maybe_enable_amp @@ -74,6 +76,11 @@ def __init__( self.pp_has_first_stage = pp_has_first_stage self.pp_has_last_stage = pp_has_last_stage + if self.job_config.validation.steps == -1: + logger.warning( + "Setting validation steps to -1 could cause hangs due to mismatch among ranks." + ) + @torch.no_grad() def validate( self, diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index a2247aa210..14ff6e7e8a 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -729,7 +729,10 @@ class Validation: """Frequency of validation""" steps: int = -1 - """Number of steps to take in the validation set, -1 means consuming all the data in the validation dataset""" + """ + Number of steps to take in the validation set, -1 means consuming all the data in the validation dataset + WARNING: When setting to -1 there could be hangs due to mismatch among ranks + """ def __post_init__(self): assert ( diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 0e30f8fe51..fce42f0655 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -209,6 +209,7 @@ def build_hf_validation_dataloader( dp_rank: int, tokenizer: BaseTokenizer, job_config: JobConfig, + infinite: bool = False, ) -> ParallelAwareDataloader: """Build a validation data loader for HuggingFace datasets.""" dataset_name = job_config.validation.dataset @@ -223,7 +224,7 @@ def build_hf_validation_dataloader( seq_len=seq_len, dp_rank=dp_rank, dp_world_size=dp_world_size, - infinite=False, + infinite=infinite, ) return ParallelAwareDataloader( diff --git a/torchtitan/experiments/flux/dataset/flux_dataset.py b/torchtitan/experiments/flux/dataset/flux_dataset.py index 02fd73afec..c7656bfd26 100644 --- a/torchtitan/experiments/flux/dataset/flux_dataset.py +++ b/torchtitan/experiments/flux/dataset/flux_dataset.py @@ -367,6 +367,7 @@ def __init__( dp_rank: int = 0, dp_world_size: int = 1, generate_timesteps: bool = True, + infinite: bool = False, ) -> None: # Call parent constructor correctly super().__init__( @@ -377,7 +378,7 @@ def __init__( job_config=job_config, dp_rank=dp_rank, dp_world_size=dp_world_size, - infinite=False, + infinite=infinite, ) # Initialize timestep generation for validation @@ -406,6 +407,7 @@ def build_flux_validation_dataloader( # This parameter is not used, keep it for compatibility tokenizer: BaseTokenizer | None, generate_timestamps: bool = True, + infinite: bool = False, ) -> ParallelAwareDataloader: """Build a data loader for HuggingFace datasets.""" dataset_name = job_config.validation.dataset @@ -423,6 +425,7 @@ def build_flux_validation_dataloader( dp_rank=dp_rank, dp_world_size=dp_world_size, generate_timesteps=generate_timestamps, + infinite=infinite, ) return ParallelAwareDataloader( diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml index d22815e59e..ee1bae220f 100644 --- a/torchtitan/experiments/flux/train_configs/debug_model.toml +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -69,7 +69,7 @@ enable = false dataset = "coco-validation" freq = 5 local_batch_size = 8 -steps = 1 +steps = 48 # Recommended value with the current settings and world_size=8 # args for sampling images enable_classifier_free_guidance = true classifier_free_guidance_scale = 5.0 diff --git a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml index 389b1aa9a7..e425e6f774 100644 --- a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml @@ -67,7 +67,7 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] enable = false dataset = "coco-validation" local_batch_size = 32 -steps = 1 +steps = 12 # Recommended value with the current settings and world_size=8 freq = 1000 enable_classifier_free_guidance = true classifier_free_guidance_scale = 5.0 diff --git a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml index 9e1cbb85fa..fdf977786d 100644 --- a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml @@ -68,6 +68,7 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] enable = false dataset = "coco-validation" local_batch_size=64 +steps = 6 # Recommended value with the current settings and world_size=8 freq = 1000 enable_classifier_free_guidance = true classifier_free_guidance_scale = 5.0 diff --git a/torchtitan/experiments/flux/validate.py b/torchtitan/experiments/flux/validate.py index 28f09156b1..7e4f015b82 100644 --- a/torchtitan/experiments/flux/validate.py +++ b/torchtitan/experiments/flux/validate.py @@ -32,6 +32,7 @@ preprocess_data, unpack_latents, ) +from torchtitan.tools.logging import logger class FluxValidator(Validator): @@ -72,12 +73,18 @@ def __init__( dp_rank=dp_rank, tokenizer=tokenizer, generate_timestamps=not self.all_timesteps, + infinite=self.job_config.validation.steps != -1, ) self.validation_context = validation_context self.maybe_enable_amp = maybe_enable_amp self.metrics_processor = metrics_processor self.t5_tokenizer, self.clip_tokenizer = build_flux_tokenizer(self.job_config) + if self.job_config.validation.steps == -1: + logger.warning( + "Setting validation steps to -1 could cause hangs due to mismatch among ranks." + ) + def flux_init( self, device: torch.device, diff --git a/torchtitan/models/llama3/train_configs/llama3_405b.toml b/torchtitan/models/llama3/train_configs/llama3_405b.toml index 824918eae2..08cb70dbbe 100644 --- a/torchtitan/models/llama3/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_405b.toml @@ -68,4 +68,4 @@ filter_fqns = ["output"] enable = false dataset = "c4_validation" freq = 500 -steps = -1 +steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192 diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index 20756ba421..ae3024211d 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -67,4 +67,4 @@ filter_fqns = ["output"] enable = false dataset = "c4_validation" freq = 500 -steps = -1 +steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192 diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 577d489c33..0bb9ae0d05 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -67,5 +67,5 @@ filter_fqns = ["output"] [validation] enable = false dataset = "c4_validation" -freq = 100 -steps = -1 +freq = 500 +steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192 From 715641639d689f4ace2cfb900a58503686e0b114 Mon Sep 17 00:00:00 2001 From: Danning XIE <24580222+DNXie@users.noreply.github.com> Date: Thu, 28 Aug 2025 01:45:09 -0700 Subject: [PATCH 128/128] update warning message (#1648) Follow up for https://github.com/pytorch/torchtitan/pull/1634 Updated the warning message according to [Jiani's suggestion](https://github.com/pytorch/torchtitan/pull/1634#pullrequestreview-3161237853) cc @wwwjn --- torchtitan/components/validate.py | 3 ++- torchtitan/experiments/flux/validate.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index d704447703..f725bd0563 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -78,7 +78,8 @@ def __init__( if self.job_config.validation.steps == -1: logger.warning( - "Setting validation steps to -1 could cause hangs due to mismatch among ranks." + "Setting validation steps to -1 might cause hangs because of " + "unequal sample counts across ranks when dataset is exhausted." ) @torch.no_grad() diff --git a/torchtitan/experiments/flux/validate.py b/torchtitan/experiments/flux/validate.py index 7e4f015b82..6cd9a6db3e 100644 --- a/torchtitan/experiments/flux/validate.py +++ b/torchtitan/experiments/flux/validate.py @@ -82,7 +82,8 @@ def __init__( if self.job_config.validation.steps == -1: logger.warning( - "Setting validation steps to -1 could cause hangs due to mismatch among ranks." + "Setting validation steps to -1 might cause hangs because of " + "unequal sample counts across ranks when dataset is exhausted." ) def flux_init(