Art3B-chat / app.py
freeCS-dot-org's picture
Update app.py
2fb89d3 verified
raw
history blame
5.26 kB
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = "AGI-0/Art-v0-3B"
device = "cuda" # Use "cpu" if no GPU available
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
MODEL, torch_dtype=torch.bfloat16, device_map="auto"
)
end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
class ConversationManager:
def __init__(self):
self.user_history = [] # User-facing history with formatting
self.model_history = [] # Model-facing history without formatting
def add_exchange(self, user_message, model_response):
formatted_response = self.format_response(model_response)
self.model_history.append((user_message, model_response))
self.user_history.append((user_message, formatted_response))
print(f"\nModel History Updated: {self.model_history}")
print(f"\nUser History Updated: {self.user_history}")
def format_response(self, response):
"""Format response for UI while keeping raw text for model."""
if "<|end_reasoning|>" in response:
parts = response.split("<|end_reasoning|>")
reasoning, rest = parts[0], parts[1] if len(parts) > 1 else ""
return f"<details><summary>Click to see reasoning</summary>\n\n{reasoning}\n\n</details>\n\n{rest}"
return response
def get_user_history(self):
return self.user_history
def get_model_history(self):
return self.model_history
conversation_manager = ConversationManager()
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.2,
max_new_tokens: int = 4096,
top_p: float = 1.0,
top_k: int = 1,
penalty: float = 1.1,
):
print(f'User Message: {message}')
model_history = conversation_manager.get_model_history()
print(f'Model History: {model_history}')
conversation = []
for prompt, answer in model_history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer},
])
conversation.append({"role": "user", "content": message})
print(f'Formatted Conversation for Model: {conversation}')
input_ids = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, return_tensors="pt"
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=False if temperature == 0 else True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty,
eos_token_id=[end_of_sentence],
streamer=streamer,
)
buffer = ""
original_response = ""
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
for new_text in streamer:
buffer += new_text
original_response += new_text
print(f'Streaming: {buffer}')
formatted_buffer = conversation_manager.format_response(buffer)
yield formatted_buffer, history + [[message, formatted_buffer]]
conversation_manager.add_exchange(message, original_response)
chatbot = gr.Chatbot(height=600, placeholder="<center><p>Hi! How can I help you today?</p></center>")
demo = gr.Blocks()
with demo:
gr.HTML("<h2>Link to the model: <a href='https://huggingface.co/AGI-0/Art-v0-3B'>click here</a></h2>")
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Textbox(value="", label="System Prompt", render=False),
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False),
gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False),
gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False),
gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.1, label="Repetition penalty", render=False),
],
examples=[
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
["Tell me a random fun fact about the Roman Empire."],
["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()