freeCS-dot-org commited on
Commit
f663ac7
·
verified ·
1 Parent(s): 2fb89d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -57
app.py CHANGED
@@ -2,49 +2,183 @@ import os
2
  import time
3
  import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import gradio as gr
7
  from threading import Thread
 
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
  MODEL = "AGI-0/Art-v0-3B"
11
 
12
- device = "cuda" # Use "cpu" if no GPU available
13
 
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
15
- model = AutoModelForCausalLM.from_pretrained(
16
- MODEL, torch_dtype=torch.bfloat16, device_map="auto"
17
- )
18
- end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class ConversationManager:
21
  def __init__(self):
22
- self.user_history = [] # User-facing history with formatting
23
- self.model_history = [] # Model-facing history without formatting
24
-
25
- def add_exchange(self, user_message, model_response):
26
- formatted_response = self.format_response(model_response)
27
- self.model_history.append((user_message, model_response))
28
- self.user_history.append((user_message, formatted_response))
29
- print(f"\nModel History Updated: {self.model_history}")
30
- print(f"\nUser History Updated: {self.user_history}")
31
-
32
- def format_response(self, response):
33
- """Format response for UI while keeping raw text for model."""
34
- if "<|end_reasoning|>" in response:
35
- parts = response.split("<|end_reasoning|>")
36
- reasoning, rest = parts[0], parts[1] if len(parts) > 1 else ""
37
- return f"<details><summary>Click to see reasoning</summary>\n\n{reasoning}\n\n</details>\n\n{rest}"
38
- return response
39
-
40
- def get_user_history(self):
41
- return self.user_history
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def get_model_history(self):
 
 
44
  return self.model_history
 
 
 
 
 
 
 
 
 
45
 
 
46
  conversation_manager = ConversationManager()
47
 
 
 
 
 
 
 
 
 
 
 
 
48
  @spaces.GPU()
