Skip to content

Commit c72ea10

Browse files
authored
[Structured Output][Reasoning] Improves decoding throughput for models using single-token reasoning endings. (#30056)
1 parent 67475a6 commit c72ea10

File tree

10 files changed

+89
-1
lines changed

10 files changed

+89
-1
lines changed

docs/features/reasoning_outputs.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,9 @@ Additionally, to enable structured output, you'll need to create a new `Reasoner
299299

300300
def is_reasoning_end(self, input_ids: list[int]) -> bool:
301301
return self.end_token_id in input_ids
302+
303+
def is_reasoning_end_streaming(self, input_ids: list[int], delta_ids: list[int]) -> bool:
304+
return self.end_token_id in delta_token_ids
302305
...
303306
```
304307

tests/reasoning/test_base_thinking_reasoning_parser.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,41 @@ def test_is_reasoning_end(self, test_tokenizer):
132132
is False
133133
)
134134

135+
def test_is_reasoning_end_streaming(self, test_tokenizer):
136+
"""Test the is_reasoning_end_streaming method."""
137+
parser = TestThinkingReasoningParser(test_tokenizer)
138+
end_token_id = parser.end_token_id
139+
start_token_id = parser.start_token_id
140+
141+
assert (
142+
parser.is_reasoning_end_streaming([1, 2, end_token_id], [end_token_id])
143+
is True
144+
)
145+
assert parser.is_reasoning_end_streaming([1, 2, 3, 4], [4]) is False
146+
assert parser.is_reasoning_end_streaming([], []) is False
147+
assert (
148+
parser.is_reasoning_end_streaming(
149+
[1, start_token_id, 2, end_token_id], [end_token_id]
150+
)
151+
is True
152+
)
153+
assert (
154+
parser.is_reasoning_end_streaming([1, start_token_id, 2, 3], [3]) is False
155+
)
156+
assert (
157+
parser.is_reasoning_end_streaming(
158+
[1, start_token_id, 2, end_token_id, 2, start_token_id, 2],
159+
[2],
160+
)
161+
is False
162+
)
163+
assert (
164+
parser.is_reasoning_end_streaming(
165+
[1, start_token_id, 2, end_token_id, 2, 2], [2]
166+
)
167+
is False
168+
)
169+
135170
def test_extract_content_ids(self, test_tokenizer):
136171
"""Test the extract_content_ids method."""
137172
parser = TestThinkingReasoningParser(test_tokenizer)

tests/reasoning/test_deepseekv3_reasoning_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def test_identity_reasoning_parser_basic(tokenizer):
4040
input_tokens = tokenizer.tokenize(input_text)
4141
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
4242
assert parser.is_reasoning_end(input_ids) is True
43+
assert parser.is_reasoning_end_streaming(input_ids, input_ids) is True
4344

4445
# Test extract_content_ids returns all input_ids
4546
assert parser.extract_content_ids(input_ids) == input_ids

tests/v1/structured_output/test_reasoning_structured_output.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def mock_request_with_structured_output(self):
7070
request.use_structured_output = True
7171
request.prompt_token_ids = [1, 2, 3, 4, 5]
7272
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
73+
request.num_computed_tokens = 5
7374
return request
7475

7576
def test_should_fill_bitmask_with_enable_in_reasoning(

vllm/reasoning/abs_reasoning_parsers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,31 @@ def is_reasoning_end(self, input_ids: list[int]) -> bool:
6363
True if the reasoning content ends in the input_ids.
6464
"""
6565

66+
def is_reasoning_end_streaming(
67+
self, input_ids: list[int], delta_ids: list[int]
68+
) -> bool:
69+
"""
70+
Check if the reasoning content ends in the input_ids on a
71+
decode step.
72+
73+
It is used in structured engines like `xgrammar` to check if the
74+
reasoning content ends in the model output during a decode step.
75+
`input_ids` the entire model output and `delta_ids` are the last few
76+
computed tokens of the model output (like during a decode step).
77+
78+
Parameters:
79+
input_ids: list[int]
80+
The entire model output.
81+
delta_ids: list[int]
82+
The last few computed tokens of the model output at the current decode step.
83+
84+
Returns:
85+
bool
86+
True if the reasoning content ends in the `delta_ids` on a
87+
decode step.
88+
"""
89+
return self.is_reasoning_end(input_ids)
90+
6691
@abstractmethod
6792
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
6893
"""

vllm/reasoning/basic_parsers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ def is_reasoning_end(self, input_ids: list[int]) -> bool:
7474
return True
7575
return False
7676

77+
def is_reasoning_end_streaming(
78+
self, input_ids: list[int], delta_ids: list[int]
79+
) -> bool:
80+
end_token_id = self.end_token_id
81+
return end_token_id in delta_ids
82+
7783
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
7884
"""
7985
Extract the content after the end tokens

vllm/reasoning/deepseek_v3_reasoning_parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
3535
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
3636
return self._parser.is_reasoning_end(input_ids)
3737

38+
def is_reasoning_end_streaming(
39+
self, input_ids: list[int], delta_ids: list[int]
40+
) -> bool:
41+
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
42+
3843
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
3944
return self._parser.extract_content_ids(input_ids)
4045

vllm/reasoning/holo2_reasoning_parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
5656
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
5757
return self._parser.is_reasoning_end(input_ids)
5858

59+
def is_reasoning_end_streaming(
60+
self, input_ids: list[int], delta_ids: list[int]
61+
) -> bool:
62+
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
63+
5964
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
6065
return self._parser.extract_content_ids(input_ids)
6166

vllm/reasoning/identity_reasoning_parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ def is_reasoning_end(self, input_ids: list[int]) -> bool:
3232
# Always return True, since we never treat reasoning specially
3333
return True
3434

35+
def is_reasoning_end_streaming(
36+
self, input_ids: list[int], delta_ids: list[int]
37+
) -> bool:
38+
return True
39+
3540
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
3641
# Identity: return all tokens as content
3742
return input_ids

vllm/v1/structured_output/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,9 @@ def should_advance(self, request: Request) -> bool:
339339
return True
340340

341341
# Check if reasoning ends in *this* step
342-
if self.reasoner.is_reasoning_end(request.all_token_ids):
342+
if self.reasoner.is_reasoning_end_streaming(
343+
request.all_token_ids, request.all_token_ids[request.num_computed_tokens :]
344+
):
343345
# Reasoning just ended, so we shouldn't advance til
344346
# next pass
345347
structured_req.reasoning_ended = True

0 commit comments

Comments
 (0)