Skip to content

Commit 131d467

Browse files
committed
misc
1 parent 9ae9e2a commit 131d467

File tree

5 files changed

+44
-56
lines changed

5 files changed

+44
-56
lines changed

torchtitan/distributed/parallel_dims.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def unflatten_mesh(
121121
)
122122

123123
batch = self.dp_replicate * self.dp_shard
124-
loss = self.dp_replicate * self.dp_shard * self.cp
125124
fsdp = self.dp_shard * self.cp
126125
efsdp = fsdp * self.tp // (self.etp * self.ep)
127126

@@ -145,12 +144,12 @@ def unflatten_mesh(
145144
(self.pp, self.dp_replicate, efsdp, self.ep, self.etp),
146145
)
147146

148-
# We have created all the required 1D meshes. This part is to create the
149-
# all the 2D meshes. We pre-created 2D meshes and error out if the users
150-
# try to access a 2D mesh that is not pre-created.
151-
hsdp_mesh = dense_mesh["dp_replicate", "fsdp"]
152-
ehsdp_mesh = sparse_mesh["dp_replicate", "efsdp"]
153-
ep_etp_mesh = sparse_mesh["ep", "etp"]
147+
self._global_meshes = {
148+
"dataloading": dataloading_mesh,
149+
"loss": loss_mesh,
150+
"dense": dense_mesh,
151+
"sparse": sparse_mesh,
152+
}
154153

155154
self._meshes = {
156155
"pp": dataloading_mesh["pp"],
@@ -163,9 +162,6 @@ def unflatten_mesh(
163162
"ep": sparse_mesh["ep"],
164163
"efsdp": sparse_mesh["efsdp"],
165164
"etp": sparse_mesh["etp"],
166-
"dp_replicate_fsdp": hsdp_mesh,
167-
"dp_replicate_efsdp": ehsdp_mesh,
168-
"ep_etp": ep_etp_mesh,
169165
}
170166

171167
# Validate mesh sizes
@@ -191,19 +187,10 @@ def _validate_meshes(self):
191187
"ep": self.ep,
192188
"efsdp": self.dp_shard * self.cp * self.tp // (self.etp * self.ep),
193189
"etp": self.etp,
194-
"dp_replicate_fsdp": (self.dp_replicate, self.dp_shard * self.cp),
195-
"dp_replicate_efsdp": (
196-
self.dp_replicate,
197-
self.dp_shard * self.cp * self.tp // (self.etp * self.ep),
198-
),
199-
"ep_etp": (self.ep, self.etp),
200190
}
201191

