Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from threading import Thread | |
import gradio as gr | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
MODEL = "AGI-0/Art-v0-3B" | |
class ConversationManager: | |
def __init__(self): | |
self.model_messages = [] # Stores raw responses with tags | |
def format_for_display(self, raw_response): | |
"""Convert model response to user-friendly markdown. | |
Keeps original response intact for model.""" | |
# No response? Return empty | |
if not raw_response: | |
return "" | |
display_response = raw_response | |
# Handle reasoning sections | |
while "<|start_reasoning|>" in display_response and "<|end_reasoning|>" in display_response: | |
start = display_response.find("<|start_reasoning|>") | |
end = display_response.find("<|end_reasoning|>") + len("<|end_reasoning|>") | |
# Extract reasoning content | |
reasoning_block = display_response[start:end] | |
reasoning_content = reasoning_block.replace("<|start_reasoning|>", "").replace("<|end_reasoning|>", "") | |
# Replace with markdown details/summary | |
markdown_block = f"\n<details><summary>View Reasoning</summary>\n\n{reasoning_content}\n\n</details>\n" | |
display_response = display_response[:start] + markdown_block + display_response[end:] | |
# Clean up other tags | |
tags_to_remove = [ | |
"<|im_start|>", | |
"<|im_end|>", | |
"<|assistant|>", | |
"<|user|>" | |
] | |
for tag in tags_to_remove: | |
display_response = display_response.replace(tag, "") | |
# Clean up any extra whitespace | |
display_response = "\n".join(line.strip() for line in display_response.split("\n")) | |
display_response = "\n".join(filter(None, display_response.split("\n"))) | |
return display_response.strip() | |
def add_exchange(self, user_message, assistant_response): | |
"""Store raw response in model history""" | |
print("\n=== New Exchange ===") | |
print(f"User: {user_message[:100]}{'...' if len(user_message) > 100 else ''}") | |
print(f"Assistant (raw): {assistant_response[:100]}{'...' if len(assistant_response) > 100 else ''}") | |
self.model_messages.append({ | |
"role": "user", | |
"content": user_message | |
}) | |
self.model_messages.append({ | |
"role": "assistant", | |
"content": assistant_response | |
}) | |
print(f"Current history length: {len(self.model_messages)} messages") | |
def get_conversation_messages(self): | |
"""Get full conversation history for model""" | |
return self.model_messages | |
# Initialize globals | |
conversation_manager = ConversationManager() | |
device = "cuda" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL, | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
def stream_chat( | |
message: str, | |
history: list, | |
system_prompt: str, | |
temperature: float = 0.2, | |
max_new_tokens: int = 4096, | |
top_p: float = 1.0, | |
top_k: int = 1, | |
penalty: float = 1.1, | |
): | |
print(f"\n=== New Chat Request ===") | |
print(f"Message: {message}") | |
print(f"History length: {len(history)}") | |
# Build conversation history from model's stored messages | |
conversation = [] | |
if system_prompt: | |
conversation.append({"role": "system", "content": system_prompt}) | |
# Add all previous messages | |
conversation.extend(conversation_manager.get_conversation_messages()) | |
# Add new message | |
conversation.append({"role": "user", "content": message}) | |
print(f"Sending {len(conversation)} messages to model") | |
# Prepare model input | |
input_ids = tokenizer.apply_chat_template( | |
conversation, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to(model.device) | |
streamer = TextIteratorStreamer( | |
tokenizer, | |
timeout=60.0, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=False if temperature == 0 else True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
repetition_penalty=penalty, | |
eos_token_id=[end_of_sentence], | |
streamer=streamer, | |
) | |
# Storage for building complete response | |
buffer = "" | |
model_response = "" | |
with torch.no_grad(): | |
thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
for new_text in streamer: | |
buffer += new_text | |
model_response += new_text | |
# Convert current buffer for display | |
display_text = conversation_manager.format_for_display(buffer) | |
if not thread.is_alive(): | |
print("Generation complete") | |
# Store final response in model history | |
conversation_manager.add_exchange(message, model_response) | |
yield display_text | |
# Set up Gradio interface | |
CSS = """ | |
.duplicate-button { | |
margin: auto !important; | |
color: white !important; | |
background: black !important; | |
border-radius: 100vh !important; | |
} | |
h3 { text-align: center; } | |
""" | |
chatbot = gr.Chatbot( | |
height=600, | |
placeholder=""" | |
<center> | |
<p>Hi! How can I help you today?</p> | |
</center> | |
""" | |
) | |
with gr.Blocks(css=CSS, theme="soft") as demo: | |
gr.HTML("""<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B">click here</a></h2>""") | |
gr.DuplicateButton( | |
value="Duplicate Space for private use", | |
elem_classes="duplicate-button" | |
) | |
gr.ChatInterface( | |
fn=stream_chat, | |
chatbot=chatbot, | |
fill_height=True, | |
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False, render=False), | |
additional_inputs=[ | |
gr.Textbox(value="", label="System Prompt", render=False), | |
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False), | |
gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False), | |
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False), | |
gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False), | |
gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.1, label="Repetition penalty", render=False), | |
], | |
examples=[ | |
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."], | |
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."], | |
["Tell me a random fun fact about the Roman Empire."], | |
["Show me a code snippet of a website's sticky header in CSS and JavaScript."], | |
], | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.launch() |