@@ -69,8 +69,7 @@ __device__ void moe_fused_gate_impl(void* input, void* bias, float* output_ptr,
6969 }
7070
7171 // Calculate topk_excluding_share_expert_fusion from topk
72- int64_t topk_excluding_share_expert_fusion =
73- topk - (num_fused_shared_experts > 0 ? 1 : 0 );
72+ int64_t topk_excluding_share_expert_fusion = topk - num_fused_shared_experts;
7473
7574 // Cast pointers to type T:
7675 auto * input_ptr = reinterpret_cast <T*>(input);
@@ -362,6 +361,9 @@ std::vector<at::Tensor> moe_fused_gate(
362361 at::Tensor& input, at::Tensor& bias, int64_t num_expert_group,
363362 int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts,
364363 double routed_scaling_factor, bool apply_routed_scaling_factor_on_output) {
364+ TORCH_CHECK (input.dtype () == bias.dtype (),
365+ " input and bias should have the same dtype" );
366+
365367 int64_t num_rows = input.size (0 );
366368 int32_t num_experts = input.size (1 );
367369 auto options =
@@ -410,16 +412,16 @@ std::vector<at::Tensor> moe_fused_gate(
410412 LAUNCH_MOE_GATE_CONFIG (float16_t , 256 , 8 );
411413 } else if (input.scalar_type () == at::kFloat ) {
412414 LAUNCH_MOE_GATE_CONFIG (float32_t , 256 , 8 );
413- } else if (num_expert_group == 16 ) {
414- // Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6
415- // * 2 = 12.
416- if (input. scalar_type () == at:: kBFloat16 ) {
417- LAUNCH_MOE_GATE_CONFIG ( bfloat16_t , 256 , 16 );
418- } else if (input. scalar_type () == at:: kHalf ) {
419- LAUNCH_MOE_GATE_CONFIG ( float16_t , 256 , 16 );
420- } else if (input. scalar_type () == at:: kFloat ) {
421- LAUNCH_MOE_GATE_CONFIG ( float32_t , 256 , 16 );
422- }
415+ }
416+ } else if (num_expert_group = = 16 ) {
417+ // Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6
418+ // * 2 = 12.
419+ if (input. scalar_type () == at:: kBFloat16 ) {
420+ LAUNCH_MOE_GATE_CONFIG ( bfloat16_t , 256 , 16 );
421+ } else if (input. scalar_type () == at:: kHalf ) {
422+ LAUNCH_MOE_GATE_CONFIG ( float16_t , 256 , 16 );
423+ } else if (input. scalar_type () == at:: kFloat ) {
424+ LAUNCH_MOE_GATE_CONFIG ( float32_t , 256 , 16 );
423425 }
424426 }
425427 break ;
@@ -433,16 +435,16 @@ std::vector<at::Tensor> moe_fused_gate(
433435 LAUNCH_MOE_GATE_CONFIG (float16_t , 128 , 4 );
434436 } else if (input.scalar_type () == at::kFloat ) {
435437 LAUNCH_MOE_GATE_CONFIG (float32_t , 128 , 4 );
436- } else if (num_expert_group == 8 ) {
437- // VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4
438- // = 24.
439- if (input. scalar_type () == at:: kBFloat16 ) {
440- LAUNCH_MOE_GATE_CONFIG ( bfloat16_t , 128 , 8 );
441- } else if (input. scalar_type () == at:: kHalf ) {
442- LAUNCH_MOE_GATE_CONFIG ( float16_t , 128 , 8 );
443- } else if (input. scalar_type () == at:: kFloat ) {
444- LAUNCH_MOE_GATE_CONFIG ( float32_t , 128 , 8 );
445- }
438+ }
439+ } else if (num_expert_group == 8 ) {
440+ // VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4
441+ // = 24.
442+ if (input. scalar_type () == at:: kBFloat16 ) {
443+ LAUNCH_MOE_GATE_CONFIG ( bfloat16_t , 128 , 8 );
444+ } else if (input. scalar_type () == at:: kHalf ) {
445+ LAUNCH_MOE_GATE_CONFIG ( float16_t , 128 , 8 );
446+ } else if (input. scalar_type () == at:: kFloat ) {
447+ LAUNCH_MOE_GATE_CONFIG ( float32_t , 128 , 8 );
446448 }
447449 }
448450 break ;
0 commit comments