# Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import re import numpy as np from typing import Optional, Dict, Any, Union, List from pathlib import Path from transformers import Pipeline, PreTrainedModel # Inherit directly from the base class from transformers.utils import logging # Import necessary items from this module (assuming they are defined in modeling_spark_tts) from .modeling_spark_tts import SparkTTSModel from .configuration_spark_tts import SparkTTSConfig from .modeling_spark_tts import ( load_audio, TASK_TOKEN_MAP, LEVELS_MAP, GENDER_MAP, LEVELS_MAP_UI, ) # No need to import SparkTTSModel/Config here, pipeline gets them during init logger = logging.get_logger(__name__) # Define constants if needed (e.g., for default generation args) DEFAULT_MAX_NEW_TOKENS = 3000 DEFAULT_TEMPERATURE = 0.8 DEFAULT_TOP_K = 50 DEFAULT_TOP_P = 0.95 class SparkTTSPipeline(Pipeline): """ Custom Pipeline for SparkTTS text-to-speech generation, following HF documentation structure. Handles voice cloning and voice creation modes. """ def __init__(self, model, tokenizer=None, framework="pt", device=None, **kwargs): # --- KHÔNG NÊN load model ở đây nữa --- # __init__ của pipeline tùy chỉnh nên nhận model và tokenizer đã được load # Việc load nên xảy ra BÊN NGOÀI trước khi gọi pipeline() hoặc do pipeline factory xử lý # Kiểm tra model và tokenizer được truyền vào if model is None: raise ValueError("SparkTTSPipeline requires a 'model' argument.") if not isinstance(model, SparkTTSModel): # Có thể model được load bằng AutoModel nên là PreTrainedModel, kiểm tra config if isinstance(model, PreTrainedModel) and isinstance(model.config, SparkTTSConfig): pass # OK, model tương thích dựa trên config else: raise TypeError(f"Expected model compatible with SparkTTSConfig, but got {type(model)}") if tokenizer is None: raise ValueError("SparkTTSPipeline requires a 'tokenizer' argument.") if not hasattr(tokenizer, 'encode') or not hasattr(tokenizer, 'batch_decode'): raise TypeError("Tokenizer does not seem to be a valid Transformers tokenizer.") # Gọi super().__init__ với model/tokenizer đã nhận super().__init__(model=model, tokenizer=tokenizer, framework=framework, device=device, **kwargs) if hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'): self.sampling_rate = self.model.config.sample_rate else: # Nên đặt giá trị mặc định hoặc lấy từ nơi khác nếu config không có logger.warning("Could not determine sampling rate from model config. Defaulting to 16000.") self.sampling_rate = 16000 def _sanitize_parameters(self, **kwargs) -> tuple[dict, dict, dict]: """ Sanitizes pipeline parameters and separates them for preprocess, forward, and postprocess. Returns: Tuple[dict, dict, dict]: preprocess_kwargs, forward_kwargs, postprocess_kwargs """ preprocess_kwargs = {} # --- Preprocessing specific args --- if "prompt_speech_path" in kwargs: preprocess_kwargs["prompt_speech_path"] = kwargs["prompt_speech_path"] if "prompt_text" in kwargs: preprocess_kwargs["prompt_text"] = kwargs["prompt_text"] if "gender" in kwargs: preprocess_kwargs["gender"] = kwargs["gender"] if "pitch" in kwargs: preprocess_kwargs["pitch"] = kwargs["pitch"] if "speed" in kwargs: preprocess_kwargs["speed"] = kwargs["speed"] forward_kwargs = {} # --- Forward specific args (LLM generation) --- # Use kwargs.get to allow users to override defaults forward_kwargs["max_new_tokens"] = kwargs.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS) forward_kwargs["do_sample"] = kwargs.get("do_sample", True) forward_kwargs["temperature"] = kwargs.get("temperature", DEFAULT_TEMPERATURE) forward_kwargs["top_k"] = kwargs.get("top_k", DEFAULT_TOP_K) forward_kwargs["top_p"] = kwargs.get("top_p", DEFAULT_TOP_P) # Ensure essential generation tokens are present if needed if self.tokenizer.eos_token_id is not None: forward_kwargs["eos_token_id"] = self.tokenizer.eos_token_id if self.tokenizer.pad_token_id is not None: forward_kwargs["pad_token_id"] = self.tokenizer.pad_token_id elif self.tokenizer.eos_token_id is not None: logger.warning("Setting pad_token_id to eos_token_id for open-end generation.") forward_kwargs["pad_token_id"] = self.tokenizer.eos_token_id # Filter out None values that might cause issues with generate forward_kwargs = {k: v for k, v in forward_kwargs.items() if v is not None} postprocess_kwargs = {} # --- Postprocessing specific args (if any in the future) --- # Example: if you added an option to return tokens instead of audio # if "return_tokens" in kwargs: # postprocess_kwargs["return_tokens"] = kwargs["return_tokens"] return preprocess_kwargs, forward_kwargs, postprocess_kwargs def preprocess(self, inputs, **preprocess_kwargs) -> dict: """ Transforms text input and preprocess_kwargs into model input format. Args: inputs (str): The text to synthesize. preprocess_kwargs (dict): Arguments relevant to preprocessing (e.g., prompt paths, controls). Returns: dict: Containing `model_inputs` (tokenized dict) and `global_token_ids_prompt` (optional Tensor). """ text = inputs prompt_speech_path = preprocess_kwargs.get("prompt_speech_path") prompt_text = preprocess_kwargs.get("prompt_text") gender = preprocess_kwargs.get("gender") pitch = preprocess_kwargs.get("pitch") speed = preprocess_kwargs.get("speed") global_token_ids = None llm_prompt_string = "" # --- Logic to build llm_prompt_string and get global_token_ids --- if prompt_speech_path is not None: # Voice Cloning Mode logger.info(f"Preprocessing for Voice Cloning (prompt: {prompt_speech_path})") if not Path(prompt_speech_path).exists(): raise FileNotFoundError(f"Prompt speech file not found: {prompt_speech_path}") # Use the MODEL's method for tokenization (self.model is set by base Pipeline class) global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path) global_token_ids = global_tokens # Keep Tensor for detokenization global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()]) if prompt_text and len(prompt_text) > 1: semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()]) llm_prompt_parts = [ TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>", "<|start_global_token|>", global_tokens_str, "<|end_global_token|>", "<|start_semantic_token|>", semantic_tokens_str, ] else: llm_prompt_parts = [ TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>", "<|start_global_token|>", global_tokens_str, "<|end_global_token|>", ] llm_prompt_string = "".join(llm_prompt_parts) elif gender is not None and pitch is not None and speed is not None: # Voice Creation Mode logger.info(f"Preprocessing for Voice Creation (gender: {gender}, pitch: {pitch}, speed: {speed})") if gender not in GENDER_MAP: raise ValueError(f"Invalid gender: {gender}") if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch: {pitch}") if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed: {speed}") gender_id = GENDER_MAP[gender] pitch_level_id = LEVELS_MAP[pitch] speed_level_id = LEVELS_MAP[speed] attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>" llm_prompt_parts = [ TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>", "<|start_style_label|>", attribute_tokens, "<|end_style_label|>", ] llm_prompt_string = "".join(llm_prompt_parts) else: raise ValueError("Pipeline requires 'prompt_speech_path' (for cloning) or 'gender', 'pitch', 'speed' (for creation).") # --- End prompt building logic --- # Tokenize the final prompt for the LLM # Use self.tokenizer (set by base Pipeline class) model_inputs = self.tokenizer(llm_prompt_string, return_tensors=self.framework, padding=False) return {"model_inputs": model_inputs, "global_token_ids_prompt": global_token_ids} def _forward(self, model_inputs, **forward_kwargs) -> dict: """ Passes model_inputs to the model's LLM generate method with forward_kwargs. Args: model_inputs (dict): Output from `preprocess`. forward_kwargs (dict): Generation arguments from `_sanitize_parameters`. Returns: dict: Containing `generated_ids`, `input_ids_len`, and context (`global_token_ids_prompt`). """ llm_inputs = model_inputs["model_inputs"] global_token_ids_prompt = model_inputs.get("global_token_ids_prompt") # Move inputs to the correct device (self.device is set by base Pipeline class) llm_inputs = {k: v.to(self.device) for k, v in llm_inputs.items()} input_ids = llm_inputs["input_ids"] input_ids_len = input_ids.shape[-1] # Combine LLM inputs and generation arguments generate_kwargs = {**llm_inputs, **forward_kwargs} logger.info(f"Generating tokens with args: {forward_kwargs}") # Use the model's LLM component (self.model is set by base Pipeline class) with torch.no_grad(): generated_ids = self.model.llm.generate(**generate_kwargs) # Prepare output dict to pass to postprocess output_dict = { "generated_ids": generated_ids, "input_ids_len": input_ids_len, } if global_token_ids_prompt is not None: output_dict["global_token_ids_prompt"] = global_token_ids_prompt return output_dict def postprocess(self, model_outputs, **postprocess_kwargs) -> dict: """ Transforms model outputs (from _forward) into the final audio dictionary. Args: model_outputs (dict): Dictionary from `_forward`. postprocess_kwargs (dict): Arguments relevant to postprocessing (currently none). Returns: dict: Containing `audio` (np.ndarray) and `sampling_rate` (int). """ generated_ids = model_outputs["generated_ids"] input_ids_len = model_outputs["input_ids_len"] global_token_ids_prompt = model_outputs.get("global_token_ids_prompt") # --- Logic to extract tokens and detokenize --- output_ids = generated_ids[0, input_ids_len:] # Assumes batch size 1 # Use self.tokenizer (set by base Pipeline class) predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=True) semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text) if not semantic_matches: logger.warning("No semantic tokens found. Returning empty audio.") # Use self.model.config for sampling rate return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.model.config.sample_rate} pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0).to(self.device) if global_token_ids_prompt is not None: # Voice Cloning: Use prompt tokens global_token_ids = global_token_ids_prompt.to(self.device) logger.info("Using global tokens from prompt.") else: # Voice Creation: Extract generated tokens global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text) if not global_matches: raise ValueError("Voice creation failed: No bicodec_global tokens found.") global_token_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0).to(self.device) if global_token_ids.ndim == 2: global_token_ids = global_token_ids.unsqueeze(1) logger.info("Using global tokens from generated text.") # Use the MODEL's method for detokenization (self.model is set by base Pipeline class) wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids) logger.info(f"Generated audio shape: {wav_np.shape}") # --- End detokenization logic --- # Return final output dictionary return {"audio": wav_np, "sampling_rate": self.model.config.sample_rate} # --- Add Registration Code Here --- # This code will execute when this file is loaded via trust_remote_code try: from transformers.pipelines import PIPELINE_REGISTRY from transformers import AutoModel # Use AutoModel for registration print(f"Registering SparkTTSPipeline for task 'text-to-speech' from pipeline_spark_tts.py...") PIPELINE_REGISTRY.register_pipeline( "text-to-speech", # Task name pipeline_class=SparkTTSPipeline, # The class defined above pt_model=AutoModel, # Compatible PT AutoModel class # tf_model=None, # Add TF class if needed ) print("Pipeline registration call completed successfully.") except ImportError: # Handle potential import error if transformers structure changes print("WARNING: Could not import PIPELINE_REGISTRY or AutoModel. Pipeline registration failed.") except Exception as e: print(f"ERROR: An unexpected error occurred during pipeline registration: {e}") # --- End Registration Code ---