File size: 3,783 Bytes
87c4b82
 
 
 
ee40bdf
87c4b82
ee40bdf
 
87c4b82
ee40bdf
 
 
 
 
87c4b82
 
 
 
 
 
ee40bdf
 
 
87c4b82
 
ee40bdf
 
 
87c4b82
ee40bdf
 
 
87c4b82
ccfb364
ee40bdf
 
87c4b82
ee40bdf
 
 
 
 
87c4b82
b5fc8ee
 
 
 
 
 
 
87c4b82
 
 
 
 
 
 
 
 
 
 
 
ee40bdf
87c4b82
ee40bdf
 
87c4b82
 
 
 
 
 
ee40bdf
 
 
 
 
 
 
 
87c4b82
ee40bdf
87c4b82
 
 
 
 
ee40bdf
87c4b82
 
 
ee40bdf
87c4b82
 
 
ee40bdf
 
 
87c4b82
 
 
 
 
 
 
 
 
 
 
 
 
 
ee40bdf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import gradio as gr
from typing import Callable
import base64
from openai import OpenAI

def get_fn(model_name: str, **model_kwargs):
    """Create a chat function that uses the OpenAI-compatible endpoint."""   

    OPENAI_API_KEY = "-"
    client = OpenAI(
    base_url=" http://192.222.58.60:8000/v1", 
    api_key="tela",
    )

    def predict(
        message: str,
        history,
        system_prompt: str,
        temperature: float,
        top_p: float,
        max_tokens: int,
        
    ):
        try:
            messages = []
            if system_prompt:
                messages.append({"role": "system", "content": system_prompt})
            for user_msg, assistant_msg in history:
                messages.append({"role": "user", "content": user_msg})
                messages.append({"role": "assistant", "content": assistant_msg})
            messages.append({"role": "user", "content": message})

            response = client.chat.completions.create(
                model=model_name,
                messages=messages,
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_new_tokens,
                n=1,
                stream=True, 
                response_format={"type": "text"},
            )

            response_text = ""
            for chunk in response:
                chunk_message = chunk.choices[0].delta.content
                if chunk_message:
                    response_text += chunk_message
                    yield assistant_message.strip()
        except Exception as e:
            print(f"Error during generation: {str(e)}")
            yield f"An error occurred: {str(e)}"

    return predict

def get_image_base64(url: str, ext: str):
    with open(url, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
    return "data:image/" + ext + ";base64," + encoded_string

def handle_user_msg(message: str):
    if isinstance(message, str):
        return message
    elif isinstance(message, dict):
        if message.get("files"):
            ext = os.path.splitext(message["files"][-1])[1].strip(".")
            if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]:
                encoded_str = get_image_base64(message["files"][-1], ext)
            else:
                raise NotImplementedError(f"Not supported file type {ext}")
            content = [
                {"type": "text", "text": message.get("text", "")},
                {
                    "type": "image_url",
                    "image_url": {
                        "url": encoded_str,
                    }
                },
            ]
        else:
            content = message.get("text", "")
        return content
    else:
        raise NotImplementedError

def get_model_path(name: str = None, model_path: str = None) -> str:
    """Get the model name to use with the endpoint."""
    if model_path:
        return model_path
    if name:
        return name
    raise ValueError("Either name or model_path must be provided")

def registry(name: str = None, model_path: str = None, **kwargs):
    """Create a Gradio ChatInterface."""
    model_name = get_model_path(name, model_path)
    fn = get_fn(model_name, **kwargs)

    interface = gr.ChatInterface(
        fn=fn,
        additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False),
        additional_inputs=[
            gr.Textbox(
                "You are a helpful AI assistant.",
                label="System prompt"
            ),
            gr.Slider(0, 1, 0.7, label="Temperature"),
            gr.Slider(128, 4096, 1024, label="Max new tokens"),
            gr.Slider(0, 1, 0.95, label="Top P sampling"),
        ],
    )

    return interface