File size: 15,088 Bytes
5cd61b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
# Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import re
import numpy as np
from typing import Optional, Dict, Any, Union, List
from pathlib import Path
from transformers import Pipeline, PreTrainedModel # Inherit directly from the base class
from transformers.utils import logging

# Import necessary items from this module (assuming they are defined in modeling_spark_tts)
from .modeling_spark_tts import SparkTTSModel
from .configuration_spark_tts import SparkTTSConfig
from .modeling_spark_tts import (
    load_audio,
    TASK_TOKEN_MAP,
    LEVELS_MAP,
    GENDER_MAP,
    LEVELS_MAP_UI,
)
# No need to import SparkTTSModel/Config here, pipeline gets them during init

logger = logging.get_logger(__name__)

# Define constants if needed (e.g., for default generation args)
DEFAULT_MAX_NEW_TOKENS = 3000
DEFAULT_TEMPERATURE = 0.8
DEFAULT_TOP_K = 50
DEFAULT_TOP_P = 0.95

class SparkTTSPipeline(Pipeline):
    """
    Custom Pipeline for SparkTTS text-to-speech generation, following HF documentation structure.
    Handles voice cloning and voice creation modes.
    """
    def __init__(self, model, tokenizer=None, framework="pt", device=None, **kwargs):
        # --- KHÔNG NÊN load model ở đây nữa ---
        # __init__ của pipeline tùy chỉnh nên nhận model và tokenizer đã được load
        # Việc load nên xảy ra BÊN NGOÀI trước khi gọi pipeline() hoặc do pipeline factory xử lý

        # Kiểm tra model và tokenizer được truyền vào
        if model is None:
             raise ValueError("SparkTTSPipeline requires a 'model' argument.")
        if not isinstance(model, SparkTTSModel):
             # Có thể model được load bằng AutoModel nên là PreTrainedModel, kiểm tra config
             if isinstance(model, PreTrainedModel) and isinstance(model.config, SparkTTSConfig):
                  pass # OK, model tương thích dựa trên config
             else:
                  raise TypeError(f"Expected model compatible with SparkTTSConfig, but got {type(model)}")

        if tokenizer is None:
             raise ValueError("SparkTTSPipeline requires a 'tokenizer' argument.")
        if not hasattr(tokenizer, 'encode') or not hasattr(tokenizer, 'batch_decode'):
              raise TypeError("Tokenizer does not seem to be a valid Transformers tokenizer.")

        # Gọi super().__init__ với model/tokenizer đã nhận
        super().__init__(model=model, tokenizer=tokenizer, framework=framework, device=device, **kwargs)
        if hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'):
             self.sampling_rate = self.model.config.sample_rate
        else:
             # Nên đặt giá trị mặc định hoặc lấy từ nơi khác nếu config không có
             logger.warning("Could not determine sampling rate from model config. Defaulting to 16000.")
             self.sampling_rate = 16000

    def _sanitize_parameters(self, **kwargs) -> tuple[dict, dict, dict]:
        """
        Sanitizes pipeline parameters and separates them for preprocess, forward, and postprocess.

        Returns:
            Tuple[dict, dict, dict]: preprocess_kwargs, forward_kwargs, postprocess_kwargs
        """
        preprocess_kwargs = {}
        # --- Preprocessing specific args ---
        if "prompt_speech_path" in kwargs:
            preprocess_kwargs["prompt_speech_path"] = kwargs["prompt_speech_path"]
        if "prompt_text" in kwargs:
            preprocess_kwargs["prompt_text"] = kwargs["prompt_text"]
        if "gender" in kwargs:
            preprocess_kwargs["gender"] = kwargs["gender"]
        if "pitch" in kwargs:
            preprocess_kwargs["pitch"] = kwargs["pitch"]
        if "speed" in kwargs:
            preprocess_kwargs["speed"] = kwargs["speed"]

        forward_kwargs = {}
        # --- Forward specific args (LLM generation) ---
        # Use kwargs.get to allow users to override defaults
        forward_kwargs["max_new_tokens"] = kwargs.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
        forward_kwargs["do_sample"] = kwargs.get("do_sample", True)
        forward_kwargs["temperature"] = kwargs.get("temperature", DEFAULT_TEMPERATURE)
        forward_kwargs["top_k"] = kwargs.get("top_k", DEFAULT_TOP_K)
        forward_kwargs["top_p"] = kwargs.get("top_p", DEFAULT_TOP_P)
        # Ensure essential generation tokens are present if needed
        if self.tokenizer.eos_token_id is not None:
            forward_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
        if self.tokenizer.pad_token_id is not None:
            forward_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
        elif self.tokenizer.eos_token_id is not None:
            logger.warning("Setting pad_token_id to eos_token_id for open-end generation.")
            forward_kwargs["pad_token_id"] = self.tokenizer.eos_token_id
        # Filter out None values that might cause issues with generate
        forward_kwargs = {k: v for k, v in forward_kwargs.items() if v is not None}

        postprocess_kwargs = {}
        # --- Postprocessing specific args (if any in the future) ---
        # Example: if you added an option to return tokens instead of audio
        # if "return_tokens" in kwargs:
        #    postprocess_kwargs["return_tokens"] = kwargs["return_tokens"]

        return preprocess_kwargs, forward_kwargs, postprocess_kwargs

    def preprocess(self, inputs, **preprocess_kwargs) -> dict:
        """
        Transforms text input and preprocess_kwargs into model input format.

        Args:
            inputs (str): The text to synthesize.
            preprocess_kwargs (dict): Arguments relevant to preprocessing (e.g., prompt paths, controls).

        Returns:
            dict: Containing `model_inputs` (tokenized dict) and `global_token_ids_prompt` (optional Tensor).
        """
        text = inputs
        prompt_speech_path = preprocess_kwargs.get("prompt_speech_path")
        prompt_text = preprocess_kwargs.get("prompt_text")
        gender = preprocess_kwargs.get("gender")
        pitch = preprocess_kwargs.get("pitch")
        speed = preprocess_kwargs.get("speed")

        global_token_ids = None
        llm_prompt_string = ""

        # --- Logic to build llm_prompt_string and get global_token_ids ---
        if prompt_speech_path is not None:
            # Voice Cloning Mode
            logger.info(f"Preprocessing for Voice Cloning (prompt: {prompt_speech_path})")
            if not Path(prompt_speech_path).exists():
                 raise FileNotFoundError(f"Prompt speech file not found: {prompt_speech_path}")
            # Use the MODEL's method for tokenization (self.model is set by base Pipeline class)
            global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path)
            global_token_ids = global_tokens # Keep Tensor for detokenization

            global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()])
            if prompt_text and len(prompt_text) > 1:
                semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()])
                llm_prompt_parts = [
                    TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>",
                    "<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
                    "<|start_semantic_token|>", semantic_tokens_str,
                ]
            else:
                 llm_prompt_parts = [
                    TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>",
                    "<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
                 ]
            llm_prompt_string = "".join(llm_prompt_parts)
        elif gender is not None and pitch is not None and speed is not None:
            # Voice Creation Mode
            logger.info(f"Preprocessing for Voice Creation (gender: {gender}, pitch: {pitch}, speed: {speed})")
            if gender not in GENDER_MAP: raise ValueError(f"Invalid gender: {gender}")
            if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch: {pitch}")
            if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed: {speed}")

            gender_id = GENDER_MAP[gender]
            pitch_level_id = LEVELS_MAP[pitch]
            speed_level_id = LEVELS_MAP[speed]
            attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>"
            llm_prompt_parts = [
                TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>",
                "<|start_style_label|>", attribute_tokens, "<|end_style_label|>",
            ]
            llm_prompt_string = "".join(llm_prompt_parts)
        else:
             raise ValueError("Pipeline requires 'prompt_speech_path' (for cloning) or 'gender', 'pitch', 'speed' (for creation).")
        # --- End prompt building logic ---

        # Tokenize the final prompt for the LLM
        # Use self.tokenizer (set by base Pipeline class)
        model_inputs = self.tokenizer(llm_prompt_string, return_tensors=self.framework, padding=False)

        return {"model_inputs": model_inputs, "global_token_ids_prompt": global_token_ids}


    def _forward(self, model_inputs, **forward_kwargs) -> dict:
        """
        Passes model_inputs to the model's LLM generate method with forward_kwargs.

        Args:
            model_inputs (dict): Output from `preprocess`.
            forward_kwargs (dict): Generation arguments from `_sanitize_parameters`.

        Returns:
            dict: Containing `generated_ids`, `input_ids_len`, and context (`global_token_ids_prompt`).
        """
        llm_inputs = model_inputs["model_inputs"]
        global_token_ids_prompt = model_inputs.get("global_token_ids_prompt")

        # Move inputs to the correct device (self.device is set by base Pipeline class)
        llm_inputs = {k: v.to(self.device) for k, v in llm_inputs.items()}
        input_ids = llm_inputs["input_ids"]
        input_ids_len = input_ids.shape[-1]

        # Combine LLM inputs and generation arguments
        generate_kwargs = {**llm_inputs, **forward_kwargs}

        logger.info(f"Generating tokens with args: {forward_kwargs}")
        # Use the model's LLM component (self.model is set by base Pipeline class)
        with torch.no_grad():
            generated_ids = self.model.llm.generate(**generate_kwargs)

        # Prepare output dict to pass to postprocess
        output_dict = {
            "generated_ids": generated_ids,
            "input_ids_len": input_ids_len,
        }
        if global_token_ids_prompt is not None:
            output_dict["global_token_ids_prompt"] = global_token_ids_prompt

        return output_dict

    def postprocess(self, model_outputs, **postprocess_kwargs) -> dict:
        """
        Transforms model outputs (from _forward) into the final audio dictionary.

        Args:
            model_outputs (dict): Dictionary from `_forward`.
            postprocess_kwargs (dict): Arguments relevant to postprocessing (currently none).

        Returns:
            dict: Containing `audio` (np.ndarray) and `sampling_rate` (int).
        """
        generated_ids = model_outputs["generated_ids"]
        input_ids_len = model_outputs["input_ids_len"]
        global_token_ids_prompt = model_outputs.get("global_token_ids_prompt")

        # --- Logic to extract tokens and detokenize ---
        output_ids = generated_ids[0, input_ids_len:] # Assumes batch size 1
        # Use self.tokenizer (set by base Pipeline class)
        predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)

        semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text)
        if not semantic_matches:
             logger.warning("No semantic tokens found. Returning empty audio.")
             # Use self.model.config for sampling rate
             return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.model.config.sample_rate}

        pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0).to(self.device)

        if global_token_ids_prompt is not None:
             # Voice Cloning: Use prompt tokens
             global_token_ids = global_token_ids_prompt.to(self.device)
             logger.info("Using global tokens from prompt.")
        else:
             # Voice Creation: Extract generated tokens
             global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text)
             if not global_matches:
                  raise ValueError("Voice creation failed: No bicodec_global tokens found.")
             global_token_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0).to(self.device)
             if global_token_ids.ndim == 2: global_token_ids = global_token_ids.unsqueeze(1)
             logger.info("Using global tokens from generated text.")

        # Use the MODEL's method for detokenization (self.model is set by base Pipeline class)
        wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids)
        logger.info(f"Generated audio shape: {wav_np.shape}")
        # --- End detokenization logic ---

        # Return final output dictionary
        return {"audio": wav_np, "sampling_rate": self.model.config.sample_rate}



# --- Add Registration Code Here ---
# This code will execute when this file is loaded via trust_remote_code
try:
    from transformers.pipelines import PIPELINE_REGISTRY
    from transformers import AutoModel # Use AutoModel for registration

    print(f"Registering SparkTTSPipeline for task 'text-to-speech' from pipeline_spark_tts.py...")
    PIPELINE_REGISTRY.register_pipeline(
        "text-to-speech",                 # Task name
        pipeline_class=SparkTTSPipeline,  # The class defined above
        pt_model=AutoModel,               # Compatible PT AutoModel class
        # tf_model=None,                  # Add TF class if needed
    )
    print("Pipeline registration call completed successfully.")
except ImportError:
     # Handle potential import error if transformers structure changes
     print("WARNING: Could not import PIPELINE_REGISTRY or AutoModel. Pipeline registration failed.")
except Exception as e:
     print(f"ERROR: An unexpected error occurred during pipeline registration: {e}")
# --- End Registration Code ---