Skip to content

Commit 96b06fb

Browse files
Merge remote-tracking branch 'happyamazonian/restart' into restart
2 parents dc2bd2a + 1bfb898 commit 96b06fb

File tree

3 files changed

+81
-36
lines changed

3 files changed

+81
-36
lines changed

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@
222222
VLLM_USE_FBGEMM: bool = False
223223
VLLM_GC_DEBUG: str = ""
224224
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
225+
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
225226
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
226227
VLLM_FLAT_LOGPROBS: bool = False
227228

@@ -1476,6 +1477,13 @@ def get_vllm_port() -> int | None:
14761477
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
14771478
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
14781479
),
1480+
# Limits when we run shared_experts in a separate stream.
1481+
# We found out that for large batch sizes, the separate stream
1482+
# execution is not beneficial (most likely because of the input clone)
1483+
# TODO(alexm-redhat): Tune to be more dynamic based on GPU type
1484+
"VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD": lambda: int(
1485+
int(os.getenv("VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD", 256))
1486+
),
14791487
# Format for saving torch.compile cache artifacts
14801488
# - "binary": saves as binary file
14811489
# Safe for multiple vllm serve processes accessing the same torch compile cache.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@
4848
)
4949
from vllm.platforms import current_platform
5050
from vllm.utils.math_utils import cdiv, round_up
51-
from vllm.utils.torch_utils import current_stream, direct_register_custom_op
51+
from vllm.utils.torch_utils import (
52+
aux_stream,
53+
current_stream,
54+
direct_register_custom_op,
55+
)
5256
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
5357

