Skip to content

Commit d44a63c

Browse files
authored
[BugFix] Fix returned logprobs with spec decode + prefill chunking (#29216)
Signed-off-by: Nick Hill <[email protected]>
1 parent 066209a commit d44a63c

File tree

3 files changed

+22
-15
lines changed

3 files changed

+22
-15
lines changed

tests/v1/sample/test_logprobs.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,8 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
521521
pytest.param(
522522
(
523523
"eagle",
524-
"meta-llama/Llama-3.1-8B-Instruct",
525-
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
524+
"meta-llama/Llama-3.2-1B-Instruct",
525+
"nm-testing/Llama3_2_1B_speculator.eagle3",
526526
),
527527
marks=large_gpu_mark(min_gb=32),
528528
),
@@ -541,7 +541,7 @@ def test_spec_decode_logprobs(
541541
"""
542542
from vllm import LLM
543543

544-
prompt = "Hello world"
544+
prompt = "Hello world " * 50
545545
sampling_params = SamplingParams(
546546
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
547547
)
@@ -582,6 +582,9 @@ def test_spec_decode_logprobs(
582582
seed=42,
583583
logprobs_mode=logprobs_mode,
584584
gpu_memory_utilization=0.4,
585+
# Force prefill chunking
586+
enable_chunked_prefill=True,
587+
max_num_batched_tokens=32,
585588
)
586589
spec_results = spec_llm.generate([prompt], sampling_params)
587590
# Collect logprobs outputs from spec decode LLM.
@@ -597,6 +600,8 @@ def test_spec_decode_logprobs(
597600
# Per-token logprobs are expected to be the same.
598601
assert len(ref_logprobs) == len(spec_logprobs)
599602
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
600-
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
603+
assert math.isclose(
604+
ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1
605+
)
601606
assert ref_logprob.rank == spec_logprob.rank
602607
assert ref_logprob.decoded_token == spec_logprob.decoded_token

vllm/v1/sample/sampler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ def forward(
8181
if logprobs_mode == "raw_logprobs":
8282
raw_logprobs = self.compute_logprobs(logits)
8383
elif logprobs_mode == "raw_logits":
84-
raw_logprobs = logits.clone()
84+
if logits.dtype == torch.float32:
85+
raw_logprobs = logits.clone()
86+
else:
87+
raw_logprobs = logits.to(torch.float32)
8588

8689
# Use float32 for the logits.
8790
logits = logits.to(torch.float32)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2466,7 +2466,9 @@ def _bookkeeping_sync(
24662466

24672467
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
24682468
sampled_token_ids = sampler_output.sampled_token_ids
2469+
logprobs_tensors = sampler_output.logprobs_tensors
24692470
invalid_req_indices = []
2471+
cu_num_new_tokens: list[int] | None = None
24702472
if not self.use_async_scheduling:
24712473
# Get the valid generated tokens.
24722474
max_gen_len = sampled_token_ids.shape[-1]
@@ -2479,6 +2481,12 @@ def _bookkeeping_sync(
24792481
sampled_token_ids,
24802482
self.input_batch.vocab_size,
24812483
)
2484+
if logprobs_tensors:
2485+
# Needed for extracting logprobs when spec decoding.
2486+
# This must be done prior to discarding sampled tokens.
2487+
cu_num_new_tokens = [0]
2488+
for toks in valid_sampled_token_ids:
2489+
cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
24822490
# Mask out the sampled tokens that should not be sampled.
24832491
for i in discard_sampled_tokens_req_indices:
24842492
valid_sampled_token_ids[int(i)].clear()
@@ -2506,10 +2514,6 @@ def _bookkeeping_sync(
25062514
# the sampled tokens back, because there's no direct communication
25072515
# between the first-stage worker and the last-stage worker.
25082516
req_ids = self.input_batch.req_ids
2509-
logprobs_tensors = sampler_output.logprobs_tensors
2510-
cu_num_accepted_tokens = (
2511-
[0] if spec_decode_metadata and logprobs_tensors else None
2512-
)
25132517
for req_idx in range(num_sampled_tokens):
25142518
if self.use_async_scheduling:
25152519
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
@@ -2518,11 +2522,6 @@ def _bookkeeping_sync(
25182522

25192523
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
25202524

2521-
if cu_num_accepted_tokens is not None:
2522-
cu_num_accepted_tokens.append(
2523-
cu_num_accepted_tokens[-1] + num_sampled_ids
2524-
)
2525-
25262525
if not sampled_ids:
25272526
continue
25282527

@@ -2544,7 +2543,7 @@ def _bookkeeping_sync(
25442543
req_state.output_token_ids.extend(sampled_ids)
25452544

25462545
logprobs_lists = (
2547-
logprobs_tensors.tolists(cu_num_accepted_tokens)
2546+
logprobs_tensors.tolists(cu_num_new_tokens)
25482547
if not self.use_async_scheduling and logprobs_tensors is not None
25492548
else None
25502549
)

0 commit comments

Comments
 (0)