File size: 3,639 Bytes
cc22261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from threading import Thread

import gradio as gr
import openvino as ov
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.model.builder import load_pretrained_model
from transformers import TextIteratorStreamer

css = """
.text textarea {font-size: 24px !important;}
.text p {font-size: 24px !important;}
"""

model_path = "llava-med-imf16-llmint4"
# model_path = "llava-med-imint8-llmint4"
model_name = get_model_name_from_path(model_path)

device = "GPU" if "GPU" in ov.Core().available_devices else "CPU"
image_device = "NPU" if "NPU" in ov.Core().available_devices else device
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=model_name,
    device=device,
    openvino=True,
    image_device=image_device,
)
print("models loaded")


def reset_inputs():
    return None, "", ""


def prepare_inputs_image(image, question):
    conv_mode = "vicuna_v1"  # default
    qs = question.replace(DEFAULT_IMAGE_TOKEN, "").strip()
    qs = DEFAULT_IMAGE_TOKEN + "\n" + qs  # model.config.mm_use_im_start_end is False

    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0)

    # image = Image.open(image_file)
    image_tensor = process_images([image], image_processor, model.config)[0]
    return input_ids, image_tensor


def run_inference(image, message):
    """
    Function to handle the chat input and generate model responses.
    """
    if not message:
        return ""

    input_ids, image_tensor = prepare_inputs_image(image, message)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = {
        "streamer": streamer,
        "input_ids": input_ids,
        "images": image_tensor.unsqueeze(0).half(),
        "do_sample": False,
        "max_new_tokens": 512,
        "use_cache": True,
    }
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Stream output
    response = ""
    for new_text in streamer:
        response += new_text
        yield response


with gr.Blocks(css=css) as demo:
    gr.Markdown("# LLaVA-Med 1.5 OpenVINO Demo")

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload an Image", height=300, width=500)
        with gr.Column():
            text_input = gr.Textbox(label="Enter a Question", elem_classes="text", interactive=True)
            chatbot = gr.Textbox(label="Answer", elem_classes="text")

    with gr.Row():
        process_button = gr.Button("Process")
        reset_button = gr.Button("Reset")

    gr.Markdown("NOTE: This OpenVINO model is unvalidated. Results are provisional and may contain errors. Use this demo to explore AI PC and OpenVINO optimizations")
    gr.Markdown("Source model: [microsoft/LLaVA-Med](https://github.com/microsoft/LLaVA-Med). For research purposes only.")

    process_button.click(run_inference, inputs=[image_input, text_input], outputs=chatbot)
    text_input.submit(run_inference, inputs=[image_input, text_input], outputs=chatbot)
    reset_button.click(reset_inputs, inputs=[], outputs=[image_input, text_input, chatbot])

if __name__ == "__main__":
    demo.launch(server_port=7788, server_name="0.0.0.0")