44end-to-end tests for context length corner cases of vLLM v1 model runner
55versus HuggingFace's transformers.
66
7- This test verifies the following behavior: allow a prefill that fills the
8- model's maximum context length and then request a single new token.
7+ This test verifies the following behavior: allow prefill and decodes on the
8+ model's maximum context length ``max_model_len`` and get one more token.
99
1010Test strategy
11- - Build a textual prompt that tokenizes to exactly ``max_model_len `` tokens.
12- - Run vLLM generation requesting a single new token (max_tokens=1) .
13- - Run HF generation on the same prompt requesting a single token too .
11+ - Build a prompt consisting of exactly ``prompt_len `` tokens.
12+ - Run vLLM generation requesting ``max_tokens`` new tokens .
13+ - Run HF generation on the same prompt requesting the same number of tokens .
1414- Assert both return the same number of generated tokens and the same ids.
1515
16+ Test cases
17+ - Prefill a prompt of ``max_model_len`` (2048) and request a single token which
18+ will be sampled after the prefill (context length ``max_model_len``).
19+ - Prefill a prompt of ``max_model_len`` - 1 (2047) and request two tokens where
20+ the 1st will be sampled after the prefill and the 2nd after the first decode
21+ (context length ``max_model_len``).
22+
1623"""
1724
1825import pytest
2734
2835@create_new_process_for_each_test ()
2936@pytest .mark .parametrize ("model" , ["JackFram/llama-160m" ])
30- @pytest .mark .parametrize ("max_model_len" , [2048 ])
31- @pytest .mark .parametrize ("max_tokens" , [1 ])
32- def test_prefill_max_context_length (
37+ @pytest .mark .parametrize (
38+ "prompt_len, max_tokens" ,
39+ [
40+ (2048 , 1 ), # prompt_len = max_model_len
41+ (2047 , 2 ), # prompt_len = max_model_len - 1
42+ ],
43+ )
44+ def test_max_context_length (
3345 model : str ,
34- max_model_len : int ,
46+ prompt_len : int ,
3547 max_tokens : int ,
3648) -> None :
3749 """Compare vLLM and HuggingFace when the prompt already fills the
@@ -42,8 +54,8 @@ def test_prefill_max_context_length(
4254 single token when given the same inputs.
4355 """
4456
45- # Construct a prompt of size max_model_len
46- prompt_ids = [[43 ] * max_model_len ]
57+ # Construct a prompt of size prompt_len
58+ prompt_ids = [[43 ] * prompt_len ]
4759
4860 # Generate max_tokens new tokens deterministically.
4961 sampling_params = [
@@ -54,6 +66,7 @@ def test_prefill_max_context_length(
5466 llm = LLM (
5567 model = model ,
5668 tokenizer = model ,
69+ max_model_len = 2048 ,
5770 max_num_seqs = 1 ,
5871 tensor_parallel_size = 1 ,
5972 )
@@ -81,6 +94,9 @@ def test_prefill_max_context_length(
8194 # HF returns the prompt + generated tokens. Slice off the prompt.
8295 hf_output_ids = hf_generated .cpu ().tolist ()[0 ][len (prompt_ids [0 ]):]
8396
97+ # check that exactly max_tokens tokens were generated with vLLM and HF
98+ assert len (vllm_output_ids ) == len (hf_output_ids ) == max_tokens
99+
84100 # check that vLLM outputs (token ids) match HF outputs
85101 # Note: for simplicity don't pass detokenized string
86102 check_outputs_equal (
0 commit comments