Skip to content

Commit 0ee6416

Browse files
authored
[Perf] Optimize group_topk kernel, 1.9% Throughput improvement, 2.1% TPOT improvemnt (#30159)
Signed-off-by: yewentao256 <[email protected]>
1 parent d941709 commit 0ee6416

File tree

1 file changed

+128
-47
lines changed

1 file changed

+128
-47
lines changed

csrc/moe/grouped_topk_kernels.cu

Lines changed: 128 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -444,23 +444,27 @@ __device__ inline T apply_sigmoid(T val) {
444444
return cuda_cast<T, float>(sigmoid_accurate(f));
445445
}
446446

447-
template <typename T>
447+
template <ScoringFunc SF, typename T>
448+
__device__ inline T apply_scoring(T val) {
449+
if constexpr (SF == SCORING_SIGMOID) {
450+
return apply_sigmoid(val);
451+
} else {
452+
return val;
453+
}
454+
}
455+
456+
template <typename T, ScoringFunc SF>
448457
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
449458
cg::thread_block_tile<32> const& tile,
450459
int32_t const lane_id,
451-
int const num_experts_per_group,
452-
int const scoring_func) {
460+
int const num_experts_per_group) {
453461
// Get the top2 per thread
454462
T largest = neg_inf<T>();
455463
T second_largest = neg_inf<T>();
456464

457465
if (num_experts_per_group > WARP_SIZE) {
458466
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
459-
T value = input[i];
460-
// Apply scoring function if needed
461-
if (scoring_func == SCORING_SIGMOID) {
462-
value = apply_sigmoid(value);
463-
}
467+
T value = apply_scoring<SF>(input[i]);
464468
value = value + bias[i];
465469

466470
if (value > largest) {
@@ -472,11 +476,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
472476
}
473477
} else {
474478
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
475-
T value = input[i];
476-
// Apply scoring function if needed
477-
if (scoring_func == SCORING_SIGMOID) {
478-
value = apply_sigmoid(value);
479-
}
479+
T value = apply_scoring<SF>(input[i]);
480480
value = value + bias[i];
481481
largest = value;
482482
}
@@ -501,13 +501,12 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
501501
}
502502
}
503503

504-
template <typename T>
504+
template <typename T, ScoringFunc SF>
505505
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
506506
int64_t const num_tokens,
507507
int64_t const num_cases,
508508
int64_t const n_group,
509-
int64_t const num_experts_per_group,
510-
int const scoring_func) {
509+
int64_t const num_experts_per_group) {
511510
int32_t warp_id = threadIdx.x / WARP_SIZE;
512511
int32_t lane_id = threadIdx.x % WARP_SIZE;
513512

@@ -525,21 +524,21 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
525524
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
526525
asm volatile("griddepcontrol.wait;");
527526
#endif
528-
topk_with_k2(output, input, group_bias, tile, lane_id,
529-
num_experts_per_group, scoring_func);
527+
topk_with_k2<T, SF>(output, input, group_bias, tile, lane_id,
528+
num_experts_per_group);
530529
}
531530
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
532531
asm volatile("griddepcontrol.launch_dependents;");
533532
#endif
534533
}
535534

