Spaces:
Build error
Build error
# π€β‘ ββ [ I M P O R T S ] | |
import accelerate | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | |
# π§ π§ ββ [ M O D E L ] | |
microsoft_model = None | |
microsoft_tokenizer = None | |
def load_model(): | |
global microsoft_model, microsoft_tokenizer | |
if microsoft_model is None or microsoft_tokenizer is None: | |
model_id = "microsoft/bitnet-b1.58-2B-4T" | |
microsoft_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
config = AutoConfig.from_pretrained(model_id) | |
microsoft_model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
config=config, | |
torch_dtype=torch.bfloat16 | |
) | |
return microsoft_model, microsoft_tokenizer | |
# ποΈπ°οΈ ββ [ C O N V E R S A T I O N - H I S T O R Y ] | |
def manage_history(history): | |
# Limit to 3 turns (each turn is user + assistant = 2 messages) | |
max_messages = 6 # 3 turns * 2 messages per turn | |
if len(history) > max_messages: | |
history = history[-max_messages:] | |
# Limit total character count to 300 | |
total_chars = sum(len(msg["content"]) for msg in history) | |
while total_chars > 300 and history: | |
history.pop(0) # Remove oldest message | |
total_chars = sum(len(msg["content"]) for msg in history) | |
return history | |
# π¬β¨ ββ [ G E N E R A T E - R E S P O N S E ] | |
def generate_response(user_input, system_prompt, max_new_tokens, temperature, top_p, top_k, history): | |
model, tokenizer = load_model() | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_input}, | |
] | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
chat_input = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate Response | |
chat_outputs = model.generate( | |
**chat_input, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
do_sample=True | |
) | |
# Decode Response | |
response = tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True) | |
# Update History | |
history.append({"role": "user", "content": user_input}) | |
history.append({"role": "assistant", "content": response}) | |
# Manage History Limits | |
history = manage_history(history) | |
return history, history | |
# ποΈπ₯οΈ ββ [ G R A D I O - I N T E R F A C E ] | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# BitNet b1.58 2B4T Demo") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(""" | |
## About BitNet b1.58 2B4T | |
BitNet b1.58 2B4T is the first open-source, native 1-bit Large Language Model with 2 billion parameters, | |
developed by Microsoft Research. Trained on 4 trillion tokens, it matches the performance of full-precision | |
models while offering significant efficiency gains in memory, energy, and latency. Features include: | |
- Transformer-based architecture with BitLinear layers | |
- Native 1.58-bit weights and 8-bit activations | |
- Maximum context length of 4096 tokens | |
- Optimized for efficient inference with bitnet.cpp | |
""") | |
with gr.Column(): | |
gr.Markdown(""" | |
## About Tonic AI | |
Tonic AI is a vibrant community of AI enthusiasts and developers always building cool demos and pushing | |
the boundaries of what's possible with AI. We're passionate about creating innovative, accessible, and | |
engaging AI experiences for everyone. Join us in exploring the future of AI! | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...") | |
system_prompt = gr.Textbox( | |
label="System Prompt", | |
value="You are a helpful AI assistant.", | |
placeholder="Enter system prompt..." | |
) | |
with gr.Accordion("Advanced Options", open=False): | |
max_new_tokens = gr.Slider( | |
minimum=10, | |
maximum=500, | |
value=50, | |
step=10, | |
label="Max New Tokens" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="Top P" | |
) | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Top K" | |
) | |
submit_btn = gr.Button("Send") | |
with gr.Column(): | |
chatbot = gr.Chatbot(label="Conversation", type="messages") | |
chat_history = gr.State([]) | |
submit_btn.click( | |
fn=generate_response, | |
inputs=[ | |
user_input, | |
system_prompt, | |
max_new_tokens, | |
temperature, | |
top_p, | |
top_k, | |
chat_history | |
], | |
outputs=[chatbot, chat_history] | |
) | |
# ππ₯ ββ [ M A I N ] | |
if __name__ == "__main__": | |
load_model() | |
demo.launch(ssr_mode=False, share=False) |