Skip to content

Commit f05fea1

Browse files
authored
[Core] Enable decode of context length equal to max model length (#26168)
Signed-off-by: Yannick Schnider <[email protected]>
1 parent d0df145 commit f05fea1

File tree

4 files changed

+32
-15
lines changed

4 files changed

+32
-15
lines changed

tests/entrypoints/llm/test_generate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,11 @@ def test_max_model_len():
8282
for output in outputs:
8383
num_total_tokens = len(output.prompt_token_ids) + len(
8484
output.outputs[0].token_ids)
85-
# Total tokens must not exceed max_model_len.
85+
# Total tokens must not exceed max_model_len + 1 (the last token can be
86+
# generated with the context length equal to the max model length)
8687
# It can be less if generation finishes due to other reasons (e.g., EOS)
8788
# before reaching the absolute model length limit.
88-
assert num_total_tokens <= max_model_len
89+
assert num_total_tokens <= max_model_len + 1
8990

9091

9192
def test_log_stats():

tests/v1/e2e/test_context_length.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,22 @@
44
end-to-end tests for context length corner cases of vLLM v1 model runner
55
versus HuggingFace's transformers.
66
7-
This test verifies the following behavior: allow a prefill that fills the
8-
model's maximum context length and then request a single new token.
7+
This test verifies the following behavior: allow prefill and decodes on the
8+
model's maximum context length ``max_model_len`` and get one more token.
99
1010
Test strategy
11-
- Build a textual prompt that tokenizes to exactly ``max_model_len`` tokens.
12-
- Run vLLM generation requesting a single new token (max_tokens=1).
13-
- Run HF generation on the same prompt requesting a single token too.
11+
- Build a prompt consisting of exactly ``prompt_len`` tokens.
12+
- Run vLLM generation requesting ``max_tokens`` new tokens.
13+
- Run HF generation on the same prompt requesting the same number of tokens.
1414
- Assert both return the same number of generated tokens and the same ids.
1515
16+
Test cases
17+
- Prefill a prompt of ``max_model_len`` (2048) and request a single token which
18+
will be sampled after the prefill (context length ``max_model_len``).
19+
- Prefill a prompt of ``max_model_len`` - 1 (2047) and request two tokens where
20+
the 1st will be sampled after the prefill and the 2nd after the first decode
21+
(context length ``max_model_len``).
22+
1623
"""
1724

1825
import pytest
@@ -27,11 +34,16 @@
2734

2835
@create_new_process_for_each_test()
2936
@pytest.mark.parametrize("model", ["JackFram/llama-160m"])
30-
@pytest.mark.parametrize("max_model_len", [2048])
31-
@pytest.mark.parametrize("max_tokens", [1])
32-
def test_prefill_max_context_length(
37+
@pytest.mark.parametrize(
38+
"prompt_len, max_tokens",
39+
[
40+
(2048, 1), # prompt_len = max_model_len
41+
(2047, 2), # prompt_len = max_model_len - 1
42+
],
43+
)
44+
def test_max_context_length(
3345
model: str,
34-
max_model_len: int,
46+
prompt_len: int,
3547
max_tokens: int,
3648
) -> None:
3749
"""Compare vLLM and HuggingFace when the prompt already fills the
@@ -42,8 +54,8 @@ def test_prefill_max_context_length(
4254
single token when given the same inputs.
4355
"""
4456

45-
# Construct a prompt of size max_model_len
46-
prompt_ids = [[43] * max_model_len]
57+
# Construct a prompt of size prompt_len
58+
prompt_ids = [[43] * prompt_len]
4759

4860
# Generate max_tokens new tokens deterministically.
4961
sampling_params = [
@@ -54,6 +66,7 @@ def test_prefill_max_context_length(
5466
llm = LLM(
5567
model=model,
5668
tokenizer=model,
69+
max_model_len=2048,
5770
max_num_seqs=1,
5871
tensor_parallel_size=1,
5972
)
@@ -81,6 +94,9 @@ def test_prefill_max_context_length(
8194
# HF returns the prompt + generated tokens. Slice off the prompt.
8295
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]):]
8396

97+
# check that exactly max_tokens tokens were generated with vLLM and HF
98+
assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens
99+
84100
# check that vLLM outputs (token ids) match HF outputs
85101
# Note: for simplicity don't pass detokenized string
86102
check_outputs_equal(

vllm/v1/core/sched/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def schedule(self) -> SchedulerOutput:
224224
# This is necessary when using spec decoding.
225225
num_new_tokens = min(
226226
num_new_tokens,
227-
self.max_model_len - 1 - request.num_computed_tokens)
227+
self.max_model_len - request.num_computed_tokens)
228228

229229
# Schedule encoder inputs.
230230
encoder_inputs_to_schedule = None

vllm/v1/core/sched/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def remove_all(lst: list, items_to_remove: set) -> list:
4343
def check_stop(request: Request,
4444
max_model_len: int,
4545
pooler_output: Optional[torch.Tensor] = None) -> bool:
46-
if (request.num_tokens >= max_model_len
46+
if (request.num_tokens > max_model_len
4747
or request.num_output_tokens >= request.max_tokens):
4848
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
4949
return True

0 commit comments

Comments
 (0)