Harshithtd's picture
Create app.py
b1c60a9 verified
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import numpy as np
import cv2
# Load the pre-trained CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def apply_gradcam(image, text):
inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
image_embeds = outputs.image_embeds
text_embeds = outputs.text_embeds
similarity = torch.nn.functional.cosine_similarity(image_embeds, text_embeds)
similarity.backward()
gradients = model.get_input_embeddings().weight.grad
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
activations = outputs.last_hidden_state
for i in range(pooled_gradients.shape[0]):
activations[:, i, :, :] *= pooled_gradients[i]
heatmap = torch.mean(activations, dim=1).squeeze().detach().cpu().numpy()
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
heatmap = cv2.resize(heatmap, (image.size[0], image.size[1]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = cv2.addWeighted(np.array(image), 0.6, heatmap, 0.4, 0)
return superimposed_img
def highlight_image(image, text):
highlighted_image = apply_gradcam(image, text)
return Image.fromarray(highlighted_image)
# Define Gradio interface
iface = gr.Interface(
fn=highlight_image,
inputs=[gr.Image(type="pil"), gr.Textbox(label="Text Description")],
outputs=gr.Image(type="pil"),
title="Image Text Highlight",
description="Upload an image and provide a text description to highlight the relevant part of the image."
)
# Launch the Gradio app
iface.launch()