|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import re |
|
import numpy as np |
|
from typing import Optional, Dict, Any, Union, List |
|
from pathlib import Path |
|
from transformers import Pipeline, PreTrainedModel |
|
from transformers.utils import logging |
|
|
|
|
|
from .modeling_spark_tts import SparkTTSModel |
|
from .configuration_spark_tts import SparkTTSConfig |
|
from .modeling_spark_tts import ( |
|
load_audio, |
|
TASK_TOKEN_MAP, |
|
LEVELS_MAP, |
|
GENDER_MAP, |
|
LEVELS_MAP_UI, |
|
) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
DEFAULT_MAX_NEW_TOKENS = 3000 |
|
DEFAULT_TEMPERATURE = 0.8 |
|
DEFAULT_TOP_K = 50 |
|
DEFAULT_TOP_P = 0.95 |
|
|
|
class SparkTTSPipeline(Pipeline): |
|
""" |
|
Custom Pipeline for SparkTTS text-to-speech generation, following HF documentation structure. |
|
Handles voice cloning and voice creation modes. |
|
""" |
|
def __init__(self, model, tokenizer=None, framework="pt", device=None, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
if model is None: |
|
raise ValueError("SparkTTSPipeline requires a 'model' argument.") |
|
if not isinstance(model, SparkTTSModel): |
|
|
|
if isinstance(model, PreTrainedModel) and isinstance(model.config, SparkTTSConfig): |
|
pass |
|
else: |
|
raise TypeError(f"Expected model compatible with SparkTTSConfig, but got {type(model)}") |
|
|
|
if tokenizer is None: |
|
raise ValueError("SparkTTSPipeline requires a 'tokenizer' argument.") |
|
if not hasattr(tokenizer, 'encode') or not hasattr(tokenizer, 'batch_decode'): |
|
raise TypeError("Tokenizer does not seem to be a valid Transformers tokenizer.") |
|
|
|
|
|
super().__init__(model=model, tokenizer=tokenizer, framework=framework, device=device, **kwargs) |
|
if hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'): |
|
self.sampling_rate = self.model.config.sample_rate |
|
else: |
|
|
|
logger.warning("Could not determine sampling rate from model config. Defaulting to 16000.") |
|
self.sampling_rate = 16000 |
|
|
|
def _sanitize_parameters(self, **kwargs) -> tuple[dict, dict, dict]: |
|
""" |
|
Sanitizes pipeline parameters and separates them for preprocess, forward, and postprocess. |
|
|
|
Returns: |
|
Tuple[dict, dict, dict]: preprocess_kwargs, forward_kwargs, postprocess_kwargs |
|
""" |
|
preprocess_kwargs = {} |
|
|
|
if "prompt_speech_path" in kwargs: |
|
preprocess_kwargs["prompt_speech_path"] = kwargs["prompt_speech_path"] |
|
if "prompt_text" in kwargs: |
|
preprocess_kwargs["prompt_text"] = kwargs["prompt_text"] |
|
if "gender" in kwargs: |
|
preprocess_kwargs["gender"] = kwargs["gender"] |
|
if "pitch" in kwargs: |
|
preprocess_kwargs["pitch"] = kwargs["pitch"] |
|
if "speed" in kwargs: |
|
preprocess_kwargs["speed"] = kwargs["speed"] |
|
|
|
forward_kwargs = {} |
|
|
|
|
|
forward_kwargs["max_new_tokens"] = kwargs.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS) |
|
forward_kwargs["do_sample"] = kwargs.get("do_sample", True) |
|
forward_kwargs["temperature"] = kwargs.get("temperature", DEFAULT_TEMPERATURE) |
|
forward_kwargs["top_k"] = kwargs.get("top_k", DEFAULT_TOP_K) |
|
forward_kwargs["top_p"] = kwargs.get("top_p", DEFAULT_TOP_P) |
|
|
|
if self.tokenizer.eos_token_id is not None: |
|
forward_kwargs["eos_token_id"] = self.tokenizer.eos_token_id |
|
if self.tokenizer.pad_token_id is not None: |
|
forward_kwargs["pad_token_id"] = self.tokenizer.pad_token_id |
|
elif self.tokenizer.eos_token_id is not None: |
|
logger.warning("Setting pad_token_id to eos_token_id for open-end generation.") |
|
forward_kwargs["pad_token_id"] = self.tokenizer.eos_token_id |
|
|
|
forward_kwargs = {k: v for k, v in forward_kwargs.items() if v is not None} |
|
|
|
postprocess_kwargs = {} |
|
|
|
|
|
|
|
|
|
|
|
return preprocess_kwargs, forward_kwargs, postprocess_kwargs |
|
|
|
def preprocess(self, inputs, **preprocess_kwargs) -> dict: |
|
""" |
|
Transforms text input and preprocess_kwargs into model input format. |
|
|
|
Args: |
|
inputs (str): The text to synthesize. |
|
preprocess_kwargs (dict): Arguments relevant to preprocessing (e.g., prompt paths, controls). |
|
|
|
Returns: |
|
dict: Containing `model_inputs` (tokenized dict) and `global_token_ids_prompt` (optional Tensor). |
|
""" |
|
text = inputs |
|
prompt_speech_path = preprocess_kwargs.get("prompt_speech_path") |
|
prompt_text = preprocess_kwargs.get("prompt_text") |
|
gender = preprocess_kwargs.get("gender") |
|
pitch = preprocess_kwargs.get("pitch") |
|
speed = preprocess_kwargs.get("speed") |
|
|
|
global_token_ids = None |
|
llm_prompt_string = "" |
|
|
|
|
|
if prompt_speech_path is not None: |
|
|
|
logger.info(f"Preprocessing for Voice Cloning (prompt: {prompt_speech_path})") |
|
if not Path(prompt_speech_path).exists(): |
|
raise FileNotFoundError(f"Prompt speech file not found: {prompt_speech_path}") |
|
|
|
global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path) |
|
global_token_ids = global_tokens |
|
|
|
global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()]) |
|
if prompt_text and len(prompt_text) > 1: |
|
semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()]) |
|
llm_prompt_parts = [ |
|
TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>", |
|
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>", |
|
"<|start_semantic_token|>", semantic_tokens_str, |
|
] |
|
else: |
|
llm_prompt_parts = [ |
|
TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>", |
|
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>", |
|
] |
|
llm_prompt_string = "".join(llm_prompt_parts) |
|
elif gender is not None and pitch is not None and speed is not None: |
|
|
|
logger.info(f"Preprocessing for Voice Creation (gender: {gender}, pitch: {pitch}, speed: {speed})") |
|
if gender not in GENDER_MAP: raise ValueError(f"Invalid gender: {gender}") |
|
if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch: {pitch}") |
|
if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed: {speed}") |
|
|
|
gender_id = GENDER_MAP[gender] |
|
pitch_level_id = LEVELS_MAP[pitch] |
|
speed_level_id = LEVELS_MAP[speed] |
|
attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>" |
|
llm_prompt_parts = [ |
|
TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>", |
|
"<|start_style_label|>", attribute_tokens, "<|end_style_label|>", |
|
] |
|
llm_prompt_string = "".join(llm_prompt_parts) |
|
else: |
|
raise ValueError("Pipeline requires 'prompt_speech_path' (for cloning) or 'gender', 'pitch', 'speed' (for creation).") |
|
|
|
|
|
|
|
|
|
model_inputs = self.tokenizer(llm_prompt_string, return_tensors=self.framework, padding=False) |
|
|
|
return {"model_inputs": model_inputs, "global_token_ids_prompt": global_token_ids} |
|
|
|
|
|
def _forward(self, model_inputs, **forward_kwargs) -> dict: |
|
""" |
|
Passes model_inputs to the model's LLM generate method with forward_kwargs. |
|
|
|
Args: |
|
model_inputs (dict): Output from `preprocess`. |
|
forward_kwargs (dict): Generation arguments from `_sanitize_parameters`. |
|
|
|
Returns: |
|
dict: Containing `generated_ids`, `input_ids_len`, and context (`global_token_ids_prompt`). |
|
""" |
|
llm_inputs = model_inputs["model_inputs"] |
|
global_token_ids_prompt = model_inputs.get("global_token_ids_prompt") |
|
|
|
|
|
llm_inputs = {k: v.to(self.device) for k, v in llm_inputs.items()} |
|
input_ids = llm_inputs["input_ids"] |
|
input_ids_len = input_ids.shape[-1] |
|
|
|
|
|
generate_kwargs = {**llm_inputs, **forward_kwargs} |
|
|
|
logger.info(f"Generating tokens with args: {forward_kwargs}") |
|
|
|
with torch.no_grad(): |
|
generated_ids = self.model.llm.generate(**generate_kwargs) |
|
|
|
|
|
output_dict = { |
|
"generated_ids": generated_ids, |
|
"input_ids_len": input_ids_len, |
|
} |
|
if global_token_ids_prompt is not None: |
|
output_dict["global_token_ids_prompt"] = global_token_ids_prompt |
|
|
|
return output_dict |
|
|
|
def postprocess(self, model_outputs, **postprocess_kwargs) -> dict: |
|
""" |
|
Transforms model outputs (from _forward) into the final audio dictionary. |
|
|
|
Args: |
|
model_outputs (dict): Dictionary from `_forward`. |
|
postprocess_kwargs (dict): Arguments relevant to postprocessing (currently none). |
|
|
|
Returns: |
|
dict: Containing `audio` (np.ndarray) and `sampling_rate` (int). |
|
""" |
|
generated_ids = model_outputs["generated_ids"] |
|
input_ids_len = model_outputs["input_ids_len"] |
|
global_token_ids_prompt = model_outputs.get("global_token_ids_prompt") |
|
|
|
|
|
output_ids = generated_ids[0, input_ids_len:] |
|
|
|
predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=True) |
|
|
|
semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text) |
|
if not semantic_matches: |
|
logger.warning("No semantic tokens found. Returning empty audio.") |
|
|
|
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.model.config.sample_rate} |
|
|
|
pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0).to(self.device) |
|
|
|
if global_token_ids_prompt is not None: |
|
|
|
global_token_ids = global_token_ids_prompt.to(self.device) |
|
logger.info("Using global tokens from prompt.") |
|
else: |
|
|
|
global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text) |
|
if not global_matches: |
|
raise ValueError("Voice creation failed: No bicodec_global tokens found.") |
|
global_token_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0).to(self.device) |
|
if global_token_ids.ndim == 2: global_token_ids = global_token_ids.unsqueeze(1) |
|
logger.info("Using global tokens from generated text.") |
|
|
|
|
|
wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids) |
|
logger.info(f"Generated audio shape: {wav_np.shape}") |
|
|
|
|
|
|
|
return {"audio": wav_np, "sampling_rate": self.model.config.sample_rate} |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
from transformers import AutoModel |
|
|
|
print(f"Registering SparkTTSPipeline for task 'text-to-speech' from pipeline_spark_tts.py...") |
|
PIPELINE_REGISTRY.register_pipeline( |
|
"text-to-speech", |
|
pipeline_class=SparkTTSPipeline, |
|
pt_model=AutoModel, |
|
|
|
) |
|
print("Pipeline registration call completed successfully.") |
|
except ImportError: |
|
|
|
print("WARNING: Could not import PIPELINE_REGISTRY or AutoModel. Pipeline registration failed.") |
|
except Exception as e: |
|
print(f"ERROR: An unexpected error occurred during pipeline registration: {e}") |
|
|