File size: 8,215 Bytes
46d09fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import json
import re
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest, DeltaMessage, DeltaToolCall,
    DeltaFunctionCall, ExtractedToolCallInformation, ToolCall, FunctionCall
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager
from vllm.utils import random_uuid
from vllm.logger import init_logger
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
                                                        is_complete_json,
                                                        partial_json_loads)

logger = init_logger(__name__)

@ToolParserManager.register_module("xlam")
class xLAMToolParser(ToolParser):
    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        super().__init__(tokenizer)
        # State for streaming mode
        self.prev_tool_calls: List[Dict] = []
        self.current_tools_sent: List[bool] = []
        self.streamed_args: List[str] = []
        # Remove regex since we're parsing direct JSON
        
    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest
    ) -> ExtractedToolCallInformation:
        try:
            # Modified: Direct JSON parsing without looking for ```
            if not model_output.strip().startswith('['):
                return ExtractedToolCallInformation(
                    tools_called=False,
                    tool_calls=[],
                    content=model_output
                )

            tool_calls_data = json.loads(model_output)
            tool_calls: List[ToolCall] = []
            
            for idx, call in enumerate(tool_calls_data):
                tool_call = ToolCall(
                    id=f"call_{idx}_{random_uuid()}",
                    type="function",
                    function=FunctionCall(
                        name=call["name"],
                        arguments=json.dumps(call["arguments"])
                    )
                )
                tool_calls.append(tool_call)

            return ExtractedToolCallInformation(
                tools_called=True,
                tool_calls=tool_calls,
                content=None
            )

        except Exception:
            logger.exception("Error extracting tool calls")
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=model_output
            )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:
        if not current_text.strip().startswith('['):
            return DeltaMessage(content=delta_text)

        flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
        
        try:
            tool_call_arr = []
            is_complete = []
            try:
                # Parse the JSON array
                start_idx = 0
                while start_idx < len(current_text):
                    obj, end_idx = partial_json_loads(current_text[start_idx:], flags)
                    is_complete.append(
                        is_complete_json(current_text[start_idx:start_idx + end_idx])
                    )
                    start_idx += end_idx
                    tool_call_arr.append(obj)
            except partial_json_parser.core.exceptions.MalformedJSON:
                logger.debug('not enough tokens to parse into JSON yet')
                return None

            # Get current tool call based on state
            current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
                if len(tool_call_arr) > 0 else {}

            # Case 1: No tools parsed yet
            if len(tool_call_arr) == 0:
                return None

            # Case 2: Starting a new tool in array
            elif (len(tool_call_arr) > 0 
                and len(tool_call_arr) > self.current_tool_id + 1):
                
                # Handle any remaining arguments from previous tool
                if self.current_tool_id >= 0:
                    cur_arguments = current_tool_call.get("arguments")
                    if cur_arguments:
                        cur_args_json = json.dumps(cur_arguments)
                        sent = len(self.streamed_args[self.current_tool_id])
                        argument_diff = cur_args_json[sent:]

                        if argument_diff:
                            delta = DeltaMessage(tool_calls=[
                                DeltaToolCall(
                                    index=self.current_tool_id,
                                    function=DeltaFunctionCall(
                                        arguments=argument_diff
                                    ).model_dump(exclude_none=True)
                                )
                            ])
                            self.streamed_args[self.current_tool_id] += argument_diff
                            return delta

                # Setup new tool
                self.current_tool_id = len(tool_call_arr) - 1
                self.current_tools_sent.append(False)
                self.streamed_args.append("")
                logger.debug("starting new tool %d", self.current_tool_id)
                return None

            # Case 3: Send tool name if not sent yet
            elif not self.current_tools_sent[self.current_tool_id]:
                function_name = current_tool_call.get("name")
                if function_name:
                    delta = DeltaMessage(tool_calls=[
                        DeltaToolCall(
                            index=self.current_tool_id,
                            type="function",
                            id=f"call_{self.current_tool_id}_{random_uuid()}",
                            function=DeltaFunctionCall(
                                name=function_name
                            ).model_dump(exclude_none=True)
                        )
                    ])
                    self.current_tools_sent[self.current_tool_id] = True
                    return delta
                return None

            # Case 4: Stream arguments
            else:
                cur_arguments = current_tool_call.get("arguments")
                if cur_arguments:
                    sent = len(self.streamed_args[self.current_tool_id])
                    cur_args_json = json.dumps(cur_arguments)
                    prev_arguments = self.prev_tool_calls[self.current_tool_id].get("arguments")

                    argument_diff = None
                    if is_complete[self.current_tool_id]:
                        argument_diff = cur_args_json[sent:]
                    elif prev_arguments:
                        prev_args_json = json.dumps(prev_arguments)
                        if cur_args_json != prev_args_json:
                            prefix = find_common_prefix(prev_args_json, cur_args_json)
                            argument_diff = prefix[sent:]

                    if argument_diff is not None:
                        delta = DeltaMessage(tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                function=DeltaFunctionCall(
                                    arguments=argument_diff
                                ).model_dump(exclude_none=True)
                            )
                        ])
                        self.streamed_args[self.current_tool_id] += argument_diff
                        return delta

            self.prev_tool_calls = tool_call_arr
            return None

        except Exception:
            logger.exception("Error in streaming tool calls")
            logger.debug("Skipping chunk due to streaming error")
            return None