Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 34 additions & 37 deletions torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,43 +77,40 @@ def input_fn():
mp_policy=mp_policy,
compile=job_config.training.compile,
) as autop:
# currently errors due to missing sharding prop rules
torch.distributed.breakpoint()

# autop.add_parameter_memory_constraint(low=None, high=None)

# possible_input_shardings = {
# # maps relative to mesh dim names used in torchtitan
# "dp_replicate": Shard(0),
# "dp_shard": Shard(0),
# "tp": Replicate(),
# }
# # only used if loss parallel is enabled
# possible_output_shardings = {
# # maps relative to mesh dim names used in torchtitan
# "dp_shard": Shard(0),
# "tp": Shard(2),
# }
# assert all(
# name in possible_input_shardings for name in world_mesh.mesh_dim_names
# ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel"
# x_sharding = tuple(
# possible_input_shardings[name] for name in world_mesh.mesh_dim_names
# )
# out_sharding = x_sharding
# if parallel_dims.loss_parallel_enabled:
# out_sharding = tuple(
# possible_output_shardings[name]
# for name in world_mesh.mesh_dim_names
# if name != "dp_replicate"
# )
# autop.add_input_constraints([x_sharding])
# autop.add_output_constraints([out_sharding])
# t0 = time.time()
# sharding_placement = autop.optimize_placement()
# t1 = time.time()
# logger.info(f"AutoParallel took {t1 - t0} seconds")
# parallel_mod = autop.apply_placement(sharding_placement)
autop.add_parameter_memory_constraint(low=None, high=None)

possible_input_shardings = {
# maps relative to mesh dim names used in torchtitan
"dp_replicate": Shard(0),
"dp_shard": Shard(0),
"tp": Replicate(),
}
# only used if loss parallel is enabled
possible_output_shardings = {
# maps relative to mesh dim names used in torchtitan
"dp_shard": Shard(0),
"tp": Shard(2),
}
assert all(
name in possible_input_shardings for name in world_mesh.mesh_dim_names
), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel"
x_sharding = tuple(
possible_input_shardings[name] for name in world_mesh.mesh_dim_names
)
out_sharding = x_sharding
if parallel_dims.loss_parallel_enabled:
out_sharding = tuple(
possible_output_shardings[name]
for name in world_mesh.mesh_dim_names
if name != "dp_replicate"
)
autop.add_input_constraints([x_sharding])
autop.add_output_constraints([out_sharding])
t0 = time.time()
sharding_placement = autop.optimize_placement()
t1 = time.time()
logger.info(f"AutoParallel took {t1 - t0} seconds")
parallel_mod = autop.apply_placement(sharding_placement)

if parallel_dims.loss_parallel_enabled:

Expand Down
4 changes: 3 additions & 1 deletion torchtitan/models/deepseek_v3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs):
self.attention = Attention(model_args)
self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.moe_enabled = layer_id >= model_args.n_dense_layers
# self.moe_enabled = layer_id >= model_args.n_dense_layers
# TODO: enable me when local_map works
self.moe_enabled = False

if self.moe_enabled:
self.moe = MoE(model_args)
Expand Down
Loading