Vi-SparkTTS-0.5B / pipeline_spark_tts.py
ancv's picture
Upload 24 files
5cd61b8 verified
raw
history blame
15.1 kB
# Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import re
import numpy as np
from typing import Optional, Dict, Any, Union, List
from pathlib import Path
from transformers import Pipeline, PreTrainedModel # Inherit directly from the base class
from transformers.utils import logging
# Import necessary items from this module (assuming they are defined in modeling_spark_tts)
from .modeling_spark_tts import SparkTTSModel
from .configuration_spark_tts import SparkTTSConfig
from .modeling_spark_tts import (
load_audio,
TASK_TOKEN_MAP,
LEVELS_MAP,
GENDER_MAP,
LEVELS_MAP_UI,
)
# No need to import SparkTTSModel/Config here, pipeline gets them during init
logger = logging.get_logger(__name__)
# Define constants if needed (e.g., for default generation args)
DEFAULT_MAX_NEW_TOKENS = 3000
DEFAULT_TEMPERATURE = 0.8
DEFAULT_TOP_K = 50
DEFAULT_TOP_P = 0.95
class SparkTTSPipeline(Pipeline):
"""
Custom Pipeline for SparkTTS text-to-speech generation, following HF documentation structure.
Handles voice cloning and voice creation modes.
"""
def __init__(self, model, tokenizer=None, framework="pt", device=None, **kwargs):
# --- KHÔNG NÊN load model ở đây nữa ---
# __init__ của pipeline tùy chỉnh nên nhận model và tokenizer đã được load
# Việc load nên xảy ra BÊN NGOÀI trước khi gọi pipeline() hoặc do pipeline factory xử lý
# Kiểm tra model và tokenizer được truyền vào
if model is None:
raise ValueError("SparkTTSPipeline requires a 'model' argument.")
if not isinstance(model, SparkTTSModel):
# Có thể model được load bằng AutoModel nên là PreTrainedModel, kiểm tra config
if isinstance(model, PreTrainedModel) and isinstance(model.config, SparkTTSConfig):
pass # OK, model tương thích dựa trên config
else:
raise TypeError(f"Expected model compatible with SparkTTSConfig, but got {type(model)}")
if tokenizer is None:
raise ValueError("SparkTTSPipeline requires a 'tokenizer' argument.")
if not hasattr(tokenizer, 'encode') or not hasattr(tokenizer, 'batch_decode'):
raise TypeError("Tokenizer does not seem to be a valid Transformers tokenizer.")
# Gọi super().__init__ với model/tokenizer đã nhận
super().__init__(model=model, tokenizer=tokenizer, framework=framework, device=device, **kwargs)
if hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'):
self.sampling_rate = self.model.config.sample_rate
else:
# Nên đặt giá trị mặc định hoặc lấy từ nơi khác nếu config không có
logger.warning("Could not determine sampling rate from model config. Defaulting to 16000.")
self.sampling_rate = 16000
def _sanitize_parameters(self, **kwargs) -> tuple[dict, dict, dict]:
"""
Sanitizes pipeline parameters and separates them for preprocess, forward, and postprocess.
Returns:
Tuple[dict, dict, dict]: preprocess_kwargs, forward_kwargs, postprocess_kwargs
"""
preprocess_kwargs = {}
# --- Preprocessing specific args ---
if "prompt_speech_path" in kwargs:
preprocess_kwargs["prompt_speech_path"] = kwargs["prompt_speech_path"]
if "prompt_text" in kwargs:
preprocess_kwargs["prompt_text"] = kwargs["prompt_text"]
if "gender" in kwargs:
preprocess_kwargs["gender"] = kwargs["gender"]
if "pitch" in kwargs:
preprocess_kwargs["pitch"] = kwargs["pitch"]
if "speed" in kwargs:
preprocess_kwargs["speed"] = kwargs["speed"]
forward_kwargs = {}
# --- Forward specific args (LLM generation) ---
# Use kwargs.get to allow users to override defaults
forward_kwargs["max_new_tokens"] = kwargs.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
forward_kwargs["do_sample"] = kwargs.get("do_sample", True)
forward_kwargs["temperature"] = kwargs.get("temperature", DEFAULT_TEMPERATURE)
forward_kwargs["top_k"] = kwargs.get("top_k", DEFAULT_TOP_K)
forward_kwargs["top_p"] = kwargs.get("top_p", DEFAULT_TOP_P)
# Ensure essential generation tokens are present if needed
if self.tokenizer.eos_token_id is not None:
forward_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
if self.tokenizer.pad_token_id is not None:
forward_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
elif self.tokenizer.eos_token_id is not None:
logger.warning("Setting pad_token_id to eos_token_id for open-end generation.")
forward_kwargs["pad_token_id"] = self.tokenizer.eos_token_id
# Filter out None values that might cause issues with generate
forward_kwargs = {k: v for k, v in forward_kwargs.items() if v is not None}
postprocess_kwargs = {}
# --- Postprocessing specific args (if any in the future) ---
# Example: if you added an option to return tokens instead of audio
# if "return_tokens" in kwargs:
# postprocess_kwargs["return_tokens"] = kwargs["return_tokens"]
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(self, inputs, **preprocess_kwargs) -> dict:
"""
Transforms text input and preprocess_kwargs into model input format.
Args:
inputs (str): The text to synthesize.
preprocess_kwargs (dict): Arguments relevant to preprocessing (e.g., prompt paths, controls).
Returns:
dict: Containing `model_inputs` (tokenized dict) and `global_token_ids_prompt` (optional Tensor).
"""
text = inputs
prompt_speech_path = preprocess_kwargs.get("prompt_speech_path")
prompt_text = preprocess_kwargs.get("prompt_text")
gender = preprocess_kwargs.get("gender")
pitch = preprocess_kwargs.get("pitch")
speed = preprocess_kwargs.get("speed")
global_token_ids = None
llm_prompt_string = ""
# --- Logic to build llm_prompt_string and get global_token_ids ---
if prompt_speech_path is not None:
# Voice Cloning Mode
logger.info(f"Preprocessing for Voice Cloning (prompt: {prompt_speech_path})")
if not Path(prompt_speech_path).exists():
raise FileNotFoundError(f"Prompt speech file not found: {prompt_speech_path}")
# Use the MODEL's method for tokenization (self.model is set by base Pipeline class)
global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path)
global_token_ids = global_tokens # Keep Tensor for detokenization
global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()])
if prompt_text and len(prompt_text) > 1:
semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()])
llm_prompt_parts = [
TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>",
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
"<|start_semantic_token|>", semantic_tokens_str,
]
else:
llm_prompt_parts = [
TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>",
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
]
llm_prompt_string = "".join(llm_prompt_parts)
elif gender is not None and pitch is not None and speed is not None:
# Voice Creation Mode
logger.info(f"Preprocessing for Voice Creation (gender: {gender}, pitch: {pitch}, speed: {speed})")
if gender not in GENDER_MAP: raise ValueError(f"Invalid gender: {gender}")
if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch: {pitch}")
if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed: {speed}")
gender_id = GENDER_MAP[gender]
pitch_level_id = LEVELS_MAP[pitch]
speed_level_id = LEVELS_MAP[speed]
attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>"
llm_prompt_parts = [
TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>",
"<|start_style_label|>", attribute_tokens, "<|end_style_label|>",
]
llm_prompt_string = "".join(llm_prompt_parts)
else:
raise ValueError("Pipeline requires 'prompt_speech_path' (for cloning) or 'gender', 'pitch', 'speed' (for creation).")
# --- End prompt building logic ---
# Tokenize the final prompt for the LLM
# Use self.tokenizer (set by base Pipeline class)
model_inputs = self.tokenizer(llm_prompt_string, return_tensors=self.framework, padding=False)
return {"model_inputs": model_inputs, "global_token_ids_prompt": global_token_ids}
def _forward(self, model_inputs, **forward_kwargs) -> dict:
"""
Passes model_inputs to the model's LLM generate method with forward_kwargs.
Args:
model_inputs (dict): Output from `preprocess`.
forward_kwargs (dict): Generation arguments from `_sanitize_parameters`.
Returns:
dict: Containing `generated_ids`, `input_ids_len`, and context (`global_token_ids_prompt`).
"""
llm_inputs = model_inputs["model_inputs"]
global_token_ids_prompt = model_inputs.get("global_token_ids_prompt")
# Move inputs to the correct device (self.device is set by base Pipeline class)
llm_inputs = {k: v.to(self.device) for k, v in llm_inputs.items()}
input_ids = llm_inputs["input_ids"]
input_ids_len = input_ids.shape[-1]
# Combine LLM inputs and generation arguments
generate_kwargs = {**llm_inputs, **forward_kwargs}
logger.info(f"Generating tokens with args: {forward_kwargs}")
# Use the model's LLM component (self.model is set by base Pipeline class)
with torch.no_grad():
generated_ids = self.model.llm.generate(**generate_kwargs)
# Prepare output dict to pass to postprocess
output_dict = {
"generated_ids": generated_ids,
"input_ids_len": input_ids_len,
}
if global_token_ids_prompt is not None:
output_dict["global_token_ids_prompt"] = global_token_ids_prompt
return output_dict
def postprocess(self, model_outputs, **postprocess_kwargs) -> dict:
"""
Transforms model outputs (from _forward) into the final audio dictionary.
Args:
model_outputs (dict): Dictionary from `_forward`.
postprocess_kwargs (dict): Arguments relevant to postprocessing (currently none).
Returns:
dict: Containing `audio` (np.ndarray) and `sampling_rate` (int).
"""
generated_ids = model_outputs["generated_ids"]
input_ids_len = model_outputs["input_ids_len"]
global_token_ids_prompt = model_outputs.get("global_token_ids_prompt")
# --- Logic to extract tokens and detokenize ---
output_ids = generated_ids[0, input_ids_len:] # Assumes batch size 1
# Use self.tokenizer (set by base Pipeline class)
predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text)
if not semantic_matches:
logger.warning("No semantic tokens found. Returning empty audio.")
# Use self.model.config for sampling rate
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.model.config.sample_rate}
pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0).to(self.device)
if global_token_ids_prompt is not None:
# Voice Cloning: Use prompt tokens
global_token_ids = global_token_ids_prompt.to(self.device)
logger.info("Using global tokens from prompt.")
else:
# Voice Creation: Extract generated tokens
global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text)
if not global_matches:
raise ValueError("Voice creation failed: No bicodec_global tokens found.")
global_token_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0).to(self.device)
if global_token_ids.ndim == 2: global_token_ids = global_token_ids.unsqueeze(1)
logger.info("Using global tokens from generated text.")
# Use the MODEL's method for detokenization (self.model is set by base Pipeline class)
wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids)
logger.info(f"Generated audio shape: {wav_np.shape}")
# --- End detokenization logic ---
# Return final output dictionary
return {"audio": wav_np, "sampling_rate": self.model.config.sample_rate}
# --- Add Registration Code Here ---
# This code will execute when this file is loaded via trust_remote_code
try:
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModel # Use AutoModel for registration
print(f"Registering SparkTTSPipeline for task 'text-to-speech' from pipeline_spark_tts.py...")
PIPELINE_REGISTRY.register_pipeline(
"text-to-speech", # Task name
pipeline_class=SparkTTSPipeline, # The class defined above
pt_model=AutoModel, # Compatible PT AutoModel class
# tf_model=None, # Add TF class if needed
)
print("Pipeline registration call completed successfully.")
except ImportError:
# Handle potential import error if transformers structure changes
print("WARNING: Could not import PIPELINE_REGISTRY or AutoModel. Pipeline registration failed.")
except Exception as e:
print(f"ERROR: An unexpected error occurred during pipeline registration: {e}")
# --- End Registration Code ---