from dotenv import load_dotenv from replicate.client import Client from transformers import AutoTokenizer # Add this import import gradio as gr import json import time import re import os # CSS styling css = """ .category-legend{display:none} button{height: 60px} """ # Constants MASK_TOKEN = "[MASK]" # Initialize environment and client load_dotenv() replicate = Client(api_token=os.environ.get("REPLICATE_API_TOKEN")) # Load tokenizer for formatting chat template properly tokenizer = AutoTokenizer.from_pretrained( "GSAI-ML/LLaDA-8B-Instruct", trust_remote_code=True ) def parse_constraints(constraints_text): """Parse constraints in format: 'position:word, position:word, ...'""" constraints = {} if not constraints_text: return constraints parts = constraints_text.split(",") for part in parts: if ":" not in part: continue pos_str, word = part.split(":", 1) try: pos = int(pos_str.strip()) word = word.strip() if word and pos >= 0: constraints[pos] = word except ValueError: continue return constraints def format_chat_history(history): """Format chat history for the LLaDA model""" messages = [] for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: # Skip if None (for the latest user message) messages.append({"role": "assistant", "content": assistant_msg}) return messages def generate_response_with_visualization( messages, gen_length=64, steps=32, constraints=None, temperature=0.5, cfg_scale=0.0, block_length=32, remasking="low_confidence", ): """Generate text using the Replicate API version of LLaDA with visualization""" # Process constraints if constraints is None: constraints = {} constraints_json = json.dumps(constraints) # Format chat using the tokenizer's chat template chat_input = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) # Call Replicate API output = replicate.run( "spuuntries/llada-8b-kcv:e8b3ac0457f822454d662dec90edcac05f6e5947a50b55f92b22aa996acbf780", input={ "steps": steps, "prompt": chat_input, "cfg_scale": cfg_scale, "remasking": remasking, "max_tokens": gen_length, "constraints": constraints_json, "temperature": temperature, "block_length": block_length, "prompt_template": "{prompt}", # Use the already formatted prompt }, wait=False, ) # Extract final response and states final_output = output["final_output"] states = output["states"] # Extract only the last assistant response by finding the last occurrence # of the assistant header pattern last_assistant_pattern = r"<\|start_header_id\|>assistant<\|end_header_id\|>\n" last_assistant_match = list(re.finditer(last_assistant_pattern, final_output)) if last_assistant_match: # Get the last match last_match = last_assistant_match[-1] # Start position of the actual content (after the header) start_pos = last_match.end() # Extract everything from this position to the end or until end token end_pattern = r"<\|endoftext\|>|<\|start_header_id\|>" end_match = re.search(end_pattern, final_output[start_pos:]) if end_match: end_pos = start_pos + end_match.start() response_text = final_output[start_pos:end_pos].strip() else: response_text = final_output[start_pos:].strip() else: response_text = "Error: Could not parse the model response." # Process states for visualization visualization_states = [] # Add initial state (all masked) initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)] visualization_states.append(initial_state) for state in states: # Similar parsing for visualization states last_assistant_match = list(re.finditer(last_assistant_pattern, state)) if last_assistant_match: last_match = last_assistant_match[-1] start_pos = last_match.end() tokens_text = state[start_pos:].strip() tokens = tokens_text.split() current_state = [] for token in tokens: if token == "[MASK]": current_state.append((token, "#444444")) # Dark gray for masks else: current_state.append( (token, "#6699CC") ) # Light blue for revealed tokens visualization_states.append(current_state) else: # Fallback if we can't parse properly visualization_states.append( [(MASK_TOKEN, "#FF6666")] ) # Red mask as error indicator return visualization_states, response_text.replace("<|eot_id|>", "") def create_chatbot_demo(): with gr.Blocks(css=css) as demo: gr.Markdown("# LLaDA - Large Language Diffusion Model Demo") gr.Markdown( "[model](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct), [project page](https://ml-gsai.github.io/LLaDA-demo/)" ) # STATE MANAGEMENT chat_history = gr.State([]) # Current response text box (hidden) current_response = gr.Textbox( label="Current Response", placeholder="The assistant's response will appear here...", lines=3, visible=False, ) # UI COMPONENTS with gr.Row(): with gr.Column(scale=3): chatbot_ui = gr.Chatbot(label="Conversation", height=500) # Message input with gr.Group(): with gr.Row(): user_input = gr.Textbox( label="Your Message", placeholder="Type your message here...", show_label=False, ) send_btn = gr.Button("Send") constraints_input = gr.Textbox( label="Word Constraints", info="Format: 'position:word, position:word, ...' Example: '0:Once, 5:upon, 10:time'", placeholder="0:Once, 5:upon, 10:time", value="", ) with gr.Column(scale=2): output_vis = gr.HighlightedText( label="Denoising Process Visualization", combine_adjacent=False, show_legend=True, ) # Advanced generation settings with gr.Accordion("Generation Settings", open=False): with gr.Row(): gen_length = gr.Slider( minimum=16, maximum=128, value=64, step=8, label="Generation Length" ) steps = gr.Slider( minimum=8, maximum=128, value=128, step=4, label="Denoising Steps" ) with gr.Row(): temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.1, label="Temperature" ) cfg_scale = gr.Slider( minimum=0.0, maximum=2.0, value=0.3, step=0.1, label="CFG Scale" ) with gr.Row(): block_length = gr.Slider( minimum=8, maximum=128, value=32, step=8, label="Block Length" ) remasking_strategy = gr.Radio( choices=["low_confidence", "random"], value="low_confidence", label="Remasking Strategy", ) with gr.Row(): visualization_delay = gr.Slider( minimum=0.0, maximum=1.0, value=0.05, step=0.01, label="Visualization Delay (seconds)", ) # Clear button clear_btn = gr.Button("Clear Conversation") def add_message(history, message, response): """Add a message pair to the history and return the updated history""" history = history.copy() history.append([message, response]) return history def user_message_submitted( message, history, gen_length, steps, constraints, delay ): """Process a submitted user message""" # Skip empty messages if not message.strip(): # Return current state unchanged history_for_display = history.copy() return history, history_for_display, "", [], "" # Add user message to history history = add_message(history, message, None) # Format for display - temporarily show user message with empty response history_for_display = history.copy() # Clear the input message_out = "" # Return immediately to update UI with user message return history, history_for_display, message_out, [], "" def bot_response( history, gen_length, steps, constraints, delay, temperature, cfg_scale, block_length, remasking, ): """Generate bot response for the latest message""" if not history: return history, [], "" try: # Format all messages except the last one (which has no response yet) messages = format_chat_history(history[:-1]) # Add the last user message messages.append({"role": "user", "content": history[-1][0]}) # Parse constraints parsed_constraints = parse_constraints(constraints) # Generate response with visualization vis_states, response_text = generate_response_with_visualization( messages, gen_length=gen_length, steps=steps, constraints=parsed_constraints, temperature=temperature, cfg_scale=cfg_scale, block_length=block_length, remasking=remasking, ) # Update history with the assistant's response history[-1][1] = response_text # Return the initial state immediately yield history, vis_states[0], response_text # Then animate through visualization states for state in vis_states[1:]: time.sleep(delay) yield history, state, response_text except Exception as e: error_msg = f"Error: {str(e)}" print(error_msg) # Show error in visualization error_vis = [(error_msg, "red")] # Don't update history with error yield history, error_vis, error_msg def clear_conversation(): """Clear the conversation history""" return [], [], "", [] # EVENT HANDLERS # Clear button handler clear_btn.click( fn=clear_conversation, inputs=[], outputs=[chat_history, chatbot_ui, current_response, output_vis], ) # User message submission flow (2-step process) # Step 1: Add user message to history and update UI msg_submit = user_input.submit( fn=user_message_submitted, inputs=[ user_input, chat_history, gen_length, steps, constraints_input, visualization_delay, ], outputs=[ chat_history, chatbot_ui, user_input, output_vis, current_response, ], ) # Also connect the send button send_click = send_btn.click( fn=user_message_submitted, inputs=[ user_input, chat_history, gen_length, steps, constraints_input, visualization_delay, ], outputs=[ chat_history, chatbot_ui, user_input, output_vis, current_response, ], ) # Step 2: Generate bot response # This happens after the user message is displayed msg_submit.then( fn=bot_response, inputs=[ chat_history, gen_length, steps, constraints_input, visualization_delay, temperature, cfg_scale, block_length, remasking_strategy, ], outputs=[chatbot_ui, output_vis, current_response], ) send_click.then( fn=bot_response, inputs=[ chat_history, gen_length, steps, constraints_input, visualization_delay, temperature, cfg_scale, block_length, remasking_strategy, ], outputs=[chatbot_ui, output_vis, current_response], ) return demo # Launch the demo if __name__ == "__main__": demo = create_chatbot_demo() demo.queue().launch(server_name="0.0.0.0")