Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,35 @@ def _update_expert_bias(
dp_cp_mesh = (
parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None
)

################################################################3
# AP friendly methods

def is_moe_block(block):
moe_enabled = getattr(block, "moe_enabled", False)
has_moe_submod = hasattr(block, "moe") # AP
return moe_enabled or has_moe_submod

def get_transformer_blocks(model_part):
if isinstance(model_part.layers, nn.ModuleDict):
# regular torchtitan
blocks = model_part.layers.values()
else:
# TODO: fix autoparallel to preserve the module dict
blocks = model_part.layers.children()
return blocks

def should_manual_allreduce(tokens_per_expert_by_layer):
return not isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor)
################################################################3

# TODO: Currently this sync is blocking (thus exposed) and happens on the
# default compute stream. Need to assess if this is OK performance-wise.
tokens_per_expert_list = []
for model_part in model_parts:
for transformer_block in model_part.layers.values():
if not transformer_block.moe_enabled:
blocks = get_transformer_blocks(model_part)
for transformer_block in blocks:
if not is_moe_block(transformer_block):
continue
if transformer_block.moe.load_balance_coeff is None:
return
Expand All @@ -372,17 +395,19 @@ def _update_expert_bias(
tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list)

if dp_cp_mesh is not None:
# Perform single all-reduce to get global statistics across all processes
pg = dp_cp_mesh.get_group()
torch.distributed.all_reduce(
tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM
)
if should_manual_allreduce(tokens_per_expert_by_layer):
# Perform single all-reduce to get global statistics across all processes
pg = dp_cp_mesh.get_group()
torch.distributed.all_reduce(
tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM
)

moe_layer_idx = 0
with torch.no_grad():
for model_part in model_parts:
for transformer_block in model_part.layers.values():
if not transformer_block.moe_enabled:
blocks = get_transformer_blocks(model_part)
for transformer_block in blocks:
if not is_moe_block(transformer_block):
continue
moe = transformer_block.moe

Expand Down
95 changes: 93 additions & 2 deletions torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,59 @@
from torchtitan.tools.logging import logger


def apply_local_map_to_moe():
"""
TODO: fix HOPs not restoring the original signature.
TODO: fix tracing with local shapes so that we can use Shard placements

Current HOP signature we get:

class subgraph_0(torch.nn.Module):
def forward(self,
rms_norm_5: "f32[64, 2048, 256][524288, 256, 1]cuda:0",
self____modules__layers____modules__1____modules__moe____modules__router____modules__gate____parameters__weight: "f32[8, 256][256, 1]cuda:0",
self____modules__layers____modules__1____modules__moe____buffers__expert_bias: "f32[8][1]cuda:0",
self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w1: "f32[8, 256, 256][65536, 256, 1]cuda:0",
self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w3: "f32[8, 256, 256][65536, 256, 1]cuda:0",
self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w2: "f32[8, 256, 256][65536, 256, 1]cuda:0",
self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w1____parameters__weight: "f32[512, 256][256, 1]cuda:0",
self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w3____parameters__weight: "f32[512, 256][256, 1]cuda:0",
self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w2____parameters__weight: "f32[256, 512][512, 1]cuda:0"):
"""
from torchtitan.models import moe
from torch.distributed._tensor.experimental import local_map
moe._moe_forward = local_map(
moe._moe_forward,
out_placements=(
(Replicate(),), # (Shard(0),),
(Replicate(),),
),
in_placements=(
(Replicate(),), # (Shard(0),),
(Replicate(),),
(Replicate(),),
(Replicate(),),
(Replicate(),),
(Replicate(),),
(Replicate(),),
(Replicate(),),
(Replicate(),),
),
redistribute_inputs=True,
in_grad_placements=None,
device_mesh=None,
)


