@@ -895,6 +895,48 @@ def get_moe_configs(
895895 return None
896896
897897
898+ def _ensure_block_size_k_divisible (
899+ size_k : int , block_size_k : int , group_size : int
900+ ) -> int :
901+ """Ensure block_size_k is a divisor of size_k and divisible by group_size.
902+
903+ This ensures BLOCK_SIZE_K compatibility with MoeWNA16 CUDA kernel which
904+ requires size_k % BLOCK_SIZE_K == 0 and BLOCK_SIZE_K % group_size == 0.
905+
906+ Args:
907+ size_k: The size_k dimension that must be divisible by result.
908+ block_size_k: Preferred block size (will be adjusted if needed).
909+ group_size: The result must be divisible by this.
910+
911+ Returns:
912+ A valid BLOCK_SIZE_K that divides size_k and is divisible by group_size.
913+ """
914+ # Fast path: already valid
915+ if size_k % block_size_k == 0 and block_size_k % group_size == 0 :
916+ return block_size_k
917+
918+ # Find the largest value that:
919+ # 1. Divides size_k (size_k % candidate == 0)
920+ # 2. Is divisible by group_size (candidate % group_size == 0)
921+ # 3. Is <= block_size_k (prefer smaller values close to block_size_k)
922+ #
923+ # Strategy: Search from min(block_size_k, size_k) down to group_size,
924+ # stepping by group_size to ensure divisibility by group_size
925+ max_search = min (block_size_k , size_k )
926+ start = (max_search // group_size ) * group_size
927+ for candidate in range (start , group_size - 1 , - group_size ):
928+ if size_k % candidate == 0 :
929+ return candidate
930+
931+ # Fallback: if group_size divides size_k, use it
932+ # This should always be true with correct group_size configuration
933+ if size_k % group_size == 0 :
934+ return group_size
935+
936+ # This should not happen with correct group_size, but ensure divisibility
937+ return size_k
938+
939+
898940def get_moe_wna16_block_config (
899941 config : dict [str , int ],
900942 use_moe_wna16_cuda : bool ,
@@ -960,6 +1002,9 @@ def get_moe_wna16_block_config(
9601002 # at the same time.
9611003 block_size_n = 1024
9621004
1005+ # Ensure BLOCK_SIZE_K is a divisor of size_k for CUDA kernel compatibility
1006+ block_size_k = _ensure_block_size_k_divisible (size_k , block_size_k , group_size )
1007+
9631008 return {"BLOCK_SIZE_N" : block_size_n , "BLOCK_SIZE_K" : block_size_k }
9641009
9651010
0 commit comments