|
|
|
|
|
|
|
"""Processor class for SparkTTS.""" |
|
|
|
import torch |
|
import re |
|
import numpy as np |
|
import warnings |
|
from typing import Optional, Dict, Any, Union, List, Tuple |
|
from pathlib import Path |
|
|
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.feature_extraction_utils import FeatureExtractionMixin |
|
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase |
|
from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor |
|
from transformers.utils import logging |
|
|
|
|
|
|
|
|
|
from .configuration_spark_tts import SparkTTSConfig |
|
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
TASK_TOKEN_MAP = { |
|
"tts": "<|task_tts|>", |
|
"controllable_tts": "<|task_controllable_tts|>", |
|
|
|
} |
|
LEVELS_MAP = {"very_low": 0, "low": 1, "moderate": 2, "high": 3, "very_high": 4} |
|
GENDER_MAP = {"female": 0, "male": 1} |
|
|
|
|
|
|
|
class SparkTTSProcessor(ProcessorMixin): |
|
r""" |
|
Constructs a SparkTTS processor which wraps a text tokenizer and an audio feature extractor |
|
into a single processor. |
|
|
|
[`SparkTTSProcessor`] offers all the functionalities of [`AutoTokenizer`] and [`Wav2Vec2FeatureExtractor`]. |
|
It processes text input for the LLM and prepares audio inputs if needed (delegating actual audio tokenization |
|
to the model). It also handles decoding the final output. |
|
|
|
Args: |
|
tokenizer (`PreTrainedTokenizerBase`): |
|
An instance of [`AutoTokenizer`]. The tokenizer is used to encode the prompt text. |
|
feature_extractor (`Wav2Vec2FeatureExtractor`): |
|
An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is used to processor reference audio |
|
(though the main processing happens inside the model). |
|
model (`PreTrainedModel`, *optional*): |
|
A reference to the loaded `SparkTTSModel`. This is REQUIRED for voice cloning (prompt audio processing) |
|
and final audio decoding, as these steps rely on the model's internal BiCodec and Wav2Vec2 components. |
|
Set this using `processor.model = model` after loading both. |
|
config (`SparkTTSConfig`, *optional*): |
|
The configuration object, needed for parameters like sample_rate. Can often be inferred from the model. |
|
""" |
|
attributes = ["tokenizer", "feature_extractor"] |
|
tokenizer_class = ("Qwen2TokenizerFast", "Qwen2Tokenizer") |
|
feature_extractor_class = ("Wav2Vec2FeatureExtractor",) |
|
|
|
def __init__(self, tokenizer=None, feature_extractor=None, model=None, config=None, **kwargs): |
|
if tokenizer is None: |
|
raise ValueError("SparkTTSProcessor requires a `tokenizer`.") |
|
if feature_extractor is None: |
|
|
|
raise ValueError("SparkTTSProcessor requires a `feature_extractor` (Wav2Vec2FeatureExtractor).") |
|
|
|
super().__init__(tokenizer, feature_extractor) |
|
self.model = model |
|
self.config = config |
|
|
|
|
|
self.sampling_rate = None |
|
if self.config and hasattr(self.config, 'sample_rate'): |
|
self.sampling_rate = self.config.sample_rate |
|
elif self.model and hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'): |
|
self.sampling_rate = self.model.config.sample_rate |
|
else: |
|
|
|
if hasattr(self.feature_extractor, 'sampling_rate'): |
|
self.sampling_rate = self.feature_extractor.sampling_rate |
|
else: |
|
logger.warning("Could not determine sampling rate. Defaulting to 16000. Set `processor.sampling_rate` manually if needed.") |
|
self.sampling_rate = 16000 |
|
|
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
""" |
|
Instantiate a [`SparkTTSProcessor`] from a pretrained processor configuration. |
|
|
|
Args: |
|
pretrained_model_name_or_path (`str` or `os.PathLike`): |
|
This can be either: |
|
- a string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co. |
|
- a path to a *directory* containing processor files saved using the `save_pretrained()` method, |
|
e.g., `./my_model_directory/`. |
|
**kwargs: |
|
Additional keyword arguments passed along to both `AutoTokenizer.from_pretrained()` and |
|
`AutoFeatureExtractor.from_pretrained()`. |
|
""" |
|
config = kwargs.pop("config", None) |
|
if config is None: |
|
|
|
try: |
|
config = SparkTTSConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
except Exception: |
|
logger.warning(f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. Processor might lack some config values.") |
|
config = None |
|
|
|
|
|
|
|
def _resolve_path(sub_path): |
|
p = Path(sub_path) |
|
if p.is_absolute(): |
|
return str(p) |
|
|
|
main_path = Path(pretrained_model_name_or_path) |
|
if main_path.is_dir(): |
|
resolved = main_path / p |
|
if resolved.exists(): |
|
return str(resolved) |
|
|
|
return sub_path |
|
|
|
|
|
llm_tokenizer_path = "./LLM" |
|
w2v_processor_path = "./wav2vec2-large-xlsr-53" |
|
if config: |
|
llm_tokenizer_path = getattr(config, 'llm_model_name_or_path', llm_tokenizer_path) |
|
w2v_processor_path = getattr(config, 'wav2vec2_model_name_or_path', w2v_processor_path) |
|
|
|
resolved_tokenizer_path = _resolve_path(llm_tokenizer_path) |
|
resolved_w2v_path = _resolve_path(w2v_processor_path) |
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(resolved_tokenizer_path, **kwargs) |
|
except Exception as e: |
|
raise OSError(f"Could not load tokenizer from {resolved_tokenizer_path}. Ensure path is correct and files exist. Original error: {e}") |
|
|
|
try: |
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(resolved_w2v_path, **kwargs) |
|
except Exception as e: |
|
raise OSError(f"Could not load feature extractor from {resolved_w2v_path}. Ensure path is correct and files exist. Original error: {e}") |
|
|
|
|
|
return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=config) |
|
|
|
|
|
def __call__(self, text: str = None, |
|
prompt_speech_path: Optional[str] = None, |
|
prompt_text: Optional[str] = None, |
|
gender: Optional[str] = None, |
|
pitch: Optional[str] = None, |
|
speed: Optional[str] = None, |
|
return_tensors: Optional[str] = "pt", |
|
**kwargs) -> BatchEncoding: |
|
""" |
|
Main method to process inputs for the SparkTTS model. |
|
|
|
Args: |
|
text (`str`): The text to be synthesized. |
|
prompt_speech_path (`str`, *optional*): Path to prompt audio for voice cloning. |
|
prompt_text (`str`, *optional*): Transcript of prompt audio. |
|
gender (`str`, *optional*): Target gender ('male' or 'female') for voice creation. |
|
pitch (`str`, *optional*): Target pitch level ('very_low'...'very_high') for voice creation. |
|
speed (`str`, *optional*): Target speed level ('very_low'...'very_high') for voice creation. |
|
return_tensors (`str`, *optional*, defaults to `"pt"`): |
|
Framework of the returned tensors (`"pt"` for PyTorch, `"np"` for NumPy). |
|
**kwargs: Additional arguments (currently ignored). |
|
|
|
Returns: |
|
`BatchEncoding`: A dictionary containing the `input_ids`, `attention_mask`, and optionally |
|
`global_token_ids_prompt` ready for the model's `.generate()` method. |
|
""" |
|
if text is None: |
|
raise ValueError("`text` input must be provided.") |
|
|
|
global_token_ids_prompt = None |
|
llm_prompt_string = "" |
|
|
|
if prompt_speech_path is not None: |
|
|
|
if self.model is None: |
|
raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for voice cloning.") |
|
if not hasattr(self.model, '_tokenize_audio'): |
|
raise AttributeError("The provided model object does not have the required '_tokenize_audio' method.") |
|
|
|
logger.info(f"Processing prompt audio: {prompt_speech_path}") |
|
|
|
try: |
|
|
|
global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path) |
|
global_token_ids_prompt = global_tokens |
|
except Exception as e: |
|
logger.error(f"Error tokenizing prompt audio: {e}", exc_info=True) |
|
raise RuntimeError(f"Failed to process prompt audio file: {prompt_speech_path}. Check file integrity and model compatibility.") from e |
|
|
|
|
|
global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()]) |
|
|
|
if prompt_text and len(prompt_text) > 1: |
|
semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()]) |
|
llm_prompt_parts = [ |
|
TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>", |
|
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>", |
|
"<|start_semantic_token|>", semantic_tokens_str, |
|
] |
|
else: |
|
llm_prompt_parts = [ |
|
TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>", |
|
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>", |
|
] |
|
llm_prompt_string = "".join(llm_prompt_parts) |
|
|
|
elif gender is not None and pitch is not None and speed is not None: |
|
|
|
if gender not in GENDER_MAP: raise ValueError(f"Invalid gender '{gender}'.") |
|
if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch '{pitch}'.") |
|
if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed '{speed}'.") |
|
|
|
gender_id = GENDER_MAP[gender] |
|
pitch_level_id = LEVELS_MAP[pitch] |
|
speed_level_id = LEVELS_MAP[speed] |
|
|
|
attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>" |
|
|
|
llm_prompt_parts = [ |
|
TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>", |
|
"<|start_style_label|>", attribute_tokens, "<|end_style_label|>", |
|
] |
|
llm_prompt_string = "".join(llm_prompt_parts) |
|
|
|
|
|
else: |
|
raise ValueError("Processor requires either 'prompt_speech_path' (for cloning) or 'gender', 'pitch', and 'speed' (for creation).") |
|
|
|
|
|
inputs = self.tokenizer(llm_prompt_string, return_tensors=return_tensors, padding=False, truncation=False) |
|
|
|
|
|
if global_token_ids_prompt is not None: |
|
inputs["global_token_ids_prompt"] = global_token_ids_prompt |
|
|
|
return inputs |
|
|
|
def decode(self, |
|
generated_ids: Union[List[int], np.ndarray, torch.Tensor], |
|
global_token_ids_prompt: Optional[torch.Tensor] = None, |
|
input_ids_len: Optional[int] = None, |
|
skip_special_tokens: bool = True) -> Dict[str, Any]: |
|
""" |
|
Decodes the raw token IDs generated by the model into an audio waveform. |
|
|
|
Args: |
|
generated_ids (`Union[List[int], np.ndarray, torch.Tensor]`): |
|
The token IDs generated by the `model.generate()` method. Assumed to be a single sequence (batch size 1). |
|
global_token_ids_prompt (`torch.Tensor`, *optional*): |
|
The global tokens obtained from the prompt audio during preprocessing (needed for voice cloning). |
|
Should be passed from the `__call__` output. |
|
input_ids_len (`int`, *optional*): |
|
The length of the original prompt `input_ids`. If provided, the prompt part will be stripped from |
|
`generated_ids` before decoding the text representation. If None, assumes `generated_ids` contains |
|
*only* the generated part. |
|
skip_special_tokens (`bool`, *optional*, defaults to `True`): |
|
Whether to skip special tokens when decoding the text representation for parsing. |
|
|
|
Returns: |
|
`Dict[str, Any]`: A dictionary containing: |
|
- `audio` (`np.ndarray`): The generated audio waveform. |
|
- `sampling_rate` (`int`): The sampling rate of the audio. |
|
""" |
|
if self.model is None: |
|
raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for decoding.") |
|
if not hasattr(self.model, '_detokenize_audio'): |
|
raise AttributeError("The provided model object does not have the required '_detokenize_audio' method.") |
|
if self.sampling_rate is None: |
|
raise ValueError("Processor could not determine sampling_rate. Set `processor.sampling_rate`.") |
|
|
|
|
|
if isinstance(generated_ids, (list, np.ndarray)): |
|
output_ids_tensor = torch.tensor(generated_ids) |
|
else: |
|
output_ids_tensor = generated_ids |
|
|
|
|
|
if input_ids_len is not None: |
|
|
|
if output_ids_tensor.ndim > 1: |
|
output_ids = output_ids_tensor[0, input_ids_len:] |
|
else: |
|
output_ids = output_ids_tensor[input_ids_len:] |
|
else: |
|
if output_ids_tensor.ndim > 1: |
|
output_ids = output_ids_tensor[0] |
|
else: |
|
output_ids = output_ids_tensor |
|
|
|
if output_ids.numel() == 0: |
|
logger.warning("Received empty generated IDs after removing prompt. Returning empty audio.") |
|
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate} |
|
|
|
|
|
predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens) |
|
|
|
|
|
semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text) |
|
if not semantic_matches: |
|
logger.warning("No semantic tokens found in the generated output text. Cannot synthesize audio.") |
|
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate} |
|
|
|
device = self.model.device |
|
pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches], dtype=torch.long, device=device).unsqueeze(0) |
|
|
|
|
|
if global_token_ids_prompt is not None: |
|
|
|
global_token_ids = global_token_ids_prompt.to(device) |
|
|
|
if global_token_ids.ndim == 2: |
|
global_token_ids = global_token_ids.unsqueeze(1) |
|
else: |
|
|
|
global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text) |
|
if not global_matches: |
|
logger.error("Voice creation failed: No global tokens found in generated text.") |
|
raise ValueError("Voice creation failed: Could not find bicodec_global tokens in the LLM output.") |
|
global_token_ids = torch.tensor([int(token) for token in global_matches], dtype=torch.long, device=device).unsqueeze(0) |
|
|
|
if global_token_ids.ndim == 2: |
|
global_token_ids = global_token_ids.unsqueeze(1) |
|
|
|
|
|
try: |
|
wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids) |
|
except Exception as e: |
|
logger.error(f"Error during audio detokenization: {e}", exc_info=True) |
|
raise RuntimeError("Failed to synthesize audio waveform from generated tokens.") from e |
|
|
|
return {"audio": wav_np, "sampling_rate": self.sampling_rate} |