import os import time import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread import gradio as gr HF_TOKEN = os.environ.get("HF_TOKEN", None) MODEL = "AGI-0/Art-v0-3B" class ConversationManager: def __init__(self): self.model_messages = [] # Stores raw responses with tags def format_for_display(self, raw_response): """Convert model response to user-friendly markdown. Keeps original response intact for model.""" # No response? Return empty if not raw_response: return "" display_response = raw_response # Handle reasoning sections while "<|start_reasoning|>" in display_response and "<|end_reasoning|>" in display_response: start = display_response.find("<|start_reasoning|>") end = display_response.find("<|end_reasoning|>") + len("<|end_reasoning|>") # Extract reasoning content reasoning_block = display_response[start:end] reasoning_content = reasoning_block.replace("<|start_reasoning|>", "").replace("<|end_reasoning|>", "") # Replace with markdown details/summary markdown_block = f"\n
View Reasoning\n\n{reasoning_content}\n\n
\n" display_response = display_response[:start] + markdown_block + display_response[end:] # Clean up other tags tags_to_remove = [ "<|im_start|>", "<|im_end|>", "<|assistant|>", "<|user|>" ] for tag in tags_to_remove: display_response = display_response.replace(tag, "") # Clean up any extra whitespace display_response = "\n".join(line.strip() for line in display_response.split("\n")) display_response = "\n".join(filter(None, display_response.split("\n"))) return display_response.strip() def add_exchange(self, user_message, assistant_response): """Store raw response in model history""" print("\n=== New Exchange ===") print(f"User: {user_message[:100]}{'...' if len(user_message) > 100 else ''}") print(f"Assistant (raw): {assistant_response[:100]}{'...' if len(assistant_response) > 100 else ''}") self.model_messages.append({ "role": "user", "content": user_message }) self.model_messages.append({ "role": "assistant", "content": assistant_response }) print(f"Current history length: {len(self.model_messages)} messages") def get_conversation_messages(self): """Get full conversation history for model""" return self.model_messages # Initialize globals conversation_manager = ConversationManager() device = "cuda" tokenizer = AutoTokenizer.from_pretrained(MODEL) model = AutoModelForCausalLM.from_pretrained( MODEL, torch_dtype=torch.bfloat16, device_map="auto" ) end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>") @spaces.GPU() def stream_chat( message: str, history: list, system_prompt: str, temperature: float = 0.2, max_new_tokens: int = 4096, top_p: float = 1.0, top_k: int = 1, penalty: float = 1.1, ): print(f"\n=== New Chat Request ===") print(f"Message: {message}") print(f"History length: {len(history)}") # Build conversation history from model's stored messages conversation = [] if system_prompt: conversation.append({"role": "system", "content": system_prompt}) # Add all previous messages conversation.extend(conversation_manager.get_conversation_messages()) # Add new message conversation.append({"role": "user", "content": message}) print(f"Sending {len(conversation)} messages to model") # Prepare model input input_ids = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ).to(model.device) streamer = TextIteratorStreamer( tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=False if temperature == 0 else True, top_p=top_p, top_k=top_k, temperature=temperature, repetition_penalty=penalty, eos_token_id=[end_of_sentence], streamer=streamer, ) # Storage for building complete response buffer = "" model_response = "" with torch.no_grad(): thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() for new_text in streamer: buffer += new_text model_response += new_text # Convert current buffer for display display_text = conversation_manager.format_for_display(buffer) if not thread.is_alive(): print("Generation complete") # Store final response in model history conversation_manager.add_exchange(message, model_response) yield display_text # Set up Gradio interface CSS = """ .duplicate-button { margin: auto !important; color: white !important; background: black !important; border-radius: 100vh !important; } h3 { text-align: center; } """ chatbot = gr.Chatbot( height=600, placeholder="""

Hi! How can I help you today?

""" ) with gr.Blocks(css=CSS, theme="soft") as demo: gr.HTML("""

Link to the model: click here

""") gr.DuplicateButton( value="Duplicate Space for private use", elem_classes="duplicate-button" ) gr.ChatInterface( fn=stream_chat, chatbot=chatbot, fill_height=True, additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False, render=False), additional_inputs=[ gr.Textbox(value="", label="System Prompt", render=False), gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False), gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False), gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False), gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False), gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.1, label="Repetition penalty", render=False), ], examples=[ ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."], ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."], ["Tell me a random fun fact about the Roman Empire."], ["Show me a code snippet of a website's sticky header in CSS and JavaScript."], ], cache_examples=False, ) if __name__ == "__main__": demo.launch()