diff --git a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py index 800c8c83a6b..8073fd45601 100644 --- a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py @@ -11,12 +11,14 @@ ) from sglang.srt.layers.moe.ep_moe.kernels import ( + deepep_ll_get_cutlass_w4a8_moe_mm_data, deepep_permute_triton_kernel, deepep_post_reorder_triton_kernel, deepep_run_moe_deep_preprocess, post_reorder_triton_kernel_for_cutlass_moe, pre_reorder_triton_kernel_for_cutlass_moe, run_moe_ep_preproess, + silu_and_mul_masked_post_per_tensor_quant_fwd, ) @@ -396,3 +398,139 @@ def cutlass_w4a8_moe_deepep_normal( ) return output + + +def cutlass_w4a8_moe_deepep_ll( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_ids_: torch.Tensor, + masked_m: torch.Tensor, + a_strides1: torch.Tensor, + b_strides1: torch.Tensor, + c_strides1: torch.Tensor, + a_strides2: torch.Tensor, + b_strides2: torch.Tensor, + c_strides2: torch.Tensor, + s_strides13: torch.Tensor, + s_strides2: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a w4a8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + grouped gemm. + + Parameters: + - a (torch.Tensor): The input tensor to the MoE layer. + Shape: [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, K] + - w1_q (torch.Tensor): The first set of int4-quantized expert weights. + Shape: [num_experts, N * 2, K // 2] + (the weights are passed transposed and int4-packed) + - w2_q (torch.Tensor): The second set of int4-quantized expert weights. + Shape: [num_experts, K, N // 2] + (the weights are passed transposed and int4-packed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts, K // 512, N * 8] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts, N // 512, K * 4] + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - a_strides1 (torch.Tensor): The input strides of the first grouped gemm. + - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm. + - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. + - a_strides2 (torch.Tensor): The input strides of the second grouped gemm. + - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm. + - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. + - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm. + - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [1, K] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [1, N] + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is 1. + + Returns: + - torch.Tensor: The fp8 output tensor after applying the MoE layer. + """ + assert w1_q.dtype == torch.int8 + assert w2_q.dtype == torch.int8 + assert a.shape[2] // 2 == w1_q.shape[2], "Hidden size mismatch w1" + assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" + + assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch" + assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch" + assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch" + assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch" + num_experts = w1_q.size(0) + m = a.size(1) + k = w1_q.size(2) * 2 # w1_q is transposed and packed + n = w2_q.size(2) * 2 # w2_q is transposed and packed + topk = topk_ids_.size(1) + + device = a.device + + problem_sizes1, problem_sizes2 = deepep_ll_get_cutlass_w4a8_moe_mm_data( + masked_m, + problem_sizes1, + problem_sizes2, + num_experts, + n, + k, + ) + + gateup_input = torch.empty(a.shape, dtype=torch.float8_e4m3fn, device=device) + sgl_per_tensor_quant_fp8(a, gateup_input, a1_scale.float(), True) + c1 = torch.empty((num_experts, m, n * 2), device=device, dtype=torch.bfloat16) + c2 = torch.empty((num_experts, m, k), device=device, dtype=torch.bfloat16) + + cutlass_w4a8_moe_mm( + c1, + gateup_input, + w1_q, + a1_scale.float(), + w1_scale, + expert_offsets[:-1], + problem_sizes1, + a_strides1, + b_strides1, + c_strides1, + s_strides13, + 128, + topk, + ) + + intermediate_q = torch.empty( + (num_experts, m, n), device=a.device, dtype=torch.float8_e4m3fn + ) + silu_and_mul_masked_post_per_tensor_quant_fwd( + c1, intermediate_q, masked_m, a2_scale + ) + cutlass_w4a8_moe_mm( + c2, + intermediate_q, + w2_q, + a2_scale.float(), + w2_scale, + expert_offsets[:-1], + problem_sizes2, + a_strides2, + b_strides2, + c_strides2, + s_strides2, + 128, + topk, + ) + + return c2 diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 62aa943901d..166f42ea9e6 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -1014,3 +1014,197 @@ def zero_experts_compute_triton( ) return output + + +@triton.jit +def compute_problem_sizes_w4a8_kernel( + masked_m_ptr, + problem_sizes1_ptr, + problem_sizes2_ptr, + n, + k, + num_experts, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = pid < num_experts + final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0) + + ps1_idx_0 = pid * 3 + ps1_idx_1 = ps1_idx_0 + 1 + ps1_idx_2 = ps1_idx_0 + 2 + + ps2_idx_0 = pid * 3 + ps2_idx_1 = ps2_idx_0 + 1 + ps2_idx_2 = ps2_idx_0 + 2 + + ps1_mask_0 = ps1_idx_0 < num_experts * 3 + ps1_mask_1 = ps1_idx_1 < num_experts * 3 + ps1_mask_2 = ps1_idx_2 < num_experts * 3 + ps2_mask_0 = ps2_idx_0 < num_experts * 3 + ps2_mask_1 = ps2_idx_1 < num_experts * 3 + ps2_mask_2 = ps2_idx_2 < num_experts * 3 + + tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0) + tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1) + tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2) + + tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0) + tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1) + tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2) + + +def compute_problem_sizes_w4a8( + masked_m, problem_sizes1, problem_sizes2, n, k, num_experts +): + BLOCK_SIZE = 256 + grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),) + compute_problem_sizes_w4a8_kernel[grid]( + masked_m, + problem_sizes1, + problem_sizes2, + n, + k, + num_experts, + BLOCK_SIZE=BLOCK_SIZE, + ) + return problem_sizes1, problem_sizes2 + + +def deepep_ll_get_cutlass_w4a8_moe_mm_data( + masked_m, + problem_sizes1, + problem_sizes2, + num_experts, + n, + k, +): + problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8( + masked_m, problem_sizes1, problem_sizes2, n, k, num_experts + ) + return ( + problem_sizes1.to(torch.int32), + problem_sizes2.to(torch.int32), + ) + + +@triton.jit +def _silu_and_mul_post_per_tensor_quant_kernel( + input_ptr, + stride_input_expert, + stride_input_token, + stride_input_dim, + output_ptr, + stride_output_expert, + stride_output_token, + stride_output_dim, + scale_ptr, + masked_m_ptr, + inner_dim, + fp8_max, + fp8_min, + BLOCK_N: tl.constexpr, + NUM_STAGE: tl.constexpr, +): + """ + Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization. + + Shape: + input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D] + output: [E, T_padded, D], dtype=float8_e4m3fn + """ + expert_id = tl.program_id(2) + block_id_token = tl.program_id(1) + block_id_dim = tl.program_id(0) + + num_token_blocks = tl.num_programs(1) + + token_num_cur_expert = tl.load(masked_m_ptr + expert_id) + + scale = 1.0 / tl.load(scale_ptr).to(tl.float32) + + stride_input_expert = tl.cast(stride_input_expert, tl.int32) + stride_output_expert = tl.cast(stride_output_expert, tl.int32) + stride_input_token = tl.cast(stride_input_token, tl.int32) + stride_output_token = tl.cast(stride_output_token, tl.int32) + + offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N) + mask_d = offset_d < inner_dim + + # base pointers for current expert and dim block + input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d + output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d + + for token_idx in tl.range( + block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE + ): + gate_ptr = input_base_offs + token_idx * stride_input_token + up_ptr = gate_ptr + inner_dim + gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32) + up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32) + + # SiLU: x * sigmoid(x) + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(input_ptr.dtype.element_ty) + gate_up = up * gate + + scaled = gate_up * scale + output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty) + out_ptr = output_base_offs + token_idx * stride_output_token + tl.store(out_ptr, output_q, mask=mask_d) + + +def silu_and_mul_masked_post_per_tensor_quant_fwd( + input: torch.Tensor, + output: torch.Tensor, + masked_m: torch.Tensor, + scale: torch.Tensor, +) -> torch.Tensor: + """ + Fused SiLU + Mul + Per-Tensor Quantization to FP8. + + Args: + input: [expert_num, token_num_padded, 2 * inner_dim] + output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn + masked_m: [expert_num], actual token count for each expert + scale: [1] or [expert_num], quantization scale (per-tensor or per-expert) + + Returns: + output tensor + """ + assert input.is_contiguous() + assert output.is_contiguous() + assert output.dtype == torch.float8_e4m3fn + assert input.ndim == 3 + assert input.shape[0] == masked_m.shape[0] + assert input.shape[-1] % 2 == 0 + assert scale.numel() == 1 or scale.shape[0] == input.shape[0] + + expert_num = input.shape[0] + # 3584 + inner_dim = input.shape[-1] // 2 + + BLOCK_N = 256 + BLOCK_M = 64 if expert_num < 4 else 32 + NUM_STAGES = 3 + hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N) + + grid = (hidden_dim_split_block_num, BLOCK_M, expert_num) + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + fp8_min = -fp8_max + + _silu_and_mul_post_per_tensor_quant_kernel[grid]( + input, + *input.stride(), + output, + *output.stride(), + scale, + masked_m, + inner_dim, + fp8_max, + fp8_min, + BLOCK_N=BLOCK_N, + NUM_STAGE=NUM_STAGES, + ) + return output diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 12f04eb9ee7..f300eb0df1d 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -100,6 +100,7 @@ def __init__( self.use_fp8_w8a8 = False self.use_block_quant = False else: + self.use_w4afp8 = False self.use_fp8_w8a8 = False self.use_block_quant = False self.use_w4afp8 = False @@ -199,6 +200,8 @@ def run_moe_core( return self.forward_flashinfer_cutedsl( dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args ) + elif self.use_w4afp8: + return self.forward_cutlass_w4afp8_masked(dispatch_output) assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 assert down_gemm_overlap_args is None return self.forward_deepgemm_masked(dispatch_output) @@ -513,6 +516,20 @@ def forward_deepgemm_masked( return down_output + def forward_cutlass_w4afp8_masked( + self, + dispatch_output: DeepEPNormalOutput, + ): + assert self.moe_runner_config.activation == "silu" + assert isinstance(self.quant_method, W4AFp8MoEMethod) + assert get_bool_env_var( + "SGLANG_DEEPEP_BF16_DISPATCH" + ), "W4AFP8 does not support FP8 dispatch; please set SGLANG_DEEPEP_BF16_DISPATCH=1." + return self.quant_method.apply_deepep_ll( + layer=self, + dispatch_output=dispatch_output, + ) + def forward_npu( self, dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput], diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index 44deaa8af14..4a6f6def715 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -23,6 +23,7 @@ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE from sglang.srt.layers.moe.token_dispatcher import ( CombineInput, + DeepEPLLOutput, DeepEPNormalOutput, StandardDispatchOutput, ) @@ -328,6 +329,41 @@ def apply( output *= self.moe_runner_config.routed_scaling_factor return StandardCombineInput(hidden_states=output) + def apply_deepep_ll( + self, + layer: DeepEPMoE, + dispatch_output: DeepEPLLOutput, + ) -> torch.Tensor: + + from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe_deepep_ll + + hidden_states, _, topk_ids, _, masked_m, _ = dispatch_output + + output = cutlass_w4a8_moe_deepep_ll( + hidden_states, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale_inv, + layer.w2_weight_scale_inv, + topk_ids, + masked_m, + layer.quant_method.a_strides1, + layer.quant_method.b_strides1, + layer.quant_method.c_strides1, + layer.quant_method.a_strides2, + layer.quant_method.b_strides2, + layer.quant_method.c_strides2, + layer.quant_method.s_strides13, + layer.quant_method.s_strides2, + layer.quant_method.expert_offsets, + layer.quant_method.problem_sizes1, + layer.quant_method.problem_sizes2, + layer.w13_input_scale, + layer.w2_input_scale, + ) + + return output + def apply_deepep_normal( self, layer: DeepEPMoE, diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh index 8cd50c60c1d..8f29ec379c5 100644 --- a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_get_group_starts.cuh @@ -34,6 +34,40 @@ __global__ void int4_fp8_get_group_gemm_starts( b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * k / 128 : expert_id); } +template +__global__ void int4_fp8_get_group_gemm_starts_3d( + ElementA** a_offsets, + ElementB** b_offsets, + ElementC** out_offsets, + ElementAccumulator** a_scales_offsets, + cutlass::bfloat16_t** b_scales_offsets, + ElementA* a_base_as_int, + ElementB* b_base_as_int, + ElementC* out_base_as_int, + ElementAccumulator* a_scales_base_as_int, + cutlass::bfloat16_t* b_scales_base_as_int, + int64_t n, + int64_t m, + int64_t k, + bool per_act_token, + bool per_out_ch, + int num_experts) { + int expert_id = blockIdx.x * blockDim.x + threadIdx.x; + if (expert_id >= num_experts) return; + + int64_t a_offset = expert_id * m * k; + int64_t b_offset = expert_id * k * n / 2; + int64_t out_offset = expert_id * m * n; + int64_t a_scales_offset = 0; + int64_t b_scales_offset = per_out_ch ? expert_id * n * 4 * k / 512 : expert_id; + + a_offsets[expert_id] = a_base_as_int + a_offset; + b_offsets[expert_id] = b_base_as_int + b_offset; + out_offsets[expert_id] = out_base_as_int + out_offset; + a_scales_offsets[expert_id] = a_scales_base_as_int + a_scales_offset; + b_scales_offsets[expert_id] = b_scales_base_as_int + b_scales_offset; +} + #define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ int4_fp8_get_group_gemm_starts \ @@ -55,6 +89,28 @@ __global__ void int4_fp8_get_group_gemm_starts( per_out_ch); \ } +#define __CALL_W4A8_GET_STARTS_KERNEL_3D(TENSOR_C_TYPE, C_TYPE) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + int4_fp8_get_group_gemm_starts_3d \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + out_tensors.size(2), \ + a_tensors.size(1), \ + a_tensors.size(2), \ + per_act_token, \ + per_out_ch, \ + num_experts); \ + } + namespace { void run_int4_fp8_get_group_gemm_starts( @@ -80,12 +136,22 @@ void run_int4_fp8_get_group_gemm_starts( auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index()); - if (false) { - } - __CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) - __CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half) - else { - TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + if (a_tensors.dim() == 3) { + if (false) { + } + __CALL_W4A8_GET_STARTS_KERNEL_3D(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_W4A8_GET_STARTS_KERNEL_3D(torch::kFloat16, half) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } + } else { + if (false) { + } + __CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } } } diff --git a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh index d8b794997a5..5afd1c34728 100644 --- a/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh +++ b/sgl-kernel/csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cuh @@ -174,7 +174,7 @@ void cutlass_w4a8_group_gemm_caller( bool per_out_ch = b_scales.numel() != num_experts; // Check inputs - TORCH_CHECK(a_tensors.dim() == 2, "A tensor must be 2D"); + TORCH_CHECK(a_tensors.dim() == 2 or a_tensors.dim() == 3, "A tensor must be 2D/3D"); TORCH_CHECK(b_tensors.dim() == 3, "B tensor must be 3D [E, N, K/2]"); TORCH_CHECK(b_scales.dim() == 3, "Scale tensor must be 3D [E, K//512, N*4]"); TORCH_CHECK(a_scales.dim() == 1, "A Scale tensor must be 1D [1]"); @@ -186,7 +186,9 @@ void cutlass_w4a8_group_gemm_caller( TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have 3 columns (N, M, K)"); TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups"); TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups"); - TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension"); + TORCH_CHECK( + b_tensors.size(2) * 2 == a_tensors.size(1) or b_tensors.size(2) * 2 == a_tensors.size(2), + "B tensor K/2 dimension must match A tensor K dimension"); // Check tensor types TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type"); diff --git a/test/srt/quant/test_w4a8_deepseek_v3.py b/test/srt/quant/test_w4a8_deepseek_v3.py index 30e02279685..064986a577a 100644 --- a/test/srt/quant/test_w4a8_deepseek_v3.py +++ b/test/srt/quant/test_w4a8_deepseek_v3.py @@ -1,3 +1,4 @@ +import os import unittest from types import SimpleNamespace @@ -173,5 +174,73 @@ def test_gsm8k( self.assertGreater(metrics["accuracy"], 0.92) +class TestDeepseekV3W4Afp8DeepepAutoMtp(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = try_cached_model(DEFAULT_DEEPSEEK_W4AFP8_MODEL_FOR_TEST) + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--tp", + "8", + "--trust-remote-code", + "--ep-size", + "8", + "--cuda-graph-bs", + "256", + "--disable-radix-cache", + "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "auto", + "--dp", + "8", + "--enable-dp-attention", + "--moe-runner-backend", + "cutlass", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + ] + if not is_in_amd_ci(): + other_args += ["--mem-frac", "0.7"] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + env={ + **os.environ, + "SGLANG_DEEPEP_BF16_DISPATCH": "1", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256", + }, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k( + self, + ): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Eval accuracy of GSM8K: {metrics=}") + + self.assertGreater(metrics["accuracy"], 0.92) + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index f1b30e6ae3b..5ab6484ea9a 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -169,7 +169,7 @@ class TestFile: TestFile("test_disaggregation_hybrid_attention.py", 200), ], "per-commit-8-gpu-h20": [ - TestFile("quant/test_w4a8_deepseek_v3.py", 371), + TestFile("quant/test_w4a8_deepseek_v3.py", 520), TestFile("test_disaggregation_different_tp.py", 600), TestFile("test_disaggregation_pp.py", 140), ],