Spaces:
Running
Running
# reference: https://huggingface.co/spaces/r3gm/Audio_separator | |
import gradio as gr | |
import shutil | |
import numpy as np | |
from pathlib import Path | |
import os | |
import time | |
import torch | |
from huggingface_hub import hf_hub_download | |
from uvr_processing import process_uvr_task, run_mdx, get_model_params | |
from utils import convert_to_stereo_and_wav | |
import onnxruntime as ort | |
import io | |
MODEL_ID = "masszhou/mdxnet" | |
MODELS_PATH = { | |
"bgm": Path(hf_hub_download(repo_id=MODEL_ID, filename="UVR-MDX-NET-Inst_HQ_3.onnx")), | |
"basic_vocal": Path(hf_hub_download(repo_id=MODEL_ID, filename="UVR-MDX-NET-Voc_FT.onnx")), | |
"main_vocal": Path(hf_hub_download(repo_id=MODEL_ID, filename="UVR_MDXNET_KARA_2.onnx")) | |
} | |
def get_device_info(): | |
if torch.cuda.is_available(): | |
device = f"GPU ({torch.cuda.get_device_name(0)})" | |
else: | |
device = "CPU" | |
return f"Current running environment: {device}" | |
def inference(audio_file: str, | |
stem: str = "vocal",) -> list[str]: | |
# audio_file = '/private/var/folders/02/_9ymjkz12xq8m_xh5592pl840000gn/T/gradio/74c3de047a439ea3cfb8e2d1ee6e5a85ea999d3eb30537b88d386aac177902d0/Spare Zeit und Aufwand mit den Servicevertragen von Mercedes-Benz Trucks..m4a' | |
if not audio_file: | |
raise ValueError("The audio path is missing.") | |
if not stem: | |
raise ValueError("Please select 'vocal' or 'background' stem.") | |
audio_file = Path(audio_file) | |
output_dir = Path("./output") | |
outputs = [] | |
start_time = time.time() | |
background_path, vocals_path = process_uvr_task( | |
input_file_path=audio_file, | |
output_dir=output_dir, | |
models_path=MODELS_PATH, | |
) | |
end_time = time.time() | |
execution_time = end_time - start_time | |
print(f"Execution time: {execution_time} seconds") | |
print(f"Background file: {background_path}") | |
print(f"Vocals file: {vocals_path}") | |
os.makedirs("static/results", exist_ok=True) | |
# shutil.copy(background_path, bg_dst) | |
# shutil.copy(vocals_path, vc_dst) | |
outputs.append(str(background_path)) | |
outputs.append(str(vocals_path)) | |
return outputs | |
def inference_bgm(audio_file: str) -> list[str]: | |
mdx_model_params = get_model_params(Path("./mdx_models")) | |
audio_file = convert_to_stereo_and_wav(Path(audio_file)) # resampling at 44100 Hz | |
device_base = "cuda" if torch.cuda.is_available() else "cpu" | |
output_dir = Path("./output") | |
model_bgm_path = MODELS_PATH["bgm"] | |
background_path, _ = run_mdx(model_params=mdx_model_params, | |
input_filename=audio_file, | |
output_dir=output_dir, | |
model_path=model_bgm_path, | |
denoise=False, | |
device_base=device_base, | |
) | |
return [str(background_path)] | |
def return_original_file(file): | |
# 读取原始文件名和内容 | |
filename = os.path.basename(file.name) | |
with open(file.path, "rb") as f: | |
content = f.read() | |
return (filename, content) | |
def get_gui(theme, title, description): | |
with gr.Blocks(theme=theme) as app: | |
# Add title and description | |
gr.Markdown(title) | |
gr.Markdown(description) | |
gr.Markdown(get_device_info()) | |
# audio_input = gr.Audio(label="Audio file", type="filepath") | |
# download_button = gr.Button("Inference") | |
# file_output = gr.File(label="Result", file_count="multiple", interactive=False) | |
# download_button.click( | |
# inference, | |
# inputs=[audio_input], | |
# outputs=[file_output], | |
# ) | |
audio_input = gr.File(file_types=[".mp3", ".wav"], label="上传音频") | |
download_btn = gr.DownloadData(return_original_file, audio_input) | |
return app | |
if __name__ == "__main__": | |
title = "<center><strong><font size='7'>Vocal BGM Separator</font></strong></center>" | |
description = "This demo uses the MDX-Net models to perform Ultimate Vocal Remover (uvr) task for vocal and background sound separation." | |
theme = "NoCrypt/miku" | |
model_id = "masszhou/mdxnet" | |
models_path = { | |
"bgm": Path(hf_hub_download(repo_id=model_id, filename="UVR-MDX-NET-Inst_HQ_3.onnx")), | |
"basic_vocal": Path(hf_hub_download(repo_id=model_id, filename="UVR-MDX-NET-Voc_FT.onnx")), | |
"main_vocal": Path(hf_hub_download(repo_id=model_id, filename="UVR_MDXNET_KARA_2.onnx")) | |
} | |
print(f"ort.get_available_providers(): {ort.get_available_providers()}") | |
print(gr.__version__) | |
# entry point for GUI | |
# predict(audio_file, api_name="/inference") -> result | |
app_gui = get_gui(theme, title, description) | |
# entry point for API | |
# predict(audio_file, api_name="/predict") -> output | |
app_api = gr.Interface( | |
fn=inference_bgm, | |
inputs=gr.Audio(type="filepath"), | |
outputs=gr.File(file_count="multiple"), | |
) | |
app = gr.TabbedInterface( | |
interface_list=[app_gui, app_api], | |
tab_names=["GUI", "API"] | |
) | |
app.queue(default_concurrency_limit=40) | |
app.launch() |