Skip to content

Commit f72a817

Browse files
wenscarlrootmgoin
authored
[MoE] CuteDSL MoE with Nvfp4 DeepEP dispatch (#27141)
Signed-off-by: Shu Wang <[email protected]> Signed-off-by: Shu Wang. <[email protected]> Signed-off-by: Michael Goin <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent ec38a73 commit f72a817

File tree

3 files changed

+113
-47
lines changed

3 files changed

+113
-47
lines changed

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
148148
VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None
149149
VLLM_MXFP4_USE_MARLIN: bool | None = None
150+
VLLM_DEEPEPLL_NVFP4_DISPATCH: bool = False
150151
VLLM_V1_USE_OUTLINES_CACHE: bool = False
151152
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
152153
VLLM_TPU_MOST_MODEL_LEN: int | None = None
@@ -1127,6 +1128,12 @@ def get_vllm_port() -> int | None:
11271128
"VLLM_MARLIN_INPUT_DTYPE": env_with_choices(
11281129
"VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"]
11291130
),
1131+
# Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method
1132+
# only supported on Blackwell GPUs and with
1133+
# https://github.com/deepseek-ai/DeepEP/pull/341
1134+
"VLLM_DEEPEPLL_NVFP4_DISPATCH": lambda: bool(
1135+
int(os.getenv("VLLM_DEEPEPLL_NVFP4_DISPATCH", "0"))
1136+
),
11301137
# Whether to turn on the outlines cache for V1
11311138
# This cache is unbounded and on disk, so it's not safe to use in
11321139
# an environment with potentially malicious users.

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -184,31 +184,47 @@ def _do_quant(
184184
x_fp8, x_scales = x
185185
x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)
186186

187-
assert isinstance(x, torch.Tensor)
188-
189-
num_experts, max_tokens, hidden_dim = x.size()
190-
191-
# TODO (varun): Optimization - Use a batched version of quant
192-
x = x.view((-1, hidden_dim))
187+
assert isinstance(x, (torch.Tensor, tuple))
193188
q_dtype = quant_config.quant_dtype
194189

195-
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
190+
if q_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH:
196191
logger.info_once(
197-
"Skip quantization when using FlashInfer CUTEDSL(masked_gemm) "
198-
"for ModelOptNvFp4FusedMoE."
192+
"Since VLLM_DEEPEPLL_NVFP4_DISPATCH==1, make sure "
193+
"using the hybrid-ep branch of DeepEP"
194+
"(https://github.com/deepseek-ai/DeepEP/tree/hybrid-ep)"
199195
)
200-
q_dtype = None
201-
202-
x, x_scales = moe_kernel_quantize_input(
203-
x,
204-
quant_config.a1_scale,
205-
q_dtype,
206-
quant_config.per_act_token_quant,
207-
quant_config.block_shape,
208-
)
209-
x = x.view((num_experts, -1, hidden_dim))
196+
assert isinstance(x, tuple)
197+
x_scales = x[1]
198+
x = x[0].permute(2, 0, 1)
199+
num_experts, max_tokens, hidden_dim_by_2 = x.shape
200+
hidden_dim = hidden_dim_by_2 * 2
201+
assert envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm"
202+
logger.info_once(
203+
"Quantization is fused with DeepEP nvfp4 dispatch for "
204+
"FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1"
205+
)
206+
else:
207+
if q_dtype == "nvfp4":
208+
q_dtype = None
209+
logger.info_once(
210+
"Using DeepEP bfloat16 dispatch for FlashInfer CUTEDSL as "
211+
"VLLM_DEEPEPLL_NVFP4_DISPATCH==0"
212+
)
213+
assert isinstance(x, torch.Tensor)
214+
num_experts, max_tokens, hidden_dim = x.size()
215+
216+
# TODO (varun): Optimization - Use a batched version of quant
217+
x = x.view((-1, hidden_dim))
218+
x, x_scales = moe_kernel_quantize_input(
219+
x,
220+
quant_config.a1_scale,
221+
q_dtype,
222+
quant_config.per_act_token_quant,
223+
quant_config.block_shape,
224+
)
225+
x = x.view((num_experts, -1, hidden_dim))
210226

