Spaces:
Sleeping
Sleeping
File size: 3,783 Bytes
87c4b82 ee40bdf 87c4b82 ee40bdf 87c4b82 ee40bdf 87c4b82 ee40bdf 87c4b82 ee40bdf 87c4b82 ee40bdf 87c4b82 ccfb364 ee40bdf 87c4b82 ee40bdf 87c4b82 b5fc8ee 87c4b82 ee40bdf 87c4b82 ee40bdf 87c4b82 ee40bdf 87c4b82 ee40bdf 87c4b82 ee40bdf 87c4b82 ee40bdf 87c4b82 ee40bdf 87c4b82 ee40bdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
import gradio as gr
from typing import Callable
import base64
from openai import OpenAI
def get_fn(model_name: str, **model_kwargs):
"""Create a chat function that uses the OpenAI-compatible endpoint."""
OPENAI_API_KEY = "-"
client = OpenAI(
base_url=" http://192.222.58.60:8000/v1",
api_key="tela",
)
def predict(
message: str,
history,
system_prompt: str,
temperature: float,
top_p: float,
max_tokens: int,
):
try:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
top_p=top_p,
max_tokens=max_new_tokens,
n=1,
stream=True,
response_format={"type": "text"},
)
response_text = ""
for chunk in response:
chunk_message = chunk.choices[0].delta.content
if chunk_message:
response_text += chunk_message
yield assistant_message.strip()
except Exception as e:
print(f"Error during generation: {str(e)}")
yield f"An error occurred: {str(e)}"
return predict
def get_image_base64(url: str, ext: str):
with open(url, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return "data:image/" + ext + ";base64," + encoded_string
def handle_user_msg(message: str):
if isinstance(message, str):
return message
elif isinstance(message, dict):
if message.get("files"):
ext = os.path.splitext(message["files"][-1])[1].strip(".")
if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
encoded_str = get_image_base64(message["files"][-1], ext)
else:
raise NotImplementedError(f"Not supported file type {ext}")
content = [
{"type": "text", "text": message.get("text", "")},
{
"type": "image_url",
"image_url": {
"url": encoded_str,
}
},
]
else:
content = message.get("text", "")
return content
else:
raise NotImplementedError
def get_model_path(name: str = None, model_path: str = None) -> str:
"""Get the model name to use with the endpoint."""
if model_path:
return model_path
if name:
return name
raise ValueError("Either name or model_path must be provided")
def registry(name: str = None, model_path: str = None, **kwargs):
"""Create a Gradio ChatInterface."""
model_name = get_model_path(name, model_path)
fn = get_fn(model_name, **kwargs)
interface = gr.ChatInterface(
fn=fn,
additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox(
"You are a helpful AI assistant.",
label="System prompt"
),
gr.Slider(0, 1, 0.7, label="Temperature"),
gr.Slider(128, 4096, 1024, label="Max new tokens"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
)
return interface
|