raphaelmerx's picture
Use newer gradio version
66d6b16
raw
history blame
2.6 kB
import gradio as gr
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
import numpy as np
import librosa
import json
with open('ISO_codes.json', 'r') as file:
iso_codes = json.load(file)
languages = list(iso_codes.keys())
model_id = "facebook/mms-1b-all"
processor = AutoProcessor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)
def transcribe(audio_file_mic=None, audio_file_upload=None, language="English (eng)", progress=gr.Progress()):
if audio_file_mic:
audio_file = audio_file_mic
elif audio_file_upload:
audio_file = audio_file_upload
else:
return "Please upload an audio file or record one"
progress(0, desc="Starting")
# Make sure audio is 16kHz
speech, sample_rate = librosa.load(audio_file)
if sample_rate != 16000:
progress(1, desc="Resampling")
speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000)
# Cut speech into chunks
chunk_size = 30 * 16000 # 30s * 16000Hz
chunks = np.split(speech, np.arange(chunk_size, len(speech), chunk_size))
# load model adapter for this language
language_code = iso_codes[language]
processor.tokenizer.set_target_lang(language_code)
model.load_adapter(language_code)
transcriptions = []
progress(2, desc="Transcribing")
for chunk in progress.tqdm(chunks, desc="Transcribing"):
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)
transcriptions.append(transcription)
transcription = ' '.join(transcriptions)
return transcription
examples = [
["balinese.mp3", None, "Bali (ban)"],
["madura.mp3", None, "Madura (mad)"],
["toba_batak.mp3", None, "Batak Toba (bbc)"],
["minangkabau.mp3", None, "Minangkabau (min)"],
]
description = '''Automatic Speech Recognition with [MMS](https://ai.facebook.com/blog/multilingual-model-speech-recognition/) (Massively Multilingual Speech) by Meta.'''
demo = gr.Interface(
transcribe,
inputs=[
gr.Audio(source="microphone", type="filepath", label="Record Audio"),
gr.Audio(source="upload", type="filepath", label="Upload Audio"),
gr.Dropdown(choices=languages, label="Language", value="English (eng)")
],
outputs=gr.Textbox(label="Transcription"),
examples=examples,
description=description
)
if __name__ == "__main__":
demo.queue(concurrency_count=1).launch()