Spaces:
Running
Running
# Construct pairs of text and image | |
from configs import CFG | |
from costum_datasets import make_pairs | |
from text_image_audio import OneEncoder | |
import torch | |
import gradio as gr | |
import torchaudio | |
# Construct pairs of text and image | |
training_pairs = make_pairs(CFG.image_dir, CFG.image_dir, 5) # 413.915 -> 82.783 images | |
# Sorted according images | |
training_pairs = sorted(training_pairs, key=lambda x: x[0]) | |
coco_images, coco_captions = zip(*training_pairs) | |
# Take unique images | |
unique_images = set() | |
unique_pairs = [(item[0], item[1]) for item in training_pairs if item[0] not in unique_images | |
and not unique_images.add(item[0])] | |
coco_images, _ = zip(*unique_pairs) | |
# Load model (update) | |
model = OneEncoder.from_pretrained("bilalfaye/OneEncoder-text-image-audio") | |
# Load coco image features | |
coco_image_features = torch.load("image_embeddings_best.pt", map_location=CFG.device) | |
coco_image_features = coco_image_features[:3000] | |
def text_image(query): | |
model.text_image_encoder.image_retrieval(query, | |
image_paths=coco_images, | |
image_embeddings=coco_image_features, | |
n=9, | |
plot=True, | |
temperature=0.0 | |
) | |
return "img.png" | |
def audio_image(query): | |
# Load the audio with torchaudio (returns tensor and sample rate) | |
waveform, sample_rate = torchaudio.load(query) | |
# Check if audio is stereo | |
if waveform.shape[0] > 1: # Stereo (2 channels) | |
# Convert stereo to mono: sum the left and right channels and divide by 2 | |
mono_audio = waveform.mean(dim=0, keepdim=True) | |
else: | |
# Audio is already mono | |
mono_audio = waveform | |
# Resample to 16000 Hz if not already | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
mono_audio = resampler(mono_audio) | |
# Convert to numpy array for pipeline processing (if required) | |
mono_audio = mono_audio.squeeze(0).numpy() | |
audio_encoding = model.process_audio([mono_audio]) | |
model.image_retrieval(audio_encoding, | |
image_paths=coco_images, | |
image_embeddings=coco_image_features, | |
n=9, | |
plot=True, | |
temperature=0.0, | |
display_audio=False) | |
return "img.png" | |
# Updated Gradio Interface | |
iface = gr.TabbedInterface( | |
[ | |
gr.Interface( | |
fn=text_image, | |
inputs=gr.Textbox(label="Text Query"), | |
outputs="image", | |
title="Retrieve images using text as query", | |
description="Implementation of OneEncoder using one layer on UP for light demo, Only coco train dataset is used in this example (3000 images)." | |
), | |
gr.Interface( | |
fn=audio_image, | |
inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Provide Audio Query"), | |
outputs="image", | |
title="Retrieve images using audio as query", | |
description="Implementation of OneEncoder using one layer on UP for light demo, Only coco train dataset is used in this example (3000 images)." | |
) | |
], | |
tab_names=["Text - Image", "Audio - Image"] | |
) | |
iface.launch(debug=True, share=True) | |