77
88from vllm import LLM , SamplingParams
99from vllm .config import CompilationConfig , CompilationMode
10+ from vllm .platforms import current_platform
1011
1112from ...utils import check_answers , fork_new_process_for_each_test , prep_prompts
1213
@@ -43,15 +44,26 @@ def test_prompts():
4344 return prompts
4445
4546
46- @fork_new_process_for_each_test
47+ use_fork_for_test = (
48+ fork_new_process_for_each_test if not current_platform .is_rocm () else lambda x : x
49+ )
50+
51+
52+ @use_fork_for_test
4753@pytest .mark .parametrize ("kv_sharing_fast_prefill" , [False , True ])
4854@pytest .mark .parametrize ("enforce_eager" , [True , False ])
4955def test_kv_sharing_fast_prefill (
5056 monkeypatch : pytest .MonkeyPatch ,
5157 kv_sharing_fast_prefill : bool ,
5258 enforce_eager : bool ,
53- test_prompts : list [str ],
5459):
60+ if not enforce_eager and current_platform .is_rocm ():
61+ # Relevant context: https://github.com/vllm-project/vllm/pull/29244
62+ pytest .skip (
63+ "ROCm: torch.compile produces incorrect output for gemma-3n's GELU "
64+ "with tanh approximation. Use enforce_eager=True instead."
65+ )
66+
5567 sampling_params = SamplingParams (temperature = 0.0 , max_tokens = 100 )
5668 compilation_config = CompilationConfig (
5769 # This allows vLLM compilation backend to handle allocating and
@@ -65,7 +77,11 @@ def test_kv_sharing_fast_prefill(
6577
6678 with monkeypatch .context () as m :
6779 # Make scheduling deterministic for reproducibility
68- m .setenv ("VLLM_ENABLE_V1_MULTIPROCESSING" , "0" )
80+ if current_platform .is_rocm ():
81+ # Use spawn to prevent cuda re-initialization error
82+ m .setenv ("VLLM_WORKER_MULTIPROC_METHOD" , "spawn" )
83+ else :
84+ m .setenv ("VLLM_ENABLE_V1_MULTIPROCESSING" , "0" )
6985
7086 prompts , answer , indices = prep_prompts (batch_size )
7187
0 commit comments