Skip to content

Commit 39d2810

Browse files
authored
[Feat] Support non-gated activations in NVFP4 modelopt path (#29004)
1 parent cd719de commit 39d2810

File tree

5 files changed

+98
-22
lines changed

5 files changed

+98
-22
lines changed

tests/kernels/moe/test_flashinfer_moe.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
FlashInferExperts,
1717
is_valid_flashinfer_cutlass_fused_moe,
1818
)
19+
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
20+
create_flashinfer_prepare_finalize,
21+
)
1922
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
2023
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
21-
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
22-
MoEPrepareAndFinalizeNoEP,
23-
)
2424
from vllm.platforms import current_platform
2525
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
2626

@@ -48,9 +48,10 @@
4848
@pytest.mark.parametrize("e", [40, 64, 256])
4949
@pytest.mark.parametrize("topk", [1, 6, 8])
5050
@pytest.mark.parametrize("dtype", [torch.bfloat16])
51+
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
5152
@torch.inference_mode()
5253
def test_flashinfer_fp4_moe_no_graph(
53-
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
54+
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str
5455
):
5556
current_platform.seed_everything(7)
5657
with set_current_vllm_config(
@@ -59,6 +60,7 @@ def test_flashinfer_fp4_moe_no_graph(
5960
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
6061

6162
quant_blocksize = 16
63+
is_gated_act = activation == "silu_and_mul"
6264

6365
w1_q, w2_q, quant_config = make_test_quant_config(
6466
e,
@@ -68,6 +70,7 @@ def test_flashinfer_fp4_moe_no_graph(
6870
quant_dtype="nvfp4",
6971
block_shape=None,
7072
per_act_token_quant=False,
73+
make_gate=is_gated_act,
7174
)
7275

7376
score = torch.randn((m, e), device="cuda", dtype=dtype)
@@ -76,16 +79,19 @@ def test_flashinfer_fp4_moe_no_graph(
7679
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
7780

7881
flashinfer_experts = FusedMoEModularKernel(
79-
MoEPrepareAndFinalizeNoEP(),
82+
create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True),
8083
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
8184
)
8285

86+
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
87+
8388
flashinfer_output = flashinfer_experts(
8489
hidden_states=a,
8590
w1=w1_q,
8691
w2=w2_q,
8792
topk_weights=topk_weights,
8893
topk_ids=topk_ids,
94+
activation=fi_activation,
8995
)
9096

9197
# Reference check:
@@ -103,7 +109,9 @@ def test_flashinfer_fp4_moe_no_graph(
103109
block_size=quant_blocksize,
104110
)
105111

106-
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
112+
w1_d = torch.empty(
113+
(e, (2 if is_gated_act else 1) * n, k), device="cuda", dtype=dtype
114+
)
107115
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
108116

109117
for idx in range(0, e):
@@ -124,7 +132,9 @@ def test_flashinfer_fp4_moe_no_graph(
124132
block_size=quant_blocksize,
125133
)
126134

127-
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
135+
torch_output = torch_moe(
136+
a_in_dtype, w1_d, w2_d, score, topk, activation=activation
137+
)
128138

129139
torch.testing.assert_close(
130140
torch_output, flashinfer_output, atol=1e-1, rtol=1e-1

tests/kernels/moe/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,20 @@ def make_test_weights(
264264
quant_dtype: torch.dtype | str | None = None,
265265
block_shape: list[int] | None = None,
266266
per_out_ch_quant: bool = False,
267+
make_gate: bool = True,
267268
) -> tuple[
268269
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
269270
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
270271
]:
271272
return (
272273
make_test_weight(
273-
e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_out_ch_quant
274+
e,
275+
(2 if make_gate else 1) * n,
276+
k,
277+
in_dtype,
278+
quant_dtype,
279+
block_shape,
280+
per_out_ch_quant,
274281
),
275282
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
276283
)
@@ -297,6 +304,7 @@ def make_test_quant_config(
297304
quant_dtype: torch.dtype | str | None = None,
298305
per_act_token_quant: bool = False,
299306
block_shape: list[int] | None = None,
307+
make_gate: bool = True,
300308
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
301309
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
302310
e,
@@ -306,6 +314,7 @@ def make_test_quant_config(
306314
quant_dtype,
307315
per_out_ch_quant=per_act_token_quant,
308316
block_shape=block_shape,
317+
make_gate=make_gate,
309318
)
310319

311320
# Hacky/trivial scales for nvfp4.

tests/kernels/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from tests.kernels.quant_utils import native_w8a8_block_matmul
1616
from vllm.attention.backends.abstract import AttentionType
17+
from vllm.model_executor.custom_op import CustomOp
1718
from vllm.model_executor.layers.activation import SiluAndMul
1819
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
1920
from vllm.utils.torch_utils import make_tensor_with_pad
@@ -839,6 +840,7 @@ def torch_experts(
839840
per_act_token_quant=False,
840841
block_shape: list[int] | None = None,
841842
apply_router_weights_on_input: bool = False,
843+
activation: str = "silu_and_mul",
842844
) -> torch.Tensor:
843845
assert (
844846
global_num_experts == -1
@@ -881,14 +883,16 @@ def torch_experts(
881883

882884
f32 = torch.float32
883885

886+
act = CustomOp.op_registry[activation]
887+
884888
for i in range(num_experts):
885889
mask = topk_ids == i
886890
if mask.sum():
887891
if quant_dtype is None:
888892
tmp1 = a[mask] @ w1[i].transpose(0, 1)
889893
if b_bias1 is not None:
890894
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
891-
tmp2 = SiluAndMul()(tmp1)
895+
tmp2 = act()(tmp1)
892896
out[mask] = tmp2 @ w2[i].transpose(0, 1)
893897
if b_bias2 is not None:
894898
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
@@ -969,6 +973,7 @@ def torch_moe(
969973
b_bias2: torch.Tensor | None = None,
970974
global_num_experts: int = -1,
971975
expert_map: torch.Tensor | None = None,
976+
activation: str = "silu_and_mul",
972977
) -> torch.Tensor:
973978
score = torch.softmax(score, dim=-1, dtype=torch.float32)
974979
topk_weight, topk_ids = torch.topk(score, topk)
@@ -982,6 +987,7 @@ def torch_moe(
982987
b_bias1,
983988
b_bias2,
984989
expert_map,
990+
activation=activation,
985991
)
986992

987993

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,14 +600,20 @@ def _get_quant_method() -> FusedMoEMethodBase:
600600
# Avoid circular import
601601
from vllm.model_executor.layers.quantization.modelopt import (
602602
ModelOptFp8MoEMethod,
603+
ModelOptNvFp4FusedMoE,
603604
)
604605

605606
if not isinstance(
606-
self.quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
607+
self.quant_method,
608+
(
609+
UnquantizedFusedMoEMethod,
610+
ModelOptFp8MoEMethod,
611+
ModelOptNvFp4FusedMoE,
612+
),
607613
):
608614
raise NotImplementedError(
609615
"is_act_and_mul=False is supported only for unquantized "
610-
"and ModelOpt FP8 moe for now"
616+
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
611617
)
612618
if not current_platform.is_cuda():
613619
raise NotImplementedError(
@@ -1277,7 +1283,7 @@ def weight_loader(
12771283
self._load_combined_w13_weight_scale(
12781284
shard_dim=shard_dim,
12791285
loaded_weight=loaded_weight,
1280-
param=param,
1286+
param=expert_data,
12811287
tp_rank=self.tp_rank,
12821288
)
12831289
return True if return_success else None

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,7 @@ def create_weights(
12161216
w13_weight = ModelWeightParameter(
12171217
data=torch.empty(
12181218
num_experts,
1219-
2 * intermediate_size_per_partition,
1219+
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
12201220
# 2 fp4 items are packed in the input dimension
12211221
hidden_size // 2,
12221222
dtype=weight_dtype,
@@ -1245,7 +1245,7 @@ def create_weights(
12451245
w13_weight_scale = ModelWeightParameter(
12461246
data=torch.empty(
12471247
num_experts,
1248-
2 * intermediate_size_per_partition,
1248+
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
12491249
# 2 fp4 items are packed in the input dimension
12501250
hidden_size // self.quant_config.group_size,
12511251
dtype=weight_scale_dtype,
@@ -1275,7 +1275,9 @@ def create_weights(
12751275
)
12761276

12771277
w13_weight_scale_2 = PerTensorScaleParameter(
1278-
data=torch.empty(num_experts, 2, dtype=torch.float32),
1278+
data=torch.empty(
1279+
num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
1280+
),
12791281
weight_loader=weight_loader,
12801282
)
12811283
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
@@ -1296,7 +1298,11 @@ def create_weights(
12961298
global_scale_num_experts = global_num_experts if use_global_sf else num_experts
12971299

12981300
w13_input_scale = PerTensorScaleParameter(
1299-
data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
1301+
data=torch.empty(
1302+
global_scale_num_experts,
1303+
2 if self.moe.is_act_and_mul else 1,
1304+
dtype=torch.float32,
1305+
),
13001306
weight_loader=weight_loader,
13011307
)
13021308
layer.register_parameter("w13_input_scale", w13_input_scale)
@@ -1312,9 +1318,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13121318
gemm1_weight = layer.w13_weight.data
13131319
gemm1_weight_scale = layer.w13_weight_scale.data
13141320

1315-
if self.allow_flashinfer and (
1316-
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
1317-
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
1321+
if (
1322+
self.allow_flashinfer
1323+
and (
1324+
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
1325+
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
1326+
)
1327+
and self.moe.is_act_and_mul
13181328
):
13191329
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
13201330
gemm1_weight, gemm1_weight_scale, dim=-2
@@ -1324,7 +1334,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13241334
layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
13251335

13261336
# Common processing for w13_weight_scale_2
1327-
if not torch.allclose(
1337+
if self.moe.is_act_and_mul and not torch.allclose(
13281338
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
13291339
):
13301340
logger.warning_once(
@@ -1437,11 +1447,39 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
14371447
w13_blockscale_swizzled, requires_grad=False
14381448
)
14391449

1450+
w13_weight = layer.w13_weight
1451+
intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
1452+
if intermediate_size_pad:
1453+
# padding gated activations will require to split w1 and w3
1454+
# and pad them individually
1455+
assert not self.moe.is_act_and_mul, (
1456+
"The intermediate size required padding, "
1457+
"but padding is not implemented for gated activations"
1458+
)
1459+
1460+
layer.w13_weight = Parameter(
1461+
torch.nn.functional.pad(
1462+
w13_weight, (0, 0, 0, intermediate_size_pad)
1463+
),
1464+
requires_grad=False,
1465+
)
1466+
layer.w2_weight = Parameter(
1467+
torch.nn.functional.pad(
1468+
layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
1469+
),
1470+
requires_grad=False,
1471+
)
1472+
layer.w2_weight_scale = Parameter(
1473+
torch.nn.functional.pad(
1474+
layer.w2_weight_scale, (0, intermediate_size_pad // 16)
1475+
),
1476+
requires_grad=False,
1477+
)
1478+
14401479
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
14411480
layer.w2_weight_scale = Parameter(
14421481
w2_blockscale_swizzled, requires_grad=False
14431482
)
1444-
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
14451483

14461484
def get_fused_moe_quant_config(
14471485
self, layer: torch.nn.Module
@@ -1484,7 +1522,14 @@ def apply(
14841522
logical_to_physical_map: torch.Tensor | None = None,
14851523
logical_replica_count: torch.Tensor | None = None,
14861524
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1487-
assert activation == "silu", "Only SiLU activation is supported."
1525+
if not self.moe.is_act_and_mul:
1526+
assert (
1527+
self.allow_flashinfer
1528+
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
1529+
), (
1530+
"Non-gated activations are only supported by the"
1531+
" flashinfer CUTLASS backend for modelopt checkpoints"
1532+
)
14881533

14891534
if (
14901535
self.allow_flashinfer

0 commit comments

Comments
 (0)