Skip to content

Commit 1166c31

Browse files
[Bugfix]: Fix glm46 awq marlin moe wna16 compatibility (#30210)
Signed-off-by: baonudesifeizhai <[email protected]>
1 parent 03416ea commit 1166c31

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,48 @@ def get_moe_configs(
895895
return None
896896

897897

898+
def _ensure_block_size_k_divisible(
899+
size_k: int, block_size_k: int, group_size: int
900+
) -> int:
901+
"""Ensure block_size_k is a divisor of size_k and divisible by group_size.
902+
903+
This ensures BLOCK_SIZE_K compatibility with MoeWNA16 CUDA kernel which
904+
requires size_k % BLOCK_SIZE_K == 0 and BLOCK_SIZE_K % group_size == 0.
905+
906+
Args:
907+
size_k: The size_k dimension that must be divisible by result.
908+
block_size_k: Preferred block size (will be adjusted if needed).
909+
group_size: The result must be divisible by this.
910+
911+
Returns:
912+
A valid BLOCK_SIZE_K that divides size_k and is divisible by group_size.
913+
"""
914+
# Fast path: already valid
915+
if size_k % block_size_k == 0 and block_size_k % group_size == 0:
916+
return block_size_k
917+
918+
# Find the largest value that:
919+
# 1. Divides size_k (size_k % candidate == 0)
920+
# 2. Is divisible by group_size (candidate % group_size == 0)
921+
# 3. Is <= block_size_k (prefer smaller values close to block_size_k)
922+
#
923+
# Strategy: Search from min(block_size_k, size_k) down to group_size,
924+
# stepping by group_size to ensure divisibility by group_size
925+
max_search = min(block_size_k, size_k)
926+
start = (max_search // group_size) * group_size
927+
for candidate in range(start, group_size - 1, -group_size):
928+
if size_k % candidate == 0:
929+
return candidate
930+
931+
# Fallback: if group_size divides size_k, use it
932+
# This should always be true with correct group_size configuration
933+
if size_k % group_size == 0:
934+
return group_size
935+
936+
# This should not happen with correct group_size, but ensure divisibility
937+
return size_k
938+
939+
898940
def get_moe_wna16_block_config(
899941
config: dict[str, int],
900942
use_moe_wna16_cuda: bool,
@@ -960,6 +1002,9 @@ def get_moe_wna16_block_config(
9601002
# at the same time.
9611003
block_size_n = 1024
9621004

1005+
# Ensure BLOCK_SIZE_K is a divisor of size_k for CUDA kernel compatibility
1006+
block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size)
1007+
9631008
return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
9641009

9651010

vllm/model_executor/layers/quantization/moe_wna16.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060

6161
if self.linear_quant_method == "gptq":
6262
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
63-
elif self.linear_quant_method == "awq":
63+
elif self.linear_quant_method in ("awq", "awq_marlin"):
6464
capability_tuple = current_platform.get_device_capability()
6565
device_capability = (
6666
-1 if capability_tuple is None else capability_tuple.to_int()
@@ -107,7 +107,7 @@ def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
107107
if linear_quant_method == "gptq":
108108
has_zp = not cls.get_from_keys(config, ["sym"])
109109
modules_to_not_convert = []
110-
elif linear_quant_method == "awq":
110+
elif linear_quant_method in ("awq", "awq_marlin"):
111111
has_zp = cls.get_from_keys(config, ["zero_point"])
112112
modules_to_not_convert = cls.get_from_keys_or(
113113
config, ["modules_to_not_convert"], None
@@ -184,7 +184,7 @@ def get_quant_method(
184184
return GPTQConfig.from_config(self.full_config).get_quant_method(
185185
layer, prefix
186186
)
187-
elif self.linear_quant_method == "awq":
187+
elif self.linear_quant_method in ("awq", "awq_marlin"):
188188
if self.use_marlin and check_marlin_supports_layer(
189189
layer, self.group_size
190190
):
@@ -468,7 +468,8 @@ def moe_wna16_weight_loader(
468468
shard_size = layer.intermediate_size_per_partition
469469

470470
# convert gptq and awq weight to a standard format
471-
if layer.quant_config.linear_quant_method == "awq":
471+
# awq_marlin uses the same weight format as awq
472+
if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"):
472473
assert layer.quant_config.weight_bits == 4
473474
if "weight" in weight_name:
474475
loaded_weight = convert_awq_tensor(loaded_weight, "qweight")

0 commit comments

Comments
 (0)