-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
add toolparser for deepseek v3.2 reusing qwen xml parser #29874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add toolparser for deepseek v3.2 reusing qwen xml parser #29874
Conversation
Signed-off-by: wenmengzhou <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if ( | ||
| buffer.startswith("<DSML_function_calls") | ||
| or buffer == "<DSML_function_calls"[: len(buffer)] | ||
| or buffer.startswith("<DSML_invoke") | ||
| or buffer == "<DSML_invoke"[: len(buffer)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle partial DeepSeek tags with original token prefixes
When the stream delivers a split token such as <|DSML|function_calls without the closing >, _find_next_complete_element only waits if the buffer starts with <DSML_…> and otherwise treats the fragment as plain text. DeepSeek actually emits <|DSML|…> prefixes, so any tag broken across chunks will be consumed as text and the tool call never parsed, breaking streaming extraction whenever tags are split mid-token.
Useful? React with 👍 / 👎.
| logger.warning( | ||
| "Parsed value '%s' of parameter '%s' is not an integer " | ||
| "in tool '%s', degenerating to string.", | ||
| param_value, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix logging argument mismatch in numeric parse fallback
The warning emitted when integer parsing fails supplies only param_value to a format string with three %s placeholders, so hitting this path raises TypeError inside logging and aborts parsing instead of gracefully falling back to a string. The same mismatch appears in the float branch below; any non-numeric output for numeric parameters will crash the parser rather than returning the raw value.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a tool parser for DeepSeek v3.2 by adapting the existing Qwen XML parser. While the implementation is comprehensive for both streaming and non-streaming scenarios, I have significant concerns about the robustness and maintainability of the core parsing logic. The StreamingXMLToolCallParser contains a critical flaw in its stream-splitting mechanism that can lead to parser failures on valid tool calls. Furthermore, the overall complexity of this class is very high, making it difficult to maintain. The accompanying tests also lack coverage for crucial edge cases that could expose these fragilities. My review includes detailed feedback on these critical and high-severity issues with recommendations for simplification and improved testing.
| def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]: | ||
| """ | ||
| Find next complete XML element from specified position | ||
| Args: | ||
| start_pos: Position to start searching | ||
| Returns: | ||
| (Complete element string, element end position), | ||
| returns (None, start_pos) if no complete element found | ||
| """ | ||
| buffer = self.streaming_buffer[start_pos:] | ||
|
|
||
| if not buffer: | ||
| return None, start_pos | ||
|
|
||
| if buffer.startswith("<"): | ||
| # Need to ensure no new < appears, | ||
| # find the nearest one between < and > | ||
| tag_end = buffer.find("<", 1) | ||
| tag_end2 = buffer.find(">", 1) | ||
| if tag_end != -1 and tag_end2 != -1: | ||
| # Next nearest is < | ||
| if tag_end < tag_end2: | ||
| return buffer[:tag_end], start_pos + tag_end | ||
| # Next nearest is >, means found XML element | ||
| else: | ||
| return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 | ||
| elif tag_end != -1: | ||
| return buffer[:tag_end], start_pos + tag_end | ||
| elif tag_end2 != -1: | ||
| return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 | ||
| else: | ||
| # If currently not parsing tool calls (entering a tool_call), | ||
| # check if starts with DeepSeek tags | ||
| if self.current_call_id is None: | ||
| # Check if might be start of DeepSeek tags | ||
| if ( | ||
| buffer.startswith("<DSML_function_calls") | ||
| or buffer == "<DSML_function_calls"[: len(buffer)] | ||
| or buffer.startswith("<DSML_invoke") | ||
| or buffer == "<DSML_invoke"[: len(buffer)] | ||
| ): | ||
| # Might be start of DeepSeek tag, wait for more data | ||
| return None, start_pos | ||
| else: | ||
| # Not start of DeepSeek tag, treat as text | ||
| return buffer, start_pos + len(buffer) | ||
| else: | ||
| # When parsing tool calls, | ||
| # wait for more data to get complete tag | ||
| return None, start_pos | ||
| else: | ||
| # Find text content (until next < or buffer end) | ||
| next_tag_pos = buffer.find("<") | ||
| if next_tag_pos != -1: | ||
| # Found text content | ||
| text_content = buffer[:next_tag_pos] | ||
| return text_content, start_pos + next_tag_pos | ||
| else: | ||
| # Buffer end is all text, process | ||
| # (no longer wait for more data) | ||
| remaining = buffer | ||
| return remaining, start_pos + len(remaining) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function _find_next_complete_element has a fundamental flaw in its parsing logic. It splits the input stream by finding the next < or > character. This approach does not account for < or > characters appearing inside quoted attribute values, which is valid in XML. For example, an input like <invoke name="a<b"> would be incorrectly split into <invoke name="a<, which is not a valid tag and will cause the expat parser to fail with an "unclosed token" error. This makes the entire tool call parsing mechanism fragile and prone to errors with certain model outputs.
A more robust approach should be used. Instead of manually trying to find complete XML elements, I recommend simplifying the parsing logic to leverage expat's streaming capabilities more directly. The expat parser is designed to handle chunked data. The pre-processing should be limited to what's necessary, like replacing DeepSeek's special tokens.
If the goal is to sanitize the input before it reaches expat, a more sophisticated state machine that is aware of XML quoting rules is needed for splitting the stream. However, this adds a lot of complexity, and it's usually better to rely on the battle-tested expat parser as much as possible.
| class StreamingXMLToolCallParser: | ||
| """ | ||
| Simplified streaming XML tool call parser | ||
| Supports streaming input, parsing, and output | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| self.reset_streaming_state() | ||
|
|
||
| # Tool configuration information | ||
| self.tools: list[ChatCompletionToolsParam] | None = None | ||
|
|
||
| # DeepSeek v3.2 token definitions | ||
| self.tool_call_start_token: str = "<|DSML|function_calls>" | ||
| self.tool_call_end_token: str = "</|DSML|function_calls>" | ||
| self.function_start_token: str = "<|DSML|invoke" | ||
| self.function_end_token: str = "</|DSML|invoke>" | ||
| self.parameter_start_token: str = "<|DSML|parameter" | ||
| self.parameter_end_token: str = "</|DSML|parameter>" | ||
| self.end_of_sentence_token: str = "<|end▁of▁sentence|>" | ||
|
|
||
| def reset_streaming_state(self): | ||
| """Reset streaming parsing state""" | ||
|
|
||
| self.deltas = [] | ||
| # state for streaming | ||
| self.tool_call_index = 0 | ||
| self.current_call_id = None | ||
| self.last_completed_call_id = None | ||
| self.current_function_name = None | ||
| self.current_function_open = False | ||
| self.parameters = {} | ||
| self.current_param_name = None | ||
| self.current_param_value = "" | ||
| self.current_param_value_converted = "" | ||
| self.current_param_is_first = False | ||
| self.should_emit_end_newline = False | ||
| self.start_quote_emitted = False | ||
|
|
||
| self.streaming_buffer = "" | ||
| self.last_processed_pos = 0 | ||
|
|
||
| self.text_content_buffer = "" | ||
|
|
||
| # state for preprocessing and deferred parsing | ||
| self._pre_inside_parameter = False | ||
| self._pre_param_buffer = "" | ||
| self._pre_current_param_name = None | ||
| self.defer_current_parameter = False | ||
| self.deferred_param_raw_value = "" | ||
|
|
||
| # DeepSeek explicit type support | ||
| self.current_param_explicit_type = None | ||
|
|
||
| # recreate parser | ||
| self.parser = ParserCreate() | ||
| self.setup_parser() | ||
|
|
||
| def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: | ||
| """ | ||
| Parse single streaming XML chunk and return Delta response | ||
| This is the actual streaming interface that receives chunks | ||
| one by one and maintains internal state | ||
| Args: | ||
| xml_chunk: Single XML chunk string | ||
| Returns: | ||
| DeltaMessage: Contains delta information generated by this chunk, | ||
| returns empty response if no complete elements | ||
| """ | ||
| # Remove DeepSeek end-of-sentence token if present | ||
| if self.end_of_sentence_token and self.end_of_sentence_token in xml_chunk: | ||
| xml_chunk = xml_chunk.replace(self.end_of_sentence_token, "") | ||
|
|
||
| # Record delta count before processing | ||
| initial_delta_count = len(self.deltas) | ||
|
|
||
| self.streaming_buffer += xml_chunk | ||
|
|
||
| found_elements = self._process_complete_xml_elements() | ||
|
|
||
| if found_elements: | ||
| # If complete elements found, check if end events were missed | ||
| # some tags may not have been triggered | ||
| try: | ||
| new_deltas = self.deltas[initial_delta_count:] | ||
| # If this chunk contains </|DSML|invoke> | ||
| # but didn't generate '}', then complete it | ||
| if ( | ||
| self.current_call_id is not None | ||
| and "</|DSML|invoke>" in xml_chunk | ||
| ): | ||
| # - Added '}' (non-empty parameter ending) | ||
| # - Added '{}' (empty parameter function) | ||
| has_function_close = any( | ||
| ( | ||
| td.tool_calls | ||
| and any( | ||
| ( | ||
| tc.function | ||
| and tc.id == self.current_call_id | ||
| and isinstance(tc.function.arguments, str) | ||
| and (tc.function.arguments in ("}", "{}")) | ||
| ) | ||
| for tc in td.tool_calls | ||
| ) | ||
| ) | ||
| for td in new_deltas | ||
| ) | ||
| if not has_function_close: | ||
| # Close potentially unclosed element | ||
| if self.current_param_name: | ||
| self._end_element("DSML_parameter") | ||
| if self.current_function_name: | ||
| self._end_element("DSML_invoke") | ||
| # If this chunk contains </|DSML|function_calls> | ||
| # but didn't generate final empty delta, then complete it | ||
| if ( | ||
| self.current_call_id is not None | ||
| and "</|DSML|function_calls>" in xml_chunk | ||
| ): | ||
| has_container_close = any( | ||
| ( | ||
| td.tool_calls | ||
| and any( | ||
| ( | ||
| tc.type == "function" | ||
| and tc.function | ||
| and tc.function.arguments == "" | ||
| and tc.id == self.current_call_id | ||
| ) | ||
| for tc in td.tool_calls | ||
| ) | ||
| ) | ||
| for td in new_deltas | ||
| ) | ||
| if not has_container_close: | ||
| # Close potentially unclosed element | ||
| if self.current_param_name: | ||
| self._end_element("DSML_parameter") | ||
| if self.current_function_name: | ||
| self._end_element("DSML_invoke") | ||
| self._end_element("DSML_function_calls") | ||
| except Exception as e: | ||
| logger.warning("Error with fallback parsing: %s", e) | ||
| # Merge newly generated deltas into single response | ||
| result_delta = self._merge_new_deltas_to_single_response( | ||
| initial_delta_count | ||
| ) | ||
| return result_delta | ||
| else: | ||
| # No complete elements, check if there's unoutput text content | ||
| if self.text_content_buffer and self.tool_call_index == 0: | ||
| # Has text content but no tool_call yet, output text content | ||
| text_delta = DeltaMessage(content=self.text_content_buffer) | ||
| self._emit_delta(text_delta) | ||
| # Clear buffer to avoid duplicate output | ||
| self.text_content_buffer = "" | ||
| return text_delta | ||
|
|
||
| # If this chunk contains end tags but wasn't triggered by parser, | ||
| # manually complete end events | ||
| # Only execute when still on the same call as when entered, | ||
| # to prevent accidentally closing new calls | ||
| # in multi invoke scenarios | ||
| if self.current_call_id is not None and ( | ||
| "</|DSML|invoke>" in xml_chunk | ||
| or "</|DSML|function_calls>" in xml_chunk | ||
| ): | ||
| # Close potentially unclosed element | ||
| if self.current_param_name: | ||
| self._end_element("DSML_parameter") | ||
| if "</|DSML|invoke>" in xml_chunk and self.current_function_name: | ||
| self._end_element("DSML_invoke") | ||
| if "</|DSML|function_calls>" in xml_chunk: | ||
| self._end_element("DSML_function_calls") | ||
| # Return the merged delta result generated by this fallback | ||
| result_delta = self._merge_new_deltas_to_single_response( | ||
| initial_delta_count | ||
| ) | ||
| return result_delta | ||
|
|
||
| # No complete elements, return empty response | ||
| return DeltaMessage(content=None) | ||
|
|
||
| def _escape_xml_special_chars(self, text: str) -> str: | ||
| """ | ||
| Escape XML special characters | ||
| Args: | ||
| text: Original text | ||
| Returns: | ||
| Escaped text | ||
| """ | ||
| xml_escapes = { | ||
| "&": "&", | ||
| "<": "<", | ||
| ">": ">", | ||
| '"': """, | ||
| "'": "'", | ||
| } | ||
|
|
||
| for char, escape in xml_escapes.items(): | ||
| text = text.replace(char, escape) | ||
|
|
||
| return text | ||
|
|
||
| def _process_complete_xml_elements(self) -> bool: | ||
| """ | ||
| Process complete XML elements in buffer | ||
| Returns: | ||
| bool: Whether complete elements were found and processed | ||
| """ | ||
| found_any = False | ||
| while self.last_processed_pos < len(self.streaming_buffer): | ||
| # Find next complete xml element | ||
| element, end_pos = self._find_next_complete_element(self.last_processed_pos) | ||
| if element is None: | ||
| # No complete element found, wait for more data | ||
| break | ||
|
|
||
| # Check if this element should be skipped | ||
| if self._should_skip_element(element): | ||
| self.last_processed_pos = end_pos | ||
| continue | ||
|
|
||
| # Found complete XML element, process it | ||
| try: | ||
| preprocessed_element = self._preprocess_xml_chunk(element) | ||
| # Check if this is the first tool_call start | ||
| # Note: after preprocessing, <|DSML|invoke becomes <DSML_invoke | ||
| if ( | ||
| ( | ||
| preprocessed_element.strip().startswith("<DSML_invoke") | ||
| or preprocessed_element.strip().startswith( | ||
| "<DSML_function_calls" | ||
| ) | ||
| ) | ||
| and self.tool_call_index == 0 | ||
| ) and self.text_content_buffer: | ||
| # First tool_call starts, | ||
| # output previously collected text content first | ||
| text_delta = DeltaMessage(content=self.text_content_buffer) | ||
| self._emit_delta(text_delta) | ||
| # Clear buffer for potential subsequent text content | ||
| self.text_content_buffer = "" | ||
|
|
||
| # If a new invoke starts and | ||
| # there are already completed invokes | ||
| if ( | ||
| preprocessed_element.strip().startswith("<DSML_invoke") | ||
| and self.tool_call_index > 0 | ||
| and self.current_call_id | ||
| ): | ||
| # Reset parser state but preserve generated deltas | ||
| if self.current_param_name: | ||
| self._end_element("DSML_parameter") | ||
| if self.current_function_open or self.current_function_name: | ||
| self._end_element("DSML_invoke") | ||
| # Output final invoke tail delta | ||
| final_delta = DeltaMessage( | ||
| role=None, | ||
| content=None, | ||
| reasoning=None, | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall(name=None, arguments=""), | ||
| ) | ||
| ], | ||
| ) | ||
| self._emit_delta(final_delta) | ||
| # Reset current call state (lightweight reset for DeepSeek) | ||
| self._reset_current_call_state() | ||
| # Parse preprocessed element | ||
| self.parser.Parse(preprocessed_element, False) | ||
| found_any = True | ||
|
|
||
| except Exception as e: | ||
| logger.warning("Error when parsing XML elements: %s", e) | ||
|
|
||
| # Update processed position | ||
| self.last_processed_pos = end_pos | ||
|
|
||
| return found_any | ||
|
|
||
| def _should_skip_element(self, element: str) -> bool: | ||
| """ | ||
| Determine whether an element should be skipped | ||
| Args: | ||
| element: Element to evaluate | ||
| Returns: | ||
| bool: True means should skip, False means should process | ||
| """ | ||
|
|
||
| # Check if it's a DeepSeek tool_call XML tag, don't skip | ||
| if ( | ||
| element.startswith(self.tool_call_start_token) | ||
| or element.startswith(self.function_start_token) | ||
| or element.startswith(self.parameter_start_token) | ||
| ): | ||
| return False | ||
|
|
||
| # If currently not parsing tool calls and not blank, | ||
| # collect this text instead of skipping | ||
| # Only process other XML elements after tool_call appears, | ||
| # otherwise treat as plain text | ||
| if self.current_call_id is None and element: | ||
| # Collect text content to buffer | ||
| self.text_content_buffer += element | ||
| return True # Still skip, but content has been collected | ||
|
|
||
| # If currently parsing tool calls, | ||
| # this might be parameter value, don't skip | ||
| if self.current_call_id is not None: | ||
| return False | ||
|
|
||
| # Skip blank content | ||
| return not element | ||
|
|
||
| def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]: | ||
| """ | ||
| Find next complete XML element from specified position | ||
| Args: | ||
| start_pos: Position to start searching | ||
| Returns: | ||
| (Complete element string, element end position), | ||
| returns (None, start_pos) if no complete element found | ||
| """ | ||
| buffer = self.streaming_buffer[start_pos:] | ||
|
|
||
| if not buffer: | ||
| return None, start_pos | ||
|
|
||
| if buffer.startswith("<"): | ||
| # Need to ensure no new < appears, | ||
| # find the nearest one between < and > | ||
| tag_end = buffer.find("<", 1) | ||
| tag_end2 = buffer.find(">", 1) | ||
| if tag_end != -1 and tag_end2 != -1: | ||
| # Next nearest is < | ||
| if tag_end < tag_end2: | ||
| return buffer[:tag_end], start_pos + tag_end | ||
| # Next nearest is >, means found XML element | ||
| else: | ||
| return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 | ||
| elif tag_end != -1: | ||
| return buffer[:tag_end], start_pos + tag_end | ||
| elif tag_end2 != -1: | ||
| return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 | ||
| else: | ||
| # If currently not parsing tool calls (entering a tool_call), | ||
| # check if starts with DeepSeek tags | ||
| if self.current_call_id is None: | ||
| # Check if might be start of DeepSeek tags | ||
| if ( | ||
| buffer.startswith("<DSML_function_calls") | ||
| or buffer == "<DSML_function_calls"[: len(buffer)] | ||
| or buffer.startswith("<DSML_invoke") | ||
| or buffer == "<DSML_invoke"[: len(buffer)] | ||
| ): | ||
| # Might be start of DeepSeek tag, wait for more data | ||
| return None, start_pos | ||
| else: | ||
| # Not start of DeepSeek tag, treat as text | ||
| return buffer, start_pos + len(buffer) | ||
| else: | ||
| # When parsing tool calls, | ||
| # wait for more data to get complete tag | ||
| return None, start_pos | ||
| else: | ||
| # Find text content (until next < or buffer end) | ||
| next_tag_pos = buffer.find("<") | ||
| if next_tag_pos != -1: | ||
| # Found text content | ||
| text_content = buffer[:next_tag_pos] | ||
| return text_content, start_pos + next_tag_pos | ||
| else: | ||
| # Buffer end is all text, process | ||
| # (no longer wait for more data) | ||
| remaining = buffer | ||
| return remaining, start_pos + len(remaining) | ||
|
|
||
| def _merge_new_deltas_to_single_response(self, initial_count: int) -> DeltaMessage: | ||
| """ | ||
| Merge newly generated deltas from this processing | ||
| into a single DeltaMessage | ||
| Args: | ||
| initial_count: Delta count before processing | ||
| Returns: | ||
| Merged DeltaMessage containing all newly generated delta information | ||
| """ | ||
| if len(self.deltas) <= initial_count: | ||
| return DeltaMessage(content=None) | ||
|
|
||
| # Get newly generated deltas | ||
| new_deltas = self.deltas[initial_count:] | ||
|
|
||
| if len(new_deltas) == 1: | ||
| # Only one new delta, return directly | ||
| return new_deltas[0] | ||
|
|
||
| # Merge multiple new deltas | ||
| merged_tool_calls: list[DeltaToolCall] = [] | ||
| merged_content: str = "" | ||
|
|
||
| for delta in new_deltas: | ||
| if delta.content: | ||
| merged_content += delta.content | ||
| if delta.tool_calls: | ||
| # For tool_calls, we need to intelligently merge arguments | ||
| for tool_call in delta.tool_calls: | ||
| # Find if there's already a tool_call with the same call_id | ||
| existing_call = None | ||
| for existing in merged_tool_calls: | ||
| if existing.id == tool_call.id: | ||
| existing_call = existing | ||
| break | ||
|
|
||
| if existing_call and existing_call.function: | ||
| # Merge to existing tool_call | ||
| if tool_call.function and tool_call.function.name: | ||
| existing_call.function.name = tool_call.function.name | ||
| if ( | ||
| tool_call.function | ||
| and tool_call.function.arguments is not None | ||
| ): | ||
| if existing_call.function.arguments is None: | ||
| existing_call.function.arguments = "" | ||
|
|
||
| # For streaming JSON parameters, | ||
| # simply concatenate in order | ||
| new_args = tool_call.function.arguments | ||
| existing_call.function.arguments += new_args | ||
| if tool_call.type: | ||
| existing_call.type = tool_call.type | ||
| else: | ||
| # Add new tool_call | ||
| merged_tool_calls.append(tool_call) | ||
|
|
||
| return DeltaMessage( | ||
| content=merged_content if merged_content else None, | ||
| tool_calls=merged_tool_calls, | ||
| ) | ||
|
|
||
| def _preprocess_xml_chunk(self, chunk: str) -> str: | ||
| """ | ||
| Preprocess XML chunk, handle DeepSeek special characters, | ||
| and escape special characters | ||
| Args: | ||
| chunk: Original XML chunk | ||
| Returns: | ||
| Processed XML chunk | ||
| """ | ||
|
|
||
| # Step 1: Normalize DeepSeek special characters to XML-safe characters | ||
| processed = chunk.replace("|DSML|", "DSML_") | ||
|
|
||
| # Check if this is a tool_call related element | ||
| is_tool_call = False | ||
| if ( | ||
| "DSML_function_calls" in processed | ||
| or "DSML_invoke" in processed | ||
| or "DSML_parameter" in processed | ||
| ): | ||
| is_tool_call = True | ||
|
|
||
| # DeepSeek already uses standard attribute format, no conversion needed | ||
| # <DSML_invoke name="get_weather"> ✓ | ||
| # <DSML_parameter name="location" string="true"> ✓ | ||
|
|
||
| original_chunk = chunk | ||
| # If in parameter value accumulation mode | ||
| if self._pre_inside_parameter: | ||
| # Parameter end: output accumulated raw text | ||
| # safely then return </parameter> | ||
| if processed.startswith("</DSML_parameter>"): | ||
| body_text = self._pre_param_buffer | ||
| # Trigger deferred parsing mode | ||
| # literal_eval+json output in end_element | ||
| self.defer_current_parameter = True | ||
| self.deferred_param_raw_value = body_text | ||
| # Clean up state | ||
| self._pre_inside_parameter = False | ||
| self._pre_param_buffer = "" | ||
| self._pre_current_param_name = None | ||
| safe_text = self._escape_xml_special_chars(body_text) | ||
| return f"{safe_text}</DSML_parameter>" | ||
| else: | ||
| # If this is the first block of content after entering parameter | ||
| # evaluate if deferred parsing is needed; | ||
| # If not needed, exit accumulation mode | ||
| # and pass through directly | ||
| if self._pre_param_buffer == "": | ||
| # Get current parameter type | ||
| param_type = ( | ||
| self._get_param_type(self._pre_current_param_name) | ||
| if self._pre_current_param_name | ||
| else "string" | ||
| ) | ||
| # Only these types need deferred parsing to | ||
| # handle Python literals containing single quotes | ||
| is_object_type = param_type in ["object"] | ||
| is_complex_type = ( | ||
| param_type in ["array", "arr", "sequence"] | ||
| or param_type.startswith("dict") | ||
| or param_type.startswith("list") | ||
| ) | ||
|
|
||
| # Only delay when contains container symbols | ||
| # and has single quotes and is complex type | ||
| has_container_hint = ( | ||
| ("[" in original_chunk) | ||
| or ("{" in original_chunk) | ||
| or ("(" in original_chunk) | ||
| ) | ||
|
|
||
| # Determine if deferred parsing is needed | ||
| need_defer = False | ||
| if is_complex_type: | ||
| # Complex type, always need deferred parsing | ||
| need_defer = True | ||
| elif ( | ||
| is_object_type | ||
| and has_container_hint | ||
| and ("'" in original_chunk) | ||
| ): | ||
| # Object type with container symbols | ||
| # and single quotes, need deferred parsing | ||
| need_defer = True | ||
|
|
||
| if not need_defer: | ||
| # No need for deferred parsing, | ||
| # exit parameter mode directly | ||
| self._pre_inside_parameter = False | ||
| return self._escape_xml_special_chars(original_chunk) | ||
| self._pre_param_buffer += original_chunk | ||
| return "" | ||
|
|
||
| # Parameter start: enable accumulation | ||
| if processed.startswith("<DSML_parameter"): | ||
| # Extract parameter name and type attributes | ||
| m = re.match( | ||
| r'<DSML_parameter\s+name="([^"]+)"(?:\s+string="(true|false)")?(?:\s+number="(true|false)")?(?:\s+boolean="(true|false)")?(?:\s+object="(true|false)")?(?:\s+array="(true|false)")?', | ||
| processed, | ||
| ) | ||
| if m: | ||
| self._pre_current_param_name = m.group(1) | ||
| self._pre_inside_parameter = True | ||
| self._pre_param_buffer = "" | ||
| return processed | ||
|
|
||
| # If processed doesn't contain special_token, escape processed | ||
| # This is because XML parsing encounters special characters | ||
| # and reports errors, so escaping is needed | ||
| if not is_tool_call: | ||
| processed = self._escape_xml_special_chars(processed) | ||
| return processed | ||
|
|
||
| def _emit_delta(self, delta: DeltaMessage): | ||
| """Emit Delta response (streaming output)""" | ||
| self.deltas.append(delta) | ||
|
|
||
| def _auto_close_open_parameter_if_needed(self, incoming_tag: str | None = None): | ||
| """Before starting to process new elements, | ||
| if there are unclosed tags from before, | ||
| automatically complete their endings to the parser. | ||
| - If there are unclosed parameters, | ||
| it's equivalent to feeding `</parameter>` | ||
| - When about to start a new function or tool_call, | ||
| if there are unclosed functions, complete `</function>`. | ||
| - When about to start a new tool_call, | ||
| if there are unclosed tool_calls, complete `</tool_call>`. | ||
| """ | ||
| # First close unclosed parameters | ||
| if self.current_param_name: | ||
| self._end_element("parameter") | ||
|
|
||
| # If about to start new function or tool_call, | ||
| # and there are unclosed functions, close function first | ||
| if incoming_tag in ("function", "tool_call") and self.current_function_name: | ||
| self._end_element("function") | ||
|
|
||
| # If about to start new tool_call, | ||
| # and there are unclosed tool_calls, close tool_call first | ||
| if incoming_tag == "tool_call" and self.current_call_id: | ||
| self._end_element("tool_call") | ||
|
|
||
| def _start_element(self, name: str, attrs: dict[str, str]): | ||
| """Handle XML start element events for DeepSeek format""" | ||
|
|
||
| if name == "root": | ||
| return | ||
|
|
||
| # Handle function calls | ||
| if name == "DSML_function_calls": | ||
| # Before opening new tool_call, | ||
| # automatically complete previous unclosed tags | ||
| self._auto_close_open_parameter_if_needed("DSML_function_calls") | ||
|
|
||
| self.parameters = {} | ||
| self.current_call_id = make_tool_call_id() | ||
| self.current_param_is_first = True | ||
| self.tool_call_index += 1 | ||
|
|
||
| # Handle invoke (equivalent to Qwen3's tool_call + function combined) | ||
| if name == "DSML_invoke": | ||
| # If missing tool_call, manually complete | ||
| if not self.current_call_id: | ||
| self._start_element("DSML_function_calls", {}) | ||
|
|
||
| # Extract function name from invoke's name attribute | ||
| function_name = attrs.get("name") | ||
| if function_name: | ||
| self.current_function_name = function_name | ||
| self.current_function_open = True | ||
|
|
||
| # Emit delta with function name | ||
| delta = DeltaMessage( | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall( | ||
| name=function_name, arguments="" | ||
| ), | ||
| ) | ||
| ] | ||
| ) | ||
| self._emit_delta(delta) | ||
|
|
||
| # Handle parameter | ||
| elif name == "DSML_parameter": | ||
| # If no invoke yet, create one (fault tolerance) | ||
| if not self.current_call_id: | ||
| self._start_element("DSML_invoke", {"name": "unknown"}) | ||
|
|
||
| # Auto-close previous parameter if exists | ||
| self._auto_close_open_parameter_if_needed("DSML_parameter") | ||
|
|
||
| # Extract parameter name and type | ||
| param_name = attrs.get("name") | ||
| self.current_param_name = param_name | ||
| self.current_param_value = "" | ||
| self.current_param_value_converted = "" | ||
| self.start_quote_emitted = False | ||
|
|
||
| # Save explicit type information | ||
| if attrs.get("string") == "true": | ||
| self.current_param_explicit_type = "string" | ||
| elif attrs.get("number") == "true": | ||
| self.current_param_explicit_type = "number" | ||
| elif attrs.get("boolean") == "true": | ||
| self.current_param_explicit_type = "boolean" | ||
| elif attrs.get("object") == "true": | ||
| self.current_param_explicit_type = "object" | ||
| elif attrs.get("array") == "true": | ||
| self.current_param_explicit_type = "array" | ||
| else: | ||
| self.current_param_explicit_type = None | ||
|
|
||
| # Output JSON parameter name and colon | ||
| if param_name: | ||
| if not self.parameters: | ||
| # First parameter - start JSON object | ||
| json_start = f'{{"{param_name}": ' | ||
| delta = DeltaMessage( | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall( | ||
| name=None, arguments=json_start | ||
| ), | ||
| ) | ||
| ] | ||
| ) | ||
| self._emit_delta(delta) | ||
| self.current_param_is_first = True | ||
| else: | ||
| # Subsequent parameters - add comma | ||
| json_continue = f', "{param_name}": ' | ||
| delta = DeltaMessage( | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall( | ||
| name=None, arguments=json_continue | ||
| ), | ||
| ) | ||
| ] | ||
| ) | ||
| self._emit_delta(delta) | ||
| self.current_param_is_first = False | ||
|
|
||
| def _char_data(self, data: str): | ||
| """Handle XML character data events""" | ||
| if data and self.current_param_name: | ||
| # If preprocessing stage determines deferred parsing is needed, | ||
| # only cache character data, no streaming output | ||
| if self.defer_current_parameter: | ||
| original_data = data | ||
| if self.should_emit_end_newline: | ||
| original_data = "\n" + original_data | ||
| self.should_emit_end_newline = False | ||
| if original_data.endswith("\n"): | ||
| self.should_emit_end_newline = True | ||
| original_data = original_data[:-1] | ||
| self.current_param_value += original_data | ||
| return | ||
|
|
||
| param_type = self._get_effective_param_type(self.current_param_name) | ||
|
|
||
| # Check if this is the first time receiving data for this parameter | ||
| # If this is the first packet of data and starts with \n, remove \n | ||
| if not self.current_param_value and data.startswith("\n"): | ||
| data = data[1:] | ||
|
|
||
| # Output start quote for string type (if not already output) | ||
| if ( | ||
| param_type in ["string", "str", "text", "varchar", "char", "enum"] | ||
| and not self.start_quote_emitted | ||
| ): | ||
| quote_delta = DeltaMessage( | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall(name=None, arguments='"'), | ||
| ) | ||
| ] | ||
| ) | ||
| self._emit_delta(quote_delta) | ||
| self.start_quote_emitted = True | ||
|
|
||
| if not data: | ||
| return | ||
|
|
||
| original_data = data | ||
| # Delay output of trailing newline | ||
| if self.should_emit_end_newline: | ||
| original_data = "\n" + original_data | ||
| self.should_emit_end_newline = False | ||
| if original_data.endswith("\n"): | ||
| self.should_emit_end_newline = True | ||
| original_data = original_data[:-1] | ||
| self.current_param_value += original_data | ||
|
|
||
| # convert parameter value by param_type | ||
| converted_value = self._convert_param_value( | ||
| self.current_param_value, param_type | ||
| ) | ||
| output_data = self._convert_for_json_streaming(converted_value, param_type) | ||
|
|
||
| delta_data = output_data[len(self.current_param_value_converted) :] | ||
| self.current_param_value_converted = output_data | ||
|
|
||
| delta = DeltaMessage( | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall(name=None, arguments=delta_data), | ||
| ) | ||
| ] | ||
| ) | ||
| self._emit_delta(delta) | ||
|
|
||
| def _end_element(self, name: str): | ||
| """Handle XML end element events for DeepSeek format""" | ||
|
|
||
| if name == "root": | ||
| return | ||
|
|
||
| # If invoke ends and there are still unclosed parameters, | ||
| # complete parameter end first | ||
| if ( | ||
| name == "DSML_invoke" or name == "DSML_function_calls" | ||
| ) and self.current_param_name: | ||
| self._auto_close_open_parameter_if_needed() | ||
|
|
||
| # Handle parameter ending | ||
| if name == "DSML_parameter" and self.current_param_name: | ||
| # End current parameter | ||
| param_name = self.current_param_name | ||
| param_value = self.current_param_value | ||
|
|
||
| # If in deferred parsing mode, | ||
| # perform overall parsing on raw content | ||
| # accumulated in preprocessing stage and output once | ||
| if self.defer_current_parameter: | ||
| raw_text = ( | ||
| self.deferred_param_raw_value | ||
| if self.deferred_param_raw_value | ||
| else param_value | ||
| ) | ||
| parsed_value = None | ||
| output_arguments = None | ||
| try: | ||
| # If previously delayed trailing newline, | ||
| # add it back before parsing | ||
| if self.should_emit_end_newline: | ||
| raw_for_parse = raw_text + "\n" | ||
| else: | ||
| raw_for_parse = raw_text | ||
| parsed_value = ast.literal_eval(raw_for_parse) | ||
| output_arguments = json.dumps(parsed_value, ensure_ascii=False) | ||
| except Exception: | ||
| # Fallback: output as string as-is | ||
| output_arguments = json.dumps(raw_text, ensure_ascii=False) | ||
| parsed_value = raw_text | ||
|
|
||
| delta = DeltaMessage( | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall( | ||
| name=None, arguments=output_arguments | ||
| ), | ||
| ) | ||
| ] | ||
| ) | ||
| self._emit_delta(delta) | ||
|
|
||
| # Clean up and store | ||
| self.should_emit_end_newline = False | ||
| self.parameters[param_name] = parsed_value | ||
| self.current_param_name = None | ||
| self.current_param_value = "" | ||
| self.current_param_value_converted = "" | ||
| self.start_quote_emitted = False | ||
| self.defer_current_parameter = False | ||
| self.deferred_param_raw_value = "" | ||
| # Reset explicit type | ||
| self.current_param_explicit_type = None | ||
| return | ||
|
|
||
| param_type = self._get_effective_param_type(param_name) | ||
|
|
||
| # convert complete parameter value by param_type | ||
| converted_value = self._convert_param_value(param_value, param_type) | ||
|
|
||
| # Decide whether to add end quote based on parameter type | ||
| if param_type in ["string", "str", "text", "varchar", "char", "enum"]: | ||
| # For empty string parameters, need special handling | ||
| if not param_value and not self.start_quote_emitted: | ||
| # No start quote output, | ||
| # directly output complete empty string | ||
| delta = DeltaMessage( | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall(name=None, arguments='""'), | ||
| ) | ||
| ] | ||
| ) | ||
| self._emit_delta(delta) | ||
| else: | ||
| # Non-empty parameter value, output end quote | ||
| delta = DeltaMessage( | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall(name=None, arguments='"'), | ||
| ) | ||
| ] | ||
| ) | ||
| self._emit_delta(delta) | ||
|
|
||
| self.should_emit_end_newline = False | ||
| # Store converted value | ||
| self.parameters[param_name] = converted_value | ||
| self.current_param_name = None | ||
| self.current_param_value = "" | ||
| self.current_param_value_converted = "" | ||
| self.start_quote_emitted = False | ||
| # Reset explicit type | ||
| self.current_param_explicit_type = None | ||
|
|
||
| # Handle invoke ending (equivalent to Qwen3's function + tool_call) | ||
| elif name == "DSML_invoke": | ||
| # Ensure parameter is closed first | ||
| if self.current_param_name: | ||
| self._end_element("DSML_parameter") | ||
|
|
||
| # Close JSON object | ||
| if self.parameters: | ||
| delta = DeltaMessage( | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall(name=None, arguments="}"), | ||
| ) | ||
| ] | ||
| ) | ||
| self._emit_delta(delta) | ||
| else: | ||
| # Empty parameters - output {} | ||
| delta = DeltaMessage( | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall(name=None, arguments="{}"), | ||
| ) | ||
| ] | ||
| ) | ||
| self._emit_delta(delta) | ||
| self.current_function_open = False | ||
|
|
||
| elif name == "DSML_function_calls": | ||
| # Before ending tool_call, | ||
| # ensure function is closed to complete missing right brace | ||
| if self.current_function_open: | ||
| # If there are still unclosed parameters, close them first | ||
| if self.current_param_name: | ||
| self._end_element("DSML_parameter") | ||
| # Close function, ensure output '}' or '{}' | ||
| self._end_element("DSML_invoke") | ||
| delta = DeltaMessage( | ||
| tool_calls=[ | ||
| DeltaToolCall( | ||
| index=self.tool_call_index - 1, | ||
| id=self.current_call_id, | ||
| type="function", | ||
| function=DeltaFunctionCall(name=None, arguments=""), | ||
| ) | ||
| ] | ||
| ) | ||
| self._emit_delta(delta) | ||
|
|
||
| # Check if there's text content to output (between invokes) | ||
| if self.text_content_buffer.strip(): | ||
| text_delta = DeltaMessage(content=self.text_content_buffer) | ||
| self._emit_delta(text_delta) | ||
|
|
||
| # Lightweight reset for next invoke (don't reset parser) | ||
| self._reset_current_call_state() | ||
|
|
||
| def setup_parser(self): | ||
| """Set up XML parser event handlers""" | ||
| self.parser.buffer_text = True | ||
| self.parser.StartElementHandler = self._start_element | ||
| self.parser.EndElementHandler = self._end_element | ||
| self.parser.CharacterDataHandler = self._char_data | ||
|
|
||
| def set_tools(self, tools: list[ChatCompletionToolsParam] | None): | ||
| """Set tool configuration information""" | ||
| self.tools = tools | ||
|
|
||
| def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None: | ||
| """Extract function name from various formats""" | ||
| if attrs and "name" in attrs: | ||
| return attrs["name"] | ||
|
|
||
| if "=" in name: | ||
| parts = name.split("=", 1) | ||
| if len(parts) == 2 and parts[0] == "function": | ||
| return parts[1] | ||
|
|
||
| return None | ||
|
|
||
| def _extract_parameter_name(self, name: str, attrs: dict[str, str]) -> str | None: | ||
| """Extract parameter name from various formats""" | ||
| if attrs and "name" in attrs: | ||
| return attrs["name"] | ||
|
|
||
| if "=" in name: | ||
| parts = name.split("=", 1) | ||
| if len(parts) == 2 and parts[0] == "parameter": | ||
| return parts[1] | ||
|
|
||
| return None | ||
|
|
||
| def _get_param_type(self, param_name: str) -> str: | ||
| """Get parameter type based on tool configuration, defaults to string | ||
| Args: | ||
| param_name: Parameter name | ||
| Returns: | ||
| Parameter type | ||
| """ | ||
| if not self.tools or not self.current_function_name: | ||
| return "string" | ||
|
|
||
| for tool in self.tools: | ||
| if not hasattr(tool, "type") or not ( | ||
| hasattr(tool, "function") and hasattr(tool.function, "name") | ||
| ): | ||
| continue | ||
| if ( | ||
| tool.type == "function" | ||
| and tool.function.name == self.current_function_name | ||
| ): | ||
| if not hasattr(tool.function, "parameters"): | ||
| return "string" | ||
| params = tool.function.parameters | ||
| if isinstance(params, dict) and "properties" in params: | ||
| properties = params["properties"] | ||
| if param_name in properties and isinstance( | ||
| properties[param_name], dict | ||
| ): | ||
| return self.repair_param_type( | ||
| str(properties[param_name].get("type", "string")) | ||
| ) | ||
| elif isinstance(params, dict) and param_name in params: | ||
| param_config = params[param_name] | ||
| if isinstance(param_config, dict): | ||
| return self.repair_param_type( | ||
| str(param_config.get("type", "string")) | ||
| ) | ||
| break | ||
| return "string" | ||
|
|
||
| def _get_effective_param_type(self, param_name: str) -> str: | ||
| """Get effective parameter type | ||
| Args: | ||
| param_name: Parameter name | ||
| Returns: | ||
| Effective parameter type | ||
| """ | ||
| # DeepSeek: if explicit type annotation exists, use it | ||
| if ( | ||
| hasattr(self, "current_param_explicit_type") | ||
| and self.current_param_explicit_type | ||
| ): | ||
| return self.current_param_explicit_type | ||
|
|
||
| # Otherwise infer from tool schema | ||
| return self._get_param_type(param_name) | ||
|
|
||
| def repair_param_type(self, param_type: str) -> str: | ||
| """Repair unknown parameter types by treating them as string | ||
| Args: | ||
| param_type: Parameter type | ||
| Returns: | ||
| Repaired parameter type | ||
| """ | ||
| if ( | ||
| param_type in ["string", "str", "text", "varchar", "char", "enum"] | ||
| or param_type.startswith("int") | ||
| or param_type.startswith("uint") | ||
| or param_type.startswith("long") | ||
| or param_type.startswith("short") | ||
| or param_type.startswith("unsigned") | ||
| or param_type.startswith("num") | ||
| or param_type.startswith("float") | ||
| or param_type in ["boolean", "bool", "binary"] | ||
| or ( | ||
| param_type in ["object", "array", "arr", "sequence"] | ||
| or param_type.startswith("dict") | ||
| or param_type.startswith("list") | ||
| ) | ||
| ): | ||
| return param_type | ||
| else: | ||
| return "string" | ||
|
|
||
| def _convert_param_value(self, param_value: str, param_type: str) -> Any: | ||
| """Convert value based on parameter type | ||
| Args: | ||
| param_value: Parameter value | ||
| param_type: Parameter type | ||
| Returns: | ||
| Converted value | ||
| """ | ||
| if param_value.lower() == "null": | ||
| return None | ||
|
|
||
| param_type = param_type.strip().lower() | ||
| if param_type in ["string", "str", "text", "varchar", "char", "enum"]: | ||
| return param_value | ||
| elif ( | ||
| param_type.startswith("int") | ||
| or param_type.startswith("uint") | ||
| or param_type.startswith("long") | ||
| or param_type.startswith("short") | ||
| or param_type.startswith("unsigned") | ||
| ): | ||
| try: | ||
| return int(param_value) | ||
| except (ValueError, TypeError): | ||
| logger.warning( | ||
| "Parsed value '%s' of parameter '%s' is not an integer " | ||
| "in tool '%s', degenerating to string.", | ||
| param_value, | ||
| ) | ||
| return param_value | ||
| elif param_type.startswith("num") or param_type.startswith("float"): | ||
| try: | ||
| float_param_value: float = float(param_value) | ||
| return ( | ||
| float_param_value | ||
| if float_param_value - int(float_param_value) != 0 | ||
| else int(float_param_value) | ||
| ) | ||
| except (ValueError, TypeError): | ||
| logger.warning( | ||
| "Parsed value '%s' of parameter '%s' is not a float " | ||
| "in tool '%s', degenerating to string.", | ||
| param_value, | ||
| ) | ||
| return param_value | ||
| elif param_type in ["boolean", "bool", "binary"]: | ||
| param_value = param_value.lower() | ||
| return param_value == "true" | ||
| else: | ||
| return param_value | ||
|
|
||
| def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str: | ||
| """Convert converted_value based on | ||
| whether it's empty and if type is string | ||
| Args: | ||
| converted_value: Converted value | ||
| param_type: Parameter type | ||
| Returns: | ||
| Converted string for streaming output | ||
| """ | ||
| # Check if value is empty, but exclude numeric 0 | ||
| if converted_value is None or converted_value == "": | ||
| return "" | ||
|
|
||
| if param_type in ["string", "str", "text", "varchar", "char", "enum"]: | ||
| # String type, remove double quotes | ||
| return json.dumps(converted_value, ensure_ascii=False)[1:-1] | ||
| else: | ||
| # Non-string type, return complete JSON string | ||
| if not isinstance(converted_value, str): | ||
| return json.dumps(converted_value, ensure_ascii=False) | ||
| else: | ||
| return converted_value | ||
|
|
||
| def _reset_xml_parser_after_tool_call(self): | ||
| """ | ||
| Each tool_call is treated as a separate XML document, | ||
| so we need to reset the parser after each tool_call. | ||
| """ | ||
|
|
||
| # recreate XML parser | ||
| self.parser = ParserCreate() | ||
| self.setup_parser() | ||
|
|
||
| # Reset current tool_call state | ||
| if self.current_call_id: | ||
| self.last_completed_call_id = self.current_call_id | ||
| self.current_call_id = None | ||
| self.current_function_name = None | ||
| self.current_function_open = False | ||
| self.parameters = {} | ||
| self.current_param_name = None | ||
| self.current_param_value = "" | ||
| self.current_param_value_converted = "" | ||
| self.current_param_is_first = False | ||
| self.should_emit_end_newline = False | ||
| self.start_quote_emitted = False | ||
| self.text_content_buffer = "" | ||
|
|
||
| # Reset preprocessing and deferred parsing state | ||
| self._pre_inside_parameter = False | ||
| self._pre_param_buffer = "" | ||
| self._pre_current_param_name = None | ||
| self.defer_current_parameter = False | ||
| self.deferred_param_raw_value = "" | ||
|
|
||
| def _reset_current_call_state(self): | ||
| """ | ||
| Lightweight reset for current invoke state. | ||
| Used for DeepSeek to handle multiple invokes within the same | ||
| function_calls container. | ||
| Does NOT reset the parser or deltas. | ||
| """ | ||
| # Reset current tool_call state | ||
| if self.current_call_id: | ||
| self.last_completed_call_id = self.current_call_id | ||
| self.current_call_id = None | ||
| self.current_function_name = None | ||
| self.current_function_open = False | ||
| self.parameters = {} | ||
| self.current_param_name = None | ||
| self.current_param_value = "" | ||
| self.current_param_value_converted = "" | ||
| self.current_param_is_first = False | ||
| self.should_emit_end_newline = False | ||
| self.start_quote_emitted = False | ||
| self.text_content_buffer = "" | ||
|
|
||
| # Reset preprocessing and deferred parsing state | ||
| self._pre_inside_parameter = False | ||
| self._pre_param_buffer = "" | ||
| self._pre_current_param_name = None | ||
| self.defer_current_parameter = False | ||
| self.deferred_param_raw_value = "" | ||
|
|
||
| # Reset explicit type | ||
| self.current_param_explicit_type = None | ||
|
|
||
| # Note: Do NOT reset: | ||
| # - self.parser (XML parser instance) | ||
| # - self.deltas (accumulated delta list) | ||
| # - self.tool_call_index (tool call index continues to increment) | ||
| # - self.streaming_buffer (streaming buffer) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The StreamingXMLToolCallParser class is extremely complex. It implements a stateful, streaming XML parser with deferred parsing logic, fallback mechanisms, and manual buffer management, essentially building a parser on top of Python's expat parser. This complexity makes the code very difficult to understand, debug, and maintain. For instance, the interaction between _find_next_complete_element, _process_complete_xml_elements, and _preprocess_xml_chunk creates a convoluted data flow that is hard to trace.
While parsing model-generated tool calls in a streaming fashion is inherently complex, this implementation seems overly so. It's worth investigating if a simpler design is possible that relies more on the expat parser's built-in streaming capabilities and reduces the amount of manual state management and pre-parsing logic. A simpler, more declarative implementation would be more robust and easier to maintain in the long run. Since this code is reused from another parser, consider refactoring the common logic into a shared base class to improve maintainability and avoid code duplication.
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import json | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm.entrypoints.openai.protocol import ( | ||
| ChatCompletionRequest, | ||
| ChatCompletionToolsParam, | ||
| FunctionCall, | ||
| ToolCall, | ||
| ) | ||
| from vllm.entrypoints.openai.tool_parsers.deepseek_v32_tool_parser import ( | ||
| DeepSeekV32ToolParser, | ||
| ) | ||
| from vllm.tokenizers import get_tokenizer | ||
|
|
||
| pytestmark = pytest.mark.cpu_test | ||
|
|
||
| MODEL = "deepseek-ai/DeepSeek-V3" | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def deepseek_tokenizer(): | ||
| return get_tokenizer(tokenizer_name=MODEL) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def deepseek_tool_parser(deepseek_tokenizer): | ||
| return DeepSeekV32ToolParser(deepseek_tokenizer) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def sample_tools(): | ||
| return [ | ||
| ChatCompletionToolsParam( | ||
| type="function", | ||
| function={ | ||
| "name": "get_weather", | ||
| "description": "Get the weather for a location", | ||
| "parameters": { | ||
| "type": "object", | ||
| "properties": { | ||
| "location": {"type": "string", "description": "The city name"}, | ||
| "date": {"type": "string", "description": "The date"}, | ||
| }, | ||
| "required": ["location", "date"], | ||
| }, | ||
| }, | ||
| ), | ||
| ] | ||
|
|
||
|
|
||
| def assert_tool_calls( | ||
| actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] | ||
| ): | ||
| assert len(actual_tool_calls) == len(expected_tool_calls) | ||
|
|
||
| for actual_tool_call, expected_tool_call in zip( | ||
| actual_tool_calls, expected_tool_calls | ||
| ): | ||
| assert actual_tool_call.type == "function" | ||
| assert actual_tool_call.function.name == expected_tool_call.function.name | ||
| try: | ||
| assert json.loads(actual_tool_call.function.arguments) == json.loads( | ||
| expected_tool_call.function.arguments | ||
| ) | ||
| except json.JSONDecodeError as e: | ||
| print(e) | ||
| print("actual_tool_call", actual_tool_call.function.arguments) | ||
| print("expected_tool_call", expected_tool_call.function.arguments) | ||
|
|
||
|
|
||
| def test_extract_tool_calls_single_function( | ||
| deepseek_tool_parser, | ||
| sample_tools, | ||
| ): | ||
| """Test extracting a single function call""" | ||
| model_output = """<|DSML|function_calls> | ||
| <|DSML|invoke name="get_weather"> | ||
| <|DSML|parameter name="location" string="true">Hangzhou</|DSML|parameter> | ||
| <|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter> | ||
| </|DSML|invoke> | ||
| </|DSML|function_calls>""" | ||
|
|
||
| expected_tool_calls = [ | ||
| ToolCall( | ||
| function=FunctionCall( | ||
| name="get_weather", | ||
| arguments=json.dumps({"location": "Hangzhou", "date": "2024-01-16"}), | ||
| ) | ||
| ), | ||
| ] | ||
|
|
||
| request = ChatCompletionRequest( | ||
| model="deepseek-v3", | ||
| messages=[{"role": "user", "content": "What is the weather?"}], | ||
| tools=sample_tools, | ||
| ) | ||
|
|
||
| extracted = deepseek_tool_parser.extract_tool_calls(model_output, request) | ||
|
|
||
| assert extracted.tools_called | ||
| assert extracted.content is None | ||
| assert_tool_calls(extracted.tool_calls, expected_tool_calls) | ||
|
|
||
|
|
||
| def test_extract_tool_calls_multiple_functions( | ||
| deepseek_tool_parser, | ||
| sample_tools, | ||
| ): | ||
| """Test extracting multiple function calls""" | ||
| model_output = """<|DSML|function_calls> | ||
| <|DSML|invoke name="get_weather"> | ||
| <|DSML|parameter name="location" string="true">Hangzhou</|DSML|parameter> | ||
| <|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter> | ||
| </|DSML|invoke> | ||
| <|DSML|invoke name="get_weather"> | ||
| <|DSML|parameter name="location" string="true">Beijing</|DSML|parameter> | ||
| <|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter> | ||
| </|DSML|invoke> | ||
| </|DSML|function_calls>""" | ||
|
|
||
| expected_tool_calls = [ | ||
| ToolCall( | ||
| function=FunctionCall( | ||
| name="get_weather", | ||
| arguments=json.dumps({"location": "Hangzhou", "date": "2024-01-16"}), | ||
| ) | ||
| ), | ||
| ToolCall( | ||
| function=FunctionCall( | ||
| name="get_weather", | ||
| arguments=json.dumps({"location": "Beijing", "date": "2024-01-16"}), | ||
| ) | ||
| ), | ||
| ] | ||
|
|
||
| request = ChatCompletionRequest( | ||
| model="deepseek-v3", | ||
| messages=[{"role": "user", "content": "What is the weather?"}], | ||
| tools=sample_tools, | ||
| ) | ||
|
|
||
| extracted = deepseek_tool_parser.extract_tool_calls(model_output, request) | ||
|
|
||
| assert extracted.tools_called | ||
| assert extracted.content is None | ||
| assert_tool_calls(extracted.tool_calls, expected_tool_calls) | ||
|
|
||
|
|
||
| def test_extract_tool_calls_with_end_of_sentence_token( | ||
| deepseek_tool_parser, | ||
| sample_tools, | ||
| ): | ||
| """Test extracting function calls with end-of-sentence token""" | ||
| model_output = """<|DSML|function_calls> | ||
| <|DSML|invoke name="get_weather"> | ||
| <|DSML|parameter name="location" string="true">Hangzhou</|DSML|parameter> | ||
| <|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter> | ||
| </|DSML|invoke> | ||
| </|DSML|function_calls><|end▁of▁sentence|>""" | ||
|
|
||
| expected_tool_calls = [ | ||
| ToolCall( | ||
| function=FunctionCall( | ||
| name="get_weather", | ||
| arguments=json.dumps({"location": "Hangzhou", "date": "2024-01-16"}), | ||
| ) | ||
| ), | ||
| ] | ||
|
|
||
| request = ChatCompletionRequest( | ||
| model="deepseek-v3", | ||
| messages=[{"role": "user", "content": "What is the weather?"}], | ||
| tools=sample_tools, | ||
| ) | ||
|
|
||
| extracted = deepseek_tool_parser.extract_tool_calls(model_output, request) | ||
|
|
||
| assert extracted.tools_called | ||
| assert extracted.content is None | ||
| assert_tool_calls(extracted.tool_calls, expected_tool_calls) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The added tests cover the basic cases of single and multiple function calls. However, they do not cover more complex edge cases that the parser is designed to handle, or which might expose its fragility. Given the complexity of the XML streaming parser, it's important to have comprehensive tests for:
- Parameter values containing special XML characters (e.g.,
<,>,&,',"). - Parameter values that are complex objects or arrays, especially those that would trigger the deferred parsing logic (e.g., containing Python literals with single quotes).
- Malformed XML output from the model (e.g., unclosed tags).
- Text content mixed with tool calls.
Adding these tests would significantly improve the robustness and maintainability of the tool parser.
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
Signed-off-by: wenmengzhou <[email protected]>
Signed-off-by: wenmengzhou <[email protected]>
Purpose
add tool parser support for deepseek v3.2
Test Plan
linter passed
add UT passed
Test Result
Essential Elements of an Effective PR Description Checklist