marco-o1 / helper.py
rodrigomasini's picture
Update helper.py
b5fc8ee verified
raw
history blame
3.78 kB
import os
import gradio as gr
from typing import Callable
import base64
from openai import OpenAI
def get_fn(model_name: str, **model_kwargs):
"""Create a chat function that uses the OpenAI-compatible endpoint."""
OPENAI_API_KEY = "-"
client = OpenAI(
base_url=" http://192.222.58.60:8000/v1",
api_key="tela",
)
def predict(
message: str,
history,
system_prompt: str,
temperature: float,
top_p: float,
max_tokens: int,
):
try:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
max_tokens=max_new_tokens,
n=1,
stream=True,
response_format={"type": "text"},
)
response_text = ""
for chunk in response:
chunk_message = chunk.choices[0].delta.content
if chunk_message:
response_text += chunk_message
yield assistant_message.strip()
except Exception as e:
print(f"Error during generation: {str(e)}")
yield f"An error occurred: {str(e)}"
return predict
def get_image_base64(url: str, ext: str):
with open(url, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return "data:image/" + ext + ";base64," + encoded_string
def handle_user_msg(message: str):
if isinstance(message, str):
return message
elif isinstance(message, dict):
if message.get("files"):
ext = os.path.splitext(message["files"][-1])[1].strip(".")
if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
encoded_str = get_image_base64(message["files"][-1], ext)
else:
raise NotImplementedError(f"Not supported file type {ext}")
content = [
{"type": "text", "text": message.get("text", "")},
{
"type": "image_url",
"image_url": {
"url": encoded_str,
}
},
]
else:
content = message.get("text", "")
return content
else:
raise NotImplementedError
def get_model_path(name: str = None, model_path: str = None) -> str:
"""Get the model name to use with the endpoint."""
if model_path:
return model_path
if name:
return name
raise ValueError("Either name or model_path must be provided")
def registry(name: str = None, model_path: str = None, **kwargs):
"""Create a Gradio ChatInterface."""
model_name = get_model_path(name, model_path)
fn = get_fn(model_name, **kwargs)
interface = gr.ChatInterface(
fn=fn,
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox(
"You are a helpful AI assistant.",
label="System prompt"
),
gr.Slider(0, 1, 0.7, label="Temperature"),
gr.Slider(128, 4096, 1024, label="Max new tokens"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
)
return interface