freeCS-dot-org commited on
Commit
befe84a
·
verified ·
1 Parent(s): 1898bf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -120
app.py CHANGED
@@ -2,81 +2,55 @@ import os
2
  import time
3
  import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
- from threading import Thread
7
  import gradio as gr
 
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
  MODEL = "AGI-0/Art-v0-3B"
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class ConversationManager:
13
  def __init__(self):
14
- self.model_messages = [] # Stores raw responses with tags
15
-
16
- def format_for_display(self, raw_response):
17
- """Convert model response to user-friendly markdown.
18
- Keeps original response intact for model."""
19
-
20
- # No response? Return empty
21
- if not raw_response:
22
- return ""
23
-
24
- display_response = raw_response
25
-
26
- # Handle reasoning sections
27
- while "<|start_reasoning|>" in display_response and "<|end_reasoning|>" in display_response:
28
- start = display_response.find("<|start_reasoning|>")
29
- end = display_response.find("<|end_reasoning|>") + len("<|end_reasoning|>")
30
-
31
- # Extract reasoning content
32
- reasoning_block = display_response[start:end]
33
- reasoning_content = reasoning_block.replace("<|start_reasoning|>", "").replace("<|end_reasoning|>", "")
34
-
35
- # Replace with markdown details/summary
36
- markdown_block = f"\n<details><summary>View Reasoning</summary>\n\n{reasoning_content}\n\n</details>\n"
37
- display_response = display_response[:start] + markdown_block + display_response[end:]
38
-
39
- # Clean up other tags
40
- tags_to_remove = [
41
- "<|im_start|>",
42
- "<|im_end|>",
43
- "<|assistant|>",
44
- "<|user|>"
45
- ]
46
-
47
- for tag in tags_to_remove:
48
- display_response = display_response.replace(tag, "")
49
-
50
- # Clean up any extra whitespace
51
- display_response = "\n".join(line.strip() for line in display_response.split("\n"))
52
- display_response = "\n".join(filter(None, display_response.split("\n")))
53
-
54
- return display_response.strip()
55
 
56
- def add_exchange(self, user_message, assistant_response):
57
- """Store raw response in model history"""
58
- print("\n=== New Exchange ===")
59
- print(f"User: {user_message[:100]}{'...' if len(user_message) > 100 else ''}")
60
- print(f"Assistant (raw): {assistant_response[:100]}{'...' if len(assistant_response) > 100 else ''}")
61
-
62
- self.model_messages.append({
63
- "role": "user",
64
- "content": user_message
65
- })
66
- self.model_messages.append({
67
- "role": "assistant",
68
- "content": assistant_response
69
- })
70
-
71
- print(f"Current history length: {len(self.model_messages)} messages")
72
 
73
- def get_conversation_messages(self):
74
- """Get full conversation history for model"""
75
- return self.model_messages
 
 
76
 
77
- # Initialize globals
78
  conversation_manager = ConversationManager()
79
- device = "cuda"
 
80
 
