Vi-SparkTTS-0.5B / processing_spark_tts.py
ancv's picture
Upload 24 files
5cd61b8 verified
raw
history blame
18.4 kB
# coding=utf-8
# Copyright 2024 The SparkAudio Authors and The HuggingFace Inc. team. All rights reserved.
# ... (license) ...
"""Processor class for SparkTTS."""
import torch
import re
import numpy as np
import warnings
from typing import Optional, Dict, Any, Union, List, Tuple
from pathlib import Path
from transformers.processing_utils import ProcessorMixin
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor
from transformers.utils import logging
# Import necessary items directly or ensure they are available via model reference
# Note: Avoid direct model imports here if possible, rely on the model reference.
# from .modeling_spark_tts import SparkTTSModel # Avoid direct model import if possible
from .configuration_spark_tts import SparkTTSConfig # Config is okay
# Import utils needed for prompt formatting (assuming they are merged into modeling)
# We'll access them via the model reference if needed, or duplicate simple ones like token maps.
logger = logging.get_logger(__name__)
# --- Token Maps (Duplicate here for direct use in processor) ---
TASK_TOKEN_MAP = {
"tts": "<|task_tts|>",
"controllable_tts": "<|task_controllable_tts|>",
# Add other tasks if needed by processor logic
}
LEVELS_MAP = {"very_low": 0, "low": 1, "moderate": 2, "high": 3, "very_high": 4}
GENDER_MAP = {"female": 0, "male": 1}
# --- End Token Maps ---
class SparkTTSProcessor(ProcessorMixin):
r"""
Constructs a SparkTTS processor which wraps a text tokenizer and an audio feature extractor
into a single processor.
[`SparkTTSProcessor`] offers all the functionalities of [`AutoTokenizer`] and [`Wav2Vec2FeatureExtractor`].
It processes text input for the LLM and prepares audio inputs if needed (delegating actual audio tokenization
to the model). It also handles decoding the final output.
Args:
tokenizer (`PreTrainedTokenizerBase`):
An instance of [`AutoTokenizer`]. The tokenizer is used to encode the prompt text.
feature_extractor (`Wav2Vec2FeatureExtractor`):
An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is used to processor reference audio
(though the main processing happens inside the model).
model (`PreTrainedModel`, *optional*):
A reference to the loaded `SparkTTSModel`. This is REQUIRED for voice cloning (prompt audio processing)
and final audio decoding, as these steps rely on the model's internal BiCodec and Wav2Vec2 components.
Set this using `processor.model = model` after loading both.
config (`SparkTTSConfig`, *optional*):
The configuration object, needed for parameters like sample_rate. Can often be inferred from the model.
"""
attributes = ["tokenizer", "feature_extractor"]
tokenizer_class = ("Qwen2TokenizerFast", "Qwen2Tokenizer") # Specify the underlying tokenizer type
feature_extractor_class = ("Wav2Vec2FeatureExtractor",) # Specify the underlying feature extractor type
def __init__(self, tokenizer=None, feature_extractor=None, model=None, config=None, **kwargs):
if tokenizer is None:
raise ValueError("SparkTTSProcessor requires a `tokenizer`.")
if feature_extractor is None:
# Attempt to load default if path is known or provide clearer error
raise ValueError("SparkTTSProcessor requires a `feature_extractor` (Wav2Vec2FeatureExtractor).")
super().__init__(tokenizer, feature_extractor)
self.model = model # Store model reference (can be None initially)
self.config = config # Store config reference
# Get sampling rate from config if available
self.sampling_rate = None
if self.config and hasattr(self.config, 'sample_rate'):
self.sampling_rate = self.config.sample_rate
elif self.model and hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'):
self.sampling_rate = self.model.config.sample_rate
else:
# Try feature extractor default, or raise warning
if hasattr(self.feature_extractor, 'sampling_rate'):
self.sampling_rate = self.feature_extractor.sampling_rate
else:
logger.warning("Could not determine sampling rate. Defaulting to 16000. Set `processor.sampling_rate` manually if needed.")
self.sampling_rate = 16000
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Instantiate a [`SparkTTSProcessor`] from a pretrained processor configuration.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co.
- a path to a *directory* containing processor files saved using the `save_pretrained()` method,
e.g., `./my_model_directory/`.
**kwargs:
Additional keyword arguments passed along to both `AutoTokenizer.from_pretrained()` and
`AutoFeatureExtractor.from_pretrained()`.
"""
config = kwargs.pop("config", None)
if config is None:
# Try loading the specific config first
try:
config = SparkTTSConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
except Exception:
logger.warning(f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. Processor might lack some config values.")
config = None
# Resolve component paths relative to the main path
def _resolve_path(sub_path):
p = Path(sub_path)
if p.is_absolute():
return str(p)
# Try resolving relative to the main path if it's a directory
main_path = Path(pretrained_model_name_or_path)
if main_path.is_dir():
resolved = main_path / p
if resolved.exists():
return str(resolved)
# Fallback to assuming sub_path is relative within a repo structure (might fail for local non-dirs)
return sub_path
# Determine paths from config or assume defaults
llm_tokenizer_path = "./LLM"
w2v_processor_path = "./wav2vec2-large-xlsr-53"
if config:
llm_tokenizer_path = getattr(config, 'llm_model_name_or_path', llm_tokenizer_path)
w2v_processor_path = getattr(config, 'wav2vec2_model_name_or_path', w2v_processor_path)
resolved_tokenizer_path = _resolve_path(llm_tokenizer_path)
resolved_w2v_path = _resolve_path(w2v_processor_path)
try:
tokenizer = AutoTokenizer.from_pretrained(resolved_tokenizer_path, **kwargs)
except Exception as e:
raise OSError(f"Could not load tokenizer from {resolved_tokenizer_path}. Ensure path is correct and files exist. Original error: {e}")
try:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(resolved_w2v_path, **kwargs)
except Exception as e:
raise OSError(f"Could not load feature extractor from {resolved_w2v_path}. Ensure path is correct and files exist. Original error: {e}")
# The 'model' attribute will be set later externally
return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=config)
def __call__(self, text: str = None,
prompt_speech_path: Optional[str] = None,
prompt_text: Optional[str] = None,
gender: Optional[str] = None,
pitch: Optional[str] = None,
speed: Optional[str] = None,
return_tensors: Optional[str] = "pt",
**kwargs) -> BatchEncoding:
"""
Main method to process inputs for the SparkTTS model.
Args:
text (`str`): The text to be synthesized.
prompt_speech_path (`str`, *optional*): Path to prompt audio for voice cloning.
prompt_text (`str`, *optional*): Transcript of prompt audio.
gender (`str`, *optional*): Target gender ('male' or 'female') for voice creation.
pitch (`str`, *optional*): Target pitch level ('very_low'...'very_high') for voice creation.
speed (`str`, *optional*): Target speed level ('very_low'...'very_high') for voice creation.
return_tensors (`str`, *optional*, defaults to `"pt"`):
Framework of the returned tensors (`"pt"` for PyTorch, `"np"` for NumPy).
**kwargs: Additional arguments (currently ignored).
Returns:
`BatchEncoding`: A dictionary containing the `input_ids`, `attention_mask`, and optionally
`global_token_ids_prompt` ready for the model's `.generate()` method.
"""
if text is None:
raise ValueError("`text` input must be provided.")
global_token_ids_prompt = None
llm_prompt_string = ""
if prompt_speech_path is not None:
# --- Voice Cloning Mode ---
if self.model is None:
raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for voice cloning.")
if not hasattr(self.model, '_tokenize_audio'):
raise AttributeError("The provided model object does not have the required '_tokenize_audio' method.")
logger.info(f"Processing prompt audio: {prompt_speech_path}")
# Delegate audio tokenization to the model
try:
# _tokenize_audio returns (global_tokens, semantic_tokens)
global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path)
global_token_ids_prompt = global_tokens # Keep for decoding stage
except Exception as e:
logger.error(f"Error tokenizing prompt audio: {e}", exc_info=True)
raise RuntimeError(f"Failed to process prompt audio file: {prompt_speech_path}. Check file integrity and model compatibility.") from e
# Format prompt string using token maps
global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()])
if prompt_text and len(prompt_text) > 1:
semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()])
llm_prompt_parts = [
TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>",
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
"<|start_semantic_token|>", semantic_tokens_str,
]
else:
llm_prompt_parts = [
TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>",
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
]
llm_prompt_string = "".join(llm_prompt_parts)
elif gender is not None and pitch is not None and speed is not None:
# --- Voice Creation Mode ---
if gender not in GENDER_MAP: raise ValueError(f"Invalid gender '{gender}'.")
if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch '{pitch}'.")
if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed '{speed}'.")
gender_id = GENDER_MAP[gender]
pitch_level_id = LEVELS_MAP[pitch]
speed_level_id = LEVELS_MAP[speed]
attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>"
llm_prompt_parts = [
TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>",
"<|start_style_label|>", attribute_tokens, "<|end_style_label|>",
]
llm_prompt_string = "".join(llm_prompt_parts)
# No global_token_ids_prompt needed
else:
raise ValueError("Processor requires either 'prompt_speech_path' (for cloning) or 'gender', 'pitch', and 'speed' (for creation).")
# Tokenize the final LLM prompt string
inputs = self.tokenizer(llm_prompt_string, return_tensors=return_tensors, padding=False, truncation=False)
# Add prompt global tokens to the output if they exist (for passing to decode)
if global_token_ids_prompt is not None:
inputs["global_token_ids_prompt"] = global_token_ids_prompt
return inputs
def decode(self,
generated_ids: Union[List[int], np.ndarray, torch.Tensor],
global_token_ids_prompt: Optional[torch.Tensor] = None,
input_ids_len: Optional[int] = None,
skip_special_tokens: bool = True) -> Dict[str, Any]:
"""
Decodes the raw token IDs generated by the model into an audio waveform.
Args:
generated_ids (`Union[List[int], np.ndarray, torch.Tensor]`):
The token IDs generated by the `model.generate()` method. Assumed to be a single sequence (batch size 1).
global_token_ids_prompt (`torch.Tensor`, *optional*):
The global tokens obtained from the prompt audio during preprocessing (needed for voice cloning).
Should be passed from the `__call__` output.
input_ids_len (`int`, *optional*):
The length of the original prompt `input_ids`. If provided, the prompt part will be stripped from
`generated_ids` before decoding the text representation. If None, assumes `generated_ids` contains
*only* the generated part.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether to skip special tokens when decoding the text representation for parsing.
Returns:
`Dict[str, Any]`: A dictionary containing:
- `audio` (`np.ndarray`): The generated audio waveform.
- `sampling_rate` (`int`): The sampling rate of the audio.
"""
if self.model is None:
raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for decoding.")
if not hasattr(self.model, '_detokenize_audio'):
raise AttributeError("The provided model object does not have the required '_detokenize_audio' method.")
if self.sampling_rate is None:
raise ValueError("Processor could not determine sampling_rate. Set `processor.sampling_rate`.")
# Ensure generated_ids is a tensor on the correct device
if isinstance(generated_ids, (list, np.ndarray)):
output_ids_tensor = torch.tensor(generated_ids)
else:
output_ids_tensor = generated_ids
# Remove prompt if input_ids_len is provided
if input_ids_len is not None:
# Handle potential batch dimension if present (though usually not for decode)
if output_ids_tensor.ndim > 1:
output_ids = output_ids_tensor[0, input_ids_len:]
else:
output_ids = output_ids_tensor[input_ids_len:]
else:
if output_ids_tensor.ndim > 1:
output_ids = output_ids_tensor[0]
else:
output_ids = output_ids_tensor
if output_ids.numel() == 0:
logger.warning("Received empty generated IDs after removing prompt. Returning empty audio.")
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
# Decode the text representation to parse tokens
predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens)
# Extract semantic tokens
semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text)
if not semantic_matches:
logger.warning("No semantic tokens found in the generated output text. Cannot synthesize audio.")
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
# Use model's device for tensors
device = self.model.device
pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches], dtype=torch.long, device=device).unsqueeze(0) # Add batch dim
# Determine global tokens
if global_token_ids_prompt is not None:
# Voice Cloning: Use prompt global tokens
global_token_ids = global_token_ids_prompt.to(device)
# Ensure correct shape (B, T_token, Q) or (B, D) - BiCodec detokenize needs to handle this
if global_token_ids.ndim == 2: # If (B, D), maybe unsqueeze? Check BiCodec.detokenize expectation
global_token_ids = global_token_ids.unsqueeze(1) # Assume (B, 1, D) might be needed
else:
# Voice Creation: Parse global tokens from generated text
global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text)
if not global_matches:
logger.error("Voice creation failed: No global tokens found in generated text.")
raise ValueError("Voice creation failed: Could not find bicodec_global tokens in the LLM output.")
global_token_ids = torch.tensor([int(token) for token in global_matches], dtype=torch.long, device=device).unsqueeze(0) # Add batch dim
# Add sequence dimension if needed (check BiCodec.detokenize)
if global_token_ids.ndim == 2:
global_token_ids = global_token_ids.unsqueeze(1) # Assume (B, 1, D)
# Detokenize audio using the model's method
try:
wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids)
except Exception as e:
logger.error(f"Error during audio detokenization: {e}", exc_info=True)
raise RuntimeError("Failed to synthesize audio waveform from generated tokens.") from e
return {"audio": wav_np, "sampling_rate": self.sampling_rate}