Ali2206 commited on
Commit
c47b2de
·
verified ·
1 Parent(s): 41dec39

Upload 11 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ img/q1.gif filter=lfs diff=lfs merge=lfs -text
37
+ img/q2.gif filter=lfs diff=lfs merge=lfs -text
38
+ img/q3.gif filter=lfs diff=lfs merge=lfs -text
img/q1.gif ADDED

Git LFS Details

  • SHA256: f0cbda2e1ec46defdae51233c03aee0ddea1ad1f28ad9ed79e4ea72a8f13edf9
  • Pointer size: 132 Bytes
  • Size of remote file: 7.65 MB
img/q2.gif ADDED

Git LFS Details

  • SHA256: a453c339ddcc333e28bc9626b287d9d6fa1554edec7b127611617bcb27b90591
  • Pointer size: 132 Bytes
  • Size of remote file: 6.31 MB
img/q3.gif ADDED

Git LFS Details

  • SHA256: ded44920ea272367247ac4f1a222c0a55932ad6b1173c2b14009a3ec4a79f524
  • Pointer size: 132 Bytes
  • Size of remote file: 8.83 MB
pyproject.toml CHANGED
@@ -1,3 +1,3 @@
1
- [build-system]
2
- requires = ["setuptools", "wheel"]
3
  build-backend = "setuptools.build_meta"
 
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
  build-backend = "setuptools.build_meta"