81
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
82
  model = AutoModelForCausalLM.from_pretrained(
@@ -86,6 +60,15 @@ model = AutoModelForCausalLM.from_pretrained(
86
  )
87
  end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
88
 
 
 
 
 
 
 
 
 
 
89
  @spaces.GPU()
90
  def stream_chat(
91
  message: str,
@@ -97,34 +80,48 @@ def stream_chat(
97
  top_k: int = 1,
98
  penalty: float = 1.1,
99
  ):
100
- print(f"\n=== New Chat Request ===")
101
- print(f"Message: {message}")
102
- print(f"History length: {len(history)}")
 
 
103
 
104
- # Build conversation history from model's stored messages
105
  conversation = []
106
- if system_prompt:
107
- conversation.append({"role": "system", "content": system_prompt})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- # Add all previous messages
110
- conversation.extend(conversation_manager.get_conversation_messages())
111
-
112
- # Add new message
113
  conversation.append({"role": "user", "content": message})
 
 
114
 
115
- print(f"Sending {len(conversation)} messages to model")
116
-
117
- # Prepare model input
118
  input_ids = tokenizer.apply_chat_template(
119
- conversation,
120
- add_generation_prompt=True,
121
  return_tensors="pt"
122
  ).to(model.device)
123
 
124
  streamer = TextIteratorStreamer(
125
- tokenizer,
126
- timeout=60.0,
127
- skip_prompt=True,
128
  skip_special_tokens=True
129
  )
130
 
@@ -140,9 +137,8 @@ def stream_chat(
140
  streamer=streamer,
141
  )
142
 
143
- # Storage for building complete response
144
  buffer = ""
145
- model_response = ""
146
 
147
  with torch.no_grad():
148
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -150,56 +146,86 @@ def stream_chat(
150
 
151
  for new_text in streamer:
152
  buffer += new_text
153
- model_response += new_text
154
 
155
- # Convert current buffer for display
156
- display_text = conversation_manager.format_for_display(buffer)
157
 
158
- if not thread.is_alive():
159
- print("Generation complete")
160
- # Store final response in model history
161
- conversation_manager.add_exchange(message, model_response)
 
 
 
 
 
 
162
 
163
- yield display_text
164
 
165
- # Set up Gradio interface
166
- CSS = """
167
- .duplicate-button {
168
- margin: auto !important;
169
- color: white !important;
170
- background: black !important;
171
- border-radius: 100vh !important;
172
- }
173
- h3 { text-align: center; }
174
- """
175
-
176
- chatbot = gr.Chatbot(
177
- height=600,
178
- placeholder="""
179
- <center>
180
- <p>Hi! How can I help you today?</p>
181
- </center>
182
- """
183
- )
184
 
185
  with gr.Blocks(css=CSS, theme="soft") as demo:
186
- gr.HTML("""<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B">click here</a></h2>""")
187
  gr.DuplicateButton(
188
- value="Duplicate Space for private use",
189
  elem_classes="duplicate-button"
190
  )
191
  gr.ChatInterface(
192
  fn=stream_chat,
193
  chatbot=chatbot,
194
  fill_height=True,
195
- additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False, render=False),
 
 
 
 
196
  additional_inputs=[
197
- gr.Textbox(value="", label="System Prompt", render=False),
198
- gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False),
199
- gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False),
200
- gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False),
201
- gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False),
202
- gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.1, label="Repetition penalty", render=False),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  ],
204
  examples=[
205
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
 
2
  import time
3
  import spaces
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
 
6
  import gradio as gr
7
+ from threading import Thread
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
  MODEL = "AGI-0/Art-v0-3B"
11
 
12
+ TITLE = """<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B">click here</a></h2>"""
13
+
14
+ PLACEHOLDER = """
15
+ <center>
16
+ <p>Hi! How can I help you today?</p>
17
+ </center>
18
+ """
19
+
20
+ CSS = """
21
+ .duplicate-button {
22
+ margin: auto !important;
23
+ color: white !important;
24
+ background: black !important;
25
+ border-radius: 100vh !important;
26
+ }
27
+ h3 {
28
+ text-align: center;
29
+ }
30
+ """
31
+
32
  class ConversationManager:
33
  def __init__(self):
34
+ self.user_history = [] # For displaying to user (with markdown)
35
+ self.model_history = [] # For feeding back to model (with original tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def add_exchange(self, user_message, assistant_response, formatted_response):
38
+ self.model_history.append((user_message, assistant_response))
39
+ self.user_history.append((user_message, formatted_response))
40
+ print(f"\nModel History Exchange:")
41
+ print(f"User: {user_message}")
42
+ print(f"Assistant (Original): {assistant_response}")
43
+ print(f"Assistant (Formatted): {formatted_response}")
 
 
 
 
 
 
 
 
 
44
 
45
+ def get_model_history(self):
46
+ return self.model_history
47
+
48
+ def get_user_history(self):
49
+ return self.user_history
50
 
 
51
  conversation_manager = ConversationManager()
52
+
53
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
54
 
55
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
56
  model = AutoModelForCausalLM.from_pretrained(
 
60
  )
61
  end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
62
 
63
+ def format_response(response):
64
+ """Format the response for user display"""
65
+ if "<|end_reasoning|>" in response:
66
+ parts = response.split("<|end_reasoning|>")
67
+ reasoning = parts[0]
68
+ rest = parts[1] if len(parts) > 1 else ""
69
+ return f"<details><summary>Click to see reasoning</summary>\n\n{reasoning}\n\n</details>\n\n{rest}"
70
+ return response
71
+
72
  @spaces.GPU()
73
  def stream_chat(
74
  message: str,
 
80
  top_k: int = 1,
81
  penalty: float = 1.1,
82
  ):
83
+ print(f'\nNew Chat Request:')
84
+ print(f'Message: {message}')
85
+ print(f'History from UI: {history}')
86
+ print(f'System Prompt: {system_prompt}')
87
+ print(f'Parameters: temp={temperature}, max_tokens={max_new_tokens}, top_p={top_p}, top_k={top_k}, penalty={penalty}')
88
 
89
+ # Build conversation from UI history instead of model_history
90
  conversation = []
91
+ for prompt, answer in (history or []):
92
+ # Extract original response if it's in the details format
93
+ if "<details>" in answer:
94
+ # Extract content between <details> tags and after </details>
95
+ parts = answer.split("</details>")
96
+ if len(parts) > 1:
97
+ # Get the content after the </details> tag
98
+ answer_content = parts[1].strip()
99
+ # Get the reasoning part
100
+ reasoning = answer.split("<summary>")[1].split("</summary>")[1].strip()
101
+ # Reconstruct the original format
102
+ answer = f"{reasoning}<|end_reasoning|>{answer_content}"
103
+ else:
104
+ # If no </details> tag found, use the answer as is
105
+ answer = answer
106
+ conversation.extend([
107
+ {"role": "user", "content": prompt},
108
+ {"role": "assistant", "content": answer},
109
+ ])
110
 
 
 
 
 
111
  conversation.append({"role": "user", "content": message})
112
+ print(f'\nFormatted Conversation for Model:')
113
+ print(conversation)
114
 
 
 
 
115
  input_ids = tokenizer.apply_chat_template(
116
+ conversation,
117
+ add_generation_prompt=True,
118
  return_tensors="pt"
119
  ).to(model.device)
120
 
121
  streamer = TextIteratorStreamer(
122
+ tokenizer,
123
+ timeout=60.0,
124
+ skip_prompt=True,
125
  skip_special_tokens=True
126
  )
127
 
 
137
  streamer=streamer,
138
  )
