Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import gradio as gr | |
from threading import Thread | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
MODEL = "AGI-0/Art-v0-3B" | |
device = "cuda" # Use "cpu" if no GPU available | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL, torch_dtype=torch.bfloat16, device_map="auto" | |
) | |
end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
class ConversationManager: | |
def __init__(self): | |
self.user_history = [] # User-facing history with formatting | |
self.model_history = [] # Model-facing history without formatting | |
def add_exchange(self, user_message, model_response): | |
formatted_response = self.format_response(model_response) | |
self.model_history.append((user_message, model_response)) | |
self.user_history.append((user_message, formatted_response)) | |
print(f"\nModel History Updated: {self.model_history}") | |
print(f"\nUser History Updated: {self.user_history}") | |
def format_response(self, response): | |
"""Format response for UI while keeping raw text for model.""" | |
if "<|end_reasoning|>" in response: | |
parts = response.split("<|end_reasoning|>") | |
reasoning, rest = parts[0], parts[1] if len(parts) > 1 else "" | |
return f"<details><summary>Click to see reasoning</summary>\n\n{reasoning}\n\n</details>\n\n{rest}" | |
return response | |
def get_user_history(self): | |
return self.user_history | |
def get_model_history(self): | |
return self.model_history | |
conversation_manager = ConversationManager() | |
def stream_chat( | |
message: str, | |
history: list, | |
system_prompt: str, | |
temperature: float = 0.2, | |
max_new_tokens: int = 4096, | |
top_p: float = 1.0, | |
top_k: int = 1, | |
penalty: float = 1.1, | |
): | |
print(f'User Message: {message}') | |
model_history = conversation_manager.get_model_history() | |
print(f'Model History: {model_history}') | |
conversation = [] | |
for prompt, answer in model_history: | |
conversation.extend([ | |
{"role": "user", "content": prompt}, | |
{"role": "assistant", "content": answer}, | |
]) | |
conversation.append({"role": "user", "content": message}) | |
print(f'Formatted Conversation for Model: {conversation}') | |
input_ids = tokenizer.apply_chat_template( | |
conversation, add_generation_prompt=True, return_tensors="pt" | |
).to(model.device) | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=False if temperature == 0 else True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
repetition_penalty=penalty, | |
eos_token_id=[end_of_sentence], | |
streamer=streamer, | |
) | |
buffer = "" | |
original_response = "" | |
with torch.no_grad(): | |
thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
for new_text in streamer: | |
buffer += new_text | |
original_response += new_text | |
print(f'Streaming: {buffer}') | |
formatted_buffer = conversation_manager.format_response(buffer) | |
yield formatted_buffer, history + [[message, formatted_buffer]] | |
conversation_manager.add_exchange(message, original_response) | |
chatbot = gr.Chatbot(height=600, placeholder="<center><p>Hi! How can I help you today?</p></center>") | |
demo = gr.Blocks() | |
with demo: | |
gr.HTML("<h2>Link to the model: <a href='https://huggingface.co/AGI-0/Art-v0-3B'>click here</a></h2>") | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") | |
gr.ChatInterface( | |
fn=stream_chat, | |
chatbot=chatbot, | |
fill_height=True, | |
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), | |
additional_inputs=[ | |
gr.Textbox(value="", label="System Prompt", render=False), | |
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False), | |
gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False), | |
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False), | |
gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False), | |
gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.1, label="Repetition penalty", render=False), | |
], | |
examples=[ | |
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."], | |
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."], | |
["Tell me a random fun fact about the Roman Empire."], | |
["Show me a code snippet of a website's sticky header in CSS and JavaScript."], | |
], | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.launch() |