Skip to content

Commit 5dcd593

Browse files
[Feature] Batch-Invariant Support for FA2 and LoRA (#30018)
Signed-off-by: quanliu <[email protected]> Co-authored-by: Wentao Ye <[email protected]>
1 parent 5c213d2 commit 5dcd593

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

tests/v1/determinism/test_batch_invariance.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
BACKENDS,
1111
_extract_step_logprobs,
1212
_random_prompt,
13+
is_device_capability_below_90,
1314
resolve_model_name,
1415
skip_unsupported,
1516
)
1617

1718
import vllm.model_executor.layers.batch_invariant as batch_invariant
1819
from vllm import LLM, SamplingParams
1920

21+
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
22+
2023

2124
@skip_unsupported
2225
@pytest.mark.timeout(1000)
@@ -190,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
190193
max_model_len=8192,
191194
dtype="bfloat16", # not everything is supported
192195
gpu_memory_utilization=0.9,
196+
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
193197
)
194198

195199
# Use more realistic prompts for better token generation
@@ -393,6 +397,8 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
393397
gpu_memory_utilization=0.9,
394398
max_model_len=2048,
395399
dtype="bfloat16",
400+
enable_prefix_caching=False,
401+
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
396402
)
397403

398404
prompt = "the capital of france is"
@@ -459,6 +465,7 @@ def test_logprobs_without_batch_invariance_should_fail(
459465
max_num_seqs=32,
460466
max_model_len=8192,
461467
dtype="bfloat16",
468+
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
462469
)
463470

464471
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
@@ -682,6 +689,7 @@ def test_decode_logprobs_match_prefill_logprobs(
682689
max_num_seqs=32,
683690
max_model_len=8192,
684691
dtype="bfloat16",
692+
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
685693
)
686694

687695
# Use a few test prompts
@@ -925,6 +933,8 @@ def LLM_with_max_seqs(
925933
max_model_len=max_model_len,
926934
dtype="bfloat16",
927935
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
936+
enable_prefix_caching=False,
937+
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
928938
# Enable for MOE models
929939
# enable_expert_parallel=True,
930940
)

tests/v1/determinism/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from vllm.utils.flashinfer import has_flashinfer
1212

1313
skip_unsupported = pytest.mark.skipif(
14-
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
15-
reason="Requires CUDA and >= Hopper (SM90)",
14+
not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
15+
# Supports testing on Ampere and Ada Lovelace devices.
16+
# Note: For devices with SM < 90, batch invariance does not support CUDA Graphs.
17+
reason="Requires CUDA and >= Ampere (SM80)",
1618
)
1719

1820
BACKENDS: list[str] = [
@@ -97,3 +99,7 @@ def _extract_step_logprobs(request_output):
9799
return t, inner.token_ids
98100

99101
return None, None
102+
103+
104+
def is_device_capability_below_90() -> bool:
105+
return not current_platform.has_device_capability(90)

vllm/model_executor/layers/batch_invariant.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,11 @@ def enable_batch_invariant_mode():
935935

936936
# Batch invariant matmuls are no longer needed after cublas overrides
937937
if not is_torch_equal_or_newer("2.10.0.dev"):
938-
if current_platform.is_device_capability(100):
938+
if (
939+
current_platform.is_device_capability(100)
940+
or current_platform.is_device_capability(80)
941+
or current_platform.is_device_capability(89)
942+
):
939943
# For PyTorch 2.9, B200 uses GEMV for bs=1
940944
# Requires https://github.com/pytorch/pytorch/pull/166735
941945
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")

0 commit comments

Comments
 (0)