Skip to content

Commit 74144e6

Browse files
committed
remove lora kernel file
Signed-off-by: gnovack <[email protected]>
1 parent c2959aa commit 74144e6

File tree

3 files changed

+101
-256
lines changed

3 files changed

+101
-256
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
908908
set(VLLM_MOE_EXT_SRC
909909
"csrc/moe/torch_bindings.cpp"
910910
"csrc/moe/moe_align_sum_kernels.cu"
911-
"csrc/moe/moe_lora_align_sum_kernels.cu"
912911
"csrc/moe/topk_softmax_kernels.cu")
913912

914913
if(VLLM_GPU_LANG STREQUAL "CUDA")

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 101 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
namespace vllm {
1616
namespace moe {
17-
1817
namespace batched_moe_align_block_size {
1918

2019
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
@@ -86,13 +85,19 @@ __device__ void _moe_align_block_size(
8685
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
8786
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
8887
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded,
89-
int32_t max_num_m_blocks, int inactive_expert_id, int adapter_offset,
90-
int topk_num, int32_t* token_mask) {
88+
int32_t max_num_m_blocks, int32_t model_offset, int32_t inactive_expert_id,
89+
int32_t topk_num, int32_t* token_mask) {
9190
extern __shared__ int32_t shared_counts[];
9291

92+
// Compute input buffer offsets. Typically these will all be 0, except when
93+
// using Multi LoRA.
94+
int sorted_token_ids_offset = max_num_tokens_padded * model_offset;
95+
int expert_ids_offset = max_num_m_blocks * model_offset;
96+
int cumsum_offset = (num_experts + 1) * model_offset;
97+
9398
// Initialize sorted_token_ids with numel
9499
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
95-
sorted_token_ids[(adapter_offset * max_num_tokens_padded) + it] = numel;
100+
sorted_token_ids[sorted_token_ids_offset + it] = numel;
96101
}
97102

98103
const int warp_id = threadIdx.x / WARP_SIZE;
@@ -116,8 +121,9 @@ __device__ void _moe_align_block_size(
116121
}
117122
int warp_idx = expert_id / experts_per_warp;
118123
int expert_offset = expert_id % experts_per_warp;
124+
int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
119125
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset],
120-
token_mask[i / topk_num]);
126+
mask);
121127
}
122128

123129
__syncthreads();
@@ -135,33 +141,31 @@ __device__ void _moe_align_block_size(
135141
expert_count = CEILDIV(expert_count, block_size) * block_size;
136142
}
137143

138-
int adapter_cumsum_offset = (num_experts + 1) * adapter_offset;
139144
int cumsum_val;
140145
BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val);
141146
if (expert_id <= num_experts) {
142-
cumsum[adapter_cumsum_offset + expert_id] = cumsum_val;
147+
cumsum[cumsum_offset + expert_id] = cumsum_val;
143148
}
144149

145150
if (expert_id == num_experts) {
146-
total_tokens_post_pad[adapter_offset] = cumsum_val;
151+
total_tokens_post_pad[model_offset] = cumsum_val;
147152
}
148153

149154
__syncthreads();
150155

151156
if (threadIdx.x < num_experts) {
152-
for (int i = cumsum[adapter_cumsum_offset + threadIdx.x];
153-
i < cumsum[adapter_cumsum_offset + threadIdx.x + 1]; i += block_size) {
154-
expert_ids[(max_num_m_blocks * adapter_offset) + i / block_size] =
155-
threadIdx.x;
157+
for (int i = cumsum[cumsum_offset + threadIdx.x];
158+
i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) {
159+
expert_ids[expert_ids_offset + i / block_size] = threadIdx.x;
156160
}
157161
}
158162

159163
// Fill remaining expert_ids with 0
160164
const size_t fill_start_idx =
161-
cumsum[adapter_cumsum_offset + num_experts] / block_size + threadIdx.x;
165+
cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x;
162166
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
163167
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
164-
expert_ids[(max_num_m_blocks * adapter_offset) + i] = inactive_expert_id;
168+
expert_ids[expert_ids_offset + i] = inactive_expert_id;
165169
}
166170
}
167171

