diff --git a/tests/tool_use/test_deepseek_v32_tool_parser.py b/tests/tool_use/test_deepseek_v32_tool_parser.py new file mode 100644 index 000000000000..4317a9c59b8f --- /dev/null +++ b/tests/tool_use/test_deepseek_v32_tool_parser.py @@ -0,0 +1,546 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.deepseek_v32_tool_parser import ( + DeepSeekV32ToolParser, +) +from vllm.tokenizers import get_tokenizer + +pytestmark = pytest.mark.cpu_test + +MODEL = "deepseek-ai/DeepSeek-V3" + + +@pytest.fixture(scope="module") +def deepseek_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def deepseek_tool_parser(deepseek_tokenizer): + return DeepSeekV32ToolParser(deepseek_tokenizer) + + +@pytest.fixture +def sample_tools(): + return [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_weather", + "description": "Get the weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city name"}, + "date": {"type": "string", "description": "The date"}, + }, + "required": ["location", "date"], + }, + }, + ), + ] + + +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): + assert actual_tool_call.type == "function" + assert actual_tool_call.function.name == expected_tool_call.function.name + try: + assert json.loads(actual_tool_call.function.arguments) == json.loads( + expected_tool_call.function.arguments + ) + except json.JSONDecodeError as e: + print(e) + print("actual_tool_call", actual_tool_call.function.arguments) + print("expected_tool_call", expected_tool_call.function.arguments) + + +def test_extract_tool_calls_single_function( + deepseek_tool_parser, + sample_tools, +): + """Test extracting a single function call""" + model_output = """<|DSML|function_calls> +<|DSML|invoke name="get_weather"> +<|DSML|parameter name="location" string="true">Hangzhou +<|DSML|parameter name="date" string="true">2024-01-16 + +""" + + expected_tool_calls = [ + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps({"location": "Hangzhou", "date": "2024-01-16"}), + ) + ), + ] + + request = ChatCompletionRequest( + model="deepseek-v3", + messages=[{"role": "user", "content": "What is the weather?"}], + tools=sample_tools, + ) + + extracted = deepseek_tool_parser.extract_tool_calls(model_output, request) + + assert extracted.tools_called + assert extracted.content is None + assert_tool_calls(extracted.tool_calls, expected_tool_calls) + + +def test_extract_tool_calls_multiple_functions( + deepseek_tool_parser, + sample_tools, +): + """Test extracting multiple function calls""" + model_output = """<|DSML|function_calls> +<|DSML|invoke name="get_weather"> +<|DSML|parameter name="location" string="true">Hangzhou +<|DSML|parameter name="date" string="true">2024-01-16 + +<|DSML|invoke name="get_weather"> +<|DSML|parameter name="location" string="true">Beijing +<|DSML|parameter name="date" string="true">2024-01-16 + +""" + + expected_tool_calls = [ + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps({"location": "Hangzhou", "date": "2024-01-16"}), + ) + ), + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps({"location": "Beijing", "date": "2024-01-16"}), + ) + ), + ] + + request = ChatCompletionRequest( + model="deepseek-v3", + messages=[{"role": "user", "content": "What is the weather?"}], + tools=sample_tools, + ) + + extracted = deepseek_tool_parser.extract_tool_calls(model_output, request) + + assert extracted.tools_called + assert extracted.content is None + assert extracted.tool_calls[0].id != extracted.tool_calls[1].id + assert_tool_calls(extracted.tool_calls, expected_tool_calls) + + +def test_extract_tool_calls_with_end_of_sentence_token( + deepseek_tool_parser, + sample_tools, +): + """Test extracting function calls with end-of-sentence token""" + model_output = """<|DSML|function_calls> +<|DSML|invoke name="get_weather"> +<|DSML|parameter name="location" string="true">Hangzhou +<|DSML|parameter name="date" string="true">2024-01-16 + +<|end▁of▁sentence|>""" + + expected_tool_calls = [ + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps({"location": "Hangzhou", "date": "2024-01-16"}), + ) + ), + ] + + request = ChatCompletionRequest( + model="deepseek-v3", + messages=[{"role": "user", "content": "What is the weather?"}], + tools=sample_tools, + ) + + extracted = deepseek_tool_parser.extract_tool_calls(model_output, request) + + assert extracted.tools_called + assert extracted.content is None + assert_tool_calls(extracted.tool_calls, expected_tool_calls) + + +def test_extract_tool_calls_streaming( + deepseek_tool_parser, + sample_tools, +): + """Test streaming extraction of function calls""" + # Simulate streaming chunks + chunks = [ + "<|DSML|function_calls>", + '\n<|DSML|invoke name="get_weather">', + '\n<|DSML|parameter name="location" string="true">', + "Hangzhou", + "", + '\n<|DSML|parameter name="date" string="true">', + "2024-01-16", + "", + "\n", + "\n", + ] + + request = ChatCompletionRequest( + model="deepseek-v3", + messages=[{"role": "user", "content": "What is the weather?"}], + tools=sample_tools, + ) + + # Track accumulated state + tool_states = {} + previous_text = "" + + for chunk in chunks: + current_text = previous_text + chunk + delta_text = chunk + + # Call streaming extraction + delta_message = deepseek_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=request, + ) + + if delta_message and delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index if tool_call.index is not None else 0 + + # Initialize state for new tool + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None, + } + + # Update state + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + tool_states[idx]["arguments"] += tool_call.function.arguments + + previous_text = current_text + + # Verify final state + assert len(tool_states) == 1, f"Expected 1 tool call, got {len(tool_states)}" + + state = tool_states[0] + assert state["id"] is not None, "Tool call ID should be set" + assert state["type"] == "function", f"Expected type 'function', got {state['type']}" + assert state["name"] == "get_weather", ( + f"Expected name 'get_weather', got {state['name']}" + ) + assert state["arguments"] is not None + # Verify arguments + arguments = json.loads(state["arguments"]) + assert arguments == { + "location": "Hangzhou", + "date": "2024-01-16", + }, f"Unexpected arguments: {arguments}" + + +def test_extract_tool_calls_streaming_multiple_functions( + deepseek_tool_parser, + sample_tools, +): + """Test streaming extraction of multiple function calls""" + # Simulate streaming chunks for two function calls + chunks = [ + "<|DSML|function_calls>", + '\n<|DSML|invoke name="get_weather">', + '\n<|DSML|parameter name="location" string="true">Hangzhou', # noqa: E501 + '\n<|DSML|parameter name="date" string="true">2024-01-16', + "\n", + '\n<|DSML|invoke name="get_weather">', + '\n<|DSML|parameter name="location" string="true">Beijing', # noqa: E501 + '\n<|DSML|parameter name="date" string="true">2024-01-16', + "\n", + "\n", + ] + + request = ChatCompletionRequest( + model="deepseek-v3", + messages=[{"role": "user", "content": "What is the weather?"}], + tools=sample_tools, + ) + + # Track accumulated state + tool_states = {} + previous_text = "" + + for chunk in chunks: + current_text = previous_text + chunk + delta_text = chunk + # Call streaming extraction + delta_message = deepseek_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=request, + ) + + if delta_message and delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index if tool_call.index is not None else 0 + + # Initialize state for new tool + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None, + } + + # Update state + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + tool_states[idx]["arguments"] += tool_call.function.arguments + + previous_text = current_text + + # Verify final state + assert len(tool_states) == 2, f"Expected 2 tool calls, got {len(tool_states)}" + + # Verify first tool call + state0 = tool_states[0] + assert state0["id"] is not None + assert state0["type"] == "function" + assert state0["name"] == "get_weather" + assert state0["arguments"] is not None + arguments0 = json.loads(state0["arguments"]) + assert arguments0 == {"location": "Hangzhou", "date": "2024-01-16"} + + # Verify second tool call + state1 = tool_states[1] + assert state1["id"] is not None + assert state1["id"] != state0["id"] + assert state1["type"] == "function" + assert state1["name"] == "get_weather" + assert state1["arguments"] is not None + arguments1 = json.loads(state1["arguments"]) + assert arguments1 == {"location": "Beijing", "date": "2024-01-16"} + + +def test_extract_tool_calls_streaming_incomplete_chunk_functions( + deepseek_tool_parser, + sample_tools, +): + """Test streaming extraction of multiple function calls""" + # Simulate streaming chunks for two function calls + chunks = [ + "<|DSML", + "|function_calls>", + '\n<|DSML|invoke name="get_current_weather">', + '\n<|DSML|parameter name="location" string="true">北京', + '\n<|DSML|parameter name="time" string="true">2025-03-21', + "\n", + "", + ] + + request = ChatCompletionRequest( + model="deepseek-v3", + messages=[{"role": "user", "content": "What is the weather?"}], + tools=sample_tools, + ) + + # Track accumulated state + tool_states = {} + previous_text = "" + + for chunk in chunks: + current_text = previous_text + chunk + delta_text = chunk + # Call streaming extraction + delta_message = deepseek_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=request, + ) + + if delta_message and delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index if tool_call.index is not None else 0 + + # Initialize state for new tool + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None, + } + + # Update state + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + tool_states[idx]["arguments"] += tool_call.function.arguments + + previous_text = current_text + + # Verify final state + assert len(tool_states) == 1, f"Expected 1 tool calls, got {len(tool_states)}" + + # Verify first tool call + state0 = tool_states[0] + assert state0["id"] is not None + assert state0["type"] == "function" + assert state0["name"] == "get_current_weather" + assert state0["arguments"] is not None + arguments0 = json.loads(state0["arguments"]) + assert arguments0 == {"location": "北京", "time": "2025-03-21"} + + +def test_extract_tool_calls_streaming_incomplete_chunk_function2( + deepseek_tool_parser, + sample_tools, +): + """Test streaming extraction of multiple function calls""" + # Simulate streaming chunks for two function calls + chunks = [ + "<|DSML|function_calls>", + '<|DSML|invoke name="get_current_weather">', + '<|DSML|parameter name="location" string="true">北京', + '<|DSML|parameter name="time" string="true">2025-03-21', + "", + '<|DSML|invoke name="get_current_weather">', + '<|DSML|parameter name="location" string="true">北京', + "", + " ", + ] + + request = ChatCompletionRequest( + model="deepseek-v3", + messages=[{"role": "user", "content": "What is the weather?"}], + tools=sample_tools, + ) + + # Track accumulated state + tool_states = {} + previous_text = "" + + for chunk in chunks: + current_text = previous_text + chunk + delta_text = chunk + # Call streaming extraction + delta_message = deepseek_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=request, + ) + print("aaa", delta_message) + if delta_message and delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index if tool_call.index is not None else 0 + + # Initialize state for new tool + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None, + } + + # Update state + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + tool_states[idx]["arguments"] += tool_call.function.arguments + + previous_text = current_text + + # Verify final state + assert len(tool_states) == 2, f"Expected 2 tool calls, got {len(tool_states)}" + + # Verify first tool call + state0 = tool_states[0] + assert state0["id"] is not None + assert state0["type"] == "function" + assert state0["name"] == "get_current_weather" + assert state0["arguments"] is not None + arguments0 = json.loads(state0["arguments"]) + assert arguments0 == {"location": "北京", "time": "2025-03-21"} + + # Verify second tool call + state0 = tool_states[1] + assert state0["id"] is not None + assert state0["type"] == "function" + assert state0["name"] == "get_current_weather" + assert state0["arguments"] is not None + arguments0 = json.loads(state0["arguments"]) + assert arguments0 == {"location": "北京"} diff --git a/vllm/entrypoints/openai/tool_parsers/deepseek_v32_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseek_v32_tool_parser.py new file mode 100644 index 000000000000..fd85b0e8f5b9 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/deepseek_v32_tool_parser.py @@ -0,0 +1,1418 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import json +from collections.abc import Sequence +from typing import Any +from xml.parsers.expat import ParserCreate + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike + +logger = init_logger(__name__) + + +class StreamingXMLToolCallParser: + """ + Simplified streaming XML tool call parser + Supports streaming input, parsing, and output + """ + + def __init__(self): + self.reset_streaming_state() + + # Tool configuration information + self.tools: list[ChatCompletionToolsParam] | None = None + + # DeepSeek v3.2 token definitions + self.tool_call_start_token: str = "<|DSML|function_calls>" + self.tool_call_end_token: str = "" + self.function_start_token: str = "<|DSML|invoke" + self.function_end_token: str = "" + self.parameter_start_token: str = "<|DSML|parameter" + self.parameter_end_token: str = "" + self.end_of_sentence_token: str = "<|end▁of▁sentence|>" + + def reset_streaming_state(self): + """Reset streaming parsing state""" + + self.deltas = [] + # state for streaming + self.tool_call_index = 0 + self.current_call_id = None + self.last_completed_call_id = None + self.current_function_name = None + self.current_function_open = False + self.parameters = {} + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.current_param_is_first = False + self.should_emit_end_newline = False + self.start_quote_emitted = False + + self.streaming_buffer = "" + self.last_processed_pos = 0 + + self.text_content_buffer = "" + + # state for preprocessing and deferred parsing + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + + # DeepSeek explicit type support + self.current_param_explicit_type = None + + # recreate parser + self.parser = ParserCreate() + self.setup_parser() + + def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: + """ + Parse single streaming XML chunk and return Delta response + This is the actual streaming interface that receives chunks + one by one and maintains internal state + + Args: + xml_chunk: Single XML chunk string + Returns: + DeltaMessage: Contains delta information generated by this chunk, + returns empty response if no complete elements + """ + # Remove DeepSeek end-of-sentence token if present + if self.end_of_sentence_token and self.end_of_sentence_token in xml_chunk: + xml_chunk = xml_chunk.replace(self.end_of_sentence_token, "") + + # Record delta count before processing + initial_delta_count = len(self.deltas) + + self.streaming_buffer += xml_chunk + + found_elements = self._process_complete_xml_elements() + if found_elements: + # If complete elements found, check if end events were missed + # some tags may not have been triggered + try: + new_deltas = self.deltas[initial_delta_count:] + # If this chunk contains + # but didn't generate '}', then complete it + if ( + self.current_call_id is not None + and "" in xml_chunk + ): + # - Added '}' (non-empty parameter ending) + # - Added '{}' (empty parameter function) + has_function_close = any( + ( + td.tool_calls + and any( + ( + tc.function + and tc.id == self.current_call_id + and isinstance(tc.function.arguments, str) + and (tc.function.arguments in ("}", "{}")) + ) + for tc in td.tool_calls + ) + ) + for td in new_deltas + ) + if not has_function_close: + # Close potentially unclosed element + if self.current_param_name: + self._end_element("DSML_parameter") + if self.current_function_name: + self._end_element("DSML_invoke") + # If this chunk contains + # but didn't generate final empty delta, then complete it + if ( + self.current_call_id is not None + and "" in xml_chunk + ): + has_container_close = any( + ( + td.tool_calls + and any( + ( + tc.type == "function" + and tc.function + and tc.function.arguments == "" + and tc.id == self.current_call_id + ) + for tc in td.tool_calls + ) + ) + for td in new_deltas + ) + if not has_container_close: + # Close potentially unclosed element + if self.current_param_name: + self._end_element("DSML_parameter") + if self.current_function_name: + self._end_element("DSML_invoke") + self._end_element("DSML_function_calls") + except Exception as e: + logger.error("Error with fallback parsing: %s", e) + # Merge newly generated deltas into single response + result_delta = self._merge_new_deltas_to_single_response( + initial_delta_count + ) + return result_delta + else: + # No complete elements, check if there's unoutput text content + if self.text_content_buffer and self.tool_call_index == 0: + # Has text content but no tool_call yet, output text content + text_delta = DeltaMessage(content=self.text_content_buffer) + self._emit_delta(text_delta) + # Clear buffer to avoid duplicate output + self.text_content_buffer = "" + return text_delta + + # If this chunk contains end tags but wasn't triggered by parser, + # manually complete end events + # Only execute when still on the same call as when entered, + # to prevent accidentally closing new calls + # in multi invoke scenarios + if self.current_call_id is not None and ( + "" in xml_chunk + or "" in xml_chunk + ): + # Close potentially unclosed element + if self.current_param_name: + self._end_element("DSML_parameter") + if "" in xml_chunk and self.current_function_name: + self._end_element("DSML_invoke") + if "" in xml_chunk: + self._end_element("DSML_function_calls") + # Return the merged delta result generated by this fallback + result_delta = self._merge_new_deltas_to_single_response( + initial_delta_count + ) + return result_delta + + # No complete elements, return empty response + return DeltaMessage(content=None) + + def _escape_xml_special_chars(self, text: str) -> str: + """ + Escape XML special characters + Args: + text: Original text + Returns: + Escaped text + """ + xml_escapes = { + "&": "&", + "<": "<", + ">": ">", + '"': """, + "'": "'", + } + + for char, escape in xml_escapes.items(): + text = text.replace(char, escape) + + return text + + def _process_complete_xml_elements(self) -> bool: + """ + Process complete XML elements in buffer + + Returns: + bool: Whether complete elements were found and processed + """ + found_any = False + while self.last_processed_pos < len(self.streaming_buffer): + # Find next complete xml element + element, end_pos = self._find_next_complete_element(self.last_processed_pos) + if element is None: + # No complete element found, wait for more data + break + + # Check if this element should be skipped + if self._should_skip_element(element): + self.last_processed_pos = end_pos + continue + + # Found complete XML element, process it + try: + preprocessed_element = self._preprocess_xml_chunk(element) + # Check if this is the first tool_call start + # Note: after preprocessing, <|DSML|invoke becomes 0 + and self.current_call_id + ): + # Reset parser state but preserve generated deltas + if self.current_param_name: + self._end_element("DSML_parameter") + if self.current_function_open or self.current_function_name: + self._end_element("DSML_invoke") + # Output final invoke tail delta + final_delta = DeltaMessage( + role=None, + content=None, + reasoning=None, + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=""), + ) + ], + ) + self._emit_delta(final_delta) + + # Reset current call state (lightweight reset for DeepSeek) + self._reset_current_call_state() + # Parse preprocessed element + self.parser.Parse(preprocessed_element, False) + found_any = True + + except Exception as e: + logger.warning("Error when parsing XML elements: %s", e) + + # Update processed position + self.last_processed_pos = end_pos + + return found_any + + def _should_skip_element(self, element: str) -> bool: + """ + Determine whether an element should be skipped + + Args: + element: Element to evaluate + + Returns: + bool: True means should skip, False means should process + """ + + # Check if it's a DeepSeek tool_call XML tag, don't skip + if ( + element.startswith(self.tool_call_start_token) + or element.startswith(self.function_start_token) + or element.startswith(self.parameter_start_token) + ): + return False + + # If currently not parsing tool calls and not blank, + # collect this text instead of skipping + # Only process other XML elements after tool_call appears, + # otherwise treat as plain text + if self.current_call_id is None and element: + # Collect text content to buffer + self.text_content_buffer += element + return True # Still skip, but content has been collected + + # If currently parsing tool calls, + # this might be parameter value, don't skip + if self.current_call_id is not None: + return False + + # Skip blank content + return not element + + def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]: + """ + Find next complete XML element from specified position + + Args: + start_pos: Position to start searching + + Returns: + (Complete element string, element end position), + returns (None, start_pos) if no complete element found + """ + buffer = self.streaming_buffer[start_pos:] + + if not buffer: + return None, start_pos + + if buffer.startswith("<"): + # Need to ensure no new < appears, + # find the nearest one between < and > + tag_end = buffer.find("<", 1) + tag_end2 = buffer.find(">", 1) + if tag_end != -1 and tag_end2 != -1: + # Next nearest is < + if tag_end < tag_end2: + return buffer[:tag_end], start_pos + tag_end + # Next nearest is >, means found XML element + else: + return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 + elif tag_end != -1: + return buffer[:tag_end], start_pos + tag_end + elif tag_end2 != -1: + return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 + else: + # If currently not parsing tool calls (entering a tool_call), + # check if starts with DeepSeek tags + if self.current_call_id is None: + # Check if might be start of DeepSeek tags + if ( + buffer.startswith(" DeltaMessage: + """ + Merge newly generated deltas from this processing + into a single DeltaMessage + + Args: + initial_count: Delta count before processing + + Returns: + Merged DeltaMessage containing all newly generated delta information + """ + if len(self.deltas) <= initial_count: + return DeltaMessage(content=None) + + # Get newly generated deltas + new_deltas = self.deltas[initial_count:] + + if len(new_deltas) == 1: + # Only one new delta, return directly + return new_deltas[0] + + # Merge multiple new deltas + merged_tool_calls: list[DeltaToolCall] = [] + merged_content: str = "" + + for delta in new_deltas: + if delta.content: + merged_content += delta.content + if delta.tool_calls: + # For tool_calls, we need to intelligently merge arguments + for tool_call in delta.tool_calls: + # Find if there's already a tool_call with the same call_id + existing_call = None + for existing in merged_tool_calls: + if existing.id == tool_call.id: + existing_call = existing + break + + if existing_call and existing_call.function: + # Merge to existing tool_call + if tool_call.function and tool_call.function.name: + existing_call.function.name = tool_call.function.name + if ( + tool_call.function + and tool_call.function.arguments is not None + ): + if existing_call.function.arguments is None: + existing_call.function.arguments = "" + + # For streaming JSON parameters, + # simply concatenate in order + new_args = tool_call.function.arguments + existing_call.function.arguments += new_args + if tool_call.type: + existing_call.type = tool_call.type + else: + # Add new tool_call + merged_tool_calls.append(tool_call) + + return DeltaMessage( + content=merged_content if merged_content else None, + tool_calls=merged_tool_calls, + ) + + def _preprocess_xml_chunk(self, chunk: str) -> str: + """ + Preprocess XML chunk, handle DeepSeek special characters, + and escape special characters + + Args: + chunk: Original XML chunk + + Returns: + Processed XML chunk + """ + + # Step 1: Normalize DeepSeek special characters to XML-safe characters + processed = chunk.replace("|DSML|", "DSML_") + + # Check if this is a tool_call related element + is_tool_call = False + if ( + "DSML_function_calls" in processed + or "DSML_invoke" in processed + or "DSML_parameter" in processed + ): + is_tool_call = True + + # DeepSeek already uses standard attribute format, no conversion needed + # ✓ + # ✓ + + original_chunk = chunk + # If in parameter value accumulation mode + if self._pre_inside_parameter: + # Parameter end: output accumulated raw text + # safely then return + if processed.startswith(""): + body_text = self._pre_param_buffer + # Trigger deferred parsing mode + # literal_eval+json output in end_element + self.defer_current_parameter = True + self.deferred_param_raw_value = body_text + # Clean up state + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + safe_text = self._escape_xml_special_chars(body_text) + return f"{safe_text}" + else: + # If this is the first block of content after entering parameter + # evaluate if deferred parsing is needed; + # If not needed, exit accumulation mode + # and pass through directly + if self._pre_param_buffer == "": + # Get current parameter type + param_type = ( + self._get_param_type(self._pre_current_param_name) + if self._pre_current_param_name + else "string" + ) + # Only these types need deferred parsing to + # handle Python literals containing single quotes + is_object_type = param_type in ["object"] + is_complex_type = ( + param_type in ["array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list") + ) + + # Only delay when contains container symbols + # and has single quotes and is complex type + has_container_hint = ( + ("[" in original_chunk) + or ("{" in original_chunk) + or ("(" in original_chunk) + ) + + # Determine if deferred parsing is needed + need_defer = False + if is_complex_type: + # Complex type, always need deferred parsing + need_defer = True + elif ( + is_object_type + and has_container_hint + and ("'" in original_chunk) + ): + # Object type with container symbols + # and single quotes, need deferred parsing + need_defer = True + + if not need_defer: + # No need for deferred parsing, + # exit parameter mode directly + self._pre_inside_parameter = False + return self._escape_xml_special_chars(original_chunk) + self._pre_param_buffer += original_chunk + return "" + + # Parameter start: enable accumulation + if processed.startswith("` + - When about to start a new function or tool_call, + if there are unclosed functions, complete ``. + - When about to start a new tool_call, + if there are unclosed tool_calls, complete ``. + """ + # First close unclosed parameters + if self.current_param_name: + self._end_element("DSML_parameter") + + # If about to start new function or tool_call, + # and there are unclosed functions, close function first + if ( + incoming_tag in ("DSML_function_calls", "DSML_invoke") + and self.current_function_name + ): + self._end_element("DSML_invoke") + + # If about to start new tool_call, + # and there are unclosed tool_calls, close tool_call first + if incoming_tag == "DSML_function_calls" and self.current_call_id: + self._end_element("DSML_function_calls") + + def _start_element(self, name: str, attrs: dict[str, str]): + """Handle XML start element events for DeepSeek format""" + if name == "root": + return + # Handle function calls + if name == "DSML_function_calls": + # Before opening new tool_call, + # automatically complete previous unclosed tags + self._auto_close_open_parameter_if_needed("DSML_function_calls") + # add a dummy call id to indicate it is in function + self.current_call_id = -1 + + # Handle invoke (equivalent to Qwen3's tool_call + function combined) + if name == "DSML_invoke": + self.parameters = {} + self.current_call_id = make_tool_call_id() + self.current_param_is_first = True + self.tool_call_index += 1 + + # # If missing tool_call, manually complete + # if not self.current_call_id: + # self._start_element("DSML_function_calls", {}) + + # Extract function name from invoke's name attribute + function_name = attrs.get("name") + if function_name: + self.current_function_name = function_name + self.current_function_open = True + + # Emit delta with function name + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=function_name, arguments="" + ), + ) + ] + ) + self._emit_delta(delta) + + # Handle parameter + elif name == "DSML_parameter": + # If no invoke yet, create one (fault tolerance) + if not self.current_call_id: + self._start_element("DSML_invoke", {"name": "unknown"}) + + # Auto-close previous parameter if exists + self._auto_close_open_parameter_if_needed("DSML_parameter") + + # Extract parameter name and type + param_name = attrs.get("name") + self.current_param_name = param_name + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False + + # Save explicit type information + if attrs.get("string") == "true": + self.current_param_explicit_type = "string" + elif attrs.get("number") == "true": + self.current_param_explicit_type = "number" + elif attrs.get("boolean") == "true": + self.current_param_explicit_type = "boolean" + elif attrs.get("object") == "true": + self.current_param_explicit_type = "object" + elif attrs.get("array") == "true": + self.current_param_explicit_type = "array" + else: + self.current_param_explicit_type = None + + # Output JSON parameter name and colon + if param_name: + if not self.parameters: + # First parameter - start JSON object + json_start = f'{{"{param_name}": ' + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=json_start + ), + ) + ] + ) + self._emit_delta(delta) + self.current_param_is_first = True + else: + # Subsequent parameters - add comma + json_continue = f', "{param_name}": ' + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=json_continue + ), + ) + ] + ) + self._emit_delta(delta) + self.current_param_is_first = False + + def _char_data(self, data: str): + """Handle XML character data events""" + if data and self.current_param_name: + # If preprocessing stage determines deferred parsing is needed, + # only cache character data, no streaming output + if self.defer_current_parameter: + original_data = data + if self.should_emit_end_newline: + original_data = "\n" + original_data + self.should_emit_end_newline = False + if original_data.endswith("\n"): + self.should_emit_end_newline = True + original_data = original_data[:-1] + self.current_param_value += original_data + return + + param_type = self._get_effective_param_type(self.current_param_name) + + # Check if this is the first time receiving data for this parameter + # If this is the first packet of data and starts with \n, remove \n + if not self.current_param_value and data.startswith("\n"): + data = data[1:] + + # Output start quote for string type (if not already output) + if ( + param_type in ["string", "str", "text", "varchar", "char", "enum"] + and not self.start_quote_emitted + ): + quote_delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='"'), + ) + ] + ) + self._emit_delta(quote_delta) + self.start_quote_emitted = True + + if not data: + return + + original_data = data + # Delay output of trailing newline + if self.should_emit_end_newline: + original_data = "\n" + original_data + self.should_emit_end_newline = False + if original_data.endswith("\n"): + self.should_emit_end_newline = True + original_data = original_data[:-1] + self.current_param_value += original_data + + # convert parameter value by param_type + converted_value = self._convert_param_value( + self.current_param_value, param_type + ) + output_data = self._convert_for_json_streaming(converted_value, param_type) + + delta_data = output_data[len(self.current_param_value_converted) :] + self.current_param_value_converted = output_data + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=delta_data), + ) + ] + ) + self._emit_delta(delta) + + def _end_element(self, name: str): + """Handle XML end element events for DeepSeek format""" + + if name == "root": + return + + # If invoke ends and there are still unclosed parameters, + # complete parameter end first + if ( + name == "DSML_invoke" or name == "DSML_function_calls" + ) and self.current_param_name: + self._auto_close_open_parameter_if_needed() + + # Handle parameter ending + if name == "DSML_parameter" and self.current_param_name: + # End current parameter + param_name = self.current_param_name + param_value = self.current_param_value + + # If in deferred parsing mode, + # perform overall parsing on raw content + # accumulated in preprocessing stage and output once + if self.defer_current_parameter: + raw_text = ( + self.deferred_param_raw_value + if self.deferred_param_raw_value + else param_value + ) + parsed_value = None + output_arguments = None + try: + # If previously delayed trailing newline, + # add it back before parsing + if self.should_emit_end_newline: + raw_for_parse = raw_text + "\n" + else: + raw_for_parse = raw_text + parsed_value = ast.literal_eval(raw_for_parse) + output_arguments = json.dumps(parsed_value, ensure_ascii=False) + except Exception: + # Fallback: output as string as-is + output_arguments = json.dumps(raw_text, ensure_ascii=False) + parsed_value = raw_text + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=output_arguments + ), + ) + ] + ) + self._emit_delta(delta) + + # Clean up and store + self.should_emit_end_newline = False + self.parameters[param_name] = parsed_value + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + # Reset explicit type + self.current_param_explicit_type = None + return + + param_type = self._get_effective_param_type(param_name) + + # convert complete parameter value by param_type + converted_value = self._convert_param_value(param_value, param_type) + + # Decide whether to add end quote based on parameter type + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + # For empty string parameters, need special handling + if not param_value and not self.start_quote_emitted: + # No start quote output, + # directly output complete empty string + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='""'), + ) + ] + ) + self._emit_delta(delta) + else: + # Non-empty parameter value, output end quote + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='"'), + ) + ] + ) + self._emit_delta(delta) + + self.should_emit_end_newline = False + # Store converted value + self.parameters[param_name] = converted_value + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False + # Reset explicit type + self.current_param_explicit_type = None + + # Handle invoke ending (equivalent to Qwen3's function + tool_call) + elif name == "DSML_invoke": + # Ensure parameter is closed first + if self.current_param_name: + self._end_element("DSML_parameter") + + # Close JSON object + if self.parameters: + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="}"), + ) + ] + ) + self._emit_delta(delta) + else: + # Empty parameters - output {} + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="{}"), + ) + ] + ) + self._emit_delta(delta) + self.current_function_open = False + + elif name == "DSML_function_calls": + # Before ending tool_call, + # ensure function is closed to complete missing right brace + if self.current_function_open: + # If there are still unclosed parameters, close them first + if self.current_param_name: + self._end_element("DSML_parameter") + # Close function, ensure output '}' or '{}' + self._end_element("DSML_invoke") + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=""), + ) + ] + ) + self._emit_delta(delta) + # Check if there's text content to output (between invokes) + if self.text_content_buffer.strip(): + text_delta = DeltaMessage(content=self.text_content_buffer) + self._emit_delta(text_delta) + + # Lightweight reset for next invoke (don't reset parser) + self._reset_current_call_state() + + def setup_parser(self): + """Set up XML parser event handlers""" + self.parser.buffer_text = True + self.parser.StartElementHandler = self._start_element + self.parser.EndElementHandler = self._end_element + self.parser.CharacterDataHandler = self._char_data + + def set_tools(self, tools: list[ChatCompletionToolsParam] | None): + """Set tool configuration information""" + self.tools = tools + + def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None: + """Extract function name from various formats""" + if attrs and "name" in attrs: + return attrs["name"] + + if "=" in name: + parts = name.split("=", 1) + if len(parts) == 2 and parts[0] == "function": + return parts[1] + + return None + + def _extract_parameter_name(self, name: str, attrs: dict[str, str]) -> str | None: + """Extract parameter name from various formats""" + if attrs and "name" in attrs: + return attrs["name"] + + if "=" in name: + parts = name.split("=", 1) + if len(parts) == 2 and parts[0] == "parameter": + return parts[1] + + return None + + def _get_param_type(self, param_name: str) -> str: + """Get parameter type based on tool configuration, defaults to string + Args: + param_name: Parameter name + + Returns: + Parameter type + """ + if not self.tools or not self.current_function_name: + return "string" + + for tool in self.tools: + if not hasattr(tool, "type") or not ( + hasattr(tool, "function") and hasattr(tool.function, "name") + ): + continue + if ( + tool.type == "function" + and tool.function.name == self.current_function_name + ): + if not hasattr(tool.function, "parameters"): + return "string" + params = tool.function.parameters + if isinstance(params, dict) and "properties" in params: + properties = params["properties"] + if param_name in properties and isinstance( + properties[param_name], dict + ): + return self.repair_param_type( + str(properties[param_name].get("type", "string")) + ) + elif isinstance(params, dict) and param_name in params: + param_config = params[param_name] + if isinstance(param_config, dict): + return self.repair_param_type( + str(param_config.get("type", "string")) + ) + break + return "string" + + def _get_effective_param_type(self, param_name: str) -> str: + """Get effective parameter type + Args: + param_name: Parameter name + + Returns: + Effective parameter type + """ + # DeepSeek: if explicit type annotation exists, use it + if ( + hasattr(self, "current_param_explicit_type") + and self.current_param_explicit_type + ): + return self.current_param_explicit_type + + # Otherwise infer from tool schema + return self._get_param_type(param_name) + + def repair_param_type(self, param_type: str) -> str: + """Repair unknown parameter types by treating them as string + Args: + param_type: Parameter type + + Returns: + Repaired parameter type + """ + if ( + param_type in ["string", "str", "text", "varchar", "char", "enum"] + or param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + or param_type.startswith("num") + or param_type.startswith("float") + or param_type in ["boolean", "bool", "binary"] + or ( + param_type in ["object", "array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list") + ) + ): + return param_type + else: + return "string" + + def _convert_param_value(self, param_value: str, param_type: str) -> Any: + """Convert value based on parameter type + Args: + param_value: Parameter value + param_type: Parameter type + + Returns: + Converted value + """ + if param_value.lower() == "null": + return None + + param_type = param_type.strip().lower() + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif ( + param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + ): + try: + return int(param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an integer " + "in tool '%s', degenerating to string.", + param_value, + ) + return param_value + elif param_type.startswith("num") or param_type.startswith("float"): + try: + float_param_value: float = float(param_value) + return ( + float_param_value + if float_param_value - int(float_param_value) != 0 + else int(float_param_value) + ) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float " + "in tool '%s', degenerating to string.", + param_value, + ) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + return param_value == "true" + else: + return param_value + + def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str: + """Convert converted_value based on + whether it's empty and if type is string + Args: + converted_value: Converted value + param_type: Parameter type + + Returns: + Converted string for streaming output + """ + # Check if value is empty, but exclude numeric 0 + if converted_value is None or converted_value == "": + return "" + + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + # String type, remove double quotes + return json.dumps(converted_value, ensure_ascii=False)[1:-1] + else: + # Non-string type, return complete JSON string + if not isinstance(converted_value, str): + return json.dumps(converted_value, ensure_ascii=False) + else: + return converted_value + + def _reset_xml_parser_after_tool_call(self): + """ + Each tool_call is treated as a separate XML document, + so we need to reset the parser after each tool_call. + """ + + # recreate XML parser + self.parser = ParserCreate() + self.setup_parser() + + # Reset current tool_call state + if self.current_call_id: + self.last_completed_call_id = self.current_call_id + self.current_call_id = None + self.current_function_name = None + self.current_function_open = False + self.parameters = {} + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.current_param_is_first = False + self.should_emit_end_newline = False + self.start_quote_emitted = False + self.text_content_buffer = "" + + # Reset preprocessing and deferred parsing state + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + + def _reset_current_call_state(self): + """ + Lightweight reset for current invoke state. + Used for DeepSeek to handle multiple invokes within the same + function_calls container. + Does NOT reset the parser or deltas. + """ + # recreate XML parser + self.parser = ParserCreate() + self.setup_parser() + + # Reset current tool_call state + if self.current_call_id: + self.last_completed_call_id = self.current_call_id + self.current_call_id = None + self.current_function_name = None + self.current_function_open = False + self.parameters = {} + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.current_param_is_first = False + self.should_emit_end_newline = False + self.start_quote_emitted = False + self.text_content_buffer = "" + + # Reset preprocessing and deferred parsing state + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + + # Reset explicit type + self.current_param_explicit_type = None + + # Note: Do NOT reset: + # - self.parser (XML parser instance) + # - self.deltas (accumulated delta list) + # - self.tool_call_index (tool call index continues to increment) + # - self.streaming_buffer (streaming buffer) + + +class DeepSeekV32ToolParser(ToolParser): + def __init__(self, tokenizer: TokenizerLike): + super().__init__(tokenizer) + self.parser = StreamingXMLToolCallParser() + + # Add missing attributes for compatibility with serving_chat.py + self.prev_tool_call_arr: list[dict] = [] + self.streamed_args_for_tool: list[str] = [] + + logger.info( + "vLLM Successfully import tool parser %s !", self.__class__.__name__ + ) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + self.parser.reset_streaming_state() + # Reset tool call tracking arrays for new extraction + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] + if request: + self.parser.set_tools(request.tools) + result = self.parser.parse_single_streaming_chunks(model_output) + if not result.tool_calls: + return ExtractedToolCallInformation( + tool_calls=[], + tools_called=False, + content=result.content, + ) + else: + tool_calls = [] + for tool_call in result.tool_calls: + if tool_call.function and tool_call.function.name: + tool_calls.append( + ToolCall( + id=tool_call.id, + type=tool_call.type, + function=FunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + ) + + # Update tool call tracking arrays for compatibility + tool_index = ( + tool_call.index + if tool_call.index is not None + else len(self.prev_tool_call_arr) - 1 + ) + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({"name": "", "arguments": ""}) + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + + # Update tool call information + self.prev_tool_call_arr[tool_index]["name"] = ( + tool_call.function.name + ) + self.prev_tool_call_arr[tool_index]["arguments"] = ( + tool_call.function.arguments + ) + + # Update streamed arguments + if tool_call.function.arguments: + self.streamed_args_for_tool[tool_index] = ( + tool_call.function.arguments + ) + + return ExtractedToolCallInformation( + tool_calls=tool_calls, + tools_called=len(tool_calls) > 0, + content=result.content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + if not previous_text: + self.parser.reset_streaming_state() + # Reset tool call tracking arrays for new streaming session + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] + if request: + self.parser.set_tools(request.tools) + + # Model sometimes outputs separately causing delta_text to be empty. + # If there were tool_calls before and all current tool_calls have ended, + # return an empty tool_call for outer streaming output + # to correctly output tool_call field + if not delta_text and delta_token_ids: + open_calls = current_text.count( + self.parser.tool_call_start_token + ) - current_text.count(self.parser.tool_call_end_token) + if ( + open_calls == 0 + and self.parser.tool_call_index > 0 + or not self.parser.tool_call_index + and current_text + ): + return DeltaMessage(content="") + return None + + # Parse the delta text and get the result + result = self.parser.parse_single_streaming_chunks(delta_text) + + # Update tool call tracking arrays based on incremental parsing results + if result and result.tool_calls: + for tool_call in result.tool_calls: + if tool_call.function: + tool_index = ( + tool_call.index + if tool_call.index is not None + else len(self.prev_tool_call_arr) - 1 + ) + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({"name": "", "arguments": ""}) + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + + # Update tool name if provided + if tool_call.function.name: + self.prev_tool_call_arr[tool_index]["name"] = ( + tool_call.function.name + ) + + # Update arguments incrementally + if tool_call.function.arguments is not None: + # Concatenate the incremental arguments + # to the existing streamed arguments + self.prev_tool_call_arr[tool_index]["arguments"] += ( + tool_call.function.arguments + ) + self.streamed_args_for_tool[tool_index] += ( + tool_call.function.arguments + ) + return result