211-
if q_dtype is not None:
227+
if q_dtype is not None and q_dtype != "nvfp4":
212228
assert x_scales is not None
213229
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
214230

@@ -240,18 +256,28 @@ def prepare_async(
240256
"DeepEP kernels quantize the inputs in blocks of shape 128"
241257
)
242258

259+
use_nvfp4 = False
260+
nvfp4_dispatch = (
261+
quant_config.quant_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
262+
)
263+
if nvfp4_dispatch:
264+
use_nvfp4 = True
265+
qc_a1_gscale_or_scale = (
266+
quant_config.a1_gscale if nvfp4_dispatch else quant_config.a1_scale
267+
)
243268
has_per_token_scales = (
244-
quant_config.a1_scale.numel() != 1
245-
if quant_config.a1_scale is not None
269+
qc_a1_gscale_or_scale.numel() != 1
270+
if qc_a1_gscale_or_scale is not None
246271
else (
247272
quant_config.a2_scale.numel() != 1
248273
if quant_config.a2_scale is not None
249274
else False
250275
)
251276
)
252-
assert not has_per_token_scales, (
253-
"low_latency kernels doesn't support dispatching per-token scales"
254-
)
277+
if not use_nvfp4:
278+
assert not has_per_token_scales, (
279+
"low_latency kernels doesn't support dispatching per-token scales"
280+
)
255281

256282
if apply_router_weight_on_input:
257283
topk = topk_ids.size(1)
@@ -269,9 +295,12 @@ def prepare_async(
269295
self.max_tokens_per_rank,
270296
num_experts,
271297
use_fp8=self.use_fp8_dispatch,
272-
# round_scale needs to be set to dispatch in ue8m0
273-
round_scale=self.use_ue8m0_dispatch,
274-
use_ue8m0=self.use_ue8m0_dispatch,
298+
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
299+
**(
300+
dict(x_global_scale=qc_a1_gscale_or_scale)
301+
if qc_a1_gscale_or_scale is not None
302+
else dict()
303+
),
275304
async_finish=False,
276305
return_recv_hook=True,
277306
)

vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
7+
from vllm import envs
78
from vllm.logger import init_logger
89
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
910
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
@@ -109,7 +110,8 @@ def workspace_shapes(
109110
- Note: in order for activation chunking to work, the first dimension
110111
of each tuple must be the number of tokens.
111112
"""
112-
output_shape = (local_num_experts, M, K)
113+
K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K
114+
output_shape = (local_num_experts, M, K_dim)
113115
workspace2 = (local_num_experts, M, N)
114116
workspace1 = output_shape
115117
return (workspace1, workspace2, output_shape)
@@ -144,9 +146,18 @@ def apply(
144146
assert hidden_states.ndim == 3
145147
assert self.w1_scale.ndim == 3
146148
assert self.w2_scale.ndim == 3
149+
150+
input_global_scale = (
151+
None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale
152+
)
153+
flashinfer_hidden_states = (
154+
(hidden_states, a1q_scale)
155+
if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
156+
else hidden_states
157+
)
147158
flashinfer_cutedsl_moe_masked(
148-
hidden_states=hidden_states,
149-
input_global_scale=self.a1_gscale,
159+
hidden_states=flashinfer_hidden_states,
160+
input_global_scale=input_global_scale,
150161
w1=w1,
151162
w1_blockscale=self.w1_scale,
152163
w1_alpha=self.g1_alphas,
@@ -172,7 +183,7 @@ def get_cute_dtype(input: torch.Tensor) -> str:
172183

173184

174185
def flashinfer_cutedsl_moe_masked(
175-
hidden_states: torch.Tensor,
186+
hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
176187
input_global_scale: torch.Tensor,
177188
w1: torch.Tensor,
178189
w1_blockscale: torch.Tensor,
@@ -190,7 +201,10 @@ def flashinfer_cutedsl_moe_masked(
190201
kernels.
191202
192203
Args:
193-
hidden_states (torch.Tensor): [num_experts, m, k], bf16
204+
hidden_states: Either of the following case
205+
* torch.Tensor: [num_experts, m, k], bf16
206+
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2],
207+
uint8, [num_experts, m, k // 16], float8_e4m3fn
194208
input_global_scale (torch.Tensor): (l,)
195209
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
196210
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
@@ -207,9 +221,6 @@ def flashinfer_cutedsl_moe_masked(
207221
"""
208222

209223
# === Assertions on dtypes ===
210-
assert input_global_scale.dtype == torch.float32, (
211-
f"input_global_scale must be float32, got {input_global_scale.dtype}"
212-
)
213224
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
214225
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
215226
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
@@ -230,7 +241,32 @@ def flashinfer_cutedsl_moe_masked(
230241

231242
# === Assertions on shapes ===
232243
n = w2.shape[-1] * 2 # intermediate dimension
233-
num_experts, m, k = hidden_states.shape
244+
if isinstance(hidden_states, tuple):
245+
assert input_global_scale is None, (
246+
"input_global_scale is needed when input needs quant"
247+
)
248+
249+
aq = hidden_states[0].view(torch.uint8)
250+
aq_sf = hidden_states[1].view(torch.float8_e4m3fn)
251+
# m, k_by_2, num_experts = aq.shape
252+
num_experts, m, k_by_2 = aq.shape
253+
k = k_by_2 * 2
254+
aq = aq.permute(1, 2, 0)
255+
else:
256+
num_experts, m, k = hidden_states.shape
257+
258+
assert input_global_scale.dtype == torch.float32, (
259+
f"input_global_scale must be float32, got {input_global_scale.dtype}"
260+
)
261+
assert input_global_scale.shape == (num_experts,), (
262+
f"input_global_scale must be (l,), got {input_global_scale.shape}"
263+
)
264+
265+
aq, aq_sf = scaled_fp4_grouped_quantize(
266+
hidden_states,
267+
masked_m,
268+
input_global_scale,
269+
)
234270

235271
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
236272
assert w1.shape[-1] * 2 == k, (
@@ -241,9 +277,6 @@ def flashinfer_cutedsl_moe_masked(
241277
n // 2,
242278
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
243279

244-
assert input_global_scale.shape == (num_experts,), (
245-
f"input_global_scale must be (l,), got {input_global_scale.shape}"
246-
)
247280
assert w1_alpha.shape == (num_experts,), (
248281
f"w1_alpha must be (l,), got {w1_alpha.shape}"
249282
)
@@ -254,20 +287,17 @@ def flashinfer_cutedsl_moe_masked(
254287
f"w2_alpha must be (l,), got {w2_alpha.shape}"
255288
)
256289

257-
aq, aq_sf = scaled_fp4_grouped_quantize(
258-
hidden_states,
259-
masked_m,
260-
input_global_scale,
261-
)
262-
263290
workspace = workspace.permute(1, 2, 0) # requirement of kernel
264291
sf_vec_size = 16
265292
assert aq_sf.dtype == torch.float8_e4m3fn
266293
assert aq.dtype == torch.uint8
267294
ab_dtype = "float4_e2m1fn"
268295
sf_dtype = "float8_e4m3fn"
269296

270-
c_dtype = get_cute_dtype(hidden_states)
297+
if isinstance(hidden_states, tuple):
298+
c_dtype = "bfloat16"
299+
else:
300+
c_dtype = get_cute_dtype(hidden_states)
271301

272302
# Gemm1
273303
flashinfer_cutedsl_grouped_gemm_nt_masked(

0 commit comments

Comments
 (0)