Skip to content

Commit de72969

Browse files
committed
aligning
Signed-off-by: Barbara Suslova <[email protected]>
1 parent f9d6dc5 commit de72969

File tree

2 files changed

+43
-29
lines changed

2 files changed

+43
-29
lines changed

csrc/moe/moe_fused_gate.cu

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ __device__ void moe_fused_gate_impl(void* input, void* bias, float* output_ptr,
6969
}
7070

7171
// Calculate topk_excluding_share_expert_fusion from topk
72-
int64_t topk_excluding_share_expert_fusion =
73-
topk - (num_fused_shared_experts > 0 ? 1 : 0);
72+
int64_t topk_excluding_share_expert_fusion = topk - num_fused_shared_experts;
7473

7574
// Cast pointers to type T:
7675
auto* input_ptr = reinterpret_cast<T*>(input);
@@ -362,6 +361,9 @@ std::vector<at::Tensor> moe_fused_gate(
362361
at::Tensor& input, at::Tensor& bias, int64_t num_expert_group,
363362
int64_t topk_group, int64_t topk, int64_t num_fused_shared_experts,
364363
double routed_scaling_factor, bool apply_routed_scaling_factor_on_output) {
364+
TORCH_CHECK(input.dtype() == bias.dtype(),
365+
"input and bias should have the same dtype");
366+
365367
int64_t num_rows = input.size(0);
366368
int32_t num_experts = input.size(1);
367369
auto options =
@@ -410,16 +412,16 @@ std::vector<at::Tensor> moe_fused_gate(
410412
LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 8);
411413
} else if (input.scalar_type() == at::kFloat) {
412414
LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 8);
413-
} else if (num_expert_group == 16) {
414-
// Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6
415-
// * 2 = 12.
416-
if (input.scalar_type() == at::kBFloat16) {
417-
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 16);
418-
} else if (input.scalar_type() == at::kHalf) {
419-
LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 16);
420-
} else if (input.scalar_type() == at::kFloat) {
421-
LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 16);
422-
}
415+
}
416+
} else if (num_expert_group == 16) {
417+
// Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6
418+
// * 2 = 12.
419+
if (input.scalar_type() == at::kBFloat16) {
420+
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 16);
421+
} else if (input.scalar_type() == at::kHalf) {
422+
LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 16);
423+
} else if (input.scalar_type() == at::kFloat) {
424+
LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 16);
423425
}
424426
}
425427
break;
@@ -433,16 +435,16 @@ std::vector<at::Tensor> moe_fused_gate(
433435
LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 4);
434436
} else if (input.scalar_type() == at::kFloat) {
435437
LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 4);
436-
} else if (num_expert_group == 8) {
437-
// VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4
438-
// = 24.
439-
if (input.scalar_type() == at::kBFloat16) {
440-
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 8);
441-
} else if (input.scalar_type() == at::kHalf) {
442-
LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 8);
443-
} else if (input.scalar_type() == at::kFloat) {
444-
LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 8);
445-
}
438+
}
439+
} else if (num_expert_group == 8) {
440+
// VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4
441+
// = 24.
442+
if (input.scalar_type() == at::kBFloat16) {
443+
LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 8);
444+
} else if (input.scalar_type() == at::kHalf) {
445+
LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 8);
446+
} else if (input.scalar_type() == at::kFloat) {
447+
LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 8);
446448
}
447449
}
448450
break;

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,7 +1169,8 @@ def grouped_topk(
11691169
num_fused_shared_experts: int = 0,
11701170
) -> tuple[torch.Tensor, torch.Tensor]:
11711171
use_fused_moe_grouped_topk = envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
1172-
if num_fused_shared_experts > 0 and use_fused_moe_grouped_topk:
1172+
enable_fused_shared_experts = num_fused_shared_experts > 0
1173+
if enable_fused_shared_experts and use_fused_moe_grouped_topk:
11731174
logger.info(
11741175
"Fused MoE grouped topk is enabled with fused shared experts.",
11751176
"Only one of these options can be used at a time",
@@ -1235,15 +1236,23 @@ def grouped_topk(
12351236
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
12361237

12371238
if e_score_correction_bias is not None:
1238-
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
1239+
topk_ids = torch.topk(
1240+
tmp_scores,
1241+
k=topk,
1242+
dim=-1,
1243+
sorted=(use_sorted or enable_fused_shared_experts),
1244+
)[1]
12391245
# Use original unbiased scores for the routing weights
12401246
topk_weights = original_scores.gather(1, topk_ids)
12411247
else:
12421248
topk_weights, topk_ids = torch.topk(
1243-
tmp_scores, k=topk, dim=-1, sorted=use_sorted
1249+
tmp_scores,
1250+
k=topk,
1251+
dim=-1,
1252+
sorted=(use_sorted or enable_fused_shared_experts),
12441253
)
12451254

1246-
if num_fused_shared_experts > 0:
1255+
if enable_fused_shared_experts:
12471256
assert routed_scaling_factor is not None, "With num_fused_shared_experts>0"
12481257
", routed_scaling_factor need to be provided"
12491258
topk_ids[:, -1] = torch.randint(
@@ -1253,16 +1262,19 @@ def grouped_topk(
12531262
dtype=topk_ids.dtype,
12541263
device=topk_ids.device,
12551264
)
1256-
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
1265+
if routed_scaling_factor != 1.0:
1266+
topk_weights[:, -1] = (
1267+
topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
1268+
)
12571269

12581270
if renormalize:
1259-
if num_fused_shared_experts == 0:
1271+
if not enable_fused_shared_experts:
12601272
topk_weights_sum = topk_weights.sum(dim=-1, keepdim=True)
12611273
else:
12621274
topk_weights_sum = topk_weights[:, :-1].sum(dim=-1, keepdim=True)
12631275
topk_weights = topk_weights / topk_weights_sum
12641276

1265-
if num_fused_shared_experts == 0 and routed_scaling_factor != 1.0:
1277+
if not enable_fused_shared_experts and routed_scaling_factor != 1.0:
12661278
topk_weights = topk_weights * routed_scaling_factor
12671279
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
12681280

0 commit comments

Comments
 (0)