Skip to content

Commit afb1e5b

Browse files
authored
[CI][ROCm][tests/v1/e2e] Fix multiprocessing launch for the test (#29123)
Signed-off-by: Divakar Verma <[email protected]>
1 parent 1c593e1 commit afb1e5b

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

tests/v1/e2e/test_kv_sharing_fast_prefill.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm import LLM, SamplingParams
99
from vllm.config import CompilationConfig, CompilationMode
10+
from vllm.platforms import current_platform
1011

1112
from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
1213

@@ -43,15 +44,26 @@ def test_prompts():
4344
return prompts
4445

4546

46-
@fork_new_process_for_each_test
47+
use_fork_for_test = (
48+
fork_new_process_for_each_test if not current_platform.is_rocm() else lambda x: x
49+
)
50+
51+
52+
@use_fork_for_test
4753
@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
4854
@pytest.mark.parametrize("enforce_eager", [True, False])
4955
def test_kv_sharing_fast_prefill(
5056
monkeypatch: pytest.MonkeyPatch,
5157
kv_sharing_fast_prefill: bool,
5258
enforce_eager: bool,
53-
test_prompts: list[str],
5459
):
60+
if not enforce_eager and current_platform.is_rocm():
61+
# Relevant context: https://github.com/vllm-project/vllm/pull/29244
62+
pytest.skip(
63+
"ROCm: torch.compile produces incorrect output for gemma-3n's GELU "
64+
"with tanh approximation. Use enforce_eager=True instead."
65+
)
66+
5567
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
5668
compilation_config = CompilationConfig(
5769
# This allows vLLM compilation backend to handle allocating and
@@ -65,7 +77,11 @@ def test_kv_sharing_fast_prefill(
6577

6678
with monkeypatch.context() as m:
6779
# Make scheduling deterministic for reproducibility
68-
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
80+
if current_platform.is_rocm():
81+
# Use spawn to prevent cuda re-initialization error
82+
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
83+
else:
84+
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
6985

7086
prompts, answer, indices = prep_prompts(batch_size)
7187

0 commit comments

Comments
 (0)