preston-cell's picture
Update app.py
f984625 verified
import gradio as gr
from transformers import (
pipeline,
AutoProcessor,
AutoModelForCausalLM,
AutoTokenizer,
set_seed
)
from datasets import load_dataset
import torch
import numpy as np
# Set seed
set_seed(42)
# Captioning model
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
# GPT-2 model for context generation
gpt2_generator = pipeline("text-generation", model="gpt2")
# SpeechT5 for text-to-speech
synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")
# Load Florence-2-base for OCR
ocr_device = "cuda" if torch.cuda.is_available() else "cpu"
ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
# Load speaker embedding
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
def process_image(image):
try:
# Generate caption
caption = caption_model(image)[0]['generated_text']
# Extract OCR text
inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
generated_ids = ocr_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=False
)
extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Generate context with GPT-2
prompt = f"Determine the context of this image based on the caption and extracted text. Caption: {caption}. Extracted text: {extracted_text}. Context:"
context_output = gpt2_generator(prompt, max_length=100, num_return_sequences=1)
context = context_output[0]['generated_text']
# Text-to-speech
speech = synthesiser(context, forward_params={"speaker_embeddings": speaker_embedding})
audio = np.array(speech["audio"])
rate = speech["sampling_rate"]
return (rate, audio), caption, extracted_text, context
except Exception as e:
return None, f"Error: {str(e)}", "", ""
# Gradio UI
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type='pil', label="Upload an Image"),
outputs=[
gr.Audio(label="Generated Audio"),
gr.Textbox(label="Generated Caption"),
gr.Textbox(label="Extracted Text (OCR)"),
gr.Textbox(label="Generated Context")
],
title="SeeSay Contextualizer",
description="Upload an image to generate a caption, extract text, create audio from context, and determine the context using GPT-2 and Florence-2-base."
)
iface.launch()