File size: 15,088 Bytes
5cd61b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 |
# Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import re
import numpy as np
from typing import Optional, Dict, Any, Union, List
from pathlib import Path
from transformers import Pipeline, PreTrainedModel # Inherit directly from the base class
from transformers.utils import logging
# Import necessary items from this module (assuming they are defined in modeling_spark_tts)
from .modeling_spark_tts import SparkTTSModel
from .configuration_spark_tts import SparkTTSConfig
from .modeling_spark_tts import (
load_audio,
TASK_TOKEN_MAP,
LEVELS_MAP,
GENDER_MAP,
LEVELS_MAP_UI,
)
# No need to import SparkTTSModel/Config here, pipeline gets them during init
logger = logging.get_logger(__name__)
# Define constants if needed (e.g., for default generation args)
DEFAULT_MAX_NEW_TOKENS = 3000
DEFAULT_TEMPERATURE = 0.8
DEFAULT_TOP_K = 50
DEFAULT_TOP_P = 0.95
class SparkTTSPipeline(Pipeline):
"""
Custom Pipeline for SparkTTS text-to-speech generation, following HF documentation structure.
Handles voice cloning and voice creation modes.
"""
def __init__(self, model, tokenizer=None, framework="pt", device=None, **kwargs):
# --- KHÔNG NÊN load model ở đây nữa ---
# __init__ của pipeline tùy chỉnh nên nhận model và tokenizer đã được load
# Việc load nên xảy ra BÊN NGOÀI trước khi gọi pipeline() hoặc do pipeline factory xử lý
# Kiểm tra model và tokenizer được truyền vào
if model is None:
raise ValueError("SparkTTSPipeline requires a 'model' argument.")
if not isinstance(model, SparkTTSModel):
# Có thể model được load bằng AutoModel nên là PreTrainedModel, kiểm tra config
if isinstance(model, PreTrainedModel) and isinstance(model.config, SparkTTSConfig):
pass # OK, model tương thích dựa trên config
else:
raise TypeError(f"Expected model compatible with SparkTTSConfig, but got {type(model)}")
if tokenizer is None:
raise ValueError("SparkTTSPipeline requires a 'tokenizer' argument.")
if not hasattr(tokenizer, 'encode') or not hasattr(tokenizer, 'batch_decode'):
raise TypeError("Tokenizer does not seem to be a valid Transformers tokenizer.")
# Gọi super().__init__ với model/tokenizer đã nhận
super().__init__(model=model, tokenizer=tokenizer, framework=framework, device=device, **kwargs)
if hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'):
self.sampling_rate = self.model.config.sample_rate
else:
# Nên đặt giá trị mặc định hoặc lấy từ nơi khác nếu config không có
logger.warning("Could not determine sampling rate from model config. Defaulting to 16000.")
self.sampling_rate = 16000
def _sanitize_parameters(self, **kwargs) -> tuple[dict, dict, dict]:
"""
Sanitizes pipeline parameters and separates them for preprocess, forward, and postprocess.
Returns:
Tuple[dict, dict, dict]: preprocess_kwargs, forward_kwargs, postprocess_kwargs
"""
preprocess_kwargs = {}
# --- Preprocessing specific args ---
if "prompt_speech_path" in kwargs:
preprocess_kwargs["prompt_speech_path"] = kwargs["prompt_speech_path"]
if "prompt_text" in kwargs:
preprocess_kwargs["prompt_text"] = kwargs["prompt_text"]
if "gender" in kwargs:
preprocess_kwargs["gender"] = kwargs["gender"]
if "pitch" in kwargs:
preprocess_kwargs["pitch"] = kwargs["pitch"]
if "speed" in kwargs:
preprocess_kwargs["speed"] = kwargs["speed"]
forward_kwargs = {}
# --- Forward specific args (LLM generation) ---
# Use kwargs.get to allow users to override defaults
forward_kwargs["max_new_tokens"] = kwargs.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
forward_kwargs["do_sample"] = kwargs.get("do_sample", True)
forward_kwargs["temperature"] = kwargs.get("temperature", DEFAULT_TEMPERATURE)
forward_kwargs["top_k"] = kwargs.get("top_k", DEFAULT_TOP_K)
forward_kwargs["top_p"] = kwargs.get("top_p", DEFAULT_TOP_P)
# Ensure essential generation tokens are present if needed
if self.tokenizer.eos_token_id is not None:
forward_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
if self.tokenizer.pad_token_id is not None:
forward_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
elif self.tokenizer.eos_token_id is not None:
logger.warning("Setting pad_token_id to eos_token_id for open-end generation.")
forward_kwargs["pad_token_id"] = self.tokenizer.eos_token_id
# Filter out None values that might cause issues with generate
forward_kwargs = {k: v for k, v in forward_kwargs.items() if v is not None}
postprocess_kwargs = {}
# --- Postprocessing specific args (if any in the future) ---
# Example: if you added an option to return tokens instead of audio
# if "return_tokens" in kwargs:
# postprocess_kwargs["return_tokens"] = kwargs["return_tokens"]
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(self, inputs, **preprocess_kwargs) -> dict:
"""
Transforms text input and preprocess_kwargs into model input format.
Args:
inputs (str): The text to synthesize.
preprocess_kwargs (dict): Arguments relevant to preprocessing (e.g., prompt paths, controls).
Returns:
dict: Containing `model_inputs` (tokenized dict) and `global_token_ids_prompt` (optional Tensor).
"""
text = inputs
prompt_speech_path = preprocess_kwargs.get("prompt_speech_path")
prompt_text = preprocess_kwargs.get("prompt_text")
gender = preprocess_kwargs.get("gender")
pitch = preprocess_kwargs.get("pitch")
speed = preprocess_kwargs.get("speed")
global_token_ids = None
llm_prompt_string = ""
# --- Logic to build llm_prompt_string and get global_token_ids ---
if prompt_speech_path is not None:
# Voice Cloning Mode
logger.info(f"Preprocessing for Voice Cloning (prompt: {prompt_speech_path})")
if not Path(prompt_speech_path).exists():
raise FileNotFoundError(f"Prompt speech file not found: {prompt_speech_path}")
# Use the MODEL's method for tokenization (self.model is set by base Pipeline class)
global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path)
global_token_ids = global_tokens # Keep Tensor for detokenization
global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()])
if prompt_text and len(prompt_text) > 1:
semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()])
llm_prompt_parts = [
TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>",
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
"<|start_semantic_token|>", semantic_tokens_str,
]
else:
llm_prompt_parts = [
TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>",
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
]
llm_prompt_string = "".join(llm_prompt_parts)
elif gender is not None and pitch is not None and speed is not None:
# Voice Creation Mode
logger.info(f"Preprocessing for Voice Creation (gender: {gender}, pitch: {pitch}, speed: {speed})")
if gender not in GENDER_MAP: raise ValueError(f"Invalid gender: {gender}")
if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch: {pitch}")
if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed: {speed}")
gender_id = GENDER_MAP[gender]
pitch_level_id = LEVELS_MAP[pitch]
speed_level_id = LEVELS_MAP[speed]
attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>"
llm_prompt_parts = [
TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>",
"<|start_style_label|>", attribute_tokens, "<|end_style_label|>",
]
llm_prompt_string = "".join(llm_prompt_parts)
else:
raise ValueError("Pipeline requires 'prompt_speech_path' (for cloning) or 'gender', 'pitch', 'speed' (for creation).")
# --- End prompt building logic ---
# Tokenize the final prompt for the LLM
# Use self.tokenizer (set by base Pipeline class)
model_inputs = self.tokenizer(llm_prompt_string, return_tensors=self.framework, padding=False)
return {"model_inputs": model_inputs, "global_token_ids_prompt": global_token_ids}
def _forward(self, model_inputs, **forward_kwargs) -> dict:
"""
Passes model_inputs to the model's LLM generate method with forward_kwargs.
Args:
model_inputs (dict): Output from `preprocess`.
forward_kwargs (dict): Generation arguments from `_sanitize_parameters`.
Returns:
dict: Containing `generated_ids`, `input_ids_len`, and context (`global_token_ids_prompt`).
"""
llm_inputs = model_inputs["model_inputs"]
global_token_ids_prompt = model_inputs.get("global_token_ids_prompt")
# Move inputs to the correct device (self.device is set by base Pipeline class)
llm_inputs = {k: v.to(self.device) for k, v in llm_inputs.items()}
input_ids = llm_inputs["input_ids"]
input_ids_len = input_ids.shape[-1]
# Combine LLM inputs and generation arguments
generate_kwargs = {**llm_inputs, **forward_kwargs}
logger.info(f"Generating tokens with args: {forward_kwargs}")
# Use the model's LLM component (self.model is set by base Pipeline class)
with torch.no_grad():
generated_ids = self.model.llm.generate(**generate_kwargs)
# Prepare output dict to pass to postprocess
output_dict = {
"generated_ids": generated_ids,
"input_ids_len": input_ids_len,
}
if global_token_ids_prompt is not None:
output_dict["global_token_ids_prompt"] = global_token_ids_prompt
return output_dict
def postprocess(self, model_outputs, **postprocess_kwargs) -> dict:
"""
Transforms model outputs (from _forward) into the final audio dictionary.
Args:
model_outputs (dict): Dictionary from `_forward`.
postprocess_kwargs (dict): Arguments relevant to postprocessing (currently none).
Returns:
dict: Containing `audio` (np.ndarray) and `sampling_rate` (int).
"""
generated_ids = model_outputs["generated_ids"]
input_ids_len = model_outputs["input_ids_len"]
global_token_ids_prompt = model_outputs.get("global_token_ids_prompt")
# --- Logic to extract tokens and detokenize ---
output_ids = generated_ids[0, input_ids_len:] # Assumes batch size 1
# Use self.tokenizer (set by base Pipeline class)
predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text)
if not semantic_matches:
logger.warning("No semantic tokens found. Returning empty audio.")
# Use self.model.config for sampling rate
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.model.config.sample_rate}
pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0).to(self.device)
if global_token_ids_prompt is not None:
# Voice Cloning: Use prompt tokens
global_token_ids = global_token_ids_prompt.to(self.device)
logger.info("Using global tokens from prompt.")
else:
# Voice Creation: Extract generated tokens
global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text)
if not global_matches:
raise ValueError("Voice creation failed: No bicodec_global tokens found.")
global_token_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0).to(self.device)
if global_token_ids.ndim == 2: global_token_ids = global_token_ids.unsqueeze(1)
logger.info("Using global tokens from generated text.")
# Use the MODEL's method for detokenization (self.model is set by base Pipeline class)
wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids)
logger.info(f"Generated audio shape: {wav_np.shape}")
# --- End detokenization logic ---
# Return final output dictionary
return {"audio": wav_np, "sampling_rate": self.model.config.sample_rate}
# --- Add Registration Code Here ---
# This code will execute when this file is loaded via trust_remote_code
try:
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModel # Use AutoModel for registration
print(f"Registering SparkTTSPipeline for task 'text-to-speech' from pipeline_spark_tts.py...")
PIPELINE_REGISTRY.register_pipeline(
"text-to-speech", # Task name
pipeline_class=SparkTTSPipeline, # The class defined above
pt_model=AutoModel, # Compatible PT AutoModel class
# tf_model=None, # Add TF class if needed
)
print("Pipeline registration call completed successfully.")
except ImportError:
# Handle potential import error if transformers structure changes
print("WARNING: Could not import PIPELINE_REGISTRY or AutoModel. Pipeline registration failed.")
except Exception as e:
print(f"ERROR: An unexpected error occurred during pipeline registration: {e}")
# --- End Registration Code --- |