Skip to content

Commit 561253b

Browse files
jiahancmgoin
andauthored
[Performance][Fix] update nvfp4 code to support renorm routing (#28569)
Signed-off-by: jiahanc <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent 80b6080 commit 561253b

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.model_executor.layers.fused_moe.config import (
1616
FusedMoEConfig,
1717
FusedMoEQuantConfig,
18+
RoutingMethodType,
1819
fp8_w8a8_moe_quant_config,
1920
nvfp4_moe_quant_config,
2021
)
@@ -1657,16 +1658,19 @@ def apply(
16571658
use_llama4_routing = (
16581659
custom_routing_function is Llama4MoE.custom_routing_function
16591660
)
1660-
routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3
1661+
routing_method_type = layer.routing_method_type
16611662
if use_llama4_routing:
1662-
routing_method_type = flashinfer.RoutingMethodType.Llama4
1663+
routing_method_type = RoutingMethodType.Llama4
1664+
router_logits = (
1665+
router_logits.to(torch.float32)
1666+
if routing_method_type == RoutingMethodType.DeepSeekV3
1667+
else router_logits
1668+
)
16631669
routing_bias = e_score_correction_bias
16641670
if routing_bias is not None:
16651671
routing_bias = routing_bias.to(torch.bfloat16)
16661672
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
1667-
routing_logits=router_logits
1668-
if use_llama4_routing
1669-
else router_logits.to(torch.float32),
1673+
routing_logits=router_logits,
16701674
routing_bias=routing_bias,
16711675
hidden_states=hidden_states_fp4,
16721676
hidden_states_scale=hidden_states_scale_linear_fp4.view(
@@ -1690,8 +1694,8 @@ def apply(
16901694
output2_scale_scalar=layer.g2_alphas.data,
16911695
num_experts=global_num_experts,
16921696
top_k=top_k,
1693-
n_group=num_expert_group if num_expert_group is not None else 0,
1694-
topk_group=topk_group if topk_group is not None else 0,
1697+
n_group=num_expert_group,
1698+
topk_group=topk_group,
16951699
intermediate_size=layer.intermediate_size_per_partition,
16961700
local_expert_offset=layer.ep_rank * layer.local_num_experts,
16971701
local_num_experts=layer.local_num_experts,

vllm/model_executor/layers/quantization/utils/flashinfer_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,5 +291,8 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
291291

292292
def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool:
293293
# TODO(shuw@nvidia): Update when new backends are added.
294-
backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,)
294+
backends_supporting_global_sf = (
295+
FlashinferMoeBackend.CUTLASS,
296+
FlashinferMoeBackend.TENSORRT_LLM,
297+
)
295298
return backend in backends_supporting_global_sf

0 commit comments

Comments
 (0)