49
  def stream_chat(
50
  message: str,
@@ -56,29 +190,46 @@ def stream_chat(
56
  top_k: int = 1,
57
  penalty: float = 1.1,
58
  ):
59
- print(f'User Message: {message}')
 
 
 
 
 
 
 
60
 
 
61
  model_history = conversation_manager.get_model_history()
62
- print(f'Model History: {model_history}')
63
 
 
64
  conversation = []
 
 
 
65
  for prompt, answer in model_history:
66
  conversation.extend([
67
  {"role": "user", "content": prompt},
68
  {"role": "assistant", "content": answer},
69
  ])
 
70
  conversation.append({"role": "user", "content": message})
71
-
72
- print(f'Formatted Conversation for Model: {conversation}')
73
 
74
  input_ids = tokenizer.apply_chat_template(
75
- conversation, add_generation_prompt=True, return_tensors="pt"
 
 
76
  ).to(model.device)
77
-
78
  streamer = TextIteratorStreamer(
79
- tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
 
 
 
80
  )
81
-
82
  generate_kwargs = dict(
83
  input_ids=input_ids,
84
  max_new_tokens=max_new_tokens,
@@ -90,42 +241,96 @@ def stream_chat(
90
  eos_token_id=[end_of_sentence],
91
  streamer=streamer,
92
  )
93
-
94
  buffer = ""
95
  original_response = ""
96
-
97
  with torch.no_grad():
98
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
99
  thread.start()
100
-
101
  for new_text in streamer:
102
  buffer += new_text
103
  original_response += new_text
104
- print(f'Streaming: {buffer}')
105
- formatted_buffer = conversation_manager.format_response(buffer)
106
- yield formatted_buffer, history + [[message, formatted_buffer]]
107
-
 
 
 
 
 
 
 
 
 
 
108
 
109
- conversation_manager.add_exchange(message, original_response)
 
110
 
111
- chatbot = gr.Chatbot(height=600, placeholder="<center><p>Hi! How can I help you today?</p></center>")
112
-
113
- demo = gr.Blocks()
114
- with demo:
115
- gr.HTML("<h2>Link to the model: <a href='https://huggingface.co/AGI-0/Art-v0-3B'>click here</a></h2>")
116
- gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
117
  gr.ChatInterface(
118
  fn=stream_chat,
119
  chatbot=chatbot,
120
  fill_height=True,
121
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
 
 
 
 
122
  additional_inputs=[
123
- gr.Textbox(value="", label="System Prompt", render=False),
124
- gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False),
125
- gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False),
126
- gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False),
127
- gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False),
128
- gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.1, label="Repetition penalty", render=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  ],
130
  examples=[
131
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
 
2
  import time
3
  import spaces
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
6
  import gradio as gr
7
  from threading import Thread
8
+ import re
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL = "AGI-0/Art-v0-3B"
12
 
13
+ TITLE = """<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B">click here</a></h2>"""
14
 
15
+ PLACEHOLDER = """
16
+ <center>
17
+ <p>Hi! How can I help you today?</p>
18
+ </center>
19
+ """
20
+
21
+ CSS = """
22
+ .duplicate-button {
23
+ margin: auto !important;
24
+ color: white !important;
25
+ background: black !important;
26
+ border-radius: 100vh !important;
27
+ }
28
+ h3 {
29
+ text-align: center;
30
+ }
31
+ """
32
 
33
  class ConversationManager:
34
  def __init__(self):
35
+ self.user_history = [] # For displaying to user (markdown)
36
+ self.model_history = [] # For feeding back to model (special tags)
37
+ self.debug_log = []
38
+
39
+ def log(self, message):
40
+ """Add timestamped log entry"""
41
+ timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
42
+ log_entry = f"[{timestamp}] {message}"
43
+ print(log_entry)
44
+ self.debug_log.append(log_entry)
45
+
46
+ def convert_to_markdown(self, model_text):
47
+ """Convert from model format (with special tags) to markdown"""
48
+ self.log(f"Converting to markdown - Input length: {len(model_text)}")
49
+ self.log(f"Input text: {model_text[:200]}..." if len(model_text) > 200 else f"Input text: {model_text}")
50
+
51
+ markdown_text = model_text
52
+
53
+ # Convert special tags to markdown
54
+ tag_conversions = [
55
+ # Reasoning blocks
56
+ ("<|start_reasoning|>", "<details><summary>Reasoning</summary>\n\n"),
57
+ ("<|end_reasoning|>", "\n\n</details>\n\n"),
58
+
59
+ # Other special tags (add more as needed)
60
+ ("<|im_start|>", ""),
61
+ ("<|im_end|>", ""),
62
+ ("<|assistant|>", ""),
63
+ ("<|user|>", ""),
64
+ ]
65
+
66
+ for old, new in tag_conversions:
67
+ if old in markdown_text:
68
+ self.log(f"Converting tag: {old} -> {new}")
69
+ markdown_text = markdown_text.replace(old, new)
70
+
71
+ # Clean up any remaining special tags using regex
72
+ markdown_text = re.sub(r'<\|[^>]+\|>', '', markdown_text)
73
+
74
+ # Fix common markdown formatting issues
75
+ markdown_text = re.sub(r'\n{3,}', '\n\n', markdown_text) # Remove excess newlines
76
+ markdown_text = markdown_text.strip()
77
+
78
+ self.log(f"Markdown conversion complete - Output length: {len(markdown_text)}")
79
+ self.log(f"Output text: {markdown_text[:200]}..." if len(markdown_text) > 200 else f"Output text: {markdown_text}")
80
+
81
+ return markdown_text
82
+
83
+ def convert_to_model_format(self, markdown_text):
84
+ """Convert from markdown to model format (with special tags)"""
85
+ self.log(f"Converting to model format - Input length: {len(markdown_text)}")
86
+ self.log(f"Input text: {markdown_text[:200]}..." if len(markdown_text) > 200 else f"Input text: {markdown_text}")
87
+
88
+ model_text = markdown_text
89
+
90
+ # Convert markdown to special tags
91
+ if "<details>" in markdown_text and "</details>" in markdown_text:
92
+ try:
93
+ # Extract content between details tags
94
+ pattern = r'<details><summary>.*?</summary>\s*(.*?)\s*</details>'
95
+ matches = re.findall(pattern, markdown_text, re.DOTALL)
96
+
97
+ for match in matches:
98
+ original = f"<details><summary>Reasoning</summary>\n\n{match}\n\n</details>"
99
+ replacement = f"<|start_reasoning|>{match}<|end_reasoning|>"
100
+ model_text = model_text.replace(original, replacement)
101
+ self.log(f"Converted details block to reasoning tags")
102
+ except Exception as e:
103
+ self.log(f"Warning: Failed to convert details block: {str(e)}")
104
+
105
+ # Clean up formatting
106
+ model_text = re.sub(r'\n{3,}', '\n\n', model_text) # Remove excess newlines
107
+ model_text = model_text.strip()
108
+
109
+ self.log(f"Model format conversion complete - Output length: {len(model_text)}")
110
+ self.log(f"Output text: {model_text[:200]}..." if len(model_text) > 200 else f"Output text: {model_text}")
111
+
112
+ return model_text
113
+
114
+ def add_exchange(self, user_message, assistant_response):
115
+ """Add a new exchange to both histories"""
116
+ self.log(f"\n=== Adding New Exchange ===")
117
+ self.log(f"User Message: {user_message[:100]}..." if len(user_message) > 100 else f"User Message: {user_message}")
118
+ self.log(f"Assistant Response: {assistant_response[:100]}..." if len(assistant_response) > 100 else f"Assistant Response: {assistant_response}")
119
+
120
+ # Convert assistant response to markdown for user display
121
+ markdown_response = self.convert_to_markdown(assistant_response)
122
+
123
+ # Store both versions
124
+ self.model_history.append((user_message, assistant_response))
125
+ self.user_history.append((user_message, markdown_response))
126
+
127
+ self.log(f"Current History State:")
128
+ self.log(f"- Model History: {len(self.model_history)} exchanges")
129
+ self.log(f"- User History: {len(self.user_history)} exchanges")
130
+
131
+ def sync_with_ui_history(self, ui_history):
132
+ """Sync our histories with the UI history"""
133
+ self.log(f"\n=== Syncing with UI History ===")
134
+ self.log(f"UI History Length: {len(ui_history)}")
135
+
136
+ # Clear current histories
137
+ self.model_history = []
138
+ self.user_history = []
139
+
140
+ # Rebuild histories from UI
141
+ for user_msg, markdown_response in ui_history:
142
+ model_response = self.convert_to_model_format(markdown_response)
143
+ self.model_history.append((user_msg, model_response))
144
+ self.user_history.append((user_msg, markdown_response))
145
+
146
+ self.log(f"Sync Complete:")
147
+ self.log(f"- Model History: {len(self.model_history)} exchanges")
148
+ self.log(f"- User History: {len(self.user_history)} exchanges")
149
+
150
+ # Verify sync integrity
151
+ if len(self.model_history) != len(self.user_history) or len(self.model_history) != len(ui_history):
152
+ self.log("WARNING: History length mismatch after sync!")
153
+
154
  def get_model_history(self):
155
+ """Get history in model format"""
156
+ self.log(f"\nReturning Model History ({len(self.model_history)} exchanges)")
157
  return self.model_history
158
+
159
+ def get_user_history(self):
160
+ """Get history in markdown format"""
161
+ self.log(f"\nReturning User History ({len(self.user_history)} exchanges)")
162
+ return self.user_history
163
+
164
+ def get_debug_log(self):
165
+ """Get the full debug log"""
166
+ return "\n".join(self.debug_log)
167
 
168
+ # Initialize global conversation manager
169
  conversation_manager = ConversationManager()
170
 
171
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
172
+
173
+ # Initialize model and tokenizer
174
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
175
+ model = AutoModelForCausalLM.from_pretrained(
176
+ MODEL,
177
+ torch_dtype=torch.bfloat16,
178
+ device_map="auto"
179
+ )
180
+ end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
181
+
182
  @spaces.GPU()
183
  def stream_chat(
184
  message: str,
 
190
  top_k: int = 1,
191
  penalty: float = 1.1,
192
  ):
193
+ conversation_manager.log(f'\n=== New Chat Request ===')
194
+ conversation_manager.log(f'Message: {message}')
195
+ conversation_manager.log(f'History Length: {len(history)}')
196
+ conversation_manager.log(f'System Prompt: {system_prompt}')
197
+ conversation_manager.log(f'Parameters: temp={temperature}, max_tokens={max_new_tokens}, top_p={top_p}, top_k={top_k}, penalty={penalty}')
198
+
199
+ # Sync with UI history
200
+ conversation_manager.sync_with_ui_history(history)
201
 
202
+ # Get model-formatted history
203
  model_history = conversation_manager.get_model_history()
 
204
 
205
+ # Build conversation for model
206
  conversation = []
207
+ if system_prompt:
208
+ conversation.append({"role": "system", "content": system_prompt})
209
+
210
  for prompt, answer in model_history:
211
  conversation.extend([
212
  {"role": "user", "content": prompt},
213
  {"role": "assistant", "content": answer},
214
  ])
215
+
216
  conversation.append({"role": "user", "content": message})
217
+
218
+ conversation_manager.log(f'Built conversation with {len(conversation)} messages')
219
 
220
  input_ids = tokenizer.apply_chat_template(
221
+ conversation,
222
+ add_generation_prompt=True,
223
+ return_tensors="pt"
224
  ).to(model.device)
225
+
226
  streamer = TextIteratorStreamer(
227
+ tokenizer,
228
+ timeout=60.0,
229
+ skip_prompt=True,
230
+ skip_special_tokens=True
231
  )
232
+
233
  generate_kwargs = dict(
234
  input_ids=input_ids,
235
  max_new_tokens=max_new_tokens,
 
241
  eos_token_id=[end_of_sentence],
242
  streamer=streamer,
243
  )
244
+
245
  buffer = ""
246
  original_response = ""
247
+
248
  with torch.no_grad():
249
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
250
  thread.start()
251
+
252
  for new_text in streamer:
253
  buffer += new_text
254
  original_response += new_text
255
+
256
+ # Convert buffer to markdown for display
257
+ formatted_buffer = conversation_manager.convert_to_markdown(buffer)
258
+
259
+ if thread.is_alive() is False:
260
+ conversation_manager.log(f'Generation Complete:')
261
+ conversation_manager.log(f'Final Response Length: {len(original_response)}')
262
+
263
+ conversation_manager.add_exchange(
264
+ message,
265
+ original_response # Original for model
266
+ )
267
+
268
+ yield formatted_buffer
269
 
270
+ # Initialize Gradio interface
271
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
272
 
273
+ with gr.Blocks(css=CSS, theme="soft") as demo:
274
+ gr.HTML(TITLE)
275
+ gr.DuplicateButton(
276
+ value="Duplicate Space for private use",
277
+ elem_classes="duplicate-button"
278
+ )
279
  gr.ChatInterface(
280
  fn=stream_chat,
281
  chatbot=chatbot,
282
  fill_height=True,
283
+ additional_inputs_accordion=gr.Accordion(
284
+ label="⚙️ Parameters",
285
+ open=False,
286
+ render=False
287
+ ),
288
  additional_inputs=[
289
+ gr.Textbox(
290
+ value="",
291
+ label="System Prompt",
292
+ render=False,
293
+ ),
294
+ gr.Slider(
295
+ minimum=0,
296
+ maximum=1,
297
+ step=0.1,
298
+ value=0.2,
299
+ label="Temperature",
300
+ render=False,
301
+ ),
302
+ gr.Slider(
303
+ minimum=128,
304
+ maximum=8192,
305
+ step=1,
306
+ value=4096,
307
+ label="Max new tokens",
308
+ render=False,
309
+ ),
310
+ gr.Slider(
311
+ minimum=0.0,
312
+ maximum=1.0,
313
+ step=0.1,
314
+ value=1.0,
315
+ label="top_p",
316
+ render=False,
317
+ ),
318
+ gr.Slider(
319
+ minimum=1,
320
+ maximum=50,
321
+ step=1,
322
+ value=1,
323
+ label="top_k",
324
+ render=False,
325
+ ),
326
+ gr.Slider(
327
+ minimum=0.0,
328
+ maximum=2.0,
329
+ step=0.1,
330
+ value=1.1,
331
+ label="Repetition penalty",
332
+ render=False,
333
+ ),
334
  ],
335
  examples=[
336
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],