Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from typing import Callable | |
import base64 | |
from openai import OpenAI | |
def get_fn(model_name: str, **model_kwargs): | |
"""Create a chat function that uses the OpenAI-compatible endpoint.""" | |
OPENAI_API_KEY = "-" | |
client = OpenAI( | |
base_url=" http://192.222.58.60:8000/v1", | |
api_key="tela", | |
) | |
def predict( | |
message: str, | |
history, | |
system_prompt: str, | |
temperature: float, | |
top_p: float, | |
max_tokens: int, | |
): | |
try: | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": message}) | |
response = client.chat.completions.create( | |
model=model_name, | |
messages=messages, | |
temperature=temperature, | |
top_p=top_p, | |
max_tokens=max_new_tokens, | |
n=1, | |
stream=True, | |
response_format={"type": "text"}, | |
) | |
response_text = "" | |
for chunk in response: | |
chunk_message = chunk.choices[0].delta.content | |
if chunk_message: | |
response_text += chunk_message | |
yield assistant_message.strip() | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
yield f"An error occurred: {str(e)}" | |
return predict | |
def get_image_base64(url: str, ext: str): | |
with open(url, "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
return "data:image/" + ext + ";base64," + encoded_string | |
def handle_user_msg(message: str): | |
if isinstance(message, str): | |
return message | |
elif isinstance(message, dict): | |
if message.get("files"): | |
ext = os.path.splitext(message["files"][-1])[1].strip(".") | |
if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]: | |
encoded_str = get_image_base64(message["files"][-1], ext) | |
else: | |
raise NotImplementedError(f"Not supported file type {ext}") | |
content = [ | |
{"type": "text", "text": message.get("text", "")}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": encoded_str, | |
} | |
}, | |
] | |
else: | |
content = message.get("text", "") | |
return content | |
else: | |
raise NotImplementedError | |
def get_model_path(name: str = None, model_path: str = None) -> str: | |
"""Get the model name to use with the endpoint.""" | |
if model_path: | |
return model_path | |
if name: | |
return name | |
raise ValueError("Either name or model_path must be provided") | |
def registry(name: str = None, model_path: str = None, **kwargs): | |
"""Create a Gradio ChatInterface.""" | |
model_name = get_model_path(name, model_path) | |
fn = get_fn(model_name, **kwargs) | |
interface = gr.ChatInterface( | |
fn=fn, | |
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False), | |
additional_inputs=[ | |
gr.Textbox( | |
"You are a helpful AI assistant.", | |
label="System prompt" | |
), | |
gr.Slider(0, 1, 0.7, label="Temperature"), | |
gr.Slider(128, 4096, 1024, label="Max new tokens"), | |
gr.Slider(0, 1, 0.95, label="Top P sampling"), | |
], | |
) | |
return interface | |