139
 
 
140
  buffer = ""
141
+ original_response = ""
142
 
143
  with torch.no_grad():
144
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
 
146
 
147
  for new_text in streamer:
148
  buffer += new_text
149
+ original_response += new_text
150
 
151
+ formatted_buffer = format_response(buffer)
 
152
 
153
+ if thread.is_alive() is False:
154
+ print(f'\nGeneration Complete:')
155
+ print(f'Original Response: {original_response}')
156
+ print(f'Formatted Response: {formatted_buffer}')
157
+
158
+ conversation_manager.add_exchange(
159
+ message,
160
+ original_response, # Original for model
161
+ formatted_buffer # Formatted for user
162
+ )
163
 
164
+ yield formatted_buffer
165
 
166
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  with gr.Blocks(css=CSS, theme="soft") as demo:
169
+ gr.HTML(TITLE)
170
  gr.DuplicateButton(
171
+ value="Duplicate Space for private use",
172
  elem_classes="duplicate-button"
173
  )
174
  gr.ChatInterface(
175
  fn=stream_chat,
176
  chatbot=chatbot,
177
  fill_height=True,
178
+ additional_inputs_accordion=gr.Accordion(
179
+ label="⚙️ Parameters",
180
+ open=False,
181
+ render=False
182
+ ),
183
  additional_inputs=[
184
+ gr.Textbox(
185
+ value="",
186
+ label="System Prompt",
187
+ render=False,
188
+ ),
189
+ gr.Slider(
190
+ minimum=0,
191
+ maximum=1,
192
+ step=0.1,
193
+ value=0.2,
194
+ label="Temperature",
195
+ render=False,
196
+ ),
197
+ gr.Slider(
198
+ minimum=128,
199
+ maximum=8192,
200
+ step=1,
201
+ value=4096,
202
+ label="Max new tokens",
203
+ render=False,
204
+ ),
205
+ gr.Slider(
206
+ minimum=0.0,
207
+ maximum=1.0,
208
+ step=0.1,
209
+ value=1.0,
210
+ label="top_p",
211
+ render=False,
212
+ ),
213
+ gr.Slider(
214
+ minimum=1,
215
+ maximum=50,
216
+ step=1,
217
+ value=1,
218
+ label="top_k",
219
+ render=False,
220
+ ),
221
+ gr.Slider(
222
+ minimum=0.0,
223
+ maximum=2.0,
224
+ step=0.1,
225
+ value=1.1,
226
+ label="Repetition penalty",
227
+ render=False,
228
+ ),
229
  ],
230
  examples=[
231
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],