From 69a73837b0f1bd8619fbdab953f97715fe5f5970 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 3 Nov 2025 14:13:29 -0800 Subject: [PATCH 01/38] Use new DeviceMesh unflatten to rewrite parallel_dims This is a demonstration of how parallel_dims will be when using https://github.com/pytorch/pytorch/pull/161224 stack. ghstack-source-id: d29d2e2908a0529a9c343633961f82dc05cef0b7 Pull-Request: https://github.com/pytorch/torchtitan/pull/1885 --- torchtitan/distributed/parallel_dims.py | 141 +++++++++++++++++++++++- torchtitan/train.py | 9 +- 2 files changed, 143 insertions(+), 7 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 44822039a6..3144dfdbe4 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from collections import defaultdict from dataclasses import dataclass from torch.distributed.device_mesh import DeviceMesh, init_device_mesh @@ -25,6 +26,7 @@ class ParallelDims: ep: int etp: int world_size: int + mesh_dim_names: tuple[str] = tuple() _world_mesh: DeviceMesh = None @@ -63,6 +65,134 @@ def _validate(self): # EP would borrow all cp and tp and some dp_shard degree assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0 + def build_mesh(self) -> "ParallelDims": + """Build the device mesh with the required mesh dimensions. + + The following mesh dimensions may be created based on the parallel configuration: + + pp: For PP. + dp_replicate: For DDP or HSDP replicate dimension. + dp_shard_cp: For FSDP or HSDP shard dimension. This includes + ``cp`` even if ``cp`` is 1, so we just use the name + ``dp_shard_cp``. As a result, we always use the name + ``dp_shard_cp`` and ``dp_shard`` is not created as a + dimension. + dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``, + ``dp_shard``, and ``cp`` as all of them are data parallelisms. + dp: This is used by data loading. It includes both ``dp_replicate`` + and ``dp_shard``. + The naming can be confusing; ``batch`` could be a better name. + cp: For CP. + tp: For TP. + ep: For EP. + dp_shard_mod_ep: For FSDP or HSDP shard dimension in EP region. + + Note: These dimensions won't exist at the same time. The meshes we need to + unflatten from world_mesh, assuming all degrees are > 1 except for ``pp``: + + ["dp", "cp", "tp"]: ``dp`` process group is wasted as dataloader + doesn't need it. + + ["dp_cp", "tp"]: loss computation + + ["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. + + ["dp_replicate", "dp_shard_mod_ep", "ep", "tp"]: EP region computation if etp == tp. + + ["dp_replicate", "dp_shard_mod_ep", "ep"]: EP region computation if etp == 1. + """ + + def add_dim(name, degree, config): + config["name"].append(name) + config["degree"].append(degree) + + world_mesh = init_device_mesh(device_type, [self.world_size]) + dp_shard_mod_ep = ( + self.dp_shard * self.cp // self.ep + if self.etp == self.tp + else self.dp_shard * self.cp * self.tp // self.ep + ) + + data_mesh_dims = defaultdict(list) + non_ep_computation_dims = defaultdict(list) + ep_computation_dims = defaultdict(list) + + if self.pp_enabled: + add_dim("pp", self.pp, data_mesh_dims) + add_dim("pp", self.pp, non_ep_computation_dims) + add_dim("pp", self.pp, ep_computation_dims) + + if self.dp_enabled: + add_dim("dp", self.dp_replicate * self.dp_shard, data_mesh_dims) + if self.dp_replicate_enabled: + add_dim("dp_replicate", self.dp_replicate, non_ep_computation_dims) + add_dim("dp_replicate", self.dp_replicate, ep_computation_dims) + if self.dp_shard_enabled: + add_dim("dp_shard_cp", self.dp_shard * self.cp, non_ep_computation_dims) + add_dim("dp_shard_mod_ep", dp_shard_mod_ep, ep_computation_dims) + + if self.cp_enabled: + add_dim("cp", self.cp, data_mesh_dims) + + if self.tp_enabled: + add_dim("tp", self.tp, data_mesh_dims, non_ep_computation_dims) + if self.etp == self.tp: + add_dim("tp", self.tp, ep_computation_dims) + + self._all_meshes = [] + + if self.dp_enabled: + data_mesh = world_mesh._unflatten( + 0, data_mesh_dims["degree"], data_mesh_dims["name"] + ) + self._all_meshes.append(data_mesh) + # Note that we don't create loss_mesh as it is easier to flatten + # from data_mesh + if self.cp_enabled: + self._all_meshes[-1]["dp", "cp"]._flatten(mesh_dim_name="dp_cp") + else: + self._all_meshes[-1]["dp"]._flatten(mesh_dim_name="dp_cp") + + if self.dp_cp_enabled or self.tp_enabled or self.pp_enabled: + self._all_meshes.append( + world_mesh._unflatten( + 0, + non_ep_computation_dims["degree"], + non_ep_computation_dims["name"], + ) + ) + + if self.ep_enabled: + add_dim("ep", self.ep, ep_computation_dims) + self._all_meshes.append( + world_mesh._unflatten( + 0, ep_computation_dims["degree"], ep_computation_dims["name"] + ) + ) + + self._world_mesh = world_mesh + self.mesh_dim_names = tuple( + name for m in self._all_meshes for name in m.mesh_dim_names + ) + return self + + def __getitem__(self, name): + # This is a hack to make ParallelDims behave like a DeviceMesh. + # We will need to change trainer if design is concluded. For now, + # this is just a quick hack to make it work with unflatten() + + if "mesh_dim_names" == name: + return [name for m in self._all_meshes for name in m.mesh_dim_names] + + for mesh in self._all_meshes: + try: + submesh = mesh[name] + return submesh + except KeyError: + pass + raise AttributeError(f"ParallelDims has no attribute {name}") + + """ def build_mesh(self) -> DeviceMesh: # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel # is not very clean, due to the limited support from DeviceMesh @@ -188,14 +318,19 @@ def _build_mesh_without_ep(self) -> DeviceMesh: mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") return mesh + """ @property - def world_mesh(self) -> DeviceMesh: + def world_mesh(self) -> "ParallelDims": + # This is a hack to make ParallelDims behave like a DeviceMesh. + # We will need to change trainer if design is concluded. For now, + # this is just a quick hack to make it work with unflatten() + # doing late init so ParallelDims can still be used as a lightweight # dataclass without having to initialize the world mesh if self._world_mesh is None: - self._world_mesh = self.build_mesh() - return self._world_mesh + self.build_mesh() + return self @property def dp_enabled(self): diff --git a/torchtitan/train.py b/torchtitan/train.py index c897ee3c8a..73441f3a0f 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -14,9 +14,8 @@ import torch -from torch.distributed.elastic.multiprocessing.errors import record - import torchtitan.protocols.train_spec as train_spec_module +from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training @@ -109,12 +108,14 @@ def __init__(self, job_config: JobConfig): # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). + """ dist_utils.set_determinism( - world_mesh, + world_mesh._world_mesh, self.device, job_config.debug, distinct_seed_mesh_dims=["pp"], ) + """ self.train_spec = train_spec_module.get_train_spec(job_config.model.name) # build tokenizer and dataloader @@ -687,7 +688,7 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.parallel_dims.world_mesh, + world_mesh=self.parallel_dims._world_mesh, ) if torch.distributed.get_rank() == 0: From 225bcfb41329e33eda89dd888f533634f6a08743 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 15 Oct 2025 11:35:04 -0700 Subject: [PATCH 02/38] misc ghstack-source-id: f7c3fefa0e48bd8deb201b5741b39744e1597e59 Pull-Request: https://github.com/pytorch/torchtitan/pull/1886 --- torchtitan/distributed/parallel_dims.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 3144dfdbe4..c06500e266 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -87,19 +87,24 @@ def build_mesh(self) -> "ParallelDims": ep: For EP. dp_shard_mod_ep: For FSDP or HSDP shard dimension in EP region. - Note: These dimensions won't exist at the same time. The meshes we need to - unflatten from world_mesh, assuming all degrees are > 1 except for ``pp``: + Note: These dimensions won't exist at the same time. If we consider + unflatten() operator only, following are all the meshes required + assuming all degrees are > 1 except for ``pp``: ["dp", "cp", "tp"]: ``dp`` process group is wasted as dataloader doesn't need it. - ["dp_cp", "tp"]: loss computation - ["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. - ["dp_replicate", "dp_shard_mod_ep", "ep", "tp"]: EP region computation if etp == tp. - ["dp_replicate", "dp_shard_mod_ep", "ep"]: EP region computation if etp == 1. + + In reality, we don't actually need to create all of these meshes. + For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"]. + So we don't actually need to create ["dp_cp", "tp"]. + + But there are some meses we MUST create if that mesh will be used for a + parameter. So Non-EP-region-computation mesh and EP-region-computation mesh + are required. """ def add_dim(name, degree, config): From f8fda7ba3a42cae2c0a5e41abd06bf2edb4ff5a3 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 15 Oct 2025 11:35:08 -0700 Subject: [PATCH 03/38] Delete legacy code ghstack-source-id: cf7ad2a7aca42beec0a90e5bcc34177f0c5293ae Pull-Request: https://github.com/pytorch/torchtitan/pull/1887 --- torchtitan/distributed/parallel_dims.py | 128 ------------------------ 1 file changed, 128 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index c06500e266..6b0a3de79f 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -197,134 +197,6 @@ def __getitem__(self, name): pass raise AttributeError(f"ParallelDims has no attribute {name}") - """ - def build_mesh(self) -> DeviceMesh: - # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel - # is not very clean, due to the limited support from DeviceMesh - # for creating two staggered meshes. Will improve. - if self.ep > 1: - return self._build_mesh_with_ep() - else: - return self._build_mesh_without_ep() - - def _build_mesh_with_ep(self) -> DeviceMesh: - # With ep, dp_shard and ep are derived submeshes: - # dp_shard = dp_shard_mod_ep * dp_shard_in_ep - if self.etp == self.tp: - # ep = dp_shard_in_ep * cp - dp_shard_mod_ep = self.dp_shard * self.cp // self.ep - dp_shard_in_ep = self.ep // self.cp - else: - assert self.etp == 1 - # ep = dp_shard_in_ep * cp * tp - dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep - dp_shard_in_ep = self.ep // (self.cp * self.tp) - - dims = [] - names = [] - for d, name in zip( - [ - self.pp, - self.dp_replicate, - dp_shard_mod_ep, - dp_shard_in_ep, - self.cp, - self.tp, - ], - ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], - ): - # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping - # helps the MoE layers do mixed precision training - if d > 1 or name == "dp_shard_mod_ep": - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - # Mesh for ep - ep_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - # dp_shard_mod_ep is always needed, even if it's 1 - dp_mesh_dim_names.append("dp_shard_mod_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") - dp_cp_mesh_dim_names.append("dp_shard_mod_ep") - if "dp_shard_in_ep" in names: - dp_mesh_dim_names.append("dp_shard_in_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") - dp_cp_mesh_dim_names.append("dp_shard_in_ep") - ep_mesh_dim_names.append("dp_shard_in_ep") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - ep_mesh_dim_names.append("cp") - if self.etp == 1 and self.tp_enabled: - ep_mesh_dim_names.append("tp") - - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") - mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") - - return mesh - - def _build_mesh_without_ep(self) -> DeviceMesh: - dims = [] - names = [] - for d, name in zip( - [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], - ["pp", "dp_replicate", "dp_shard", "cp", "tp"], - ): - if d > 1: - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - if self.dp_shard_enabled: - dp_mesh_dim_names.append("dp_shard") - dp_shard_cp_mesh_dim_names.append("dp_shard") - dp_cp_mesh_dim_names.append("dp_shard") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - - if dp_mesh_dim_names != []: - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - if dp_shard_cp_mesh_dim_names != []: - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( - mesh_dim_name="dp_shard_cp" - ) - if dp_cp_mesh_dim_names != []: - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") - - return mesh - """ - @property def world_mesh(self) -> "ParallelDims": # This is a hack to make ParallelDims behave like a DeviceMesh. From c078db9d675fd384b182981d9b90cd35a69c971d Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 15 Oct 2025 11:35:12 -0700 Subject: [PATCH 04/38] misc ghstack-source-id: f7c3fefa0e48bd8deb201b5741b39744e1597e59 Pull-Request: https://github.com/pytorch/torchtitan/pull/1888 --- torchtitan/distributed/parallel_dims.py | 128 ++++++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 6b0a3de79f..c06500e266 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -197,6 +197,134 @@ def __getitem__(self, name): pass raise AttributeError(f"ParallelDims has no attribute {name}") + """ + def build_mesh(self) -> DeviceMesh: + # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel + # is not very clean, due to the limited support from DeviceMesh + # for creating two staggered meshes. Will improve. + if self.ep > 1: + return self._build_mesh_with_ep() + else: + return self._build_mesh_without_ep() + + def _build_mesh_with_ep(self) -> DeviceMesh: + # With ep, dp_shard and ep are derived submeshes: + # dp_shard = dp_shard_mod_ep * dp_shard_in_ep + if self.etp == self.tp: + # ep = dp_shard_in_ep * cp + dp_shard_mod_ep = self.dp_shard * self.cp // self.ep + dp_shard_in_ep = self.ep // self.cp + else: + assert self.etp == 1 + # ep = dp_shard_in_ep * cp * tp + dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep + dp_shard_in_ep = self.ep // (self.cp * self.tp) + + dims = [] + names = [] + for d, name in zip( + [ + self.pp, + self.dp_replicate, + dp_shard_mod_ep, + dp_shard_in_ep, + self.cp, + self.tp, + ], + ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], + ): + # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping + # helps the MoE layers do mixed precision training + if d > 1 or name == "dp_shard_mod_ep": + dims.append(d) + names.append(name) + + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + # Create all the submesh here to ensure all required process groups are + # initialized: + # Mesh for data loading (no communication on this mesh) + dp_mesh_dim_names = [] + # Mesh for param sharding + dp_shard_cp_mesh_dim_names = [] + # Mesh for loss all-reduce + dp_cp_mesh_dim_names = [] + # Mesh for ep + ep_mesh_dim_names = [] + + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_cp_mesh_dim_names.append("dp_replicate") + # dp_shard_mod_ep is always needed, even if it's 1 + dp_mesh_dim_names.append("dp_shard_mod_ep") + dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") + dp_cp_mesh_dim_names.append("dp_shard_mod_ep") + if "dp_shard_in_ep" in names: + dp_mesh_dim_names.append("dp_shard_in_ep") + dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") + dp_cp_mesh_dim_names.append("dp_shard_in_ep") + ep_mesh_dim_names.append("dp_shard_in_ep") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") + ep_mesh_dim_names.append("cp") + if self.etp == 1 and self.tp_enabled: + ep_mesh_dim_names.append("tp") + + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") + + return mesh + + def _build_mesh_without_ep(self) -> DeviceMesh: + dims = [] + names = [] + for d, name in zip( + [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], + ["pp", "dp_replicate", "dp_shard", "cp", "tp"], + ): + if d > 1: + dims.append(d) + names.append(name) + + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + # Create all the submesh here to ensure all required process groups are + # initialized: + # Mesh for data loading (no communication on this mesh) + dp_mesh_dim_names = [] + # Mesh for param sharding + dp_shard_cp_mesh_dim_names = [] + # Mesh for loss all-reduce + dp_cp_mesh_dim_names = [] + + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_cp_mesh_dim_names.append("dp_replicate") + if self.dp_shard_enabled: + dp_mesh_dim_names.append("dp_shard") + dp_shard_cp_mesh_dim_names.append("dp_shard") + dp_cp_mesh_dim_names.append("dp_shard") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") + + if dp_mesh_dim_names != []: + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + if dp_shard_cp_mesh_dim_names != []: + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( + mesh_dim_name="dp_shard_cp" + ) + if dp_cp_mesh_dim_names != []: + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + + return mesh + """ + @property def world_mesh(self) -> "ParallelDims": # This is a hack to make ParallelDims behave like a DeviceMesh. From 9089a044f752e495e7f6230ed51251d764ea6fc8 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 15 Oct 2025 11:35:17 -0700 Subject: [PATCH 05/38] misc ghstack-source-id: 6173cc5ba175dc598412fa1669c625d73cffed62 Pull-Request: https://github.com/pytorch/torchtitan/pull/1889 --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 73441f3a0f..6e907228da 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -110,7 +110,7 @@ def __init__(self, job_config: JobConfig): # (mainly for debugging, expect perf loss). """ dist_utils.set_determinism( - world_mesh._world_mesh, + world_mesh, self.device, job_config.debug, distinct_seed_mesh_dims=["pp"], From 5500624f68c707c401857ee1d9d115e00c17989f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 15 Oct 2025 11:35:21 -0700 Subject: [PATCH 06/38] lint ghstack-source-id: 065ffd4952bdecef2bd5fb297cdaa7c844f502fe Pull-Request: https://github.com/pytorch/torchtitan/pull/1890 --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 6e907228da..a96dba2bb2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -13,9 +13,9 @@ from typing import Any, Generator, Iterable import torch +from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module -from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training From cdb9b6d3692a7fc9388e164a1930134617985d5e Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 15 Oct 2025 11:35:25 -0700 Subject: [PATCH 07/38] misc ghstack-source-id: 08dd4a6beebb3c479d32478127e57f4ee2a9d57d Pull-Request: https://github.com/pytorch/torchtitan/pull/1891 --- torchtitan/distributed/parallel_dims.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index c06500e266..685d384457 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -85,7 +85,7 @@ def build_mesh(self) -> "ParallelDims": cp: For CP. tp: For TP. ep: For EP. - dp_shard_mod_ep: For FSDP or HSDP shard dimension in EP region. + dp_shard_in_ep: For FSDP or HSDP shard dimension in EP region. Note: These dimensions won't exist at the same time. If we consider unflatten() operator only, following are all the meshes required @@ -95,8 +95,8 @@ def build_mesh(self) -> "ParallelDims": doesn't need it. ["dp_cp", "tp"]: loss computation ["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. - ["dp_replicate", "dp_shard_mod_ep", "ep", "tp"]: EP region computation if etp == tp. - ["dp_replicate", "dp_shard_mod_ep", "ep"]: EP region computation if etp == 1. + ["dp_replicate", "dp_shard_in_ep", "ep", "tp"]: EP region computation if etp == tp. + ["dp_replicate", "dp_shard_in_ep", "ep"]: EP region computation if etp == 1. In reality, we don't actually need to create all of these meshes. For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"]. @@ -112,7 +112,7 @@ def add_dim(name, degree, config): config["degree"].append(degree) world_mesh = init_device_mesh(device_type, [self.world_size]) - dp_shard_mod_ep = ( + dp_shard_in_ep = ( self.dp_shard * self.cp // self.ep if self.etp == self.tp else self.dp_shard * self.cp * self.tp // self.ep @@ -134,7 +134,7 @@ def add_dim(name, degree, config): add_dim("dp_replicate", self.dp_replicate, ep_computation_dims) if self.dp_shard_enabled: add_dim("dp_shard_cp", self.dp_shard * self.cp, non_ep_computation_dims) - add_dim("dp_shard_mod_ep", dp_shard_mod_ep, ep_computation_dims) + add_dim("dp_shard_in_ep", dp_shard_in_ep, ep_computation_dims) if self.cp_enabled: add_dim("cp", self.cp, data_mesh_dims) From e306eddf2c0d1afd39f4c9bbc4c8d7ecfcb999e1 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 15 Oct 2025 11:35:29 -0700 Subject: [PATCH 08/38] misc ghstack-source-id: dcf962b85158e2c57a592619e01f89eb4d9acaf9 Pull-Request: https://github.com/pytorch/torchtitan/pull/1892 --- torchtitan/distributed/parallel_dims.py | 26 ++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 685d384457..e265099db0 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -73,27 +73,27 @@ def build_mesh(self) -> "ParallelDims": pp: For PP. dp_replicate: For DDP or HSDP replicate dimension. dp_shard_cp: For FSDP or HSDP shard dimension. This includes - ``cp`` even if ``cp`` is 1, so we just use the name - ``dp_shard_cp``. As a result, we always use the name - ``dp_shard_cp`` and ``dp_shard`` is not created as a - dimension. + ``cp`` even if ``cp`` is 1. As a result, we always + use the name ``dp_shard_cp``, and ``dp_shard`` is not + created as a dimension. dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``, ``dp_shard``, and ``cp`` as all of them are data parallelisms. - dp: This is used by data loading. It includes both ``dp_replicate`` - and ``dp_shard``. - The naming can be confusing; ``batch`` could be a better name. + dp: This is used by data loading to decide the global batch size and + which part of data this raunk should read. This dim includes both + ``dp_replicate`` and ``dp_shard``. + The name is confusing; ``batch`` could be a better name. cp: For CP. tp: For TP. ep: For EP. - dp_shard_in_ep: For FSDP or HSDP shard dimension in EP region. + dp_shard_in_ep: For FSDP or HSDP shard dimension in the EP region. Note: These dimensions won't exist at the same time. If we consider - unflatten() operator only, following are all the meshes required + the unflatten() operator only, the following are all the meshes required assuming all degrees are > 1 except for ``pp``: - ["dp", "cp", "tp"]: ``dp`` process group is wasted as dataloader - doesn't need it. - ["dp_cp", "tp"]: loss computation + ["dp", "cp", "tp"]: The ``dp`` process group is wasted as the dataloader + doesn't need it for communication. + ["dp_cp", "tp"]: Loss computation. ["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. ["dp_replicate", "dp_shard_in_ep", "ep", "tp"]: EP region computation if etp == tp. ["dp_replicate", "dp_shard_in_ep", "ep"]: EP region computation if etp == 1. @@ -102,7 +102,7 @@ def build_mesh(self) -> "ParallelDims": For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"]. So we don't actually need to create ["dp_cp", "tp"]. - But there are some meses we MUST create if that mesh will be used for a + But there are some meshes we MUST create if that mesh will be used for a parameter. So Non-EP-region-computation mesh and EP-region-computation mesh are required. """ From d5fa24e71e7bf87bd750d14343ee6d73a3c6a7c3 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 15 Oct 2025 11:35:33 -0700 Subject: [PATCH 09/38] Another round ghstack-source-id: c9fdc967e6558fd5821f187d82f58905d5649fdd Pull-Request: https://github.com/pytorch/torchtitan/pull/1893 --- torchtitan/distributed/parallel_dims.py | 416 +++++++----------- torchtitan/distributed/pipeline_parallel.py | 2 +- torchtitan/distributed/utils.py | 75 ++-- torchtitan/models/llama3/infra/parallelize.py | 22 +- torchtitan/models/llama4/infra/parallelize.py | 59 ++- torchtitan/train.py | 50 +-- 6 files changed, 247 insertions(+), 377 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index e265099db0..d2b91c7c82 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -4,8 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from torch.distributed.device_mesh import DeviceMesh, init_device_mesh @@ -26,9 +25,9 @@ class ParallelDims: ep: int etp: int world_size: int - mesh_dim_names: tuple[str] = tuple() - _world_mesh: DeviceMesh = None + _meshes: dict[str, DeviceMesh] = field(default_factory=dict) + _world_mesh: DeviceMesh | None = None def __post_init__(self): self._validate() @@ -65,277 +64,180 @@ def _validate(self): # EP would borrow all cp and tp and some dp_shard degree assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0 - def build_mesh(self) -> "ParallelDims": - """Build the device mesh with the required mesh dimensions. - - The following mesh dimensions may be created based on the parallel configuration: - - pp: For PP. + def build_mesh(self) -> DeviceMesh: + """ + Build the device mesh with the required mesh dimensions. + + The following mesh dimensions will be created: + + pp: Pipeline Parallelism (PP). + spmd: Used by SPMD DTensor RNG seed. + batch: Used by data loading to determine the global batch size and which + part of the data each rank should read. This dimension includes both + ``dp_replicate`` and ``dp_shard``. The backend is set to ``fake`` for + this dimension to avoid unnecessary process group creation. + loss: Used by all-reduce when computing the loss. Includes ``dp_replicate``, + ``dp_shard``, and ``cp`` degrees, as all are data parallelisms. dp_replicate: For DDP or HSDP replicate dimension. - dp_shard_cp: For FSDP or HSDP shard dimension. This includes - ``cp`` even if ``cp`` is 1. As a result, we always - use the name ``dp_shard_cp``, and ``dp_shard`` is not - created as a dimension. - dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``, - ``dp_shard``, and ``cp`` as all of them are data parallelisms. - dp: This is used by data loading to decide the global batch size and - which part of data this raunk should read. This dim includes both - ``dp_replicate`` and ``dp_shard``. - The name is confusing; ``batch`` could be a better name. - cp: For CP. - tp: For TP. - ep: For EP. - dp_shard_in_ep: For FSDP or HSDP shard dimension in the EP region. - - Note: These dimensions won't exist at the same time. If we consider - the unflatten() operator only, the following are all the meshes required - assuming all degrees are > 1 except for ``pp``: - - ["dp", "cp", "tp"]: The ``dp`` process group is wasted as the dataloader - doesn't need it for communication. - ["dp_cp", "tp"]: Loss computation. - ["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. - ["dp_replicate", "dp_shard_in_ep", "ep", "tp"]: EP region computation if etp == tp. - ["dp_replicate", "dp_shard_in_ep", "ep"]: EP region computation if etp == 1. - - In reality, we don't actually need to create all of these meshes. - For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"]. - So we don't actually need to create ["dp_cp", "tp"]. - - But there are some meshes we MUST create if that mesh will be used for a - parameter. So Non-EP-region-computation mesh and EP-region-computation mesh - are required. + fsdp: For FSDP dimension. This includes ``cp``. + cp: Context Parallelism (CP). + tp: Tensor Parallelism (TP). + ep: Expert Parallelism (EP). + efsdp: FSDP in the EP region. + etp: TP in the EP region. + + Note: All the dimensions above are created by unflattening the world mesh. + This API performs the following unflatten operations: + + ["pp", "spmd"] + ["pp", "batch", "cp", "tp"] + ["pp", "loss", "tp"] + ["pp", "dp_replicate", "fsdp", "tp"] + ["pp", "dp_replicate", "efsdp", "ep", "etp"] + + Note: DeviceMesh currently recreates the process group for each dimension. + It should share the process group for the same dim group to avoid unnecessary + process group creation. """ - def add_dim(name, degree, config): - config["name"].append(name) - config["degree"].append(degree) + def unflatten_mesh( + world_mesh: DeviceMesh, dim_names: tuple[str], dim_degrees: tuple[int] + ): + """Unflatten the world mesh to create the required mesh dimensions. + + Uses fake backend for dimensions with degree 1 or for 'batch' dimension + to avoid unnecessary process group creation. + """ + backend_override = {} + for name, degree in zip(dim_names, dim_degrees, strict=True): + if degree == 1 or name == "batch": + backend_override[name] = "fake" + + return world_mesh._unflatten( + 0, dim_degrees, dim_names, backend_override=backend_override + ) - world_mesh = init_device_mesh(device_type, [self.world_size]) - dp_shard_in_ep = ( - self.dp_shard * self.cp // self.ep - if self.etp == self.tp - else self.dp_shard * self.cp * self.tp // self.ep + logger.info( + f"Building device mesh with parallelism: " + f"pp={self.pp}, dp_replicate={self.dp_replicate}, dp_shard={self.dp_shard}, " + f"cp={self.cp}, tp={self.tp}, ep={self.ep}, etp={self.etp}" ) - data_mesh_dims = defaultdict(list) - non_ep_computation_dims = defaultdict(list) - ep_computation_dims = defaultdict(list) - - if self.pp_enabled: - add_dim("pp", self.pp, data_mesh_dims) - add_dim("pp", self.pp, non_ep_computation_dims) - add_dim("pp", self.pp, ep_computation_dims) - - if self.dp_enabled: - add_dim("dp", self.dp_replicate * self.dp_shard, data_mesh_dims) - if self.dp_replicate_enabled: - add_dim("dp_replicate", self.dp_replicate, non_ep_computation_dims) - add_dim("dp_replicate", self.dp_replicate, ep_computation_dims) - if self.dp_shard_enabled: - add_dim("dp_shard_cp", self.dp_shard * self.cp, non_ep_computation_dims) - add_dim("dp_shard_in_ep", dp_shard_in_ep, ep_computation_dims) - - if self.cp_enabled: - add_dim("cp", self.cp, data_mesh_dims) - - if self.tp_enabled: - add_dim("tp", self.tp, data_mesh_dims, non_ep_computation_dims) - if self.etp == self.tp: - add_dim("tp", self.tp, ep_computation_dims) - - self._all_meshes = [] - - if self.dp_enabled: - data_mesh = world_mesh._unflatten( - 0, data_mesh_dims["degree"], data_mesh_dims["name"] - ) - self._all_meshes.append(data_mesh) - # Note that we don't create loss_mesh as it is easier to flatten - # from data_mesh - if self.cp_enabled: - self._all_meshes[-1]["dp", "cp"]._flatten(mesh_dim_name="dp_cp") - else: - self._all_meshes[-1]["dp"]._flatten(mesh_dim_name="dp_cp") - - if self.dp_cp_enabled or self.tp_enabled or self.pp_enabled: - self._all_meshes.append( - world_mesh._unflatten( - 0, - non_ep_computation_dims["degree"], - non_ep_computation_dims["name"], - ) - ) + batch = self.dp_replicate * self.dp_shard + loss = self.dp_replicate * self.dp_shard * self.cp + fsdp = self.dp_shard * self.cp + efsdp = fsdp * self.tp // (self.etp * self.ep) + spmd = self.world_size // self.pp - if self.ep_enabled: - add_dim("ep", self.ep, ep_computation_dims) - self._all_meshes.append( - world_mesh._unflatten( - 0, ep_computation_dims["degree"], ep_computation_dims["name"] - ) - ) + self._world_mesh = init_device_mesh( + device_type, (self.world_size,), mesh_dim_names=("world",) + ) + pp_spmd_mesh = unflatten_mesh(self._world_mesh, ("pp", "spmd"), (self.pp, spmd)) + data_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "batch", "cp", "tp"), + (self.pp, batch, self.cp, self.tp), + ) + loss_mesh = unflatten_mesh( + self._world_mesh, ("pp", "loss", "tp"), (self.pp, loss, self.tp) + ) + dense_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "dp_replicate", "fsdp", "tp"), + (self.pp, self.dp_replicate, fsdp, self.tp), + ) + sparse_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "dp_replicate", "efsdp", "ep", "etp"), + (self.pp, self.dp_replicate, efsdp, self.ep, self.etp), + ) - self._world_mesh = world_mesh - self.mesh_dim_names = tuple( - name for m in self._all_meshes for name in m.mesh_dim_names + self._meshes = { + "pp": pp_spmd_mesh["pp"], + "spmd": pp_spmd_mesh["spmd"], + "batch": data_mesh["batch"], + "loss": loss_mesh["loss"], + "dp_replicate": dense_mesh["dp_replicate"], + "fsdp": dense_mesh["fsdp"], + "cp": data_mesh["cp"], + "tp": data_mesh["tp"], + "ep": sparse_mesh["ep"], + "efsdp": sparse_mesh["efsdp"], + "etp": sparse_mesh["etp"], + } + + # Validate mesh sizes + self._validate_meshes() + + logger.info( + f"Successfully created meshes with active dimensions: " + f"{list(self.get_all_meshes().keys())}" ) - return self - def __getitem__(self, name): - # This is a hack to make ParallelDims behave like a DeviceMesh. - # We will need to change trainer if design is concluded. For now, - # this is just a quick hack to make it work with unflatten() + return self._world_mesh + + def _validate_meshes(self): + """Validate that created meshes have the expected sizes.""" + expected_sizes = { + "pp": self.pp, + "spmd": self.world_size // self.pp, + "batch": self.dp_replicate * self.dp_shard, + "loss": self.dp_replicate * self.dp_shard * self.cp, + "dp_replicate": self.dp_replicate, + "fsdp": self.dp_shard * self.cp, + "cp": self.cp, + "tp": self.tp, + "ep": self.ep, + "efsdp": self.dp_shard * self.cp * self.tp // (self.etp * self.ep), + "etp": self.etp, + } + + for mesh_name, expected_size in expected_sizes.items(): + actual_size = self._meshes[mesh_name].size() + assert actual_size == expected_size, ( + f"Mesh '{mesh_name}' has unexpected size: " + f"expected {expected_size}, got {actual_size}" + ) - if "mesh_dim_names" == name: - return [name for m in self._all_meshes for name in m.mesh_dim_names] + def get_mesh(self, dim: str) -> DeviceMesh | None: + """Get a device mesh by dimension name. - for mesh in self._all_meshes: - try: - submesh = mesh[name] - return submesh - except KeyError: - pass - raise AttributeError(f"ParallelDims has no attribute {name}") + Args: + dim: Name of the mesh dimension. Valid options include: + 'pp', 'spmd', 'batch', 'loss', 'dp_replicate', 'fsdp', + 'cp', 'tp', 'ep', 'etp', 'efsdp' - """ - def build_mesh(self) -> DeviceMesh: - # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel - # is not very clean, due to the limited support from DeviceMesh - # for creating two staggered meshes. Will improve. - if self.ep > 1: - return self._build_mesh_with_ep() - else: - return self._build_mesh_without_ep() - - def _build_mesh_with_ep(self) -> DeviceMesh: - # With ep, dp_shard and ep are derived submeshes: - # dp_shard = dp_shard_mod_ep * dp_shard_in_ep - if self.etp == self.tp: - # ep = dp_shard_in_ep * cp - dp_shard_mod_ep = self.dp_shard * self.cp // self.ep - dp_shard_in_ep = self.ep // self.cp - else: - assert self.etp == 1 - # ep = dp_shard_in_ep * cp * tp - dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep - dp_shard_in_ep = self.ep // (self.cp * self.tp) - - dims = [] - names = [] - for d, name in zip( - [ - self.pp, - self.dp_replicate, - dp_shard_mod_ep, - dp_shard_in_ep, - self.cp, - self.tp, - ], - ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], - ): - # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping - # helps the MoE layers do mixed precision training - if d > 1 or name == "dp_shard_mod_ep": - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - # Mesh for ep - ep_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - # dp_shard_mod_ep is always needed, even if it's 1 - dp_mesh_dim_names.append("dp_shard_mod_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") - dp_cp_mesh_dim_names.append("dp_shard_mod_ep") - if "dp_shard_in_ep" in names: - dp_mesh_dim_names.append("dp_shard_in_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") - dp_cp_mesh_dim_names.append("dp_shard_in_ep") - ep_mesh_dim_names.append("dp_shard_in_ep") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - ep_mesh_dim_names.append("cp") - if self.etp == 1 and self.tp_enabled: - ep_mesh_dim_names.append("tp") - - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") - mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") - - return mesh - - def _build_mesh_without_ep(self) -> DeviceMesh: - dims = [] - names = [] - for d, name in zip( - [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], - ["pp", "dp_replicate", "dp_shard", "cp", "tp"], - ): - if d > 1: - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - if self.dp_shard_enabled: - dp_mesh_dim_names.append("dp_shard") - dp_shard_cp_mesh_dim_names.append("dp_shard") - dp_cp_mesh_dim_names.append("dp_shard") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - - if dp_mesh_dim_names != []: - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - if dp_shard_cp_mesh_dim_names != []: - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( - mesh_dim_name="dp_shard_cp" + Returns: + DeviceMesh for the requested dimension, or None if the dimension + has size 1 (i.e., parallelism is disabled for that dimension). + + Raises: + ValueError: If the requested dimension name is not valid. + """ + if not self._meshes: + self.build_mesh() + + if dim not in self._meshes: + valid_dims = sorted(self._meshes.keys()) + raise ValueError( + f"Invalid mesh dim: '{dim}'. Valid dimensions are: {valid_dims}" ) - if dp_cp_mesh_dim_names != []: - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") - return mesh - """ + if self._meshes[dim].size() == 1: + return None - @property - def world_mesh(self) -> "ParallelDims": - # This is a hack to make ParallelDims behave like a DeviceMesh. - # We will need to change trainer if design is concluded. For now, - # this is just a quick hack to make it work with unflatten() + return self._meshes[dim] - # doing late init so ParallelDims can still be used as a lightweight - # dataclass without having to initialize the world mesh + def get_all_meshes(self) -> dict[str, DeviceMesh]: + if not self._meshes: + self.build_mesh() + return {k: v for k, v in self._meshes.items() if v.size() > 1} + + @property + def world_mesh(self) -> DeviceMesh: if self._world_mesh is None: self.build_mesh() - return self + return self._world_mesh @property def dp_enabled(self): @@ -354,7 +256,7 @@ def cp_enabled(self): return self.cp > 1 @property - def dp_cp_enabled(self): + def batch_enabled(self): return self.dp_enabled or self.cp_enabled @property diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index 06dba40d6f..38f3bad1ba 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -47,7 +47,7 @@ def pipeline_llm( parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: - pp_mesh = parallel_dims.world_mesh["pp"] + pp_mesh = parallel_dims.get_mesh("pp") # Determine the number of virtual stages based on schedule type schedule_class = get_schedule_class( diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 6a73ffd083..2f16bc3837 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import itertools import math import os from collections.abc import Generator, Iterable @@ -84,24 +85,17 @@ def set_determinism( world_mesh: DeviceMesh | None, device: torch.device, debug_config: DebugConfig, - distinct_seed_mesh_dims: list[str], + distinct_seed_mesh_dim: str = "pp", ) -> None: """ Set the same DTensor manual seed for all dimensions in world mesh, but only different seeds - across dimensions denoted by `distinct_seed_mesh_dims`. An example use case is pipeline parallelism, + across dimension denoted by `distinct_seed_mesh_dim`. An example use case is pipeline parallelism, where we want to have the same seed across SPMD groups, but different seeds across PP groups. Currently, does not set seeds for the CUDA RNG since TorchTitan always uses DTensor for SPMD parallelisms, and DTensor manages its own RNG tracker, but we could extend to support both if needed. Set Determinism flags for increased reproducibility with loss of performance. - - Args: - world_mesh: Device mesh for distributed training - device: Device to use - distinct_seed_mesh_dims: List of mesh dimension names to have distinct seeds across. - seed: Base seed value (if None, will be determined automatically) - deterministic: Whether to enable deterministic algorithms """ if debug_config.deterministic: logger.info("Deterministic algorithm enabled (expect perf degradation).") @@ -140,45 +134,28 @@ def set_determinism( torch.distributed.broadcast(seed_tensor, src=0) seed = seed_tensor.to("cpu").view(torch.uint64).item() - # Set distinct seed for each rank in mesh dimensions, with dimension names provided by `distinct_seed_mesh_dims` + # Set distinct seed for each rank in mesh dimensions, with dimension name provided by `distinct_seed_mesh_dim` # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh, # and choose a unique seed for each rank on the PP mesh. - # We support multiple distinct dimensions by adding each distinct dimension's local rank to the seed. - distinct_dims_in_mesh = [ - dim - for dim in distinct_seed_mesh_dims - if world_mesh.mesh_dim_names and dim in world_mesh.mesh_dim_names - ] - - if c10d.get_world_size() > 1 and distinct_dims_in_mesh: - # Each dimension contributes: local_rank * (product of all previous dimension sizes) - # This guarantees uniqueness like multi-dimensional array indexing - seed_offset = 0 - cumulative_size = 1 - - for dim in distinct_dims_in_mesh: - distinct_mesh = world_mesh[dim] - local_rank = distinct_mesh.get_local_rank() - # Add contribution from this dimension - seed_offset += local_rank * cumulative_size - # Update cumulative size for next dimension - cumulative_size *= distinct_mesh.size() - - seed += seed_offset + # TODO(jianiw): We could further extend this to support multiple distinct dimensions instead of just one. + if ( + c10d.get_world_size() > 1 + and distinct_seed_mesh_dim in world_mesh.mesh_dim_names + ): + distinct_mesh = world_mesh[distinct_seed_mesh_dim] + seed += distinct_mesh.get_local_rank() seed %= 2**64 logger.debug( - f"Distinct dims {distinct_dims_in_mesh}, Global rank {c10d.get_rank()} using seed: {seed}" + f"{distinct_seed_mesh_dim} rank {distinct_mesh.get_local_rank()}, Global rank {c10d.get_rank()} using seed: {seed}" + ) + duplicate_seed_mesh = list( + filter( + lambda name: name != distinct_seed_mesh_dim, world_mesh.mesh_dim_names + ) ) - - # Filter out all distinct dimensions to get duplicate_seed_mesh - duplicate_seed_mesh_dims = [ - name - for name in world_mesh.mesh_dim_names - if name not in distinct_dims_in_mesh - ] duplicate_seed_mesh = ( - world_mesh[duplicate_seed_mesh_dims] if duplicate_seed_mesh_dims else None + world_mesh[duplicate_seed_mesh] if len(duplicate_seed_mesh) else None ) else: duplicate_seed_mesh = world_mesh @@ -351,7 +328,10 @@ def _get_distributed_backend(enable_cpu_backend): return torch.distributed.get_world_size() -def set_pg_timeouts(timeout, world_mesh): +def set_pg_timeouts( + timeout: timedelta, + parallel_dims: ParallelDims, +): """ Sets the timeout for all PGs in the provided mesh, and the default (world) group. @@ -370,11 +350,10 @@ def set_pg_timeouts(timeout, world_mesh): torch.distributed.barrier(device_ids=[device_module.current_device()]) device_module.synchronize() - groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)] - # None represents the 'default' PG, not part of the mesh - groups.append(None) - for group in groups: + for group in itertools.chain( + [None], [mesh.get_group() for mesh in parallel_dims.get_all_meshes().values()] + ): torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) @@ -498,9 +477,7 @@ def _clip_grad_norm_with_ep( if math.isinf(norm_type): total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) else: - total_norm = ( - ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type - ) + total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type total_norm **= 1.0 / norm_type if pp_mesh is not None: diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 13a968be96..b644f45112 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -61,7 +61,6 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -84,13 +83,14 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + tp_mesh = parallel_dims.get_mesh("tp") apply_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, tp_mesh) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -110,15 +110,16 @@ def parallelize_llama( apply_compile(model, job_config.compile) if parallel_dims.fsdp_enabled: - # apply FSDP or HSDP, potentially with Context Parallel + # dp_mesh is the mesh for FSDP/HSDP if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh = DeviceMesh._concatenate( + [parallel_dims.get_mesh("dp_replicate"), parallel_dims.get_mesh("fsdp")] + ) else: - dp_mesh_dim_names = ("dp_shard_cp",) - + dp_mesh = parallel_dims.get_mesh("fsdp") apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + dp_mesh, param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, @@ -137,11 +138,12 @@ def parallelize_llama( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_replicate_mesh = parallel_dims.get_mesh("dp_replicate") + if parallel_dims.world_size != dp_replicate_mesh.size(): raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_replicate_mesh, enable_compile=model_compile_enabled, ) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 0fb2b54eac..a9d43c29ff 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -68,7 +68,6 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -91,26 +90,21 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + tp_mesh = parallel_dims.get_mesh("tp") apply_non_moe_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, tp_mesh) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), + tp_mesh=tp_mesh, + ep_mesh=parallel_dims.get_mesh("ep"), + etp_mesh=parallel_dims.get_mesh("etp"), etp_enabled=parallel_dims.etp_enabled, ) @@ -130,21 +124,27 @@ def parallelize_llama( if model_compile_enabled: apply_compile(model, job_config.compile, parallel_dims.ep_enabled) - dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: - # apply FSDP or HSDP, potentially with Context Parallel + # dp_mesh is the mesh for FSDP/HSDP if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh = DeviceMesh._concatenate( + [parallel_dims.get_mesh("dp_replicate"), parallel_dims.get_mesh("fsdp")] + ) else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh = parallel_dims.get_mesh("fsdp") # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] + dp_mod_ep_mesh = None if parallel_dims.ep_enabled: if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + dp_mod_ep_mesh = DeviceMesh._concatenate( + [ + parallel_dims.get_mesh("dp_replicate"), + parallel_dims.get_mesh("efsdp"), + ] + ) + else: + dp_mod_ep_mesh = parallel_dims.get_mesh("efsdp") apply_fsdp( model, @@ -155,11 +155,7 @@ def parallelize_llama( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + dp_mod_ep_mesh=dp_mod_ep_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -174,9 +170,9 @@ def parallelize_llama( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if parallel_dims.world_size != dp_mesh.size(): raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, @@ -442,8 +438,7 @@ def apply_moe_ep_tp( model: nn.Module, tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, - ep_tp_mesh: DeviceMesh | None, - etp_enabled: bool, + etp_mesh: DeviceMesh | None, ): assert ep_mesh is not None or tp_mesh is not None @@ -465,7 +460,7 @@ def apply_moe_ep_tp( # replicate computation for the router "moe.router.gate": NoParallel(), } - if ep_mesh is not None and not etp_enabled: + if ep_mesh is not None and etp_mesh is None: # If TP is borrowed for EP, then split the tokens across TP ranks so that # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) @@ -492,12 +487,12 @@ def apply_moe_ep_tp( experts_mesh = tp_mesh # input Replicate, output Partial experts_plan = TensorParallel() - elif tp_mesh is None or not etp_enabled: + elif tp_mesh is None or etp_mesh is None: experts_mesh = ep_mesh # input / output sharding on the batch / tokens dim experts_plan = ExpertParallel() else: - experts_mesh = ep_tp_mesh + experts_mesh = DeviceMesh._concatenate([ep_mesh, etp_mesh]) experts_plan = ExpertTensorParallel() parallelize_module( diff --git a/torchtitan/train.py b/torchtitan/train.py index a96dba2bb2..25b6b38710 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -88,18 +88,14 @@ def __init__(self, job_config: JobConfig): # init distributed and build meshes self.parallel_dims = parallel_dims = self.init_distributed() - # Logging needs to happen after distributed initialized - job_config.maybe_log() - - world_mesh = parallel_dims.world_mesh if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] - dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + batch_mesh = parallel_dims.get_mesh("batch") + batch_degree, batch_rank = batch_mesh.size(), batch_mesh.get_local_rank() else: - dp_degree, dp_rank = 1, 0 + batch_degree, batch_rank = 1, 0 self.ft_manager = FTManager(job_config.fault_tolerance) - dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) + batch_degree, batch_rank = self.ft_manager.get_dp_info(batch_degree, batch_rank) # take control of garbage collection to avoid stragglers self.gc_handler = utils.GarbageCollection( @@ -108,14 +104,12 @@ def __init__(self, job_config: JobConfig): # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). - """ dist_utils.set_determinism( - world_mesh, + parallel_dims, self.device, job_config.debug, distinct_seed_mesh_dims=["pp"], ) - """ self.train_spec = train_spec_module.get_train_spec(job_config.model.name) # build tokenizer and dataloader @@ -126,8 +120,8 @@ def __init__(self, job_config: JobConfig): ) self.dataloader = self.train_spec.build_dataloader_fn( - dp_world_size=dp_degree, - dp_rank=dp_rank, + dp_world_size=batch_degree, + dp_rank=batch_rank, tokenizer=self.tokenizer, job_config=job_config, ) @@ -194,19 +188,20 @@ def __init__(self, job_config: JobConfig): if global_batch_size < 0: # This global batch size results in 1 gradient accumulation # step. - global_batch_size = job_config.training.local_batch_size * dp_degree + global_batch_size = job_config.training.local_batch_size * batch_degree assert global_batch_size > 0 assert ( - global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 + global_batch_size % (job_config.training.local_batch_size * batch_degree) + == 0 ), ( f"global batch size must be multiple of local batch size times " f"data-parallel degree ({global_batch_size} " - f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" + f"% ({job_config.training.local_batch_size} * {batch_degree}) != 0)" ) # calculate gradient accumulation steps self.gradient_accumulation_steps = global_batch_size // ( - job_config.training.local_batch_size * dp_degree + job_config.training.local_batch_size * batch_degree ) assert self.gradient_accumulation_steps > 0 self.loss_fn = rescale_accumulated_loss( @@ -340,8 +335,8 @@ def __init__(self, job_config: JobConfig): self.validator = self.train_spec.build_validator_fn( job_config=job_config, - dp_world_size=dp_degree, - dp_rank=dp_rank, + dp_world_size=batch_degree, + dp_rank=batch_rank, tokenizer=self.tokenizer, parallel_dims=parallel_dims, loss_fn=self.loss_fn, @@ -486,7 +481,7 @@ def forward_backward_step( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], + cp_mesh=parallel_dims.get_mesh("cp"), cp_buffers=cp_buffers, cp_seq_dims=cp_seq_dims, cp_no_restore_buffers={inputs, labels}, @@ -565,9 +560,7 @@ def train_step( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, - pp_mesh=( - parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None - ), + pp_mesh=parallel_dims.get_mesh("pp"), ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() @@ -581,17 +574,18 @@ def train_step( if not self.metrics_processor.should_log(self.step): return - if parallel_dims.dp_cp_enabled: + if parallel_dims.batch_enabled: loss = loss.detach() ft_pg = self.ft_manager.loss_sync_pg + batch_mesh = parallel_dims.get_mesh("batch") global_avg_loss, global_max_loss, global_ntokens_seen = ( - dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), - dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_mean(loss, batch_mesh, ft_pg), + dist_utils.dist_max(loss, batch_mesh, ft_pg), dist_utils.dist_sum( torch.tensor( self.ntokens_seen, dtype=torch.int64, device=self.device ), - parallel_dims.world_mesh["dp_cp"], + batch_mesh, ft_pg, ), ) @@ -688,7 +682,7 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.parallel_dims._world_mesh, + parallel_dims=self.parallel_dims, ) if torch.distributed.get_rank() == 0: From 8219f76d2fb1a53f2929fbf98c7cf3f6943142eb Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 28 Oct 2025 12:27:29 -0700 Subject: [PATCH 10/38] misc --- torchtitan/config/job_config.py | 14 +---- torchtitan/distributed/parallel_dims.py | 56 ++++++++----------- torchtitan/models/llama3/infra/parallelize.py | 10 ++-- torchtitan/models/llama4/infra/parallelize.py | 26 ++++----- torchtitan/train.py | 2 +- 5 files changed, 40 insertions(+), 68 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index c806041bb6..d185af2b52 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -388,19 +388,7 @@ class Parallelism: """ Expert parallelism degree. 1 means disabled. No effect for non-MoE models. - Currently, it is supported with the following constraints: - - - when etp = tp: - - - cp <= ep <= dp_shard * cp - - ep % cp == 0 - - dp_shard * cp % ep == 0 - - - when etp = 1: - - - cp * tp <= ep <= dp_shard * cp * tp - - ep % (cp * tp) == 0 - - dp_shard * cp * tp % ep == 0 + Currently, etp is either 1 or is the same as tp. Note that this is still an experimental feature. Some constraints will be relaxed soon when we have more flexible DeviceMesh support. diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index d2b91c7c82..5b49ba1bca 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -57,12 +57,6 @@ def _validate(self): if ep > 1: assert etp == tp or etp == 1, "Currently we only support ETP=TP or ETP=1" - if etp == tp: - # EP would borrow all cp and some dp_shard degree - assert ep % cp == 0 and (dp_shard * cp) % ep == 0 - elif etp == 1: - # EP would borrow all cp and tp and some dp_shard degree - assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0 def build_mesh(self) -> DeviceMesh: """ @@ -71,7 +65,6 @@ def build_mesh(self) -> DeviceMesh: The following mesh dimensions will be created: pp: Pipeline Parallelism (PP). - spmd: Used by SPMD DTensor RNG seed. batch: Used by data loading to determine the global batch size and which part of the data each rank should read. This dimension includes both ``dp_replicate`` and ``dp_shard``. The backend is set to ``fake`` for @@ -79,7 +72,7 @@ def build_mesh(self) -> DeviceMesh: loss: Used by all-reduce when computing the loss. Includes ``dp_replicate``, ``dp_shard``, and ``cp`` degrees, as all are data parallelisms. dp_replicate: For DDP or HSDP replicate dimension. - fsdp: For FSDP dimension. This includes ``cp``. + fsdp: For FSDP dimension. This includes ``dp_shard`` and ``cp``. cp: Context Parallelism (CP). tp: Tensor Parallelism (TP). ep: Expert Parallelism (EP). @@ -89,7 +82,6 @@ def build_mesh(self) -> DeviceMesh: Note: All the dimensions above are created by unflattening the world mesh. This API performs the following unflatten operations: - ["pp", "spmd"] ["pp", "batch", "cp", "tp"] ["pp", "loss", "tp"] ["pp", "dp_replicate", "fsdp", "tp"] @@ -127,20 +119,16 @@ def unflatten_mesh( loss = self.dp_replicate * self.dp_shard * self.cp fsdp = self.dp_shard * self.cp efsdp = fsdp * self.tp // (self.etp * self.ep) - spmd = self.world_size // self.pp self._world_mesh = init_device_mesh( device_type, (self.world_size,), mesh_dim_names=("world",) ) - pp_spmd_mesh = unflatten_mesh(self._world_mesh, ("pp", "spmd"), (self.pp, spmd)) - data_mesh = unflatten_mesh( + dataloading_mesh = unflatten_mesh( self._world_mesh, ("pp", "batch", "cp", "tp"), (self.pp, batch, self.cp, self.tp), ) - loss_mesh = unflatten_mesh( - self._world_mesh, ("pp", "loss", "tp"), (self.pp, loss, self.tp) - ) + loss_mesh = dataloading_mesh["batch", "cp"].flatten("loss_mesh") dense_mesh = unflatten_mesh( self._world_mesh, ("pp", "dp_replicate", "fsdp", "tp"), @@ -153,14 +141,13 @@ def unflatten_mesh( ) self._meshes = { - "pp": pp_spmd_mesh["pp"], - "spmd": pp_spmd_mesh["spmd"], - "batch": data_mesh["batch"], + "pp": dataloading_mesh["pp"], + "batch": dataloading_mesh["batch"], "loss": loss_mesh["loss"], "dp_replicate": dense_mesh["dp_replicate"], "fsdp": dense_mesh["fsdp"], - "cp": data_mesh["cp"], - "tp": data_mesh["tp"], + "cp": dataloading_mesh["cp"], + "tp": dataloading_mesh["tp"], "ep": sparse_mesh["ep"], "efsdp": sparse_mesh["efsdp"], "etp": sparse_mesh["etp"], @@ -180,7 +167,6 @@ def _validate_meshes(self): """Validate that created meshes have the expected sizes.""" expected_sizes = { "pp": self.pp, - "spmd": self.world_size // self.pp, "batch": self.dp_replicate * self.dp_shard, "loss": self.dp_replicate * self.dp_shard * self.cp, "dp_replicate": self.dp_replicate, @@ -199,34 +185,38 @@ def _validate_meshes(self): f"expected {expected_size}, got {actual_size}" ) - def get_mesh(self, dim: str) -> DeviceMesh | None: - """Get a device mesh by dimension name. + def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None: + """Get a device mesh by dimension names. Args: - dim: Name of the mesh dimension. Valid options include: - 'pp', 'spmd', 'batch', 'loss', 'dp_replicate', 'fsdp', + dims: Names of the mesh dimension. Valid options include: + 'pp', 'batch', 'loss', 'dp_replicate', 'fsdp', 'cp', 'tp', 'ep', 'etp', 'efsdp' Returns: - DeviceMesh for the requested dimension, or None if the dimension - has size 1 (i.e., parallelism is disabled for that dimension). + DeviceMesh for the requested dimension(s), or None if any of + dimension(s) has size 1 (i.e., parallelism is disabled for that dimension). Raises: - ValueError: If the requested dimension name is not valid. + ValueError: If the requested dimension name(s) is not valid. """ if not self._meshes: self.build_mesh() - if dim not in self._meshes: + if isinstance(dims, str): + dims = [dims] + + if not all(dim in self._meshes for dim in dims): valid_dims = sorted(self._meshes.keys()) raise ValueError( - f"Invalid mesh dim: '{dim}'. Valid dimensions are: {valid_dims}" + f"Invalid mesh dim: '{dims}'. Valid dimensions are: {valid_dims}" ) - if self._meshes[dim].size() == 1: + if any(self._meshes[dim].size() == 1 for dim in dims): return None - return self._meshes[dim] + meshes = [self._meshes[dim] for dim in dims] + return meshes[0] if len(meshes) == 1 else DeviceMesh._concatenate(meshes) def get_all_meshes(self) -> dict[str, DeviceMesh]: if not self._meshes: @@ -256,7 +246,7 @@ def cp_enabled(self): return self.cp > 1 @property - def batch_enabled(self): + def dp_cp_enabled(self): return self.dp_enabled or self.cp_enabled @property diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index b644f45112..7896afab8b 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -111,12 +111,10 @@ def parallelize_llama( if parallel_dims.fsdp_enabled: # dp_mesh is the mesh for FSDP/HSDP - if parallel_dims.dp_replicate_enabled: - dp_mesh = DeviceMesh._concatenate( - [parallel_dims.get_mesh("dp_replicate"), parallel_dims.get_mesh("fsdp")] - ) - else: - dp_mesh = parallel_dims.get_mesh("fsdp") + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) apply_fsdp( model, dp_mesh, diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index a9d43c29ff..8f991295e1 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -105,7 +105,7 @@ def parallelize_llama( tp_mesh=tp_mesh, ep_mesh=parallel_dims.get_mesh("ep"), etp_mesh=parallel_dims.get_mesh("etp"), - etp_enabled=parallel_dims.etp_enabled, + ep_etp_mesh=parallel_dims.get_mesh(["ep", "etp"]), ) model_compile_enabled = ( @@ -126,23 +126,16 @@ def parallelize_llama( if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # dp_mesh is the mesh for FSDP/HSDP - if parallel_dims.dp_replicate_enabled: - dp_mesh = DeviceMesh._concatenate( - [parallel_dims.get_mesh("dp_replicate"), parallel_dims.get_mesh("fsdp")] - ) - else: - dp_mesh = parallel_dims.get_mesh("fsdp") + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP dp_mod_ep_mesh = None if parallel_dims.ep_enabled: if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh = DeviceMesh._concatenate( - [ - parallel_dims.get_mesh("dp_replicate"), - parallel_dims.get_mesh("efsdp"), - ] - ) + dp_mod_ep_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"]) else: dp_mod_ep_mesh = parallel_dims.get_mesh("efsdp") @@ -439,6 +432,7 @@ def apply_moe_ep_tp( tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, etp_mesh: DeviceMesh | None, + ep_etp_mesh: DeviceMesh | None, ): assert ep_mesh is not None or tp_mesh is not None @@ -482,17 +476,19 @@ def apply_moe_ep_tp( parallelize_plan=moe_layer_plan, ) - experts_mesh, experts_plan = None, None + expert_mesh, experts_plan = None, None if ep_mesh is None: + assert ep_etp_mesh is None experts_mesh = tp_mesh # input Replicate, output Partial experts_plan = TensorParallel() elif tp_mesh is None or etp_mesh is None: + assert ep_etp_mesh is None experts_mesh = ep_mesh # input / output sharding on the batch / tokens dim experts_plan = ExpertParallel() else: - experts_mesh = DeviceMesh._concatenate([ep_mesh, etp_mesh]) + experts_mesh = ep_etp_mesh experts_plan = ExpertTensorParallel() parallelize_module( diff --git a/torchtitan/train.py b/torchtitan/train.py index 25b6b38710..91d8b07251 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -574,7 +574,7 @@ def train_step( if not self.metrics_processor.should_log(self.step): return - if parallel_dims.batch_enabled: + if parallel_dims.dp_cp_enabled: loss = loss.detach() ft_pg = self.ft_manager.loss_sync_pg batch_mesh = parallel_dims.get_mesh("batch") From d830822608c51c80bbf5713bd9fd89815138fa98 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 28 Oct 2025 13:33:05 -0700 Subject: [PATCH 11/38] misc --- torchtitan/distributed/parallel_dims.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 5b49ba1bca..1920a05202 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -128,7 +128,7 @@ def unflatten_mesh( ("pp", "batch", "cp", "tp"), (self.pp, batch, self.cp, self.tp), ) - loss_mesh = dataloading_mesh["batch", "cp"].flatten("loss_mesh") + loss_mesh = dataloading_mesh["batch", "cp"]._flatten("loss_mesh") dense_mesh = unflatten_mesh( self._world_mesh, ("pp", "dp_replicate", "fsdp", "tp"), @@ -143,7 +143,7 @@ def unflatten_mesh( self._meshes = { "pp": dataloading_mesh["pp"], "batch": dataloading_mesh["batch"], - "loss": loss_mesh["loss"], + "loss": loss_mesh, "dp_replicate": dense_mesh["dp_replicate"], "fsdp": dense_mesh["fsdp"], "cp": dataloading_mesh["cp"], From 1f950c7775d0899f3c07f370b431d8205beae7ec Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 28 Oct 2025 14:26:56 -0700 Subject: [PATCH 12/38] misc --- torchtitan/distributed/parallel_dims.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 1920a05202..d586b719d7 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -79,11 +79,11 @@ def build_mesh(self) -> DeviceMesh: efsdp: FSDP in the EP region. etp: TP in the EP region. - Note: All the dimensions above are created by unflattening the world mesh. + Note: Most dimensions above are created by unflattening the world mesh, except for loss, + which is created by flattening the batch and cp dimensions. This API performs the following unflatten operations: ["pp", "batch", "cp", "tp"] - ["pp", "loss", "tp"] ["pp", "dp_replicate", "fsdp", "tp"] ["pp", "dp_replicate", "efsdp", "ep", "etp"] From f169408e3dba2dc6b348c58da26cc27a987c6fa6 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 3 Nov 2025 15:06:12 -0800 Subject: [PATCH 13/38] misc --- torchtitan/distributed/utils.py | 67 +++++++++++++------ torchtitan/models/llama4/infra/parallelize.py | 2 +- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 2f16bc3837..d64eec3cbf 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -82,20 +82,26 @@ def dist_mean( def set_determinism( - world_mesh: DeviceMesh | None, + parallel_dims: ParallelDims, device: torch.device, debug_config: DebugConfig, - distinct_seed_mesh_dim: str = "pp", + distinct_seed_mesh_dims: list[str], ) -> None: """ Set the same DTensor manual seed for all dimensions in world mesh, but only different seeds - across dimension denoted by `distinct_seed_mesh_dim`. An example use case is pipeline parallelism, + across dimensions denoted by `distinct_seed_mesh_dims`. An example use case is pipeline parallelism, where we want to have the same seed across SPMD groups, but different seeds across PP groups. Currently, does not set seeds for the CUDA RNG since TorchTitan always uses DTensor for SPMD parallelisms, and DTensor manages its own RNG tracker, but we could extend to support both if needed. Set Determinism flags for increased reproducibility with loss of performance. + + Args: + world_mesh: Device mesh for distributed training + device: Device to use + debug_config: Debug config to use + distinct_seed_mesh_dims: List of mesh dimension names to have distinct seeds across. """ if debug_config.deterministic: logger.info("Deterministic algorithm enabled (expect perf degradation).") @@ -118,7 +124,7 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) seed = debug_config.seed - if not world_mesh: + if parallel_dims.world_size == 1: if seed is not None: torch.manual_seed(seed) os.environ["PYTHONHASHSEED"] = str(seed % 2**32) @@ -134,31 +140,46 @@ def set_determinism( torch.distributed.broadcast(seed_tensor, src=0) seed = seed_tensor.to("cpu").view(torch.uint64).item() - # Set distinct seed for each rank in mesh dimensions, with dimension name provided by `distinct_seed_mesh_dim` + # Set distinct seed for each rank in mesh dimensions, with dimension names provided by `distinct_seed_mesh_dims` # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh, # and choose a unique seed for each rank on the PP mesh. - # TODO(jianiw): We could further extend this to support multiple distinct dimensions instead of just one. - if ( - c10d.get_world_size() > 1 - and distinct_seed_mesh_dim in world_mesh.mesh_dim_names - ): - distinct_mesh = world_mesh[distinct_seed_mesh_dim] - seed += distinct_mesh.get_local_rank() + # We support multiple distinct dimensions by adding each distinct dimension's local rank to the seed. + distinct_seed_meshes = [ + parallel_dims.get_mesh(dim) for dim in distinct_seed_mesh_dims + ] + distinct_seed_meshes = [mesh for mesh in distinct_seed_meshes if mesh is not None] + + if distinct_dims_in_mesh: + # Each dimension contributes: local_rank * (product of all previous dimension sizes) + # This guarantees uniqueness like multi-dimensional array indexing + seed_offset = 0 + cumulative_size = 1 + + for distinct_mesh in distinct_seed_meshes: + local_rank = distinct_mesh.get_local_rank() + # Add contribution from this dimension + seed_offset += local_rank * cumulative_size + # Update cumulative size for next dimension + cumulative_size *= distinct_mesh.size() + + seed += seed_offset seed %= 2**64 logger.debug( - f"{distinct_seed_mesh_dim} rank {distinct_mesh.get_local_rank()}, Global rank {c10d.get_rank()} using seed: {seed}" - ) - duplicate_seed_mesh = list( - filter( - lambda name: name != distinct_seed_mesh_dim, world_mesh.mesh_dim_names - ) + f"Distinct dims {distinct_dims_in_mesh}, Global rank {c10d.get_rank()} using seed: {seed}" ) + + # Filter out all distinct dimensions to get duplicate_seed_mesh + duplicate_seed_mesh_dims = [ + v + for k, v in parallel_dims.get_all_meshes() + if k not in distinct_dims_in_mesh + ] duplicate_seed_mesh = ( - world_mesh[duplicate_seed_mesh] if len(duplicate_seed_mesh) else None + world_mesh[duplicate_seed_mesh_dims] if duplicate_seed_mesh_dims else None ) else: - duplicate_seed_mesh = world_mesh + duplicate_seed_meshes = [parallel_dims.world_mesh] logger.debug(f"Global Rank {c10d.get_rank()} using seed: {seed}") # The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency. @@ -168,8 +189,10 @@ def set_determinism( # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh. # IF PP is also used, this seed is unique per PP rank. - if duplicate_seed_mesh and duplicate_seed_mesh.get_coordinate() is not None: - torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh) + # TODO: remove the need of duplicate_seed_meshes once torch.distributed.tensor._random.manual_seed + # doesn't require a mesh input. + if duplicate_seed_meshes: + torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_meshes[0]) def create_context_parallel_ctx( diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 8f991295e1..a370c63514 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -476,7 +476,7 @@ def apply_moe_ep_tp( parallelize_plan=moe_layer_plan, ) - expert_mesh, experts_plan = None, None + experts_mesh, experts_plan = None, None if ep_mesh is None: assert ep_etp_mesh is None experts_mesh = tp_mesh From 4db7bea76f04fec47082d7c487081be4f76fe9f0 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 3 Nov 2025 15:17:20 -0800 Subject: [PATCH 14/38] misc --- torchtitan/distributed/parallel_dims.py | 26 ++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index d586b719d7..55f2db2928 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -140,6 +140,13 @@ def unflatten_mesh( (self.pp, self.dp_replicate, efsdp, self.ep, self.etp), ) + # We have created all the required 1D meshes. This part is to create the + # all the 2D meshes. We pre-created 2D meshes and error out if the users + # try to access a 2D mesh that is not pre-created. + hsdp_mesh = dense_mesh["dp_replicate", "fsdp"] + ehsdp_mesh = sparse_mesh["dp_replicate", "efsdp"] + ep_etp_mesh = sparse_mesh["ep", "etp"] + self._meshes = { "pp": dataloading_mesh["pp"], "batch": dataloading_mesh["batch"], @@ -151,6 +158,9 @@ def unflatten_mesh( "ep": sparse_mesh["ep"], "efsdp": sparse_mesh["efsdp"], "etp": sparse_mesh["etp"], + "dp_replicate_fsdp": hsdp_mesh, + "dp_replicate_efsdp": ehsdp_mesh, + "ep_etp": ep_etp_mesh, } # Validate mesh sizes @@ -176,6 +186,12 @@ def _validate_meshes(self): "ep": self.ep, "efsdp": self.dp_shard * self.cp * self.tp // (self.etp * self.ep), "etp": self.etp, + "dp_replicate_fsdp": (self.dp_replicate, self.dp_shard * self.cp), + "dp_replicate_efsdp": ( + self.dp_replicate, + self.dp_shard * self.cp * self.tp // (self.etp * self.ep), + ), + "ep_etp": (self.ep, self.etp), } for mesh_name, expected_size in expected_sizes.items(): @@ -206,17 +222,17 @@ def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None: if isinstance(dims, str): dims = [dims] - if not all(dim in self._meshes for dim in dims): - valid_dims = sorted(self._meshes.keys()) + mesh_name = "_".join(dims) + if mesh_name not in self._meshes: raise ValueError( - f"Invalid mesh dim: '{dims}'. Valid dimensions are: {valid_dims}" + f"Invalid mesh dim: '{mesh_name}'. " + f"Valid dimensions are: {list(self._meshes.keys())}" ) if any(self._meshes[dim].size() == 1 for dim in dims): return None - meshes = [self._meshes[dim] for dim in dims] - return meshes[0] if len(meshes) == 1 else DeviceMesh._concatenate(meshes) + return self._meshes[mesh_name] def get_all_meshes(self) -> dict[str, DeviceMesh]: if not self._meshes: From baa61f9d8e69234a9b1dd5177c6803386e48e2e5 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 3 Nov 2025 15:36:43 -0800 Subject: [PATCH 15/38] misc --- torchtitan/experiments/forge/engine.py | 7 +++---- torchtitan/models/flux/train.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index 5035129008..e9818ba19f 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -86,10 +86,9 @@ def __init__(self, job_config: ForgeJobConfig): world_size=world_size, ) - world_mesh = parallel_dims.world_mesh if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] - dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + batch_mesh = parallel_dims.get_mesh("batch") + dp_degree, dp_rank = batch_mesh.size(), batch_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 self.dp_degree, self.dp_rank = dp_degree, dp_rank @@ -102,7 +101,7 @@ def __init__(self, job_config: ForgeJobConfig): # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( - world_mesh, + parallel_dims, self.device, job_config.debug, distinct_seed_mesh_dims=["pp"], # same as `torchtitan/train.py` diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 5af9959050..0f2e67d7c9 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -28,7 +28,7 @@ def __init__(self, job_config: JobConfig): # (mainly for debugging, expect perf loss). # For Flux model, we need distinct seed across FSDP ranks to ensure we randomly dropout prompts info in dataloader dist_utils.set_determinism( - self.parallel_dims.world_mesh, + self.parallel_dims, self.device, job_config.debug, distinct_seed_mesh_dims=["dp_shard", "dp_replicate"], @@ -129,7 +129,7 @@ def forward_backward_step( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=self.parallel_dims.world_mesh["cp"], + cp_mesh=self.parallel_dims.get_mesh("cp"), cp_buffers=[ latents, latent_pos_enc, From 61663ee3fbea6f9ad91ac35ed91b4e2c34c0b9d4 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 3 Nov 2025 16:05:42 -0800 Subject: [PATCH 16/38] misc --- torchtitan/components/optimizer.py | 2 +- torchtitan/components/validate.py | 4 ++-- torchtitan/distributed/utils.py | 10 ++++++---- .../experiments/compiler_toolkit/common_utils.py | 6 ++++-- torchtitan/experiments/compiler_toolkit/graph_utils.py | 4 +--- torchtitan/experiments/forge/example_train.py | 10 +++++----- torchtitan/experiments/gpt_oss/infra/parallelize.py | 8 ++++---- .../experiments/simple_fsdp/deepseek_v3/parallelize.py | 10 +++++----- .../experiments/simple_fsdp/llama3/parallelize.py | 2 +- torchtitan/experiments/vlm/infra/loss.py | 2 +- torchtitan/models/deepseek_v3/infra/parallelize.py | 10 +++++----- torchtitan/models/flux/validate.py | 4 ++-- torchtitan/models/qwen3/infra/parallelize.py | 8 ++++---- 13 files changed, 41 insertions(+), 39 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 80557366da..c961c4181b 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -358,7 +358,7 @@ def _update_expert_bias( parallel_dims: ParallelDims, ): dp_cp_mesh = ( - parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + parallel_dims.get_mesh("dp_cp") if parallel_dims.dp_cp_enabled else None ) # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 93fb68a3cc..8087bb2a63 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -113,7 +113,7 @@ def validate( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], + cp_mesh=parallel_dims.get_mesh("cp"), cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, @@ -167,7 +167,7 @@ def validate( loss /= num_steps if parallel_dims.dp_cp_enabled: global_avg_loss = dist_utils.dist_mean( - loss, parallel_dims.world_mesh["dp_cp"] + loss, parallel_dims.get_mesh("dp_cp") ) else: global_avg_loss = loss.item() diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index d64eec3cbf..e30eef5222 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -149,7 +149,7 @@ def set_determinism( ] distinct_seed_meshes = [mesh for mesh in distinct_seed_meshes if mesh is not None] - if distinct_dims_in_mesh: + if distinct_seed_meshes: # Each dimension contributes: local_rank * (product of all previous dimension sizes) # This guarantees uniqueness like multi-dimensional array indexing seed_offset = 0 @@ -166,17 +166,19 @@ def set_determinism( seed %= 2**64 logger.debug( - f"Distinct dims {distinct_dims_in_mesh}, Global rank {c10d.get_rank()} using seed: {seed}" + f"Distinct dims {distinct_seed_mesh_dims}, Global rank {c10d.get_rank()} using seed: {seed}" ) # Filter out all distinct dimensions to get duplicate_seed_mesh duplicate_seed_mesh_dims = [ v for k, v in parallel_dims.get_all_meshes() - if k not in distinct_dims_in_mesh + if k not in distinct_seed_mesh_dims ] duplicate_seed_mesh = ( - world_mesh[duplicate_seed_mesh_dims] if duplicate_seed_mesh_dims else None + parallel_dims.world_mesh[duplicate_seed_mesh_dims] + if duplicate_seed_mesh_dims + else None ) else: duplicate_seed_meshes = [parallel_dims.world_mesh] diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 997af9a2c4..2b2a1f5244 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -25,10 +25,12 @@ def disable_compile(job_config: JobConfig): job_config.compile.enable = original_value -def parallelize_inputs(world_mesh, args, kwargs): +def parallelize_inputs(parallel_dims, args, kwargs): def to_dtensor(tensor): if isinstance(tensor, torch.Tensor): - return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()]) + return DTensor.from_local( + tensor, parallel_dims.get_mesh("tp"), [Replicate()] + ) return tensor dt_args = tree_map(to_dtensor, args) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index e097579cc0..551bf695c5 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -194,9 +194,7 @@ def parameters(self, *args, **kwargs) -> Any: def forward(self, *args, **kwargs): assert "forward" not in self._overrides, "forward cannot be overridden" - dt_args, dt_kwargs = self.parallelize_inputs( - self.parallel_dims.world_mesh, args, kwargs - ) + dt_args, dt_kwargs = self.parallelize_inputs(self.parallel_dims, args, kwargs) if self.joint_graph_module is None: self.joint_graph_module = self.joint_graph_builder( diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 66ad151dd0..02069ab2b2 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -169,7 +169,7 @@ def forward_backward_step( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], + cp_mesh=parallel_dims.get_mesh("cp"), cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, @@ -244,7 +244,7 @@ def train_step( self.job_config.training.max_norm, foreach=True, pp_mesh=( - parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + parallel_dims.get_mesh("pp") if parallel_dims.pp_enabled else None ), ep_enabled=parallel_dims.ep_enabled, ) @@ -262,8 +262,8 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() global_avg_loss, global_max_loss = ( - dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"]), - dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"]), + dist_utils.dist_mean(loss, parallel_dims.get_mesh("dp_cp")), + dist_utils.dist_max(loss, parallel_dims.get_mesh("dp_cp")), ) else: global_avg_loss = global_max_loss = loss.detach().item() @@ -329,7 +329,7 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.parallel_dims.world_mesh, + parallel_dims=self.parallel_dims, ) if torch.distributed.get_rank() == 0: diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 4d1177d1ab..aa5acaaf0a 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -86,7 +86,7 @@ def parallelize_gptoss( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, enable_async_tp=False, @@ -95,10 +95,10 @@ def parallelize_gptoss( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, + ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, ep_tp_mesh=( - world_mesh["ep", "tp"] + parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled and parallel_dims.ep_enabled and parallel_dims.etp_enabled diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 83e24d7dc1..7547977195 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -87,19 +87,19 @@ def parallelize_deepseekv3( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, parallel_dims.get_mesh("tp")) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, + ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, ep_tp_mesh=( - world_mesh["ep", "tp"] + parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled and parallel_dims.ep_enabled and parallel_dims.etp_enabled diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 484d3d4747..90af5cdf12 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -97,7 +97,7 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise - tp_mesh = parallel_dims.world_mesh["tp"] + tp_mesh = parallel_dims.get_mesh("tp") apply_tp( model, tp_mesh, diff --git a/torchtitan/experiments/vlm/infra/loss.py b/torchtitan/experiments/vlm/infra/loss.py index bba51f2819..291cd193f3 100644 --- a/torchtitan/experiments/vlm/infra/loss.py +++ b/torchtitan/experiments/vlm/infra/loss.py @@ -104,7 +104,7 @@ def build_token_imbalance_ce_loss( # NOTE: The device mesh where the input tokens w/ shape BSD can be sliced: # DP split the batch dim B # CP split the sequence dim S - token_mesh = parallel_dims.world_mesh["dp_cp"] + token_mesh = parallel_dims.get_mesh("dp_cp") ft_pg = ft_manager.loss_sync_pg loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg) if job_config.compile.enable and "loss" in job_config.compile.components: diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index d66a30a83d..dfc6dfc9db 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -85,19 +85,19 @@ def parallelize_deepseekv3( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, parallel_dims.get_mesh("tp")) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, + ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, ep_tp_mesh=( - world_mesh["ep", "tp"] + parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled and parallel_dims.ep_enabled and parallel_dims.etp_enabled diff --git a/torchtitan/models/flux/validate.py b/torchtitan/models/flux/validate.py index 189385e0f2..3d06cb05b1 100644 --- a/torchtitan/models/flux/validate.py +++ b/torchtitan/models/flux/validate.py @@ -213,7 +213,7 @@ def validate( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], + cp_mesh=parallel_dims.get_mesh("cp"), cp_buffers=[ latents, latent_pos_enc, @@ -259,7 +259,7 @@ def validate( loss /= num_steps if parallel_dims.dp_cp_enabled: global_avg_loss = dist_utils.dist_mean( - loss, parallel_dims.world_mesh["dp_cp"] + loss, parallel_dims.get_mesh("dp_cp") ) else: global_avg_loss = loss.item() diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 517435714b..24cfcfcc62 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -91,7 +91,7 @@ def parallelize_qwen3( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, @@ -100,10 +100,10 @@ def parallelize_qwen3( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, + ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, ep_tp_mesh=( - world_mesh["ep", "tp"] + parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled and parallel_dims.ep_enabled and parallel_dims.etp_enabled From 22bb1f389fecd9be2c0b15cac35fd311769f5949 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Nov 2025 11:13:07 -0800 Subject: [PATCH 17/38] fix --- torchtitan/distributed/parallel_dims.py | 14 +++++++++++--- torchtitan/distributed/utils.py | 15 ++++++++++++--- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 55f2db2928..d84cca981a 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -195,7 +195,10 @@ def _validate_meshes(self): } for mesh_name, expected_size in expected_sizes.items(): - actual_size = self._meshes[mesh_name].size() + if isinstance(expected_size, tuple): + actual_size = self._meshes[mesh_name].shape + else: + actual_size = self._meshes[mesh_name].size() assert actual_size == expected_size, ( f"Mesh '{mesh_name}' has unexpected size: " f"expected {expected_size}, got {actual_size}" @@ -234,10 +237,15 @@ def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None: return self._meshes[mesh_name] - def get_all_meshes(self) -> dict[str, DeviceMesh]: + def get_all_meshes(self, one_dimensioal_only: bool = True) -> dict[str, DeviceMesh]: if not self._meshes: self.build_mesh() - return {k: v for k, v in self._meshes.items() if v.size() > 1} + if one_dimensioal_only: + return { + k: v for k, v in self._meshes.items() if v.ndim == 1 and v.size() > 1 + } + else: + return {k: v for k, v in self._meshes.items() if v.size() > 1} @property def world_mesh(self) -> DeviceMesh: diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index e30eef5222..0559c5aff2 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -53,9 +53,12 @@ def _dist_reduce( def dist_max( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: + if mesh is None: + assert not isinstance(x, DTensor), "mesh is required for DTensor input" + return x.item() return _dist_reduce( x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh, extra_pg=extra_pg ) @@ -63,9 +66,12 @@ def dist_max( def dist_sum( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: + if mesh is None: + assert not isinstance(x, DTensor), "mesh is required for DTensor input" + return x.item() return _dist_reduce( x, reduceOp=c10d.ReduceOp.SUM.name, mesh=mesh, extra_pg=extra_pg ) @@ -73,9 +79,12 @@ def dist_sum( def dist_mean( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: + if mesh is None: + assert not isinstance(x, DTensor), "mesh is required for DTensor input" + return x.item() return _dist_reduce( x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh, extra_pg=extra_pg ) From e5871c8f5455462b7e51e7c152da3be3a8204c09 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Nov 2025 12:14:35 -0800 Subject: [PATCH 18/38] misc --- torchtitan/distributed/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 0559c5aff2..4467549ad7 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -181,11 +181,11 @@ def set_determinism( # Filter out all distinct dimensions to get duplicate_seed_mesh duplicate_seed_mesh_dims = [ v - for k, v in parallel_dims.get_all_meshes() + for k, v in parallel_dims.get_all_meshes().items() if k not in distinct_seed_mesh_dims ] duplicate_seed_mesh = ( - parallel_dims.world_mesh[duplicate_seed_mesh_dims] + parallel_dims.get_mesh(duplicate_seed_mesh_dims) if duplicate_seed_mesh_dims else None ) From a144b19cc5196c316da9b116fdfe9c672ed24744 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Nov 2025 14:02:42 -0800 Subject: [PATCH 19/38] misc --- scripts/generate/test_generate.py | 7 +++---- tests/unit_tests/test_set_determinism.py | 6 +++--- torchtitan/distributed/utils.py | 11 ++++------- .../experiments/gpt_oss/infra/parallelize.py | 14 ++++++-------- .../simple_fsdp/deepseek_v3/parallelize.py | 11 +++++------ .../experiments/simple_fsdp/llama3/parallelize.py | 8 ++++---- torchtitan/experiments/vlm/infra/parallelize.py | 12 ++++++------ torchtitan/models/deepseek_v3/infra/parallelize.py | 13 ++++++------- torchtitan/models/flux/infra/parallelize.py | 12 ++++++------ torchtitan/models/qwen3/infra/parallelize.py | 14 +++++++------- 10 files changed, 50 insertions(+), 58 deletions(-) diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index b1d19ad17f..2efec2a494 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -113,7 +113,7 @@ def test_generate( logger.info(f"Init model on init_device: {init_device}") model = train_spec.model_cls(model_args) - world_mesh = None + parallel_dims = None # Init distributed env if world_size > 1: dist_utils.init_distributed(config.comm) @@ -127,15 +127,14 @@ def test_generate( etp=1, world_size=world_size, ) - world_mesh = parallel_dims.world_mesh # apply_tp (with Sequence Parallel) on unevenly sharded # sequences would require https://github.com/pytorch/torchtitan/pull/686 - apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"]) + apply_tp_minus_sp(model, parallel_dims.get_mesh("tp")) debug_config = DebugConfig(seed=seed, deterministic=deterministic) dist_utils.set_determinism( - world_mesh=world_mesh, + parallel_dims=parallel_dims, device=device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], diff --git a/tests/unit_tests/test_set_determinism.py b/tests/unit_tests/test_set_determinism.py index c8087731c5..5d5ecb1557 100644 --- a/tests/unit_tests/test_set_determinism.py +++ b/tests/unit_tests/test_set_determinism.py @@ -90,7 +90,7 @@ def test_seed_uniqueness_2d_mesh(self, mock_get_rank, mock_get_world_size): # Call set_determinism with distinct seeds only on PP dimension debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], @@ -159,7 +159,7 @@ def test_seed_uniqueness_3d_mesh(self, mock_get_rank, mock_get_world_size): # Call set_determinism with distinct seeds on dp_shard and dp_replicate only debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["dp_shard", "dp_replicate"], @@ -223,7 +223,7 @@ def test_set_determinism_single_gpu(self, mock_get_rank, mock_get_world_size): debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 4467549ad7..5732743d00 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -179,15 +179,10 @@ def set_determinism( ) # Filter out all distinct dimensions to get duplicate_seed_mesh - duplicate_seed_mesh_dims = [ + duplicate_seed_meshes = list( v for k, v in parallel_dims.get_all_meshes().items() if k not in distinct_seed_mesh_dims - ] - duplicate_seed_mesh = ( - parallel_dims.get_mesh(duplicate_seed_mesh_dims) - if duplicate_seed_mesh_dims - else None ) else: duplicate_seed_meshes = [parallel_dims.world_mesh] @@ -511,7 +506,9 @@ def _clip_grad_norm_with_ep( if math.isinf(norm_type): total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) else: - total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type + total_norm = ( + ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type + ) total_norm **= 1.0 / norm_type if pp_mesh is not None: diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index aa5acaaf0a..2c119604e2 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -57,8 +57,6 @@ def parallelize_gptoss( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh - assert ( job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" @@ -123,10 +121,10 @@ def parallelize_gptoss( if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh_dim_names = ["dp_shard_cp"] + dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP dp_mod_ep_mesh_dim_names = [] @@ -145,7 +143,7 @@ def parallelize_gptoss( reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + parallel_dims.get_mesh(dp_mod_ep_mesh_dim_names) if parallel_dims.ep_enabled else None ), @@ -162,9 +160,9 @@ def parallelize_gptoss( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 7547977195..f3e2127c97 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -54,7 +54,6 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -125,16 +124,16 @@ def parallelize_deepseekv3( ): if parallel_dims.dp_replicate_enabled: if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] dp_mode = "hybrid_shard" else: - dp_mesh_dim_names = ("dp_replicate",) + dp_mesh_dim_names = ["dp_replicate"] dp_mode = "replicate" else: - dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh_dim_names = ["dp_shard_cp"] dp_mode = "fully_shard" - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP dp_mod_ep_mesh_dim_names = [] @@ -142,7 +141,7 @@ def parallelize_deepseekv3( if parallel_dims.dp_replicate_enabled: dp_mod_ep_mesh_dim_names.append("dp_replicate") dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") - dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + dp_mod_ep_mesh = parallel_dims.get_mesh(dp_mod_ep_mesh_dim_names) for _, transformer_block in model.layers.items(): if transformer_block.moe_enabled and parallel_dims.ep_enabled: diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 90af5cdf12..72ae20fd11 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -126,13 +126,13 @@ def parallelize_llama( ): if parallel_dims.dp_replicate_enabled: if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] dp_mode = "hybrid_shard" else: - dp_mesh_dim_names = ("dp_replicate",) + dp_mesh_dim_names = ["dp_replicate"] dp_mode = "replicate" else: - dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh_dim_names = ["dp_shard_cp"] dp_mode = "fully_shard" mp_policy = MixedPrecisionPolicy( @@ -142,7 +142,7 @@ def parallelize_llama( model = data_parallel( model, - parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(dp_mesh_dim_names), mode=dp_mode, mp_policy=mp_policy, ) diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index b6ada94d00..d9c72c3431 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -38,7 +38,6 @@ def parallelize_vlm( the model must fit on GPU or CPU memory. """ assert isinstance(model.encoder, nn.Module) - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -75,13 +74,13 @@ def parallelize_vlm( if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] else: - dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh_dim_names = ["dp_shard_cp"] apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(dp_mesh_dim_names), param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, @@ -100,11 +99,12 @@ def parallelize_vlm( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_mesh, enable_compile=job_config.compile.enable, ) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index dfc6dfc9db..f8a7961457 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -54,7 +54,6 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -126,10 +125,10 @@ def parallelize_deepseekv3( if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh_dim_names = ["dp_shard_cp"] + dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP dp_mod_ep_mesh_dim_names = [] @@ -148,7 +147,7 @@ def parallelize_deepseekv3( reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + parallel_dims.get_mesh(dp_mod_ep_mesh_dim_names) if parallel_dims.ep_enabled else None ), @@ -166,9 +165,9 @@ def parallelize_deepseekv3( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, diff --git a/torchtitan/models/flux/infra/parallelize.py b/torchtitan/models/flux/infra/parallelize.py index fc9c926af0..fa1e11aee6 100644 --- a/torchtitan/models/flux/infra/parallelize.py +++ b/torchtitan/models/flux/infra/parallelize.py @@ -29,13 +29,13 @@ def parallelize_flux( if parallel_dims.fsdp_enabled: if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] else: - dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh_dim_names = ["dp_shard_cp"] apply_fsdp( model, - parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(dp_mesh_dim_names), param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], cpu_offload=job_config.training.enable_cpu_offload, @@ -131,16 +131,16 @@ def parallelize_encoders( ): if parallel_dims.dp_shard_enabled: # apply FSDP or HSDP if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard") + dp_mesh_dim_names = ["dp_replicate", "dp_shard"] else: - dp_mesh_dim_names = ("dp_shard",) + dp_mesh_dim_names = ["dp_shard"] mp_policy = MixedPrecisionPolicy( param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) fsdp_config = { - "mesh": parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + "mesh": parallel_dims.get_mesh(dp_mesh_dim_names), "mp_policy": mp_policy, } if job_config.training.enable_cpu_offload: diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 24cfcfcc62..b7f341da0e 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -56,7 +56,6 @@ def parallelize_qwen3( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh assert ( job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" @@ -128,10 +127,10 @@ def parallelize_qwen3( if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh_dim_names = ["dp_shard_cp"] + dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP dp_mod_ep_mesh_dim_names = [] @@ -150,7 +149,7 @@ def parallelize_qwen3( reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + parallel_dims.get_mesh(dp_mod_ep_mesh_dim_names) if parallel_dims.ep_enabled else None ), @@ -168,11 +167,12 @@ def parallelize_qwen3( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_mesh, enable_compile=model_compile_enabled, ) From b4f1a2781cf3e8549745b4809d727a8224a9ea15 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Nov 2025 15:34:41 -0800 Subject: [PATCH 20/38] misc --- torchtitan/components/validate.py | 4 +--- torchtitan/distributed/utils.py | 17 ++++++----------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 8087bb2a63..5dcebfb94b 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -166,9 +166,7 @@ def validate( loss = torch.sum(torch.stack(accumulated_losses)) loss /= num_steps if parallel_dims.dp_cp_enabled: - global_avg_loss = dist_utils.dist_mean( - loss, parallel_dims.get_mesh("dp_cp") - ) + global_avg_loss = dist_utils.dist_mean(loss, parallel_dims.get_mesh("loss")) else: global_avg_loss = loss.item() diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 5732743d00..814170ddc6 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -27,7 +27,7 @@ def _dist_reduce( x: torch.Tensor, reduceOp: str, - mesh: DeviceMesh, + mesh: DeviceMesh | None, extra_pg: dist.ProcessGroup | None, ) -> float: """Perform distributed reduction on a tensor. @@ -35,7 +35,8 @@ def _dist_reduce( Args: x (torch.Tensor): Input tensor. reduceOp (str): Reduce operation to perform. - mesh (DeviceMesh): Device mesh to use for reduction. + mesh (DeviceMesh | None): Device mesh to use for reduction. + If None, no reduction is performed but simply convert the tensor to a float. extra_pg (dist.ProcessGroup, optional): Extra process group to use for reduction. Defaults to None. If provided, this all_reduce will be called for the extra process group, and then the result will be all_reduced for the mesh. @@ -47,6 +48,9 @@ def _dist_reduce( if extra_pg is not None: x = funcol.all_reduce(x, reduceOp=reduceOp, group=extra_pg) + if mesh is None: + return x.item() + assert x.numel() == 1 # required by `.item()` return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() @@ -56,9 +60,6 @@ def dist_max( mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: - if mesh is None: - assert not isinstance(x, DTensor), "mesh is required for DTensor input" - return x.item() return _dist_reduce( x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh, extra_pg=extra_pg ) @@ -69,9 +70,6 @@ def dist_sum( mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: - if mesh is None: - assert not isinstance(x, DTensor), "mesh is required for DTensor input" - return x.item() return _dist_reduce( x, reduceOp=c10d.ReduceOp.SUM.name, mesh=mesh, extra_pg=extra_pg ) @@ -82,9 +80,6 @@ def dist_mean( mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: - if mesh is None: - assert not isinstance(x, DTensor), "mesh is required for DTensor input" - return x.item() return _dist_reduce( x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh, extra_pg=extra_pg ) From df3810e475226db699a0b28df81f2bbbce975706 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Nov 2025 15:40:45 -0800 Subject: [PATCH 21/38] fix --- tests/unit_tests/test_parallel_dims.py | 561 ++++++++++++++++++ torchtitan/components/optimizer.py | 8 +- torchtitan/experiments/forge/example_train.py | 4 +- torchtitan/experiments/vlm/infra/loss.py | 2 +- torchtitan/models/flux/validate.py | 4 +- 5 files changed, 569 insertions(+), 10 deletions(-) create mode 100644 tests/unit_tests/test_parallel_dims.py diff --git a/tests/unit_tests/test_parallel_dims.py b/tests/unit_tests/test_parallel_dims.py new file mode 100644 index 0000000000..1c3276dc6c --- /dev/null +++ b/tests/unit_tests/test_parallel_dims.py @@ -0,0 +1,561 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest.mock import patch + +import torch.distributed as dist +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torchtitan.distributed import ParallelDims + + +class TestParallelDimsValidation(unittest.TestCase): + """Test ParallelDims validation logic without mesh building.""" + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_basic_initialization(self): + """Test basic initialization with valid parameters.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertEqual(parallel_dims.dp_replicate, 2) + self.assertEqual(parallel_dims.dp_shard, 2) + self.assertEqual(parallel_dims.cp, 1) + self.assertEqual(parallel_dims.tp, 2) + self.assertEqual(parallel_dims.pp, 1) + self.assertEqual(parallel_dims.ep, 1) + self.assertEqual(parallel_dims.etp, 1) + self.assertEqual(parallel_dims.world_size, 8) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_auto_calculate_dp_shard(self): + """Test automatic calculation of dp_shard when set to -1.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=-1, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertEqual(parallel_dims.dp_shard, 2) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_world_size(self): + """Test validation fails when parallelism degrees don't match world_size.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=10, # Invalid: 2*2*1*2*1 = 8, not 10 + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_etp(self): + """Test validation fails when etp is not equal to tp or 1.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=4, + pp=1, + ep=2, + etp=2, # Invalid: etp must be tp or 1 when ep > 1 + world_size=8, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_zero_parallelism(self): + """Test validation fails when parallelism degree is 0.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=0, # Invalid: must be >= 1 + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_dp_shard(self): + """Test validation fails when dp_shard is invalid (not -1 and not >=1).""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=1, + dp_shard=0, # Invalid: must be -1 or >= 1 + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_enabled_properties(self): + """Test all enabled properties.""" + # Test with DP enabled + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertTrue(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + self.assertFalse(parallel_dims.cp_enabled) + self.assertTrue(parallel_dims.tp_enabled) + self.assertFalse(parallel_dims.pp_enabled) + self.assertFalse(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + + # Test with CP enabled + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=2, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=2, + ) + self.assertFalse(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.cp_enabled) + self.assertTrue(parallel_dims.dp_cp_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + + # Test with EP and ETP enabled (EP * ETP must not contribute to world_size) + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=2, + etp=1, + world_size=2, + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + + # Test with PP enabled + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=2, + ep=1, + etp=1, + world_size=2, + ) + self.assertTrue(parallel_dims.pp_enabled) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_fsdp_gradient_divide_factor(self): + """Test fsdp_gradient_divide_factor calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=3, + cp=2, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=12, + ) + # Should be dp_replicate * dp_shard * cp = 2 * 3 * 2 = 12 + self.assertEqual(parallel_dims.fsdp_gradient_divide_factor, 12) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_non_data_parallel_size(self): + """Test non_data_parallel_size calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=2, + tp=3, + pp=2, + ep=1, + etp=1, + world_size=48, + ) + # Should be cp * tp * pp = 2 * 3 * 2 = 12 + self.assertEqual(parallel_dims.non_data_parallel_size, 12) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_seq_len_divisor(self): + """Test seq_len_divisor calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=1, + cp=2, + tp=4, + pp=1, + ep=1, + etp=1, + world_size=16, + ) + # Should be tp * (cp * 2) = 4 * 4 = 16 + self.assertEqual(parallel_dims.seq_len_divisor, 16) + + +class TestParallelDimsMeshOperations(unittest.TestCase): + """Test ParallelDims mesh operations with single-rank distributed environment.""" + + def setUp(self): + """Initialize distributed environment for CPU testing.""" + if not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + init_method="tcp://localhost:12356", + world_size=1, + rank=0, + ) + + def tearDown(self): + """Clean up distributed environment.""" + if dist.is_initialized(): + dist.destroy_process_group() + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_invalid_name(self): + """Test getting mesh with invalid name raises error.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + parallel_dims.build_mesh() + + with self.assertRaises(ValueError) as context: + parallel_dims.get_mesh("invalid_mesh") + self.assertIn("Invalid mesh dim", str(context.exception)) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_lazy_initialization(self): + """Test that get_mesh triggers build_mesh if not built yet.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + # Don't call build_mesh explicitly + self.assertEqual(len(parallel_dims._meshes), 0) + + # get_mesh should trigger build_mesh + result = parallel_dims.get_mesh("tp") + # Result is None because tp has size 1, but build_mesh should have been called + self.assertGreater(len(parallel_dims._meshes), 0) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_single_rank_mesh_operations(self): + """Comprehensive test for all single-rank mesh operations. + + This test verifies mesh building, mesh retrieval, mesh sizes, and property + access when all parallelism dimensions are set to 1 (single rank). + """ + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + # Test mesh building + world_mesh = parallel_dims.build_mesh() + self.assertIsNotNone(world_mesh) + self.assertEqual(world_mesh.size(), 1) + + # Verify all expected meshes are created + self.assertIsNotNone(parallel_dims._meshes) + self.assertIn("pp", parallel_dims._meshes) + self.assertIn("batch", parallel_dims._meshes) + self.assertIn("loss", parallel_dims._meshes) + self.assertIn("dp_replicate", parallel_dims._meshes) + self.assertIn("fsdp", parallel_dims._meshes) + self.assertIn("cp", parallel_dims._meshes) + self.assertIn("tp", parallel_dims._meshes) + + # Validate 1D mesh sizes - all should be 1 for single rank + self.assertEqual(parallel_dims._meshes["dp_replicate"].size(), 1) + self.assertEqual(parallel_dims._meshes["fsdp"].size(), 1) + self.assertEqual(parallel_dims._meshes["tp"].size(), 1) + self.assertEqual(parallel_dims._meshes["batch"].size(), 1) + self.assertEqual(parallel_dims._meshes["loss"].size(), 1) + self.assertEqual(parallel_dims._meshes["pp"].size(), 1) + self.assertEqual(parallel_dims._meshes["cp"].size(), 1) + self.assertEqual(parallel_dims._meshes["ep"].size(), 1) + self.assertEqual(parallel_dims._meshes["etp"].size(), 1) + self.assertEqual(parallel_dims._meshes["efsdp"].size(), 1) + + # Validate 2D mesh shapes + self.assertEqual(parallel_dims._meshes["dp_replicate_fsdp"].shape, (1, 1)) + self.assertEqual(parallel_dims._meshes["dp_replicate_efsdp"].shape, (1, 1)) + self.assertEqual(parallel_dims._meshes["ep_etp"].shape, (1, 1)) + + # Test get_mesh returns None when all dimensions have size 1 + self.assertIsNone(parallel_dims.get_mesh("tp")) + self.assertIsNone(parallel_dims.get_mesh("dp_replicate")) + self.assertIsNone(parallel_dims.get_mesh("pp")) + self.assertIsNone(parallel_dims.get_mesh("cp")) + self.assertIsNone(parallel_dims.get_mesh("fsdp")) + + # Test get_mesh with list input + self.assertIsNone(parallel_dims.get_mesh(["dp_replicate", "fsdp"])) + + # Test get_all_meshes returns empty when all dimensions have size 1 + one_d_meshes = parallel_dims.get_all_meshes(one_dimensioal_only=True) + self.assertEqual(len(one_d_meshes), 0) + + all_meshes = parallel_dims.get_all_meshes(one_dimensioal_only=False) + self.assertEqual(len(all_meshes), 0) + + # Test world_mesh property + world_mesh_property = parallel_dims.world_mesh + self.assertIsNotNone(world_mesh_property) + self.assertEqual(world_mesh_property.size(), 1) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_with_list_input(self): + """Test get_mesh accepts both string and list inputs.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + parallel_dims.build_mesh() + + # Should accept list input + result = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + # Returns None because both dimensions have size 1 + self.assertIsNone(result) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_expert_parallelism_validation(self): + """Test expert parallelism configurations.""" + # EP with ETP = 1 (valid) - world_size = dp_replicate * dp_shard * cp * tp * pp + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=2, + etp=1, + world_size=2, # 1 * 2 * 1 * 1 * 1 = 2 + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + + # Test with larger configuration + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=3, + etp=1, + world_size=4, # 2 * 2 * 1 * 1 * 1 = 4 + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + + +class TestParallelDimsWorld8MeshOperations(DTensorTestBase): + """Test ParallelDims mesh operations with 8-rank distributed environment.""" + + @property + def world_size(self): + return 8 + + @with_comms + def test_world_size_8_mesh_operations(self): + """Comprehensive test for world_size=8 mesh operations. + + This test validates mesh building, mesh retrieval, mesh sizes, and properties + for a world_size=8 configuration with multiple parallelism dimensions enabled. + Configuration: dp_replicate=2, dp_shard=2, cp=1, tp=2, pp=1 (2*2*1*2*1 = 8) + """ + with patch( + "torchtitan.distributed.parallel_dims.device_type", self.device_type + ): + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + + # Test mesh building + world_mesh = parallel_dims.build_mesh() + self.assertIsNotNone(world_mesh) + self.assertEqual(world_mesh.size(), 8) + + # Verify all expected meshes are created + self.assertIsNotNone(parallel_dims._meshes) + self.assertIn("pp", parallel_dims._meshes) + self.assertIn("batch", parallel_dims._meshes) + self.assertIn("loss", parallel_dims._meshes) + self.assertIn("dp_replicate", parallel_dims._meshes) + self.assertIn("fsdp", parallel_dims._meshes) + self.assertIn("cp", parallel_dims._meshes) + self.assertIn("tp", parallel_dims._meshes) + self.assertIn("ep", parallel_dims._meshes) + self.assertIn("etp", parallel_dims._meshes) + self.assertIn("efsdp", parallel_dims._meshes) + + # Validate 1D mesh sizes match parallelism configuration + self.assertEqual(parallel_dims._meshes["pp"].size(), 1) + self.assertEqual( + parallel_dims._meshes["batch"].size(), 4 + ) # dp_replicate * dp_shard = 2 * 2 + self.assertEqual( + parallel_dims._meshes["loss"].size(), 4 + ) # dp_replicate * dp_shard * cp = 2 * 2 * 1 + self.assertEqual(parallel_dims._meshes["dp_replicate"].size(), 2) + self.assertEqual( + parallel_dims._meshes["fsdp"].size(), 2 + ) # dp_shard * cp = 2 * 1 + self.assertEqual(parallel_dims._meshes["cp"].size(), 1) + self.assertEqual(parallel_dims._meshes["tp"].size(), 2) + self.assertEqual(parallel_dims._meshes["ep"].size(), 1) + self.assertEqual(parallel_dims._meshes["etp"].size(), 1) + self.assertEqual( + parallel_dims._meshes["efsdp"].size(), 4 + ) # fsdp * tp / (etp * ep) = 2 * 2 / (1 * 1) = 4 + + # Validate 2D mesh shapes + self.assertEqual( + parallel_dims._meshes["dp_replicate_fsdp"].shape, (2, 2) + ) # (dp_replicate, fsdp) + self.assertEqual( + parallel_dims._meshes["dp_replicate_efsdp"].shape, (2, 4) + ) # (dp_replicate, efsdp) + self.assertEqual(parallel_dims._meshes["ep_etp"].shape, (1, 1)) # (ep, etp) + + # Test get_mesh returns valid meshes for enabled dimensions (size > 1) + self.assertIsNotNone(parallel_dims.get_mesh("tp")) + self.assertIsNotNone(parallel_dims.get_mesh("dp_replicate")) + self.assertIsNotNone(parallel_dims.get_mesh("fsdp")) + self.assertIsNotNone(parallel_dims.get_mesh("batch")) + self.assertIsNotNone(parallel_dims.get_mesh("loss")) + + # Test get_mesh returns None for disabled dimensions (size = 1) + self.assertIsNone(parallel_dims.get_mesh("pp")) + self.assertIsNone(parallel_dims.get_mesh("cp")) + self.assertIsNone(parallel_dims.get_mesh("ep")) + + # Test get_mesh with 2D mesh names + self.assertIsNotNone(parallel_dims.get_mesh(["dp_replicate", "fsdp"])) + hsdp_mesh = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + self.assertEqual(hsdp_mesh.shape, (2, 2)) + + # Test get_all_meshes returns only meshes with size > 1 + one_d_meshes = parallel_dims.get_all_meshes(one_dimensioal_only=True) + self.assertGreater(len(one_d_meshes), 0) + # Should include: dp_replicate, fsdp, tp, batch, loss, efsdp (all with size > 1) + self.assertIn("dp_replicate", one_d_meshes) + self.assertIn("fsdp", one_d_meshes) + self.assertIn("tp", one_d_meshes) + self.assertIn("batch", one_d_meshes) + self.assertIn("loss", one_d_meshes) + self.assertIn("efsdp", one_d_meshes) + # Should not include: pp, cp, ep, etp (all with size = 1) + self.assertNotIn("pp", one_d_meshes) + self.assertNotIn("cp", one_d_meshes) + self.assertNotIn("ep", one_d_meshes) + self.assertNotIn("etp", one_d_meshes) + + all_meshes = parallel_dims.get_all_meshes(one_dimensioal_only=False) + self.assertGreater(len(all_meshes), len(one_d_meshes)) + # Should also include 2D meshes + self.assertIn("dp_replicate_fsdp", all_meshes) + self.assertIn("dp_replicate_efsdp", all_meshes) + + # Test world_mesh property + world_mesh_property = parallel_dims.world_mesh + self.assertIsNotNone(world_mesh_property) + self.assertEqual(world_mesh_property.size(), 8) + + # Validate enabled properties + self.assertTrue(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + self.assertTrue(parallel_dims.tp_enabled) + self.assertFalse(parallel_dims.cp_enabled) + self.assertFalse(parallel_dims.pp_enabled) + self.assertFalse(parallel_dims.ep_enabled) + + # Validate calculated properties + self.assertEqual( + parallel_dims.fsdp_gradient_divide_factor, 4 + ) # dp_replicate * dp_shard * cp = 2 * 2 * 1 + self.assertEqual( + parallel_dims.non_data_parallel_size, 2 + ) # cp * tp * pp = 1 * 2 * 1 + self.assertEqual( + parallel_dims.seq_len_divisor, 4 + ) # tp * (cp * 2) = 2 * (1 * 2) = 2 * 2 + + +if __name__ == "__main__": + unittest.main() diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index c961c4181b..1016535c2e 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -357,8 +357,8 @@ def _update_expert_bias( model_parts: list[nn.Module], parallel_dims: ParallelDims, ): - dp_cp_mesh = ( - parallel_dims.get_mesh("dp_cp") if parallel_dims.dp_cp_enabled else None + loss_mesh = ( + parallel_dims.get_mesh("loss") if parallel_dims.dp_cp_enabled else None ) # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. @@ -380,7 +380,7 @@ def _update_expert_bias( tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) - if dp_cp_mesh is not None: + if loss_mesh is not None: if isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor): tokens_per_expert_by_layer = tokens_per_expert_by_layer.redistribute( placements=[Replicate()] @@ -388,7 +388,7 @@ def _update_expert_bias( ) else: # Perform single all-reduce to get global statistics across all processes - pg = dp_cp_mesh.get_group() + pg = loss_mesh.get_group() torch.distributed.all_reduce( tokens_per_expert_by_layer, group=pg, diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 02069ab2b2..2fed6d374f 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -262,8 +262,8 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() global_avg_loss, global_max_loss = ( - dist_utils.dist_mean(loss, parallel_dims.get_mesh("dp_cp")), - dist_utils.dist_max(loss, parallel_dims.get_mesh("dp_cp")), + dist_utils.dist_mean(loss, parallel_dims.get_mesh("loss")), + dist_utils.dist_max(loss, parallel_dims.get_mesh("loss")), ) else: global_avg_loss = global_max_loss = loss.detach().item() diff --git a/torchtitan/experiments/vlm/infra/loss.py b/torchtitan/experiments/vlm/infra/loss.py index 291cd193f3..7a3a490fb7 100644 --- a/torchtitan/experiments/vlm/infra/loss.py +++ b/torchtitan/experiments/vlm/infra/loss.py @@ -104,7 +104,7 @@ def build_token_imbalance_ce_loss( # NOTE: The device mesh where the input tokens w/ shape BSD can be sliced: # DP split the batch dim B # CP split the sequence dim S - token_mesh = parallel_dims.get_mesh("dp_cp") + token_mesh = parallel_dims.get_mesh("loss") ft_pg = ft_manager.loss_sync_pg loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg) if job_config.compile.enable and "loss" in job_config.compile.components: diff --git a/torchtitan/models/flux/validate.py b/torchtitan/models/flux/validate.py index 3d06cb05b1..f0646c9719 100644 --- a/torchtitan/models/flux/validate.py +++ b/torchtitan/models/flux/validate.py @@ -258,9 +258,7 @@ def validate( loss = torch.sum(torch.stack(accumulated_losses)) loss /= num_steps if parallel_dims.dp_cp_enabled: - global_avg_loss = dist_utils.dist_mean( - loss, parallel_dims.get_mesh("dp_cp") - ) + global_avg_loss = dist_utils.dist_mean(loss, parallel_dims.get_mesh("loss")) else: global_avg_loss = loss.item() From c8667b3c16bec7b2b8d8968ff80cdb00a9b1e36d Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Nov 2025 16:00:38 -0800 Subject: [PATCH 22/38] fix --- torchtitan/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 91d8b07251..de0ad538f4 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -577,15 +577,15 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() ft_pg = self.ft_manager.loss_sync_pg - batch_mesh = parallel_dims.get_mesh("batch") + loss_mesh = parallel_dims.get_mesh("loss") global_avg_loss, global_max_loss, global_ntokens_seen = ( - dist_utils.dist_mean(loss, batch_mesh, ft_pg), - dist_utils.dist_max(loss, batch_mesh, ft_pg), + dist_utils.dist_mean(loss, loss_mesh, ft_pg), + dist_utils.dist_max(loss, loss_mesh, ft_pg), dist_utils.dist_sum( torch.tensor( self.ntokens_seen, dtype=torch.int64, device=self.device ), - batch_mesh, + loss_mesh, ft_pg, ), ) From c2e851a626f9a8d37cbd3ac2e8c58e9120247ec6 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Nov 2025 16:19:30 -0800 Subject: [PATCH 23/38] fix --- torchtitan/experiments/gpt_oss/infra/parallelize.py | 6 +++--- .../experiments/simple_fsdp/deepseek_v3/parallelize.py | 2 +- torchtitan/models/deepseek_v3/infra/parallelize.py | 2 +- torchtitan/models/qwen3/infra/parallelize.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 2c119604e2..591e58ae45 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -95,7 +95,7 @@ def parallelize_gptoss( model, tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, - ep_tp_mesh=( + ep_etp_mesh=( parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled and parallel_dims.ep_enabled @@ -253,7 +253,7 @@ def apply_moe_ep_tp( model: nn.Module, tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, - ep_tp_mesh: DeviceMesh | None, + ep_etp_mesh: DeviceMesh | None, etp_enabled: bool, ): assert ep_mesh is not None or tp_mesh is not None @@ -298,7 +298,7 @@ def apply_moe_ep_tp( # input / output sharding on the batch / tokens dim experts_plan = ExpertParallel() else: - experts_mesh = ep_tp_mesh + experts_mesh = ep_etp_mesh experts_plan = GptossExpertTensorParallel() parallelize_module( diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index f3e2127c97..13c372e8d5 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -97,7 +97,7 @@ def parallelize_deepseekv3( model, tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, - ep_tp_mesh=( + ep_etp_mesh=( parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled and parallel_dims.ep_enabled diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index f8a7961457..5f81a8746a 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -95,7 +95,7 @@ def parallelize_deepseekv3( model, tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, - ep_tp_mesh=( + ep_etp_mesh=( parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled and parallel_dims.ep_enabled diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index b7f341da0e..e9ce47202a 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -101,7 +101,7 @@ def parallelize_qwen3( model, tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, - ep_tp_mesh=( + ep_etp_mesh=( parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled and parallel_dims.ep_enabled From e150969ff154a5e946c1b73589ebfdb71fb0dc7f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 5 Nov 2025 11:46:49 -0800 Subject: [PATCH 24/38] misc --- torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py | 1 - torchtitan/models/deepseek_v3/infra/parallelize.py | 1 - torchtitan/models/qwen3/infra/parallelize.py | 1 - 3 files changed, 3 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 13c372e8d5..2248f2711a 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -104,7 +104,6 @@ def parallelize_deepseekv3( and parallel_dims.etp_enabled else None ), - etp_enabled=parallel_dims.etp_enabled, ) if job_config.activation_checkpoint.mode != "none": diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 5f81a8746a..14807f209f 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -102,7 +102,6 @@ def parallelize_deepseekv3( and parallel_dims.etp_enabled else None ), - etp_enabled=parallel_dims.etp_enabled, ) model_compile_enabled = ( diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index e9ce47202a..544fb08f35 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -108,7 +108,6 @@ def parallelize_qwen3( and parallel_dims.etp_enabled else None ), - etp_enabled=parallel_dims.etp_enabled, ) if job_config.activation_checkpoint.mode != "none": From 398dd80af78f070f296619592f59fbd1e3ce92f1 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 5 Nov 2025 13:02:47 -0800 Subject: [PATCH 25/38] fix --- .../experiments/gpt_oss/infra/parallelize.py | 9 ++++----- .../simple_fsdp/deepseek_v3/parallelize.py | 4 ++-- .../simple_fsdp/llama3/parallelize.py | 4 ++-- .../experiments/vlm/infra/parallelize.py | 9 ++++----- .../models/deepseek_v3/infra/parallelize.py | 9 ++++----- torchtitan/models/flux/infra/parallelize.py | 18 ++++++++---------- torchtitan/models/qwen3/infra/parallelize.py | 9 ++++----- 7 files changed, 28 insertions(+), 34 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 591e58ae45..613a8ef6cb 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -120,11 +120,10 @@ def parallelize_gptoss( dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] - else: - dp_mesh_dim_names = ["dp_shard_cp"] - dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP dp_mod_ep_mesh_dim_names = [] diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 2248f2711a..3719428033 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -123,13 +123,13 @@ def parallelize_deepseekv3( ): if parallel_dims.dp_replicate_enabled: if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: - dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] + dp_mesh_dim_names = ["dp_replicate", "fsdp"] dp_mode = "hybrid_shard" else: dp_mesh_dim_names = ["dp_replicate"] dp_mode = "replicate" else: - dp_mesh_dim_names = ["dp_shard_cp"] + dp_mesh_dim_names = ["fsdp"] dp_mode = "fully_shard" dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 72ae20fd11..d64a8b79fc 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -126,13 +126,13 @@ def parallelize_llama( ): if parallel_dims.dp_replicate_enabled: if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: - dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] + dp_mesh_dim_names = ["dp_replicate", "fsdp"] dp_mode = "hybrid_shard" else: dp_mesh_dim_names = ["dp_replicate"] dp_mode = "replicate" else: - dp_mesh_dim_names = ["dp_shard_cp"] + dp_mesh_dim_names = ["fsdp"] dp_mode = "fully_shard" mp_policy = MixedPrecisionPolicy( diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index d9c72c3431..d87070bee6 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -73,14 +73,13 @@ def parallelize_vlm( if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] - else: - dp_mesh_dim_names = ["dp_shard_cp"] + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) apply_fsdp( model, - parallel_dims.get_mesh(dp_mesh_dim_names), + parallel_dims.get_mesh(names), param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 14807f209f..5b009f4253 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -123,11 +123,10 @@ def parallelize_deepseekv3( dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] - else: - dp_mesh_dim_names = ["dp_shard_cp"] - dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP dp_mod_ep_mesh_dim_names = [] diff --git a/torchtitan/models/flux/infra/parallelize.py b/torchtitan/models/flux/infra/parallelize.py index fa1e11aee6..e6f6d934e9 100644 --- a/torchtitan/models/flux/infra/parallelize.py +++ b/torchtitan/models/flux/infra/parallelize.py @@ -28,14 +28,13 @@ def parallelize_flux( apply_ac(model, job_config.activation_checkpoint) if parallel_dims.fsdp_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] - else: - dp_mesh_dim_names = ["dp_shard_cp"] + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) apply_fsdp( model, - parallel_dims.get_mesh(dp_mesh_dim_names), + parallel_dims.get_mesh(names), param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], cpu_offload=job_config.training.enable_cpu_offload, @@ -130,17 +129,16 @@ def parallelize_encoders( job_config: JobConfig, ): if parallel_dims.dp_shard_enabled: # apply FSDP or HSDP - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ["dp_replicate", "dp_shard"] - else: - dp_mesh_dim_names = ["dp_shard"] + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) mp_policy = MixedPrecisionPolicy( param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) fsdp_config = { - "mesh": parallel_dims.get_mesh(dp_mesh_dim_names), + "mesh": parallel_dims.get_mesh(names), "mp_policy": mp_policy, } if job_config.training.enable_cpu_offload: diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 544fb08f35..2c403240ce 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -125,11 +125,10 @@ def parallelize_qwen3( if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ["dp_replicate", "dp_shard_cp"] - else: - dp_mesh_dim_names = ["dp_shard_cp"] - dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP dp_mod_ep_mesh_dim_names = [] From e7f29388d65c669a7e1106e917c961b4ace9ec20 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 5 Nov 2025 13:13:00 -0800 Subject: [PATCH 26/38] fix --- tests/unit_tests/test_set_determinism.py | 45 ++++++++++++++++++------ 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/tests/unit_tests/test_set_determinism.py b/tests/unit_tests/test_set_determinism.py index 5d5ecb1557..10230b71bb 100644 --- a/tests/unit_tests/test_set_determinism.py +++ b/tests/unit_tests/test_set_determinism.py @@ -13,8 +13,8 @@ from torchtitan.distributed.utils import set_determinism -class FakeDeviceMesh: - """Fake DeviceMesh for testing seed uniqueness. +class FakeParallelDims: + """Fake ParallelDims for testing seed uniqueness. Args: mesh_dim_names: List of dimension names (e.g., ["dp", "pp", "tp"]) @@ -26,25 +26,48 @@ def __init__(self, mesh_dim_names, mesh_sizes, rank_coords): self.mesh_dim_names = mesh_dim_names self.mesh_sizes = dict(zip(mesh_dim_names, mesh_sizes)) self.rank_coords = dict(zip(mesh_dim_names, rank_coords)) + # Calculate world_size as product of all mesh sizes + self.world_size = 1 + for size in mesh_sizes: + self.world_size *= size - def __getitem__(self, key): - """Return a submesh for the given dimension(s).""" + # Create a world_mesh mock + self.world_mesh = MagicMock() + + def get_mesh(self, key): + """Return a submesh for the given dimension.""" if isinstance(key, str): # Single dimension + if key not in self.mesh_dim_names: + return None submesh = MagicMock() submesh.get_local_rank.return_value = self.rank_coords[key] submesh.size.return_value = self.mesh_sizes[key] submesh.get_coordinate.return_value = self.rank_coords[key] + submesh.device_type = "cpu" return submesh elif isinstance(key, list): - # Multiple dimensions + # Multiple dimensions - check if all exist + if not all(dim in self.mesh_dim_names for dim in key): + return None submesh = MagicMock() # For multiple dimensions, get_coordinate should return None # since we're not testing this path submesh.get_coordinate.return_value = None + submesh.device_type = "cpu" return submesh else: - raise ValueError(f"Unsupported key type: {type(key)}") + return None + + def get_all_meshes(self): + """Return a dict of all meshes.""" + return { + dim: self.get_mesh(dim) for dim in self.mesh_dim_names + } + + def __getitem__(self, key): + """Return a submesh for the given dimension(s) - for backward compatibility.""" + return self.get_mesh(key) def get_coordinate(self): """Return the coordinate tuple for this rank.""" @@ -85,7 +108,7 @@ def test_seed_uniqueness_2d_mesh(self, mock_get_rank, mock_get_world_size): # Create fake mesh for this rank rank_coords = (dp_rank, pp_rank) - fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords) + fake_mesh = FakeParallelDims(mesh_dim_names, mesh_sizes, rank_coords) # Call set_determinism with distinct seeds only on PP dimension debug_config = DebugConfig(seed=base_seed, deterministic=False) @@ -154,7 +177,7 @@ def test_seed_uniqueness_3d_mesh(self, mock_get_rank, mock_get_world_size): # Create fake mesh for this rank rank_coords = (dp_shard_rank, dp_replicate_rank, tp_rank) - fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords) + fake_mesh = FakeParallelDims(mesh_dim_names, mesh_sizes, rank_coords) # Call set_determinism with distinct seeds on dp_shard and dp_replicate only debug_config = DebugConfig(seed=base_seed, deterministic=False) @@ -218,8 +241,10 @@ def test_set_determinism_single_gpu(self, mock_get_rank, mock_get_world_size): base_seed = 42 fake_mesh = MagicMock() - fake_mesh.mesh_dim_names = None - fake_mesh.get_coordinate.return_value = None + fake_mesh.world_size = 1 + fake_mesh.world_mesh = MagicMock() + fake_mesh.get_mesh.return_value = None + fake_mesh.get_all_meshes.return_value = {} debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( From 41f750e6295e30cf290a3a84f53eb256beba93ea Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 5 Nov 2025 14:18:51 -0800 Subject: [PATCH 27/38] fix --- tests/unit_tests/test_set_determinism.py | 8 ++++---- .../experiments/simple_fsdp/deepseek_v3/parallelize.py | 3 +++ torchtitan/models/deepseek_v3/infra/parallelize.py | 3 +++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/unit_tests/test_set_determinism.py b/tests/unit_tests/test_set_determinism.py index 10230b71bb..545603e4e8 100644 --- a/tests/unit_tests/test_set_determinism.py +++ b/tests/unit_tests/test_set_determinism.py @@ -61,9 +61,7 @@ def get_mesh(self, key): def get_all_meshes(self): """Return a dict of all meshes.""" - return { - dim: self.get_mesh(dim) for dim in self.mesh_dim_names - } + return {dim: self.get_mesh(dim) for dim in self.mesh_dim_names} def __getitem__(self, key): """Return a submesh for the given dimension(s) - for backward compatibility.""" @@ -177,7 +175,9 @@ def test_seed_uniqueness_3d_mesh(self, mock_get_rank, mock_get_world_size): # Create fake mesh for this rank rank_coords = (dp_shard_rank, dp_replicate_rank, tp_rank) - fake_mesh = FakeParallelDims(mesh_dim_names, mesh_sizes, rank_coords) + fake_mesh = FakeParallelDims( + mesh_dim_names, mesh_sizes, rank_coords + ) # Call set_determinism with distinct seeds on dp_shard and dp_replicate only debug_config = DebugConfig(seed=base_seed, deterministic=False) diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 3719428033..bb15466576 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -97,6 +97,9 @@ def parallelize_deepseekv3( model, tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, + etp_mesh=parallel_dims.get_mesh("etp") + if parallel_dims.etp_enabled + else None, ep_etp_mesh=( parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 5b009f4253..bfcecfac77 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -95,6 +95,9 @@ def parallelize_deepseekv3( model, tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, + etp_mesh=parallel_dims.get_mesh("etp") + if parallel_dims.etp_enabled + else None, ep_etp_mesh=( parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled From 595e3a7a3b876d85c5bc84e5e9cae8a68a972a8e Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 5 Nov 2025 14:47:13 -0800 Subject: [PATCH 28/38] misc --- .../simple_fsdp/deepseek_v3/parallelize.py | 15 +++++++-------- .../models/deepseek_v3/infra/parallelize.py | 17 +++++++---------- torchtitan/models/qwen3/infra/parallelize.py | 17 +++++++---------- 3 files changed, 21 insertions(+), 28 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index bb15466576..7f10dcb661 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -137,18 +137,17 @@ def parallelize_deepseekv3( dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") - dp_mod_ep_mesh = parallel_dims.get_mesh(dp_mod_ep_mesh_dim_names) + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ["dp_replicate", "efsdp"] + else: + dp_mesh_dim_names = ["efsdp"] + edp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) for _, transformer_block in model.layers.items(): if transformer_block.moe_enabled and parallel_dims.ep_enabled: experts_shard_dim = 0 - assert dp_mod_ep_mesh is not None + assert edp_mesh is not None assert hasattr(transformer_block, "moe") if ( dp_mod_ep_mesh.size() * parallel_dims.ep @@ -165,7 +164,7 @@ def parallelize_deepseekv3( # https://github.com/pytorch/torchtitan/pull/1803#discussion_r2415190883 transformer_block.moe.experts = data_parallel( transformer_block.moe.experts, - dp_mod_ep_mesh, + edp_mesh, dp_mode, mp_policy=mp_policy, shard_dim=experts_shard_dim, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index bfcecfac77..b3dc7d506f 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -132,11 +132,12 @@ def parallelize_deepseekv3( dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_mesh(names) apply_fsdp( model, @@ -147,11 +148,7 @@ def parallelize_deepseekv3( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - parallel_dims.get_mesh(dp_mod_ep_mesh_dim_names) - if parallel_dims.ep_enabled - else None - ), + dp_mod_ep_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 2c403240ce..4edf9d77a5 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -131,11 +131,12 @@ def parallelize_qwen3( dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_mesh(names) apply_fsdp( model, @@ -146,11 +147,7 @@ def parallelize_qwen3( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - parallel_dims.get_mesh(dp_mod_ep_mesh_dim_names) - if parallel_dims.ep_enabled - else None - ), + dp_mod_ep_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) From 0510e59efd7e22f1125f78bab5494423d5f7a8f7 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 5 Nov 2025 16:13:13 -0800 Subject: [PATCH 29/38] misc --- torchtitan/distributed/parallel_dims.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index d84cca981a..7e7d7bf876 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -58,6 +58,11 @@ def _validate(self): if ep > 1: assert etp == tp or etp == 1, "Currently we only support ETP=TP or ETP=1" + def _mesh_exist(self, name: str, degree: int) -> bool: + if name == "efsdp": + return True if self.ep > 1 else False + return degree > 1 + def build_mesh(self) -> DeviceMesh: """ Build the device mesh with the required mesh dimensions. @@ -102,7 +107,7 @@ def unflatten_mesh( """ backend_override = {} for name, degree in zip(dim_names, dim_degrees, strict=True): - if degree == 1 or name == "batch": + if self._mesh_exist(name, degree) or name == "batch": backend_override[name] = "fake" return world_mesh._unflatten( @@ -213,8 +218,10 @@ def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None: 'cp', 'tp', 'ep', 'etp', 'efsdp' Returns: - DeviceMesh for the requested dimension(s), or None if any of - dimension(s) has size 1 (i.e., parallelism is disabled for that dimension). + DeviceMesh for the requested dimension(s). The DeviceMesh exists if + 1) dimension size is larger than 1 (the parallelism is enabled) + 2) efsdp is enabled even if size is 1 if ep is > 1. + The return value if None otherwise. Raises: ValueError: If the requested dimension name(s) is not valid. @@ -232,7 +239,7 @@ def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None: f"Valid dimensions are: {list(self._meshes.keys())}" ) - if any(self._meshes[dim].size() == 1 for dim in dims): + if any(self._mesh_exist(dim, self._meshes[dim].size()) for dim in dims): return None return self._meshes[mesh_name] From 36a81404bde7318e3f29ce56691f5f3aa3a0c922 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 5 Nov 2025 17:37:50 -0800 Subject: [PATCH 30/38] fix --- torchtitan/distributed/parallel_dims.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 7e7d7bf876..0e7420e26a 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -107,7 +107,7 @@ def unflatten_mesh( """ backend_override = {} for name, degree in zip(dim_names, dim_degrees, strict=True): - if self._mesh_exist(name, degree) or name == "batch": + if (not self._mesh_exist(name, degree)) or name == "batch": backend_override[name] = "fake" return world_mesh._unflatten( @@ -239,7 +239,7 @@ def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None: f"Valid dimensions are: {list(self._meshes.keys())}" ) - if any(self._mesh_exist(dim, self._meshes[dim].size()) for dim in dims): + if any(not self._mesh_exist(dim, self._meshes[dim].size()) for dim in dims): return None return self._meshes[mesh_name] From d7eca47deeee2e04c19596107e47b0cba5d3b72d Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 5 Nov 2025 18:08:03 -0800 Subject: [PATCH 31/38] misc --- torchtitan/distributed/expert_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index b78019e057..917d2286da 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -177,7 +177,7 @@ def _token_dispatch(self, mod, inputs, device_mesh): # The grad_placements on inputs is set to Partial so that necessary # reductions are performed during backward. routed_input = DTensor.from_local( - routed_input, device_mesh["tp"], (Replicate(),) + routed_input, device_mesh["etp"], (Replicate(),) ).to_local(grad_placements=(Partial(),)) inputs = (routed_input, num_tokens_per_expert) From 7180aa0fcafdb2c9fe2d48354b66b67d0c152151 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 5 Nov 2025 19:14:50 -0800 Subject: [PATCH 32/38] misc --- .../simple_fsdp/deepseek_v3/parallelize.py | 20 ++++++------------- torchtitan/models/qwen3/infra/parallelize.py | 13 ++++-------- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index 7f10dcb661..af25bb78b2 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -95,18 +95,10 @@ def parallelize_deepseekv3( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, - ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, - etp_mesh=parallel_dims.get_mesh("etp") - if parallel_dims.etp_enabled - else None, - ep_etp_mesh=( - parallel_dims.get_mesh("ep_etp") - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), + tp_mesh=parallel_dims.get_mesh("tp"), + ep_mesh=parallel_dims.get_mesh("ep"), + etp_mesh=parallel_dims.get_mesh("etp"), + ep_etp_mesh=parallel_dims.get_mesh(["ep", "etp"]), ) if job_config.activation_checkpoint.mode != "none": @@ -150,13 +142,13 @@ def parallelize_deepseekv3( assert edp_mesh is not None assert hasattr(transformer_block, "moe") if ( - dp_mod_ep_mesh.size() * parallel_dims.ep + edp_mesh.size() * parallel_dims.ep > transformer_block.moe.experts.num_experts ): experts_shard_dim = 1 # when EP is enable, the routed experts' gradient reduction is done over - # dp_mod_ep_mesh instead of whole dp_mesh. + # edp_mesh instead of whole dp_mesh. # we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh # to be consistent with data. # TODO (ruisizhang123): update the logic following the link below instead diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 4edf9d77a5..a379356a5d 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -99,15 +99,10 @@ def parallelize_qwen3( if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, - ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, - ep_etp_mesh=( - parallel_dims.get_mesh("ep_etp") - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), + tp_mesh=parallel_dims.get_mesh("tp"), + ep_mesh=parallel_dims.get_mesh("ep"), + etp_mesh=parallel_dims.get_mesh("etp"), + ep_etp_mesh=parallel_dims.get_mesh(["ep", "etp"]), ) if job_config.activation_checkpoint.mode != "none": From 49fe5408c3f972a125f97b6bfe42ec6d25bd23cf Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 5 Nov 2025 23:01:51 -0800 Subject: [PATCH 33/38] misc --- .../experiments/simple_fsdp/tests/test_numerics.py | 13 ++++++------- torchtitan/models/flux/train.py | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py index 76233aeb87..aaf94a5023 100644 --- a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py +++ b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py @@ -20,13 +20,13 @@ def init_test(self): self.loss_fn = cross_entropy_loss data_parallel_shard_degree = -1 if self.mode == "replicate": - self.dp_mesh_dim_names = ("dp_replicate",) + self.dp_mesh_dim_names = ["dp_replicate"] data_parallel_replicate_degree = self.world_size elif self.mode == "fully_shard": - self.dp_mesh_dim_names = ("dp_shard_cp",) + self.dp_mesh_dim_names = ["fsdp"] data_parallel_replicate_degree = 1 elif self.mode == "hybrid_shard": - self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + self.dp_mesh_dim_names = ["dp_replicate", "fsdp"] data_parallel_replicate_degree = self.world_size // 2 else: raise ValueError(f"Unsupported mode {self.mode}") @@ -41,7 +41,6 @@ def init_test(self): etp=1, world_size=self.world_size, ) - self.device_mesh = self.parallel_dims.world_mesh def get_input(self): inputs = torch.randn(8, 8).cuda() @@ -50,7 +49,7 @@ def get_input(self): return model, inputs, labels def run_fsdp2(self, model, inputs, labels, epoch=20): - fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)]) + fully_shard(model, mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names)) optim = self.optimizer(model.parameters(), lr=1e-4) losses = [] for _ in range(epoch): @@ -65,7 +64,7 @@ def run_fsdp2(self, model, inputs, labels, epoch=20): def run_simple_fsdp(self, model, inputs, labels, epoch=20): model = data_parallel( model, - device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)], + device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names), mode=self.mode, ) optim = self.optimizer(model.parameters(), lr=1e-4) @@ -82,7 +81,7 @@ def run_simple_fsdp(self, model, inputs, labels, epoch=20): def run_simple_fsdp_compiled_aot_eager(self, model, inputs, labels, epoch=20): model = data_parallel( model, - device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)], + device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names), mode=self.mode, ) # TODO: Add "inductor" backend when it's numerical issues are fixed diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 0f2e67d7c9..382ccc577d 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -31,7 +31,7 @@ def __init__(self, job_config: JobConfig): self.parallel_dims, self.device, job_config.debug, - distinct_seed_mesh_dims=["dp_shard", "dp_replicate"], + distinct_seed_mesh_dims=["fsdp", "dp_replicate"], ) # NOTE: self._dtype is the data type used for encoders (image encoder, T5 text encoder, CLIP text encoder). From a06cdbd566828ed038ea9ea843179af821ccfe05 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Nov 2025 16:27:56 -0800 Subject: [PATCH 34/38] misc --- torchtitan/distributed/parallel_dims.py | 48 ++++++++----------- .../experiments/gpt_oss/infra/parallelize.py | 22 ++++----- .../models/deepseek_v3/infra/parallelize.py | 8 ++-- torchtitan/models/llama4/infra/parallelize.py | 21 ++++---- torchtitan/models/qwen3/infra/parallelize.py | 2 +- 5 files changed, 45 insertions(+), 56 deletions(-) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 0e7420e26a..7f2d116433 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -121,7 +121,6 @@ def unflatten_mesh( ) batch = self.dp_replicate * self.dp_shard - loss = self.dp_replicate * self.dp_shard * self.cp fsdp = self.dp_shard * self.cp efsdp = fsdp * self.tp // (self.etp * self.ep) @@ -145,12 +144,12 @@ def unflatten_mesh( (self.pp, self.dp_replicate, efsdp, self.ep, self.etp), ) - # We have created all the required 1D meshes. This part is to create the - # all the 2D meshes. We pre-created 2D meshes and error out if the users - # try to access a 2D mesh that is not pre-created. - hsdp_mesh = dense_mesh["dp_replicate", "fsdp"] - ehsdp_mesh = sparse_mesh["dp_replicate", "efsdp"] - ep_etp_mesh = sparse_mesh["ep", "etp"] + self._global_meshes = { + "dataloading": dataloading_mesh, + "loss": loss_mesh, + "dense": dense_mesh, + "sparse": sparse_mesh, + } self._meshes = { "pp": dataloading_mesh["pp"], @@ -163,9 +162,6 @@ def unflatten_mesh( "ep": sparse_mesh["ep"], "efsdp": sparse_mesh["efsdp"], "etp": sparse_mesh["etp"], - "dp_replicate_fsdp": hsdp_mesh, - "dp_replicate_efsdp": ehsdp_mesh, - "ep_etp": ep_etp_mesh, } # Validate mesh sizes @@ -191,19 +187,10 @@ def _validate_meshes(self): "ep": self.ep, "efsdp": self.dp_shard * self.cp * self.tp // (self.etp * self.ep), "etp": self.etp, - "dp_replicate_fsdp": (self.dp_replicate, self.dp_shard * self.cp), - "dp_replicate_efsdp": ( - self.dp_replicate, - self.dp_shard * self.cp * self.tp // (self.etp * self.ep), - ), - "ep_etp": (self.ep, self.etp), } for mesh_name, expected_size in expected_sizes.items(): - if isinstance(expected_size, tuple): - actual_size = self._meshes[mesh_name].shape - else: - actual_size = self._meshes[mesh_name].size() + actual_size = self._meshes[mesh_name].size() assert actual_size == expected_size, ( f"Mesh '{mesh_name}' has unexpected size: " f"expected {expected_size}, got {actual_size}" @@ -232,17 +219,24 @@ def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None: if isinstance(dims, str): dims = [dims] - mesh_name = "_".join(dims) - if mesh_name not in self._meshes: - raise ValueError( - f"Invalid mesh dim: '{mesh_name}'. " - f"Valid dimensions are: {list(self._meshes.keys())}" - ) + for mesh_name in dims: + if mesh_name not in self._meshes: + raise ValueError( + f"Invalid mesh dim: '{mesh_name}'. " + f"Valid dimensions are: {list(self._meshes.keys())}" + ) if any(not self._mesh_exist(dim, self._meshes[dim].size()) for dim in dims): return None - return self._meshes[mesh_name] + if len(dims) == 1: + return self._meshes[dims[0]] + else: + for global_mesh in self._global_meshes.values(): + if not set(dims).issubset(set(global_mesh.mesh_dim_names)): + continue + return global_mesh[tuple(dims)] + raise ValueError(f"Invalid mesh name combinations {dims}.") def get_all_meshes(self, one_dimensioal_only: bool = True) -> dict[str, DeviceMesh]: if not self._meshes: diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 613a8ef6cb..6ce7fee790 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -64,6 +64,11 @@ def parallelize_gptoss( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel @@ -105,10 +110,6 @@ def parallelize_gptoss( etp_enabled=parallel_dims.etp_enabled, ) - model_compile_enabled = ( - job_config.compile.enable and "model" in job_config.compile.components - ) - if job_config.activation_checkpoint.mode != "none": apply_ac( model, @@ -126,11 +127,12 @@ def parallelize_gptoss( dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] + edp_mesh = None if parallel_dims.ep_enabled: if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + edp_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"]) + else: + edp_mesh = parallel_dims.get_mesh("efsdp") apply_fsdp( model, @@ -141,11 +143,7 @@ def parallelize_gptoss( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - parallel_dims.get_mesh(dp_mod_ep_mesh_dim_names) - if parallel_dims.ep_enabled - else None - ), + edp_mesh=edp_mesh, ) if parallel_dims.dp_replicate_enabled: diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index b3dc7d506f..10cbef809e 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -95,9 +95,9 @@ def parallelize_deepseekv3( model, tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None, ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None, - etp_mesh=parallel_dims.get_mesh("etp") - if parallel_dims.etp_enabled - else None, + etp_mesh=( + parallel_dims.get_mesh("etp") if parallel_dims.etp_enabled else None + ), ep_etp_mesh=( parallel_dims.get_mesh("ep_etp") if parallel_dims.tp_enabled @@ -148,7 +148,7 @@ def parallelize_deepseekv3( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=edp_mesh, + edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index a370c63514..a48f6f714e 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -132,12 +132,12 @@ def parallelize_llama( dp_mesh = parallel_dims.get_mesh(names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh = None + edp_mesh = None if parallel_dims.ep_enabled: if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"]) + edp_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"]) else: - dp_mod_ep_mesh = parallel_dims.get_mesh("efsdp") + edp_mesh = parallel_dims.get_mesh("efsdp") apply_fsdp( model, @@ -148,7 +148,7 @@ def parallelize_llama( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=dp_mod_ep_mesh, + edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -275,7 +275,7 @@ def apply_fsdp( cpu_offload: bool = False, reshard_after_forward_policy: str = "default", ep_degree: int = 1, - dp_mod_ep_mesh: DeviceMesh | None = None, + edp_mesh: DeviceMesh | None = None, gradient_divide_factor: int | None = None, ): """ @@ -324,10 +324,10 @@ def apply_fsdp( for layer_id, transformer_block in model.layers.items(): # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping # - the router and the shared experts are sharded together with the TransformerBlock - # - the routed experts are sharded with the remaining dp_mod_ep_mesh + # - the routed experts are sharded with the remaining edp_mesh if transformer_block.moe_enabled and ep_degree > 1: fsdp_mod_ep_config = fsdp_config.copy() - fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + fsdp_mod_ep_config["mesh"] = edp_mesh # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding @@ -336,12 +336,9 @@ def apply_fsdp( # on non-0 dim. For now it may not be worth the complexity to support # shard_placement_fn on the outer TransformerBlock-level FSDP. _experts_shard_placement_fn = None - assert dp_mod_ep_mesh is not None + assert edp_mesh is not None assert hasattr(transformer_block, "moe") - if ( - dp_mod_ep_mesh.size() * ep_degree - > transformer_block.moe.experts.num_experts - ): + if edp_mesh.size() * ep_degree > transformer_block.moe.experts.num_experts: _experts_shard_placement_fn = lambda param: Shard(1) fully_shard( diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index a379356a5d..b2498c5c97 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -142,7 +142,7 @@ def parallelize_qwen3( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=edp_mesh, + edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) From 89bd094e0c5c28b266b287fc443e65bff6d2fd9a Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Nov 2025 17:51:44 -0800 Subject: [PATCH 35/38] fix comments --- .../experiments/gpt_oss/infra/parallelize.py | 16 +++++++------- .../simple_fsdp/deepseek_v3/parallelize.py | 15 ++++++------- .../models/deepseek_v3/infra/parallelize.py | 8 +++---- torchtitan/models/llama4/infra/parallelize.py | 21 +++++++++++-------- torchtitan/models/qwen3/infra/parallelize.py | 8 +++---- 5 files changed, 36 insertions(+), 32 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 6ce7fee790..e2b54fb51d 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -121,18 +121,18 @@ def parallelize_gptoss( dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel - names = ( + dp_mesh_names = ( ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] ) - dp_mesh = parallel_dims.get_mesh(names) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - edp_mesh = None - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - edp_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"]) - else: - edp_mesh = parallel_dims.get_mesh("efsdp") + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_mesh(edp_mesh_names) apply_fsdp( model, diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index af25bb78b2..f7cd4f759f 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -128,13 +128,14 @@ def parallelize_deepseekv3( dp_mode = "fully_shard" dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) - # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ["dp_replicate", "efsdp"] - else: - dp_mesh_dim_names = ["efsdp"] - edp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_mesh(edp_mesh_names) for _, transformer_block in model.layers.items(): if transformer_block.moe_enabled and parallel_dims.ep_enabled: @@ -142,7 +143,7 @@ def parallelize_deepseekv3( assert edp_mesh is not None assert hasattr(transformer_block, "moe") if ( - edp_mesh.size() * parallel_dims.ep + edp_mesh["efsdp"].size() * parallel_dims.ep > transformer_block.moe.experts.num_experts ): experts_shard_dim = 1 diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 10cbef809e..2dd1a9ec83 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -126,18 +126,18 @@ def parallelize_deepseekv3( dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel - names = ( + dp_mesh_names = ( ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] ) - dp_mesh = parallel_dims.get_mesh(names) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - names = ( + edp_mesh_names = ( ["dp_replicate", "efsdp"] if parallel_dims.dp_replicate_enabled else ["efsdp"] ) - edp_mesh = parallel_dims.get_mesh(names) + edp_mesh = parallel_dims.get_mesh(edp_mesh_names) apply_fsdp( model, diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index a48f6f714e..908ce1dc7b 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -126,18 +126,18 @@ def parallelize_llama( if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # dp_mesh is the mesh for FSDP/HSDP - names = ( + dp_mesh_names = ( ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] ) - dp_mesh = parallel_dims.get_mesh(names) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - edp_mesh = None - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - edp_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"]) - else: - edp_mesh = parallel_dims.get_mesh("efsdp") + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_mesh(edp_mesh_names) apply_fsdp( model, @@ -338,7 +338,10 @@ def apply_fsdp( _experts_shard_placement_fn = None assert edp_mesh is not None assert hasattr(transformer_block, "moe") - if edp_mesh.size() * ep_degree > transformer_block.moe.experts.num_experts: + if ( + edp_mesh["efsdp"].size() * ep_degree + > transformer_block.moe.experts.num_experts + ): _experts_shard_placement_fn = lambda param: Shard(1) fully_shard( diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index b2498c5c97..e1d91097a9 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -120,18 +120,18 @@ def parallelize_qwen3( if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel - names = ( + dp_mesh_names = ( ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] ) - dp_mesh = parallel_dims.get_mesh(names) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - names = ( + edp_mesh_names = ( ["dp_replicate", "efsdp"] if parallel_dims.dp_replicate_enabled else ["efsdp"] ) - edp_mesh = parallel_dims.get_mesh(names) + edp_mesh = parallel_dims.get_mesh(edp_mesh_names) apply_fsdp( model, From f27917c430fc208a0dff7e9409648718e65bfed1 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Nov 2025 18:00:26 -0800 Subject: [PATCH 36/38] misc --- torchtitan/distributed/utils.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 814170ddc6..c214a647cb 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -173,14 +173,7 @@ def set_determinism( f"Distinct dims {distinct_seed_mesh_dims}, Global rank {c10d.get_rank()} using seed: {seed}" ) - # Filter out all distinct dimensions to get duplicate_seed_mesh - duplicate_seed_meshes = list( - v - for k, v in parallel_dims.get_all_meshes().items() - if k not in distinct_seed_mesh_dims - ) else: - duplicate_seed_meshes = [parallel_dims.world_mesh] logger.debug(f"Global Rank {c10d.get_rank()} using seed: {seed}") # The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency. @@ -188,12 +181,14 @@ def set_determinism( # PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1] os.environ["PYTHONHASHSEED"] = str(seed % 2**32) - # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh. - # IF PP is also used, this seed is unique per PP rank. - # TODO: remove the need of duplicate_seed_meshes once torch.distributed.tensor._random.manual_seed - # doesn't require a mesh input. - if duplicate_seed_meshes: - torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_meshes[0]) + # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for + # all ranks of the SPMD mesh. If PP is also used, this seed is unique per PP rank. + # TODO: remove the need of passing in a mes once + # torch.distributed.tensor._random.manual_seed doesn't require a mesh input. + if parallel_dims.world_size > parallel_dims.pp_size: + # We just need to pass the world_mesh as the device_id is the only information + # this API uses. + torch.distributed.tensor._random.manual_seed(seed, parallel_dims.world_mesh) def create_context_parallel_ctx( From 68215573f71886e54b404b0bc17c769c4a5b874d Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 18 Nov 2025 17:47:02 -0800 Subject: [PATCH 37/38] fix --- torchtitan/distributed/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index c214a647cb..08664d82c9 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -185,7 +185,7 @@ def set_determinism( # all ranks of the SPMD mesh. If PP is also used, this seed is unique per PP rank. # TODO: remove the need of passing in a mes once # torch.distributed.tensor._random.manual_seed doesn't require a mesh input. - if parallel_dims.world_size > parallel_dims.pp_size: + if parallel_dims.world_size > parallel_dims.pp: # We just need to pass the world_mesh as the device_id is the only information # this API uses. torch.distributed.tensor._random.manual_seed(seed, parallel_dims.world_mesh) From b53341e0bcf19248248e75dd57e979aa8dd310ef Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 19 Nov 2025 18:20:54 -0800 Subject: [PATCH 38/38] Fix --- tests/unit_tests/test_parallel_dims.py | 34 ++++++++++++------- tests/unit_tests/test_set_determinism.py | 14 ++++++++ torchtitan/distributed/parallel_dims.py | 27 ++++++++++----- .../deepseek_v3/parallelize_deepseekv3.py | 27 ++++++++++----- .../autoparallel/llama3/parallelize_llama.py | 20 ++++++----- .../experiments/gpt_oss/infra/parallelize.py | 1 - .../infra/parallelize.py | 14 ++++---- .../infra/pipeline.py | 2 +- .../models/deepseek_v3/infra/parallelize.py | 2 +- 9 files changed, 93 insertions(+), 48 deletions(-) diff --git a/tests/unit_tests/test_parallel_dims.py b/tests/unit_tests/test_parallel_dims.py index 1c3276dc6c..988bd1ad4d 100644 --- a/tests/unit_tests/test_parallel_dims.py +++ b/tests/unit_tests/test_parallel_dims.py @@ -335,9 +335,12 @@ def test_single_rank_mesh_operations(self): self.assertEqual(parallel_dims._meshes["efsdp"].size(), 1) # Validate 2D mesh shapes - self.assertEqual(parallel_dims._meshes["dp_replicate_fsdp"].shape, (1, 1)) - self.assertEqual(parallel_dims._meshes["dp_replicate_efsdp"].shape, (1, 1)) - self.assertEqual(parallel_dims._meshes["ep_etp"].shape, (1, 1)) + dp_replicate_fsdp_mesh = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + self.assertIsNone(dp_replicate_fsdp_mesh) # Both dimensions have size 1 + dp_replicate_efsdp_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"]) + self.assertIsNone(dp_replicate_efsdp_mesh) # Both dimensions have size 1 + ep_etp_mesh = parallel_dims.get_mesh(["ep", "etp"]) + self.assertIsNone(ep_etp_mesh) # Both dimensions have size 1 # Test get_mesh returns None when all dimensions have size 1 self.assertIsNone(parallel_dims.get_mesh("tp")) @@ -483,13 +486,16 @@ def test_world_size_8_mesh_operations(self): ) # fsdp * tp / (etp * ep) = 2 * 2 / (1 * 1) = 4 # Validate 2D mesh shapes + dp_replicate_fsdp_mesh = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + self.assertIsNotNone(dp_replicate_fsdp_mesh) self.assertEqual( - parallel_dims._meshes["dp_replicate_fsdp"].shape, (2, 2) + dp_replicate_fsdp_mesh.shape, (2, 2) ) # (dp_replicate, fsdp) - self.assertEqual( - parallel_dims._meshes["dp_replicate_efsdp"].shape, (2, 4) - ) # (dp_replicate, efsdp) - self.assertEqual(parallel_dims._meshes["ep_etp"].shape, (1, 1)) # (ep, etp) + # efsdp mesh only exists when ep > 1, so dp_replicate_efsdp should be None when ep=1 + dp_replicate_efsdp_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"]) + self.assertIsNone(dp_replicate_efsdp_mesh) # efsdp disabled when ep=1 + ep_etp_mesh = parallel_dims.get_mesh(["ep", "etp"]) + self.assertIsNone(ep_etp_mesh) # Both dimensions have size 1 # Test get_mesh returns valid meshes for enabled dimensions (size > 1) self.assertIsNotNone(parallel_dims.get_mesh("tp")) @@ -524,11 +530,15 @@ def test_world_size_8_mesh_operations(self): self.assertNotIn("ep", one_d_meshes) self.assertNotIn("etp", one_d_meshes) + # In the new implementation, get_all_meshes only returns 1D meshes from _meshes + # Multi-D meshes are not stored separately, but can be obtained via get_mesh() all_meshes = parallel_dims.get_all_meshes(one_dimensioal_only=False) - self.assertGreater(len(all_meshes), len(one_d_meshes)) - # Should also include 2D meshes - self.assertIn("dp_replicate_fsdp", all_meshes) - self.assertIn("dp_replicate_efsdp", all_meshes) + # Since _meshes only contains 1D meshes, both should return the same + self.assertEqual(len(all_meshes), len(one_d_meshes)) + # Verify we can get 2D meshes via get_mesh() instead + dp_replicate_fsdp = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + self.assertIsNotNone(dp_replicate_fsdp) + self.assertEqual(dp_replicate_fsdp.ndim, 2) # Test world_mesh property world_mesh_property = parallel_dims.world_mesh diff --git a/tests/unit_tests/test_set_determinism.py b/tests/unit_tests/test_set_determinism.py index 545603e4e8..24611ff9e8 100644 --- a/tests/unit_tests/test_set_determinism.py +++ b/tests/unit_tests/test_set_determinism.py @@ -31,8 +31,22 @@ def __init__(self, mesh_dim_names, mesh_sizes, rank_coords): for size in mesh_sizes: self.world_size *= size + # Add individual parallelism degree attributes to match real ParallelDims interface + self.pp = self.mesh_sizes.get("pp", 1) + self.tp = self.mesh_sizes.get("tp", 1) + self.cp = self.mesh_sizes.get("cp", 1) + self.dp_replicate = self.mesh_sizes.get("dp_replicate", 1) + self.dp_shard = self.mesh_sizes.get("dp_shard", 1) + self.ep = self.mesh_sizes.get("ep", 1) + self.etp = self.mesh_sizes.get("etp", 1) + + # For backward compatibility with 'dp' dimension name + if "dp" in self.mesh_sizes: + self.dp_replicate = self.mesh_sizes["dp"] + # Create a world_mesh mock self.world_mesh = MagicMock() + self.world_mesh.device_type = "cpu" def get_mesh(self, key): """Return a submesh for the given dimension.""" diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 7f2d116433..42e1851948 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -60,6 +60,8 @@ def _validate(self): def _mesh_exist(self, name: str, degree: int) -> bool: if name == "efsdp": + # We always keep the efsdp if EP is larger than 1 because we need + # FSDP wrapping to help the MoE layers do mixed precision training. return True if self.ep > 1 else False return degree > 1 @@ -75,9 +77,13 @@ def build_mesh(self) -> DeviceMesh: ``dp_replicate`` and ``dp_shard``. The backend is set to ``fake`` for this dimension to avoid unnecessary process group creation. loss: Used by all-reduce when computing the loss. Includes ``dp_replicate``, - ``dp_shard``, and ``cp`` degrees, as all are data parallelisms. + ``dp_shard``, and ``cp`` degrees, as all of them parallelize the data, + essentially require the weight gradients reduction. dp_replicate: For DDP or HSDP replicate dimension. - fsdp: For FSDP dimension. This includes ``dp_shard`` and ``cp``. + fsdp: For FSDP dimension. This includes ``dp_shard`` and ``cp``. Note that + we always assume that when ``cp`` is used, FSDP is also applied to + utilize its weight all-gather and gradients reduce_scatter even if + there may be no data parallelism (e.g., global batch size is 1). cp: Context Parallelism (CP). tp: Tensor Parallelism (TP). ep: Expert Parallelism (EP). @@ -86,15 +92,18 @@ def build_mesh(self) -> DeviceMesh: Note: Most dimensions above are created by unflattening the world mesh, except for loss, which is created by flattening the batch and cp dimensions. - This API performs the following unflatten operations: + This API performs the following unflatten operations from the world mesh: - ["pp", "batch", "cp", "tp"] - ["pp", "dp_replicate", "fsdp", "tp"] - ["pp", "dp_replicate", "efsdp", "ep", "etp"] + ["pp", "batch", "cp", "tp"] # dataloading_mesh + ["pp", "dp_replicate", "fsdp", "tp"] # dense_mesh + ["pp", "dp_replicate", "efsdp", "ep", "etp"] # sparse_mesh Note: DeviceMesh currently recreates the process group for each dimension. It should share the process group for the same dim group to avoid unnecessary - process group creation. + process group creation. We can also use Fake to achieve a similar goal. + However, using Fake to avoid redundancy messing up the code. We only use Fake + when it is necessary. For now, we just let DeviceMesh create redundant process + group and wait for DeviceMesh to fix the issue. """ def unflatten_mesh( @@ -197,12 +206,12 @@ def _validate_meshes(self): ) def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None: - """Get a device mesh by dimension names. + """Get a device mesh by dimension name(s). Args: dims: Names of the mesh dimension. Valid options include: 'pp', 'batch', 'loss', 'dp_replicate', 'fsdp', - 'cp', 'tp', 'ep', 'etp', 'efsdp' + 'cp', 'tp', 'ep', 'etp', 'efsdp'. Returns: DeviceMesh for the requested dimension(s). The DeviceMesh exists if diff --git a/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py index 0f718a389b..26dbc3c6d3 100644 --- a/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py @@ -206,7 +206,7 @@ def monkey_patch_checks(moe): assert not list(moe.reorderer.buffers()) -def monkey_patch_local_map_moe(model, world_mesh): +def monkey_patch_local_map_moe(model, sparse_mesh): """ TODO: fix HOPs not restoring the original signature. TODO: fix tracing with local shapes so that we can use Shard placements @@ -239,7 +239,7 @@ def monkey_patch_local_map_moe(model, world_mesh): ), redistribute_inputs=True, in_grad_placements=None, - device_mesh=world_mesh, + device_mesh=sparse_mesh, ) for block in model.layers.children(): @@ -282,7 +282,11 @@ def parallelize_deepseekv3( job_config.experimental.comms_bucket_reorder_strategy ) - world_mesh = parallel_dims.world_mesh + dense_names = ["pp", "dp_replicate", "fsdp", "tp"] + dense_names = [ + name for name in dense_names if parallel_dims.get_mesh(name) is not None + ] + dense_mesh = parallel_dims.get_mesh(dense_names) def input_fn(): global_batch_size = job_config.training.global_batch_size @@ -306,7 +310,12 @@ def input_fn(): assert parallel_dims.pp_enabled is False, "PP not supported yet" # apply local_map to MoE - monkey_patch_local_map_moe(model, world_mesh) + sparse_names = ["pp", "dp_replicate", "efsdp", "ep", "etp"] + sparse_names = [ + name for name in sparse_names if parallel_dims.get_mesh(name) is not None + ] + sparse_mesh = parallel_dims.get_mesh(sparse_names) + monkey_patch_local_map_moe(model, sparse_mesh) # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( # lambda bucket_idx: 500 / parallel_dims.tp @@ -326,7 +335,7 @@ def input_fn(): with AutoParallel( model, input_fn, - world_mesh, + dense_mesh, mp_policy=mp_policy, compile=job_config.compile, ) as autop: @@ -345,10 +354,10 @@ def input_fn(): "tp": Shard(2), } assert all( - name in possible_input_shardings for name in world_mesh.mesh_dim_names + name in possible_input_shardings for name in dense_mesh.mesh_dim_names ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" x_sharding = tuple( - possible_input_shardings[name] for name in world_mesh.mesh_dim_names + possible_input_shardings[name] for name in dense_mesh.mesh_dim_names ) out_sharding = x_sharding loss_parallel_enabled = ( @@ -358,7 +367,7 @@ def input_fn(): if loss_parallel_enabled: out_sharding = tuple( possible_output_shardings[name] - for name in world_mesh.mesh_dim_names + for name in dense_mesh.mesh_dim_names if name != "dp_replicate" ) autop.add_input_constraints([x_sharding]) @@ -381,7 +390,7 @@ def input_fn(): # it would require putting the loss inside the model as well def _return_as_dtensor_for_loss_parallel(module, args, output): return torch.distributed.tensor.DTensor.from_local( - output, world_mesh["tp"], (Shard(2),) + output, dense_mesh["tp"], (Shard(2),) ) # not keeping a reference to the hook, don't plan on diff --git a/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py b/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py index d7fbae2622..68b7458a17 100644 --- a/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py +++ b/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py @@ -44,7 +44,11 @@ def parallelize_llama( job_config.experimental.comms_bucket_reorder_strategy ) - world_mesh = parallel_dims.world_mesh + dense_names = ["pp", "dp_replicate", "fsdp", "tp"] + dense_names = [ + name for name in dense_names if parallel_dims.get_mesh(name) is not None + ] + dense_mesh = parallel_dims.get_mesh(dense_names) def input_fn(): global_batch_size = job_config.training.global_batch_size @@ -88,7 +92,7 @@ def input_fn(): with AutoParallel( model, input_fn, - world_mesh, + dense_mesh, mp_policy=mp_policy, compile=job_config.compile, ) as autop: @@ -97,20 +101,20 @@ def input_fn(): possible_input_shardings = { # maps relative to mesh dim names used in torchtitan "dp_replicate": Shard(0), - "dp_shard": Shard(0), + "fsdp": Shard(0), "tp": Replicate(), } # only used if loss parallel is enabled possible_output_shardings = { # maps relative to mesh dim names used in torchtitan - "dp_shard": Shard(0), + "fsdp": Shard(0), "tp": Shard(2), } assert all( - name in possible_input_shardings for name in world_mesh.mesh_dim_names + name in possible_input_shardings for name in dense_mesh.mesh_dim_names ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" x_sharding = tuple( - possible_input_shardings[name] for name in world_mesh.mesh_dim_names + possible_input_shardings[name] for name in dense_mesh.mesh_dim_names ) out_sharding = x_sharding loss_parallel_enabled = ( @@ -120,7 +124,7 @@ def input_fn(): if loss_parallel_enabled: out_sharding = tuple( possible_output_shardings[name] - for name in world_mesh.mesh_dim_names + for name in dense_mesh.mesh_dim_names if name != "dp_replicate" ) autop.add_input_constraints([x_sharding]) @@ -141,7 +145,7 @@ def input_fn(): # it would require putting the loss inside the model as well def _return_as_dtensor_for_loss_parallel(module, args, output): return torch.distributed.tensor.DTensor.from_local( - output, world_mesh["tp"], (Shard(2),) + output, dense_mesh["tp"], (Shard(2),) ) # not keeping a reference to the hook, don't plan on diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index e2b54fb51d..86e1248243 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -64,7 +64,6 @@ def parallelize_gptoss( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) diff --git a/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py b/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py index a049d88d76..58ff39a5a4 100644 --- a/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py +++ b/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py @@ -39,7 +39,6 @@ def parallelize_hf_transformers( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -64,11 +63,11 @@ def parallelize_hf_transformers( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, parallel_dims.get_mesh("tp")) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -90,7 +89,7 @@ def parallelize_hf_transformers( apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(list(dp_mesh_dim_names)), param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, @@ -104,17 +103,18 @@ def parallelize_hf_transformers( logger.info("Applied FSDP to the model") if parallel_dims.cp_enabled: - model.set_cp_mesh(world_mesh["cp"]) + model.set_cp_mesh(parallel_dims.get_mesh("cp")) logger.info("Applied Context Parallel to the model") if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_replicate_mesh = parallel_dims.get_mesh("dp_replicate") + if parallel_dims.world_size != dp_replicate_mesh.size(): raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_replicate_mesh, enable_compile=model_compile_enabled, ) diff --git a/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py b/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py index f05caf9abf..f27f884014 100644 --- a/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py +++ b/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py @@ -287,7 +287,7 @@ def pipeline_hf_transformers( parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: - pp_mesh = parallel_dims.world_mesh["pp"] + pp_mesh = parallel_dims.get_mesh("pp") # Determine the number of virtual stages based on schedule type schedule_class = get_schedule_class( diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 2dd1a9ec83..74868cfa5f 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -99,7 +99,7 @@ def parallelize_deepseekv3( parallel_dims.get_mesh("etp") if parallel_dims.etp_enabled else None ), ep_etp_mesh=( - parallel_dims.get_mesh("ep_etp") + parallel_dims.get_mesh(["ep", "etp"]) if parallel_dims.tp_enabled and parallel_dims.ep_enabled and parallel_dims.etp_enabled