4848)
4949from vllm .platforms import current_platform
5050from vllm .utils .math_utils import cdiv , round_up
51- from vllm .utils .torch_utils import current_stream , direct_register_custom_op
51+ from vllm .utils .torch_utils import (
52+ aux_stream ,
53+ current_stream ,
54+ direct_register_custom_op ,
55+ )
5256from vllm .v1 .worker .ubatching import dbo_current_ubatch_id
5357
5458if current_platform .is_cuda_alike ():
@@ -331,7 +335,11 @@ def __init__(
331335 logger .info_once ("Disabling MoE shared_experts cuda stream" )
332336 self .shared_experts_stream = None
333337 else :
334- self .shared_experts_stream = torch .cuda .Stream ()
338+ # TODO(rob): enable shared expert overlap with non-cuda.
339+ # aux_stream() returns None on non-cuda platforms.
340+ self .shared_experts_stream = aux_stream ()
341+ if self .shared_experts_stream is not None :
342+ logger .info_once ("Enabled separate cuda stream for MoE shared_experts" )
335343
336344 if params_dtype is None :
337345 params_dtype = torch .get_default_dtype ()
@@ -1606,7 +1614,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
16061614 if has_separate_shared_experts :
16071615 assert not isinstance (final_hidden_states , tuple )
16081616 assert self .shared_experts is not None
1617+
16091618 shared_output = self .shared_experts (staged_hidden_states )
1619+
16101620 final_hidden_states = (
16111621 shared_output ,
16121622 final_hidden_states ,
@@ -1684,13 +1694,34 @@ def forward_impl(
16841694
16851695 use_chunked_impl = self .use_dp_chunking
16861696
1687- if (
1697+ use_shared_experts_stream = (
16881698 has_separate_shared_experts
16891699 and not use_chunked_impl
16901700 and self .shared_experts_stream is not None
1691- ):
1692- # Start the separate shared experts stream here since we want
1693- # to run in parallel with the router/gate (next op below)
1701+ and (
1702+ hidden_states .shape [0 ]
1703+ <= envs .VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
1704+ )
1705+ )
1706+
1707+ if use_shared_experts_stream :
1708+ assert self .shared_experts_stream is not None
1709+
1710+ # Clone BEFORE switching streams to avoid race condition
1711+ # where routed_expert kernel may mutate hidden_states.
1712+ hidden_states_clone = hidden_states .clone ()
1713+
1714+ # Record that the clone will be used by shared_experts_stream
1715+ # to avoid gc issue from deallocation of hidden_states_clone
1716+ # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
1717+ # NOTE: We dont need shared_output.record_stream(current_stream())
1718+ # because we synch the streams before using shared_output.
1719+ hidden_states_clone .record_stream (self .shared_experts_stream )
1720+
1721+ # Mark sync start point for the separate shared experts
1722+ # stream here since we want to run in parallel with the
1723+ # router/gate (next op below)
1724+ assert self .shared_experts_stream is not None
16941725 self .shared_experts_stream .wait_stream (current_stream ())
16951726
16961727 # If router/gate provided, then apply it here.
@@ -1709,33 +1740,6 @@ def forward_impl(
17091740 self .quant_method , FusedMoEModularMethod
17101741 )
17111742
1712- # If there are shared experts but we are not using a modular kernel, the
1713- # shared experts must be called here
1714- if has_separate_shared_experts :
1715- assert self .shared_experts is not None
1716-
1717- if self .shared_experts_stream is not None :
1718- # Clone BEFORE switching streams to avoid race condition
1719- # where routed_expert kernel may mutate hidden_states.
1720- hidden_states_clone = hidden_states .clone ()
1721- self .shared_experts_stream .wait_stream (current_stream ())
1722-
1723- # Run shared experts in parallel on a separate stream
1724- with torch .cuda .stream (self .shared_experts_stream ):
1725- shared_output = self .shared_experts (hidden_states_clone )
1726-
1727- # Record that the clone will be used by shared_experts_stream
1728- # to avoid gc issue from deallocation of hidden_states_clone
1729- # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
1730- # NOTE: we dont need shared_output.record_stream(current_stream())
1731- # because we synch the streams before using shared_output.
1732- hidden_states_clone .record_stream (self .shared_experts_stream )
1733-
1734- else :
1735- shared_output = self .shared_experts (hidden_states )
1736- else :
1737- shared_output = None
1738-
17391743 ctx = get_forward_context ()
17401744 sp_ctx = (
17411745 ctx .dp_metadata .sp_local_sizes (self .sp_size )
@@ -1776,12 +1780,21 @@ def forward_impl(
17761780 )
17771781
17781782 if has_separate_shared_experts :
1779- assert not isinstance (final_hidden_states , tuple )
17801783 assert self .shared_experts is not None
17811784
1782- # Wait for the parallel shared experts stream to finish here
1783- if self .shared_experts_stream is not None :
1785+ if use_shared_experts_stream :
1786+ # Run shared experts in parallel on a separate stream
1787+ # NOTE: We start the separate stream here and mark the
1788+ # sync end point immediately after it is done. This is
1789+ # important to avoid excessive stream allocations by the cuda
1790+ # graph replay later.
1791+ with torch .cuda .stream (self .shared_experts_stream ):
1792+ # Note that hidden_states clone() is necessary here to avoid
1793+ # conflict with the main stream
1794+ shared_output = self .shared_experts (hidden_states_clone )
17841795 current_stream ().wait_stream (self .shared_experts_stream )
1796+ else :
1797+ shared_output = self .shared_experts (hidden_states )
17851798
17861799 final_hidden_states = (
17871800 shared_output ,
0 commit comments