Vi-SparkTTS-0.5B / _utils.py
ancv's picture
Upload 24 files
5cd61b8 verified
raw
history blame
4.96 kB
# Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Utility functions for SparkTTS """
import random
import soxr
import soundfile
import torch
import torchaudio
import numpy as np
from pathlib import Path
from typing import Tuple, Dict, Any
from numpy.lib.stride_tricks import sliding_window_view
from omegaconf import OmegaConf # Keep if BiCodec config loading needs it
# --- Token Maps (from sparktts/utils/token_parser.py) ---
TASK_TOKEN_MAP = {
"vc": "<|task_vc|>",
"tts": "<|task_tts|>",
"asr": "<|task_asr|>",
"s2s": "<|task_s2s|>",
"t2s": "<|task_t2s|>",
"understand": "<|task_understand|>",
"caption": "<|task_cap|>",
"controllable_tts": "<|task_controllable_tts|>",
"prompt_tts": "<|task_prompt_tts|>",
"speech_edit": "<|task_edit|>",
}
LEVELS_MAP = {
"very_low": 0,
"low": 1,
"moderate": 2,
"high": 3,
"very_high": 4,
}
LEVELS_MAP_UI = {
1: 'very_low',
2: 'low',
3: 'moderate',
4: 'high',
5: 'very_high'
}
GENDER_MAP = {
"female": 0,
"male": 1,
}
# --- Audio Utils (from sparktts/utils/audio.py) ---
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
temp = np.sort(np.abs(audio))
if len(temp) == 0: # Handle empty audio case
return audio
if temp[-1] < 0.1:
scaling_factor = max(temp[-1], 1e-3)
audio = audio / scaling_factor * 0.1
temp = temp[temp > 0.01]
L = temp.shape[0]
if L <= 10:
return audio
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
if volume == 0: # Avoid division by zero if volume is effectively zero
return audio
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
max_value = np.max(np.abs(audio)) if len(audio) > 0 else 0
if max_value > 1:
audio = audio / max_value
return audio
def load_audio(
adfile: Path,
sampling_rate: int = None,
length: int = None,
volume_normalize: bool = False,
segment_duration: int = None,
) -> np.ndarray:
try:
audio, sr = soundfile.read(adfile, dtype='float32') # Ensure float32
except Exception as e:
raise IOError(f"Could not read audio file {adfile}: {e}")
if audio is None or len(audio) == 0:
raise ValueError(f"Audio file {adfile} is empty or invalid.")
if len(audio.shape) > 1:
audio = audio[:, 0]
if sampling_rate is not None and sr != sampling_rate:
try:
# Ensure input is float64 for soxr
audio = audio.astype(np.float64)
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
# Convert back to float32
audio = audio.astype(np.float32)
sr = sampling_rate
except Exception as e:
raise RuntimeError(f"Failed to resample audio from {sr}Hz to {sampling_rate}Hz: {e}")
if segment_duration is not None:
seg_length = int(sr * segment_duration)
audio = random_select_audio_segment(audio, seg_length)
if volume_normalize:
audio = audio_volume_normalize(audio)
if length is not None:
if audio.shape[0] > length:
audio = audio[:length]
else:
audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant')
return audio
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
if audio.shape[0] < length:
audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant')
start_index = 0 # If padded, start from beginning
elif audio.shape[0] == length:
start_index = 0 # If exact length, start from beginning
else:
start_index = random.randint(0, audio.shape[0] - length)
end_index = int(start_index + length)
return audio[start_index:end_index]
# --- File Utils (Minimal required) ---
def load_config_yaml(config_path: Path) -> Dict:
"""Loads a YAML configuration file using OmegaConf."""
# Check if path exists
if not Path(config_path).is_file():
raise FileNotFoundError(f"YAML Config file not found: {config_path}")
try:
config = OmegaConf.load(config_path)
# Convert OmegaConf DictConfig to standard Python dict
return OmegaConf.to_container(config, resolve=True)
except Exception as e:
raise IOError(f"Error loading YAML config file {config_path}: {e}")