Upload 11 files
Browse files- .gitattributes +3 -0
- img/q1.gif +3 -0
- img/q2.gif +3 -0
- img/q3.gif +3 -0
- pyproject.toml +2 -2
- run_example.py +28 -0
- run_txagent_app.py +293 -0
- src/txagent/__init__.py +6 -0
- src/txagent/toolrag.py +60 -0
- src/txagent/txagent.py +937 -0
- src/txagent/utils.py +117 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
img/q1.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
img/q2.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
img/q3.gif filter=lfs diff=lfs merge=lfs -text
|
img/q1.gif
ADDED
![]() |
Git LFS Details
|
img/q2.gif
ADDED
![]() |
Git LFS Details
|
img/q3.gif
ADDED
![]() |
Git LFS Details
|
pyproject.toml
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
[build-system]
|
2 |
-
requires = ["setuptools", "wheel"]
|
3 |
build-backend = "setuptools.build_meta"
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools", "wheel"]
|
3 |
build-backend = "setuptools.build_meta"
|
run_example.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from txagent import TxAgent
|
2 |
+
import os
|
3 |
+
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
4 |
+
|
5 |
+
|
6 |
+
model_name = 'mims-harvard/TxAgent-T1-Llama-3.1-8B'
|
7 |
+
rag_model_name = 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B'
|
8 |
+
multiagent = False
|
9 |
+
max_round = 20
|
10 |
+
init_rag_num = 0
|
11 |
+
step_rag_num = 10
|
12 |
+
|
13 |
+
agent = TxAgent(model_name,
|
14 |
+
rag_model_name,
|
15 |
+
enable_summary=False)
|
16 |
+
agent.init_model()
|
17 |
+
|
18 |
+
question = "Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?"
|
19 |
+
|
20 |
+
response = agent.run_multistep_agent(
|
21 |
+
question,
|
22 |
+
temperature=0.3,
|
23 |
+
max_new_tokens=1024,
|
24 |
+
max_token=90240,
|
25 |
+
call_agent=multiagent,
|
26 |
+
max_round=max_round)
|
27 |
+
|
28 |
+
print(f"\033[94m{response}\033[0m")
|
run_txagent_app.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import datetime
|
3 |
+
import sys
|
4 |
+
from txagent import TxAgent
|
5 |
+
import spaces
|
6 |
+
import gradio as gr
|
7 |
+
import os
|
8 |
+
import os
|
9 |
+
|
10 |
+
# Determine the directory where the current file is located
|
11 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
12 |
+
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
13 |
+
|
14 |
+
# Set an environment variable
|
15 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
16 |
+
|
17 |
+
|
18 |
+
DESCRIPTION = '''
|
19 |
+
<div>
|
20 |
+
<h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools </h1>
|
21 |
+
</div>
|
22 |
+
'''
|
23 |
+
INTRO = """
|
24 |
+
Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations. We introduce TxAgent, an AI agent that leverages multi-step reasoning and real-time biomedical knowledge retrieval across a toolbox of 211 expert-curated tools to navigate complex drug interactions, contraindications, and patient-specific treatment strategies, delivering evidence-grounded therapeutic decisions. TxAgent executes goal-oriented tool selection and iterative function calls to solve therapeutic tasks that require deep clinical understanding and cross-source validation. The ToolUniverse consolidates 211 tools linked to trusted sources, including all US FDA-approved drugs since 1939 and validated clinical insights from Open Targets.
|
25 |
+
"""
|
26 |
+
|
27 |
+
LICENSE = """
|
28 |
+
We welcome your feedback and suggestions to enhance your experience with TxAgent, and if you're interested in collaboration, please email Marinka Zitnik and Shanghua Gao.
|
29 |
+
|
30 |
+
### Medical Advice Disclaimer
|
31 |
+
DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE
|
32 |
+
The information, including but not limited to, text, graphics, images and other material contained on this website are for informational purposes only. No material on this site is intended to be a substitute for professional medical advice, diagnosis or treatment. Always seek the advice of your physician or other qualified health care provider with any questions you may have regarding a medical condition or treatment and before undertaking a new health care regimen, and never disregard professional medical advice or delay in seeking it because of something you have read on this website.
|
33 |
+
"""
|
34 |
+
|
35 |
+
PLACEHOLDER = """
|
36 |
+
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
|
37 |
+
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
|
38 |
+
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
|
39 |
+
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️
|
40 |
+
(top-right) to remove previous context before sumbmitting a new question.</p>
|
41 |
+
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
|
42 |
+
</div>
|
43 |
+
"""
|
44 |
+
|
45 |
+
css = """
|
46 |
+
h1 {
|
47 |
+
text-align: center;
|
48 |
+
display: block;
|
49 |
+
}
|
50 |
+
|
51 |
+
#duplicate-button {
|
52 |
+
margin: auto;
|
53 |
+
color: white;
|
54 |
+
background: #1565c0;
|
55 |
+
border-radius: 100vh;
|
56 |
+
}
|
57 |
+
.small-button button {
|
58 |
+
font-size: 12px !important;
|
59 |
+
padding: 4px 8px !important;
|
60 |
+
height: 6px !important;
|
61 |
+
width: 4px !important;
|
62 |
+
}
|
63 |
+
.gradio-accordion {
|
64 |
+
margin-top: 0px !important;
|
65 |
+
margin-bottom: 0px !important;
|
66 |
+
}
|
67 |
+
"""
|
68 |
+
|
69 |
+
chat_css = """
|
70 |
+
.gr-button { font-size: 20px !important; } /* Enlarges button icons */
|
71 |
+
.gr-button svg { width: 32px !important; height: 32px !important; } /* Enlarges SVG icons */
|
72 |
+
"""
|
73 |
+
|
74 |
+
# model_name = '/n/holylfs06/LABS/mzitnik_lab/Lab/shgao/bioagent/bio/alignment-handbook/data_new/L8-qlora-biov49v9v7v16_32k_chat01_merged'
|
75 |
+
model_name = 'mims-harvard/TxAgent-T1-Llama-3.1-8B'
|
76 |
+
rag_model_name = 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B'
|
77 |
+
|
78 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
79 |
+
|
80 |
+
|
81 |
+
question_examples = [
|
82 |
+
['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
|
83 |
+
['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
|
84 |
+
['A 30-year-old patient is taking Prozac to treat their depression. They were recently diagnosed with WHIM syndrome and require a treatment for that condition as well. Is Xolremdi suitable for this patient, considering contraindications?'],
|
85 |
+
]
|
86 |
+
|
87 |
+
new_tool_files = {
|
88 |
+
'new_tool': os.path.join(current_dir, 'data', 'new_tool.json'),
|
89 |
+
}
|
90 |
+
|
91 |
+
agent = TxAgent(model_name,
|
92 |
+
rag_model_name,
|
93 |
+
tool_files_dict=new_tool_files,
|
94 |
+
force_finish=True,
|
95 |
+
enable_checker=True,
|
96 |
+
step_rag_num=10,
|
97 |
+
seed=100,
|
98 |
+
additional_default_tools=['DirectResponse', 'RequireClarification'])
|
99 |
+
agent.init_model()
|
100 |
+
|
101 |
+
|
102 |
+
def update_model_parameters(enable_finish, enable_rag, enable_summary,
|
103 |
+
init_rag_num, step_rag_num, skip_last_k,
|
104 |
+
summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed):
|
105 |
+
# Update model instance parameters dynamically
|
106 |
+
updated_params = agent.update_parameters(
|
107 |
+
enable_finish=enable_finish,
|
108 |
+
enable_rag=enable_rag,
|
109 |
+
enable_summary=enable_summary,
|
110 |
+
init_rag_num=init_rag_num,
|
111 |
+
step_rag_num=step_rag_num,
|
112 |
+
skip_last_k=skip_last_k,
|
113 |
+
summary_mode=summary_mode,
|
114 |
+
summary_skip_last_k=summary_skip_last_k,
|
115 |
+
summary_context_length=summary_context_length,
|
116 |
+
force_finish=force_finish,
|
117 |
+
seed=seed,
|
118 |
+
)
|
119 |
+
|
120 |
+
return updated_params
|
121 |
+
|
122 |
+
|
123 |
+
def update_seed():
|
124 |
+
# Update model instance parameters dynamically
|
125 |
+
seed = random.randint(0, 10000)
|
126 |
+
updated_params = agent.update_parameters(
|
127 |
+
seed=seed,
|
128 |
+
)
|
129 |
+
return updated_params
|
130 |
+
|
131 |
+
|
132 |
+
def handle_retry(history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
|
133 |
+
print("Updated seed:", update_seed())
|
134 |
+
new_history = history[:retry_data.index]
|
135 |
+
previous_prompt = history[retry_data.index]['content']
|
136 |
+
|
137 |
+
print("previous_prompt", previous_prompt)
|
138 |
+
|
139 |
+
yield from agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
|
140 |
+
|
141 |
+
|
142 |
+
PASSWORD = "mypassword"
|
143 |
+
|
144 |
+
# Function to check if the password is correct
|
145 |
+
|
146 |
+
|
147 |
+
def check_password(input_password):
|
148 |
+
if input_password == PASSWORD:
|
149 |
+
return gr.update(visible=True), ""
|
150 |
+
else:
|
151 |
+
return gr.update(visible=False), "Incorrect password, try again!"
|
152 |
+
|
153 |
+
|
154 |
+
conversation_state = gr.State([])
|
155 |
+
|
156 |
+
# Gradio block
|
157 |
+
chatbot = gr.Chatbot(height=800, placeholder=PLACEHOLDER,
|
158 |
+
label='TxAgent', type="messages", show_copy_button=True)
|
159 |
+
|
160 |
+
with gr.Blocks(css=css) as demo:
|
161 |
+
gr.Markdown(DESCRIPTION)
|
162 |
+
gr.Markdown(INTRO)
|
163 |
+
default_temperature = 0.3
|
164 |
+
default_max_new_tokens = 1024
|
165 |
+
default_max_tokens = 81920
|
166 |
+
default_max_round = 30
|
167 |
+
temperature_state = gr.State(value=default_temperature)
|
168 |
+
max_new_tokens_state = gr.State(value=default_max_new_tokens)
|
169 |
+
max_tokens_state = gr.State(value=default_max_tokens)
|
170 |
+
max_round_state = gr.State(value=default_max_round)
|
171 |
+
chatbot.retry(handle_retry, chatbot, chatbot, temperature_state, max_new_tokens_state,
|
172 |
+
max_tokens_state, gr.Checkbox(value=False, render=False), conversation_state, max_round_state)
|
173 |
+
|
174 |
+
gr.ChatInterface(
|
175 |
+
fn=agent.run_gradio_chat,
|
176 |
+
chatbot=chatbot,
|
177 |
+
fill_height=True, fill_width=True, stop_btn=True,
|
178 |
+
additional_inputs_accordion=gr.Accordion(
|
179 |
+
label="⚙️ Inference Parameters", open=False, render=False),
|
180 |
+
additional_inputs=[
|
181 |
+
temperature_state, max_new_tokens_state, max_tokens_state,
|
182 |
+
gr.Checkbox(
|
183 |
+
label="Activate multi-agent reasoning mode (it requires additional time but offers a more comprehensive analysis).", value=False, render=False),
|
184 |
+
conversation_state,
|
185 |
+
max_round_state,
|
186 |
+
gr.Number(label="Seed", value=100, render=False)
|
187 |
+
],
|
188 |
+
examples=question_examples,
|
189 |
+
cache_examples=False,
|
190 |
+
css=chat_css,
|
191 |
+
)
|
192 |
+
|
193 |
+
with gr.Accordion("Settings", open=False):
|
194 |
+
|
195 |
+
# Define the sliders
|
196 |
+
temperature_slider = gr.Slider(
|
197 |
+
minimum=0,
|
198 |
+
maximum=1,
|
199 |
+
step=0.1,
|
200 |
+
value=default_temperature,
|
201 |
+
label="Temperature"
|
202 |
+
)
|
203 |
+
max_new_tokens_slider = gr.Slider(
|
204 |
+
minimum=128,
|
205 |
+
maximum=4096,
|
206 |
+
step=1,
|
207 |
+
value=default_max_new_tokens,
|
208 |
+
label="Max new tokens"
|
209 |
+
)
|
210 |
+
max_tokens_slider = gr.Slider(
|
211 |
+
minimum=128,
|
212 |
+
maximum=32000,
|
213 |
+
step=1,
|
214 |
+
value=default_max_tokens,
|
215 |
+
label="Max tokens"
|
216 |
+
)
|
217 |
+
max_round_slider = gr.Slider(
|
218 |
+
minimum=0,
|
219 |
+
maximum=50,
|
220 |
+
step=1,
|
221 |
+
value=default_max_round,
|
222 |
+
label="Max round")
|
223 |
+
|
224 |
+
# Automatically update states when slider values change
|
225 |
+
temperature_slider.change(
|
226 |
+
lambda x: x, inputs=temperature_slider, outputs=temperature_state)
|
227 |
+
max_new_tokens_slider.change(
|
228 |
+
lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
|
229 |
+
max_tokens_slider.change(
|
230 |
+
lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
|
231 |
+
max_round_slider.change(
|
232 |
+
lambda x: x, inputs=max_round_slider, outputs=max_round_state)
|
233 |
+
|
234 |
+
password_input = gr.Textbox(
|
235 |
+
label="Enter Password for More Settings", type="password")
|
236 |
+
incorrect_message = gr.Textbox(visible=False, interactive=False)
|
237 |
+
with gr.Accordion("⚙️ Settings", open=False, visible=False) as protected_accordion:
|
238 |
+
with gr.Row():
|
239 |
+
with gr.Column(scale=1):
|
240 |
+
with gr.Accordion("⚙️ Model Loading", open=False):
|
241 |
+
model_name_input = gr.Textbox(
|
242 |
+
label="Enter model path", value=model_name)
|
243 |
+
load_model_btn = gr.Button(value="Load Model")
|
244 |
+
load_model_btn.click(
|
245 |
+
agent.load_models, inputs=model_name_input, outputs=gr.Textbox(label="Status"))
|
246 |
+
with gr.Column(scale=1):
|
247 |
+
with gr.Accordion("⚙️ Functional Parameters", open=False):
|
248 |
+
# Create Gradio components for parameter inputs
|
249 |
+
enable_finish = gr.Checkbox(
|
250 |
+
label="Enable Finish", value=True)
|
251 |
+
enable_rag = gr.Checkbox(
|
252 |
+
label="Enable RAG", value=True)
|
253 |
+
enable_summary = gr.Checkbox(
|
254 |
+
label="Enable Summary", value=False)
|
255 |
+
init_rag_num = gr.Number(
|
256 |
+
label="Initial RAG Num", value=0)
|
257 |
+
step_rag_num = gr.Number(
|
258 |
+
label="Step RAG Num", value=10)
|
259 |
+
skip_last_k = gr.Number(label="Skip Last K", value=0)
|
260 |
+
summary_mode = gr.Textbox(
|
261 |
+
label="Summary Mode", value='step')
|
262 |
+
summary_skip_last_k = gr.Number(
|
263 |
+
label="Summary Skip Last K", value=0)
|
264 |
+
summary_context_length = gr.Number(
|
265 |
+
label="Summary Context Length", value=None)
|
266 |
+
force_finish = gr.Checkbox(
|
267 |
+
label="Force FinalAnswer", value=True)
|
268 |
+
seed = gr.Number(label="Seed", value=100)
|
269 |
+
# Button to submit and update parameters
|
270 |
+
submit_btn = gr.Button("Update Parameters")
|
271 |
+
|
272 |
+
# Display the updated parameters
|
273 |
+
updated_parameters_output = gr.JSON()
|
274 |
+
|
275 |
+
# When button is clicked, update parameters
|
276 |
+
submit_btn.click(fn=update_model_parameters,
|
277 |
+
inputs=[enable_finish, enable_rag, enable_summary, init_rag_num, step_rag_num, skip_last_k,
|
278 |
+
summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed],
|
279 |
+
outputs=updated_parameters_output)
|
280 |
+
# Button to submit the password
|
281 |
+
submit_button = gr.Button("Submit")
|
282 |
+
|
283 |
+
# When the button is clicked, check if the password is correct
|
284 |
+
submit_button.click(
|
285 |
+
check_password,
|
286 |
+
inputs=password_input,
|
287 |
+
outputs=[protected_accordion, incorrect_message]
|
288 |
+
)
|
289 |
+
gr.Markdown(LICENSE)
|
290 |
+
|
291 |
+
|
292 |
+
if __name__ == "__main__":
|
293 |
+
demo.launch(share=True)
|
src/txagent/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .txagent import TxAgent
|
2 |
+
from .toolrag import ToolRAGModel
|
3 |
+
__all__ = [
|
4 |
+
"TxAgent",
|
5 |
+
"ToolRAGModel",
|
6 |
+
]
|
src/txagent/toolrag.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
from .utils import get_md5
|
5 |
+
|
6 |
+
|
7 |
+
class ToolRAGModel:
|
8 |
+
def __init__(self, rag_model_name):
|
9 |
+
self.rag_model_name = rag_model_name
|
10 |
+
self.rag_model = None
|
11 |
+
self.tool_desc_embedding = None
|
12 |
+
self.tool_name = None
|
13 |
+
self.tool_embedding_path = None
|
14 |
+
self.load_rag_model()
|
15 |
+
|
16 |
+
def load_rag_model(self):
|
17 |
+
self.rag_model = SentenceTransformer(self.rag_model_name)
|
18 |
+
self.rag_model.max_seq_length = 4096
|
19 |
+
self.rag_model.tokenizer.padding_side = "right"
|
20 |
+
|
21 |
+
def load_tool_desc_embedding(self, toolbox):
|
22 |
+
self.tool_name, _ = toolbox.refresh_tool_name_desc(
|
23 |
+
enable_full_desc=True)
|
24 |
+
all_tools_str = [json.dumps(
|
25 |
+
each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)]
|
26 |
+
md5_value = get_md5(str(all_tools_str))
|
27 |
+
print("get the md value of tools:", md5_value)
|
28 |
+
self.tool_embedding_path = self.rag_model_name.split(
|
29 |
+
'/')[-1] + "tool_embedding_" + md5_value + ".pt"
|
30 |
+
try:
|
31 |
+
self.tool_desc_embedding = torch.load(
|
32 |
+
self.tool_embedding_path, weights_only=False)
|
33 |
+
assert len(self.tool_desc_embedding) == len(
|
34 |
+
toolbox.all_tools), "The number of tools in the toolbox is not equal to the number of tool_desc_embedding."
|
35 |
+
except:
|
36 |
+
self.tool_desc_embedding = None
|
37 |
+
print("\033[92mInferring the tool_desc_embedding.\033[0m")
|
38 |
+
self.tool_desc_embedding = self.rag_model.encode(
|
39 |
+
all_tools_str, prompt="", normalize_embeddings=True
|
40 |
+
)
|
41 |
+
torch.save(self.tool_desc_embedding, self.tool_embedding_path)
|
42 |
+
print("\033[92mFinished inferring the tool_desc_embedding.\033[0m")
|
43 |
+
print("\033[91mExiting. Please rerun the code to avoid the OOM issue.\033[0m")
|
44 |
+
exit()
|
45 |
+
|
46 |
+
def rag_infer(self, query, top_k=5):
|
47 |
+
torch.cuda.empty_cache()
|
48 |
+
queries = [query]
|
49 |
+
query_embeddings = self.rag_model.encode(
|
50 |
+
queries, prompt="", normalize_embeddings=True
|
51 |
+
)
|
52 |
+
if self.tool_desc_embedding is None:
|
53 |
+
print("No tool_desc_embedding")
|
54 |
+
exit()
|
55 |
+
scores = self.rag_model.similarity(
|
56 |
+
query_embeddings, self.tool_desc_embedding)
|
57 |
+
top_k = min(top_k, len(self.tool_name))
|
58 |
+
top_k_indices = torch.topk(scores, top_k).indices.tolist()[0]
|
59 |
+
top_k_tool_names = [self.tool_name[i] for i in top_k_indices]
|
60 |
+
return top_k_tool_names
|
src/txagent/txagent.py
ADDED
@@ -0,0 +1,937 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import json
|
5 |
+
import gc
|
6 |
+
import numpy as np
|
7 |
+
from vllm import LLM, SamplingParams
|
8 |
+
from jinja2 import Template
|
9 |
+
from typing import List
|
10 |
+
import types
|
11 |
+
from tooluniverse import ToolUniverse
|
12 |
+
from gradio import ChatMessage
|
13 |
+
from .toolrag import ToolRAGModel
|
14 |
+
|
15 |
+
from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
|
16 |
+
|
17 |
+
|
18 |
+
class TxAgent:
|
19 |
+
def __init__(self, model_name,
|
20 |
+
rag_model_name,
|
21 |
+
tool_files_dict=None, # None leads to the default tool files in ToolUniverse
|
22 |
+
enable_finish=True,
|
23 |
+
enable_rag=True,
|
24 |
+
enable_summary=False,
|
25 |
+
init_rag_num=0,
|
26 |
+
step_rag_num=10,
|
27 |
+
summary_mode='step',
|
28 |
+
summary_skip_last_k=0,
|
29 |
+
summary_context_length=None,
|
30 |
+
force_finish=True,
|
31 |
+
avoid_repeat=True,
|
32 |
+
seed=None,
|
33 |
+
enable_checker=False,
|
34 |
+
enable_chat=False,
|
35 |
+
additional_default_tools=None,
|
36 |
+
):
|
37 |
+
self.model_name = model_name
|
38 |
+
self.tokenizer = None
|
39 |
+
self.terminators = None
|
40 |
+
self.rag_model_name = rag_model_name
|
41 |
+
self.tool_files_dict = tool_files_dict
|
42 |
+
self.model = None
|
43 |
+
self.rag_model = ToolRAGModel(rag_model_name)
|
44 |
+
self.tooluniverse = None
|
45 |
+
# self.tool_desc = None
|
46 |
+
self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
|
47 |
+
self.self_prompt = "Strictly follow the instruction."
|
48 |
+
self.chat_prompt = "You are helpful assistant to chat with the user."
|
49 |
+
self.enable_finish = enable_finish
|
50 |
+
self.enable_rag = enable_rag
|
51 |
+
self.enable_summary = enable_summary
|
52 |
+
self.summary_mode = summary_mode
|
53 |
+
self.summary_skip_last_k = summary_skip_last_k
|
54 |
+
self.summary_context_length = summary_context_length
|
55 |
+
self.init_rag_num = init_rag_num
|
56 |
+
self.step_rag_num = step_rag_num
|
57 |
+
self.force_finish = force_finish
|
58 |
+
self.avoid_repeat = avoid_repeat
|
59 |
+
self.seed = seed
|
60 |
+
self.enable_checker = enable_checker
|
61 |
+
self.additional_default_tools = additional_default_tools
|
62 |
+
self.print_self_values()
|
63 |
+
|
64 |
+
def init_model(self):
|
65 |
+
self.load_models()
|
66 |
+
self.load_tooluniverse()
|
67 |
+
self.load_tool_desc_embedding()
|
68 |
+
|
69 |
+
def print_self_values(self):
|
70 |
+
for attr, value in self.__dict__.items():
|
71 |
+
print(f"{attr}: {value}")
|
72 |
+
|
73 |
+
def load_models(self, model_name=None):
|
74 |
+
if model_name is not None:
|
75 |
+
if model_name == self.model_name:
|
76 |
+
return f"The model {model_name} is already loaded."
|
77 |
+
self.model_name = model_name
|
78 |
+
|
79 |
+
self.model = LLM(model=self.model_name)
|
80 |
+
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
81 |
+
self.tokenizer = self.model.get_tokenizer()
|
82 |
+
|
83 |
+
return f"Model {model_name} loaded successfully."
|
84 |
+
|
85 |
+
def load_tooluniverse(self):
|
86 |
+
self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
|
87 |
+
self.tooluniverse.load_tools()
|
88 |
+
special_tools = self.tooluniverse.prepare_tool_prompts(
|
89 |
+
self.tooluniverse.tool_category_dicts["special_tools"])
|
90 |
+
self.special_tools_name = [tool['name'] for tool in special_tools]
|
91 |
+
|
92 |
+
def load_tool_desc_embedding(self):
|
93 |
+
self.rag_model.load_tool_desc_embedding(self.tooluniverse)
|
94 |
+
|
95 |
+
def rag_infer(self, query, top_k=5):
|
96 |
+
return self.rag_model.rag_infer(query, top_k)
|
97 |
+
|
98 |
+
def initialize_tools_prompt(self, call_agent, call_agent_level, message):
|
99 |
+
picked_tools_prompt = []
|
100 |
+
picked_tools_prompt = self.add_special_tools(
|
101 |
+
picked_tools_prompt, call_agent=call_agent)
|
102 |
+
if call_agent:
|
103 |
+
call_agent_level += 1
|
104 |
+
if call_agent_level >= 2:
|
105 |
+
call_agent = False
|
106 |
+
|
107 |
+
if not call_agent:
|
108 |
+
picked_tools_prompt += self.tool_RAG(
|
109 |
+
message=message, rag_num=self.init_rag_num)
|
110 |
+
return picked_tools_prompt, call_agent_level
|
111 |
+
|
112 |
+
def initialize_conversation(self, message, conversation=None, history=None):
|
113 |
+
if conversation is None:
|
114 |
+
conversation = []
|
115 |
+
|
116 |
+
conversation = self.set_system_prompt(
|
117 |
+
conversation, self.prompt_multi_step)
|
118 |
+
if history is not None:
|
119 |
+
if len(history) == 0:
|
120 |
+
conversation = []
|
121 |
+
print("clear conversation successfully")
|
122 |
+
else:
|
123 |
+
for i in range(len(history)):
|
124 |
+
if history[i]['role'] == 'user':
|
125 |
+
if i-1 >= 0 and history[i-1]['role'] == 'assistant':
|
126 |
+
conversation.append(
|
127 |
+
{"role": "assistant", "content": history[i-1]['content']})
|
128 |
+
conversation.append(
|
129 |
+
{"role": "user", "content": history[i]['content']})
|
130 |
+
if i == len(history)-1 and history[i]['role'] == 'assistant':
|
131 |
+
conversation.append(
|
132 |
+
{"role": "assistant", "content": history[i]['content']})
|
133 |
+
|
134 |
+
conversation.append({"role": "user", "content": message})
|
135 |
+
|
136 |
+
return conversation
|
137 |
+
|
138 |
+
def tool_RAG(self, message=None,
|
139 |
+
picked_tool_names=None,
|
140 |
+
existing_tools_prompt=[],
|
141 |
+
rag_num=5,
|
142 |
+
return_call_result=False):
|
143 |
+
extra_factor = 30 # Factor to retrieve more than rag_num
|
144 |
+
if picked_tool_names is None:
|
145 |
+
assert picked_tool_names is not None or message is not None
|
146 |
+
picked_tool_names = self.rag_infer(
|
147 |
+
message, top_k=rag_num*extra_factor)
|
148 |
+
|
149 |
+
picked_tool_names_no_special = []
|
150 |
+
for tool in picked_tool_names:
|
151 |
+
if tool not in self.special_tools_name:
|
152 |
+
picked_tool_names_no_special.append(tool)
|
153 |
+
picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
|
154 |
+
picked_tool_names = picked_tool_names_no_special[:rag_num]
|
155 |
+
|
156 |
+
picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
|
157 |
+
picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
|
158 |
+
picked_tools)
|
159 |
+
if return_call_result:
|
160 |
+
return picked_tools_prompt, picked_tool_names
|
161 |
+
return picked_tools_prompt
|
162 |
+
|
163 |
+
def add_special_tools(self, tools, call_agent=False):
|
164 |
+
if self.enable_finish:
|
165 |
+
tools.append(self.tooluniverse.get_one_tool_by_one_name(
|
166 |
+
'Finish', return_prompt=True))
|
167 |
+
print("Finish tool is added")
|
168 |
+
if call_agent:
|
169 |
+
tools.append(self.tooluniverse.get_one_tool_by_one_name(
|
170 |
+
'CallAgent', return_prompt=True))
|
171 |
+
print("CallAgent tool is added")
|
172 |
+
else:
|
173 |
+
if self.enable_rag:
|
174 |
+
tools.append(self.tooluniverse.get_one_tool_by_one_name(
|
175 |
+
'Tool_RAG', return_prompt=True))
|
176 |
+
print("Tool_RAG tool is added")
|
177 |
+
|
178 |
+
if self.additional_default_tools is not None:
|
179 |
+
for each_tool_name in self.additional_default_tools:
|
180 |
+
tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
|
181 |
+
each_tool_name, return_prompt=True)
|
182 |
+
if tool_prompt is not None:
|
183 |
+
print(f"{each_tool_name} tool is added")
|
184 |
+
tools.append(tool_prompt)
|
185 |
+
return tools
|
186 |
+
|
187 |
+
def add_finish_tools(self, tools):
|
188 |
+
tools.append(self.tooluniverse.get_one_tool_by_one_name(
|
189 |
+
'Finish', return_prompt=True))
|
190 |
+
print("Finish tool is added")
|
191 |
+
return tools
|
192 |
+
|
193 |
+
def set_system_prompt(self, conversation, sys_prompt):
|
194 |
+
if len(conversation) == 0:
|
195 |
+
conversation.append(
|
196 |
+
{"role": "system", "content": sys_prompt})
|
197 |
+
else:
|
198 |
+
conversation[0] = {"role": "system", "content": sys_prompt}
|
199 |
+
return conversation
|
200 |
+
|
201 |
+
def run_function_call(self, fcall_str,
|
202 |
+
return_message=False,
|
203 |
+
existing_tools_prompt=None,
|
204 |
+
message_for_call_agent=None,
|
205 |
+
call_agent=False,
|
206 |
+
call_agent_level=None,
|
207 |
+
temperature=None):
|
208 |
+
|
209 |
+
function_call_json, message = self.tooluniverse.extract_function_call_json(
|
210 |
+
fcall_str, return_message=return_message, verbose=False)
|
211 |
+
call_results = []
|
212 |
+
special_tool_call = ''
|
213 |
+
if function_call_json is not None:
|
214 |
+
if isinstance(function_call_json, list):
|
215 |
+
for i in range(len(function_call_json)):
|
216 |
+
print("\033[94mTool Call:\033[0m", function_call_json[i])
|
217 |
+
if function_call_json[i]["name"] == 'Finish':
|
218 |
+
special_tool_call = 'Finish'
|
219 |
+
break
|
220 |
+
elif function_call_json[i]["name"] == 'Tool_RAG':
|
221 |
+
new_tools_prompt, call_result = self.tool_RAG(
|
222 |
+
message=message,
|
223 |
+
existing_tools_prompt=existing_tools_prompt,
|
224 |
+
rag_num=self.step_rag_num,
|
225 |
+
return_call_result=True)
|
226 |
+
existing_tools_prompt += new_tools_prompt
|
227 |
+
elif function_call_json[i]["name"] == 'CallAgent':
|
228 |
+
if call_agent_level < 2 and call_agent:
|
229 |
+
solution_plan = function_call_json[i]['arguments']['solution']
|
230 |
+
full_message = (
|
231 |
+
message_for_call_agent +
|
232 |
+
"\nYou must follow the following plan to answer the question: " +
|
233 |
+
str(solution_plan)
|
234 |
+
)
|
235 |
+
call_result = self.run_multistep_agent(
|
236 |
+
full_message, temperature=temperature,
|
237 |
+
max_new_tokens=1024, max_token=99999,
|
238 |
+
call_agent=False, call_agent_level=call_agent_level)
|
239 |
+
call_result = call_result.split(
|
240 |
+
'[FinalAnswer]')[-1].strip()
|
241 |
+
else:
|
242 |
+
call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
|
243 |
+
else:
|
244 |
+
call_result = self.tooluniverse.run_one_function(
|
245 |
+
function_call_json[i])
|
246 |
+
|
247 |
+
call_id = self.tooluniverse.call_id_gen()
|
248 |
+
function_call_json[i]["call_id"] = call_id
|
249 |
+
print("\033[94mTool Call Result:\033[0m", call_result)
|
250 |
+
call_results.append({
|
251 |
+
"role": "tool",
|
252 |
+
"content": json.dumps({"content": call_result, "call_id": call_id})
|
253 |
+
})
|
254 |
+
else:
|
255 |
+
call_results.append({
|
256 |
+
"role": "tool",
|
257 |
+
"content": json.dumps({"content": "Not a valid function call, please check the function call format."})
|
258 |
+
})
|
259 |
+
|
260 |
+
revised_messages = [{
|
261 |
+
"role": "assistant",
|
262 |
+
"content": message.strip(),
|
263 |
+
"tool_calls": json.dumps(function_call_json)
|
264 |
+
}] + call_results
|
265 |
+
|
266 |
+
# Yield the final result.
|
267 |
+
return revised_messages, existing_tools_prompt, special_tool_call
|
268 |
+
|
269 |
+
def run_function_call_stream(self, fcall_str,
|
270 |
+
return_message=False,
|
271 |
+
existing_tools_prompt=None,
|
272 |
+
message_for_call_agent=None,
|
273 |
+
call_agent=False,
|
274 |
+
call_agent_level=None,
|
275 |
+
temperature=None,
|
276 |
+
return_gradio_history=True):
|
277 |
+
|
278 |
+
function_call_json, message = self.tooluniverse.extract_function_call_json(
|
279 |
+
fcall_str, return_message=return_message, verbose=False)
|
280 |
+
call_results = []
|
281 |
+
special_tool_call = ''
|
282 |
+
if return_gradio_history:
|
283 |
+
gradio_history = []
|
284 |
+
if function_call_json is not None:
|
285 |
+
if isinstance(function_call_json, list):
|
286 |
+
for i in range(len(function_call_json)):
|
287 |
+
if function_call_json[i]["name"] == 'Finish':
|
288 |
+
special_tool_call = 'Finish'
|
289 |
+
break
|
290 |
+
elif function_call_json[i]["name"] == 'Tool_RAG':
|
291 |
+
new_tools_prompt, call_result = self.tool_RAG(
|
292 |
+
message=message,
|
293 |
+
existing_tools_prompt=existing_tools_prompt,
|
294 |
+
rag_num=self.step_rag_num,
|
295 |
+
return_call_result=True)
|
296 |
+
existing_tools_prompt += new_tools_prompt
|
297 |
+
elif function_call_json[i]["name"] == 'DirectResponse':
|
298 |
+
call_result = function_call_json[i]['arguments']['respose']
|
299 |
+
special_tool_call = 'DirectResponse'
|
300 |
+
elif function_call_json[i]["name"] == 'RequireClarification':
|
301 |
+
call_result = function_call_json[i]['arguments']['unclear_question']
|
302 |
+
special_tool_call = 'RequireClarification'
|
303 |
+
elif function_call_json[i]["name"] == 'CallAgent':
|
304 |
+
if call_agent_level < 2 and call_agent:
|
305 |
+
solution_plan = function_call_json[i]['arguments']['solution']
|
306 |
+
full_message = (
|
307 |
+
message_for_call_agent +
|
308 |
+
"\nYou must follow the following plan to answer the question: " +
|
309 |
+
str(solution_plan)
|
310 |
+
)
|
311 |
+
sub_agent_task = "Sub TxAgent plan: " + \
|
312 |
+
str(solution_plan)
|
313 |
+
# When streaming, yield responses as they arrive.
|
314 |
+
call_result = yield from self.run_gradio_chat(
|
315 |
+
full_message, history=[], temperature=temperature,
|
316 |
+
max_new_tokens=1024, max_token=99999,
|
317 |
+
call_agent=False, call_agent_level=call_agent_level,
|
318 |
+
conversation=None,
|
319 |
+
sub_agent_task=sub_agent_task)
|
320 |
+
|
321 |
+
call_result = call_result.split(
|
322 |
+
'[FinalAnswer]')[-1]
|
323 |
+
else:
|
324 |
+
call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
|
325 |
+
else:
|
326 |
+
call_result = self.tooluniverse.run_one_function(
|
327 |
+
function_call_json[i])
|
328 |
+
|
329 |
+
call_id = self.tooluniverse.call_id_gen()
|
330 |
+
function_call_json[i]["call_id"] = call_id
|
331 |
+
call_results.append({
|
332 |
+
"role": "tool",
|
333 |
+
"content": json.dumps({"content": call_result, "call_id": call_id})
|
334 |
+
})
|
335 |
+
if return_gradio_history and function_call_json[i]["name"] != 'Finish':
|
336 |
+
if function_call_json[i]["name"] == 'Tool_RAG':
|
337 |
+
gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
|
338 |
+
"title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
|
339 |
+
|
340 |
+
else:
|
341 |
+
gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
|
342 |
+
"title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
|
343 |
+
else:
|
344 |
+
call_results.append({
|
345 |
+
"role": "tool",
|
346 |
+
"content": json.dumps({"content": "Not a valid function call, please check the function call format."})
|
347 |
+
})
|
348 |
+
|
349 |
+
revised_messages = [{
|
350 |
+
"role": "assistant",
|
351 |
+
"content": message.strip(),
|
352 |
+
"tool_calls": json.dumps(function_call_json)
|
353 |
+
}] + call_results
|
354 |
+
|
355 |
+
# Yield the final result.
|
356 |
+
if return_gradio_history:
|
357 |
+
return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
|
358 |
+
else:
|
359 |
+
return revised_messages, existing_tools_prompt, special_tool_call
|
360 |
+
|
361 |
+
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
|
362 |
+
if conversation[-1]['role'] == 'assisant':
|
363 |
+
conversation.append(
|
364 |
+
{'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
|
365 |
+
finish_tools_prompt = self.add_finish_tools([])
|
366 |
+
|
367 |
+
last_outputs_str = self.llm_infer(messages=conversation,
|
368 |
+
temperature=temperature,
|
369 |
+
tools=finish_tools_prompt,
|
370 |
+
output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
|
371 |
+
skip_special_tokens=True,
|
372 |
+
max_new_tokens=max_new_tokens, max_token=max_token)
|
373 |
+
print(last_outputs_str)
|
374 |
+
return last_outputs_str
|
375 |
+
|
376 |
+
def run_multistep_agent(self, message: str,
|
377 |
+
temperature: float,
|
378 |
+
max_new_tokens: int,
|
379 |
+
max_token: int,
|
380 |
+
max_round: int = 20,
|
381 |
+
call_agent=False,
|
382 |
+
call_agent_level=0) -> str:
|
383 |
+
"""
|
384 |
+
Generate a streaming response using the llama3-8b model.
|
385 |
+
Args:
|
386 |
+
message (str): The input message.
|
387 |
+
temperature (float): The temperature for generating the response.
|
388 |
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
389 |
+
Returns:
|
390 |
+
str: The generated response.
|
391 |
+
"""
|
392 |
+
print("\033[1;32;40mstart\033[0m")
|
393 |
+
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
394 |
+
call_agent, call_agent_level, message)
|
395 |
+
conversation = self.initialize_conversation(message)
|
396 |
+
|
397 |
+
outputs = []
|
398 |
+
last_outputs = []
|
399 |
+
next_round = True
|
400 |
+
function_call_messages = []
|
401 |
+
current_round = 0
|
402 |
+
token_overflow = False
|
403 |
+
enable_summary = False
|
404 |
+
last_status = {}
|
405 |
+
|
406 |
+
if self.enable_checker:
|
407 |
+
checker = ReasoningTraceChecker(message, conversation)
|
408 |
+
try:
|
409 |
+
while next_round and current_round < max_round:
|
410 |
+
current_round += 1
|
411 |
+
if len(outputs) > 0:
|
412 |
+
function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
|
413 |
+
last_outputs, return_message=True,
|
414 |
+
existing_tools_prompt=picked_tools_prompt,
|
415 |
+
message_for_call_agent=message,
|
416 |
+
call_agent=call_agent,
|
417 |
+
call_agent_level=call_agent_level,
|
418 |
+
temperature=temperature)
|
419 |
+
|
420 |
+
if special_tool_call == 'Finish':
|
421 |
+
next_round = False
|
422 |
+
conversation.extend(function_call_messages)
|
423 |
+
if isinstance(function_call_messages[0]['content'], types.GeneratorType):
|
424 |
+
function_call_messages[0]['content'] = next(
|
425 |
+
function_call_messages[0]['content'])
|
426 |
+
return function_call_messages[0]['content'].split('[FinalAnswer]')[-1]
|
427 |
+
|
428 |
+
if (self.enable_summary or token_overflow) and not call_agent:
|
429 |
+
if token_overflow:
|
430 |
+
print("token_overflow, using summary")
|
431 |
+
enable_summary = True
|
432 |
+
last_status = self.function_result_summary(
|
433 |
+
conversation, status=last_status, enable_summary=enable_summary)
|
434 |
+
|
435 |
+
if function_call_messages is not None:
|
436 |
+
conversation.extend(function_call_messages)
|
437 |
+
outputs.append(tool_result_format(
|
438 |
+
function_call_messages))
|
439 |
+
else:
|
440 |
+
next_round = False
|
441 |
+
conversation.extend(
|
442 |
+
[{"role": "assistant", "content": ''.join(last_outputs)}])
|
443 |
+
return ''.join(last_outputs).replace("</s>", "")
|
444 |
+
if self.enable_checker:
|
445 |
+
good_status, wrong_info = checker.check_conversation()
|
446 |
+
if not good_status:
|
447 |
+
next_round = False
|
448 |
+
print(
|
449 |
+
"Internal error in reasoning: " + wrong_info)
|
450 |
+
break
|
451 |
+
last_outputs = []
|
452 |
+
outputs.append("### TxAgent:\n")
|
453 |
+
last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
|
454 |
+
temperature=temperature,
|
455 |
+
tools=picked_tools_prompt,
|
456 |
+
skip_special_tokens=False,
|
457 |
+
max_new_tokens=max_new_tokens, max_token=max_token,
|
458 |
+
check_token_status=True)
|
459 |
+
if last_outputs_str is None:
|
460 |
+
next_round = False
|
461 |
+
print(
|
462 |
+
"The number of tokens exceeds the maximum limit.")
|
463 |
+
else:
|
464 |
+
last_outputs.append(last_outputs_str)
|
465 |
+
if max_round == current_round:
|
466 |
+
print("The number of rounds exceeds the maximum limit!")
|
467 |
+
if self.force_finish:
|
468 |
+
return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
|
469 |
+
else:
|
470 |
+
return None
|
471 |
+
|
472 |
+
except Exception as e:
|
473 |
+
print(f"Error: {e}")
|
474 |
+
if self.force_finish:
|
475 |
+
return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
|
476 |
+
else:
|
477 |
+
return None
|
478 |
+
|
479 |
+
def build_logits_processor(self, messages, llm):
|
480 |
+
# Use the tokenizer from the LLM instance.
|
481 |
+
tokenizer = llm.get_tokenizer()
|
482 |
+
if self.avoid_repeat and len(messages) > 2:
|
483 |
+
assistant_messages = []
|
484 |
+
for i in range(1, len(messages) + 1):
|
485 |
+
if messages[-i]['role'] == 'assistant':
|
486 |
+
assistant_messages.append(messages[-i]['content'])
|
487 |
+
if len(assistant_messages) == 2:
|
488 |
+
break
|
489 |
+
forbidden_ids = [tokenizer.encode(
|
490 |
+
msg, add_special_tokens=False) for msg in assistant_messages]
|
491 |
+
return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
|
492 |
+
else:
|
493 |
+
return None
|
494 |
+
|
495 |
+
def llm_infer(self, messages, temperature=0.1, tools=None,
|
496 |
+
output_begin_string=None, max_new_tokens=2048,
|
497 |
+
max_token=None, skip_special_tokens=True,
|
498 |
+
model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
|
499 |
+
|
500 |
+
if model is None:
|
501 |
+
model = self.model
|
502 |
+
|
503 |
+
logits_processor = self.build_logits_processor(messages, model)
|
504 |
+
sampling_params = SamplingParams(
|
505 |
+
temperature=temperature,
|
506 |
+
max_tokens=max_new_tokens,
|
507 |
+
logits_processors=logits_processor,
|
508 |
+
seed=seed if seed is not None else self.seed,
|
509 |
+
)
|
510 |
+
|
511 |
+
prompt = self.chat_template.render(
|
512 |
+
messages=messages, tools=tools, add_generation_prompt=True)
|
513 |
+
if output_begin_string is not None:
|
514 |
+
prompt += output_begin_string
|
515 |
+
|
516 |
+
if check_token_status and max_token is not None:
|
517 |
+
token_overflow = False
|
518 |
+
num_input_tokens = len(self.tokenizer.encode(
|
519 |
+
prompt, return_tensors="pt")[0])
|
520 |
+
if max_token is not None:
|
521 |
+
if num_input_tokens > max_token:
|
522 |
+
torch.cuda.empty_cache()
|
523 |
+
gc.collect()
|
524 |
+
print("Number of input tokens before inference:",
|
525 |
+
num_input_tokens)
|
526 |
+
logger.info(
|
527 |
+
"The number of tokens exceeds the maximum limit!!!!")
|
528 |
+
token_overflow = True
|
529 |
+
return None, token_overflow
|
530 |
+
output = model.generate(
|
531 |
+
prompt,
|
532 |
+
sampling_params=sampling_params,
|
533 |
+
)
|
534 |
+
output = output[0].outputs[0].text
|
535 |
+
print("\033[92m" + output + "\033[0m")
|
536 |
+
if check_token_status and max_token is not None:
|
537 |
+
return output, token_overflow
|
538 |
+
|
539 |
+
return output
|
540 |
+
|
541 |
+
def run_self_agent(self, message: str,
|
542 |
+
temperature: float,
|
543 |
+
max_new_tokens: int,
|
544 |
+
max_token: int) -> str:
|
545 |
+
|
546 |
+
print("\033[1;32;40mstart self agent\033[0m")
|
547 |
+
conversation = []
|
548 |
+
conversation = self.set_system_prompt(conversation, self.self_prompt)
|
549 |
+
conversation.append({"role": "user", "content": message})
|
550 |
+
return self.llm_infer(messages=conversation,
|
551 |
+
temperature=temperature,
|
552 |
+
tools=None,
|
553 |
+
max_new_tokens=max_new_tokens, max_token=max_token)
|
554 |
+
|
555 |
+
def run_chat_agent(self, message: str,
|
556 |
+
temperature: float,
|
557 |
+
max_new_tokens: int,
|
558 |
+
max_token: int) -> str:
|
559 |
+
|
560 |
+
print("\033[1;32;40mstart chat agent\033[0m")
|
561 |
+
conversation = []
|
562 |
+
conversation = self.set_system_prompt(conversation, self.chat_prompt)
|
563 |
+
conversation.append({"role": "user", "content": message})
|
564 |
+
return self.llm_infer(messages=conversation,
|
565 |
+
temperature=temperature,
|
566 |
+
tools=None,
|
567 |
+
max_new_tokens=max_new_tokens, max_token=max_token)
|
568 |
+
|
569 |
+
def run_format_agent(self, message: str,
|
570 |
+
answer: str,
|
571 |
+
temperature: float,
|
572 |
+
max_new_tokens: int,
|
573 |
+
max_token: int) -> str:
|
574 |
+
|
575 |
+
print("\033[1;32;40mstart format agent\033[0m")
|
576 |
+
if '[FinalAnswer]' in answer:
|
577 |
+
possible_final_answer = answer.split("[FinalAnswer]")[-1]
|
578 |
+
elif "\n\n" in answer:
|
579 |
+
possible_final_answer = answer.split("\n\n")[-1]
|
580 |
+
else:
|
581 |
+
possible_final_answer = answer.strip()
|
582 |
+
if len(possible_final_answer) == 1:
|
583 |
+
choice = possible_final_answer[0]
|
584 |
+
if choice in ['A', 'B', 'C', 'D', 'E']:
|
585 |
+
return choice
|
586 |
+
elif len(possible_final_answer) > 1:
|
587 |
+
if possible_final_answer[1] == ':':
|
588 |
+
choice = possible_final_answer[0]
|
589 |
+
if choice in ['A', 'B', 'C', 'D', 'E']:
|
590 |
+
print("choice", choice)
|
591 |
+
return choice
|
592 |
+
|
593 |
+
conversation = []
|
594 |
+
format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
|
595 |
+
conversation = self.set_system_prompt(conversation, format_prompt)
|
596 |
+
conversation.append({"role": "user", "content": message +
|
597 |
+
"\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
|
598 |
+
return self.llm_infer(messages=conversation,
|
599 |
+
temperature=temperature,
|
600 |
+
tools=None,
|
601 |
+
max_new_tokens=max_new_tokens, max_token=max_token)
|
602 |
+
|
603 |
+
def run_summary_agent(self, thought_calls: str,
|
604 |
+
function_response: str,
|
605 |
+
temperature: float,
|
606 |
+
max_new_tokens: int,
|
607 |
+
max_token: int) -> str:
|
608 |
+
print("\033[1;32;40mSummarized Tool Result:\033[0m")
|
609 |
+
generate_tool_result_summary_training_prompt = """Thought and function calls:
|
610 |
+
{thought_calls}
|
611 |
+
|
612 |
+
Function calls' responses:
|
613 |
+
\"\"\"
|
614 |
+
{function_response}
|
615 |
+
\"\"\"
|
616 |
+
|
617 |
+
Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
|
618 |
+
|
619 |
+
Directly respond with the summarized sentence of the function calls' responses only.
|
620 |
+
|
621 |
+
Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
|
622 |
+
""".format(thought_calls=thought_calls, function_response=function_response)
|
623 |
+
conversation = []
|
624 |
+
conversation.append(
|
625 |
+
{"role": "user", "content": generate_tool_result_summary_training_prompt})
|
626 |
+
output = self.llm_infer(messages=conversation,
|
627 |
+
temperature=temperature,
|
628 |
+
tools=None,
|
629 |
+
max_new_tokens=max_new_tokens, max_token=max_token)
|
630 |
+
|
631 |
+
if '[' in output:
|
632 |
+
output = output.split('[')[0]
|
633 |
+
return output
|
634 |
+
|
635 |
+
def function_result_summary(self, input_list, status, enable_summary):
|
636 |
+
"""
|
637 |
+
Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
|
638 |
+
Supports 'length' and 'step' modes, and skips the last 'k' groups.
|
639 |
+
|
640 |
+
Parameters:
|
641 |
+
input_list (list): A list of dictionaries containing role and other information.
|
642 |
+
summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
|
643 |
+
summary_context_length (int): The context length threshold for the 'length' mode.
|
644 |
+
last_processed_index (tuple or int): The last processed index.
|
645 |
+
|
646 |
+
Returns:
|
647 |
+
list: A list of extracted information from valid sequences.
|
648 |
+
"""
|
649 |
+
if 'tool_call_step' not in status:
|
650 |
+
status['tool_call_step'] = 0
|
651 |
+
|
652 |
+
for idx in range(len(input_list)):
|
653 |
+
pos_id = len(input_list)-idx-1
|
654 |
+
if input_list[pos_id]['role'] == 'assistant':
|
655 |
+
if 'tool_calls' in input_list[pos_id]:
|
656 |
+
if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
|
657 |
+
status['tool_call_step'] += 1
|
658 |
+
break
|
659 |
+
|
660 |
+
if 'step' in status:
|
661 |
+
status['step'] += 1
|
662 |
+
else:
|
663 |
+
status['step'] = 0
|
664 |
+
|
665 |
+
if not enable_summary:
|
666 |
+
return status
|
667 |
+
|
668 |
+
if 'summarized_index' not in status:
|
669 |
+
status['summarized_index'] = 0
|
670 |
+
|
671 |
+
if 'summarized_step' not in status:
|
672 |
+
status['summarized_step'] = 0
|
673 |
+
|
674 |
+
if 'previous_length' not in status:
|
675 |
+
status['previous_length'] = 0
|
676 |
+
|
677 |
+
if 'history' not in status:
|
678 |
+
status['history'] = []
|
679 |
+
|
680 |
+
function_response = ''
|
681 |
+
idx = 0
|
682 |
+
current_summarized_index = status['summarized_index']
|
683 |
+
|
684 |
+
status['history'].append(self.summary_mode == 'step' and status['summarized_step']
|
685 |
+
< status['step']-status['tool_call_step']-self.summary_skip_last_k)
|
686 |
+
|
687 |
+
idx = current_summarized_index
|
688 |
+
while idx < len(input_list):
|
689 |
+
if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
|
690 |
+
|
691 |
+
if input_list[idx]['role'] == 'assistant':
|
692 |
+
if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
|
693 |
+
this_thought_calls = None
|
694 |
+
else:
|
695 |
+
if len(function_response) != 0:
|
696 |
+
print("internal summary")
|
697 |
+
status['summarized_step'] += 1
|
698 |
+
result_summary = self.run_summary_agent(
|
699 |
+
thought_calls=this_thought_calls,
|
700 |
+
function_response=function_response,
|
701 |
+
temperature=0.1,
|
702 |
+
max_new_tokens=1024,
|
703 |
+
max_token=99999
|
704 |
+
)
|
705 |
+
|
706 |
+
input_list.insert(
|
707 |
+
last_call_idx+1, {'role': 'tool', 'content': result_summary})
|
708 |
+
status['summarized_index'] = last_call_idx + 2
|
709 |
+
idx += 1
|
710 |
+
|
711 |
+
last_call_idx = idx
|
712 |
+
this_thought_calls = input_list[idx]['content'] + \
|
713 |
+
input_list[idx]['tool_calls']
|
714 |
+
function_response = ''
|
715 |
+
|
716 |
+
elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
|
717 |
+
function_response += input_list[idx]['content']
|
718 |
+
del input_list[idx]
|
719 |
+
idx -= 1
|
720 |
+
|
721 |
+
else:
|
722 |
+
break
|
723 |
+
idx += 1
|
724 |
+
|
725 |
+
if len(function_response) != 0:
|
726 |
+
status['summarized_step'] += 1
|
727 |
+
result_summary = self.run_summary_agent(
|
728 |
+
thought_calls=this_thought_calls,
|
729 |
+
function_response=function_response,
|
730 |
+
temperature=0.1,
|
731 |
+
max_new_tokens=1024,
|
732 |
+
max_token=99999
|
733 |
+
)
|
734 |
+
|
735 |
+
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
736 |
+
for tool_call in tool_calls:
|
737 |
+
del tool_call['call_id']
|
738 |
+
input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
|
739 |
+
input_list.insert(
|
740 |
+
last_call_idx+1, {'role': 'tool', 'content': result_summary})
|
741 |
+
status['summarized_index'] = last_call_idx + 2
|
742 |
+
|
743 |
+
return status
|
744 |
+
|
745 |
+
# Following are Gradio related functions
|
746 |
+
|
747 |
+
# General update method that accepts any new arguments through kwargs
|
748 |
+
def update_parameters(self, **kwargs):
|
749 |
+
for key, value in kwargs.items():
|
750 |
+
if hasattr(self, key):
|
751 |
+
setattr(self, key, value)
|
752 |
+
|
753 |
+
# Return the updated attributes
|
754 |
+
updated_attributes = {key: value for key,
|
755 |
+
value in kwargs.items() if hasattr(self, key)}
|
756 |
+
return updated_attributes
|
757 |
+
|
758 |
+
def run_gradio_chat(self, message: str,
|
759 |
+
history: list,
|
760 |
+
temperature: float,
|
761 |
+
max_new_tokens: int,
|
762 |
+
max_token: int,
|
763 |
+
call_agent: bool,
|
764 |
+
conversation: gr.State,
|
765 |
+
max_round: int = 20,
|
766 |
+
seed: int = None,
|
767 |
+
call_agent_level: int = 0,
|
768 |
+
sub_agent_task: str = None) -> str:
|
769 |
+
"""
|
770 |
+
Generate a streaming response using the llama3-8b model.
|
771 |
+
Args:
|
772 |
+
message (str): The input message.
|
773 |
+
history (list): The conversation history used by ChatInterface.
|
774 |
+
temperature (float): The temperature for generating the response.
|
775 |
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
776 |
+
Returns:
|
777 |
+
str: The generated response.
|
778 |
+
"""
|
779 |
+
print("\033[1;32;40mstart\033[0m")
|
780 |
+
print("len(message)", len(message))
|
781 |
+
if len(message) <= 10:
|
782 |
+
yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
|
783 |
+
return "Please provide a valid message."
|
784 |
+
outputs = []
|
785 |
+
outputs_str = ''
|
786 |
+
last_outputs = []
|
787 |
+
|
788 |
+
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
789 |
+
call_agent,
|
790 |
+
call_agent_level,
|
791 |
+
message)
|
792 |
+
|
793 |
+
conversation = self.initialize_conversation(
|
794 |
+
message,
|
795 |
+
conversation=conversation,
|
796 |
+
history=history)
|
797 |
+
history = []
|
798 |
+
|
799 |
+
next_round = True
|
800 |
+
function_call_messages = []
|
801 |
+
current_round = 0
|
802 |
+
enable_summary = False
|
803 |
+
last_status = {} # for summary
|
804 |
+
token_overflow = False
|
805 |
+
if self.enable_checker:
|
806 |
+
checker = ReasoningTraceChecker(
|
807 |
+
message, conversation, init_index=len(conversation))
|
808 |
+
|
809 |
+
try:
|
810 |
+
while next_round and current_round < max_round:
|
811 |
+
current_round += 1
|
812 |
+
if len(last_outputs) > 0:
|
813 |
+
function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
|
814 |
+
last_outputs, return_message=True,
|
815 |
+
existing_tools_prompt=picked_tools_prompt,
|
816 |
+
message_for_call_agent=message,
|
817 |
+
call_agent=call_agent,
|
818 |
+
call_agent_level=call_agent_level,
|
819 |
+
temperature=temperature)
|
820 |
+
history.extend(current_gradio_history)
|
821 |
+
if special_tool_call == 'Finish':
|
822 |
+
yield history
|
823 |
+
next_round = False
|
824 |
+
conversation.extend(function_call_messages)
|
825 |
+
return function_call_messages[0]['content']
|
826 |
+
elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
|
827 |
+
history.append(
|
828 |
+
ChatMessage(role="assistant", content=history[-1].content))
|
829 |
+
yield history
|
830 |
+
next_round = False
|
831 |
+
return history[-1].content
|
832 |
+
if (self.enable_summary or token_overflow) and not call_agent:
|
833 |
+
if token_overflow:
|
834 |
+
print("token_overflow, using summary")
|
835 |
+
enable_summary = True
|
836 |
+
last_status = self.function_result_summary(
|
837 |
+
conversation, status=last_status,
|
838 |
+
enable_summary=enable_summary)
|
839 |
+
if function_call_messages is not None:
|
840 |
+
conversation.extend(function_call_messages)
|
841 |
+
formated_md_function_call_messages = tool_result_format(
|
842 |
+
function_call_messages)
|
843 |
+
yield history
|
844 |
+
else:
|
845 |
+
next_round = False
|
846 |
+
conversation.extend(
|
847 |
+
[{"role": "assistant", "content": ''.join(last_outputs)}])
|
848 |
+
return ''.join(last_outputs).replace("</s>", "")
|
849 |
+
if self.enable_checker:
|
850 |
+
good_status, wrong_info = checker.check_conversation()
|
851 |
+
if not good_status:
|
852 |
+
next_round = False
|
853 |
+
print("Internal error in reasoning: " + wrong_info)
|
854 |
+
break
|
855 |
+
last_outputs = []
|
856 |
+
last_outputs_str, token_overflow = self.llm_infer(
|
857 |
+
messages=conversation,
|
858 |
+
temperature=temperature,
|
859 |
+
tools=picked_tools_prompt,
|
860 |
+
skip_special_tokens=False,
|
861 |
+
max_new_tokens=max_new_tokens,
|
862 |
+
max_token=max_token,
|
863 |
+
seed=seed,
|
864 |
+
check_token_status=True)
|
865 |
+
last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
|
866 |
+
for each in history:
|
867 |
+
if each.metadata is not None:
|
868 |
+
each.metadata['status'] = 'done'
|
869 |
+
if '[FinalAnswer]' in last_thought:
|
870 |
+
final_thought, final_answer = last_thought.split(
|
871 |
+
'[FinalAnswer]')
|
872 |
+
history.append(
|
873 |
+
ChatMessage(role="assistant",
|
874 |
+
content=final_thought.strip())
|
875 |
+
)
|
876 |
+
yield history
|
877 |
+
history.append(
|
878 |
+
ChatMessage(
|
879 |
+
role="assistant", content="**Answer**:\n"+final_answer.strip())
|
880 |
+
)
|
881 |
+
yield history
|
882 |
+
else:
|
883 |
+
history.append(ChatMessage(
|
884 |
+
role="assistant", content=last_thought))
|
885 |
+
yield history
|
886 |
+
|
887 |
+
last_outputs.append(last_outputs_str)
|
888 |
+
|
889 |
+
if next_round:
|
890 |
+
if self.force_finish:
|
891 |
+
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
892 |
+
conversation, temperature, max_new_tokens, max_token)
|
893 |
+
for each in history:
|
894 |
+
if each.metadata is not None:
|
895 |
+
each.metadata['status'] = 'done'
|
896 |
+
if '[FinalAnswer]' in last_thought:
|
897 |
+
final_thought, final_answer = last_thought.split(
|
898 |
+
'[FinalAnswer]')
|
899 |
+
history.append(
|
900 |
+
ChatMessage(role="assistant",
|
901 |
+
content=final_thought.strip())
|
902 |
+
)
|
903 |
+
yield history
|
904 |
+
history.append(
|
905 |
+
ChatMessage(
|
906 |
+
role="assistant", content="**Answer**:\n"+final_answer.strip())
|
907 |
+
)
|
908 |
+
yield history
|
909 |
+
else:
|
910 |
+
yield "The number of rounds exceeds the maximum limit!"
|
911 |
+
|
912 |
+
except Exception as e:
|
913 |
+
print(f"Error: {e}")
|
914 |
+
if self.force_finish:
|
915 |
+
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
916 |
+
conversation,
|
917 |
+
temperature,
|
918 |
+
max_new_tokens,
|
919 |
+
max_token)
|
920 |
+
for each in history:
|
921 |
+
if each.metadata is not None:
|
922 |
+
each.metadata['status'] = 'done'
|
923 |
+
if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
|
924 |
+
final_thought, final_answer = last_thought.split(
|
925 |
+
'[FinalAnswer]')
|
926 |
+
history.append(
|
927 |
+
ChatMessage(role="assistant",
|
928 |
+
content=final_thought.strip())
|
929 |
+
)
|
930 |
+
yield history
|
931 |
+
history.append(
|
932 |
+
ChatMessage(
|
933 |
+
role="assistant", content="**Answer**:\n"+final_answer.strip())
|
934 |
+
)
|
935 |
+
yield history
|
936 |
+
else:
|
937 |
+
return None
|
src/txagent/utils.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import json
|
3 |
+
import hashlib
|
4 |
+
import torch
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
|
8 |
+
def get_md5(input_str):
|
9 |
+
# Create an MD5 hash object
|
10 |
+
md5_hash = hashlib.md5()
|
11 |
+
|
12 |
+
# Encode the string and update the hash object
|
13 |
+
md5_hash.update(input_str.encode('utf-8'))
|
14 |
+
|
15 |
+
# Return the hexadecimal MD5 digest
|
16 |
+
return md5_hash.hexdigest()
|
17 |
+
|
18 |
+
|
19 |
+
def tool_result_format(function_call_messages):
|
20 |
+
current_output = "\n\n<details>\n<summary> <strong>Verfied Feedback from Tools</strong>, click to see details:</summary>\n\n"
|
21 |
+
for each_message in function_call_messages:
|
22 |
+
if each_message['role'] == 'tool':
|
23 |
+
current_output += f"{each_message['content']}\n\n"
|
24 |
+
current_output += "</details>\n\n\n"
|
25 |
+
return current_output
|
26 |
+
|
27 |
+
|
28 |
+
class NoRepeatSentenceProcessor:
|
29 |
+
def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
forbidden_sequences (List[List[int]]): A list of token ID sequences corresponding to forbidden sentences.
|
33 |
+
allowed_prefix_length (int): The number k such that if the generated tokens match the first k tokens
|
34 |
+
of a forbidden sequence, then the candidate token that would extend the match is blocked.
|
35 |
+
"""
|
36 |
+
self.allowed_prefix_length = allowed_prefix_length
|
37 |
+
# Build a lookup dictionary: key is a tuple of the first k tokens, value is a set of tokens to block.
|
38 |
+
self.forbidden_prefix_dict = {}
|
39 |
+
for seq in forbidden_sequences:
|
40 |
+
if len(seq) > allowed_prefix_length:
|
41 |
+
prefix = tuple(seq[:allowed_prefix_length])
|
42 |
+
next_token = seq[allowed_prefix_length]
|
43 |
+
self.forbidden_prefix_dict.setdefault(
|
44 |
+
prefix, set()).add(next_token)
|
45 |
+
|
46 |
+
def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
|
47 |
+
"""
|
48 |
+
Modifies the logits to block tokens that would extend a forbidden sentence.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
token_ids (List[int]): List of token IDs generated so far.
|
52 |
+
logits (torch.Tensor): Logits tensor for the next token (shape: [vocab_size]).
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
torch.Tensor: Modified logits.
|
56 |
+
"""
|
57 |
+
if len(token_ids) >= self.allowed_prefix_length:
|
58 |
+
prefix = tuple(token_ids[:self.allowed_prefix_length])
|
59 |
+
if prefix in self.forbidden_prefix_dict:
|
60 |
+
for token_id in self.forbidden_prefix_dict[prefix]:
|
61 |
+
logits[token_id] = -float("inf")
|
62 |
+
return logits
|
63 |
+
|
64 |
+
|
65 |
+
class ReasoningTraceChecker:
|
66 |
+
def __init__(self, question, conversation, init_index=None):
|
67 |
+
self.question = question
|
68 |
+
self.conversation = conversation
|
69 |
+
self.existing_thoughts = []
|
70 |
+
self.existing_actions = []
|
71 |
+
if init_index is not None:
|
72 |
+
self.index = init_index
|
73 |
+
else:
|
74 |
+
self.index = 1
|
75 |
+
self.question = self.question.lower()
|
76 |
+
self.new_thoughts = []
|
77 |
+
self.new_actions = []
|
78 |
+
|
79 |
+
def check_conversation(self):
|
80 |
+
info = ''
|
81 |
+
current_index = self.index
|
82 |
+
for i in range(current_index, len(self.conversation)):
|
83 |
+
each = self.conversation[i]
|
84 |
+
self.index = i
|
85 |
+
if each['role'] == 'assistant':
|
86 |
+
print(each)
|
87 |
+
thought = each['content']
|
88 |
+
actions = each['tool_calls']
|
89 |
+
|
90 |
+
good_status, current_info = self.check_repeat_thought(thought)
|
91 |
+
info += current_info
|
92 |
+
if not good_status:
|
93 |
+
return False, info
|
94 |
+
|
95 |
+
good_status, current_info = self.check_repeat_action(actions)
|
96 |
+
info += current_info
|
97 |
+
if not good_status:
|
98 |
+
return False, info
|
99 |
+
return True, info
|
100 |
+
|
101 |
+
def check_repeat_thought(self, thought):
|
102 |
+
if thought in self.existing_thoughts:
|
103 |
+
return False, "repeat_thought"
|
104 |
+
self.existing_thoughts.append(thought)
|
105 |
+
return True, ''
|
106 |
+
|
107 |
+
def check_repeat_action(self, actions):
|
108 |
+
if type(actions) != list:
|
109 |
+
actions = json.loads(actions)
|
110 |
+
for each_action in actions:
|
111 |
+
if 'call_id' in each_action:
|
112 |
+
del each_action['call_id']
|
113 |
+
each_action = json.dumps(each_action)
|
114 |
+
if each_action in self.existing_actions:
|
115 |
+
return False, "repeat_action"
|
116 |
+
self.existing_actions.append(each_action)
|
117 |
+
return True, ''
|