ancv commited on
Commit
72fe226
·
verified ·
1 Parent(s): cb1f56e

Delete pipeline_spark_tts.py

Browse files
Files changed (1) hide show
  1. pipeline_spark_tts.py +0 -303
pipeline_spark_tts.py DELETED
@@ -1,303 +0,0 @@
1
- # Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import torch
16
- import re
17
- import numpy as np
18
- from typing import Optional, Dict, Any, Union, List
19
- from pathlib import Path
20
- from transformers import Pipeline, PreTrainedModel # Inherit directly from the base class
21
- from transformers.utils import logging
22
-
23
- # Import necessary items from this module (assuming they are defined in modeling_spark_tts)
24
- from .modeling_spark_tts import SparkTTSModel
25
- from .configuration_spark_tts import SparkTTSConfig
26
- from .modeling_spark_tts import (
27
- load_audio,
28
- TASK_TOKEN_MAP,
29
- LEVELS_MAP,
30
- GENDER_MAP,
31
- LEVELS_MAP_UI,
32
- )
33
- # No need to import SparkTTSModel/Config here, pipeline gets them during init
34
-
35
- logger = logging.get_logger(__name__)
36
-
37
- # Define constants if needed (e.g., for default generation args)
38
- DEFAULT_MAX_NEW_TOKENS = 3000
39
- DEFAULT_TEMPERATURE = 0.8
40
- DEFAULT_TOP_K = 50
41
- DEFAULT_TOP_P = 0.95
42
-
43
- class SparkTTSPipeline(Pipeline):
44
- """
45
- Custom Pipeline for SparkTTS text-to-speech generation, following HF documentation structure.
46
- Handles voice cloning and voice creation modes.
47
- """
48
- def __init__(self, model, tokenizer=None, framework="pt", device=None, **kwargs):
49
- # --- KHÔNG NÊN load model ở đây nữa ---
50
- # __init__ của pipeline tùy chỉnh nên nhận model và tokenizer đã được load
51
- # Việc load nên xảy ra BÊN NGOÀI trước khi gọi pipeline() hoặc do pipeline factory xử lý
52
-
53
- # Kiểm tra model và tokenizer được truyền vào
54
- if model is None:
55
- raise ValueError("SparkTTSPipeline requires a 'model' argument.")
56
- if not isinstance(model, SparkTTSModel):
57
- # Có thể model được load bằng AutoModel nên là PreTrainedModel, kiểm tra config
58
- if isinstance(model, PreTrainedModel) and isinstance(model.config, SparkTTSConfig):
59
- pass # OK, model tương thích dựa trên config
60
- else:
61
- raise TypeError(f"Expected model compatible with SparkTTSConfig, but got {type(model)}")
62
-
63
- if tokenizer is None:
64
- raise ValueError("SparkTTSPipeline requires a 'tokenizer' argument.")
65
- if not hasattr(tokenizer, 'encode') or not hasattr(tokenizer, 'batch_decode'):
66
- raise TypeError("Tokenizer does not seem to be a valid Transformers tokenizer.")
67
-
68
- # Gọi super().__init__ với model/tokenizer đã nhận
69
- super().__init__(model=model, tokenizer=tokenizer, framework=framework, device=device, **kwargs)
70
- if hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'):
71
- self.sampling_rate = self.model.config.sample_rate
72
- else:
73
- # Nên đặt giá trị mặc định hoặc lấy từ nơi khác nếu config không có
74
- logger.warning("Could not determine sampling rate from model config. Defaulting to 16000.")
75
- self.sampling_rate = 16000
76
-
77
- def _sanitize_parameters(self, **kwargs) -> tuple[dict, dict, dict]:
78
- """
79
- Sanitizes pipeline parameters and separates them for preprocess, forward, and postprocess.
80
-
81
- Returns:
82
- Tuple[dict, dict, dict]: preprocess_kwargs, forward_kwargs, postprocess_kwargs
83
- """
84
- preprocess_kwargs = {}
85
- # --- Preprocessing specific args ---
86
- if "prompt_speech_path" in kwargs:
87
- preprocess_kwargs["prompt_speech_path"] = kwargs["prompt_speech_path"]
88
- if "prompt_text" in kwargs:
89
- preprocess_kwargs["prompt_text"] = kwargs["prompt_text"]
90
- if "gender" in kwargs:
91
- preprocess_kwargs["gender"] = kwargs["gender"]
92
- if "pitch" in kwargs:
93
- preprocess_kwargs["pitch"] = kwargs["pitch"]
94
- if "speed" in kwargs:
95
- preprocess_kwargs["speed"] = kwargs["speed"]
96
-
97
- forward_kwargs = {}
98
- # --- Forward specific args (LLM generation) ---
99
- # Use kwargs.get to allow users to override defaults
100
- forward_kwargs["max_new_tokens"] = kwargs.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
101
- forward_kwargs["do_sample"] = kwargs.get("do_sample", True)
102
- forward_kwargs["temperature"] = kwargs.get("temperature", DEFAULT_TEMPERATURE)
103
- forward_kwargs["top_k"] = kwargs.get("top_k", DEFAULT_TOP_K)
104
- forward_kwargs["top_p"] = kwargs.get("top_p", DEFAULT_TOP_P)
105
- # Ensure essential generation tokens are present if needed
106
- if self.tokenizer.eos_token_id is not None:
107
- forward_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
108
- if self.tokenizer.pad_token_id is not None:
109
- forward_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
110
- elif self.tokenizer.eos_token_id is not None:
111
- logger.warning("Setting pad_token_id to eos_token_id for open-end generation.")
112
- forward_kwargs["pad_token_id"] = self.tokenizer.eos_token_id
113
- # Filter out None values that might cause issues with generate
114
- forward_kwargs = {k: v for k, v in forward_kwargs.items() if v is not None}
115
-
116
- postprocess_kwargs = {}
117
- # --- Postprocessing specific args (if any in the future) ---
118
- # Example: if you added an option to return tokens instead of audio
119
- # if "return_tokens" in kwargs:
120
- # postprocess_kwargs["return_tokens"] = kwargs["return_tokens"]
121
-
122
- return preprocess_kwargs, forward_kwargs, postprocess_kwargs
123
-
124
- def preprocess(self, inputs, **preprocess_kwargs) -> dict:
125
- """
126
- Transforms text input and preprocess_kwargs into model input format.
127
-
128
- Args:
129
- inputs (str): The text to synthesize.
130
- preprocess_kwargs (dict): Arguments relevant to preprocessing (e.g., prompt paths, controls).
131
-
132
- Returns:
133
- dict: Containing `model_inputs` (tokenized dict) and `global_token_ids_prompt` (optional Tensor).
134
- """
135
- text = inputs
136
- prompt_speech_path = preprocess_kwargs.get("prompt_speech_path")
137
- prompt_text = preprocess_kwargs.get("prompt_text")
138
- gender = preprocess_kwargs.get("gender")
139
- pitch = preprocess_kwargs.get("pitch")
140
- speed = preprocess_kwargs.get("speed")
141
-
142
- global_token_ids = None
143
- llm_prompt_string = ""
144
-
145
- # --- Logic to build llm_prompt_string and get global_token_ids ---
146
- if prompt_speech_path is not None:
147
- # Voice Cloning Mode
148
- logger.info(f"Preprocessing for Voice Cloning (prompt: {prompt_speech_path})")
149
- if not Path(prompt_speech_path).exists():
150
- raise FileNotFoundError(f"Prompt speech file not found: {prompt_speech_path}")
151
- # Use the MODEL's method for tokenization (self.model is set by base Pipeline class)
152
- global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path)
153
- global_token_ids = global_tokens # Keep Tensor for detokenization
154
-
155
- global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()])
156
- if prompt_text and len(prompt_text) > 1:
157
- semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()])
158
- llm_prompt_parts = [
159
- TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>",
160
- "<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
161
- "<|start_semantic_token|>", semantic_tokens_str,
162
- ]
163
- else:
164
- llm_prompt_parts = [
165
- TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>",
166
- "<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
167
- ]
168
- llm_prompt_string = "".join(llm_prompt_parts)
169
- elif gender is not None and pitch is not None and speed is not None:
170
- # Voice Creation Mode
171
- logger.info(f"Preprocessing for Voice Creation (gender: {gender}, pitch: {pitch}, speed: {speed})")
172
- if gender not in GENDER_MAP: raise ValueError(f"Invalid gender: {gender}")
173
- if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch: {pitch}")
174
- if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed: {speed}")
175
-
176
- gender_id = GENDER_MAP[gender]
177
- pitch_level_id = LEVELS_MAP[pitch]
178
- speed_level_id = LEVELS_MAP[speed]
179
- attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>"
180
- llm_prompt_parts = [
181
- TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>",
182
- "<|start_style_label|>", attribute_tokens, "<|end_style_label|>",
183
- ]
184
- llm_prompt_string = "".join(llm_prompt_parts)
185
- else:
186
- raise ValueError("Pipeline requires 'prompt_speech_path' (for cloning) or 'gender', 'pitch', 'speed' (for creation).")
187
- # --- End prompt building logic ---
188
-
189
- # Tokenize the final prompt for the LLM
190
- # Use self.tokenizer (set by base Pipeline class)
191
- model_inputs = self.tokenizer(llm_prompt_string, return_tensors=self.framework, padding=False)
192
-
193
- return {"model_inputs": model_inputs, "global_token_ids_prompt": global_token_ids}
194
-
195
-
196
- def _forward(self, model_inputs, **forward_kwargs) -> dict:
197
- """
198
- Passes model_inputs to the model's LLM generate method with forward_kwargs.
199
-
200
- Args:
201
- model_inputs (dict): Output from `preprocess`.
202
- forward_kwargs (dict): Generation arguments from `_sanitize_parameters`.
203
-
204
- Returns:
205
- dict: Containing `generated_ids`, `input_ids_len`, and context (`global_token_ids_prompt`).
206
- """
207
- llm_inputs = model_inputs["model_inputs"]
208
- global_token_ids_prompt = model_inputs.get("global_token_ids_prompt")
209
-
210
- # Move inputs to the correct device (self.device is set by base Pipeline class)
211
- llm_inputs = {k: v.to(self.device) for k, v in llm_inputs.items()}
212
- input_ids = llm_inputs["input_ids"]
213
- input_ids_len = input_ids.shape[-1]
214
-
215
- # Combine LLM inputs and generation arguments
216
- generate_kwargs = {**llm_inputs, **forward_kwargs}
217
-
218
- logger.info(f"Generating tokens with args: {forward_kwargs}")
219
- # Use the model's LLM component (self.model is set by base Pipeline class)
220
- with torch.no_grad():
221
- generated_ids = self.model.llm.generate(**generate_kwargs)
222
-
223
- # Prepare output dict to pass to postprocess
224
- output_dict = {
225
- "generated_ids": generated_ids,
226
- "input_ids_len": input_ids_len,
227
- }
228
- if global_token_ids_prompt is not None:
229
- output_dict["global_token_ids_prompt"] = global_token_ids_prompt
230
-
231
- return output_dict
232
-
233
- def postprocess(self, model_outputs, **postprocess_kwargs) -> dict:
234
- """
235
- Transforms model outputs (from _forward) into the final audio dictionary.
236
-
237
- Args:
238
- model_outputs (dict): Dictionary from `_forward`.
239
- postprocess_kwargs (dict): Arguments relevant to postprocessing (currently none).
240
-
241
- Returns:
242
- dict: Containing `audio` (np.ndarray) and `sampling_rate` (int).
243
- """
244
- generated_ids = model_outputs["generated_ids"]
245
- input_ids_len = model_outputs["input_ids_len"]
246
- global_token_ids_prompt = model_outputs.get("global_token_ids_prompt")
247
-
248
- # --- Logic to extract tokens and detokenize ---
249
- output_ids = generated_ids[0, input_ids_len:] # Assumes batch size 1
250
- # Use self.tokenizer (set by base Pipeline class)
251
- predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
252
-
253
- semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text)
254
- if not semantic_matches:
255
- logger.warning("No semantic tokens found. Returning empty audio.")
256
- # Use self.model.config for sampling rate
257
- return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.model.config.sample_rate}
258
-
259
- pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0).to(self.device)
260
-
261
- if global_token_ids_prompt is not None:
262
- # Voice Cloning: Use prompt tokens
263
- global_token_ids = global_token_ids_prompt.to(self.device)
264
- logger.info("Using global tokens from prompt.")
265
- else:
266
- # Voice Creation: Extract generated tokens
267
- global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text)
268
- if not global_matches:
269
- raise ValueError("Voice creation failed: No bicodec_global tokens found.")
270
- global_token_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0).to(self.device)
271
- if global_token_ids.ndim == 2: global_token_ids = global_token_ids.unsqueeze(1)
272
- logger.info("Using global tokens from generated text.")
273
-
274
- # Use the MODEL's method for detokenization (self.model is set by base Pipeline class)
275
- wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids)
276
- logger.info(f"Generated audio shape: {wav_np.shape}")
277
- # --- End detokenization logic ---
278
-
279
- # Return final output dictionary
280
- return {"audio": wav_np, "sampling_rate": self.model.config.sample_rate}
281
-
282
-
283
-
284
- # --- Add Registration Code Here ---
285
- # This code will execute when this file is loaded via trust_remote_code
286
- try:
287
- from transformers.pipelines import PIPELINE_REGISTRY
288
- from transformers import AutoModel # Use AutoModel for registration
289
-
290
- print(f"Registering SparkTTSPipeline for task 'text-to-speech' from pipeline_spark_tts.py...")
291
- PIPELINE_REGISTRY.register_pipeline(
292
- "text-to-speech", # Task name
293
- pipeline_class=SparkTTSPipeline, # The class defined above
294
- pt_model=AutoModel, # Compatible PT AutoModel class
295
- # tf_model=None, # Add TF class if needed
296
- )
297
- print("Pipeline registration call completed successfully.")
298
- except ImportError:
299
- # Handle potential import error if transformers structure changes
300
- print("WARNING: Could not import PIPELINE_REGISTRY or AutoModel. Pipeline registration failed.")
301
- except Exception as e:
302
- print(f"ERROR: An unexpected error occurred during pipeline registration: {e}")
303
- # --- End Registration Code ---