File size: 8,215 Bytes
46d09fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import json
import re
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, DeltaMessage, DeltaToolCall,
DeltaFunctionCall, ExtractedToolCallInformation, ToolCall, FunctionCall
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager
from vllm.utils import random_uuid
from vllm.logger import init_logger
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
is_complete_json,
partial_json_loads)
logger = init_logger(__name__)
@ToolParserManager.register_module("xlam")
class xLAMToolParser(ToolParser):
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# State for streaming mode
self.prev_tool_calls: List[Dict] = []
self.current_tools_sent: List[bool] = []
self.streamed_args: List[str] = []
# Remove regex since we're parsing direct JSON
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
try:
# Modified: Direct JSON parsing without looking for ```
if not model_output.strip().startswith('['):
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output
)
tool_calls_data = json.loads(model_output)
tool_calls: List[ToolCall] = []
for idx, call in enumerate(tool_calls_data):
tool_call = ToolCall(
id=f"call_{idx}_{random_uuid()}",
type="function",
function=FunctionCall(
name=call["name"],
arguments=json.dumps(call["arguments"])
)
)
tool_calls.append(tool_call)
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=None
)
except Exception:
logger.exception("Error extracting tool calls")
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if not current_text.strip().startswith('['):
return DeltaMessage(content=delta_text)
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
# Parse the JSON array
start_idx = 0
while start_idx < len(current_text):
obj, end_idx = partial_json_loads(current_text[start_idx:], flags)
is_complete.append(
is_complete_json(current_text[start_idx:start_idx + end_idx])
)
start_idx += end_idx
tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# Get current tool call based on state
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# Case 1: No tools parsed yet
if len(tool_call_arr) == 0:
return None
# Case 2: Starting a new tool in array
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# Handle any remaining arguments from previous tool
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(self.streamed_args[self.current_tool_id])
argument_diff = cur_args_json[sent:]
if argument_diff:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True)
)
])
self.streamed_args[self.current_tool_id] += argument_diff
return delta
# Setup new tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tools_sent.append(False)
self.streamed_args.append("")
logger.debug("starting new tool %d", self.current_tool_id)
return None
# Case 3: Send tool name if not sent yet
elif not self.current_tools_sent[self.current_tool_id]:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=f"call_{self.current_tool_id}_{random_uuid()}",
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True)
)
])
self.current_tools_sent[self.current_tool_id] = True
return delta
return None
# Case 4: Stream arguments
else:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
sent = len(self.streamed_args[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_calls[self.current_tool_id].get("arguments")
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True)
)
])
self.streamed_args[self.current_tool_id] += argument_diff
return delta
self.prev_tool_calls = tool_call_arr
return None
except Exception:
logger.exception("Error in streaming tool calls")
logger.debug("Skipping chunk due to streaming error")
return None
|