freeCS-dot-org commited on
Commit
1898bf7
·
verified ·
1 Parent(s): f663ac7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -233
app.py CHANGED
@@ -2,175 +2,82 @@ import os
2
  import time
3
  import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
6
- import gradio as gr
7
  from threading import Thread
8
- import re
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL = "AGI-0/Art-v0-3B"
12
 
13
- TITLE = """<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B">click here</a></h2>"""
14
-
15
- PLACEHOLDER = """
16
- <center>
17
- <p>Hi! How can I help you today?</p>
18
- </center>
19
- """
20
-
21
- CSS = """
22
- .duplicate-button {
23
- margin: auto !important;
24
- color: white !important;
25
- background: black !important;
26
- border-radius: 100vh !important;
27
- }
28
- h3 {
29
- text-align: center;
30
- }
31
- """
32
-
33
  class ConversationManager:
34
  def __init__(self):
35
- self.user_history = [] # For displaying to user (markdown)
36
- self.model_history = [] # For feeding back to model (special tags)
37
- self.debug_log = []
38
 
39
- def log(self, message):
40
- """Add timestamped log entry"""
41
- timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
42
- log_entry = f"[{timestamp}] {message}"
43
- print(log_entry)
44
- self.debug_log.append(log_entry)
45
-
46
- def convert_to_markdown(self, model_text):
47
- """Convert from model format (with special tags) to markdown"""
48
- self.log(f"Converting to markdown - Input length: {len(model_text)}")
49
- self.log(f"Input text: {model_text[:200]}..." if len(model_text) > 200 else f"Input text: {model_text}")
50
 
51
- markdown_text = model_text
 
 
 
 
52
 
53
- # Convert special tags to markdown
54
- tag_conversions = [
55
- # Reasoning blocks
56
- ("<|start_reasoning|>", "<details><summary>Reasoning</summary>\n\n"),
57
- ("<|end_reasoning|>", "\n\n</details>\n\n"),
 
 
 
58
 
59
- # Other special tags (add more as needed)
60
- ("<|im_start|>", ""),
61
- ("<|im_end|>", ""),
62
- ("<|assistant|>", ""),
63
- ("<|user|>", ""),
 
 
 
 
 
64
  ]
65
 
66
- for old, new in tag_conversions:
67
- if old in markdown_text:
68
- self.log(f"Converting tag: {old} -> {new}")
69
- markdown_text = markdown_text.replace(old, new)
70
-
71
- # Clean up any remaining special tags using regex
72
- markdown_text = re.sub(r'<\|[^>]+\|>', '', markdown_text)
73
-
74
- # Fix common markdown formatting issues
75
- markdown_text = re.sub(r'\n{3,}', '\n\n', markdown_text) # Remove excess newlines
76
- markdown_text = markdown_text.strip()
77
-
78
- self.log(f"Markdown conversion complete - Output length: {len(markdown_text)}")
79
- self.log(f"Output text: {markdown_text[:200]}..." if len(markdown_text) > 200 else f"Output text: {markdown_text}")
80
-
81
- return markdown_text
82
-
83
- def convert_to_model_format(self, markdown_text):
84
- """Convert from markdown to model format (with special tags)"""
85
- self.log(f"Converting to model format - Input length: {len(markdown_text)}")
86
- self.log(f"Input text: {markdown_text[:200]}..." if len(markdown_text) > 200 else f"Input text: {markdown_text}")
87
-
88
- model_text = markdown_text
89
-
90
- # Convert markdown to special tags
91
- if "<details>" in markdown_text and "</details>" in markdown_text:
92
- try:
93
- # Extract content between details tags
94
- pattern = r'<details><summary>.*?</summary>\s*(.*?)\s*</details>'
95
- matches = re.findall(pattern, markdown_text, re.DOTALL)
96
-
97
- for match in matches:
98
- original = f"<details><summary>Reasoning</summary>\n\n{match}\n\n</details>"
99
- replacement = f"<|start_reasoning|>{match}<|end_reasoning|>"
100
- model_text = model_text.replace(original, replacement)
101
- self.log(f"Converted details block to reasoning tags")
102
- except Exception as e:
103
- self.log(f"Warning: Failed to convert details block: {str(e)}")
104
 
