bilalfaye's picture
Update app.py
990c0d6 verified
# Construct pairs of text and image
from configs import CFG
from costum_datasets import make_pairs
from text_image_audio import OneEncoder
import torch
import gradio as gr
import torchaudio
# Construct pairs of text and image
training_pairs = make_pairs(CFG.image_dir, CFG.image_dir, 5) # 413.915 -> 82.783 images
# Sorted according images
training_pairs = sorted(training_pairs, key=lambda x: x[0])
coco_images, coco_captions = zip(*training_pairs)
# Take unique images
unique_images = set()
unique_pairs = [(item[0], item[1]) for item in training_pairs if item[0] not in unique_images
and not unique_images.add(item[0])]
coco_images, _ = zip(*unique_pairs)
# Load model (update)
model = OneEncoder.from_pretrained("bilalfaye/OneEncoder-text-image-audio")
# Load coco image features
coco_image_features = torch.load("image_embeddings_best.pt", map_location=CFG.device)
coco_image_features = coco_image_features[:3000]
def text_image(query):
model.text_image_encoder.image_retrieval(query,
image_paths=coco_images,
image_embeddings=coco_image_features,
n=9,
plot=True,
temperature=0.0
)
return "img.png"
def audio_image(query):
# Load the audio with torchaudio (returns tensor and sample rate)
waveform, sample_rate = torchaudio.load(query)
# Check if audio is stereo
if waveform.shape[0] > 1: # Stereo (2 channels)
# Convert stereo to mono: sum the left and right channels and divide by 2
mono_audio = waveform.mean(dim=0, keepdim=True)
else:
# Audio is already mono
mono_audio = waveform
# Resample to 16000 Hz if not already
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
mono_audio = resampler(mono_audio)
# Convert to numpy array for pipeline processing (if required)
mono_audio = mono_audio.squeeze(0).numpy()
audio_encoding = model.process_audio([mono_audio])
model.image_retrieval(audio_encoding,
image_paths=coco_images,
image_embeddings=coco_image_features,
n=9,
plot=True,
temperature=0.0,
display_audio=False)
return "img.png"
# Updated Gradio Interface
iface = gr.TabbedInterface(
[
gr.Interface(
fn=text_image,
inputs=gr.Textbox(label="Text Query"),
outputs="image",
title="Retrieve images using text as query",
description="Implementation of OneEncoder using one layer on UP for light demo, Only coco train dataset is used in this example (3000 images)."
),
gr.Interface(
fn=audio_image,
inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Provide Audio Query"),
outputs="image",
title="Retrieve images using audio as query",
description="Implementation of OneEncoder using one layer on UP for light demo, Only coco train dataset is used in this example (3000 images)."
)
],
tab_names=["Text - Image", "Audio - Image"]
)
iface.launch(debug=True, share=True)