22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import math
5- from abc import abstractmethod
65from collections .abc import Callable , Iterable
76from contextlib import nullcontext
87from enum import Enum
9- from functools import partial
108from typing import Literal , get_args , overload
119
1210import torch
5452from vllm .v1 .worker .ubatching import dbo_current_ubatch_id
5553
5654if current_platform .is_cuda_alike ():
57- from .fused_moe import eplb_map_to_physical_and_record , fused_experts
5855 from vllm ._custom_ops import moe_fused_gate
56+
57+ from .fused_moe import eplb_map_to_physical_and_record , fused_experts
5958else :
6059 fused_experts = None # type: ignore
6160 FusedMoEPermuteExpertsUnpermute = object # type: ignore
@@ -371,9 +370,10 @@ def __init__(
371370 dp_size_ = dp_size_ ,
372371 vllm_parallel_config = vllm_config .parallel_config ,
373372 )
374-
373+
375374 self .enable_fused_shared_experts = enable_fused_shared_experts
376375 if self .enable_fused_shared_experts :
376+ assert n_shared_experts is not None
377377 num_experts += n_shared_experts
378378 top_k += n_shared_experts
379379
@@ -414,10 +414,11 @@ def __init__(
414414
415415 self .num_fused_shared_experts = (
416416 n_shared_experts
417- if (
418- n_shared_experts is not None
419- and self .aiter_fmoe_shared_expert_enabled
420- ) or self .enable_fused_shared_experts
417+ if n_shared_experts is not None
418+ and (
419+ self .aiter_fmoe_shared_expert_enabled
420+ or self .enable_fused_shared_experts
421+ )
421422 else 0
422423 )
423424 if (
@@ -487,12 +488,15 @@ def __init__(
487488 self .global_num_experts ,
488489 get_compressed_expert_map (self .expert_map ),
489490 )
490- if ( self .num_fused_shared_experts > 0 ) :
491+ if self .num_fused_shared_experts > 0 :
491492 logger .warning (
492493 "With EP enabled and share expert fusion enabled"
493494 ", share expert replica should be same as ep_size"
494495 "got share expert replica = %d"
495- "and ep_size = %d" , self .num_fused_shared_experts , self .ep_size )
496+ "and ep_size = %d" ,
497+ self .num_fused_shared_experts ,
498+ self .ep_size ,
499+ )
496500 else :
497501 self .local_num_experts , self .expert_map , self .expert_mask = (
498502 self .global_num_experts ,
@@ -1375,23 +1379,24 @@ def select_experts(
13751379 assert topk_group is not None
13761380 assert num_expert_group is not None
13771381 if hidden_states .shape [0 ] == 0 :
1378- topk_ids = torch .full ((0 , top_k ),
1379- - 1 ,
1380- dtype = torch .int ,
1381- device = hidden_states .device )
1382- topk_weights = torch .empty ((0 , top_k ),
1383- dtype = torch .float32 ,
1384- device = hidden_states .device )
1382+ topk_ids = torch .full (
1383+ (0 , top_k ), - 1 , dtype = torch .int , device = hidden_states .device
1384+ )
1385+ topk_weights = torch .empty (
1386+ (0 , top_k ), dtype = torch .float32 , device = hidden_states .device
1387+ )
13851388 elif rocm_aiter_ops .is_fused_moe_enabled ():
13861389 if not rocm_aiter_ops .is_fusion_moe_shared_experts_enabled ():
13871390 assert num_fused_shared_experts == 0
13881391 grouped_topk_impl = rocm_aiter_grouped_topk
13891392 else :
13901393 grouped_topk_impl = grouped_topk
13911394
1392- if (enable_fused_moe_router
1393- and e_score_correction_bias is not None
1394- and is_power_of_two (e_score_correction_bias .shape [0 ])):
1395+ if (
1396+ enable_fused_moe_router
1397+ and e_score_correction_bias is not None
1398+ and is_power_of_two (e_score_correction_bias .shape [0 ])
1399+ ):
13951400 # The fused kernel can only work with 128/256 experts
13961401 topk_weights , topk_ids = moe_fused_gate (
13971402 input_tensor = router_logits .to (dtype = torch .float32 ),
@@ -1401,7 +1406,8 @@ def select_experts(
14011406 topk = top_k ,
14021407 num_fused_shared_experts = num_fused_shared_experts ,
14031408 routed_scaling_factor = routed_scaling_factor
1404- if routed_scaling_factor is not None else 1.0 ,
1409+ if routed_scaling_factor is not None
1410+ else 1.0 ,
14051411 apply_routed_scaling_factor_on_output = False ,
14061412 )
14071413 else :
@@ -1415,7 +1421,7 @@ def select_experts(
14151421 scoring_func = scoring_func ,
14161422 routed_scaling_factor = routed_scaling_factor ,
14171423 e_score_correction_bias = e_score_correction_bias ,
1418- num_fused_shared_experts = num_fused_shared_experts
1424+ num_fused_shared_experts = num_fused_shared_experts ,
14191425 )
14201426 if indices_type is not None :
14211427 topk_ids = topk_ids .to (dtype = indices_type )
0 commit comments