105
- # Clean up formatting
106
- model_text = re.sub(r'\n{3,}', '\n\n', model_text) # Remove excess newlines
107
- model_text = model_text.strip()
108
 
109
- self.log(f"Model format conversion complete - Output length: {len(model_text)}")
110
- self.log(f"Output text: {model_text[:200]}..." if len(model_text) > 200 else f"Output text: {model_text}")
111
-
112
- return model_text
113
 
114
  def add_exchange(self, user_message, assistant_response):
115
- """Add a new exchange to both histories"""
116
- self.log(f"\n=== Adding New Exchange ===")
117
- self.log(f"User Message: {user_message[:100]}..." if len(user_message) > 100 else f"User Message: {user_message}")
118
- self.log(f"Assistant Response: {assistant_response[:100]}..." if len(assistant_response) > 100 else f"Assistant Response: {assistant_response}")
119
-
120
- # Convert assistant response to markdown for user display
121
- markdown_response = self.convert_to_markdown(assistant_response)
122
-
123
- # Store both versions
124
- self.model_history.append((user_message, assistant_response))
125
- self.user_history.append((user_message, markdown_response))
126
-
127
- self.log(f"Current History State:")
128
- self.log(f"- Model History: {len(self.model_history)} exchanges")
129
- self.log(f"- User History: {len(self.user_history)} exchanges")
130
-
131
- def sync_with_ui_history(self, ui_history):
132
- """Sync our histories with the UI history"""
133
- self.log(f"\n=== Syncing with UI History ===")
134
- self.log(f"UI History Length: {len(ui_history)}")
135
-
136
- # Clear current histories
137
- self.model_history = []
138
- self.user_history = []
139
-
140
- # Rebuild histories from UI
141
- for user_msg, markdown_response in ui_history:
142
- model_response = self.convert_to_model_format(markdown_response)
143
- self.model_history.append((user_msg, model_response))
144
- self.user_history.append((user_msg, markdown_response))
145
-
146
- self.log(f"Sync Complete:")
147
- self.log(f"- Model History: {len(self.model_history)} exchanges")
148
- self.log(f"- User History: {len(self.user_history)} exchanges")
149
-
150
- # Verify sync integrity
151
- if len(self.model_history) != len(self.user_history) or len(self.model_history) != len(ui_history):
152
- self.log("WARNING: History length mismatch after sync!")
153
-
154
- def get_model_history(self):
155
- """Get history in model format"""
156
- self.log(f"\nReturning Model History ({len(self.model_history)} exchanges)")
157
- return self.model_history
158
-
159
- def get_user_history(self):
160
- """Get history in markdown format"""
161
- self.log(f"\nReturning User History ({len(self.user_history)} exchanges)")
162
- return self.user_history
163
-
164
- def get_debug_log(self):
165
- """Get the full debug log"""
166
- return "\n".join(self.debug_log)
167
 
168
- # Initialize global conversation manager
169
  conversation_manager = ConversationManager()
 
170
 
