import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import gradio as gr
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = "AGI-0/Art-v0-3B"
class ConversationManager:
def __init__(self):
self.model_messages = [] # Stores raw responses with tags
def format_for_display(self, raw_response):
"""Convert model response to user-friendly markdown.
Keeps original response intact for model."""
# No response? Return empty
if not raw_response:
return ""
display_response = raw_response
# Handle reasoning sections
while "<|start_reasoning|>" in display_response and "<|end_reasoning|>" in display_response:
start = display_response.find("<|start_reasoning|>")
end = display_response.find("<|end_reasoning|>") + len("<|end_reasoning|>")
# Extract reasoning content
reasoning_block = display_response[start:end]
reasoning_content = reasoning_block.replace("<|start_reasoning|>", "").replace("<|end_reasoning|>", "")
# Replace with markdown details/summary
markdown_block = f"\nView Reasoning
\n\n{reasoning_content}\n\n \n"
display_response = display_response[:start] + markdown_block + display_response[end:]
# Clean up other tags
tags_to_remove = [
"<|im_start|>",
"<|im_end|>",
"<|assistant|>",
"<|user|>"
]
for tag in tags_to_remove:
display_response = display_response.replace(tag, "")
# Clean up any extra whitespace
display_response = "\n".join(line.strip() for line in display_response.split("\n"))
display_response = "\n".join(filter(None, display_response.split("\n")))
return display_response.strip()
def add_exchange(self, user_message, assistant_response):
"""Store raw response in model history"""
print("\n=== New Exchange ===")
print(f"User: {user_message[:100]}{'...' if len(user_message) > 100 else ''}")
print(f"Assistant (raw): {assistant_response[:100]}{'...' if len(assistant_response) > 100 else ''}")
self.model_messages.append({
"role": "user",
"content": user_message
})
self.model_messages.append({
"role": "assistant",
"content": assistant_response
})
print(f"Current history length: {len(self.model_messages)} messages")
def get_conversation_messages(self):
"""Get full conversation history for model"""
return self.model_messages
# Initialize globals
conversation_manager = ConversationManager()
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
MODEL,
torch_dtype=torch.bfloat16,
device_map="auto"
)
end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.2,
max_new_tokens: int = 4096,
top_p: float = 1.0,
top_k: int = 1,
penalty: float = 1.1,
):
print(f"\n=== New Chat Request ===")
print(f"Message: {message}")
print(f"History length: {len(history)}")
# Build conversation history from model's stored messages
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
# Add all previous messages
conversation.extend(conversation_manager.get_conversation_messages())
# Add new message
conversation.append({"role": "user", "content": message})
print(f"Sending {len(conversation)} messages to model")
# Prepare model input
input_ids = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
timeout=60.0,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=False if temperature == 0 else True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty,
eos_token_id=[end_of_sentence],
streamer=streamer,
)
# Storage for building complete response
buffer = ""
model_response = ""
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
for new_text in streamer:
buffer += new_text
model_response += new_text
# Convert current buffer for display
display_text = conversation_manager.format_for_display(buffer)
if not thread.is_alive():
print("Generation complete")
# Store final response in model history
conversation_manager.add_exchange(message, model_response)
yield display_text
# Set up Gradio interface
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 { text-align: center; }
"""
chatbot = gr.Chatbot(
height=600,
placeholder="""
Hi! How can I help you today?
"""
)
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML("""""")
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_classes="duplicate-button"
)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Textbox(value="", label="System Prompt", render=False),
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False),
gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False),
gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.1, label="Repetition penalty", render=False),
],
examples=[
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
["Tell me a random fun fact about the Roman Empire."],
["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()