Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoImageProcessor | |
from PIL import Image | |
import spaces | |
import os | |
from huggingface_hub import login | |
login(os.environ["HF_KEY"]) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = AutoModelForVision2Seq.from_pretrained("stabilityai/japanese-stable-vlm", trust_remote_code=True, device_map='auto') | |
processor = AutoImageProcessor.from_pretrained("stabilityai/japanese-stable-vlm", device_map='auto') | |
tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stable-vlm", device_map='auto') | |
# Define the helper function to build prompts | |
TASK2INSTRUCTION = { | |
"caption": "画像を詳細に述べてください。", | |
"tag": "与えられた単語を使って、画像を詳細に述べてください。", | |
"vqa": "与えられた画像を下に、質問に答えてください。", | |
} | |
def build_prompt(task="caption", input=None, sep="\n\n### "): | |
assert task in TASK2INSTRUCTION, f"Please choose from {list(TASK2INSTRUCTION.keys())}" | |
if task in ["tag", "vqa"]: | |
assert input is not None, "Please fill in `input`!" | |
if task == "tag" and isinstance(input, list): | |
input = "、".join(input) | |
else: | |
assert input is None, f"`{task}` mode doesn't support to input questions" | |
sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。" | |
p = sys_msg | |
roles = ["指示", "応答"] | |
instruction = TASK2INSTRUCTION[task] | |
msgs = [": \n" + instruction, ": \n"] | |
if input: | |
roles.insert(1, "入力") | |
msgs.insert(1, ": \n" + input) | |
for role, msg in zip(roles, msgs): | |
p += sep + role + msg | |
return p | |
# Define the function to generate text from the image and prompt | |
def generate_text(image, task, input_text=None): | |
prompt = build_prompt(task=task, input=input_text) | |
inputs = processor(images=image, return_tensors="pt") | |
text_encoding = tokenizer(prompt, add_special_tokens=False, return_tensors="pt") | |
inputs.update(text_encoding) | |
outputs = model.generate( | |
**inputs.to(device=device, dtype=model.dtype), | |
do_sample=False, | |
num_beams=5, | |
max_new_tokens=128, | |
min_length=1, | |
repetition_penalty=1.5, | |
) | |
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip() | |
return generated_text | |
# Define the Gradio interface | |
image_input = gr.Image(label="Upload an image") | |
task_input = gr.Radio(choices=["caption", "tag", "vqa"], value="caption", label="Select a task") | |
text_input = gr.Textbox(label="Enter text (for tag or vqa tasks)") | |
output = gr.Textbox(label="Generated text") | |
interface = gr.Interface( | |
fn=generate_text, | |
inputs=[image_input, task_input, text_input], | |
outputs=output, | |
examples=[ | |
["examples/example_image.jpg", "caption", None], | |
["examples/example_image.jpg", "tag", "河津桜、青空"], | |
["examples/example_image.jpg", "vqa", "OCRはできますか?"], | |
], | |
) | |
interface.launch() |