@@ -171,11 +175,16 @@ __device__ void _moe_align_block_size_small_batch_expert(
171175
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
172176
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
173177
int32_t block_size, size_t numel, int32_t max_num_tokens_padded,
174-
int32_t max_num_m_blocks, int inactive_expert_id, int adapter_offset,
175-
int topk_num, int32_t* token_mask) {
178+
int32_t max_num_m_blocks, int32_t inactive_expert_id, int32_t model_offset,
179+
int32_t topk_num, int32_t* token_mask) {
180+
// Compute input buffer offsets. Typically these will all be 0, except when
181+
// using Multi LoRA.
182+
int sorted_token_ids_offset = max_num_tokens_padded * model_offset;
183+
int expert_ids_offset = max_num_m_blocks * model_offset;
184+
176185
// Initialize sorted_token_ids with numel
177186
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
178-
sorted_token_ids[(adapter_offset * max_num_tokens_padded) + it] = numel;
187+
sorted_token_ids[sorted_token_ids_offset + it] = numel;
179188
}
180189

181190
const size_t tid = threadIdx.x;
@@ -190,8 +199,8 @@ __device__ void _moe_align_block_size_small_batch_expert(
190199
}
191200

192201
for (size_t i = tid; i < numel; i += stride) {
193-
tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]] +=
194-
token_mask[i / topk_num];
202+
int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
203+
tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]] += mask;
195204
}
196205

197206
__syncthreads();
@@ -214,7 +223,7 @@ __device__ void _moe_align_block_size_small_batch_expert(
214223
CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) *
215224
block_size;
216225
}
217-
total_tokens_post_pad[adapter_offset] =
226+
total_tokens_post_pad[model_offset] =
218227
static_cast<int32_t>(cumsum[num_experts]);
219228
}
220229

@@ -223,25 +232,47 @@ __device__ void _moe_align_block_size_small_batch_expert(
223232
if (threadIdx.x < num_experts) {
224233
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
225234
i += block_size) {
226-
expert_ids[(max_num_m_blocks * adapter_offset) + i / block_size] =
227-
threadIdx.x;
235+
expert_ids[expert_ids_offset + i / block_size] = threadIdx.x;
228236
}
229237
}
230238

231239
// Fill remaining expert_ids with 0
232240
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
233241
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) {
234-
expert_ids[(max_num_m_blocks * adapter_offset) + i] = inactive_expert_id;
242+
expert_ids[expert_ids_offset + i] = inactive_expert_id;
235243
}
236244

237245
for (size_t i = tid; i < numel; i += stride) {
238246
int32_t expert_id = topk_ids[i];
239247
int32_t rank_post_pad =
240248
tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
241-
sorted_token_ids[(adapter_offset * max_num_tokens_padded) +
242-
rank_post_pad] += ((i - numel) * token_mask[i / topk_num]);
243-
tokens_cnts[threadIdx.x * num_experts + expert_id] +=
244-
token_mask[i / topk_num];
249+
250+
int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
251+
sorted_token_ids[sorted_token_ids_offset + rank_post_pad] +=
252+
((i - numel) * mask);
253+
tokens_cnts[threadIdx.x * num_experts + expert_id] += mask;
254+
}
255+
}
256+
257+
template <typename scalar_t>
258+
__device__ void _count_and_sort_expert_tokens(
259+
const scalar_t* __restrict__ topk_ids,
260+
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
261+
size_t numel, int32_t num_experts, int32_t max_num_tokens_padded,
262+
int32_t* __restrict__ token_mask, int32_t model_offset, int32_t topk_num) {
263+
const size_t tid = blockIdx.y * blockDim.x + threadIdx.x;
264+
const size_t stride = blockDim.x * gridDim.y;
265+
266+
for (size_t i = tid; i < numel; i += stride) {
267+
int32_t expert_id = topk_ids[i];
268+
if (expert_id >= num_experts) {
269+
continue;
270+
}
271+
272+
int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
273+
int32_t rank_post_pad = atomicAdd(
274+
&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], mask);
275+
sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = i;
245276
}
246277
}
247278

@@ -252,44 +283,22 @@ __global__ void moe_align_block_size_kernel(
252283
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
253284
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
254285
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded,
255-
int32_t topk_num, int32_t* __restrict__ token_mask) {
256-
int num_tokens = numel / topk_num;
257-
if (threadIdx.x < num_tokens) {
258-
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
259-
token_mask[i] = 1;
260-
}
261-
}
262-
263-
__syncthreads();
264-
286+
int32_t topk_num) {
265287
_moe_align_block_size(topk_ids, sorted_token_ids, expert_ids,
266288
total_tokens_post_pad, num_experts, padded_num_experts,
267289
experts_per_warp, block_size, numel, cumsum,
268-
max_num_tokens_padded, 0, 0, 0, topk_num, token_mask);
290+
max_num_tokens_padded, 0, 0, 0, topk_num, nullptr);
269291
}
270292

