Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
)

from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_ll_get_cutlass_w4a8_moe_mm_data,
deepep_permute_triton_kernel,
deepep_post_reorder_triton_kernel,
deepep_run_moe_deep_preprocess,
post_reorder_triton_kernel_for_cutlass_moe,
pre_reorder_triton_kernel_for_cutlass_moe,
run_moe_ep_preproess,
silu_and_mul_masked_post_per_tensor_quant_fwd,
)


Expand Down Expand Up @@ -396,3 +398,139 @@ def cutlass_w4a8_moe_deepep_normal(
)

return output


def cutlass_w4a8_moe_deepep_ll(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_ids_: torch.Tensor,
masked_m: torch.Tensor,
a_strides1: torch.Tensor,
b_strides1: torch.Tensor,
c_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides2: torch.Tensor,
c_strides2: torch.Tensor,
s_strides13: torch.Tensor,
s_strides2: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.

Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, K]
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
Shape: [num_experts, N * 2, K // 2]
(the weights are passed transposed and int4-packed)
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
Shape: [num_experts, K, N // 2]
(the weights are passed transposed and int4-packed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts, K // 512, N * 8]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts, N // 512, K * 4]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [1, K]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [1, N]
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.

Returns:
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
"""
assert w1_q.dtype == torch.int8
assert w2_q.dtype == torch.int8
assert a.shape[2] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"

assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
num_experts = w1_q.size(0)
m = a.size(1)
k = w1_q.size(2) * 2 # w1_q is transposed and packed
n = w2_q.size(2) * 2 # w2_q is transposed and packed
topk = topk_ids_.size(1)

device = a.device

problem_sizes1, problem_sizes2 = deepep_ll_get_cutlass_w4a8_moe_mm_data(
masked_m,
problem_sizes1,
problem_sizes2,
num_experts,
n,
k,
)

gateup_input = torch.empty(a.shape, dtype=torch.float8_e4m3fn, device=device)
sgl_per_tensor_quant_fp8(a, gateup_input, a1_scale.float(), True)
c1 = torch.empty((num_experts, m, n * 2), device=device, dtype=torch.bfloat16)
c2 = torch.empty((num_experts, m, k), device=device, dtype=torch.bfloat16)

cutlass_w4a8_moe_mm(
c1,
gateup_input,
w1_q,
a1_scale.float(),
w1_scale,
expert_offsets[:-1],
problem_sizes1,
a_strides1,
b_strides1,
c_strides1,
s_strides13,
128,
topk,
)

intermediate_q = torch.empty(
(num_experts, m, n), device=a.device, dtype=torch.float8_e4m3fn
)
silu_and_mul_masked_post_per_tensor_quant_fwd(
c1, intermediate_q, masked_m, a2_scale
)
cutlass_w4a8_moe_mm(
c2,
intermediate_q,
w2_q,
a2_scale.float(),
w2_scale,
expert_offsets[:-1],
problem_sizes2,
a_strides2,
b_strides2,
c_strides2,
s_strides2,
128,
topk,
)

return c2
194 changes: 194 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,3 +1014,197 @@ def zero_experts_compute_triton(
)

return output


@triton.jit
def compute_problem_sizes_w4a8_kernel(
masked_m_ptr,
problem_sizes1_ptr,
problem_sizes2_ptr,
n,
k,
num_experts,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pid < num_experts
final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)

ps1_idx_0 = pid * 3
ps1_idx_1 = ps1_idx_0 + 1
ps1_idx_2 = ps1_idx_0 + 2

ps2_idx_0 = pid * 3
ps2_idx_1 = ps2_idx_0 + 1
ps2_idx_2 = ps2_idx_0 + 2

ps1_mask_0 = ps1_idx_0 < num_experts * 3
ps1_mask_1 = ps1_idx_1 < num_experts * 3
ps1_mask_2 = ps1_idx_2 < num_experts * 3
ps2_mask_0 = ps2_idx_0 < num_experts * 3
ps2_mask_1 = ps2_idx_1 < num_experts * 3
ps2_mask_2 = ps2_idx_2 < num_experts * 3

tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0)
tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1)
tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2)

tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0)
tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1)
tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2)


def compute_problem_sizes_w4a8(
masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
):
BLOCK_SIZE = 256
grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),)
compute_problem_sizes_w4a8_kernel[grid](
masked_m,
problem_sizes1,
problem_sizes2,
n,
k,
num_experts,
BLOCK_SIZE=BLOCK_SIZE,
)
return problem_sizes1, problem_sizes2


def deepep_ll_get_cutlass_w4a8_moe_mm_data(
masked_m,
problem_sizes1,
problem_sizes2,
num_experts,
n,
k,
):
problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8(
masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
)
return (
problem_sizes1.to(torch.int32),
problem_sizes2.to(torch.int32),
)


@triton.jit
def _silu_and_mul_post_per_tensor_quant_kernel(
input_ptr,
stride_input_expert,
stride_input_token,
stride_input_dim,
output_ptr,
stride_output_expert,
stride_output_token,
stride_output_dim,
scale_ptr,
masked_m_ptr,
inner_dim,
fp8_max,
fp8_min,
BLOCK_N: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
"""
Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.

Shape:
input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D]
output: [E, T_padded, D], dtype=float8_e4m3fn
"""
expert_id = tl.program_id(2)
block_id_token = tl.program_id(1)
block_id_dim = tl.program_id(0)

num_token_blocks = tl.num_programs(1)

token_num_cur_expert = tl.load(masked_m_ptr + expert_id)

scale = 1.0 / tl.load(scale_ptr).to(tl.float32)

stride_input_expert = tl.cast(stride_input_expert, tl.int32)
stride_output_expert = tl.cast(stride_output_expert, tl.int32)
stride_input_token = tl.cast(stride_input_token, tl.int32)
stride_output_token = tl.cast(stride_output_token, tl.int32)

offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N)
mask_d = offset_d < inner_dim

# base pointers for current expert and dim block
input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d
output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d

for token_idx in tl.range(
block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE
):
gate_ptr = input_base_offs + token_idx * stride_input_token
up_ptr = gate_ptr + inner_dim
gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32)
up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32)

# SiLU: x * sigmoid(x)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate

scaled = gate_up * scale
output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty)
out_ptr = output_base_offs + token_idx * stride_output_token
tl.store(out_ptr, output_q, mask=mask_d)


def silu_and_mul_masked_post_per_tensor_quant_fwd(
input: torch.Tensor,
output: torch.Tensor,
masked_m: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
"""
Fused SiLU + Mul + Per-Tensor Quantization to FP8.

Args:
input: [expert_num, token_num_padded, 2 * inner_dim]
output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
masked_m: [expert_num], actual token count for each expert
scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)

Returns:
output tensor
"""
assert input.is_contiguous()
assert output.is_contiguous()
assert output.dtype == torch.float8_e4m3fn
assert input.ndim == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0
assert scale.numel() == 1 or scale.shape[0] == input.shape[0]

expert_num = input.shape[0]
# 3584
inner_dim = input.shape[-1] // 2

BLOCK_N = 256
BLOCK_M = 64 if expert_num < 4 else 32
NUM_STAGES = 3
hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N)

grid = (hidden_dim_split_block_num, BLOCK_M, expert_num)
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max

_silu_and_mul_post_per_tensor_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
scale,
masked_m,
inner_dim,
fp8_max,
fp8_min,
BLOCK_N=BLOCK_N,
NUM_STAGE=NUM_STAGES,
)
return output
17 changes: 17 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
self.use_fp8_w8a8 = False
self.use_block_quant = False
else:
self.use_w4afp8 = False
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.use_w4afp8 = False
Expand Down Expand Up @@ -199,6 +200,8 @@ def run_moe_core(
return self.forward_flashinfer_cutedsl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
elif self.use_w4afp8:
return self.forward_cutlass_w4afp8_masked(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
assert down_gemm_overlap_args is None
return self.forward_deepgemm_masked(dispatch_output)
Expand Down Expand Up @@ -513,6 +516,20 @@ def forward_deepgemm_masked(

return down_output

def forward_cutlass_w4afp8_masked(
self,
dispatch_output: DeepEPNormalOutput,
):
assert self.moe_runner_config.activation == "silu"
assert isinstance(self.quant_method, W4AFp8MoEMethod)
assert get_bool_env_var(
"SGLANG_DEEPEP_BF16_DISPATCH"
), "W4AFP8 does not support FP8 dispatch; please set SGLANG_DEEPEP_BF16_DISPATCH=1."
return self.quant_method.apply_deepep_ll(
layer=self,
dispatch_output=dispatch_output,
)

def forward_npu(
self,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
Expand Down
Loading
Loading