Spaces:
Running
Running
import os | |
import torch | |
from .pytorch_utils import move_data_to_device | |
from .models import Cnn14_DecisionLevelMax | |
from .config import labels, classes_num | |
def create_folder(fd): | |
if not os.path.exists(fd): | |
os.makedirs(fd) | |
def get_filename(path): | |
path = os.path.realpath(path) | |
na_ext = path.split('/')[-1] | |
na = os.path.splitext(na_ext)[0] | |
return na | |
class SoundEventDetection(object): | |
def __init__(self, model=None, checkpoint_path=None, device='cuda'): | |
"""Sound event detection inference wrapper. | |
""" | |
if not checkpoint_path: | |
checkpoint_path='panns_data/Cnn14_DecisionLevelMax.pth' # moved to current directory | |
print('Checkpoint path: {}'.format(checkpoint_path)) | |
if not os.path.exists(checkpoint_path) or os.path.getsize(checkpoint_path) < 3e8: | |
create_folder(os.path.dirname(checkpoint_path)) | |
os.system('wget -O "{}" https://zenodo.org/record/3987831/files/Cnn14_DecisionLevelMax_mAP%3D0.385.pth?download=1'.format(checkpoint_path)) | |
if device == 'cuda' and torch.cuda.is_available(): | |
self.device = 'cuda' | |
else: | |
self.device = 'cpu' | |
self.labels = labels | |
self.classes_num = classes_num | |
# Model | |
if model is None: | |
self.model = Cnn14_DecisionLevelMax(sample_rate=32000, window_size=1024, | |
hop_size=320, mel_bins=64, fmin=50, fmax=14000, | |
classes_num=self.classes_num) | |
else: | |
self.model = model | |
checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
self.model.load_state_dict(checkpoint['model']) | |
# Parallel | |
if 'cuda' in str(self.device): | |
self.model.to(self.device) | |
print('GPU number: {}'.format(torch.cuda.device_count())) | |
self.model = torch.nn.DataParallel(self.model) | |
else: | |
print('Using CPU.') | |
def inference(self, audio): | |
audio = move_data_to_device(audio, self.device) | |
with torch.no_grad(): | |
self.model.eval() | |
output_dict = self.model(audio, None) | |
framewise_output = output_dict['framewise_output'].data.cpu().numpy() | |
return framewise_output | |