@@ -521,8 +521,8 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
521521 pytest .param (
522522 (
523523 "eagle" ,
524- "meta-llama/Llama-3.1-8B -Instruct" ,
525- "yuhuili/EAGLE-LLaMA3.1-Instruct-8B " ,
524+ "meta-llama/Llama-3.2-1B -Instruct" ,
525+ "nm-testing/Llama3_2_1B_speculator.eagle3 " ,
526526 ),
527527 marks = large_gpu_mark (min_gb = 32 ),
528528 ),
@@ -541,7 +541,7 @@ def test_spec_decode_logprobs(
541541 """
542542 from vllm import LLM
543543
544- prompt = "Hello world"
544+ prompt = "Hello world " * 50
545545 sampling_params = SamplingParams (
546546 temperature = 0 , logprobs = 3 , max_tokens = 10 , ignore_eos = False
547547 )
@@ -582,6 +582,9 @@ def test_spec_decode_logprobs(
582582 seed = 42 ,
583583 logprobs_mode = logprobs_mode ,
584584 gpu_memory_utilization = 0.4 ,
585+ # Force prefill chunking
586+ enable_chunked_prefill = True ,
587+ max_num_batched_tokens = 32 ,
585588 )
586589 spec_results = spec_llm .generate ([prompt ], sampling_params )
587590 # Collect logprobs outputs from spec decode LLM.
@@ -597,6 +600,8 @@ def test_spec_decode_logprobs(
597600 # Per-token logprobs are expected to be the same.
598601 assert len (ref_logprobs ) == len (spec_logprobs )
599602 for ref_logprob , spec_logprob in zip (ref_logprobs , spec_logprobs ):
600- assert math .isclose (ref_logprob .logprob , spec_logprob .logprob , abs_tol = 1e-3 )
603+ assert math .isclose (
604+ ref_logprob .logprob , spec_logprob .logprob , rel_tol = 5e-2 , abs_tol = 1e-1
605+ )
601606 assert ref_logprob .rank == spec_logprob .rank
602607 assert ref_logprob .decoded_token == spec_logprob .decoded_token
0 commit comments