kubinooo's picture
fixing broken predictions - debug prints enabled
01d3f3f
"""
Module needed for pre-processing of uploaded audio
Uses silero_vad for silence removal and librosa for image generation
Author: Jakub Polnis
Copyright: Copyright 2025, Jakub Polnis
License: Apache 2.0
Email: [email protected]
"""
import io
import torch
import librosa
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from silero_vad import (load_silero_vad,
read_audio,
get_speech_timestamps,
save_audio,
VADIterator,
collect_chunks)
USE_ONNX = False
model = load_silero_vad(onnx=USE_ONNX)
SAMPLING_RATE = 16000
def silero_vad_remove_silence(audio_file_path):
torch.set_num_threads(1)
audio = read_audio(audio_file_path, sampling_rate=SAMPLING_RATE)
# Get speech timestamps from full audio file
speech_timestamps = get_speech_timestamps(audio, model, sampling_rate=SAMPLING_RATE)
if not speech_timestamps:
print(f"No speech detected in {audio_file_path}. Returning original audio.")
return audio # Return unmodified audio
else:
# Merge all speech chunks and return the result
processed_audio = collect_chunks(speech_timestamps, audio)
return processed_audio
def create_mel_spectrograms(file_path, segment_duration, start_offset):
duration = segment_duration
startOffset = start_offset
pil_images = []
# Call silero_vad to remove silence
processed_audio = silero_vad_remove_silence(file_path)
y = processed_audio.numpy()
sr = SAMPLING_RATE
# Calc duration of audio in seconds
audio_duration = librosa.get_duration(y=y, sr=sr)
# Calc duration of audio file in samples
segment_duration_samples = int(duration * sr)
# Calc the closest round number in seconds
rounded_duration = int(np.round(audio_duration))
# Trim the signal
if len(y) > rounded_duration * sr:
y = y[:rounded_duration * sr]
elif len(y) < rounded_duration * sr:
y = np.pad(y, (0, rounded_duration * sr - len(y)), mode='constant')
# Loop through the signal
for i in range(int(rounded_duration)):
# Starting index
start_sample = i * sr
# End index
end_sample = start_sample + segment_duration_samples
if end_sample > len(y):
continue
y_segment = y[start_sample:end_sample]
if len(y_segment) > 0:
# Creat mel-spectrogram
S = librosa.feature.melspectrogram(y=y_segment, sr=sr, n_mels=128, fmax=8000, center=True)
# Save it as img
fig, ax = plt.subplots(figsize=(224 / 100, 224 / 100))
# power_to_db
S_dB = librosa.power_to_db(S, ref=np.max)
# Setup axis
img = librosa.display.specshow(S_dB, sr=sr, fmax=8000, ax=ax)
ax.set_xlim(0, S.shape[-1])
ax.set_ylim(0, S.shape[0])
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
# Save into the buffer so we can return PIL images
buffer = io.BytesIO()
plt.savefig(buffer, format='PNG', bbox_inches=None, pad_inches=0,
dpi=100, transparent=True)
buffer.seek(0)
# Convert buffer to PIL Image
pil_image = Image.open(buffer)
pil_images.append(pil_image.copy()) # Copy to avoid buffer issues
# Close buffer and figure to free memory
buffer.close()
plt.close(fig)
print(pil_images)
return pil_images