# coding=utf-8 # Copyright 2024 The SparkAudio Authors and The HuggingFace Inc. team. All rights reserved. # ... (license) ... """Processor class for SparkTTS.""" import torch import re import numpy as np import warnings from typing import Optional, Dict, Any, Union, List, Tuple from pathlib import Path from transformers.processing_utils import ProcessorMixin from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor from transformers.utils import logging # Import necessary items directly or ensure they are available via model reference # Note: Avoid direct model imports here if possible, rely on the model reference. # from .modeling_spark_tts import SparkTTSModel # Avoid direct model import if possible from .configuration_spark_tts import SparkTTSConfig # Config is okay # Import utils needed for prompt formatting (assuming they are merged into modeling) # We'll access them via the model reference if needed, or duplicate simple ones like token maps. logger = logging.get_logger(__name__) # --- Token Maps (Duplicate here for direct use in processor) --- TASK_TOKEN_MAP = { "tts": "<|task_tts|>", "controllable_tts": "<|task_controllable_tts|>", # Add other tasks if needed by processor logic } LEVELS_MAP = {"very_low": 0, "low": 1, "moderate": 2, "high": 3, "very_high": 4} GENDER_MAP = {"female": 0, "male": 1} # --- End Token Maps --- class SparkTTSProcessor(ProcessorMixin): r""" Constructs a SparkTTS processor which wraps a text tokenizer and an audio feature extractor into a single processor. [`SparkTTSProcessor`] offers all the functionalities of [`AutoTokenizer`] and [`Wav2Vec2FeatureExtractor`]. It processes text input for the LLM and prepares audio inputs if needed (delegating actual audio tokenization to the model). It also handles decoding the final output. Args: tokenizer (`PreTrainedTokenizerBase`): An instance of [`AutoTokenizer`]. The tokenizer is used to encode the prompt text. feature_extractor (`Wav2Vec2FeatureExtractor`): An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is used to processor reference audio (though the main processing happens inside the model). model (`PreTrainedModel`, *optional*): A reference to the loaded `SparkTTSModel`. This is REQUIRED for voice cloning (prompt audio processing) and final audio decoding, as these steps rely on the model's internal BiCodec and Wav2Vec2 components. Set this using `processor.model = model` after loading both. config (`SparkTTSConfig`, *optional*): The configuration object, needed for parameters like sample_rate. Can often be inferred from the model. """ attributes = ["tokenizer", "feature_extractor"] tokenizer_class = ("Qwen2TokenizerFast", "Qwen2Tokenizer") # Specify the underlying tokenizer type feature_extractor_class = ("Wav2Vec2FeatureExtractor",) # Specify the underlying feature extractor type def __init__(self, tokenizer=None, feature_extractor=None, model=None, config=None, **kwargs): if tokenizer is None: raise ValueError("SparkTTSProcessor requires a `tokenizer`.") if feature_extractor is None: # Attempt to load default if path is known or provide clearer error raise ValueError("SparkTTSProcessor requires a `feature_extractor` (Wav2Vec2FeatureExtractor).") super().__init__(tokenizer, feature_extractor) self.model = model # Store model reference (can be None initially) self.config = config # Store config reference # Get sampling rate from config if available self.sampling_rate = None if self.config and hasattr(self.config, 'sample_rate'): self.sampling_rate = self.config.sample_rate elif self.model and hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'): self.sampling_rate = self.model.config.sample_rate else: # Try feature extractor default, or raise warning if hasattr(self.feature_extractor, 'sampling_rate'): self.sampling_rate = self.feature_extractor.sampling_rate else: logger.warning("Could not determine sampling rate. Defaulting to 16000. Set `processor.sampling_rate` manually if needed.") self.sampling_rate = 16000 @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ Instantiate a [`SparkTTSProcessor`] from a pretrained processor configuration. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: - a string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co. - a path to a *directory* containing processor files saved using the `save_pretrained()` method, e.g., `./my_model_directory/`. **kwargs: Additional keyword arguments passed along to both `AutoTokenizer.from_pretrained()` and `AutoFeatureExtractor.from_pretrained()`. """ config = kwargs.pop("config", None) if config is None: # Try loading the specific config first try: config = SparkTTSConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) except Exception: logger.warning(f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. Processor might lack some config values.") config = None # Resolve component paths relative to the main path def _resolve_path(sub_path): p = Path(sub_path) if p.is_absolute(): return str(p) # Try resolving relative to the main path if it's a directory main_path = Path(pretrained_model_name_or_path) if main_path.is_dir(): resolved = main_path / p if resolved.exists(): return str(resolved) # Fallback to assuming sub_path is relative within a repo structure (might fail for local non-dirs) return sub_path # Determine paths from config or assume defaults llm_tokenizer_path = "./LLM" w2v_processor_path = "./wav2vec2-large-xlsr-53" if config: llm_tokenizer_path = getattr(config, 'llm_model_name_or_path', llm_tokenizer_path) w2v_processor_path = getattr(config, 'wav2vec2_model_name_or_path', w2v_processor_path) resolved_tokenizer_path = _resolve_path(llm_tokenizer_path) resolved_w2v_path = _resolve_path(w2v_processor_path) try: tokenizer = AutoTokenizer.from_pretrained(resolved_tokenizer_path, **kwargs) except Exception as e: raise OSError(f"Could not load tokenizer from {resolved_tokenizer_path}. Ensure path is correct and files exist. Original error: {e}") try: feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(resolved_w2v_path, **kwargs) except Exception as e: raise OSError(f"Could not load feature extractor from {resolved_w2v_path}. Ensure path is correct and files exist. Original error: {e}") # The 'model' attribute will be set later externally return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=config) def __call__(self, text: str = None, prompt_speech_path: Optional[str] = None, prompt_text: Optional[str] = None, gender: Optional[str] = None, pitch: Optional[str] = None, speed: Optional[str] = None, return_tensors: Optional[str] = "pt", **kwargs) -> BatchEncoding: """ Main method to process inputs for the SparkTTS model. Args: text (`str`): The text to be synthesized. prompt_speech_path (`str`, *optional*): Path to prompt audio for voice cloning. prompt_text (`str`, *optional*): Transcript of prompt audio. gender (`str`, *optional*): Target gender ('male' or 'female') for voice creation. pitch (`str`, *optional*): Target pitch level ('very_low'...'very_high') for voice creation. speed (`str`, *optional*): Target speed level ('very_low'...'very_high') for voice creation. return_tensors (`str`, *optional*, defaults to `"pt"`): Framework of the returned tensors (`"pt"` for PyTorch, `"np"` for NumPy). **kwargs: Additional arguments (currently ignored). Returns: `BatchEncoding`: A dictionary containing the `input_ids`, `attention_mask`, and optionally `global_token_ids_prompt` ready for the model's `.generate()` method. """ if text is None: raise ValueError("`text` input must be provided.") global_token_ids_prompt = None llm_prompt_string = "" if prompt_speech_path is not None: # --- Voice Cloning Mode --- if self.model is None: raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for voice cloning.") if not hasattr(self.model, '_tokenize_audio'): raise AttributeError("The provided model object does not have the required '_tokenize_audio' method.") logger.info(f"Processing prompt audio: {prompt_speech_path}") # Delegate audio tokenization to the model try: # _tokenize_audio returns (global_tokens, semantic_tokens) global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path) global_token_ids_prompt = global_tokens # Keep for decoding stage except Exception as e: logger.error(f"Error tokenizing prompt audio: {e}", exc_info=True) raise RuntimeError(f"Failed to process prompt audio file: {prompt_speech_path}. Check file integrity and model compatibility.") from e # Format prompt string using token maps global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()]) if prompt_text and len(prompt_text) > 1: semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()]) llm_prompt_parts = [ TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>", "<|start_global_token|>", global_tokens_str, "<|end_global_token|>", "<|start_semantic_token|>", semantic_tokens_str, ] else: llm_prompt_parts = [ TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>", "<|start_global_token|>", global_tokens_str, "<|end_global_token|>", ] llm_prompt_string = "".join(llm_prompt_parts) elif gender is not None and pitch is not None and speed is not None: # --- Voice Creation Mode --- if gender not in GENDER_MAP: raise ValueError(f"Invalid gender '{gender}'.") if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch '{pitch}'.") if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed '{speed}'.") gender_id = GENDER_MAP[gender] pitch_level_id = LEVELS_MAP[pitch] speed_level_id = LEVELS_MAP[speed] attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>" llm_prompt_parts = [ TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>", "<|start_style_label|>", attribute_tokens, "<|end_style_label|>", ] llm_prompt_string = "".join(llm_prompt_parts) # No global_token_ids_prompt needed else: raise ValueError("Processor requires either 'prompt_speech_path' (for cloning) or 'gender', 'pitch', and 'speed' (for creation).") # Tokenize the final LLM prompt string inputs = self.tokenizer(llm_prompt_string, return_tensors=return_tensors, padding=False, truncation=False) # Add prompt global tokens to the output if they exist (for passing to decode) if global_token_ids_prompt is not None: inputs["global_token_ids_prompt"] = global_token_ids_prompt return inputs def decode(self, generated_ids: Union[List[int], np.ndarray, torch.Tensor], global_token_ids_prompt: Optional[torch.Tensor] = None, input_ids_len: Optional[int] = None, skip_special_tokens: bool = True) -> Dict[str, Any]: """ Decodes the raw token IDs generated by the model into an audio waveform. Args: generated_ids (`Union[List[int], np.ndarray, torch.Tensor]`): The token IDs generated by the `model.generate()` method. Assumed to be a single sequence (batch size 1). global_token_ids_prompt (`torch.Tensor`, *optional*): The global tokens obtained from the prompt audio during preprocessing (needed for voice cloning). Should be passed from the `__call__` output. input_ids_len (`int`, *optional*): The length of the original prompt `input_ids`. If provided, the prompt part will be stripped from `generated_ids` before decoding the text representation. If None, assumes `generated_ids` contains *only* the generated part. skip_special_tokens (`bool`, *optional*, defaults to `True`): Whether to skip special tokens when decoding the text representation for parsing. Returns: `Dict[str, Any]`: A dictionary containing: - `audio` (`np.ndarray`): The generated audio waveform. - `sampling_rate` (`int`): The sampling rate of the audio. """ if self.model is None: raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for decoding.") if not hasattr(self.model, '_detokenize_audio'): raise AttributeError("The provided model object does not have the required '_detokenize_audio' method.") if self.sampling_rate is None: raise ValueError("Processor could not determine sampling_rate. Set `processor.sampling_rate`.") # Ensure generated_ids is a tensor on the correct device if isinstance(generated_ids, (list, np.ndarray)): output_ids_tensor = torch.tensor(generated_ids) else: output_ids_tensor = generated_ids # Remove prompt if input_ids_len is provided if input_ids_len is not None: # Handle potential batch dimension if present (though usually not for decode) if output_ids_tensor.ndim > 1: output_ids = output_ids_tensor[0, input_ids_len:] else: output_ids = output_ids_tensor[input_ids_len:] else: if output_ids_tensor.ndim > 1: output_ids = output_ids_tensor[0] else: output_ids = output_ids_tensor if output_ids.numel() == 0: logger.warning("Received empty generated IDs after removing prompt. Returning empty audio.") return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate} # Decode the text representation to parse tokens predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens) # Extract semantic tokens semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text) if not semantic_matches: logger.warning("No semantic tokens found in the generated output text. Cannot synthesize audio.") return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate} # Use model's device for tensors device = self.model.device pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches], dtype=torch.long, device=device).unsqueeze(0) # Add batch dim # Determine global tokens if global_token_ids_prompt is not None: # Voice Cloning: Use prompt global tokens global_token_ids = global_token_ids_prompt.to(device) # Ensure correct shape (B, T_token, Q) or (B, D) - BiCodec detokenize needs to handle this if global_token_ids.ndim == 2: # If (B, D), maybe unsqueeze? Check BiCodec.detokenize expectation global_token_ids = global_token_ids.unsqueeze(1) # Assume (B, 1, D) might be needed else: # Voice Creation: Parse global tokens from generated text global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text) if not global_matches: logger.error("Voice creation failed: No global tokens found in generated text.") raise ValueError("Voice creation failed: Could not find bicodec_global tokens in the LLM output.") global_token_ids = torch.tensor([int(token) for token in global_matches], dtype=torch.long, device=device).unsqueeze(0) # Add batch dim # Add sequence dimension if needed (check BiCodec.detokenize) if global_token_ids.ndim == 2: global_token_ids = global_token_ids.unsqueeze(1) # Assume (B, 1, D) # Detokenize audio using the model's method try: wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids) except Exception as e: logger.error(f"Error during audio detokenization: {e}", exc_info=True) raise RuntimeError("Failed to synthesize audio waveform from generated tokens.") from e return {"audio": wav_np, "sampling_rate": self.sampling_rate}