Skip to content

Commit d941709

Browse files
authored
[Feature] Batch invariant: Enable TRITON_MLA without prefix-caching (#29125)
Signed-off-by: yewentao256 <[email protected]>
1 parent 9d6235c commit d941709

File tree

5 files changed

+43
-7
lines changed

5 files changed

+43
-7
lines changed

tests/v1/determinism/test_batch_invariance.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
185185
llm = LLM(
186186
model=model_name,
187187
tensor_parallel_size=tp_size,
188-
enable_prefix_caching=False,
188+
# enable_prefix_caching=False,
189189
max_num_seqs=32,
190190
max_model_len=8192,
191191
dtype="bfloat16", # not everything is supported
@@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
393393
gpu_memory_utilization=0.9,
394394
max_model_len=2048,
395395
dtype="bfloat16",
396-
enable_prefix_caching=False,
397396
)
398397

399398
prompt = "the capital of france is"
@@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail(
457456
llm = LLM(
458457
model=model_name,
459458
tensor_parallel_size=tp_size,
460-
enable_prefix_caching=False,
461459
max_num_seqs=32,
462460
max_model_len=8192,
463461
dtype="bfloat16",
@@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs(
681679
llm = LLM(
682680
model=model_name,
683681
tensor_parallel_size=tp_size,
684-
enable_prefix_caching=False,
685682
max_num_seqs=32,
686683
max_model_len=8192,
687684
dtype="bfloat16",
@@ -928,7 +925,6 @@ def LLM_with_max_seqs(
928925
max_model_len=max_model_len,
929926
dtype="bfloat16",
930927
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
931-
enable_prefix_caching=False,
932928
# Enable for MOE models
933929
# enable_expert_parallel=True,
934930
)

tests/v1/determinism/test_online_batch_invariance.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
153153
}
154154

155155
tp_size = os.getenv("VLLM_TP_SIZE", "1")
156-
server_args: list[str] = []
156+
server_args: list[str] = [
157+
"--max-model-len=8192",
158+
"--max-num-seqs=32",
159+
]
157160
if tp_size:
158161
server_args += ["-tp", tp_size]
159162

tests/v1/determinism/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
BACKENDS: list[str] = [
1919
"FLASH_ATTN",
20+
"TRITON_MLA",
2021
]
2122

2223
if has_flashinfer():

vllm/attention/layer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from vllm.forward_context import ForwardContext, get_forward_context
2626
from vllm.logger import init_logger
2727
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
28+
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
2829
from vllm.model_executor.layers.linear import (
2930
ColumnParallelLinear,
3031
UnquantizedLinearMethod,
@@ -251,6 +252,24 @@ def __init__(
251252
else:
252253
self.attn_backend = attn_backend
253254

255+
# prefix caching + batch invariance is currently not supported for
256+
# FLASHINFER and TRITON_MLA.
257+
if (
258+
cache_config is not None
259+
and cache_config.enable_prefix_caching
260+
and vllm_is_batch_invariant()
261+
and (
262+
self.attn_backend.get_name() == "FLASHINFER"
263+
or self.attn_backend.get_name() == "TRITON_MLA"
264+
)
265+
):
266+
logger.warning_once(
267+
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
268+
"with batch invariance, as it is not yet supported.",
269+
scope="local",
270+
)
271+
cache_config.enable_prefix_caching = False
272+
254273
impl_cls = self.attn_backend.get_impl_cls()
255274
self.impl = impl_cls(
256275
num_heads,
@@ -628,6 +647,23 @@ def __init__(
628647
use_mla=True,
629648
use_sparse=use_sparse,
630649
)
650+
651+
if (
652+
cache_config is not None
653+
and cache_config.enable_prefix_caching
654+
and vllm_is_batch_invariant()
655+
and (
656+
self.attn_backend.get_name() == "TRITON_MLA"
657+
or self.attn_backend.get_name() == "FLASHINFER"
658+
)
659+
):
660+
logger.warning_once(
661+
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
662+
"with batch invariance, as it is not yet supported.",
663+
scope="local",
664+
)
665+
cache_config.enable_prefix_caching = False
666+
631667
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
632668
self.impl = impl_cls(
633669
num_heads=self.num_heads,

vllm/model_executor/layers/batch_invariant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,11 +1006,11 @@ def override_envs_for_invariance():
10061006
"FLASH_ATTN", # best supported backend
10071007
"FLASHINFER",
10081008
"FLASH_ATTN_MLA",
1009+
"TRITON_MLA",
10091010
# Not yet supported MLA backends
10101011
# "FLASHMLA",
10111012
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
10121013
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
1013-
# "TRITON_MLA",
10141014
]
10151015
if curr_attn_backend not in supported_backends:
10161016
error = (

0 commit comments

Comments
 (0)