File size: 3,377 Bytes
7786bd6
 
 
 
 
 
 
 
 
 
 
 
 
7aa0ea9
7786bd6
 
 
 
 
 
 
 
 
 
 
 
f2027e2
990c0d6
7786bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Construct pairs of text and image
from configs import CFG
from costum_datasets import make_pairs


from text_image_audio import OneEncoder
import torch

import gradio as gr

import torchaudio

# Construct pairs of text and image
training_pairs = make_pairs(CFG.image_dir, CFG.image_dir, 5) # 413.915 -> 82.783 images

# Sorted according images
training_pairs = sorted(training_pairs, key=lambda x: x[0])

coco_images, coco_captions = zip(*training_pairs)

# Take unique images
unique_images = set()
unique_pairs = [(item[0], item[1]) for item in training_pairs if item[0] not in unique_images
                and not unique_images.add(item[0])]
coco_images, _ = zip(*unique_pairs)

# Load model (update)
model = OneEncoder.from_pretrained("bilalfaye/OneEncoder-text-image-audio")

# Load coco image features
coco_image_features = torch.load("image_embeddings_best.pt", map_location=CFG.device)
coco_image_features = coco_image_features[:3000]

def text_image(query):
    model.text_image_encoder.image_retrieval(query,
                          image_paths=coco_images,
                          image_embeddings=coco_image_features,
                          n=9,
                          plot=True,
                          temperature=0.0
                          )
    return "img.png"

def audio_image(query):
    # Load the audio with torchaudio (returns tensor and sample rate)
    waveform, sample_rate = torchaudio.load(query)

    # Check if audio is stereo
    if waveform.shape[0] > 1:  # Stereo (2 channels)
        # Convert stereo to mono: sum the left and right channels and divide by 2
        mono_audio = waveform.mean(dim=0, keepdim=True)
    else:
        # Audio is already mono
        mono_audio = waveform

    # Resample to 16000 Hz if not already
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        mono_audio = resampler(mono_audio)

    # Convert to numpy array for pipeline processing (if required)
    mono_audio = mono_audio.squeeze(0).numpy()

    audio_encoding = model.process_audio([mono_audio])

    model.image_retrieval(audio_encoding, 
                          image_paths=coco_images,
                          image_embeddings=coco_image_features, 
                          n=9, 
                          plot=True, 
                          temperature=0.0,
                          display_audio=False)

    return "img.png"


# Updated Gradio Interface
iface = gr.TabbedInterface(
    [
        gr.Interface(
            fn=text_image,
            inputs=gr.Textbox(label="Text Query"),
            outputs="image",
            title="Retrieve images using text as query",
            description="Implementation of OneEncoder using one layer on UP for light demo, Only coco train dataset is used in this example (3000 images)."
        ),
        gr.Interface(
            fn=audio_image,
            inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Provide Audio Query"),
            outputs="image",
            title="Retrieve images using audio as query",
            description="Implementation of OneEncoder using one layer on UP for light demo, Only coco train dataset is used in this example (3000 images)."
        )
    ],
    tab_names=["Text - Image", "Audio - Image"]
)

iface.launch(debug=True, share=True)