spuuntries commited on
Commit
6f75aef
·
1 Parent(s): f543cd1

feat!: demo

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. .gradio/certificate.pem +31 -0
  3. app.py +427 -0
  4. requirements.txt +1 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from replicate.client import Client
3
+ from transformers import AutoTokenizer # Add this import
4
+
5
+ import gradio as gr
6
+ import json
7
+ import time
8
+ import re
9
+ import os
10
+
11
+ # CSS styling
12
+ css = """
13
+ .category-legend{display:none}
14
+ button{height: 60px}
15
+ """
16
+
17
+ # Constants
18
+ MASK_TOKEN = "[MASK]"
19
+
20
+ # Initialize environment and client
21
+ load_dotenv()
22
+ replicate = Client(api_token=os.environ.get("REPLICATE_API_TOKEN"))
23
+
24
+ # Load tokenizer for formatting chat template properly
25
+ tokenizer = AutoTokenizer.from_pretrained(
26
+ "GSAI-ML/LLaDA-8B-Instruct", trust_remote_code=True
27
+ )
28
+
29
+
30
+ def parse_constraints(constraints_text):
31
+ """Parse constraints in format: 'position:word, position:word, ...'"""
32
+ constraints = {}
33
+ if not constraints_text:
34
+ return constraints
35
+
36
+ parts = constraints_text.split(",")
37
+ for part in parts:
38
+ if ":" not in part:
39
+ continue
40
+ pos_str, word = part.split(":", 1)
41
+ try:
42
+ pos = int(pos_str.strip())
43
+ word = word.strip()
44
+ if word and pos >= 0:
45
+ constraints[pos] = word
46
+ except ValueError:
47
+ continue
48
+
49
+ return constraints
50
+
51
+
52
+ def format_chat_history(history):
53
+ """Format chat history for the LLaDA model"""
54
+ messages = []
55
+ for user_msg, assistant_msg in history:
56
+ messages.append({"role": "user", "content": user_msg})
57
+ if assistant_msg: # Skip if None (for the latest user message)
58
+ messages.append({"role": "assistant", "content": assistant_msg})
59
+
60
+ return messages
61
+
62
+
63
+ def generate_response_with_visualization(
64
+ messages,
65
+ gen_length=64,
66
+ steps=32,
67
+ constraints=None,
68
+ temperature=0.5,
69
+ cfg_scale=0.0,
70
+ block_length=32,
71
+ remasking="low_confidence",
72
+ ):
73
+ """Generate text using the Replicate API version of LLaDA with visualization"""
74
+
75
+ # Process constraints
76
+ if constraints is None:
77
+ constraints = {}
78
+ constraints_json = json.dumps(constraints)
79
+
80
+ # Format chat using the tokenizer's chat template
81
+ chat_input = tokenizer.apply_chat_template(
82
+ messages, add_generation_prompt=True, tokenize=False
83
+ )
84
+
85
+ # Call Replicate API
86
+ output = replicate.run(
87
+ "spuuntries/llada-8b-kcv:e8b3ac0457f822454d662dec90edcac05f6e5947a50b55f92b22aa996acbf780",
88
+ input={
89
+ "steps": steps,
90
+ "prompt": chat_input,
91
+ "cfg_scale": cfg_scale,
92
+ "remasking": remasking,
93
+ "max_tokens": gen_length,
94
+ "constraints": constraints_json,
95
+ "temperature": temperature,
96
+ "block_length": block_length,
97
+ "prompt_template": "{prompt}", # Use the already formatted prompt
98
+ },
99
+ wait=False,
100
+ )
101
+
102
+ # Extract final response and states
103
+ final_output = output["final_output"]
104
+ states = output["states"]
105
+
106
+ # Extract only the last assistant response by finding the last occurrence
107
+ # of the assistant header pattern
108
+ last_assistant_pattern = r"<\|start_header_id\|>assistant<\|end_header_id\|>\n"
109
+ last_assistant_match = list(re.finditer(last_assistant_pattern, final_output))
110
+
111
+ if last_assistant_match:
112
+ # Get the last match
113
+ last_match = last_assistant_match[-1]
114
+ # Start position of the actual content (after the header)
115
+ start_pos = last_match.end()
116
+ # Extract everything from this position to the end or until end token
117
+ end_pattern = r"<\|endoftext\|>|<\|start_header_id\|>"
118
+ end_match = re.search(end_pattern, final_output[start_pos:])
119
+
120
+ if end_match:
121
+ end_pos = start_pos + end_match.start()
122
+ response_text = final_output[start_pos:end_pos].strip()
123
+ else:
124
+ response_text = final_output[start_pos:].strip()
125
+ else:
126
+ response_text = "Error: Could not parse the model response."
127
+
128
+ # Process states for visualization
129
+ visualization_states = []
130
+
131
+ # Add initial state (all masked)
132
+ initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
133
+ visualization_states.append(initial_state)
134
+
135
+ for state in states:
136
+ # Similar parsing for visualization states
137
+ last_assistant_match = list(re.finditer(last_assistant_pattern, state))
138
+
139
+ if last_assistant_match:
140
+ last_match = last_assistant_match[-1]
141
+ start_pos = last_match.end()
142
+ tokens_text = state[start_pos:].strip()
143
+ tokens = tokens_text.split()
144
+
145
+ current_state = []
146
+ for token in tokens:
147
+ if token == "[MASK]":
148
+ current_state.append((token, "#444444")) # Dark gray for masks
149
+ else:
150
+ current_state.append(
151
+ (token, "#6699CC")
152
+ ) # Light blue for revealed tokens
153
+
154
+ visualization_states.append(current_state)
155
+ else:
156
+ # Fallback if we can't parse properly
157
+ visualization_states.append(
158
+ [(MASK_TOKEN, "#FF6666")]
159
+ ) # Red mask as error indicator
160
+
161
+ return visualization_states, response_text.replace("<|eot_id|>", "")
162
+
163
+
164
+ def create_chatbot_demo():
165
+ with gr.Blocks(css=css) as demo:
166
+ gr.Markdown("# LLaDA - Large Language Diffusion Model Demo")
167
+ gr.Markdown(
168
+ "[model](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct), [project page](https://ml-gsai.github.io/LLaDA-demo/)"
169
+ )
170
+
171
+ # STATE MANAGEMENT
172
+ chat_history = gr.State([])
173
+
174
+ # Current response text box (hidden)
175
+ current_response = gr.Textbox(
176
+ label="Current Response",
177
+ placeholder="The assistant's response will appear here...",
178
+ lines=3,
179
+ visible=False,
180
+ )
181
+
182
+ # UI COMPONENTS
183
+ with gr.Row():
184
+ with gr.Column(scale=3):
185
+ chatbot_ui = gr.Chatbot(label="Conversation", height=500)
186
+
187
+ # Message input
188
+ with gr.Group():
189
+ with gr.Row():
190
+ user_input = gr.Textbox(
191
+ label="Your Message",
192
+ placeholder="Type your message here...",
193
+ show_label=False,
194
+ )
195
+ send_btn = gr.Button("Send")
196
+
197
+ constraints_input = gr.Textbox(
198
+ label="Word Constraints",
199
+ info="Format: 'position:word, position:word, ...' Example: '0:Once, 5:upon, 10:time'",
200
+ placeholder="0:Once, 5:upon, 10:time",
201
+ value="",
202
+ )
203
+ with gr.Column(scale=2):
204
+ output_vis = gr.HighlightedText(
205
+ label="Denoising Process Visualization",
206
+ combine_adjacent=False,
207
+ show_legend=True,
208
+ )
209
+
210
+ # Advanced generation settings
211
+ with gr.Accordion("Generation Settings", open=False):
212
+ with gr.Row():
213
+ gen_length = gr.Slider(
214
+ minimum=16, maximum=128, value=64, step=8, label="Generation Length"
215
+ )
216
+ steps = gr.Slider(
217
+ minimum=8, maximum=128, value=32, step=4, label="Denoising Steps"
218
+ )
219
+ with gr.Row():
220
+ temperature = gr.Slider(
221
+ minimum=0.0, maximum=1.0, value=0.5, step=0.1, label="Temperature"
222
+ )
223
+ cfg_scale = gr.Slider(
224
+ minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale"
225
+ )
226
+ with gr.Row():
227
+ block_length = gr.Slider(
228
+ minimum=8, maximum=128, value=32, step=8, label="Block Length"
229
+ )
230
+ remasking_strategy = gr.Radio(
231
+ choices=["low_confidence", "random"],
232
+ value="low_confidence",
233
+ label="Remasking Strategy",
234
+ )
235
+ with gr.Row():
236
+ visualization_delay = gr.Slider(
237
+ minimum=0.0,
238
+ maximum=1.0,
239
+ value=0.05,
240
+ step=0.01,
241
+ label="Visualization Delay (seconds)",
242
+ )
243
+
244
+ # Clear button
245
+ clear_btn = gr.Button("Clear Conversation")
246
+
247
+ def add_message(history, message, response):
248
+ """Add a message pair to the history and return the updated history"""
249
+ history = history.copy()
250
+ history.append([message, response])
251
+ return history
252
+
253
+ def user_message_submitted(
254
+ message, history, gen_length, steps, constraints, delay
255
+ ):
256
+ """Process a submitted user message"""
257
+ # Skip empty messages
258
+ if not message.strip():
259
+ # Return current state unchanged
260
+ history_for_display = history.copy()
261
+ return history, history_for_display, "", [], ""
262
+
263
+ # Add user message to history
264
+ history = add_message(history, message, None)
265
+
266
+ # Format for display - temporarily show user message with empty response
267
+ history_for_display = history.copy()
268
+
269
+ # Clear the input
270
+ message_out = ""
271
+
272
+ # Return immediately to update UI with user message
273
+ return history, history_for_display, message_out, [], ""
274
+
275
+ def bot_response(
276
+ history,
277
+ gen_length,
278
+ steps,
279
+ constraints,
280
+ delay,
281
+ temperature,
282
+ cfg_scale,
283
+ block_length,
284
+ remasking,
285
+ ):
286
+ """Generate bot response for the latest message"""
287
+ if not history:
288
+ return history, [], ""
289
+
290
+ try:
291
+ # Format all messages except the last one (which has no response yet)
292
+ messages = format_chat_history(history[:-1])
293
+
294
+ # Add the last user message
295
+ messages.append({"role": "user", "content": history[-1][0]})
296
+
297
+ # Parse constraints
298
+ parsed_constraints = parse_constraints(constraints)
299
+
300
+ # Generate response with visualization
301
+ vis_states, response_text = generate_response_with_visualization(
302
+ messages,
303
+ gen_length=gen_length,
304
+ steps=steps,
305
+ constraints=parsed_constraints,
306
+ temperature=temperature,
307
+ cfg_scale=cfg_scale,
308
+ block_length=block_length,
309
+ remasking=remasking,
310
+ )
311
+
312
+ # Update history with the assistant's response
313
+ history[-1][1] = response_text
314
+
315
+ # Return the initial state immediately
316
+ yield history, vis_states[0], response_text
317
+
318
+ # Then animate through visualization states
319
+ for state in vis_states[1:]:
320
+ time.sleep(delay)
321
+ yield history, state, response_text
322
+
323
+ except Exception as e:
324
+ error_msg = f"Error: {str(e)}"
325
+ print(error_msg)
326
+
327
+ # Show error in visualization
328
+ error_vis = [(error_msg, "red")]
329
+
330
+ # Don't update history with error
331
+ yield history, error_vis, error_msg
332
+
333
+ def clear_conversation():
334
+ """Clear the conversation history"""
335
+ return [], [], "", []
336
+
337
+ # EVENT HANDLERS
338
+
339
+ # Clear button handler
340
+ clear_btn.click(
341
+ fn=clear_conversation,
342
+ inputs=[],
343
+ outputs=[chat_history, chatbot_ui, current_response, output_vis],
344
+ )
345
+
346
+ # User message submission flow (2-step process)
347
+ # Step 1: Add user message to history and update UI
348
+ msg_submit = user_input.submit(
349
+ fn=user_message_submitted,
350
+ inputs=[
351
+ user_input,
352
+ chat_history,
353
+ gen_length,
354
+ steps,
355
+ constraints_input,
356
+ visualization_delay,
357
+ ],
358
+ outputs=[
359
+ chat_history,
360
+ chatbot_ui,
361
+ user_input,
362
+ output_vis,
363
+ current_response,
364
+ ],
365
+ )
366
+
367
+ # Also connect the send button
368
+ send_click = send_btn.click(
369
+ fn=user_message_submitted,
370
+ inputs=[
371
+ user_input,
372
+ chat_history,
373
+ gen_length,
374
+ steps,
375
+ constraints_input,
376
+ visualization_delay,
377
+ ],
378
+ outputs=[
379
+ chat_history,
380
+ chatbot_ui,
381
+ user_input,
382
+ output_vis,
383
+ current_response,
384
+ ],
385
+ )
386
+
387
+ # Step 2: Generate bot response
388
+ # This happens after the user message is displayed
389
+ msg_submit.then(
390
+ fn=bot_response,
391
+ inputs=[
392
+ chat_history,
393
+ gen_length,
394
+ steps,
395
+ constraints_input,
396
+ visualization_delay,
397
+ temperature,
398
+ cfg_scale,
399
+ block_length,
400
+ remasking_strategy,
401
+ ],
402
+ outputs=[chatbot_ui, output_vis, current_response],
403
+ )
404
+
405
+ send_click.then(
406
+ fn=bot_response,
407
+ inputs=[
408
+ chat_history,
409
+ gen_length,
410
+ steps,
411
+ constraints_input,
412
+ visualization_delay,
413
+ temperature,
414
+ cfg_scale,
415
+ block_length,
416
+ remasking_strategy,
417
+ ],
418
+ outputs=[chatbot_ui, output_vis, current_response],
419
+ )
420
+
421
+ return demo
422
+
423
+
424
+ # Launch the demo
425
+ if __name__ == "__main__":
426
+ demo = create_chatbot_demo()
427
+ demo.queue().launch(server_name="0.0.0.0")
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ replicate