271293
template <typename scalar_t>
272294
__global__ void count_and_sort_expert_tokens_kernel(
273295
const scalar_t* __restrict__ topk_ids,
274296
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
275297
size_t numel, int32_t num_experts, int32_t max_num_tokens_padded,
276-
int32_t topk_num, int32_t* __restrict__ token_mask) {
277-
const size_t adapter_offset = blockIdx.x;
278-
const size_t tid = blockIdx.y * blockDim.x + threadIdx.x;
279-
const size_t stride = blockDim.x * gridDim.y;
280-
int num_tokens = numel / topk_num;
281-
282-
for (size_t i = tid; i < numel; i += stride) {
283-
int32_t expert_id = topk_ids[i];
284-
if (expert_id >= num_experts ||
285-
token_mask[(adapter_offset * num_tokens) + i / topk_num] == 0) {
286-
continue;
287-
}
288-
int32_t rank_post_pad = atomicAdd(
289-
&cumsum_buffer[(adapter_offset * (num_experts + 1)) + expert_id], 1);
290-
sorted_token_ids[max_num_tokens_padded * adapter_offset + rank_post_pad] =
291-
i;
292-
}
298+
int32_t topk_num) {
299+
_count_and_sort_expert_tokens(topk_ids, sorted_token_ids, cumsum_buffer,
300+
numel, num_experts, max_num_tokens_padded,
301+
nullptr, 0, topk_num);
293302
}
294303

295304
template <typename scalar_t, int TOPK>
@@ -314,29 +323,24 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
314323
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
315324
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
316325
int32_t block_size, size_t numel, int32_t max_num_tokens_padded,
317-
int32_t topk_num, int32_t* __restrict__ token_mask) {
318-
int num_tokens = numel / topk_num;
319-
if (threadIdx.x < num_tokens) {
320-
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
321-
token_mask[i] = 1;
322-
}
323-
}
324-
326+
int32_t topk_num) {
325327
__syncthreads();
326328

327329
_moe_align_block_size_small_batch_expert(
328330
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad,
329331
num_experts, block_size, numel, max_num_tokens_padded, 0, 0, 0, topk_num,
330-
token_mask);
332+
nullptr);
331333
}
332334

335+
namespace lora {
336+
333337
template <typename scalar_t>
334338
__global__ void moe_lora_align_block_size_kernel(
335339
scalar_t* __restrict__ topk_ids, scalar_t* __restrict__ token_lora_mapping,
336340
int64_t block_size, int num_experts, int max_loras, size_t numel,
337341
int max_num_tokens_padded, int max_num_m_blocks,
338342
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
339-
int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
343+
int32_t topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
340344
int32_t* __restrict__ cumsum, int32_t experts_per_warp,
341345
int32_t padded_num_experts, int32_t* lora_ids,
342346
int32_t* __restrict__ token_mask) {
@@ -346,6 +350,7 @@ __global__ void moe_lora_align_block_size_kernel(
346350
return;
347351
}
348352

353+
// Populate the token_mask based on the token-LoRA mapping
349354
int num_tokens = numel / topk_num;
350355
if (threadIdx.x == 0) {
351356
total_tokens_post_pad[lora_id] = 0;
@@ -361,10 +366,30 @@ __global__ void moe_lora_align_block_size_kernel(
361366
_moe_align_block_size(topk_ids, sorted_token_ids, expert_ids,
362367
total_tokens_post_pad, num_experts, padded_num_experts,
363368
experts_per_warp, block_size, numel, cumsum,
364-
max_num_tokens_padded, max_num_m_blocks, -1, lora_id,
369+
max_num_tokens_padded, max_num_m_blocks, lora_id, -1,
365370
topk_num, &token_mask[(lora_id * num_tokens)]);
366371
}
367372

373+
template <typename scalar_t>
374+
__global__ void lora_count_and_sort_expert_tokens_kernel(
375+
const scalar_t* __restrict__ topk_ids,
376+
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
377+
size_t numel, int32_t num_experts, int32_t max_num_tokens_padded,
378+
int32_t topk_num, int32_t* token_mask, int32_t* lora_ids) {
379+
int lora_idx = blockIdx.x;
380+
int lora_id = lora_ids[lora_idx];
381+
if (lora_id == -1) {
382+
return;
383+
}
384+
385+
int num_tokens = numel / topk_num;
386+
387+
_count_and_sort_expert_tokens(topk_ids, sorted_token_ids, cumsum_buffer,
388+
numel, num_experts, max_num_tokens_padded,
389+
&token_mask[(lora_id * num_tokens)], lora_id,
390+
topk_num);
391+
}
392+
368393
template <typename scalar_t>
369394
__global__ void moe_lora_align_block_size_small_batch_expert_kernel(
370395
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
@@ -397,6 +422,7 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel(
397422
-1, lora_id, topk_num, &token_mask[(lora_id * num_tokens)]);
398423
}
399424

425+
} // namespace lora
400426
} // namespace moe
401427
} // namespace vllm
402428