5458
if current_platform.is_cuda_alike():
@@ -331,7 +335,11 @@ def __init__(
331335
logger.info_once("Disabling MoE shared_experts cuda stream")
332336
self.shared_experts_stream = None
333337
else:
334-
self.shared_experts_stream = torch.cuda.Stream()
338+
# TODO(rob): enable shared expert overlap with non-cuda.
339+
# aux_stream() returns None on non-cuda platforms.
340+
self.shared_experts_stream = aux_stream()
341+
if self.shared_experts_stream is not None:
342+
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
335343

336344
if params_dtype is None:
337345
params_dtype = torch.get_default_dtype()
@@ -1606,7 +1614,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
16061614
if has_separate_shared_experts:
16071615
assert not isinstance(final_hidden_states, tuple)
16081616
assert self.shared_experts is not None
1617+
16091618
shared_output = self.shared_experts(staged_hidden_states)
1619+
16101620
final_hidden_states = (
16111621
shared_output,
16121622
final_hidden_states,
@@ -1684,13 +1694,34 @@ def forward_impl(
16841694

16851695
use_chunked_impl = self.use_dp_chunking
16861696

1687-
if (
1697+
use_shared_experts_stream = (
16881698
has_separate_shared_experts
16891699
and not use_chunked_impl
16901700
and self.shared_experts_stream is not None
1691-
):
1692-
# Start the separate shared experts stream here since we want
1693-
# to run in parallel with the router/gate (next op below)
1701+
and (
1702+
hidden_states.shape[0]
1703+
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
1704+
)
1705+
)
1706+
1707+
if use_shared_experts_stream:
1708+
assert self.shared_experts_stream is not None
1709+
1710+
# Clone BEFORE switching streams to avoid race condition
1711+
# where routed_expert kernel may mutate hidden_states.
1712+
hidden_states_clone = hidden_states.clone()
1713+
1714+
# Record that the clone will be used by shared_experts_stream
1715+
# to avoid gc issue from deallocation of hidden_states_clone
1716+
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
1717+
# NOTE: We dont need shared_output.record_stream(current_stream())
1718+
# because we synch the streams before using shared_output.
1719+
hidden_states_clone.record_stream(self.shared_experts_stream)
1720+
1721+
# Mark sync start point for the separate shared experts
1722+
# stream here since we want to run in parallel with the
1723+
# router/gate (next op below)
1724+
assert self.shared_experts_stream is not None
16941725
self.shared_experts_stream.wait_stream(current_stream())
16951726

16961727
# If router/gate provided, then apply it here.
@@ -1709,33 +1740,6 @@ def forward_impl(
17091740
self.quant_method, FusedMoEModularMethod
17101741
)
17111742

1712-
# If there are shared experts but we are not using a modular kernel, the
1713-
# shared experts must be called here
1714-
if has_separate_shared_experts:
1715-
assert self.shared_experts is not None
1716-
1717-
if self.shared_experts_stream is not None:
1718-
# Clone BEFORE switching streams to avoid race condition
1719-
# where routed_expert kernel may mutate hidden_states.
1720-
hidden_states_clone = hidden_states.clone()
1721-
self.shared_experts_stream.wait_stream(current_stream())
1722-
1723-
# Run shared experts in parallel on a separate stream
1724-
with torch.cuda.stream(self.shared_experts_stream):
1725-
shared_output = self.shared_experts(hidden_states_clone)
1726-
1727-
# Record that the clone will be used by shared_experts_stream
1728-
# to avoid gc issue from deallocation of hidden_states_clone
1729-
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
1730-
# NOTE: we dont need shared_output.record_stream(current_stream())
1731-
# because we synch the streams before using shared_output.
1732-
hidden_states_clone.record_stream(self.shared_experts_stream)
1733-
1734-
else:
1735-
shared_output = self.shared_experts(hidden_states)
1736-
else:
1737-
shared_output = None
1738-
17391743
ctx = get_forward_context()
17401744
sp_ctx = (
17411745
ctx.dp_metadata.sp_local_sizes(self.sp_size)
@@ -1776,12 +1780,21 @@ def forward_impl(
17761780
)
17771781

17781782
if has_separate_shared_experts:
1779-
assert not isinstance(final_hidden_states, tuple)
17801783
assert self.shared_experts is not None
17811784

1782-
# Wait for the parallel shared experts stream to finish here
1783-
if self.shared_experts_stream is not None:
1785+
if use_shared_experts_stream:
1786+
# Run shared experts in parallel on a separate stream
1787+
# NOTE: We start the separate stream here and mark the
1788+
# sync end point immediately after it is done. This is
1789+
# important to avoid excessive stream allocations by the cuda
1790+
# graph replay later.
1791+
with torch.cuda.stream(self.shared_experts_stream):
1792+
# Note that hidden_states clone() is necessary here to avoid
1793+
# conflict with the main stream
1794+
shared_output = self.shared_experts(hidden_states_clone)
17841795
current_stream().wait_stream(self.shared_experts_stream)
1796+
else:
1797+
shared_output = self.shared_experts(hidden_states)
17851798

17861799
final_hidden_states = (
17871800
shared_output,

vllm/utils/torch_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,30 @@ def current_stream() -> torch.cuda.Stream:
409409
return _current_stream_tls.value
410410

411411

412+
# Global auxilary stream for running operations in background streams.
413+
# We have single global auxilary stream to avoid an explosion of streams
414+
# for every layer (and make profiling look sane).
415+
#
416+
# aux_stream() is currently used for:
417+
# - MoE shared_expert overlap with router
418+
_aux_stream: torch.cuda.Stream | None = None
419+
420+
421+
def aux_stream() -> torch.cuda.Stream | None:
422+
"""
423+
Ensures aux_stream is initialized only once
424+
"""
425+
global _aux_stream
426+
427+
from vllm.platforms import current_platform
428+
429+
# TODO: validate this works properly on ROCm platform.
430+
if _aux_stream is None and current_platform.is_cuda():
431+
_aux_stream = torch.cuda.Stream()
432+
433+
return _aux_stream
434+
435+
412436
@lru_cache(maxsize=8)
413437
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
414438
# Note: cuda_visible_devices is not used, but we keep it as an argument for

0 commit comments

Comments
 (0)