171
- device = "cuda" # for GPU usage or "cpu" for CPU usage
172
-
173
- # Initialize model and tokenizer
174
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
175
  model = AutoModelForCausalLM.from_pretrained(
176
  MODEL,
@@ -190,33 +97,24 @@ def stream_chat(
190
  top_k: int = 1,
191
  penalty: float = 1.1,
192
  ):
193
- conversation_manager.log(f'\n=== New Chat Request ===')
194
- conversation_manager.log(f'Message: {message}')
195
- conversation_manager.log(f'History Length: {len(history)}')
196
- conversation_manager.log(f'System Prompt: {system_prompt}')
197
- conversation_manager.log(f'Parameters: temp={temperature}, max_tokens={max_new_tokens}, top_p={top_p}, top_k={top_k}, penalty={penalty}')
198
 
199
- # Sync with UI history
200
- conversation_manager.sync_with_ui_history(history)
201
-
202
- # Get model-formatted history
203
- model_history = conversation_manager.get_model_history()
204
-
205
- # Build conversation for model
206
  conversation = []
207
  if system_prompt:
208
  conversation.append({"role": "system", "content": system_prompt})
209
 
210
- for prompt, answer in model_history:
211
- conversation.extend([
212
- {"role": "user", "content": prompt},
213
- {"role": "assistant", "content": answer},
214
- ])
215
 
 
216
  conversation.append({"role": "user", "content": message})
217
 
218
- conversation_manager.log(f'Built conversation with {len(conversation)} messages')
219
 
 
220
  input_ids = tokenizer.apply_chat_template(
221
  conversation,
222
  add_generation_prompt=True,
@@ -242,8 +140,9 @@ def stream_chat(
242
  streamer=streamer,
243
  )
244
 
 
245
  buffer = ""
246
- original_response = ""
247
 
248
  with torch.no_grad():
249
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -251,27 +150,40 @@ def stream_chat(
251
 
252
  for new_text in streamer:
253
  buffer += new_text
254
- original_response += new_text
255
 
256
- # Convert buffer to markdown for display
257
- formatted_buffer = conversation_manager.convert_to_markdown(buffer)
258
 
259
- if thread.is_alive() is False:
260
- conversation_manager.log(f'Generation Complete:')
261
- conversation_manager.log(f'Final Response Length: {len(original_response)}')
262
-
263
- conversation_manager.add_exchange(
264
- message,
265
- original_response # Original for model
266
- )
267
 
268
- yield formatted_buffer
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- # Initialize Gradio interface
271
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
 
 
 
 
 
 
272
 
273
  with gr.Blocks(css=CSS, theme="soft") as demo:
274
- gr.HTML(TITLE)
275
  gr.DuplicateButton(
276
  value="Duplicate Space for private use",
277
  elem_classes="duplicate-button"
@@ -280,57 +192,14 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
280
  fn=stream_chat,
281
  chatbot=chatbot,
282
  fill_height=True,
283
- additional_inputs_accordion=gr.Accordion(
284
- label="⚙️ Parameters",
285
- open=False,
286
- render=False
287
- ),
288
  additional_inputs=[
289
- gr.Textbox(
290
- value="",
291
- label="System Prompt",
292
- render=False,
293
- ),
294
- gr.Slider(
295
- minimum=0,
296
- maximum=1,
297
- step=0.1,
298
- value=0.2,
299
- label="Temperature",
300
- render=False,
301
- ),
302
- gr.Slider(
303
- minimum=128,
304
- maximum=8192,
305
- step=1,
306
- value=4096,
307
- label="Max new tokens",
308
- render=False,
309
- ),
310
- gr.Slider(
311
- minimum=0.0,
312
- maximum=1.0,
313
- step=0.1,
314
- value=1.0,
315
- label="top_p",
316
- render=False,
317
- ),
318
- gr.Slider(
319
- minimum=1,
320
- maximum=50,
321
- step=1,
322
- value=1,
323
- label="top_k",
324
- render=False,
325
- ),
326
- gr.Slider(
327
- minimum=0.0,
328
- maximum=2.0,
329
- step=0.1,
330
- value=1.1,
331
- label="Repetition penalty",
332
- render=False,
333
- ),
334
  ],
335
  examples=[
336
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
 
2
  import time
3
  import spaces
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
6
  from threading import Thread
7
+ import gradio as gr
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
  MODEL = "AGI-0/Art-v0-3B"
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class ConversationManager:
13
  def __init__(self):
14
+ self.model_messages = [] # Stores raw responses with tags
 
 
15
 
16
+ def format_for_display(self, raw_response):
17
+ """Convert model response to user-friendly markdown.
18
+ Keeps original response intact for model."""
 
 
 
 
 
 
 
 
19
 
20
+ # No response? Return empty
21
+ if not raw_response:
22
+ return ""
23
+
24
+ display_response = raw_response
25
 
26
+ # Handle reasoning sections
27
+ while "<|start_reasoning|>" in display_response and "<|end_reasoning|>" in display_response:
28
+ start = display_response.find("<|start_reasoning|>")
29
+ end = display_response.find("<|end_reasoning|>") + len("<|end_reasoning|>")
30
+
31
+ # Extract reasoning content
32
+ reasoning_block = display_response[start:end]
33
+ reasoning_content = reasoning_block.replace("<|start_reasoning|>", "").replace("<|end_reasoning|>", "")
34
 
35
+ # Replace with markdown details/summary
36
+ markdown_block = f"\n<details><summary>View Reasoning</summary>\n\n{reasoning_content}\n\n</details>\n"
37
+ display_response = display_response[:start] + markdown_block + display_response[end:]
38
+
39
+ # Clean up other tags
40
+ tags_to_remove = [
41
+ "<|im_start|>",
42
+ "<|im_end|>",
43
+ "<|assistant|>",
44
+ "<|user|>"
45
  ]
46
 
47
+ for tag in tags_to_remove:
48
+ display_response = display_response.replace(tag, "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Clean up any extra whitespace
51
+ display_response = "\n".join(line.strip() for line in display_response.split("\n"))
52
+ display_response = "\n".join(filter(None, display_response.split("\n")))
53
 
54
+ return display_response.strip()
 
 
 
55
 
56
  def add_exchange(self, user_message, assistant_response):
57
+ """Store raw response in model history"""
58
+ print("\n=== New Exchange ===")
59
+ print(f"User: {user_message[:100]}{'...' if len(user_message) > 100 else ''}")
60
+ print(f"Assistant (raw): {assistant_response[:100]}{'...' if len(assistant_response) > 100 else ''}")
61
+
62
+ self.model_messages.append({
63
+ "role": "user",
64
+ "content": user_message
65
+ })
66
+ self.model_messages.append({
67
+ "role": "assistant",
68
+ "content": assistant_response
69
+ })
70
+
71
+ print(f"Current history length: {len(self.model_messages)} messages")
72
+
73
+ def get_conversation_messages(self):
74
+ """Get full conversation history for model"""
75
+ return self.model_messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # Initialize globals
78
  conversation_manager = ConversationManager()
79
+ device = "cuda"
80
 
 
 
 
81
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
82
  model = AutoModelForCausalLM.from_pretrained(
83
  MODEL,
 
97
  top_k: int = 1,
98
  penalty: float = 1.1,
99
  ):
100
+ print(f"\n=== New Chat Request ===")
101
+ print(f"Message: {message}")
102
+ print(f"History length: {len(history)}")
 
 
103
 
104
+ # Build conversation history from model's stored messages
 
 
 
 
 
 
105
  conversation = []
106
  if system_prompt:
107
  conversation.append({"role": "system", "content": system_prompt})
108
 
109
+ # Add all previous messages
110
+ conversation.extend(conversation_manager.get_conversation_messages())
 
 
 
111
 
112
+ # Add new message
113
  conversation.append({"role": "user", "content": message})
114
 
115
+ print(f"Sending {len(conversation)} messages to model")
116
 
117
+ # Prepare model input
118
  input_ids = tokenizer.apply_chat_template(
119
  conversation,
120
  add_generation_prompt=True,
 
140
  streamer=streamer,
141
  )
142
 
143
+ # Storage for building complete response
144
  buffer = ""
145
+ model_response = ""
146
 
147
  with torch.no_grad():
148
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
 
150
 
151
  for new_text in streamer:
152
  buffer += new_text
153
+ model_response += new_text
154
 
155
+ # Convert current buffer for display
156
+ display_text = conversation_manager.format_for_display(buffer)
157
 
158
+ if not thread.is_alive():
159
+ print("Generation complete")
160
+ # Store final response in model history
161
+ conversation_manager.add_exchange(message, model_response)
 
 
 
 
162
 
163
+ yield display_text
164
+
165
+ # Set up Gradio interface
166
+ CSS = """
167
+ .duplicate-button {
168
+ margin: auto !important;
169
+ color: white !important;
170
+ background: black !important;
171
+ border-radius: 100vh !important;
172
+ }
173
+ h3 { text-align: center; }
174
+ """
175
 
176
+ chatbot = gr.Chatbot(
177
+ height=600,
178
+ placeholder="""
179
+ <center>
180
+ <p>Hi! How can I help you today?</p>
181
+ </center>
182
+ """
183
+ )
184
 
185
  with gr.Blocks(css=CSS, theme="soft") as demo:
186
+ gr.HTML("""<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B">click here</a></h2>""")
187
  gr.DuplicateButton(
188
  value="Duplicate Space for private use",
189
  elem_classes="duplicate-button"
 
192
  fn=stream_chat,
193
  chatbot=chatbot,
194
  fill_height=True,
195
+ additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False, render=False),
 
 
 
 
196
  additional_inputs=[
197
+ gr.Textbox(value="", label="System Prompt", render=False),
198
+ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False),
199
+ gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False),
200
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False),
201
+ gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False),
202
+ gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.1, label="Repetition penalty", render=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  ],
204
  examples=[
205
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],