From 1233902a54e88851f4381349d6df1ecb67134ba7 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 19 Aug 2025 15:49:44 -0700 Subject: [PATCH] [dsv3] patch graph break fix, works up until sharding rules --- .../auto_parallel/parallelize_deepseekv3.py | 24 +-- torchtitan/models/deepseek_v3/model/moe.py | 140 +++++++++--------- 2 files changed, 81 insertions(+), 83 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py index 946ec8a199..7ef9110acc 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py @@ -44,8 +44,7 @@ def input_fn(): return ( torch.randint( 0, - # job_config.training.vocab_size, - model.vocab_size, + model.model_args.vocab_size, (global_batch_size, job_config.training.seq_len), device=torch.device("cuda"), ), @@ -63,9 +62,6 @@ def input_fn(): # lambda bucket_idx: 1000 / parallel_dims.tp # ) - # bail out - return model - # if job_config.experimental.autop_force_bf16: # logger.info("Forcing bf16 on model") # model = model.bfloat16() @@ -73,13 +69,17 @@ def input_fn(): # param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] # reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] # mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - # with AutoParallel( - # model, - # input_fn, - # world_mesh, - # mp_policy=mp_policy, - # compile=job_config.training.compile, - # ) as autop: + mp_policy = None + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=job_config.training.compile, + ) as autop: + # currently errors due to missing sharding prop rules + torch.distributed.breakpoint() + # autop.add_parameter_memory_constraint(low=None, high=None) # possible_input_shardings = { diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 2554d61310..86408f82c5 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -48,6 +48,73 @@ def init_weights(self, init_std: float = 0.02): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) +# TODO: keeping this for-loop implementation for comparison +# and readability, may remove later +@expert_parallel +def _run_experts_for_loop( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, +) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx])) + h = h * torch.matmul(x_expert, w3[expert_idx]) + h = torch.matmul(h, w2[expert_idx]) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, w1)) + h = h * torch.bmm(x, w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, w2) + + return out + +@expert_parallel +def _run_experts_grouped_mm( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, +) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) + h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) + out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) + + return out + class GroupedExperts(nn.Module): def __init__( self, @@ -69,83 +136,14 @@ def forward( num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: if self.use_grouped_mm: - return GroupedExperts._run_experts_grouped_mm( + return _run_experts_grouped_mm( self.w1, self.w2, self.w3, x, num_tokens_per_expert ) else: - return GroupedExperts._run_experts_for_loop( + return _run_experts_for_loop( self.w1, self.w2, self.w3, x, num_tokens_per_expert ) - # TODO: keeping this for-loop implementation for comparison - # and readability, may remove later - @expert_parallel - @staticmethod - def _run_experts_for_loop( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() - - # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) - - # a tuple of tensors indexed by experts - # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, - dim=0, - ) - out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx])) - h = h * torch.matmul(x_expert, w3[expert_idx]) - h = torch.matmul(h, w2[expert_idx]) - # h shape (tokens_per_expert(varying), dim) - out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - - # side-effect code due to the usage of generate_permute_indices - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - else: - # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, w1)) - h = h * torch.bmm(x, w3) - # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, w2) - - return out - - @expert_parallel - @staticmethod - def _run_experts_grouped_mm( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None = None, - ) -> torch.Tensor: - if num_tokens_per_expert is not None: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 - else: - offsets = None - # fall back to regular bmm between 3D tensors - assert x.dim() == 3 - - h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets)) - h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets) - out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x) - - return out - def init_weights(self, init_std: float): nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)