@@ -426,9 +452,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
426452
torch::Tensor cumsum_buffer =
427453
torch::empty({num_experts + 1}, options_int);
428454

429-
torch::Tensor token_mask =
430-
torch::empty({topk_ids.size(0)}, options_int);
431-
432455
bool small_batch_expert_mode =
433456
(topk_ids.numel() < 1024) && (num_experts <= 64);
434457

@@ -446,8 +469,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
446469
sorted_token_ids.data_ptr<int32_t>(),
447470
experts_ids.data_ptr<int32_t>(),
448471
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
449-
topk_ids.numel(), sorted_token_ids.size(0), topk_ids.size(1),
450-
token_mask.data_ptr<int32_t>());
472+
topk_ids.numel(), sorted_token_ids.size(0), topk_ids.size(1));
451473
} else {
452474
auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
453475

@@ -462,8 +484,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
462484
num_tokens_post_pad.data_ptr<int32_t>(), num_experts,
463485
padded_num_experts, experts_per_warp, block_size,
464486
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(),
465-
sorted_token_ids.size(0), topk_ids.size(1),
466-
token_mask.data_ptr<int32_t>());
487+
sorted_token_ids.size(0), topk_ids.size(1));
467488

468489
const int block_threads = std::min(256, (int)threads);
469490
const int num_blocks =
@@ -478,8 +499,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
478499
topk_ids.data_ptr<scalar_t>(),
479500
sorted_token_ids.data_ptr<int32_t>(),
480501
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel(), num_experts,
481-
sorted_token_ids.size(0), topk_ids.size(1),
482-
token_mask.data_ptr<int32_t>());
502+
sorted_token_ids.size(0), topk_ids.size(1));
483503
}
484504
});
485505
}
@@ -598,9 +618,8 @@ void moe_lora_align_block_size(
598618
}
599619

600620
dim3 blockDim(num_thread);
601-
auto kernel =
602-
vllm::moe::moe_lora_align_block_size_small_batch_expert_kernel<
603-
scalar_t>;
621+
auto kernel = vllm::moe::lora::
622+
moe_lora_align_block_size_small_batch_expert_kernel<scalar_t>;
604623
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
605624
(void*)kernel, shared_mem));
606625
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
@@ -627,7 +646,7 @@ void moe_lora_align_block_size(
627646
torch::zeros({max_loras * (num_experts + 1)}, options_int);
628647

629648
auto align_kernel =
630-
vllm::moe::moe_lora_align_block_size_kernel<scalar_t>;
649+
vllm::moe::lora::moe_lora_align_block_size_kernel<scalar_t>;
631650
align_kernel<<<max_loras, blockDim, shared_mem_size, stream>>>(
632651
topk_ids.data_ptr<scalar_t>(),
633652
token_lora_mapping.data_ptr<scalar_t>(), block_size, num_experts,
@@ -648,13 +667,14 @@ void moe_lora_align_block_size(
648667

649668
dim3 gridDims(max_loras, actual_blocks);
650669
auto sort_kernel =
651-
vllm::moe::count_and_sort_expert_tokens_kernel<scalar_t>;
670+
vllm::moe::lora::lora_count_and_sort_expert_tokens_kernel<
671+
scalar_t>;
652672

653673
sort_kernel<<<gridDims, block_threads, 0, stream>>>(
654674
topk_ids.data_ptr<scalar_t>(),
655675
sorted_token_ids.data_ptr<int32_t>(), cumsum.data_ptr<int32_t>(),
656676
topk_ids.numel(), num_experts, max_num_tokens_padded, topk_num,
657-
token_mask.data_ptr<int32_t>());
677+
token_mask.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>());
658678
}
659679
});
660680
}

0 commit comments

Comments
 (0)