202192
for mesh_name, expected_size in expected_sizes.items():
203-
if isinstance(expected_size, tuple):
204-
actual_size = self._meshes[mesh_name].shape
205-
else:
206-
actual_size = self._meshes[mesh_name].size()
193+
actual_size = self._meshes[mesh_name].size()
207194
assert actual_size == expected_size, (
208195
f"Mesh '{mesh_name}' has unexpected size: "
209196
f"expected {expected_size}, got {actual_size}"
@@ -232,17 +219,24 @@ def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None:
232219
if isinstance(dims, str):
233220
dims = [dims]
234221

235-
mesh_name = "_".join(dims)
236-
if mesh_name not in self._meshes:
237-
raise ValueError(
238-
f"Invalid mesh dim: '{mesh_name}'. "
239-
f"Valid dimensions are: {list(self._meshes.keys())}"
240-
)
222+
for mesh_name in dims:
223+
if mesh_name not in self._meshes:
224+
raise ValueError(
225+
f"Invalid mesh dim: '{mesh_name}'. "
226+
f"Valid dimensions are: {list(self._meshes.keys())}"
227+
)
241228

242229
if any(not self._mesh_exist(dim, self._meshes[dim].size()) for dim in dims):
243230
return None
244231

245-
return self._meshes[mesh_name]
232+
if len(dims) == 1:
233+
return self._meshes[dims[0]]
234+
else:
235+
for global_mesh in self._global_meshes.values():
236+
if not set(dims).issubset(set(global_mesh.mesh_dim_names)):
237+
continue
238+
return global_mesh[tuple(dims)]
239+
raise ValueError(f"Invalid mesh name combinations {dims}.")
246240

247241
def get_all_meshes(self, one_dimensioal_only: bool = True) -> dict[str, DeviceMesh]:
248242
if not self._meshes:

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def parallelize_gptoss(
6464
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
6565
raise NotImplementedError("CP support for FlexAttention is still in progress.")
6666

67+
model_compile_enabled = (
68+
job_config.compile.enable and "model" in job_config.compile.components
69+
)
70+
6771
if parallel_dims.tp_enabled:
6872
if (
6973
job_config.parallelism.enable_async_tensor_parallel
@@ -105,10 +109,6 @@ def parallelize_gptoss(
105109
etp_enabled=parallel_dims.etp_enabled,
106110
)
107111

108-
model_compile_enabled = (
109-
job_config.compile.enable and "model" in job_config.compile.components
110-
)
111-
112112
if job_config.activation_checkpoint.mode != "none":
113113
apply_ac(
114114
model,
@@ -127,11 +127,12 @@ def parallelize_gptoss(
127127
dp_mesh = parallel_dims.get_mesh(names)
128128

129129
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
130-
dp_mod_ep_mesh_dim_names = []
130+
edp_mesh = None
131131
if parallel_dims.ep_enabled:
132132
if parallel_dims.dp_replicate_enabled:
133-
dp_mod_ep_mesh_dim_names.append("dp_replicate")
134-
dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")
133+
edp_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"])
134+
else:
135+
edp_mesh = parallel_dims.get_mesh("efsdp")
135136

136137
apply_fsdp(
137138
model,
@@ -142,11 +143,7 @@ def parallelize_gptoss(
142143
cpu_offload=job_config.training.enable_cpu_offload,
143144
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
144145
ep_degree=parallel_dims.ep,
145-
dp_mod_ep_mesh=(
146-
parallel_dims.get_mesh(dp_mod_ep_mesh_dim_names)
147-
if parallel_dims.ep_enabled
148-
else None
149-
),
146+
edp_mesh=edp_mesh,
150147
)
151148

152149
if parallel_dims.dp_replicate_enabled:

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ def parallelize_deepseekv3(
9292
model,
9393
tp_mesh=parallel_dims.get_mesh("tp") if parallel_dims.tp_enabled else None,
9494
ep_mesh=parallel_dims.get_mesh("ep") if parallel_dims.ep_enabled else None,
95-
etp_mesh=parallel_dims.get_mesh("etp")
96-
if parallel_dims.etp_enabled
97-
else None,
95+
etp_mesh=(
96+
parallel_dims.get_mesh("etp") if parallel_dims.etp_enabled else None
97+
),
9898
ep_etp_mesh=(
9999
parallel_dims.get_mesh("ep_etp")
100100
if parallel_dims.tp_enabled
@@ -146,7 +146,7 @@ def parallelize_deepseekv3(
146146
cpu_offload=job_config.training.enable_cpu_offload,
147147
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
148148
ep_degree=parallel_dims.ep,
149-
dp_mod_ep_mesh=edp_mesh,
149+
edp_mesh=edp_mesh,
150150
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
151151
)
152152

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,12 @@ def parallelize_llama(
133133
dp_mesh = parallel_dims.get_mesh(names)
134134

135135
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
136-
dp_mod_ep_mesh = None
136+
edp_mesh = None
137137
if parallel_dims.ep_enabled:
138138
if parallel_dims.dp_replicate_enabled:
139-
dp_mod_ep_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"])
139+
edp_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"])
140140
else:
141-
dp_mod_ep_mesh = parallel_dims.get_mesh("efsdp")
141+
edp_mesh = parallel_dims.get_mesh("efsdp")
142142

143143
apply_fsdp(
144144
model,
@@ -149,7 +149,7 @@ def parallelize_llama(
149149
cpu_offload=job_config.training.enable_cpu_offload,
150150
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
151151
ep_degree=parallel_dims.ep,
152-
dp_mod_ep_mesh=dp_mod_ep_mesh,
152+
edp_mesh=edp_mesh,
153153
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
154154
)
155155

@@ -274,7 +274,7 @@ def apply_fsdp(
274274
cpu_offload: bool = False,
275275
reshard_after_forward_policy: str = "default",
276276
ep_degree: int = 1,
277-
dp_mod_ep_mesh: DeviceMesh | None = None,
277+
edp_mesh: DeviceMesh | None = None,
278278
gradient_divide_factor: int | None = None,
279279
):
280280
"""
@@ -323,10 +323,10 @@ def apply_fsdp(
323323
for layer_id, transformer_block in model.layers.items():
324324
# NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping
325325
# - the router and the shared experts are sharded together with the TransformerBlock
326-
# - the routed experts are sharded with the remaining dp_mod_ep_mesh
326+
# - the routed experts are sharded with the remaining edp_mesh
327327
if transformer_block.moe_enabled and ep_degree > 1:
328328
fsdp_mod_ep_config = fsdp_config.copy()
329-
fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh
329+
fsdp_mod_ep_config["mesh"] = edp_mesh
330330

331331
# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts).
332332
# When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
@@ -335,12 +335,9 @@ def apply_fsdp(
335335
# on non-0 dim. For now it may not be worth the complexity to support
336336
# shard_placement_fn on the outer TransformerBlock-level FSDP.
337337
_experts_shard_placement_fn = None
338-
assert dp_mod_ep_mesh is not None
338+
assert edp_mesh is not None
339339
assert hasattr(transformer_block, "moe")
340-
if (
341-
dp_mod_ep_mesh.size() * ep_degree
342-
> transformer_block.moe.experts.num_experts
343-
):
340+
if edp_mesh.size() * ep_degree > transformer_block.moe.experts.num_experts:
344341
_experts_shard_placement_fn = lambda param: Shard(1)
345342

346343
fully_shard(

torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def parallelize_qwen3(
138138
cpu_offload=job_config.training.enable_cpu_offload,
139139
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
140140
ep_degree=parallel_dims.ep,
141-
dp_mod_ep_mesh=edp_mesh,
141+
edp_mesh=edp_mesh,
142142
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
143143
)
144144

0 commit comments

Comments
 (0)