@@ -444,23 +444,27 @@ __device__ inline T apply_sigmoid(T val) {
444444 return cuda_cast<T, float >(sigmoid_accurate (f));
445445}
446446
447- template <typename T>
447+ template <ScoringFunc SF, typename T>
448+ __device__ inline T apply_scoring (T val) {
449+ if constexpr (SF == SCORING_SIGMOID) {
450+ return apply_sigmoid (val);
451+ } else {
452+ return val;
453+ }
454+ }
455+
456+ template <typename T, ScoringFunc SF>
448457__device__ void topk_with_k2 (T* output, T const * input, T const * bias,
449458 cg::thread_block_tile<32 > const & tile,
450459 int32_t const lane_id,
451- int const num_experts_per_group,
452- int const scoring_func) {
460+ int const num_experts_per_group) {
453461 // Get the top2 per thread
454462 T largest = neg_inf<T>();
455463 T second_largest = neg_inf<T>();
456464
457465 if (num_experts_per_group > WARP_SIZE) {
458466 for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
459- T value = input[i];
460- // Apply scoring function if needed
461- if (scoring_func == SCORING_SIGMOID) {
462- value = apply_sigmoid (value);
463- }
467+ T value = apply_scoring<SF>(input[i]);
464468 value = value + bias[i];
465469
466470 if (value > largest) {
@@ -472,11 +476,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
472476 }
473477 } else {
474478 for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
475- T value = input[i];
476- // Apply scoring function if needed
477- if (scoring_func == SCORING_SIGMOID) {
478- value = apply_sigmoid (value);
479- }
479+ T value = apply_scoring<SF>(input[i]);
480480 value = value + bias[i];
481481 largest = value;
482482 }
@@ -501,13 +501,12 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
501501 }
502502}
503503
504- template <typename T>
504+ template <typename T, ScoringFunc SF >
505505__global__ void topk_with_k2_kernel (T* output, T* input, T const * bias,
506506 int64_t const num_tokens,
507507 int64_t const num_cases,
508508 int64_t const n_group,
509- int64_t const num_experts_per_group,
510- int const scoring_func) {
509+ int64_t const num_experts_per_group) {
511510 int32_t warp_id = threadIdx .x / WARP_SIZE;
512511 int32_t lane_id = threadIdx .x % WARP_SIZE;
513512
@@ -525,21 +524,21 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
525524#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
526525 asm volatile (" griddepcontrol.wait;" );
527526#endif
528- topk_with_k2 (output, input, group_bias, tile, lane_id,
529- num_experts_per_group, scoring_func );
527+ topk_with_k2<T, SF> (output, input, group_bias, tile, lane_id,
528+ num_experts_per_group );
530529 }
531530#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
532531 asm volatile (" griddepcontrol.launch_dependents;" );
533532#endif
534533}
535534
536- template <typename T, typename IdxT>
535+ template <typename T, typename IdxT, ScoringFunc SF, int NGroup = - 1 >
537536__global__ void group_idx_and_topk_idx_kernel (
538537 T* scores, T const * group_scores, float * topk_values, IdxT* topk_indices,
539538 T const * bias, int64_t const num_tokens, int64_t const n_group,
540539 int64_t const topk_group, int64_t const topk, int64_t const num_experts,
541540 int64_t const num_experts_per_group, bool renormalize,
542- double routed_scaling_factor, int scoring_func ) {
541+ double routed_scaling_factor) {
543542 int32_t warp_id = threadIdx .x / WARP_SIZE;
544543 int32_t lane_id = threadIdx .x % WARP_SIZE;
545544 int32_t case_id =
@@ -549,6 +548,11 @@ __global__ void group_idx_and_topk_idx_kernel(
549548 topk_values += case_id * topk;
550549 topk_indices += case_id * topk;
551550
551+ constexpr bool kUseStaticNGroup = (NGroup > 0 );
552+ // use int32 to avoid implicit conversion
553+ int32_t const n_group_i32 =
554+ kUseStaticNGroup ? NGroup : static_cast <int32_t >(n_group);
555+
552556 int32_t align_num_experts_per_group =
553557 warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
554558
@@ -574,13 +578,14 @@ __global__ void group_idx_and_topk_idx_kernel(
574578
575579 if (case_id < num_tokens) {
576580 // calculate group_idx
577- int32_t target_num_min = WARP_SIZE - n_group + topk_group;
581+ int32_t target_num_min =
582+ WARP_SIZE - n_group_i32 + static_cast <int32_t >(topk_group);
578583 // The check is necessary to avoid abnormal input
579- if (lane_id < n_group && is_finite (group_scores[lane_id])) {
584+ if (lane_id < n_group_i32 && is_finite (group_scores[lane_id])) {
580585 value = group_scores[lane_id];
581586 }
582587
583- int count_equal_to_top_value = WARP_SIZE - n_group ;
588+ int count_equal_to_top_value = WARP_SIZE - n_group_i32 ;
584589 int pre_count_equal_to_top_value = 0 ;
585590 // Use loop to find the largset top_group
586591 while (count_equal_to_top_value < target_num_min) {
@@ -604,7 +609,7 @@ __global__ void group_idx_and_topk_idx_kernel(
604609 int count_equalto_topkth_group = 0 ;
605610 bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
606611 if (case_id < num_tokens && if_proceed_next_topk) {
607- for ( int i_group = 0 ; i_group < n_group; i_group++ ) {
612+ auto process_group = [&]( int i_group) {
608613 if ((group_scores[i_group] > topk_group_value) ||
609614 ((group_scores[i_group] == topk_group_value) &&
610615 (count_equalto_topkth_group < num_equalto_topkth_group))) {
@@ -613,11 +618,10 @@ __global__ void group_idx_and_topk_idx_kernel(
613618 i += WARP_SIZE) {
614619 T candidates = neg_inf<T>();
615620 if (i < num_experts_per_group) {
616- // Apply scoring function (if any) and add bias
621+ // apply scoring function (if any) and add bias
617622 T input = scores[offset + i];
618623 if (is_finite (input)) {
619- T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid (input)
620- : input;
624+ T score = apply_scoring<SF>(input);
621625 candidates = score + bias[offset + i];
622626 }
623627 }
@@ -627,6 +631,17 @@ __global__ void group_idx_and_topk_idx_kernel(
627631 count_equalto_topkth_group++;
628632 }
629633 }
634+ };
635+
636+ if constexpr (kUseStaticNGroup ) {
637+ #pragma unroll
638+ for (int i_group = 0 ; i_group < NGroup; ++i_group) {
639+ process_group (i_group);
640+ }
641+ } else {
642+ for (int i_group = 0 ; i_group < n_group_i32; ++i_group) {
643+ process_group (i_group);
644+ }
630645 }
631646 queue.done ();
632647 __syncwarp ();
@@ -646,12 +661,13 @@ __global__ void group_idx_and_topk_idx_kernel(
646661 if (i < topk) {
647662 // Load the score value (without bias) for normalization
648663 T input = scores[s_topk_idx[i]];
649- value =
650- (scoring_func == SCORING_SIGMOID) ? apply_sigmoid (input) : input;
664+ value = apply_scoring<SF>(input);
651665 s_topk_value[i] = value;
652666 }
653- topk_sum +=
654- cg::reduce (tile, cuda_cast<float , T>(value), cg::plus<float >());
667+ if (renormalize) {
668+ topk_sum +=
669+ cg::reduce (tile, cuda_cast<float , T>(value), cg::plus<float >());
670+ }
655671 }
656672 }
657673
@@ -660,13 +676,9 @@ __global__ void group_idx_and_topk_idx_kernel(
660676 if (case_id < num_tokens) {
661677 if (if_proceed_next_topk) {
662678 for (int i = lane_id; i < topk; i += WARP_SIZE) {
663- float value;
664- if (renormalize) {
665- value = cuda_cast<float , T>(s_topk_value[i]) / topk_sum *
666- routed_scaling_factor;
667- } else {
668- value = cuda_cast<float , T>(s_topk_value[i]) * routed_scaling_factor;
669- }
679+ float base = cuda_cast<float , T>(s_topk_value[i]);
680+ float value = renormalize ? (base / topk_sum * routed_scaling_factor)
681+ : (base * routed_scaling_factor);
670682 topk_indices[i] = s_topk_idx[i];
671683 topk_values[i] = value;
672684 }
@@ -684,6 +696,45 @@ __global__ void group_idx_and_topk_idx_kernel(
684696#endif
685697}
686698
699+ template <typename T, typename IdxT, ScoringFunc SF>
700+ inline void launch_group_idx_and_topk_kernel (
701+ cudaLaunchConfig_t const & config, T* scores, T* group_scores,
702+ float * topk_values, IdxT* topk_indices, T const * bias,
703+ int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
704+ int64_t const topk, int64_t const num_experts,
705+ int64_t const num_experts_per_group, bool const renormalize,
706+ double const routed_scaling_factor) {
707+ auto launch = [&](auto * kernel_instance2) {
708+ cudaLaunchKernelEx (&config, kernel_instance2, scores, group_scores,
709+ topk_values, topk_indices, bias, num_tokens, n_group,
710+ topk_group, topk, num_experts, num_experts_per_group,
711+ renormalize, routed_scaling_factor);
712+ };
713+
714+ switch (n_group) {
715+ case 4 : {
716+ launch (&group_idx_and_topk_idx_kernel<T, IdxT, SF, 4 >);
717+ break ;
718+ }
719+ case 8 : {
720+ launch (&group_idx_and_topk_idx_kernel<T, IdxT, SF, 8 >);
721+ break ;
722+ }
723+ case 16 : {
724+ launch (&group_idx_and_topk_idx_kernel<T, IdxT, SF, 16 >);
725+ break ;
726+ }
727+ case 32 : {
728+ launch (&group_idx_and_topk_idx_kernel<T, IdxT, SF, 32 >);
729+ break ;
730+ }
731+ default : {
732+ launch (&group_idx_and_topk_idx_kernel<T, IdxT, SF>);
733+ break ;
734+ }
735+ }
736+ }
737+
687738template <typename T, typename IdxT>
688739void invokeNoAuxTc (T* scores, T* group_scores, float * topk_values,
689740 IdxT* topk_indices, T const * bias, int64_t const num_tokens,
@@ -694,7 +745,6 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
694745 cudaStream_t const stream = 0 ) {
695746 int64_t num_cases = num_tokens * n_group;
696747 int64_t topk_with_k2_num_blocks = (num_cases - 1 ) / NUM_WARPS_PER_BLOCK + 1 ;
697- auto * kernel_instance1 = &topk_with_k2_kernel<T>;
698748 cudaLaunchConfig_t config;
699749 config.gridDim = topk_with_k2_num_blocks;
700750 config.blockDim = BLOCK_SIZE;
@@ -705,16 +755,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
705755 attrs[0 ].val .programmaticStreamSerializationAllowed = enable_pdl;
706756 config.numAttrs = 1 ;
707757 config.attrs = attrs;
708- cudaLaunchKernelEx (&config, kernel_instance1, group_scores, scores, bias,
709- num_tokens, num_cases, n_group, num_experts / n_group,
710- scoring_func);
758+ auto const sf = static_cast <ScoringFunc>(scoring_func);
759+ int64_t const num_experts_per_group = num_experts / n_group;
760+ auto launch_topk_with_k2 = [&](auto * kernel_instance1) {
761+ cudaLaunchKernelEx (&config, kernel_instance1, group_scores, scores, bias,
762+ num_tokens, num_cases, n_group, num_experts_per_group);
763+ };
764+ switch (sf) {
765+ case SCORING_NONE: {
766+ auto * kernel_instance1 = &topk_with_k2_kernel<T, SCORING_NONE>;
767+ launch_topk_with_k2 (kernel_instance1);
768+ break ;
769+ }
770+ case SCORING_SIGMOID: {
771+ auto * kernel_instance1 = &topk_with_k2_kernel<T, SCORING_SIGMOID>;
772+ launch_topk_with_k2 (kernel_instance1);
773+ break ;
774+ }
775+ default :
776+ // should be guarded by higher level checks.
777+ TORCH_CHECK (false , " Unsupported scoring_func in invokeNoAuxTc" );
778+ }
711779
712780 int64_t topk_with_k_group_num_blocks =
713781 (num_tokens - 1 ) / NUM_WARPS_PER_BLOCK + 1 ;
714782 size_t dynamic_smem_in_bytes =
715783 warp_topk::calc_smem_size_for_block_wide<T, int32_t >(NUM_WARPS_PER_BLOCK,
716784 topk);
717- auto * kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
718785 config.gridDim = topk_with_k_group_num_blocks;
719786 config.blockDim = BLOCK_SIZE;
720787 config.dynamicSmemBytes = dynamic_smem_in_bytes;
@@ -723,10 +790,24 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
723790 attrs[0 ].val .programmaticStreamSerializationAllowed = enable_pdl;
724791 config.numAttrs = 1 ;
725792 config.attrs = attrs;
726- cudaLaunchKernelEx (&config, kernel_instance2, scores, group_scores,
727- topk_values, topk_indices, bias, num_tokens, n_group,
728- topk_group, topk, num_experts, num_experts / n_group,
729- renormalize, routed_scaling_factor, scoring_func);
793+ switch (sf) {
794+ case SCORING_NONE: {
795+ launch_group_idx_and_topk_kernel<T, IdxT, SCORING_NONE>(
796+ config, scores, group_scores, topk_values, topk_indices, bias,
797+ num_tokens, n_group, topk_group, topk, num_experts,
798+ num_experts_per_group, renormalize, routed_scaling_factor);
799+ break ;
800+ }
801+ case SCORING_SIGMOID: {
802+ launch_group_idx_and_topk_kernel<T, IdxT, SCORING_SIGMOID>(
803+ config, scores, group_scores, topk_values, topk_indices, bias,
804+ num_tokens, n_group, topk_group, topk, num_experts,
805+ num_experts_per_group, renormalize, routed_scaling_factor);
806+ break ;
807+ }
808+ default :
809+ TORCH_CHECK (false , " Unsupported scoring_func in invokeNoAuxTc" );
810+ }
730811}
731812
732813#define INSTANTIATE_NOAUX_TC (T, IdxT ) \
0 commit comments