Skip to content

Commit e0cce68

Browse files
committed
support triton experts
Signed-off-by: gnovack <[email protected]>
1 parent 74144e6 commit e0cce68

File tree

4 files changed

+30
-22
lines changed

4 files changed

+30
-22
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ __device__ void _moe_align_block_size(
163163
// Fill remaining expert_ids with 0
164164
const size_t fill_start_idx =
165165
cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x;
166-
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
167-
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
166+
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) {
168167
expert_ids[expert_ids_offset + i] = inactive_expert_id;
169168
}
170169
}
@@ -284,10 +283,11 @@ __global__ void moe_align_block_size_kernel(
284283
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
285284
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded,
286285
int32_t topk_num) {
287-
_moe_align_block_size(topk_ids, sorted_token_ids, expert_ids,
288-
total_tokens_post_pad, num_experts, padded_num_experts,
289-
experts_per_warp, block_size, numel, cumsum,
290-
max_num_tokens_padded, 0, 0, 0, topk_num, nullptr);
286+
_moe_align_block_size(
287+
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad,
288+
num_experts, padded_num_experts, experts_per_warp, block_size, numel,
289+
cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size),
290+
0, 0, topk_num, nullptr);
291291
}
292292

293293
template <typename scalar_t>
@@ -328,12 +328,10 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
328328

329329
_moe_align_block_size_small_batch_expert(
330330
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad,
331-
num_experts, block_size, numel, max_num_tokens_padded, 0, 0, 0, topk_num,
332-
nullptr);
331+
num_experts, block_size, numel, max_num_tokens_padded,
332+
CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr);
333333
}
334334

335-
namespace lora {
336-
337335
template <typename scalar_t>
338336
__global__ void moe_lora_align_block_size_kernel(
339337
scalar_t* __restrict__ topk_ids, scalar_t* __restrict__ token_lora_mapping,
@@ -422,7 +420,6 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel(
422420
-1, lora_id, topk_num, &token_mask[(lora_id * num_tokens)]);
423421
}
424422

425-
} // namespace lora
426423
} // namespace moe
427424
} // namespace vllm
428425

@@ -618,8 +615,9 @@ void moe_lora_align_block_size(
618615
}
619616

620617
dim3 blockDim(num_thread);
621-
auto kernel = vllm::moe::lora::
622-
moe_lora_align_block_size_small_batch_expert_kernel<scalar_t>;
618+
auto kernel =
619+
vllm::moe::moe_lora_align_block_size_small_batch_expert_kernel<
620+
scalar_t>;
623621
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
624622
(void*)kernel, shared_mem));
625623
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
@@ -646,7 +644,7 @@ void moe_lora_align_block_size(
646644
torch::zeros({max_loras * (num_experts + 1)}, options_int);
647645

648646
auto align_kernel =
649-
vllm::moe::lora::moe_lora_align_block_size_kernel<scalar_t>;
647+
vllm::moe::moe_lora_align_block_size_kernel<scalar_t>;
650648
align_kernel<<<max_loras, blockDim, shared_mem_size, stream>>>(
651649
topk_ids.data_ptr<scalar_t>(),
652650
token_lora_mapping.data_ptr<scalar_t>(), block_size, num_experts,
@@ -667,8 +665,7 @@ void moe_lora_align_block_size(
667665

668666
dim3 gridDims(max_loras, actual_blocks);
669667
auto sort_kernel =
670-
vllm::moe::lora::lora_count_and_sort_expert_tokens_kernel<
671-
scalar_t>;
668+
vllm::moe::lora_count_and_sort_expert_tokens_kernel<scalar_t>;
672669

673670
sort_kernel<<<gridDims, block_threads, 0, stream>>>(
674671
topk_ids.data_ptr<scalar_t>(),

vllm/lora/layers/fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,8 @@ def wrapper(*args, **kwargs):
327327
self.base_layer, fused_experts.moe_sum
328328
)
329329

330-
fused_experts.moe_align = moe_align_decorator(
331-
self.base_layer, fused_experts.moe_align
330+
fused_experts.moe_align_block_size = moe_align_decorator(
331+
self.base_layer, fused_experts.moe_align_block_size
332332
)
333333

334334
self.base_layer.quant_method = FusedMoEModularMethod(

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,13 +683,13 @@ def apply(
683683
sort_indices1=self.w13_g_idx_sort_indices,
684684
sort_indices2=self.w2_g_idx_sort_indices,
685685
is_k_full=self.is_k_full,
686-
moe_align=self.moe_align,
686+
moe_align=self.moe_align_block_size,
687687
)
688688

689689
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
690690
ops.moe_sum(input, output)
691691

692-
def moe_align(
692+
def moe_align_block_size(
693693
self,
694694
topk_ids: torch.Tensor,
695695
block_size_m: int,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2087,8 +2087,10 @@ def apply(
20872087
)
20882088
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
20892089

2090-
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
2091-
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
2090+
sorted_token_ids, expert_ids, num_tokens_post_padded = (
2091+
self.moe_align_block_size(
2092+
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
2093+
)
20922094
)
20932095

20942096
invoke_fused_moe_kernel(
@@ -2159,6 +2161,15 @@ def apply(
21592161
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
21602162
ops.moe_sum(input, output)
21612163

2164+
def moe_align_block_size(
2165+
self,
2166+
topk_ids: torch.Tensor,
2167+
block_size: int,
2168+
num_experts: int,
2169+
expert_map: torch.Tensor | None = None,
2170+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
2171+
return moe_align_block_size(topk_ids, block_size, num_experts, expert_map)
2172+
21622173

21632174
def modular_triton_fused_moe(
21642175
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None

0 commit comments

Comments
 (0)