Delete pipeline_spark_tts.py
Browse files- pipeline_spark_tts.py +0 -303
pipeline_spark_tts.py
DELETED
@@ -1,303 +0,0 @@
|
|
1 |
-
# Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import torch
|
16 |
-
import re
|
17 |
-
import numpy as np
|
18 |
-
from typing import Optional, Dict, Any, Union, List
|
19 |
-
from pathlib import Path
|
20 |
-
from transformers import Pipeline, PreTrainedModel # Inherit directly from the base class
|
21 |
-
from transformers.utils import logging
|
22 |
-
|
23 |
-
# Import necessary items from this module (assuming they are defined in modeling_spark_tts)
|
24 |
-
from .modeling_spark_tts import SparkTTSModel
|
25 |
-
from .configuration_spark_tts import SparkTTSConfig
|
26 |
-
from .modeling_spark_tts import (
|
27 |
-
load_audio,
|
28 |
-
TASK_TOKEN_MAP,
|
29 |
-
LEVELS_MAP,
|
30 |
-
GENDER_MAP,
|
31 |
-
LEVELS_MAP_UI,
|
32 |
-
)
|
33 |
-
# No need to import SparkTTSModel/Config here, pipeline gets them during init
|
34 |
-
|
35 |
-
logger = logging.get_logger(__name__)
|
36 |
-
|
37 |
-
# Define constants if needed (e.g., for default generation args)
|
38 |
-
DEFAULT_MAX_NEW_TOKENS = 3000
|
39 |
-
DEFAULT_TEMPERATURE = 0.8
|
40 |
-
DEFAULT_TOP_K = 50
|
41 |
-
DEFAULT_TOP_P = 0.95
|
42 |
-
|
43 |
-
class SparkTTSPipeline(Pipeline):
|
44 |
-
"""
|
45 |
-
Custom Pipeline for SparkTTS text-to-speech generation, following HF documentation structure.
|
46 |
-
Handles voice cloning and voice creation modes.
|
47 |
-
"""
|
48 |
-
def __init__(self, model, tokenizer=None, framework="pt", device=None, **kwargs):
|
49 |
-
# --- KHÔNG NÊN load model ở đây nữa ---
|
50 |
-
# __init__ của pipeline tùy chỉnh nên nhận model và tokenizer đã được load
|
51 |
-
# Việc load nên xảy ra BÊN NGOÀI trước khi gọi pipeline() hoặc do pipeline factory xử lý
|
52 |
-
|
53 |
-
# Kiểm tra model và tokenizer được truyền vào
|
54 |
-
if model is None:
|
55 |
-
raise ValueError("SparkTTSPipeline requires a 'model' argument.")
|
56 |
-
if not isinstance(model, SparkTTSModel):
|
57 |
-
# Có thể model được load bằng AutoModel nên là PreTrainedModel, kiểm tra config
|
58 |
-
if isinstance(model, PreTrainedModel) and isinstance(model.config, SparkTTSConfig):
|
59 |
-
pass # OK, model tương thích dựa trên config
|
60 |
-
else:
|
61 |
-
raise TypeError(f"Expected model compatible with SparkTTSConfig, but got {type(model)}")
|
62 |
-
|
63 |
-
if tokenizer is None:
|
64 |
-
raise ValueError("SparkTTSPipeline requires a 'tokenizer' argument.")
|
65 |
-
if not hasattr(tokenizer, 'encode') or not hasattr(tokenizer, 'batch_decode'):
|
66 |
-
raise TypeError("Tokenizer does not seem to be a valid Transformers tokenizer.")
|
67 |
-
|
68 |
-
# Gọi super().__init__ với model/tokenizer đã nhận
|
69 |
-
super().__init__(model=model, tokenizer=tokenizer, framework=framework, device=device, **kwargs)
|
70 |
-
if hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'):
|
71 |
-
self.sampling_rate = self.model.config.sample_rate
|
72 |
-
else:
|
73 |
-
# Nên đặt giá trị mặc định hoặc lấy từ nơi khác nếu config không có
|
74 |
-
logger.warning("Could not determine sampling rate from model config. Defaulting to 16000.")
|
75 |
-
self.sampling_rate = 16000
|
76 |
-
|
77 |
-
def _sanitize_parameters(self, **kwargs) -> tuple[dict, dict, dict]:
|
78 |
-
"""
|
79 |
-
Sanitizes pipeline parameters and separates them for preprocess, forward, and postprocess.
|
80 |
-
|
81 |
-
Returns:
|
82 |
-
Tuple[dict, dict, dict]: preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
83 |
-
"""
|
84 |
-
preprocess_kwargs = {}
|
85 |
-
# --- Preprocessing specific args ---
|
86 |
-
if "prompt_speech_path" in kwargs:
|
87 |
-
preprocess_kwargs["prompt_speech_path"] = kwargs["prompt_speech_path"]
|
88 |
-
if "prompt_text" in kwargs:
|
89 |
-
preprocess_kwargs["prompt_text"] = kwargs["prompt_text"]
|
90 |
-
if "gender" in kwargs:
|
91 |
-
preprocess_kwargs["gender"] = kwargs["gender"]
|
92 |
-
if "pitch" in kwargs:
|
93 |
-
preprocess_kwargs["pitch"] = kwargs["pitch"]
|
94 |
-
if "speed" in kwargs:
|
95 |
-
preprocess_kwargs["speed"] = kwargs["speed"]
|
96 |
-
|
97 |
-
forward_kwargs = {}
|
98 |
-
# --- Forward specific args (LLM generation) ---
|
99 |
-
# Use kwargs.get to allow users to override defaults
|
100 |
-
forward_kwargs["max_new_tokens"] = kwargs.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
|
101 |
-
forward_kwargs["do_sample"] = kwargs.get("do_sample", True)
|
102 |
-
forward_kwargs["temperature"] = kwargs.get("temperature", DEFAULT_TEMPERATURE)
|
103 |
-
forward_kwargs["top_k"] = kwargs.get("top_k", DEFAULT_TOP_K)
|
104 |
-
forward_kwargs["top_p"] = kwargs.get("top_p", DEFAULT_TOP_P)
|
105 |
-
# Ensure essential generation tokens are present if needed
|
106 |
-
if self.tokenizer.eos_token_id is not None:
|
107 |
-
forward_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
|
108 |
-
if self.tokenizer.pad_token_id is not None:
|
109 |
-
forward_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
110 |
-
elif self.tokenizer.eos_token_id is not None:
|
111 |
-
logger.warning("Setting pad_token_id to eos_token_id for open-end generation.")
|
112 |
-
forward_kwargs["pad_token_id"] = self.tokenizer.eos_token_id
|
113 |
-
# Filter out None values that might cause issues with generate
|
114 |
-
forward_kwargs = {k: v for k, v in forward_kwargs.items() if v is not None}
|
115 |
-
|
116 |
-
postprocess_kwargs = {}
|
117 |
-
# --- Postprocessing specific args (if any in the future) ---
|
118 |
-
# Example: if you added an option to return tokens instead of audio
|
119 |
-
# if "return_tokens" in kwargs:
|
120 |
-
# postprocess_kwargs["return_tokens"] = kwargs["return_tokens"]
|
121 |
-
|
122 |
-
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
123 |
-
|
124 |
-
def preprocess(self, inputs, **preprocess_kwargs) -> dict:
|
125 |
-
"""
|
126 |
-
Transforms text input and preprocess_kwargs into model input format.
|
127 |
-
|
128 |
-
Args:
|
129 |
-
inputs (str): The text to synthesize.
|
130 |
-
preprocess_kwargs (dict): Arguments relevant to preprocessing (e.g., prompt paths, controls).
|
131 |
-
|
132 |
-
Returns:
|
133 |
-
dict: Containing `model_inputs` (tokenized dict) and `global_token_ids_prompt` (optional Tensor).
|
134 |
-
"""
|
135 |
-
text = inputs
|
136 |
-
prompt_speech_path = preprocess_kwargs.get("prompt_speech_path")
|
137 |
-
prompt_text = preprocess_kwargs.get("prompt_text")
|
138 |
-
gender = preprocess_kwargs.get("gender")
|
139 |
-
pitch = preprocess_kwargs.get("pitch")
|
140 |
-
speed = preprocess_kwargs.get("speed")
|
141 |
-
|
142 |
-
global_token_ids = None
|
143 |
-
llm_prompt_string = ""
|
144 |
-
|
145 |
-
# --- Logic to build llm_prompt_string and get global_token_ids ---
|
146 |
-
if prompt_speech_path is not None:
|
147 |
-
# Voice Cloning Mode
|
148 |
-
logger.info(f"Preprocessing for Voice Cloning (prompt: {prompt_speech_path})")
|
149 |
-
if not Path(prompt_speech_path).exists():
|
150 |
-
raise FileNotFoundError(f"Prompt speech file not found: {prompt_speech_path}")
|
151 |
-
# Use the MODEL's method for tokenization (self.model is set by base Pipeline class)
|
152 |
-
global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path)
|
153 |
-
global_token_ids = global_tokens # Keep Tensor for detokenization
|
154 |
-
|
155 |
-
global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()])
|
156 |
-
if prompt_text and len(prompt_text) > 1:
|
157 |
-
semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()])
|
158 |
-
llm_prompt_parts = [
|
159 |
-
TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>",
|
160 |
-
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
|
161 |
-
"<|start_semantic_token|>", semantic_tokens_str,
|
162 |
-
]
|
163 |
-
else:
|
164 |
-
llm_prompt_parts = [
|
165 |
-
TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>",
|
166 |
-
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
|
167 |
-
]
|
168 |
-
llm_prompt_string = "".join(llm_prompt_parts)
|
169 |
-
elif gender is not None and pitch is not None and speed is not None:
|
170 |
-
# Voice Creation Mode
|
171 |
-
logger.info(f"Preprocessing for Voice Creation (gender: {gender}, pitch: {pitch}, speed: {speed})")
|
172 |
-
if gender not in GENDER_MAP: raise ValueError(f"Invalid gender: {gender}")
|
173 |
-
if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch: {pitch}")
|
174 |
-
if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed: {speed}")
|
175 |
-
|
176 |
-
gender_id = GENDER_MAP[gender]
|
177 |
-
pitch_level_id = LEVELS_MAP[pitch]
|
178 |
-
speed_level_id = LEVELS_MAP[speed]
|
179 |
-
attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>"
|
180 |
-
llm_prompt_parts = [
|
181 |
-
TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>",
|
182 |
-
"<|start_style_label|>", attribute_tokens, "<|end_style_label|>",
|
183 |
-
]
|
184 |
-
llm_prompt_string = "".join(llm_prompt_parts)
|
185 |
-
else:
|
186 |
-
raise ValueError("Pipeline requires 'prompt_speech_path' (for cloning) or 'gender', 'pitch', 'speed' (for creation).")
|
187 |
-
# --- End prompt building logic ---
|
188 |
-
|
189 |
-
# Tokenize the final prompt for the LLM
|
190 |
-
# Use self.tokenizer (set by base Pipeline class)
|
191 |
-
model_inputs = self.tokenizer(llm_prompt_string, return_tensors=self.framework, padding=False)
|
192 |
-
|
193 |
-
return {"model_inputs": model_inputs, "global_token_ids_prompt": global_token_ids}
|
194 |
-
|
195 |
-
|
196 |
-
def _forward(self, model_inputs, **forward_kwargs) -> dict:
|
197 |
-
"""
|
198 |
-
Passes model_inputs to the model's LLM generate method with forward_kwargs.
|
199 |
-
|
200 |
-
Args:
|
201 |
-
model_inputs (dict): Output from `preprocess`.
|
202 |
-
forward_kwargs (dict): Generation arguments from `_sanitize_parameters`.
|
203 |
-
|
204 |
-
Returns:
|
205 |
-
dict: Containing `generated_ids`, `input_ids_len`, and context (`global_token_ids_prompt`).
|
206 |
-
"""
|
207 |
-
llm_inputs = model_inputs["model_inputs"]
|
208 |
-
global_token_ids_prompt = model_inputs.get("global_token_ids_prompt")
|
209 |
-
|
210 |
-
# Move inputs to the correct device (self.device is set by base Pipeline class)
|
211 |
-
llm_inputs = {k: v.to(self.device) for k, v in llm_inputs.items()}
|
212 |
-
input_ids = llm_inputs["input_ids"]
|
213 |
-
input_ids_len = input_ids.shape[-1]
|
214 |
-
|
215 |
-
# Combine LLM inputs and generation arguments
|
216 |
-
generate_kwargs = {**llm_inputs, **forward_kwargs}
|
217 |
-
|
218 |
-
logger.info(f"Generating tokens with args: {forward_kwargs}")
|
219 |
-
# Use the model's LLM component (self.model is set by base Pipeline class)
|
220 |
-
with torch.no_grad():
|
221 |
-
generated_ids = self.model.llm.generate(**generate_kwargs)
|
222 |
-
|
223 |
-
# Prepare output dict to pass to postprocess
|
224 |
-
output_dict = {
|
225 |
-
"generated_ids": generated_ids,
|
226 |
-
"input_ids_len": input_ids_len,
|
227 |
-
}
|
228 |
-
if global_token_ids_prompt is not None:
|
229 |
-
output_dict["global_token_ids_prompt"] = global_token_ids_prompt
|
230 |
-
|
231 |
-
return output_dict
|
232 |
-
|
233 |
-
def postprocess(self, model_outputs, **postprocess_kwargs) -> dict:
|
234 |
-
"""
|
235 |
-
Transforms model outputs (from _forward) into the final audio dictionary.
|
236 |
-
|
237 |
-
Args:
|
238 |
-
model_outputs (dict): Dictionary from `_forward`.
|
239 |
-
postprocess_kwargs (dict): Arguments relevant to postprocessing (currently none).
|
240 |
-
|
241 |
-
Returns:
|
242 |
-
dict: Containing `audio` (np.ndarray) and `sampling_rate` (int).
|
243 |
-
"""
|
244 |
-
generated_ids = model_outputs["generated_ids"]
|
245 |
-
input_ids_len = model_outputs["input_ids_len"]
|
246 |
-
global_token_ids_prompt = model_outputs.get("global_token_ids_prompt")
|
247 |
-
|
248 |
-
# --- Logic to extract tokens and detokenize ---
|
249 |
-
output_ids = generated_ids[0, input_ids_len:] # Assumes batch size 1
|
250 |
-
# Use self.tokenizer (set by base Pipeline class)
|
251 |
-
predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
|
252 |
-
|
253 |
-
semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text)
|
254 |
-
if not semantic_matches:
|
255 |
-
logger.warning("No semantic tokens found. Returning empty audio.")
|
256 |
-
# Use self.model.config for sampling rate
|
257 |
-
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.model.config.sample_rate}
|
258 |
-
|
259 |
-
pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0).to(self.device)
|
260 |
-
|
261 |
-
if global_token_ids_prompt is not None:
|
262 |
-
# Voice Cloning: Use prompt tokens
|
263 |
-
global_token_ids = global_token_ids_prompt.to(self.device)
|
264 |
-
logger.info("Using global tokens from prompt.")
|
265 |
-
else:
|
266 |
-
# Voice Creation: Extract generated tokens
|
267 |
-
global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text)
|
268 |
-
if not global_matches:
|
269 |
-
raise ValueError("Voice creation failed: No bicodec_global tokens found.")
|
270 |
-
global_token_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0).to(self.device)
|
271 |
-
if global_token_ids.ndim == 2: global_token_ids = global_token_ids.unsqueeze(1)
|
272 |
-
logger.info("Using global tokens from generated text.")
|
273 |
-
|
274 |
-
# Use the MODEL's method for detokenization (self.model is set by base Pipeline class)
|
275 |
-
wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids)
|
276 |
-
logger.info(f"Generated audio shape: {wav_np.shape}")
|
277 |
-
# --- End detokenization logic ---
|
278 |
-
|
279 |
-
# Return final output dictionary
|
280 |
-
return {"audio": wav_np, "sampling_rate": self.model.config.sample_rate}
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
# --- Add Registration Code Here ---
|
285 |
-
# This code will execute when this file is loaded via trust_remote_code
|
286 |
-
try:
|
287 |
-
from transformers.pipelines import PIPELINE_REGISTRY
|
288 |
-
from transformers import AutoModel # Use AutoModel for registration
|
289 |
-
|
290 |
-
print(f"Registering SparkTTSPipeline for task 'text-to-speech' from pipeline_spark_tts.py...")
|
291 |
-
PIPELINE_REGISTRY.register_pipeline(
|
292 |
-
"text-to-speech", # Task name
|
293 |
-
pipeline_class=SparkTTSPipeline, # The class defined above
|
294 |
-
pt_model=AutoModel, # Compatible PT AutoModel class
|
295 |
-
# tf_model=None, # Add TF class if needed
|
296 |
-
)
|
297 |
-
print("Pipeline registration call completed successfully.")
|
298 |
-
except ImportError:
|
299 |
-
# Handle potential import error if transformers structure changes
|
300 |
-
print("WARNING: Could not import PIPELINE_REGISTRY or AutoModel. Pipeline registration failed.")
|
301 |
-
except Exception as e:
|
302 |
-
print(f"ERROR: An unexpected error occurred during pipeline registration: {e}")
|
303 |
-
# --- End Registration Code ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|