import json import re from typing import Dict, List, Sequence, Union import partial_json_parser from partial_json_parser.core.options import Allow from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaMessage, DeltaToolCall, DeltaFunctionCall, ExtractedToolCallInformation, ToolCall, FunctionCall ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager from vllm.utils import random_uuid from vllm.logger import init_logger from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix, is_complete_json, partial_json_loads) logger = init_logger(__name__) @ToolParserManager.register_module("xlam") class xLAMToolParser(ToolParser): def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) # State for streaming mode self.prev_tool_calls: List[Dict] = [] self.current_tools_sent: List[bool] = [] self.streamed_args: List[str] = [] # Remove regex since we're parsing direct JSON def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest ) -> ExtractedToolCallInformation: try: # Modified: Direct JSON parsing without looking for ``` if not model_output.strip().startswith('['): return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) tool_calls_data = json.loads(model_output) tool_calls: List[ToolCall] = [] for idx, call in enumerate(tool_calls_data): tool_call = ToolCall( id=f"call_{idx}_{random_uuid()}", type="function", function=FunctionCall( name=call["name"], arguments=json.dumps(call["arguments"]) ) ) tool_calls.append(tool_call) return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=None ) except Exception: logger.exception("Error extracting tool calls") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: if not current_text.strip().startswith('['): return DeltaMessage(content=delta_text) flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: tool_call_arr = [] is_complete = [] try: # Parse the JSON array start_idx = 0 while start_idx < len(current_text): obj, end_idx = partial_json_loads(current_text[start_idx:], flags) is_complete.append( is_complete_json(current_text[start_idx:start_idx + end_idx]) ) start_idx += end_idx tool_call_arr.append(obj) except partial_json_parser.core.exceptions.MalformedJSON: logger.debug('not enough tokens to parse into JSON yet') return None # Get current tool call based on state current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ if len(tool_call_arr) > 0 else {} # Case 1: No tools parsed yet if len(tool_call_arr) == 0: return None # Case 2: Starting a new tool in array elif (len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1): # Handle any remaining arguments from previous tool if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: cur_args_json = json.dumps(cur_arguments) sent = len(self.streamed_args[self.current_tool_id]) argument_diff = cur_args_json[sent:] if argument_diff: delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall( arguments=argument_diff ).model_dump(exclude_none=True) ) ]) self.streamed_args[self.current_tool_id] += argument_diff return delta # Setup new tool self.current_tool_id = len(tool_call_arr) - 1 self.current_tools_sent.append(False) self.streamed_args.append("") logger.debug("starting new tool %d", self.current_tool_id) return None # Case 3: Send tool name if not sent yet elif not self.current_tools_sent[self.current_tool_id]: function_name = current_tool_call.get("name") if function_name: delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_id, type="function", id=f"call_{self.current_tool_id}_{random_uuid()}", function=DeltaFunctionCall( name=function_name ).model_dump(exclude_none=True) ) ]) self.current_tools_sent[self.current_tool_id] = True return delta return None # Case 4: Stream arguments else: cur_arguments = current_tool_call.get("arguments") if cur_arguments: sent = len(self.streamed_args[self.current_tool_id]) cur_args_json = json.dumps(cur_arguments) prev_arguments = self.prev_tool_calls[self.current_tool_id].get("arguments") argument_diff = None if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: prev_args_json = json.dumps(prev_arguments) if cur_args_json != prev_args_json: prefix = find_common_prefix(prev_args_json, cur_args_json) argument_diff = prefix[sent:] if argument_diff is not None: delta = DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall( arguments=argument_diff ).model_dump(exclude_none=True) ) ]) self.streamed_args[self.current_tool_id] += argument_diff return delta self.prev_tool_calls = tool_call_arr return None except Exception: logger.exception("Error in streaming tool calls") logger.debug("Skipping chunk due to streaming error") return None