1414
1515namespace vllm {
1616namespace moe {
17-
1817namespace batched_moe_align_block_size {
1918
2019// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
@@ -86,13 +85,19 @@ __device__ void _moe_align_block_size(
8685 int32_t * __restrict__ total_tokens_post_pad, int32_t num_experts,
8786 int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
8887 size_t numel, int32_t * __restrict__ cumsum, int32_t max_num_tokens_padded,
89- int32_t max_num_m_blocks, int inactive_expert_id, int adapter_offset ,
90- int topk_num, int32_t * token_mask) {
88+ int32_t max_num_m_blocks, int32_t model_offset, int32_t inactive_expert_id ,
89+ int32_t topk_num, int32_t * token_mask) {
9190 extern __shared__ int32_t shared_counts[];
9291
92+ // Compute input buffer offsets. Typically these will all be 0, except when
93+ // using Multi LoRA.
94+ int sorted_token_ids_offset = max_num_tokens_padded * model_offset;
95+ int expert_ids_offset = max_num_m_blocks * model_offset;
96+ int cumsum_offset = (num_experts + 1 ) * model_offset;
97+
9398 // Initialize sorted_token_ids with numel
9499 for (size_t it = threadIdx .x ; it < max_num_tokens_padded; it += blockDim .x ) {
95- sorted_token_ids[(adapter_offset * max_num_tokens_padded) + it] = numel;
100+ sorted_token_ids[sorted_token_ids_offset + it] = numel;
96101 }
97102
98103 const int warp_id = threadIdx .x / WARP_SIZE;
@@ -116,8 +121,9 @@ __device__ void _moe_align_block_size(
116121 }
117122 int warp_idx = expert_id / experts_per_warp;
118123 int expert_offset = expert_id % experts_per_warp;
124+ int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
119125 atomicAdd (&shared_counts[warp_idx * experts_per_warp + expert_offset],
120- token_mask[i / topk_num] );
126+ mask );
121127 }
122128
123129 __syncthreads ();
@@ -135,33 +141,31 @@ __device__ void _moe_align_block_size(
135141 expert_count = CEILDIV (expert_count, block_size) * block_size;
136142 }
137143
138- int adapter_cumsum_offset = (num_experts + 1 ) * adapter_offset;
139144 int cumsum_val;
140145 BlockScan (temp_storage).ExclusiveSum (expert_count, cumsum_val);
141146 if (expert_id <= num_experts) {
142- cumsum[adapter_cumsum_offset + expert_id] = cumsum_val;
147+ cumsum[cumsum_offset + expert_id] = cumsum_val;
143148 }
144149
145150 if (expert_id == num_experts) {
146- total_tokens_post_pad[adapter_offset ] = cumsum_val;
151+ total_tokens_post_pad[model_offset ] = cumsum_val;
147152 }
148153
149154 __syncthreads ();
150155
151156 if (threadIdx .x < num_experts) {
152- for (int i = cumsum[adapter_cumsum_offset + threadIdx .x ];
153- i < cumsum[adapter_cumsum_offset + threadIdx .x + 1 ]; i += block_size) {
154- expert_ids[(max_num_m_blocks * adapter_offset) + i / block_size] =
155- threadIdx .x ;
157+ for (int i = cumsum[cumsum_offset + threadIdx .x ];
158+ i < cumsum[cumsum_offset + threadIdx .x + 1 ]; i += block_size) {
159+ expert_ids[expert_ids_offset + i / block_size] = threadIdx .x ;
156160 }
157161 }
158162
159163 // Fill remaining expert_ids with 0
160164 const size_t fill_start_idx =
161- cumsum[adapter_cumsum_offset + num_experts] / block_size + threadIdx .x ;
165+ cumsum[cumsum_offset + num_experts] / block_size + threadIdx .x ;
162166 const size_t expert_ids_size = CEILDIV (max_num_tokens_padded, block_size);
163167 for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim .x ) {
164- expert_ids[(max_num_m_blocks * adapter_offset) + i] = inactive_expert_id;
168+ expert_ids[expert_ids_offset + i] = inactive_expert_id;
165169 }
166170}
167171
@@ -171,11 +175,16 @@ __device__ void _moe_align_block_size_small_batch_expert(
171175 int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ expert_ids,
172176 int32_t * __restrict__ total_tokens_post_pad, int32_t num_experts,
173177 int32_t block_size, size_t numel, int32_t max_num_tokens_padded,
174- int32_t max_num_m_blocks, int inactive_expert_id, int adapter_offset,
175- int topk_num, int32_t * token_mask) {
178+ int32_t max_num_m_blocks, int32_t inactive_expert_id, int32_t model_offset,
179+ int32_t topk_num, int32_t * token_mask) {
180+ // Compute input buffer offsets. Typically these will all be 0, except when
181+ // using Multi LoRA.
182+ int sorted_token_ids_offset = max_num_tokens_padded * model_offset;
183+ int expert_ids_offset = max_num_m_blocks * model_offset;
184+
176185 // Initialize sorted_token_ids with numel
177186 for (size_t it = threadIdx .x ; it < max_num_tokens_padded; it += blockDim .x ) {
178- sorted_token_ids[(adapter_offset * max_num_tokens_padded) + it] = numel;
187+ sorted_token_ids[sorted_token_ids_offset + it] = numel;
179188 }
180189
181190 const size_t tid = threadIdx .x ;
@@ -190,8 +199,8 @@ __device__ void _moe_align_block_size_small_batch_expert(
190199 }
191200
192201 for (size_t i = tid; i < numel; i += stride) {
193- tokens_cnts[( threadIdx . x + 1 ) * num_experts + topk_ids[i]] +=
194- token_mask[i / topk_num] ;
202+ int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
203+ tokens_cnts[( threadIdx . x + 1 ) * num_experts + topk_ids[i]] += mask ;
195204 }
196205
197206 __syncthreads ();
@@ -214,7 +223,7 @@ __device__ void _moe_align_block_size_small_batch_expert(
214223 CEILDIV (tokens_cnts[blockDim .x * num_experts + i - 1 ], block_size) *
215224 block_size;
216225 }
217- total_tokens_post_pad[adapter_offset ] =
226+ total_tokens_post_pad[model_offset ] =
218227 static_cast <int32_t >(cumsum[num_experts]);
219228 }
220229
@@ -223,25 +232,47 @@ __device__ void _moe_align_block_size_small_batch_expert(
223232 if (threadIdx .x < num_experts) {
224233 for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ];
225234 i += block_size) {
226- expert_ids[(max_num_m_blocks * adapter_offset) + i / block_size] =
227- threadIdx .x ;
235+ expert_ids[expert_ids_offset + i / block_size] = threadIdx .x ;
228236 }
229237 }
230238
231239 // Fill remaining expert_ids with 0
232240 const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx .x ;
233241 for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim .x ) {
234- expert_ids[(max_num_m_blocks * adapter_offset) + i] = inactive_expert_id;
242+ expert_ids[expert_ids_offset + i] = inactive_expert_id;
235243 }
236244
237245 for (size_t i = tid; i < numel; i += stride) {
238246 int32_t expert_id = topk_ids[i];
239247 int32_t rank_post_pad =
240248 tokens_cnts[threadIdx .x * num_experts + expert_id] + cumsum[expert_id];
241- sorted_token_ids[(adapter_offset * max_num_tokens_padded) +
242- rank_post_pad] += ((i - numel) * token_mask[i / topk_num]);
243- tokens_cnts[threadIdx .x * num_experts + expert_id] +=
244- token_mask[i / topk_num];
249+
250+ int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
251+ sorted_token_ids[sorted_token_ids_offset + rank_post_pad] +=
252+ ((i - numel) * mask);
253+ tokens_cnts[threadIdx .x * num_experts + expert_id] += mask;
254+ }
255+ }
256+
257+ template <typename scalar_t >
258+ __device__ void _count_and_sort_expert_tokens (
259+ const scalar_t * __restrict__ topk_ids,
260+ int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ cumsum_buffer,
261+ size_t numel, int32_t num_experts, int32_t max_num_tokens_padded,
262+ int32_t * __restrict__ token_mask, int32_t model_offset, int32_t topk_num) {
263+ const size_t tid = blockIdx .y * blockDim .x + threadIdx .x ;
264+ const size_t stride = blockDim .x * gridDim .y ;
265+
266+ for (size_t i = tid; i < numel; i += stride) {
267+ int32_t expert_id = topk_ids[i];
268+ if (expert_id >= num_experts) {
269+ continue ;
270+ }
271+
272+ int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
273+ int32_t rank_post_pad = atomicAdd (
274+ &cumsum_buffer[(model_offset * (num_experts + 1 )) + expert_id], mask);
275+ sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = i;
245276 }
246277}
247278
@@ -252,44 +283,22 @@ __global__ void moe_align_block_size_kernel(
252283 int32_t * __restrict__ total_tokens_post_pad, int32_t num_experts,
253284 int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
254285 size_t numel, int32_t * __restrict__ cumsum, int32_t max_num_tokens_padded,
255- int32_t topk_num, int32_t * __restrict__ token_mask) {
256- int num_tokens = numel / topk_num;
257- if (threadIdx .x < num_tokens) {
258- for (int i = threadIdx .x ; i < num_tokens; i += blockDim .x ) {
259- token_mask[i] = 1 ;
260- }
261- }
262-
263- __syncthreads ();
264-
286+ int32_t topk_num) {
265287 _moe_align_block_size (topk_ids, sorted_token_ids, expert_ids,
266288 total_tokens_post_pad, num_experts, padded_num_experts,
267289 experts_per_warp, block_size, numel, cumsum,
268- max_num_tokens_padded, 0 , 0 , 0 , topk_num, token_mask );
290+ max_num_tokens_padded, 0 , 0 , 0 , topk_num, nullptr );
269291}
270292
271293template <typename scalar_t >
272294__global__ void count_and_sort_expert_tokens_kernel (
273295 const scalar_t * __restrict__ topk_ids,
274296 int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ cumsum_buffer,
275297 size_t numel, int32_t num_experts, int32_t max_num_tokens_padded,
276- int32_t topk_num, int32_t * __restrict__ token_mask) {
277- const size_t adapter_offset = blockIdx .x ;
278- const size_t tid = blockIdx .y * blockDim .x + threadIdx .x ;
279- const size_t stride = blockDim .x * gridDim .y ;
280- int num_tokens = numel / topk_num;
281-
282- for (size_t i = tid; i < numel; i += stride) {
283- int32_t expert_id = topk_ids[i];
284- if (expert_id >= num_experts ||
285- token_mask[(adapter_offset * num_tokens) + i / topk_num] == 0 ) {
286- continue ;
287- }
288- int32_t rank_post_pad = atomicAdd (
289- &cumsum_buffer[(adapter_offset * (num_experts + 1 )) + expert_id], 1 );
290- sorted_token_ids[max_num_tokens_padded * adapter_offset + rank_post_pad] =
291- i;
292- }
298+ int32_t topk_num) {
299+ _count_and_sort_expert_tokens (topk_ids, sorted_token_ids, cumsum_buffer,
300+ numel, num_experts, max_num_tokens_padded,
301+ nullptr , 0 , topk_num);
293302}
294303
295304template <typename scalar_t , int TOPK>
@@ -314,29 +323,24 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
314323 int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ expert_ids,
315324 int32_t * __restrict__ total_tokens_post_pad, int32_t num_experts,
316325 int32_t block_size, size_t numel, int32_t max_num_tokens_padded,
317- int32_t topk_num, int32_t * __restrict__ token_mask) {
318- int num_tokens = numel / topk_num;
319- if (threadIdx .x < num_tokens) {
320- for (int i = threadIdx .x ; i < num_tokens; i += blockDim .x ) {
321- token_mask[i] = 1 ;
322- }
323- }
324-
326+ int32_t topk_num) {
325327 __syncthreads ();
326328
327329 _moe_align_block_size_small_batch_expert (
328330 topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad,
329331 num_experts, block_size, numel, max_num_tokens_padded, 0 , 0 , 0 , topk_num,
330- token_mask );
332+ nullptr );
331333}
332334
335+ namespace lora {
336+
333337template <typename scalar_t >
334338__global__ void moe_lora_align_block_size_kernel (
335339 scalar_t * __restrict__ topk_ids, scalar_t * __restrict__ token_lora_mapping,
336340 int64_t block_size, int num_experts, int max_loras, size_t numel,
337341 int max_num_tokens_padded, int max_num_m_blocks,
338342 int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ expert_ids,
339- int topk_num, int32_t * total_tokens_post_pad, int32_t * adapter_enabled,
343+ int32_t topk_num, int32_t * total_tokens_post_pad, int32_t * adapter_enabled,
340344 int32_t * __restrict__ cumsum, int32_t experts_per_warp,
341345 int32_t padded_num_experts, int32_t * lora_ids,
342346 int32_t * __restrict__ token_mask) {
@@ -346,6 +350,7 @@ __global__ void moe_lora_align_block_size_kernel(
346350 return ;
347351 }
348352
353+ // Populate the token_mask based on the token-LoRA mapping
349354 int num_tokens = numel / topk_num;
350355 if (threadIdx .x == 0 ) {
351356 total_tokens_post_pad[lora_id] = 0 ;
@@ -361,10 +366,30 @@ __global__ void moe_lora_align_block_size_kernel(
361366 _moe_align_block_size (topk_ids, sorted_token_ids, expert_ids,
362367 total_tokens_post_pad, num_experts, padded_num_experts,
363368 experts_per_warp, block_size, numel, cumsum,
364- max_num_tokens_padded, max_num_m_blocks, - 1 , lora_id ,
369+ max_num_tokens_padded, max_num_m_blocks, lora_id, - 1 ,
365370 topk_num, &token_mask[(lora_id * num_tokens)]);
366371}
367372
373+ template <typename scalar_t >
374+ __global__ void lora_count_and_sort_expert_tokens_kernel (
375+ const scalar_t * __restrict__ topk_ids,
376+ int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ cumsum_buffer,
377+ size_t numel, int32_t num_experts, int32_t max_num_tokens_padded,
378+ int32_t topk_num, int32_t * token_mask, int32_t * lora_ids) {
379+ int lora_idx = blockIdx .x ;
380+ int lora_id = lora_ids[lora_idx];
381+ if (lora_id == -1 ) {
382+ return ;
383+ }
384+
385+ int num_tokens = numel / topk_num;
386+
387+ _count_and_sort_expert_tokens (topk_ids, sorted_token_ids, cumsum_buffer,
388+ numel, num_experts, max_num_tokens_padded,
389+ &token_mask[(lora_id * num_tokens)], lora_id,
390+ topk_num);
391+ }
392+
368393template <typename scalar_t >
369394__global__ void moe_lora_align_block_size_small_batch_expert_kernel (
370395 scalar_t * __restrict__ topk_ids, int32_t * token_lora_mapping,
@@ -397,6 +422,7 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel(
397422 -1 , lora_id, topk_num, &token_mask[(lora_id * num_tokens)]);
398423}
399424
425+ } // namespace lora
400426} // namespace moe
401427} // namespace vllm
402428
@@ -426,9 +452,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
426452 torch::Tensor cumsum_buffer =
427453 torch::empty ({num_experts + 1 }, options_int);
428454
429- torch::Tensor token_mask =
430- torch::empty ({topk_ids.size (0 )}, options_int);
431-
432455 bool small_batch_expert_mode =
433456 (topk_ids.numel () < 1024 ) && (num_experts <= 64 );
434457
@@ -446,8 +469,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
446469 sorted_token_ids.data_ptr <int32_t >(),
447470 experts_ids.data_ptr <int32_t >(),
448471 num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
449- topk_ids.numel (), sorted_token_ids.size (0 ), topk_ids.size (1 ),
450- token_mask.data_ptr <int32_t >());
472+ topk_ids.numel (), sorted_token_ids.size (0 ), topk_ids.size (1 ));
451473 } else {
452474 auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
453475
@@ -462,8 +484,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
462484 num_tokens_post_pad.data_ptr <int32_t >(), num_experts,
463485 padded_num_experts, experts_per_warp, block_size,
464486 topk_ids.numel (), cumsum_buffer.data_ptr <int32_t >(),
465- sorted_token_ids.size (0 ), topk_ids.size (1 ),
466- token_mask.data_ptr <int32_t >());
487+ sorted_token_ids.size (0 ), topk_ids.size (1 ));
467488
468489 const int block_threads = std::min (256 , (int )threads);
469490 const int num_blocks =
@@ -478,8 +499,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
478499 topk_ids.data_ptr <scalar_t >(),
479500 sorted_token_ids.data_ptr <int32_t >(),
480501 cumsum_buffer.data_ptr <int32_t >(), topk_ids.numel (), num_experts,
481- sorted_token_ids.size (0 ), topk_ids.size (1 ),
482- token_mask.data_ptr <int32_t >());
502+ sorted_token_ids.size (0 ), topk_ids.size (1 ));
483503 }
484504 });
485505}
@@ -598,9 +618,8 @@ void moe_lora_align_block_size(
598618 }
599619
600620 dim3 blockDim (num_thread);
601- auto kernel =
602- vllm::moe::moe_lora_align_block_size_small_batch_expert_kernel<
603- scalar_t >;
621+ auto kernel = vllm::moe::lora::
622+ moe_lora_align_block_size_small_batch_expert_kernel<scalar_t >;
604623 AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
605624 (void *)kernel, shared_mem));
606625 kernel<<<max_loras, blockDim , shared_mem, stream>>> (
@@ -627,7 +646,7 @@ void moe_lora_align_block_size(
627646 torch::zeros ({max_loras * (num_experts + 1 )}, options_int);
628647
629648 auto align_kernel =
630- vllm::moe::moe_lora_align_block_size_kernel<scalar_t >;
649+ vllm::moe::lora:: moe_lora_align_block_size_kernel<scalar_t >;
631650 align_kernel<<<max_loras, blockDim , shared_mem_size, stream>>> (
632651 topk_ids.data_ptr <scalar_t >(),
633652 token_lora_mapping.data_ptr <scalar_t >(), block_size, num_experts,
@@ -648,13 +667,14 @@ void moe_lora_align_block_size(
648667
649668 dim3 gridDims (max_loras, actual_blocks);
650669 auto sort_kernel =
651- vllm::moe::count_and_sort_expert_tokens_kernel<scalar_t >;
670+ vllm::moe::lora::lora_count_and_sort_expert_tokens_kernel<
671+ scalar_t >;
652672
653673 sort_kernel<<<gridDims, block_threads, 0 , stream>>> (
654674 topk_ids.data_ptr <scalar_t >(),
655675 sorted_token_ids.data_ptr <int32_t >(), cumsum.data_ptr <int32_t >(),
656676 topk_ids.numel (), num_experts, max_num_tokens_padded, topk_num,
657- token_mask.data_ptr <int32_t >());
677+ token_mask.data_ptr <int32_t >(), lora_ids. data_ptr < int32_t >() );
658678 }
659679 });
660680}
0 commit comments