536-
template <typename T, typename IdxT>
535+
template <typename T, typename IdxT, ScoringFunc SF, int NGroup = -1>
537536
__global__ void group_idx_and_topk_idx_kernel(
538537
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
539538
T const* bias, int64_t const num_tokens, int64_t const n_group,
540539
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
541540
int64_t const num_experts_per_group, bool renormalize,
542-
double routed_scaling_factor, int scoring_func) {
541+
double routed_scaling_factor) {
543542
int32_t warp_id = threadIdx.x / WARP_SIZE;
544543
int32_t lane_id = threadIdx.x % WARP_SIZE;
545544
int32_t case_id =
@@ -549,6 +548,11 @@ __global__ void group_idx_and_topk_idx_kernel(
549548
topk_values += case_id * topk;
550549
topk_indices += case_id * topk;
551550

551+
constexpr bool kUseStaticNGroup = (NGroup > 0);
552+
// use int32 to avoid implicit conversion
553+
int32_t const n_group_i32 =
554+
kUseStaticNGroup ? NGroup : static_cast<int32_t>(n_group);
555+
552556
int32_t align_num_experts_per_group =
553557
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
554558

@@ -574,13 +578,14 @@ __global__ void group_idx_and_topk_idx_kernel(
574578

575579
if (case_id < num_tokens) {
576580
// calculate group_idx
577-
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
581+
int32_t target_num_min =
582+
WARP_SIZE - n_group_i32 + static_cast<int32_t>(topk_group);
578583
// The check is necessary to avoid abnormal input
579-
if (lane_id < n_group && is_finite(group_scores[lane_id])) {
584+
if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) {
580585
value = group_scores[lane_id];
581586
}
582587

583-
int count_equal_to_top_value = WARP_SIZE - n_group;
588+
int count_equal_to_top_value = WARP_SIZE - n_group_i32;
584589
int pre_count_equal_to_top_value = 0;
585590
// Use loop to find the largset top_group
586591
while (count_equal_to_top_value < target_num_min) {
@@ -604,7 +609,7 @@ __global__ void group_idx_and_topk_idx_kernel(
604609
int count_equalto_topkth_group = 0;
605610
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
606611
if (case_id < num_tokens && if_proceed_next_topk) {
607-
for (int i_group = 0; i_group < n_group; i_group++) {
612+
auto process_group = [&](int i_group) {
608613
if ((group_scores[i_group] > topk_group_value) ||
609614
((group_scores[i_group] == topk_group_value) &&
610615
(count_equalto_topkth_group < num_equalto_topkth_group))) {
@@ -613,11 +618,10 @@ __global__ void group_idx_and_topk_idx_kernel(
613618
i += WARP_SIZE) {
614619
T candidates = neg_inf<T>();
615620
if (i < num_experts_per_group) {
616-
// Apply scoring function (if any) and add bias
621+
// apply scoring function (if any) and add bias
617622
T input = scores[offset + i];
618623
if (is_finite(input)) {
619-
T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input)
620-
: input;
624+
T score = apply_scoring<SF>(input);
621625
candidates = score + bias[offset + i];
622626
}
623627
}
@@ -627,6 +631,17 @@ __global__ void group_idx_and_topk_idx_kernel(
627631
count_equalto_topkth_group++;
628632
}
629633
}
634+
};
635+
636+
if constexpr (kUseStaticNGroup) {
637+
#pragma unroll
638+
for (int i_group = 0; i_group < NGroup; ++i_group) {
639+
process_group(i_group);
640+
}
641+
} else {
642+
for (int i_group = 0; i_group < n_group_i32; ++i_group) {
643+
process_group(i_group);
644+
}
630645
}
631646
queue.done();
632647
__syncwarp();
@@ -646,12 +661,13 @@ __global__ void group_idx_and_topk_idx_kernel(
646661
if (i < topk) {
647662
// Load the score value (without bias) for normalization
648663
T input = scores[s_topk_idx[i]];
649-
value =
650-
(scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input;
664+
value = apply_scoring<SF>(input);
651665
s_topk_value[i] = value;
652666
}
653-
topk_sum +=
654-
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
667+
if (renormalize) {
668+
topk_sum +=
669+
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
670+
}
655671
}
656672
}
657673

@@ -660,13 +676,9 @@ __global__ void group_idx_and_topk_idx_kernel(
660676
if (case_id < num_tokens) {
661677
if (if_proceed_next_topk) {
662678
for (int i = lane_id; i < topk; i += WARP_SIZE) {
663-
float value;
664-
if (renormalize) {
665-
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
666-
routed_scaling_factor;
667-
} else {
668-
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
669-
}
679+
float base = cuda_cast<float, T>(s_topk_value[i]);
680+
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
681+
: (base * routed_scaling_factor);
670682
topk_indices[i] = s_topk_idx[i];
671683
topk_values[i] = value;
672684
}
@@ -684,6 +696,45 @@ __global__ void group_idx_and_topk_idx_kernel(
684696
#endif
685697
}
686698

699+
template <typename T, typename IdxT, ScoringFunc SF>
700+
inline void launch_group_idx_and_topk_kernel(
701+
cudaLaunchConfig_t const& config, T* scores, T* group_scores,
702+
float* topk_values, IdxT* topk_indices, T const* bias,
703+
int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
704+
int64_t const topk, int64_t const num_experts,
705+
int64_t const num_experts_per_group, bool const renormalize,
706+
double const routed_scaling_factor) {
707+
auto launch = [&](auto* kernel_instance2) {
708+
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
709+
topk_values, topk_indices, bias, num_tokens, n_group,
710+
topk_group, topk, num_experts, num_experts_per_group,
711+
renormalize, routed_scaling_factor);
712+
};
713+
714+
switch (n_group) {
715+
case 4: {
716+
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 4>);
717+
break;
718+
}
719+
case 8: {
720+
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 8>);
721+
break;
722+
}
723+
case 16: {
724+
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 16>);
725+
break;
726+
}
727+
case 32: {
728+
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 32>);
729+
break;
730+
}
731+
default: {
732+
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF>);
733+
break;
734+
}
735+
}
736+
}
737+
687738
template <typename T, typename IdxT>
688739
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
689740
IdxT* topk_indices, T const* bias, int64_t const num_tokens,
@@ -694,7 +745,6 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
694745
cudaStream_t const stream = 0) {
695746
int64_t num_cases = num_tokens * n_group;
696747
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
697-
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
698748
cudaLaunchConfig_t config;
699749
config.gridDim = topk_with_k2_num_blocks;
700750
config.blockDim = BLOCK_SIZE;
@@ -705,16 +755,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
705755
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
706756
config.numAttrs = 1;
707757
config.attrs = attrs;
708-
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
709-
num_tokens, num_cases, n_group, num_experts / n_group,
710-
scoring_func);
758+
auto const sf = static_cast<ScoringFunc>(scoring_func);
759+
int64_t const num_experts_per_group = num_experts / n_group;
760+
auto launch_topk_with_k2 = [&](auto* kernel_instance1) {
761+
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
762+
num_tokens, num_cases, n_group, num_experts_per_group);
763+
};
764+
switch (sf) {
765+
case SCORING_NONE: {
766+
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_NONE>;
767+
launch_topk_with_k2(kernel_instance1);
768+
break;
769+
}
770+
case SCORING_SIGMOID: {
771+
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_SIGMOID>;
772+
launch_topk_with_k2(kernel_instance1);
773+
break;
774+
}
775+
default:
776+
// should be guarded by higher level checks.
777+
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
778+
}
711779

712780
int64_t topk_with_k_group_num_blocks =
713781
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
714782
size_t dynamic_smem_in_bytes =
715783
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
716784
topk);
717-
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
718785
config.gridDim = topk_with_k_group_num_blocks;
719786
config.blockDim = BLOCK_SIZE;
720787
config.dynamicSmemBytes = dynamic_smem_in_bytes;
@@ -723,10 +790,24 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
723790
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
724791
config.numAttrs = 1;
725792
config.attrs = attrs;
726-
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
727-
topk_values, topk_indices, bias, num_tokens, n_group,
728-
topk_group, topk, num_experts, num_experts / n_group,
729-
renormalize, routed_scaling_factor, scoring_func);
793+
switch (sf) {
794+
case SCORING_NONE: {
795+
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_NONE>(
796+
config, scores, group_scores, topk_values, topk_indices, bias,
797+
num_tokens, n_group, topk_group, topk, num_experts,
798+
num_experts_per_group, renormalize, routed_scaling_factor);
799+
break;
800+
}
801+
case SCORING_SIGMOID: {
802+
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_SIGMOID>(
803+
config, scores, group_scores, topk_values, topk_indices, bias,
804+
num_tokens, n_group, topk_group, topk, num_experts,
805+
num_experts_per_group, renormalize, routed_scaling_factor);
806+
break;
807+
}
808+
default:
809+
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
810+
}
730811
}
731812

732813
#define INSTANTIATE_NOAUX_TC(T, IdxT) \

0 commit comments

Comments
 (0)