Skip to content

Commit 70d6c03

Browse files
committed
fix comments
1 parent 131d467 commit 70d6c03

File tree

5 files changed

+36
-32
lines changed

5 files changed

+36
-32
lines changed

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,18 @@ def parallelize_gptoss(
121121
dp_mesh: DeviceMesh | None = None
122122
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
123123
# apply FSDP or HSDP, potentially with Context Parallel
124-
names = (
124+
dp_mesh_names = (
125125
["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"]
126126
)
127-
dp_mesh = parallel_dims.get_mesh(names)
127+
dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
128128

129129
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
130-
edp_mesh = None
131-
if parallel_dims.ep_enabled:
132-
if parallel_dims.dp_replicate_enabled:
133-
edp_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"])
134-
else:
135-
edp_mesh = parallel_dims.get_mesh("efsdp")
130+
edp_mesh_names = (
131+
["dp_replicate", "efsdp"]
132+
if parallel_dims.dp_replicate_enabled
133+
else ["efsdp"]
134+
)
135+
edp_mesh = parallel_dims.get_mesh(edp_mesh_names)
136136

137137
apply_fsdp(
138138
model,

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,21 +130,22 @@ def parallelize_deepseekv3(
130130
dp_mode = "fully_shard"
131131

132132
dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names)
133-
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
134133

135-
if parallel_dims.dp_replicate_enabled:
136-
dp_mesh_dim_names = ["dp_replicate", "efsdp"]
137-
else:
138-
dp_mesh_dim_names = ["efsdp"]
139-
edp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names)
134+
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
135+
edp_mesh_names = (
136+
["dp_replicate", "efsdp"]
137+
if parallel_dims.dp_replicate_enabled
138+
else ["efsdp"]
139+
)
140+
edp_mesh = parallel_dims.get_mesh(edp_mesh_names)
140141

141142
for _, transformer_block in model.layers.items():
142143
if transformer_block.moe_enabled and parallel_dims.ep_enabled:
143144
experts_shard_dim = 0
144145
assert edp_mesh is not None
145146
assert hasattr(transformer_block, "moe")
146147
if (
147-
edp_mesh.size() * parallel_dims.ep
148+
edp_mesh["efsdp"].size() * parallel_dims.ep
148149
> transformer_block.moe.experts.num_experts
149150
):
150151
experts_shard_dim = 1

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,18 +124,18 @@ def parallelize_deepseekv3(
124124
dp_mesh: DeviceMesh | None = None
125125
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
126126
# apply FSDP or HSDP, potentially with Context Parallel
127-
names = (
127+
dp_mesh_names = (
128128
["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"]
129129
)
130-
dp_mesh = parallel_dims.get_mesh(names)
130+
dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
131131

132132
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
133-
names = (
133+
edp_mesh_names = (
134134
["dp_replicate", "efsdp"]
135135
if parallel_dims.dp_replicate_enabled
136136
else ["efsdp"]
137137
)
138-
edp_mesh = parallel_dims.get_mesh(names)
138+
edp_mesh = parallel_dims.get_mesh(edp_mesh_names)
139139

140140
apply_fsdp(
141141
model,

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,18 @@ def parallelize_llama(
127127

128128
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
129129
# dp_mesh is the mesh for FSDP/HSDP
130-
names = (
130+
dp_mesh_names = (
131131
["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"]
132132
)
133-
dp_mesh = parallel_dims.get_mesh(names)
133+
dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
134134

135135
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
136-
edp_mesh = None
137-
if parallel_dims.ep_enabled:
138-
if parallel_dims.dp_replicate_enabled:
139-
edp_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"])
140-
else:
141-
edp_mesh = parallel_dims.get_mesh("efsdp")
136+
edp_mesh_names = (
137+
["dp_replicate", "efsdp"]
138+
if parallel_dims.dp_replicate_enabled
139+
else ["efsdp"]
140+
)
141+
edp_mesh = parallel_dims.get_mesh(edp_mesh_names)
142142

143143
apply_fsdp(
144144
model,
@@ -337,7 +337,10 @@ def apply_fsdp(
337337
_experts_shard_placement_fn = None
338338
assert edp_mesh is not None
339339
assert hasattr(transformer_block, "moe")
340-
if edp_mesh.size() * ep_degree > transformer_block.moe.experts.num_experts:
340+
if (
341+
edp_mesh["efsdp"].size() * ep_degree
342+
> transformer_block.moe.experts.num_experts
343+
):
341344
_experts_shard_placement_fn = lambda param: Shard(1)
342345

343346
fully_shard(

torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,18 @@ def parallelize_qwen3(
116116

117117
if parallel_dims.fsdp_enabled:
118118
# apply FSDP or HSDP, potentially with Context Parallel
119-
names = (
119+
dp_mesh_names = (
120120
["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"]
121121
)
122-
dp_mesh = parallel_dims.get_mesh(names)
122+
dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
123123

124124
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
125-
names = (
125+
edp_mesh_names = (
126126
["dp_replicate", "efsdp"]
127127
if parallel_dims.dp_replicate_enabled
128128
else ["efsdp"]
129129
)
130-
edp_mesh = parallel_dims.get_mesh(names)
130+
edp_mesh = parallel_dims.get_mesh(edp_mesh_names)
131131

132132
apply_fsdp(
133133
model,

0 commit comments

Comments
 (0)