# Run workflow with:
# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel
def parallelize_deepseekv3(
model,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
"""
Apply tensor parallelism, activation checkpointing, torch.compile, and data
parallelism to the model.
Apply Autoparallel to the model

NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
Expand Down Expand Up @@ -54,6 +99,9 @@ def input_fn():
assert parallel_dims.cp_enabled is False, "CP not supported yet"
assert parallel_dims.pp_enabled is False, "PP not supported yet"

# apply local_map to MoE
apply_local_map_to_moe()

# torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = (
# lambda bucket_idx: 500 / parallel_dims.tp
# )
Expand Down Expand Up @@ -131,4 +179,47 @@ def _return_as_dtensor_for_loss_parallel(module, args, output):
# removing it at any point
parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel)

_preserve_moe_attributes(model, parallel_mod)

return parallel_mod


def _preserve_moe_attributes(original_model, parallel_model):
"""
Preserve MoE custom attributes from the original model to the parallel model.
This is only needed for attributes that aren't used in the graph, so they aren't
lifted as graph inputs and fetched by the pre-graph runtime wrapper.

`moe_enabled` ane `load_balance_coeff` are used later in the optimizer to identify
this block as a moe block. This should be safe as they are read-only.
"""
def get_moe_modules(model):
"""Extract all MoE modules from the model."""
moe_modules = []
if hasattr(model, 'layers'):
if isinstance(model.layers, torch.nn.ModuleDict):
# regular torchtitan structure
blocks = model.layers.values()
else:
# autoparallel might change structure
blocks = model.layers.children() if hasattr(model.layers, 'children') else []

for block in blocks:
if hasattr(block, 'moe_enabled') and block.moe_enabled and hasattr(block, 'moe'):
moe_modules.append(block.moe)
elif hasattr(block, 'moe'): # fallback for autoparallel
moe_modules.append(block.moe)
return moe_modules

original_moe_modules = get_moe_modules(original_model)
parallel_moe_modules = get_moe_modules(parallel_model)

# Copy custom attributes from original to parallel MoE modules
# This is fine to do since these attributes are read only
for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules):
if hasattr(orig_moe, 'moe_enabled'):
par_moe.load_balance_coeff = orig_moe.load_balance_coeff

# Copy load_balance_coeff
if hasattr(orig_moe, 'load_balance_coeff'):
par_moe.load_balance_coeff = orig_moe.load_balance_coeff
137 changes: 74 additions & 63 deletions torchtitan/models/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import nn

from torchtitan.distributed.expert_parallel import expert_parallel
from torch.distributed.tensor.placement_types import Shard, Replicate


@dataclass
Expand Down Expand Up @@ -310,6 +311,77 @@ def forward(
num_tokens_per_expert,
)

def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, experts, shared_experts):
# x: 64, 2048, 256
bs, slen, dim = x.shape
x = x.view(-1, dim)

# top_scores and selected_experts_indices shape (bs*slen*top_k,)
# num_tokens_per_expert shape (num_experts,)
(
top_scores,
selected_experts_indices,
num_tokens_per_expert,
) = router(x, expert_bias)

# tokens_per_expert will be used to update the expert bias for load balancing.
# and also to count the expert usage
# TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert --
# first in the forward pass, and then in the backward pass. However, this has no
# effect on the expert bias update thanks to the torch.sign() operator.
# moved out to remove mutation
# with torch.no_grad():
# tokens_per_expert.add_(num_tokens_per_expert)

# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
# num_tokens_per_expert shape (num_experts,)
# NOTE: the reason we need to compute num_tokens_per_expert again is:
# 1st computation in router is to update self.tokens_per_expert
# which would be the same across all TP ranks.
# 2nd computation in reorderer is for the actual routing and experts computation
# which would be sharded over TP ranks if expert_tensor_parallel_degree==1.
# If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree.
(
top_scores_experts_sorted,
token_indices_experts_sorted,
num_tokens_per_expert,
) = reorderer(top_scores, selected_experts_indices)

# shape (bs*slen*top_k, dim)
token_indices_experts_sorted = token_indices_experts_sorted.reshape(
-1, 1
).expand(-1, dim)

# shape (bs*slen*top_k, dim)
routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted)

if score_before_experts:
routed_input = (
routed_input.to(torch.float32)
* top_scores_experts_sorted.reshape(-1, 1)
).to(x.dtype)

# shape (bs*slen*top_k, dim)
routed_output = experts(routed_input, num_tokens_per_expert)

if not score_before_experts:
routed_output = (
routed_output.to(torch.float32)
* top_scores_experts_sorted.reshape(-1, 1)
).to(x.dtype)

# shared expert
if shared_experts is not None:
out = shared_experts(x)
else:
out = torch.zeros_like(x)

out = out.scatter_add(
dim=0, index=token_indices_experts_sorted, src=routed_output
)
out = out.reshape(bs, slen, dim)
return out, num_tokens_per_expert


class MoE(nn.Module):
def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
Expand Down Expand Up @@ -367,72 +439,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Returns:
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
"""
bs, slen, dim = x.shape
x = x.view(-1, dim)

# top_scores and selected_experts_indices shape (bs*slen*top_k,)
# num_tokens_per_expert shape (num_experts,)
(
top_scores,
selected_experts_indices,
num_tokens_per_expert,
) = self.router(x, self.expert_bias)
out, num_tokens_per_expert = _moe_forward(x, self.router, self.expert_bias, self.reorderer, self.score_before_experts, self.experts, self.shared_experts)

# tokens_per_expert will be used to update the expert bias for load balancing.
# and also to count the expert usage
# TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert --
# first in the forward pass, and then in the backward pass. However, this has no
# effect on the expert bias update thanks to the torch.sign() operator.
# HOPs don't support buffer mutations, keep this outside
with torch.no_grad():
self.tokens_per_expert.add_(num_tokens_per_expert)

# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
# num_tokens_per_expert shape (num_experts,)
# NOTE: the reason we need to compute num_tokens_per_expert again is:
# 1st computation in router is to update self.tokens_per_expert
# which would be the same across all TP ranks.
# 2nd computation in reorderer is for the actual routing and experts computation
# which would be sharded over TP ranks if expert_tensor_parallel_degree==1.
# If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree.
(
top_scores_experts_sorted,
token_indices_experts_sorted,
num_tokens_per_expert,
) = self.reorderer(top_scores, selected_experts_indices)

# shape (bs*slen*top_k, dim)
token_indices_experts_sorted = token_indices_experts_sorted.reshape(
-1, 1
).expand(-1, dim)

# shape (bs*slen*top_k, dim)
routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted)

if self.score_before_experts:
routed_input = (
routed_input.to(torch.float32)
* top_scores_experts_sorted.reshape(-1, 1)
).to(x.dtype)

# shape (bs*slen*top_k, dim)
routed_output = self.experts(routed_input, num_tokens_per_expert)

if not self.score_before_experts:
routed_output = (
routed_output.to(torch.float32)
* top_scores_experts_sorted.reshape(-1, 1)
).to(x.dtype)

# shared expert
if self.shared_experts is not None:
out = self.shared_experts(x)
else:
out = torch.zeros_like(x)

out = out.scatter_add(
dim=0, index=token_indices_experts_sorted, src=routed_output
)
out = out.reshape(bs, slen, dim)
return out

def init_weights(
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def __init__(self, job_config: JobConfig):
# confirm that user will be able to view loss metrics on the console
ensure_pp_loss_visible(parallel_dims, job_config, color)
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
# apply Autoparallel
model = self.train_spec.parallelize_fn(model, parallel_dims, job_config)

model.to_empty(device=init_device)
Expand Down
Loading