run_example.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from txagent import TxAgent
2
+ import os
3
+ os.environ["MKL_THREADING_LAYER"] = "GNU"
4
+
5
+
6
+ model_name = 'mims-harvard/TxAgent-T1-Llama-3.1-8B'
7
+ rag_model_name = 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B'
8
+ multiagent = False
9
+ max_round = 20
10
+ init_rag_num = 0
11
+ step_rag_num = 10
12
+
13
+ agent = TxAgent(model_name,
14
+ rag_model_name,
15
+ enable_summary=False)
16
+ agent.init_model()
17
+
18
+ question = "Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?"
19
+
20
+ response = agent.run_multistep_agent(
21
+ question,
22
+ temperature=0.3,
23
+ max_new_tokens=1024,
24
+ max_token=90240,
25
+ call_agent=multiagent,
26
+ max_round=max_round)
27
+
28
+ print(f"\033[94m{response}\033[0m")
run_txagent_app.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import datetime
3
+ import sys
4
+ from txagent import TxAgent
5
+ import spaces
6
+ import gradio as gr
7
+ import os
8
+ import os
9
+
10
+ # Determine the directory where the current file is located
11
+ current_dir = os.path.dirname(os.path.abspath(__file__))
12
+ os.environ["MKL_THREADING_LAYER"] = "GNU"
13
+
14
+ # Set an environment variable
15
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
+
17
+
18
+ DESCRIPTION = '''
19
+ <div>
20
+ <h1 style="text-align: center;">TxAgent: An AI Agent for Therapeutic Reasoning Across a Universe of Tools </h1>
21
+ </div>
22
+ '''
23
+ INTRO = """
24
+ Precision therapeutics require multimodal adaptive models that provide personalized treatment recommendations. We introduce TxAgent, an AI agent that leverages multi-step reasoning and real-time biomedical knowledge retrieval across a toolbox of 211 expert-curated tools to navigate complex drug interactions, contraindications, and patient-specific treatment strategies, delivering evidence-grounded therapeutic decisions. TxAgent executes goal-oriented tool selection and iterative function calls to solve therapeutic tasks that require deep clinical understanding and cross-source validation. The ToolUniverse consolidates 211 tools linked to trusted sources, including all US FDA-approved drugs since 1939 and validated clinical insights from Open Targets.
25
+ """
26
+
27
+ LICENSE = """
28
+ We welcome your feedback and suggestions to enhance your experience with TxAgent, and if you're interested in collaboration, please email Marinka Zitnik and Shanghua Gao.
29
+
30
+ ### Medical Advice Disclaimer
31
+ DISCLAIMER: THIS WEBSITE DOES NOT PROVIDE MEDICAL ADVICE
32
+ The information, including but not limited to, text, graphics, images and other material contained on this website are for informational purposes only. No material on this site is intended to be a substitute for professional medical advice, diagnosis or treatment. Always seek the advice of your physician or other qualified health care provider with any questions you may have regarding a medical condition or treatment and before undertaking a new health care regimen, and never disregard professional medical advice or delay in seeking it because of something you have read on this website.
33
+ """
34
+
35
+ PLACEHOLDER = """
36
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
37
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">TxAgent</h1>
38
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using TxAgent:</p>
39
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️
40
+ (top-right) to remove previous context before sumbmitting a new question.</p>
41
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
42
+ </div>
43
+ """
44
+
45
+ css = """
46
+ h1 {
47
+ text-align: center;
48
+ display: block;
49
+ }
50
+
51
+ #duplicate-button {
52
+ margin: auto;
53
+ color: white;
54
+ background: #1565c0;
55
+ border-radius: 100vh;
56
+ }
57
+ .small-button button {
58
+ font-size: 12px !important;
59
+ padding: 4px 8px !important;
60
+ height: 6px !important;
61
+ width: 4px !important;
62
+ }
63
+ .gradio-accordion {
64
+ margin-top: 0px !important;
65
+ margin-bottom: 0px !important;
66
+ }
67
+ """
68
+
69
+ chat_css = """
70
+ .gr-button { font-size: 20px !important; } /* Enlarges button icons */
71
+ .gr-button svg { width: 32px !important; height: 32px !important; } /* Enlarges SVG icons */
72
+ """
73
+
74
+ # model_name = '/n/holylfs06/LABS/mzitnik_lab/Lab/shgao/bioagent/bio/alignment-handbook/data_new/L8-qlora-biov49v9v7v16_32k_chat01_merged'
75
+ model_name = 'mims-harvard/TxAgent-T1-Llama-3.1-8B'
76
+ rag_model_name = 'mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B'
77
+
78
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
79
+
80
+
81
+ question_examples = [
82
+ ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of moderate hepatic impairment?'],
83
+ ['Given a 50-year-old patient experiencing severe acute pain and considering the use of the newly approved medication, Journavx, how should the dosage be adjusted considering the presence of severe hepatic impairment?'],
84
+ ['A 30-year-old patient is taking Prozac to treat their depression. They were recently diagnosed with WHIM syndrome and require a treatment for that condition as well. Is Xolremdi suitable for this patient, considering contraindications?'],
85
+ ]
86
+
87
+ new_tool_files = {
88
+ 'new_tool': os.path.join(current_dir, 'data', 'new_tool.json'),
89
+ }
90
+
91
+ agent = TxAgent(model_name,
92
+ rag_model_name,
93
+ tool_files_dict=new_tool_files,
94
+ force_finish=True,
95
+ enable_checker=True,
96
+ step_rag_num=10,
97
+ seed=100,
98
+ additional_default_tools=['DirectResponse', 'RequireClarification'])
99
+ agent.init_model()
100
+
101
+
102
+ def update_model_parameters(enable_finish, enable_rag, enable_summary,
103
+ init_rag_num, step_rag_num, skip_last_k,
104
+ summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed):
105
+ # Update model instance parameters dynamically
106
+ updated_params = agent.update_parameters(
107
+ enable_finish=enable_finish,
108
+ enable_rag=enable_rag,
109
+ enable_summary=enable_summary,
110
+ init_rag_num=init_rag_num,
111
+ step_rag_num=step_rag_num,
112
+ skip_last_k=skip_last_k,
113
+ summary_mode=summary_mode,
114
+ summary_skip_last_k=summary_skip_last_k,
115
+ summary_context_length=summary_context_length,
116
+ force_finish=force_finish,
117
+ seed=seed,
118
+ )
119
+
120
+ return updated_params
121
+
122
+
123
+ def update_seed():
124
+ # Update model instance parameters dynamically
125
+ seed = random.randint(0, 10000)
126
+ updated_params = agent.update_parameters(
127
+ seed=seed,
128
+ )
129
+ return updated_params
130
+
131
+
132
+ def handle_retry(history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
133
+ print("Updated seed:", update_seed())
134
+ new_history = history[:retry_data.index]
135
+ previous_prompt = history[retry_data.index]['content']
136
+
137
+ print("previous_prompt", previous_prompt)
138
+
139
+ yield from agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
140
+
141
+
142
+ PASSWORD = "mypassword"
143
+
144
+ # Function to check if the password is correct
145
+
146
+
147
+ def check_password(input_password):
148
+ if input_password == PASSWORD:
149
+ return gr.update(visible=True), ""
150
+ else:
151
+ return gr.update(visible=False), "Incorrect password, try again!"
152
+
153
+
154
+ conversation_state = gr.State([])
155
+
156
+ # Gradio block
157
+ chatbot = gr.Chatbot(height=800, placeholder=PLACEHOLDER,
158
+ label='TxAgent', type="messages", show_copy_button=True)
159
+
160
+ with gr.Blocks(css=css) as demo:
161
+ gr.Markdown(DESCRIPTION)
162
+ gr.Markdown(INTRO)
163
+ default_temperature = 0.3
164
+ default_max_new_tokens = 1024
165
+ default_max_tokens = 81920
166
+ default_max_round = 30
167
+ temperature_state = gr.State(value=default_temperature)
168
+ max_new_tokens_state = gr.State(value=default_max_new_tokens)
169
+ max_tokens_state = gr.State(value=default_max_tokens)
170
+ max_round_state = gr.State(value=default_max_round)
171
+ chatbot.retry(handle_retry, chatbot, chatbot, temperature_state, max_new_tokens_state,
172
+ max_tokens_state, gr.Checkbox(value=False, render=False), conversation_state, max_round_state)
173
+
174
+ gr.ChatInterface(
175
+ fn=agent.run_gradio_chat,
176
+ chatbot=chatbot,
177
+ fill_height=True, fill_width=True, stop_btn=True,
178
+ additional_inputs_accordion=gr.Accordion(
179
+ label="⚙️ Inference Parameters", open=False, render=False),
180
+ additional_inputs=[
181
+ temperature_state, max_new_tokens_state, max_tokens_state,
182
+ gr.Checkbox(
183
+ label="Activate multi-agent reasoning mode (it requires additional time but offers a more comprehensive analysis).", value=False, render=False),
184
+ conversation_state,
185
+ max_round_state,
186
+ gr.Number(label="Seed", value=100, render=False)
187
+ ],
188
+ examples=question_examples,
189
+ cache_examples=False,
190
+ css=chat_css,
191
+ )
192
+
193
+ with gr.Accordion("Settings", open=False):
194
+
195
+ # Define the sliders
196
+ temperature_slider = gr.Slider(
197
+ minimum=0,
198
+ maximum=1,
199
+ step=0.1,
200
+ value=default_temperature,
201
+ label="Temperature"
202
+ )
203
+ max_new_tokens_slider = gr.Slider(
204
+ minimum=128,
205
+ maximum=4096,
206
+ step=1,
207
+ value=default_max_new_tokens,
208
+ label="Max new tokens"
209
+ )
210
+ max_tokens_slider = gr.Slider(
211
+ minimum=128,
212
+ maximum=32000,
213
+ step=1,
214
+ value=default_max_tokens,
215
+ label="Max tokens"
216
+ )
217
+ max_round_slider = gr.Slider(
218
+ minimum=0,
219
+ maximum=50,
220
+ step=1,
221
+ value=default_max_round,
222
+ label="Max round")
223
+
224
+ # Automatically update states when slider values change
225
+ temperature_slider.change(
226
+ lambda x: x, inputs=temperature_slider, outputs=temperature_state)
227
+ max_new_tokens_slider.change(
228
+ lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
229
+ max_tokens_slider.change(
230
+ lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
231
+ max_round_slider.change(
232
+ lambda x: x, inputs=max_round_slider, outputs=max_round_state)
233
+
234
+ password_input = gr.Textbox(
235
+ label="Enter Password for More Settings", type="password")
236
+ incorrect_message = gr.Textbox(visible=False, interactive=False)
237
+ with gr.Accordion("⚙️ Settings", open=False, visible=False) as protected_accordion:
238
+ with gr.Row():
239
+ with gr.Column(scale=1):
240
+ with gr.Accordion("⚙️ Model Loading", open=False):
241
+ model_name_input = gr.Textbox(
242
+ label="Enter model path", value=model_name)
243
+ load_model_btn = gr.Button(value="Load Model")
244
+ load_model_btn.click(
245
+ agent.load_models, inputs=model_name_input, outputs=gr.Textbox(label="Status"))
246
+ with gr.Column(scale=1):
247
+ with gr.Accordion("⚙️ Functional Parameters", open=False):
248
+ # Create Gradio components for parameter inputs
249
+ enable_finish = gr.Checkbox(
250
+ label="Enable Finish", value=True)
251
+ enable_rag = gr.Checkbox(
252
+ label="Enable RAG", value=True)
253
+ enable_summary = gr.Checkbox(
254
+ label="Enable Summary", value=False)
255
+ init_rag_num = gr.Number(
256
+ label="Initial RAG Num", value=0)
257
+ step_rag_num = gr.Number(
258
+ label="Step RAG Num", value=10)
259
+ skip_last_k = gr.Number(label="Skip Last K", value=0)
260
+ summary_mode = gr.Textbox(
261
+ label="Summary Mode", value='step')
262
+ summary_skip_last_k = gr.Number(
263
+ label="Summary Skip Last K", value=0)
264
+ summary_context_length = gr.Number(
265
+ label="Summary Context Length", value=None)
266
+ force_finish = gr.Checkbox(
267
+ label="Force FinalAnswer", value=True)
268
+ seed = gr.Number(label="Seed", value=100)
269
+ # Button to submit and update parameters
270
+ submit_btn = gr.Button("Update Parameters")
271
+
272
+ # Display the updated parameters
273
+ updated_parameters_output = gr.JSON()
274
+
275
+ # When button is clicked, update parameters
276
+ submit_btn.click(fn=update_model_parameters,
277
+ inputs=[enable_finish, enable_rag, enable_summary, init_rag_num, step_rag_num, skip_last_k,
278
+ summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed],
279
+ outputs=updated_parameters_output)
280
+ # Button to submit the password
281
+ submit_button = gr.Button("Submit")
282
+
283
+ # When the button is clicked, check if the password is correct
284
+ submit_button.click(
285
+ check_password,
286
+ inputs=password_input,
287
+ outputs=[protected_accordion, incorrect_message]
288
+ )
289
+ gr.Markdown(LICENSE)
290
+
291
+
292
+ if __name__ == "__main__":
293
+ demo.launch(share=True)
src/txagent/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .txagent import TxAgent
2
+ from .toolrag import ToolRAGModel
3
+ __all__ = [
4
+ "TxAgent",
5
+ "ToolRAGModel",
6
+ ]
src/txagent/toolrag.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import torch
3
+ import json
4
+ from .utils import get_md5
5
+
6
+
7
+ class ToolRAGModel:
8
+ def __init__(self, rag_model_name):
9
+ self.rag_model_name = rag_model_name
10
+ self.rag_model = None
11
+ self.tool_desc_embedding = None
12
+ self.tool_name = None
13
+ self.tool_embedding_path = None
14
+ self.load_rag_model()
15
+
16
+ def load_rag_model(self):
17
+ self.rag_model = SentenceTransformer(self.rag_model_name)
18
+ self.rag_model.max_seq_length = 4096
19
+ self.rag_model.tokenizer.padding_side = "right"
20
+
21
+ def load_tool_desc_embedding(self, toolbox):
22
+ self.tool_name, _ = toolbox.refresh_tool_name_desc(
23
+ enable_full_desc=True)
24
+ all_tools_str = [json.dumps(
25
+ each) for each in toolbox.prepare_tool_prompts(toolbox.all_tools)]
26
+ md5_value = get_md5(str(all_tools_str))
27
+ print("get the md value of tools:", md5_value)
28
+ self.tool_embedding_path = self.rag_model_name.split(
29
+ '/')[-1] + "tool_embedding_" + md5_value + ".pt"
30
+ try:
31
+ self.tool_desc_embedding = torch.load(
32
+ self.tool_embedding_path, weights_only=False)
33
+ assert len(self.tool_desc_embedding) == len(
34
+ toolbox.all_tools), "The number of tools in the toolbox is not equal to the number of tool_desc_embedding."
35
+ except:
36
+ self.tool_desc_embedding = None
37
+ print("\033[92mInferring the tool_desc_embedding.\033[0m")
38
+ self.tool_desc_embedding = self.rag_model.encode(
39
+ all_tools_str, prompt="", normalize_embeddings=True
40
+ )
41
+ torch.save(self.tool_desc_embedding, self.tool_embedding_path)
42
+ print("\033[92mFinished inferring the tool_desc_embedding.\033[0m")
43
+ print("\033[91mExiting. Please rerun the code to avoid the OOM issue.\033[0m")
44
+ exit()
45
+
46
+ def rag_infer(self, query, top_k=5):
47
+ torch.cuda.empty_cache()
48
+ queries = [query]
49
+ query_embeddings = self.rag_model.encode(
50
+ queries, prompt="", normalize_embeddings=True
51
+ )
52
+ if self.tool_desc_embedding is None:
53
+ print("No tool_desc_embedding")
54
+ exit()
55
+ scores = self.rag_model.similarity(
56
+ query_embeddings, self.tool_desc_embedding)
57
+ top_k = min(top_k, len(self.tool_name))
58
+ top_k_indices = torch.topk(scores, top_k).indices.tolist()[0]
59
+ top_k_tool_names = [self.tool_name[i] for i in top_k_indices]
60
+ return top_k_tool_names
src/txagent/txagent.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ import json
5
+ import gc
6
+ import numpy as np
7
+ from vllm import LLM, SamplingParams
8
+ from jinja2 import Template
9
+ from typing import List
10
+ import types
11
+ from tooluniverse import ToolUniverse
12
+ from gradio import ChatMessage
13
+ from .toolrag import ToolRAGModel
14
+
15
+ from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
16
+
17
+
18
+ class TxAgent:
19
+ def __init__(self, model_name,
20
+ rag_model_name,
21
+ tool_files_dict=None, # None leads to the default tool files in ToolUniverse
22
+ enable_finish=True,
23
+ enable_rag=True,
24
+ enable_summary=False,
25
+ init_rag_num=0,
26
+ step_rag_num=10,
27
+ summary_mode='step',
28
+ summary_skip_last_k=0,
29
+ summary_context_length=None,
30
+ force_finish=True,
31
+ avoid_repeat=True,
32
+ seed=None,
33
+ enable_checker=False,
34
+ enable_chat=False,
35
+ additional_default_tools=None,
36
+ ):
37
+ self.model_name = model_name
38
+ self.tokenizer = None
39
+ self.terminators = None
40
+ self.rag_model_name = rag_model_name
41
+ self.tool_files_dict = tool_files_dict
42
+ self.model = None
43
+ self.rag_model = ToolRAGModel(rag_model_name)
44
+ self.tooluniverse = None
45
+ # self.tool_desc = None
46
+ self.prompt_multi_step = "You are a helpful assistant that will solve problems through detailed, step-by-step reasoning and actions based on your reasoning. Typically, your actions will use the provided functions. You have access to the following functions."
47
+ self.self_prompt = "Strictly follow the instruction."
48
+ self.chat_prompt = "You are helpful assistant to chat with the user."
49
+ self.enable_finish = enable_finish
50
+ self.enable_rag = enable_rag
51
+ self.enable_summary = enable_summary
52
+ self.summary_mode = summary_mode
53
+ self.summary_skip_last_k = summary_skip_last_k
54
+ self.summary_context_length = summary_context_length
55
+ self.init_rag_num = init_rag_num
56
+ self.step_rag_num = step_rag_num
57
+ self.force_finish = force_finish
58
+ self.avoid_repeat = avoid_repeat
59
+ self.seed = seed
60
+ self.enable_checker = enable_checker
61
+ self.additional_default_tools = additional_default_tools
62
+ self.print_self_values()
63
+
64
+ def init_model(self):
65
+ self.load_models()
66
+ self.load_tooluniverse()
67
+ self.load_tool_desc_embedding()
68
+
69
+ def print_self_values(self):
70
+ for attr, value in self.__dict__.items():
71
+ print(f"{attr}: {value}")
72
+
73
+ def load_models(self, model_name=None):
74
+ if model_name is not None:
75
+ if model_name == self.model_name:
76
+ return f"The model {model_name} is already loaded."
77
+ self.model_name = model_name
78
+
79
+ self.model = LLM(model=self.model_name)
80
+ self.chat_template = Template(self.model.get_tokenizer().chat_template)
81
+ self.tokenizer = self.model.get_tokenizer()
82
+
83
+ return f"Model {model_name} loaded successfully."
84
+
85
+ def load_tooluniverse(self):
86
+ self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
87
+ self.tooluniverse.load_tools()
88
+ special_tools = self.tooluniverse.prepare_tool_prompts(
89
+ self.tooluniverse.tool_category_dicts["special_tools"])
90
+ self.special_tools_name = [tool['name'] for tool in special_tools]
91
+
92
+ def load_tool_desc_embedding(self):
93
+ self.rag_model.load_tool_desc_embedding(self.tooluniverse)
94
+
95
+ def rag_infer(self, query, top_k=5):
96
+ return self.rag_model.rag_infer(query, top_k)
97
+
98
+ def initialize_tools_prompt(self, call_agent, call_agent_level, message):
99
+ picked_tools_prompt = []
100
+ picked_tools_prompt = self.add_special_tools(
101
+ picked_tools_prompt, call_agent=call_agent)
102
+ if call_agent:
103
+ call_agent_level += 1
104
+ if call_agent_level >= 2:
105
+ call_agent = False
106
+
107
+ if not call_agent:
108
+ picked_tools_prompt += self.tool_RAG(
109
+ message=message, rag_num=self.init_rag_num)
110
+ return picked_tools_prompt, call_agent_level
111
+
112
+ def initialize_conversation(self, message, conversation=None, history=None):
113
+ if conversation is None:
114
+ conversation = []
115
+
116
+ conversation = self.set_system_prompt(
117
+ conversation, self.prompt_multi_step)
118
+ if history is not None:
119
+ if len(history) == 0:
120
+ conversation = []
121
+ print("clear conversation successfully")
122
+ else:
123
+ for i in range(len(history)):
124
+ if history[i]['role'] == 'user':
125
+ if i-1 >= 0 and history[i-1]['role'] == 'assistant':
126
+ conversation.append(
127
+ {"role": "assistant", "content": history[i-1]['content']})
128
+ conversation.append(
129
+ {"role": "user", "content": history[i]['content']})
130
+ if i == len(history)-1 and history[i]['role'] == 'assistant':
131
+ conversation.append(
132
+ {"role": "assistant", "content": history[i]['content']})
133
+
134
+ conversation.append({"role": "user", "content": message})
135
+
136
+ return conversation
137
+
138
+ def tool_RAG(self, message=None,
139
+ picked_tool_names=None,
140
+ existing_tools_prompt=[],
141
+ rag_num=5,
142
+ return_call_result=False):
143
+ extra_factor = 30 # Factor to retrieve more than rag_num
144
+ if picked_tool_names is None:
145
+ assert picked_tool_names is not None or message is not None
146
+ picked_tool_names = self.rag_infer(
147
+ message, top_k=rag_num*extra_factor)
148
+
149
+ picked_tool_names_no_special = []
150
+ for tool in picked_tool_names:
151
+ if tool not in self.special_tools_name:
152
+ picked_tool_names_no_special.append(tool)
153
+ picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
154
+ picked_tool_names = picked_tool_names_no_special[:rag_num]
155
+
156
+ picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
157
+ picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(
158
+ picked_tools)
159
+ if return_call_result:
160
+ return picked_tools_prompt, picked_tool_names
161
+ return picked_tools_prompt
162
+
163
+ def add_special_tools(self, tools, call_agent=False):
164
+ if self.enable_finish:
165
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
166
+ 'Finish', return_prompt=True))
167
+ print("Finish tool is added")
168
+ if call_agent:
169
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
170
+ 'CallAgent', return_prompt=True))
171
+ print("CallAgent tool is added")
172
+ else:
173
+ if self.enable_rag:
174
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
175
+ 'Tool_RAG', return_prompt=True))
176
+ print("Tool_RAG tool is added")
177
+
178
+ if self.additional_default_tools is not None:
179
+ for each_tool_name in self.additional_default_tools:
180
+ tool_prompt = self.tooluniverse.get_one_tool_by_one_name(
181
+ each_tool_name, return_prompt=True)
182
+ if tool_prompt is not None:
183
+ print(f"{each_tool_name} tool is added")
184
+ tools.append(tool_prompt)
185
+ return tools
186
+
187
+ def add_finish_tools(self, tools):
188
+ tools.append(self.tooluniverse.get_one_tool_by_one_name(
189
+ 'Finish', return_prompt=True))
190
+ print("Finish tool is added")
191
+ return tools
192
+
193
+ def set_system_prompt(self, conversation, sys_prompt):
194
+ if len(conversation) == 0:
195
+ conversation.append(
196
+ {"role": "system", "content": sys_prompt})
197
+ else:
198
+ conversation[0] = {"role": "system", "content": sys_prompt}
199
+ return conversation
200
+
201
+ def run_function_call(self, fcall_str,
202
+ return_message=False,
203
+ existing_tools_prompt=None,
204
+ message_for_call_agent=None,
205
+ call_agent=False,
206
+ call_agent_level=None,
207
+ temperature=None):
208
+
209
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
210
+ fcall_str, return_message=return_message, verbose=False)
211
+ call_results = []
212
+ special_tool_call = ''
213
+ if function_call_json is not None:
214
+ if isinstance(function_call_json, list):
215
+ for i in range(len(function_call_json)):
216
+ print("\033[94mTool Call:\033[0m", function_call_json[i])
217
+ if function_call_json[i]["name"] == 'Finish':
218
+ special_tool_call = 'Finish'
219
+ break
220
+ elif function_call_json[i]["name"] == 'Tool_RAG':
221
+ new_tools_prompt, call_result = self.tool_RAG(
222
+ message=message,
223
+ existing_tools_prompt=existing_tools_prompt,
224
+ rag_num=self.step_rag_num,
225
+ return_call_result=True)
226
+ existing_tools_prompt += new_tools_prompt
227
+ elif function_call_json[i]["name"] == 'CallAgent':
228
+ if call_agent_level < 2 and call_agent:
229
+ solution_plan = function_call_json[i]['arguments']['solution']
230
+ full_message = (
231
+ message_for_call_agent +
232
+ "\nYou must follow the following plan to answer the question: " +
233
+ str(solution_plan)
234
+ )
235
+ call_result = self.run_multistep_agent(
236
+ full_message, temperature=temperature,
237
+ max_new_tokens=1024, max_token=99999,
238
+ call_agent=False, call_agent_level=call_agent_level)
239
+ call_result = call_result.split(
240
+ '[FinalAnswer]')[-1].strip()
241
+ else:
242
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
243
+ else:
244
+ call_result = self.tooluniverse.run_one_function(
245
+ function_call_json[i])
246
+
247
+ call_id = self.tooluniverse.call_id_gen()
248
+ function_call_json[i]["call_id"] = call_id
249
+ print("\033[94mTool Call Result:\033[0m", call_result)
250
+ call_results.append({
251
+ "role": "tool",
252
+ "content": json.dumps({"content": call_result, "call_id": call_id})
253
+ })
254
+ else:
255
+ call_results.append({
256
+ "role": "tool",
257
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
258
+ })
259
+
260
+ revised_messages = [{
261
+ "role": "assistant",
262
+ "content": message.strip(),
263
+ "tool_calls": json.dumps(function_call_json)
264
+ }] + call_results
265
+
266
+ # Yield the final result.
267
+ return revised_messages, existing_tools_prompt, special_tool_call
268
+
269
+ def run_function_call_stream(self, fcall_str,
270
+ return_message=False,
271
+ existing_tools_prompt=None,
272
+ message_for_call_agent=None,
273
+ call_agent=False,
274
+ call_agent_level=None,
275
+ temperature=None,
276
+ return_gradio_history=True):
277
+
278
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
279
+ fcall_str, return_message=return_message, verbose=False)
280
+ call_results = []
281
+ special_tool_call = ''
282
+ if return_gradio_history:
283
+ gradio_history = []
284
+ if function_call_json is not None:
285
+ if isinstance(function_call_json, list):
286
+ for i in range(len(function_call_json)):
287
+ if function_call_json[i]["name"] == 'Finish':
288
+ special_tool_call = 'Finish'
289
+ break
290
+ elif function_call_json[i]["name"] == 'Tool_RAG':
291
+ new_tools_prompt, call_result = self.tool_RAG(
292
+ message=message,
293
+ existing_tools_prompt=existing_tools_prompt,
294
+ rag_num=self.step_rag_num,
295
+ return_call_result=True)
296
+ existing_tools_prompt += new_tools_prompt
297
+ elif function_call_json[i]["name"] == 'DirectResponse':
298
+ call_result = function_call_json[i]['arguments']['respose']
299
+ special_tool_call = 'DirectResponse'
300
+ elif function_call_json[i]["name"] == 'RequireClarification':
301
+ call_result = function_call_json[i]['arguments']['unclear_question']
302
+ special_tool_call = 'RequireClarification'
303
+ elif function_call_json[i]["name"] == 'CallAgent':
304
+ if call_agent_level < 2 and call_agent:
305
+ solution_plan = function_call_json[i]['arguments']['solution']
306
+ full_message = (
307
+ message_for_call_agent +
308
+ "\nYou must follow the following plan to answer the question: " +
309
+ str(solution_plan)
310
+ )
311
+ sub_agent_task = "Sub TxAgent plan: " + \
312
+ str(solution_plan)
313
+ # When streaming, yield responses as they arrive.
314
+ call_result = yield from self.run_gradio_chat(
315
+ full_message, history=[], temperature=temperature,
316
+ max_new_tokens=1024, max_token=99999,
317
+ call_agent=False, call_agent_level=call_agent_level,
318
+ conversation=None,
319
+ sub_agent_task=sub_agent_task)
320
+
321
+ call_result = call_result.split(
322
+ '[FinalAnswer]')[-1]
323
+ else:
324
+ call_result = "Error: The CallAgent has been disabled. Please proceed with your reasoning process to solve this question."
325
+ else:
326
+ call_result = self.tooluniverse.run_one_function(
327
+ function_call_json[i])
328
+
329
+ call_id = self.tooluniverse.call_id_gen()
330
+ function_call_json[i]["call_id"] = call_id
331
+ call_results.append({
332
+ "role": "tool",
333
+ "content": json.dumps({"content": call_result, "call_id": call_id})
334
+ })
335
+ if return_gradio_history and function_call_json[i]["name"] != 'Finish':
336
+ if function_call_json[i]["name"] == 'Tool_RAG':
337
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
338
+ "title": "🧰 "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
339
+
340
+ else:
341
+ gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata={
342
+ "title": "⚒️ "+function_call_json[i]['name'], "log": str(function_call_json[i]['arguments'])}))
343
+ else:
344
+ call_results.append({
345
+ "role": "tool",
346
+ "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
347
+ })
348
+
349
+ revised_messages = [{
350
+ "role": "assistant",
351
+ "content": message.strip(),
352
+ "tool_calls": json.dumps(function_call_json)
353
+ }] + call_results
354
+
355
+ # Yield the final result.
356
+ if return_gradio_history:
357
+ return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
358
+ else:
359
+ return revised_messages, existing_tools_prompt, special_tool_call
360
+
361
+ def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
362
+ if conversation[-1]['role'] == 'assisant':
363
+ conversation.append(
364
+ {'role': 'tool', 'content': 'Errors happen during the function call, please come up with the final answer with the current information.'})
365
+ finish_tools_prompt = self.add_finish_tools([])
366
+
367
+ last_outputs_str = self.llm_infer(messages=conversation,
368
+ temperature=temperature,
369
+ tools=finish_tools_prompt,
370
+ output_begin_string='Since I cannot continue reasoning, I will provide the final answer based on the current information and general knowledge.\n\n[FinalAnswer]',
371
+ skip_special_tokens=True,
372
+ max_new_tokens=max_new_tokens, max_token=max_token)
373
+ print(last_outputs_str)
374
+ return last_outputs_str
375
+
376
+ def run_multistep_agent(self, message: str,
377
+ temperature: float,
378
+ max_new_tokens: int,
379
+ max_token: int,
380
+ max_round: int = 20,
381
+ call_agent=False,
382
+ call_agent_level=0) -> str:
383
+ """
384
+ Generate a streaming response using the llama3-8b model.
385
+ Args:
386
+ message (str): The input message.
387
+ temperature (float): The temperature for generating the response.
388
+ max_new_tokens (int): The maximum number of new tokens to generate.
389
+ Returns:
390
+ str: The generated response.
391
+ """
392
+ print("\033[1;32;40mstart\033[0m")
393
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
394
+ call_agent, call_agent_level, message)
395
+ conversation = self.initialize_conversation(message)
396
+
397
+ outputs = []
398
+ last_outputs = []
399
+ next_round = True
400
+ function_call_messages = []
401
+ current_round = 0
402
+ token_overflow = False
403
+ enable_summary = False
404
+ last_status = {}
405
+
406
+ if self.enable_checker:
407
+ checker = ReasoningTraceChecker(message, conversation)
408
+ try:
409
+ while next_round and current_round < max_round:
410
+ current_round += 1
411
+ if len(outputs) > 0:
412
+ function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
413
+ last_outputs, return_message=True,
414
+ existing_tools_prompt=picked_tools_prompt,
415
+ message_for_call_agent=message,
416
+ call_agent=call_agent,
417
+ call_agent_level=call_agent_level,
418
+ temperature=temperature)
419
+
420
+ if special_tool_call == 'Finish':
421
+ next_round = False
422
+ conversation.extend(function_call_messages)
423
+ if isinstance(function_call_messages[0]['content'], types.GeneratorType):
424
+ function_call_messages[0]['content'] = next(
425
+ function_call_messages[0]['content'])
426
+ return function_call_messages[0]['content'].split('[FinalAnswer]')[-1]
427
+
428
+ if (self.enable_summary or token_overflow) and not call_agent:
429
+ if token_overflow:
430
+ print("token_overflow, using summary")
431
+ enable_summary = True
432
+ last_status = self.function_result_summary(
433
+ conversation, status=last_status, enable_summary=enable_summary)
434
+
435
+ if function_call_messages is not None:
436
+ conversation.extend(function_call_messages)
437
+ outputs.append(tool_result_format(
438
+ function_call_messages))
439
+ else:
440
+ next_round = False
441
+ conversation.extend(
442
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
443
+ return ''.join(last_outputs).replace("</s>", "")
444
+ if self.enable_checker:
445
+ good_status, wrong_info = checker.check_conversation()
446
+ if not good_status:
447
+ next_round = False
448
+ print(
449
+ "Internal error in reasoning: " + wrong_info)
450
+ break
451
+ last_outputs = []
452
+ outputs.append("### TxAgent:\n")
453
+ last_outputs_str, token_overflow = self.llm_infer(messages=conversation,
454
+ temperature=temperature,
455
+ tools=picked_tools_prompt,
456
+ skip_special_tokens=False,
457
+ max_new_tokens=max_new_tokens, max_token=max_token,
458
+ check_token_status=True)
459
+ if last_outputs_str is None:
460
+ next_round = False
461
+ print(
462
+ "The number of tokens exceeds the maximum limit.")
463
+ else:
464
+ last_outputs.append(last_outputs_str)
465
+ if max_round == current_round:
466
+ print("The number of rounds exceeds the maximum limit!")
467
+ if self.force_finish:
468
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
469
+ else:
470
+ return None
471
+
472
+ except Exception as e:
473
+ print(f"Error: {e}")
474
+ if self.force_finish:
475
+ return self.get_answer_based_on_unfinished_reasoning(conversation, temperature, max_new_tokens, max_token)
476
+ else:
477
+ return None
478
+
479
+ def build_logits_processor(self, messages, llm):
480
+ # Use the tokenizer from the LLM instance.
481
+ tokenizer = llm.get_tokenizer()
482
+ if self.avoid_repeat and len(messages) > 2:
483
+ assistant_messages = []
484
+ for i in range(1, len(messages) + 1):
485
+ if messages[-i]['role'] == 'assistant':
486
+ assistant_messages.append(messages[-i]['content'])
487
+ if len(assistant_messages) == 2:
488
+ break
489
+ forbidden_ids = [tokenizer.encode(
490
+ msg, add_special_tokens=False) for msg in assistant_messages]
491
+ return [NoRepeatSentenceProcessor(forbidden_ids, 5)]
492
+ else:
493
+ return None
494
+
495
+ def llm_infer(self, messages, temperature=0.1, tools=None,
496
+ output_begin_string=None, max_new_tokens=2048,
497
+ max_token=None, skip_special_tokens=True,
498
+ model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
499
+
500
+ if model is None:
501
+ model = self.model
502
+
503
+ logits_processor = self.build_logits_processor(messages, model)
504
+ sampling_params = SamplingParams(
505
+ temperature=temperature,
506
+ max_tokens=max_new_tokens,
507
+ logits_processors=logits_processor,
508
+ seed=seed if seed is not None else self.seed,
509
+ )
510
+
511
+ prompt = self.chat_template.render(
512
+ messages=messages, tools=tools, add_generation_prompt=True)
513
+ if output_begin_string is not None:
514
+ prompt += output_begin_string
515
+
516
+ if check_token_status and max_token is not None:
517
+ token_overflow = False
518
+ num_input_tokens = len(self.tokenizer.encode(
519
+ prompt, return_tensors="pt")[0])
520
+ if max_token is not None:
521
+ if num_input_tokens > max_token:
522
+ torch.cuda.empty_cache()
523
+ gc.collect()
524
+ print("Number of input tokens before inference:",
525
+ num_input_tokens)
526
+ logger.info(
527
+ "The number of tokens exceeds the maximum limit!!!!")
528
+ token_overflow = True
529
+ return None, token_overflow
530
+ output = model.generate(
531
+ prompt,
532
+ sampling_params=sampling_params,
533
+ )
534
+ output = output[0].outputs[0].text
535
+ print("\033[92m" + output + "\033[0m")
536
+ if check_token_status and max_token is not None:
537
+ return output, token_overflow
538
+
539
+ return output
540
+
541
+ def run_self_agent(self, message: str,
542
+ temperature: float,
543
+ max_new_tokens: int,
544
+ max_token: int) -> str:
545
+
546
+ print("\033[1;32;40mstart self agent\033[0m")
547
+ conversation = []
548
+ conversation = self.set_system_prompt(conversation, self.self_prompt)
549
+ conversation.append({"role": "user", "content": message})
550
+ return self.llm_infer(messages=conversation,
551
+ temperature=temperature,
552
+ tools=None,
553
+ max_new_tokens=max_new_tokens, max_token=max_token)
554
+
555
+ def run_chat_agent(self, message: str,
556
+ temperature: float,
557
+ max_new_tokens: int,
558
+ max_token: int) -> str:
559
+
560
+ print("\033[1;32;40mstart chat agent\033[0m")
561
+ conversation = []
562
+ conversation = self.set_system_prompt(conversation, self.chat_prompt)
563
+ conversation.append({"role": "user", "content": message})
564
+ return self.llm_infer(messages=conversation,
565
+ temperature=temperature,
566
+ tools=None,
567
+ max_new_tokens=max_new_tokens, max_token=max_token)
568
+
569
+ def run_format_agent(self, message: str,
570
+ answer: str,
571
+ temperature: float,
572
+ max_new_tokens: int,
573
+ max_token: int) -> str:
574
+
575
+ print("\033[1;32;40mstart format agent\033[0m")
576
+ if '[FinalAnswer]' in answer:
577
+ possible_final_answer = answer.split("[FinalAnswer]")[-1]
578
+ elif "\n\n" in answer:
579
+ possible_final_answer = answer.split("\n\n")[-1]
580
+ else:
581
+ possible_final_answer = answer.strip()
582
+ if len(possible_final_answer) == 1:
583
+ choice = possible_final_answer[0]
584
+ if choice in ['A', 'B', 'C', 'D', 'E']:
585
+ return choice
586
+ elif len(possible_final_answer) > 1:
587
+ if possible_final_answer[1] == ':':
588
+ choice = possible_final_answer[0]
589
+ if choice in ['A', 'B', 'C', 'D', 'E']:
590
+ print("choice", choice)
591
+ return choice
592
+
593
+ conversation = []
594
+ format_prompt = f"You are helpful assistant to transform the answer of agent to the final answer of 'A', 'B', 'C', 'D'."
595
+ conversation = self.set_system_prompt(conversation, format_prompt)
596
+ conversation.append({"role": "user", "content": message +
597
+ "\nThe final answer of agent:" + answer + "\n The answer is (must be a letter):"})
598
+ return self.llm_infer(messages=conversation,
599
+ temperature=temperature,
600
+ tools=None,
601
+ max_new_tokens=max_new_tokens, max_token=max_token)
602
+
603
+ def run_summary_agent(self, thought_calls: str,
604
+ function_response: str,
605
+ temperature: float,
606
+ max_new_tokens: int,
607
+ max_token: int) -> str:
608
+ print("\033[1;32;40mSummarized Tool Result:\033[0m")
609
+ generate_tool_result_summary_training_prompt = """Thought and function calls:
610
+ {thought_calls}
611
+
612
+ Function calls' responses:
613
+ \"\"\"
614
+ {function_response}
615
+ \"\"\"
616
+
617
+ Based on the Thought and function calls, and the function calls' responses, you need to generate a summary of the function calls' responses that fulfills the requirements of the thought. The summary MUST BE ONE sentence and include all necessary information.
618
+
619
+ Directly respond with the summarized sentence of the function calls' responses only.
620
+
621
+ Generate **one summarized sentence** about "function calls' responses" with necessary information, and respond with a string:
622
+ """.format(thought_calls=thought_calls, function_response=function_response)
623
+ conversation = []
624
+ conversation.append(
625
+ {"role": "user", "content": generate_tool_result_summary_training_prompt})
626
+ output = self.llm_infer(messages=conversation,
627
+ temperature=temperature,
628
+ tools=None,
629
+ max_new_tokens=max_new_tokens, max_token=max_token)
630
+
631
+ if '[' in output:
632
+ output = output.split('[')[0]
633
+ return output
634
+
635
+ def function_result_summary(self, input_list, status, enable_summary):
636
+ """
637
+ Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
638
+ Supports 'length' and 'step' modes, and skips the last 'k' groups.
639
+
640
+ Parameters:
641
+ input_list (list): A list of dictionaries containing role and other information.
642
+ summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
643
+ summary_context_length (int): The context length threshold for the 'length' mode.
644
+ last_processed_index (tuple or int): The last processed index.
645
+
646
+ Returns:
647
+ list: A list of extracted information from valid sequences.
648
+ """
649
+ if 'tool_call_step' not in status:
650
+ status['tool_call_step'] = 0
651
+
652
+ for idx in range(len(input_list)):
653
+ pos_id = len(input_list)-idx-1
654
+ if input_list[pos_id]['role'] == 'assistant':
655
+ if 'tool_calls' in input_list[pos_id]:
656
+ if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
657
+ status['tool_call_step'] += 1
658
+ break
659
+
660
+ if 'step' in status:
661
+ status['step'] += 1
662
+ else:
663
+ status['step'] = 0
664
+
665
+ if not enable_summary:
666
+ return status
667
+
668
+ if 'summarized_index' not in status:
669
+ status['summarized_index'] = 0
670
+
671
+ if 'summarized_step' not in status:
672
+ status['summarized_step'] = 0
673
+
674
+ if 'previous_length' not in status:
675
+ status['previous_length'] = 0
676
+
677
+ if 'history' not in status:
678
+ status['history'] = []
679
+
680
+ function_response = ''
681
+ idx = 0
682
+ current_summarized_index = status['summarized_index']
683
+
684
+ status['history'].append(self.summary_mode == 'step' and status['summarized_step']
685
+ < status['step']-status['tool_call_step']-self.summary_skip_last_k)
686
+
687
+ idx = current_summarized_index
688
+ while idx < len(input_list):
689
+ if (self.summary_mode == 'step' and status['summarized_step'] < status['step']-status['tool_call_step']-self.summary_skip_last_k) or (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
690
+
691
+ if input_list[idx]['role'] == 'assistant':
692
+ if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
693
+ this_thought_calls = None
694
+ else:
695
+ if len(function_response) != 0:
696
+ print("internal summary")
697
+ status['summarized_step'] += 1
698
+ result_summary = self.run_summary_agent(
699
+ thought_calls=this_thought_calls,
700
+ function_response=function_response,
701
+ temperature=0.1,
702
+ max_new_tokens=1024,
703
+ max_token=99999
704
+ )
705
+
706
+ input_list.insert(
707
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
708
+ status['summarized_index'] = last_call_idx + 2
709
+ idx += 1
710
+
711
+ last_call_idx = idx
712
+ this_thought_calls = input_list[idx]['content'] + \
713
+ input_list[idx]['tool_calls']
714
+ function_response = ''
715
+
716
+ elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None:
717
+ function_response += input_list[idx]['content']
718
+ del input_list[idx]
719
+ idx -= 1
720
+
721
+ else:
722
+ break
723
+ idx += 1
724
+
725
+ if len(function_response) != 0:
726
+ status['summarized_step'] += 1
727
+ result_summary = self.run_summary_agent(
728
+ thought_calls=this_thought_calls,
729
+ function_response=function_response,
730
+ temperature=0.1,
731
+ max_new_tokens=1024,
732
+ max_token=99999
733
+ )
734
+
735
+ tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
736
+ for tool_call in tool_calls:
737
+ del tool_call['call_id']
738
+ input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
739
+ input_list.insert(
740
+ last_call_idx+1, {'role': 'tool', 'content': result_summary})
741
+ status['summarized_index'] = last_call_idx + 2
742
+
743
+ return status
744
+
745
+ # Following are Gradio related functions
746
+
747
+ # General update method that accepts any new arguments through kwargs
748
+ def update_parameters(self, **kwargs):
749
+ for key, value in kwargs.items():
750
+ if hasattr(self, key):
751
+ setattr(self, key, value)
752
+
753
+ # Return the updated attributes
754
+ updated_attributes = {key: value for key,
755
+ value in kwargs.items() if hasattr(self, key)}
756
+ return updated_attributes
757
+
758
+ def run_gradio_chat(self, message: str,
759
+ history: list,
760
+ temperature: float,
761
+ max_new_tokens: int,
762
+ max_token: int,
763
+ call_agent: bool,
764
+ conversation: gr.State,
765
+ max_round: int = 20,
766
+ seed: int = None,
767
+ call_agent_level: int = 0,
768
+ sub_agent_task: str = None) -> str:
769
+ """
770
+ Generate a streaming response using the llama3-8b model.
771
+ Args:
772
+ message (str): The input message.
773
+ history (list): The conversation history used by ChatInterface.
774
+ temperature (float): The temperature for generating the response.
775
+ max_new_tokens (int): The maximum number of new tokens to generate.
776
+ Returns:
777
+ str: The generated response.
778
+ """
779
+ print("\033[1;32;40mstart\033[0m")
780
+ print("len(message)", len(message))
781
+ if len(message) <= 10:
782
+ yield "Hi, I am TxAgent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
783
+ return "Please provide a valid message."
784
+ outputs = []
785
+ outputs_str = ''
786
+ last_outputs = []
787
+
788
+ picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
789
+ call_agent,
790
+ call_agent_level,
791
+ message)
792
+
793
+ conversation = self.initialize_conversation(
794
+ message,
795
+ conversation=conversation,
796
+ history=history)
797
+ history = []
798
+
799
+ next_round = True
800
+ function_call_messages = []
801
+ current_round = 0
802
+ enable_summary = False
803
+ last_status = {} # for summary
804
+ token_overflow = False
805
+ if self.enable_checker:
806
+ checker = ReasoningTraceChecker(
807
+ message, conversation, init_index=len(conversation))
808
+
809
+ try:
810
+ while next_round and current_round < max_round:
811
+ current_round += 1
812
+ if len(last_outputs) > 0:
813
+ function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
814
+ last_outputs, return_message=True,
815
+ existing_tools_prompt=picked_tools_prompt,
816
+ message_for_call_agent=message,
817
+ call_agent=call_agent,
818
+ call_agent_level=call_agent_level,
819
+ temperature=temperature)
820
+ history.extend(current_gradio_history)
821
+ if special_tool_call == 'Finish':
822
+ yield history
823
+ next_round = False
824
+ conversation.extend(function_call_messages)
825
+ return function_call_messages[0]['content']
826
+ elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
827
+ history.append(
828
+ ChatMessage(role="assistant", content=history[-1].content))
829
+ yield history
830
+ next_round = False
831
+ return history[-1].content
832
+ if (self.enable_summary or token_overflow) and not call_agent:
833
+ if token_overflow:
834
+ print("token_overflow, using summary")
835
+ enable_summary = True
836
+ last_status = self.function_result_summary(
837
+ conversation, status=last_status,
838
+ enable_summary=enable_summary)
839
+ if function_call_messages is not None:
840
+ conversation.extend(function_call_messages)
841
+ formated_md_function_call_messages = tool_result_format(
842
+ function_call_messages)
843
+ yield history
844
+ else:
845
+ next_round = False
846
+ conversation.extend(
847
+ [{"role": "assistant", "content": ''.join(last_outputs)}])
848
+ return ''.join(last_outputs).replace("</s>", "")
849
+ if self.enable_checker:
850
+ good_status, wrong_info = checker.check_conversation()
851
+ if not good_status:
852
+ next_round = False
853
+ print("Internal error in reasoning: " + wrong_info)
854
+ break
855
+ last_outputs = []
856
+ last_outputs_str, token_overflow = self.llm_infer(
857
+ messages=conversation,
858
+ temperature=temperature,
859
+ tools=picked_tools_prompt,
860
+ skip_special_tokens=False,
861
+ max_new_tokens=max_new_tokens,
862
+ max_token=max_token,
863
+ seed=seed,
864
+ check_token_status=True)
865
+ last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
866
+ for each in history:
867
+ if each.metadata is not None:
868
+ each.metadata['status'] = 'done'
869
+ if '[FinalAnswer]' in last_thought:
870
+ final_thought, final_answer = last_thought.split(
871
+ '[FinalAnswer]')
872
+ history.append(
873
+ ChatMessage(role="assistant",
874
+ content=final_thought.strip())
875
+ )
876
+ yield history
877
+ history.append(
878
+ ChatMessage(
879
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
880
+ )
881
+ yield history
882
+ else:
883
+ history.append(ChatMessage(
884
+ role="assistant", content=last_thought))
885
+ yield history
886
+
887
+ last_outputs.append(last_outputs_str)
888
+
889
+ if next_round:
890
+ if self.force_finish:
891
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
892
+ conversation, temperature, max_new_tokens, max_token)
893
+ for each in history:
894
+ if each.metadata is not None:
895
+ each.metadata['status'] = 'done'
896
+ if '[FinalAnswer]' in last_thought:
897
+ final_thought, final_answer = last_thought.split(
898
+ '[FinalAnswer]')
899
+ history.append(
900
+ ChatMessage(role="assistant",
901
+ content=final_thought.strip())
902
+ )
903
+ yield history
904
+ history.append(
905
+ ChatMessage(
906
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
907
+ )
908
+ yield history
909
+ else:
910
+ yield "The number of rounds exceeds the maximum limit!"
911
+
912
+ except Exception as e:
913
+ print(f"Error: {e}")
914
+ if self.force_finish:
915
+ last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
916
+ conversation,
917
+ temperature,
918
+ max_new_tokens,
919
+ max_token)
920
+ for each in history:
921
+ if each.metadata is not None:
922
+ each.metadata['status'] = 'done'
923
+ if '[FinalAnswer]' in last_thought or '"name": "Finish",' in last_outputs_str:
924
+ final_thought, final_answer = last_thought.split(
925
+ '[FinalAnswer]')
926
+ history.append(
927
+ ChatMessage(role="assistant",
928
+ content=final_thought.strip())
929
+ )
930
+ yield history
931
+ history.append(
932
+ ChatMessage(
933
+ role="assistant", content="**Answer**:\n"+final_answer.strip())
934
+ )
935
+ yield history
936
+ else:
937
+ return None
src/txagent/utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import json
3
+ import hashlib
4
+ import torch
5
+ from typing import List
6
+
7
+
8
+ def get_md5(input_str):
9
+ # Create an MD5 hash object
10
+ md5_hash = hashlib.md5()
11
+
12
+ # Encode the string and update the hash object
13
+ md5_hash.update(input_str.encode('utf-8'))
14
+
15
+ # Return the hexadecimal MD5 digest
16
+ return md5_hash.hexdigest()
17
+
18
+
19
+ def tool_result_format(function_call_messages):
20
+ current_output = "\n\n<details>\n<summary> <strong>Verfied Feedback from Tools</strong>, click to see details:</summary>\n\n"
21
+ for each_message in function_call_messages:
22
+ if each_message['role'] == 'tool':
23
+ current_output += f"{each_message['content']}\n\n"
24
+ current_output += "</details>\n\n\n"
25
+ return current_output
26
+
27
+
28
+ class NoRepeatSentenceProcessor:
29
+ def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int):
30
+ """
31
+ Args:
32
+ forbidden_sequences (List[List[int]]): A list of token ID sequences corresponding to forbidden sentences.
33
+ allowed_prefix_length (int): The number k such that if the generated tokens match the first k tokens
34
+ of a forbidden sequence, then the candidate token that would extend the match is blocked.
35
+ """
36
+ self.allowed_prefix_length = allowed_prefix_length
37
+ # Build a lookup dictionary: key is a tuple of the first k tokens, value is a set of tokens to block.
38
+ self.forbidden_prefix_dict = {}
39
+ for seq in forbidden_sequences:
40
+ if len(seq) > allowed_prefix_length:
41
+ prefix = tuple(seq[:allowed_prefix_length])
42
+ next_token = seq[allowed_prefix_length]
43
+ self.forbidden_prefix_dict.setdefault(
44
+ prefix, set()).add(next_token)
45
+
46
+ def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Modifies the logits to block tokens that would extend a forbidden sentence.
49
+
50
+ Args:
51
+ token_ids (List[int]): List of token IDs generated so far.
52
+ logits (torch.Tensor): Logits tensor for the next token (shape: [vocab_size]).
53
+
54
+ Returns:
55
+ torch.Tensor: Modified logits.
56
+ """
57
+ if len(token_ids) >= self.allowed_prefix_length:
58
+ prefix = tuple(token_ids[:self.allowed_prefix_length])
59
+ if prefix in self.forbidden_prefix_dict:
60
+ for token_id in self.forbidden_prefix_dict[prefix]:
61
+ logits[token_id] = -float("inf")
62
+ return logits
63
+
64
+
65
+ class ReasoningTraceChecker:
66
+ def __init__(self, question, conversation, init_index=None):
67
+ self.question = question
68
+ self.conversation = conversation
69
+ self.existing_thoughts = []
70
+ self.existing_actions = []
71
+ if init_index is not None:
72
+ self.index = init_index
73
+ else:
74
+ self.index = 1
75
+ self.question = self.question.lower()
76
+ self.new_thoughts = []
77
+ self.new_actions = []
78
+
79
+ def check_conversation(self):
80
+ info = ''
81
+ current_index = self.index
82
+ for i in range(current_index, len(self.conversation)):
83
+ each = self.conversation[i]
84
+ self.index = i
85
+ if each['role'] == 'assistant':
86
+ print(each)
87
+ thought = each['content']
88
+ actions = each['tool_calls']
89
+
90
+ good_status, current_info = self.check_repeat_thought(thought)
91
+ info += current_info
92
+ if not good_status:
93
+ return False, info
94
+
95
+ good_status, current_info = self.check_repeat_action(actions)
96
+ info += current_info
97
+ if not good_status:
98
+ return False, info
99
+ return True, info
100
+
101
+ def check_repeat_thought(self, thought):
102
+ if thought in self.existing_thoughts:
103
+ return False, "repeat_thought"
104
+ self.existing_thoughts.append(thought)
105
+ return True, ''
106
+
107
+ def check_repeat_action(self, actions):
108
+ if type(actions) != list:
109
+ actions = json.loads(actions)
110
+ for each_action in actions:
111
+ if 'call_id' in each_action:
112
+ del each_action['call_id']
113
+ each_action = json.dumps(each_action)
114
+ if each_action in self.existing_actions:
115
+ return False, "repeat_action"
116
+ self.existing_actions.append(each_action)
117
+ return True, ''