Art3B-chat / app.py
freeCS-dot-org's picture
Update app.py
1898bf7 verified
raw
history blame
7.38 kB
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import gradio as gr
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = "AGI-0/Art-v0-3B"
class ConversationManager:
def __init__(self):
self.model_messages = [] # Stores raw responses with tags
def format_for_display(self, raw_response):
"""Convert model response to user-friendly markdown.
Keeps original response intact for model."""
# No response? Return empty
if not raw_response:
return ""
display_response = raw_response
# Handle reasoning sections
while "<|start_reasoning|>" in display_response and "<|end_reasoning|>" in display_response:
start = display_response.find("<|start_reasoning|>")
end = display_response.find("<|end_reasoning|>") + len("<|end_reasoning|>")
# Extract reasoning content
reasoning_block = display_response[start:end]
reasoning_content = reasoning_block.replace("<|start_reasoning|>", "").replace("<|end_reasoning|>", "")
# Replace with markdown details/summary
markdown_block = f"\n<details><summary>View Reasoning</summary>\n\n{reasoning_content}\n\n</details>\n"
display_response = display_response[:start] + markdown_block + display_response[end:]
# Clean up other tags
tags_to_remove = [
"<|im_start|>",
"<|im_end|>",
"<|assistant|>",
"<|user|>"
]
for tag in tags_to_remove:
display_response = display_response.replace(tag, "")
# Clean up any extra whitespace
display_response = "\n".join(line.strip() for line in display_response.split("\n"))
display_response = "\n".join(filter(None, display_response.split("\n")))
return display_response.strip()
def add_exchange(self, user_message, assistant_response):
"""Store raw response in model history"""
print("\n=== New Exchange ===")
print(f"User: {user_message[:100]}{'...' if len(user_message) > 100 else ''}")
print(f"Assistant (raw): {assistant_response[:100]}{'...' if len(assistant_response) > 100 else ''}")
self.model_messages.append({
"role": "user",
"content": user_message
})
self.model_messages.append({
"role": "assistant",
"content": assistant_response
})
print(f"Current history length: {len(self.model_messages)} messages")
def get_conversation_messages(self):
"""Get full conversation history for model"""
return self.model_messages
# Initialize globals
conversation_manager = ConversationManager()
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
MODEL,
torch_dtype=torch.bfloat16,
device_map="auto"
)
end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.2,
max_new_tokens: int = 4096,
top_p: float = 1.0,
top_k: int = 1,
penalty: float = 1.1,
):
print(f"\n=== New Chat Request ===")
print(f"Message: {message}")
print(f"History length: {len(history)}")
# Build conversation history from model's stored messages
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
# Add all previous messages
conversation.extend(conversation_manager.get_conversation_messages())
# Add new message
conversation.append({"role": "user", "content": message})
print(f"Sending {len(conversation)} messages to model")
# Prepare model input
input_ids = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer,
timeout=60.0,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=False if temperature == 0 else True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty,
eos_token_id=[end_of_sentence],
streamer=streamer,
)
# Storage for building complete response
buffer = ""
model_response = ""
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
for new_text in streamer:
buffer += new_text
model_response += new_text
# Convert current buffer for display
display_text = conversation_manager.format_for_display(buffer)
if not thread.is_alive():
print("Generation complete")
# Store final response in model history
conversation_manager.add_exchange(message, model_response)
yield display_text
# Set up Gradio interface
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 { text-align: center; }
"""
chatbot = gr.Chatbot(
height=600,
placeholder="""
<center>
<p>Hi! How can I help you today?</p>
</center>
"""
)
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML("""<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B">click here</a></h2>""")
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_classes="duplicate-button"
)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Textbox(value="", label="System Prompt", render=False),
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False),
gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False),
gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.1, label="Repetition penalty", render=False),
],
examples=[
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
["Tell me a random fun fact about the Roman Empire."],
["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()