@@ -163,8 +163,7 @@ __device__ void _moe_align_block_size(
163163 // Fill remaining expert_ids with 0
164164 const size_t fill_start_idx =
165165 cumsum[cumsum_offset + num_experts] / block_size + threadIdx .x ;
166- const size_t expert_ids_size = CEILDIV (max_num_tokens_padded, block_size);
167- for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim .x ) {
166+ for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim .x ) {
168167 expert_ids[expert_ids_offset + i] = inactive_expert_id;
169168 }
170169}
@@ -284,10 +283,11 @@ __global__ void moe_align_block_size_kernel(
284283 int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
285284 size_t numel, int32_t * __restrict__ cumsum, int32_t max_num_tokens_padded,
286285 int32_t topk_num) {
287- _moe_align_block_size (topk_ids, sorted_token_ids, expert_ids,
288- total_tokens_post_pad, num_experts, padded_num_experts,
289- experts_per_warp, block_size, numel, cumsum,
290- max_num_tokens_padded, 0 , 0 , 0 , topk_num, nullptr );
286+ _moe_align_block_size (
287+ topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad,
288+ num_experts, padded_num_experts, experts_per_warp, block_size, numel,
289+ cumsum, max_num_tokens_padded, CEILDIV (max_num_tokens_padded, block_size),
290+ 0 , 0 , topk_num, nullptr );
291291}
292292
293293template <typename scalar_t >
@@ -328,12 +328,10 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
328328
329329 _moe_align_block_size_small_batch_expert (
330330 topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad,
331- num_experts, block_size, numel, max_num_tokens_padded, 0 , 0 , 0 , topk_num,
332- nullptr );
331+ num_experts, block_size, numel, max_num_tokens_padded,
332+ CEILDIV (max_num_tokens_padded, block_size), 0 , 0 , topk_num, nullptr );
333333}
334334
335- namespace lora {
336-
337335template <typename scalar_t >
338336__global__ void moe_lora_align_block_size_kernel (
339337 scalar_t * __restrict__ topk_ids, scalar_t * __restrict__ token_lora_mapping,
@@ -422,7 +420,6 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel(
422420 -1 , lora_id, topk_num, &token_mask[(lora_id * num_tokens)]);
423421}
424422
425- } // namespace lora
426423} // namespace moe
427424} // namespace vllm
428425
@@ -618,8 +615,9 @@ void moe_lora_align_block_size(
618615 }
619616
620617 dim3 blockDim (num_thread);
621- auto kernel = vllm::moe::lora::
622- moe_lora_align_block_size_small_batch_expert_kernel<scalar_t >;
618+ auto kernel =
619+ vllm::moe::moe_lora_align_block_size_small_batch_expert_kernel<
620+ scalar_t >;
623621 AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
624622 (void *)kernel, shared_mem));
625623 kernel<<<max_loras, blockDim , shared_mem, stream>>> (
@@ -646,7 +644,7 @@ void moe_lora_align_block_size(
646644 torch::zeros ({max_loras * (num_experts + 1 )}, options_int);
647645
648646 auto align_kernel =
649- vllm::moe::lora:: moe_lora_align_block_size_kernel<scalar_t >;
647+ vllm::moe::moe_lora_align_block_size_kernel<scalar_t >;
650648 align_kernel<<<max_loras, blockDim , shared_mem_size, stream>>> (
651649 topk_ids.data_ptr <scalar_t >(),
652650 token_lora_mapping.data_ptr <scalar_t >(), block_size, num_experts,
@@ -667,8 +665,7 @@ void moe_lora_align_block_size(
667665
668666 dim3 gridDims (max_loras, actual_blocks);
669667 auto sort_kernel =
670- vllm::moe::lora::lora_count_and_sort_expert_tokens_kernel<
671- scalar_t >;
668+ vllm::moe::lora_count_and_sort_expert_tokens_kernel<scalar_t >;
672669
673670 sort_kernel<<<gridDims, block_threads, 0 , stream>>> (
674671 topk_ids.data_ptr <scalar_t >(),
0 commit comments