@@ -1216,7 +1216,7 @@ def create_weights(
12161216 w13_weight = ModelWeightParameter (
12171217 data = torch .empty (
12181218 num_experts ,
1219- 2 * intermediate_size_per_partition ,
1219+ ( 2 if self . moe . is_act_and_mul else 1 ) * intermediate_size_per_partition ,
12201220 # 2 fp4 items are packed in the input dimension
12211221 hidden_size // 2 ,
12221222 dtype = weight_dtype ,
@@ -1245,7 +1245,7 @@ def create_weights(
12451245 w13_weight_scale = ModelWeightParameter (
12461246 data = torch .empty (
12471247 num_experts ,
1248- 2 * intermediate_size_per_partition ,
1248+ ( 2 if self . moe . is_act_and_mul else 1 ) * intermediate_size_per_partition ,
12491249 # 2 fp4 items are packed in the input dimension
12501250 hidden_size // self .quant_config .group_size ,
12511251 dtype = weight_scale_dtype ,
@@ -1275,7 +1275,9 @@ def create_weights(
12751275 )
12761276
12771277 w13_weight_scale_2 = PerTensorScaleParameter (
1278- data = torch .empty (num_experts , 2 , dtype = torch .float32 ),
1278+ data = torch .empty (
1279+ num_experts , 2 if self .moe .is_act_and_mul else 1 , dtype = torch .float32
1280+ ),
12791281 weight_loader = weight_loader ,
12801282 )
12811283 layer .register_parameter ("w13_weight_scale_2" , w13_weight_scale_2 )
@@ -1296,7 +1298,11 @@ def create_weights(
12961298 global_scale_num_experts = global_num_experts if use_global_sf else num_experts
12971299
12981300 w13_input_scale = PerTensorScaleParameter (
1299- data = torch .empty (global_scale_num_experts , 2 , dtype = torch .float32 ),
1301+ data = torch .empty (
1302+ global_scale_num_experts ,
1303+ 2 if self .moe .is_act_and_mul else 1 ,
1304+ dtype = torch .float32 ,
1305+ ),
13001306 weight_loader = weight_loader ,
13011307 )
13021308 layer .register_parameter ("w13_input_scale" , w13_input_scale )
@@ -1312,9 +1318,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13121318 gemm1_weight = layer .w13_weight .data
13131319 gemm1_weight_scale = layer .w13_weight_scale .data
13141320
1315- if self .allow_flashinfer and (
1316- self .flashinfer_moe_backend == FlashinferMoeBackend .CUTLASS
1317- or self .flashinfer_moe_backend == FlashinferMoeBackend .TENSORRT_LLM
1321+ if (
1322+ self .allow_flashinfer
1323+ and (
1324+ self .flashinfer_moe_backend == FlashinferMoeBackend .CUTLASS
1325+ or self .flashinfer_moe_backend == FlashinferMoeBackend .TENSORRT_LLM
1326+ )
1327+ and self .moe .is_act_and_mul
13181328 ):
13191329 gemm1_weight , gemm1_weight_scale = reorder_w1w3_to_w3w1 (
13201330 gemm1_weight , gemm1_weight_scale , dim = - 2
@@ -1324,7 +1334,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
13241334 layer .w13_weight_scale = Parameter (gemm1_weight_scale , requires_grad = False )
13251335
13261336 # Common processing for w13_weight_scale_2
1327- if not torch .allclose (
1337+ if self . moe . is_act_and_mul and not torch .allclose (
13281338 layer .w13_weight_scale_2 [:, 0 ], layer .w13_weight_scale_2 [:, 1 ]
13291339 ):
13301340 logger .warning_once (
@@ -1437,11 +1447,39 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
14371447 w13_blockscale_swizzled , requires_grad = False
14381448 )
14391449
1450+ w13_weight = layer .w13_weight
1451+ intermediate_size_pad = w13_blockscale_swizzled .size (1 ) - w13_weight .size (1 )
1452+ if intermediate_size_pad :
1453+ # padding gated activations will require to split w1 and w3
1454+ # and pad them individually
1455+ assert not self .moe .is_act_and_mul , (
1456+ "The intermediate size required padding, "
1457+ "but padding is not implemented for gated activations"
1458+ )
1459+
1460+ layer .w13_weight = Parameter (
1461+ torch .nn .functional .pad (
1462+ w13_weight , (0 , 0 , 0 , intermediate_size_pad )
1463+ ),
1464+ requires_grad = False ,
1465+ )
1466+ layer .w2_weight = Parameter (
1467+ torch .nn .functional .pad (
1468+ layer .w2_weight , (0 , intermediate_size_pad // 2 , 0 , 0 )
1469+ ),
1470+ requires_grad = False ,
1471+ )
1472+ layer .w2_weight_scale = Parameter (
1473+ torch .nn .functional .pad (
1474+ layer .w2_weight_scale , (0 , intermediate_size_pad // 16 )
1475+ ),
1476+ requires_grad = False ,
1477+ )
1478+
14401479 w2_blockscale_swizzled = swizzle_blockscale (layer .w2_weight_scale )
14411480 layer .w2_weight_scale = Parameter (
14421481 w2_blockscale_swizzled , requires_grad = False
14431482 )
1444- layer .w2_weight = Parameter (layer .w2_weight .data , requires_grad = False )
14451483
14461484 def get_fused_moe_quant_config (
14471485 self , layer : torch .nn .Module
@@ -1484,7 +1522,14 @@ def apply(
14841522 logical_to_physical_map : torch .Tensor | None = None ,
14851523 logical_replica_count : torch .Tensor | None = None ,
14861524 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
1487- assert activation == "silu" , "Only SiLU activation is supported."
1525+ if not self .moe .is_act_and_mul :
1526+ assert (
1527+ self .allow_flashinfer
1528+ and self .flashinfer_moe_backend == FlashinferMoeBackend .CUTLASS
1529+ ), (
1530+ "Non-gated activations are only supported by the"
1531+ " flashinfer CUTLASS backend for modelopt checkpoints"
1532+ )
14881533
14891534 if (
14901535 self .allow_flashinfer
0 commit comments