Upload 24 files
Browse files- .gitattributes +3 -32
- BiCodec/config.yaml +60 -0
- BiCodec/model.safetensors +3 -0
- LLM/added_tokens.json +0 -0
- LLM/config.json +27 -0
- LLM/merges.txt +0 -0
- LLM/model.safetensors +3 -0
- LLM/special_tokens_map.json +31 -0
- LLM/tokenizer.json +3 -0
- LLM/tokenizer_config.json +0 -0
- LLM/vocab.json +0 -0
- __init__.py +32 -0
- _modeling_bicodec_components.py +0 -0
- _utils.py +151 -0
- config.json +90 -0
- configuration_spark_tts.py +86 -0
- model_index.json +39 -0
- modeling_spark_tts.py +0 -0
- pipeline_spark_tts.py +303 -0
- processing_spark_tts.py +345 -0
- wav2vec2-large-xlsr-53/README.md +29 -0
- wav2vec2-large-xlsr-53/config.json +83 -0
- wav2vec2-large-xlsr-53/preprocessor_config.json +9 -0
- wav2vec2-large-xlsr-53/pytorch_model.bin +3 -0
.gitattributes
CHANGED
@@ -1,35 +1,6 @@
|
|
1 |
-
*.
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.
|
24 |
-
*.
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
2 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
4 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -textLLM/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BiCodec/config.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
audio_tokenizer:
|
2 |
+
mel_params:
|
3 |
+
sample_rate: 16000
|
4 |
+
n_fft: 1024
|
5 |
+
win_length: 640
|
6 |
+
hop_length: 320
|
7 |
+
mel_fmin: 10
|
8 |
+
mel_fmax: null
|
9 |
+
num_mels: 128
|
10 |
+
|
11 |
+
encoder:
|
12 |
+
input_channels: 1024
|
13 |
+
vocos_dim: 384
|
14 |
+
vocos_intermediate_dim: 2048
|
15 |
+
vocos_num_layers: 12
|
16 |
+
out_channels: 1024
|
17 |
+
sample_ratios: [1,1]
|
18 |
+
|
19 |
+
decoder:
|
20 |
+
input_channel: 1024
|
21 |
+
channels: 1536
|
22 |
+
rates: [8, 5, 4, 2]
|
23 |
+
kernel_sizes: [16,11,8,4]
|
24 |
+
|
25 |
+
quantizer:
|
26 |
+
input_dim: 1024
|
27 |
+
codebook_size: 8192
|
28 |
+
codebook_dim: 8
|
29 |
+
commitment: 0.25
|
30 |
+
codebook_loss_weight: 2.0
|
31 |
+
use_l2_normlize: True
|
32 |
+
threshold_ema_dead_code: 0.2
|
33 |
+
|
34 |
+
speaker_encoder:
|
35 |
+
input_dim: 128
|
36 |
+
out_dim: 1024
|
37 |
+
latent_dim: 128
|
38 |
+
token_num: 32
|
39 |
+
fsq_levels: [4, 4, 4, 4, 4, 4]
|
40 |
+
fsq_num_quantizers: 1
|
41 |
+
|
42 |
+
prenet:
|
43 |
+
input_channels: 1024
|
44 |
+
vocos_dim: 384
|
45 |
+
vocos_intermediate_dim: 2048
|
46 |
+
vocos_num_layers: 12
|
47 |
+
out_channels: 1024
|
48 |
+
condition_dim: 1024
|
49 |
+
sample_ratios: [1,1]
|
50 |
+
use_tanh_at_final: False
|
51 |
+
|
52 |
+
postnet:
|
53 |
+
input_channels: 1024
|
54 |
+
vocos_dim: 384
|
55 |
+
vocos_intermediate_dim: 2048
|
56 |
+
vocos_num_layers: 6
|
57 |
+
out_channels: 1024
|
58 |
+
use_tanh_at_final: False
|
59 |
+
|
60 |
+
|
BiCodec/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e9940cd48d4446e4340ced82d234bf5618350dd9f5db900ebe47a4fdb03867ec
|
3 |
+
size 625518756
|
LLM/added_tokens.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LLM/config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Qwen2ForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_dropout": 0.0,
|
6 |
+
"bos_token_id": 151643,
|
7 |
+
"eos_token_id": 151645,
|
8 |
+
"hidden_act": "silu",
|
9 |
+
"hidden_size": 896,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 4864,
|
12 |
+
"max_position_embeddings": 32768,
|
13 |
+
"max_window_layers": 21,
|
14 |
+
"model_type": "qwen2",
|
15 |
+
"num_attention_heads": 14,
|
16 |
+
"num_hidden_layers": 24,
|
17 |
+
"num_key_value_heads": 2,
|
18 |
+
"rms_norm_eps": 1e-06,
|
19 |
+
"rope_theta": 1000000.0,
|
20 |
+
"sliding_window": 32768,
|
21 |
+
"tie_word_embeddings": true,
|
22 |
+
"torch_dtype": "bfloat16",
|
23 |
+
"transformers_version": "4.43.1",
|
24 |
+
"use_cache": true,
|
25 |
+
"use_sliding_window": false,
|
26 |
+
"vocab_size": 166000
|
27 |
+
}
|
LLM/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LLM/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54825baf0a2f6076eb3c78fa1d22a95aee225f59070a8b295f8169db860eb109
|
3 |
+
size 2026568968
|
LLM/special_tokens_map.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|im_start|>",
|
4 |
+
"<|im_end|>",
|
5 |
+
"<|object_ref_start|>",
|
6 |
+
"<|object_ref_end|>",
|
7 |
+
"<|box_start|>",
|
8 |
+
"<|box_end|>",
|
9 |
+
"<|quad_start|>",
|
10 |
+
"<|quad_end|>",
|
11 |
+
"<|vision_start|>",
|
12 |
+
"<|vision_end|>",
|
13 |
+
"<|vision_pad|>",
|
14 |
+
"<|image_pad|>",
|
15 |
+
"<|video_pad|>"
|
16 |
+
],
|
17 |
+
"eos_token": {
|
18 |
+
"content": "<|im_end|>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": false,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
},
|
24 |
+
"pad_token": {
|
25 |
+
"content": "<|endoftext|>",
|
26 |
+
"lstrip": false,
|
27 |
+
"normalized": false,
|
28 |
+
"rstrip": false,
|
29 |
+
"single_word": false
|
30 |
+
}
|
31 |
+
}
|
LLM/tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c8b057d6ca205a429cc3428b9fc815f0d6ee1d53106dd5e5b129ef9db2ff057
|
3 |
+
size 14129172
|
LLM/tokenizer_config.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LLM/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
__init__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# # SparkAudio/Spark-TTS-0.5B_v2/__init__.py
|
2 |
+
|
3 |
+
# import os
|
4 |
+
# from pathlib import Path
|
5 |
+
# # Import registry và AutoModel để đăng ký
|
6 |
+
# from transformers.pipelines import PIPELINE_REGISTRY
|
7 |
+
# from transformers import AutoModel
|
8 |
+
|
9 |
+
# # Print để xác nhận file này được thực thi
|
10 |
+
# print(f"Executing __init__.py for SparkTTS custom module in: {__file__}")
|
11 |
+
|
12 |
+
# try:
|
13 |
+
# # Import lớp Pipeline tùy chỉnh để đăng ký
|
14 |
+
# from .pipeline_spark_tts import SparkTTSPipeline
|
15 |
+
|
16 |
+
# # Đăng ký pipeline tùy chỉnh với registry
|
17 |
+
# # Đảm bảo tên task "text-to-speech" khớp với tên bạn dùng khi gọi pipeline()
|
18 |
+
# print(f"Attempting to register SparkTTSPipeline for task 'text-to-speech'...")
|
19 |
+
# PIPELINE_REGISTRY.register_pipeline(
|
20 |
+
# "text-to-speech", # Tên task
|
21 |
+
# pipeline_class=SparkTTSPipeline, # Lớp pipeline tùy chỉnh của bạn
|
22 |
+
# pt_model=AutoModel, # Lớp AutoModel tương thích cho PyTorch
|
23 |
+
# # tf_model=None, # Thêm lớp TF nếu có
|
24 |
+
# type="text", # Kiểu input chính (ví dụ: text, audio, image)
|
25 |
+
# )
|
26 |
+
# print("Pipeline registration call completed.")
|
27 |
+
|
28 |
+
# # Bắt lỗi nếu import thất bại (ví dụ: thiếu dependency)
|
29 |
+
# except ImportError as e:
|
30 |
+
# print(f"WARNING: Could not import SparkTTSPipeline in __init__.py: {e}. Pipeline registration failed.")
|
31 |
+
# except Exception as e:
|
32 |
+
# print(f"ERROR: An unexpected error occurred during SparkTTS __init__ registration: {e}")
|
_modeling_bicodec_components.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
_utils.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
""" Utility functions for SparkTTS """
|
15 |
+
|
16 |
+
import random
|
17 |
+
import soxr
|
18 |
+
import soundfile
|
19 |
+
import torch
|
20 |
+
import torchaudio
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Tuple, Dict, Any
|
25 |
+
from numpy.lib.stride_tricks import sliding_window_view
|
26 |
+
from omegaconf import OmegaConf # Keep if BiCodec config loading needs it
|
27 |
+
|
28 |
+
|
29 |
+
# --- Token Maps (from sparktts/utils/token_parser.py) ---
|
30 |
+
TASK_TOKEN_MAP = {
|
31 |
+
"vc": "<|task_vc|>",
|
32 |
+
"tts": "<|task_tts|>",
|
33 |
+
"asr": "<|task_asr|>",
|
34 |
+
"s2s": "<|task_s2s|>",
|
35 |
+
"t2s": "<|task_t2s|>",
|
36 |
+
"understand": "<|task_understand|>",
|
37 |
+
"caption": "<|task_cap|>",
|
38 |
+
"controllable_tts": "<|task_controllable_tts|>",
|
39 |
+
"prompt_tts": "<|task_prompt_tts|>",
|
40 |
+
"speech_edit": "<|task_edit|>",
|
41 |
+
}
|
42 |
+
|
43 |
+
LEVELS_MAP = {
|
44 |
+
"very_low": 0,
|
45 |
+
"low": 1,
|
46 |
+
"moderate": 2,
|
47 |
+
"high": 3,
|
48 |
+
"very_high": 4,
|
49 |
+
}
|
50 |
+
|
51 |
+
LEVELS_MAP_UI = {
|
52 |
+
1: 'very_low',
|
53 |
+
2: 'low',
|
54 |
+
3: 'moderate',
|
55 |
+
4: 'high',
|
56 |
+
5: 'very_high'
|
57 |
+
}
|
58 |
+
|
59 |
+
GENDER_MAP = {
|
60 |
+
"female": 0,
|
61 |
+
"male": 1,
|
62 |
+
}
|
63 |
+
|
64 |
+
# --- Audio Utils (from sparktts/utils/audio.py) ---
|
65 |
+
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
|
66 |
+
temp = np.sort(np.abs(audio))
|
67 |
+
if len(temp) == 0: # Handle empty audio case
|
68 |
+
return audio
|
69 |
+
if temp[-1] < 0.1:
|
70 |
+
scaling_factor = max(temp[-1], 1e-3)
|
71 |
+
audio = audio / scaling_factor * 0.1
|
72 |
+
temp = temp[temp > 0.01]
|
73 |
+
L = temp.shape[0]
|
74 |
+
if L <= 10:
|
75 |
+
return audio
|
76 |
+
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
|
77 |
+
if volume == 0: # Avoid division by zero if volume is effectively zero
|
78 |
+
return audio
|
79 |
+
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
|
80 |
+
max_value = np.max(np.abs(audio)) if len(audio) > 0 else 0
|
81 |
+
if max_value > 1:
|
82 |
+
audio = audio / max_value
|
83 |
+
return audio
|
84 |
+
|
85 |
+
def load_audio(
|
86 |
+
adfile: Path,
|
87 |
+
sampling_rate: int = None,
|
88 |
+
length: int = None,
|
89 |
+
volume_normalize: bool = False,
|
90 |
+
segment_duration: int = None,
|
91 |
+
) -> np.ndarray:
|
92 |
+
try:
|
93 |
+
audio, sr = soundfile.read(adfile, dtype='float32') # Ensure float32
|
94 |
+
except Exception as e:
|
95 |
+
raise IOError(f"Could not read audio file {adfile}: {e}")
|
96 |
+
|
97 |
+
if audio is None or len(audio) == 0:
|
98 |
+
raise ValueError(f"Audio file {adfile} is empty or invalid.")
|
99 |
+
|
100 |
+
if len(audio.shape) > 1:
|
101 |
+
audio = audio[:, 0]
|
102 |
+
|
103 |
+
if sampling_rate is not None and sr != sampling_rate:
|
104 |
+
try:
|
105 |
+
# Ensure input is float64 for soxr
|
106 |
+
audio = audio.astype(np.float64)
|
107 |
+
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
|
108 |
+
# Convert back to float32
|
109 |
+
audio = audio.astype(np.float32)
|
110 |
+
sr = sampling_rate
|
111 |
+
except Exception as e:
|
112 |
+
raise RuntimeError(f"Failed to resample audio from {sr}Hz to {sampling_rate}Hz: {e}")
|
113 |
+
|
114 |
+
if segment_duration is not None:
|
115 |
+
seg_length = int(sr * segment_duration)
|
116 |
+
audio = random_select_audio_segment(audio, seg_length)
|
117 |
+
|
118 |
+
if volume_normalize:
|
119 |
+
audio = audio_volume_normalize(audio)
|
120 |
+
|
121 |
+
if length is not None:
|
122 |
+
if audio.shape[0] > length:
|
123 |
+
audio = audio[:length]
|
124 |
+
else:
|
125 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant')
|
126 |
+
return audio
|
127 |
+
|
128 |
+
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
|
129 |
+
if audio.shape[0] < length:
|
130 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant')
|
131 |
+
start_index = 0 # If padded, start from beginning
|
132 |
+
elif audio.shape[0] == length:
|
133 |
+
start_index = 0 # If exact length, start from beginning
|
134 |
+
else:
|
135 |
+
start_index = random.randint(0, audio.shape[0] - length)
|
136 |
+
|
137 |
+
end_index = int(start_index + length)
|
138 |
+
return audio[start_index:end_index]
|
139 |
+
|
140 |
+
# --- File Utils (Minimal required) ---
|
141 |
+
def load_config_yaml(config_path: Path) -> Dict:
|
142 |
+
"""Loads a YAML configuration file using OmegaConf."""
|
143 |
+
# Check if path exists
|
144 |
+
if not Path(config_path).is_file():
|
145 |
+
raise FileNotFoundError(f"YAML Config file not found: {config_path}")
|
146 |
+
try:
|
147 |
+
config = OmegaConf.load(config_path)
|
148 |
+
# Convert OmegaConf DictConfig to standard Python dict
|
149 |
+
return OmegaConf.to_container(config, resolve=True)
|
150 |
+
except Exception as e:
|
151 |
+
raise IOError(f"Error loading YAML config file {config_path}: {e}")
|
config.json
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "spark-tts",
|
3 |
+
"architectures": [
|
4 |
+
"SparkTTSModel"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_spark_tts.SparkTTSConfig",
|
8 |
+
"AutoModel": "modeling_spark_tts.SparkTTSModel",
|
9 |
+
"AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
|
10 |
+
},
|
11 |
+
"processor_class": "processing_spark_tts.SparkTTSProcessor",
|
12 |
+
"custom_pipelines": {
|
13 |
+
"text-to-speech": {
|
14 |
+
"impl": "pipeline_spark_tts.SparkTTSPipeline",
|
15 |
+
"pt": ["AutoModel"]
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"llm_model_name_or_path": "./LLM",
|
19 |
+
"bicodec_model_name_or_path": "./BiCodec",
|
20 |
+
"wav2vec2_model_name_or_path": "./wav2vec2-large-xlsr-53",
|
21 |
+
"sample_rate": 16000,
|
22 |
+
"highpass_cutoff_freq": 40,
|
23 |
+
"latent_hop_length": 320,
|
24 |
+
"ref_segment_duration": 6.0,
|
25 |
+
"volume_normalize": true,
|
26 |
+
"bicodec_config": {
|
27 |
+
"audio_tokenizer": {
|
28 |
+
"mel_params": {
|
29 |
+
"sample_rate": 16000,
|
30 |
+
"n_fft": 1024,
|
31 |
+
"win_length": 640,
|
32 |
+
"hop_length": 320,
|
33 |
+
"mel_fmin": 10,
|
34 |
+
"mel_fmax": null,
|
35 |
+
"num_mels": 128
|
36 |
+
},
|
37 |
+
"encoder": {
|
38 |
+
"input_channels": 1024,
|
39 |
+
"vocos_dim": 384,
|
40 |
+
"vocos_intermediate_dim": 2048,
|
41 |
+
"vocos_num_layers": 12,
|
42 |
+
"out_channels": 1024,
|
43 |
+
"sample_ratios": [1, 1]
|
44 |
+
},
|
45 |
+
"decoder": {
|
46 |
+
"input_channel": 1024,
|
47 |
+
"channels": 1536,
|
48 |
+
"rates": [8, 5, 4, 2],
|
49 |
+
"kernel_sizes": [16, 11, 8, 4]
|
50 |
+
},
|
51 |
+
"quantizer": {
|
52 |
+
"input_dim": 1024,
|
53 |
+
"codebook_size": 8192,
|
54 |
+
"codebook_dim": 8,
|
55 |
+
"commitment": 0.25,
|
56 |
+
"codebook_loss_weight": 2.0,
|
57 |
+
"use_l2_normlize": true,
|
58 |
+
"threshold_ema_dead_code": 0.2
|
59 |
+
},
|
60 |
+
"speaker_encoder": {
|
61 |
+
"input_dim": 128,
|
62 |
+
"out_dim": 1024,
|
63 |
+
"latent_dim": 128,
|
64 |
+
"token_num": 32,
|
65 |
+
"fsq_levels": [4, 4, 4, 4, 4, 4],
|
66 |
+
"fsq_num_quantizers": 1
|
67 |
+
},
|
68 |
+
"prenet": {
|
69 |
+
"input_channels": 1024,
|
70 |
+
"vocos_dim": 384,
|
71 |
+
"vocos_intermediate_dim": 2048,
|
72 |
+
"vocos_num_layers": 12,
|
73 |
+
"out_channels": 1024,
|
74 |
+
"condition_dim": 1024,
|
75 |
+
"sample_ratios": [1, 1],
|
76 |
+
"use_tanh_at_final": false
|
77 |
+
},
|
78 |
+
"postnet": {
|
79 |
+
"input_channels": 1024,
|
80 |
+
"vocos_dim": 384,
|
81 |
+
"vocos_intermediate_dim": 2048,
|
82 |
+
"vocos_num_layers": 6,
|
83 |
+
"out_channels": 1024,
|
84 |
+
"use_tanh_at_final": false
|
85 |
+
}
|
86 |
+
}
|
87 |
+
},
|
88 |
+
"torch_dtype": "bfloat16",
|
89 |
+
"transformers_version": "4.43.1"
|
90 |
+
}
|
configuration_spark_tts.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
""" SparkTTS model configuration"""
|
15 |
+
|
16 |
+
from transformers.configuration_utils import PretrainedConfig
|
17 |
+
from transformers.utils import logging
|
18 |
+
|
19 |
+
|
20 |
+
logger = logging.get_logger(__name__)
|
21 |
+
|
22 |
+
class SparkTTSConfig(PretrainedConfig):
|
23 |
+
"""
|
24 |
+
This is the configuration class to store the configuration of a [`SparkTTSModel`].
|
25 |
+
It is used to instantiate a SparkTTS model according to the specified arguments, defining the model
|
26 |
+
architecture and sub-component paths.
|
27 |
+
|
28 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
|
29 |
+
Read the documentation from [`PretrainedConfig`] for more information.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`):
|
33 |
+
Path to the pretrained LLM model or model identifier from huggingface.co/models.
|
34 |
+
bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`):
|
35 |
+
Path to the pretrained BiCodec model directory.
|
36 |
+
wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`):
|
37 |
+
Path to the pretrained Wav2Vec2 model directory.
|
38 |
+
sample_rate (`int`, *optional*, defaults to 16000):
|
39 |
+
The sampling rate of the audio files.
|
40 |
+
highpass_cutoff_freq (`int`, *optional*, defaults to 40):
|
41 |
+
Highpass filter cutoff frequency for audio processing.
|
42 |
+
latent_hop_length (`int`, *optional*, defaults to 320):
|
43 |
+
Hop length used in BiCodec processing.
|
44 |
+
ref_segment_duration (`float`, *optional*, defaults to 6.0):
|
45 |
+
Duration (in seconds) of the reference audio clip used for speaker embedding.
|
46 |
+
volume_normalize (`bool`, *optional*, defaults to `True`):
|
47 |
+
Whether to normalize the volume of audio inputs.
|
48 |
+
bicodec_config (`dict`, *optional*):
|
49 |
+
A dictionary containing the configuration for the BiCodec model components (encoder, decoder, etc.).
|
50 |
+
This is typically loaded from the `BiCodec/config.yaml` originally.
|
51 |
+
**kwargs
|
52 |
+
Additional keyword arguments passed along to [`PretrainedConfig`].
|
53 |
+
"""
|
54 |
+
|
55 |
+
model_type = "spark-tts"
|
56 |
+
processor_class = "SparkTTSProcessor"
|
57 |
+
attribute_map = {} # Add mappings if needed for renaming attributes
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
llm_model_name_or_path="./LLM",
|
62 |
+
bicodec_model_name_or_path="./BiCodec",
|
63 |
+
wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
|
64 |
+
sample_rate=16000,
|
65 |
+
highpass_cutoff_freq=40,
|
66 |
+
latent_hop_length=320,
|
67 |
+
ref_segment_duration=6.0,
|
68 |
+
volume_normalize=True,
|
69 |
+
bicodec_config=None,
|
70 |
+
**kwargs,
|
71 |
+
):
|
72 |
+
self.llm_model_name_or_path = llm_model_name_or_path
|
73 |
+
self.bicodec_model_name_or_path = bicodec_model_name_or_path
|
74 |
+
self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
|
75 |
+
self.sample_rate = sample_rate
|
76 |
+
self.highpass_cutoff_freq = highpass_cutoff_freq
|
77 |
+
self.latent_hop_length = latent_hop_length
|
78 |
+
self.ref_segment_duration = ref_segment_duration
|
79 |
+
self.volume_normalize = volume_normalize
|
80 |
+
self.bicodec_config = bicodec_config if bicodec_config is not None else {}
|
81 |
+
|
82 |
+
# REMOVE THIS WARNING - the check in SparkTTSModel is better
|
83 |
+
# if not self.bicodec_config:
|
84 |
+
# logger.warning("BiCodec config is empty. BiCodec model might not load correctly.")
|
85 |
+
|
86 |
+
super().__init__(**kwargs)
|
model_index.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_index": [
|
3 |
+
{
|
4 |
+
"name": "SparkTTS",
|
5 |
+
"results": [],
|
6 |
+
"config": {
|
7 |
+
"architectures": [
|
8 |
+
"SparkTTSForConditionalGeneration"
|
9 |
+
],
|
10 |
+
"model_type": "sparktts"
|
11 |
+
},
|
12 |
+
"filenames": [
|
13 |
+
"config.json",
|
14 |
+
"model_index.json",
|
15 |
+
"modeling_sparktts.py",
|
16 |
+
"pipeline_sparktts.py",
|
17 |
+
"__init__.py",
|
18 |
+
"LLM/config.json",
|
19 |
+
"LLM/pytorch_model.bin",
|
20 |
+
"LLM/tokenizer.json",
|
21 |
+
"BiCodec/config.yaml",
|
22 |
+
"BiCodec/model.safetensors",
|
23 |
+
"wav2vec2-large-xlsr-53/config.json",
|
24 |
+
"wav2vec2-large-xlsr-53/pytorch_model.bin",
|
25 |
+
"wav2vec2-large-xlsr-53/preprocessor_config.json",
|
26 |
+
"sparktts_code/models/audio_tokenizer.py",
|
27 |
+
"sparktts_code/models/bicodec.py",
|
28 |
+
"sparktts_code/modules/speaker/speaker_encoder.py",
|
29 |
+
"sparktts_code/modules/encoder_decoder/feat_encoder.py",
|
30 |
+
"sparktts_code/modules/encoder_decoder/feat_decoder.py",
|
31 |
+
"sparktts_code/modules/encoder_decoder/wave_generator.py",
|
32 |
+
"sparktts_code/modules/vq/factorized_vector_quantize.py",
|
33 |
+
"sparktts_code/utils/audio.py",
|
34 |
+
"sparktts_code/utils/file.py",
|
35 |
+
"sparktts_code/utils/token_parser.py"
|
36 |
+
]
|
37 |
+
}
|
38 |
+
]
|
39 |
+
}
|
modeling_spark_tts.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pipeline_spark_tts.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import re
|
17 |
+
import numpy as np
|
18 |
+
from typing import Optional, Dict, Any, Union, List
|
19 |
+
from pathlib import Path
|
20 |
+
from transformers import Pipeline, PreTrainedModel # Inherit directly from the base class
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
# Import necessary items from this module (assuming they are defined in modeling_spark_tts)
|
24 |
+
from .modeling_spark_tts import SparkTTSModel
|
25 |
+
from .configuration_spark_tts import SparkTTSConfig
|
26 |
+
from .modeling_spark_tts import (
|
27 |
+
load_audio,
|
28 |
+
TASK_TOKEN_MAP,
|
29 |
+
LEVELS_MAP,
|
30 |
+
GENDER_MAP,
|
31 |
+
LEVELS_MAP_UI,
|
32 |
+
)
|
33 |
+
# No need to import SparkTTSModel/Config here, pipeline gets them during init
|
34 |
+
|
35 |
+
logger = logging.get_logger(__name__)
|
36 |
+
|
37 |
+
# Define constants if needed (e.g., for default generation args)
|
38 |
+
DEFAULT_MAX_NEW_TOKENS = 3000
|
39 |
+
DEFAULT_TEMPERATURE = 0.8
|
40 |
+
DEFAULT_TOP_K = 50
|
41 |
+
DEFAULT_TOP_P = 0.95
|
42 |
+
|
43 |
+
class SparkTTSPipeline(Pipeline):
|
44 |
+
"""
|
45 |
+
Custom Pipeline for SparkTTS text-to-speech generation, following HF documentation structure.
|
46 |
+
Handles voice cloning and voice creation modes.
|
47 |
+
"""
|
48 |
+
def __init__(self, model, tokenizer=None, framework="pt", device=None, **kwargs):
|
49 |
+
# --- KHÔNG NÊN load model ở đây nữa ---
|
50 |
+
# __init__ của pipeline tùy chỉnh nên nhận model và tokenizer đã được load
|
51 |
+
# Việc load nên xảy ra BÊN NGOÀI trước khi gọi pipeline() hoặc do pipeline factory xử lý
|
52 |
+
|
53 |
+
# Kiểm tra model và tokenizer được truyền vào
|
54 |
+
if model is None:
|
55 |
+
raise ValueError("SparkTTSPipeline requires a 'model' argument.")
|
56 |
+
if not isinstance(model, SparkTTSModel):
|
57 |
+
# Có thể model được load bằng AutoModel nên là PreTrainedModel, kiểm tra config
|
58 |
+
if isinstance(model, PreTrainedModel) and isinstance(model.config, SparkTTSConfig):
|
59 |
+
pass # OK, model tương thích dựa trên config
|
60 |
+
else:
|
61 |
+
raise TypeError(f"Expected model compatible with SparkTTSConfig, but got {type(model)}")
|
62 |
+
|
63 |
+
if tokenizer is None:
|
64 |
+
raise ValueError("SparkTTSPipeline requires a 'tokenizer' argument.")
|
65 |
+
if not hasattr(tokenizer, 'encode') or not hasattr(tokenizer, 'batch_decode'):
|
66 |
+
raise TypeError("Tokenizer does not seem to be a valid Transformers tokenizer.")
|
67 |
+
|
68 |
+
# Gọi super().__init__ với model/tokenizer đã nhận
|
69 |
+
super().__init__(model=model, tokenizer=tokenizer, framework=framework, device=device, **kwargs)
|
70 |
+
if hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'):
|
71 |
+
self.sampling_rate = self.model.config.sample_rate
|
72 |
+
else:
|
73 |
+
# Nên đặt giá trị mặc định hoặc lấy từ nơi khác nếu config không có
|
74 |
+
logger.warning("Could not determine sampling rate from model config. Defaulting to 16000.")
|
75 |
+
self.sampling_rate = 16000
|
76 |
+
|
77 |
+
def _sanitize_parameters(self, **kwargs) -> tuple[dict, dict, dict]:
|
78 |
+
"""
|
79 |
+
Sanitizes pipeline parameters and separates them for preprocess, forward, and postprocess.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
Tuple[dict, dict, dict]: preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
83 |
+
"""
|
84 |
+
preprocess_kwargs = {}
|
85 |
+
# --- Preprocessing specific args ---
|
86 |
+
if "prompt_speech_path" in kwargs:
|
87 |
+
preprocess_kwargs["prompt_speech_path"] = kwargs["prompt_speech_path"]
|
88 |
+
if "prompt_text" in kwargs:
|
89 |
+
preprocess_kwargs["prompt_text"] = kwargs["prompt_text"]
|
90 |
+
if "gender" in kwargs:
|
91 |
+
preprocess_kwargs["gender"] = kwargs["gender"]
|
92 |
+
if "pitch" in kwargs:
|
93 |
+
preprocess_kwargs["pitch"] = kwargs["pitch"]
|
94 |
+
if "speed" in kwargs:
|
95 |
+
preprocess_kwargs["speed"] = kwargs["speed"]
|
96 |
+
|
97 |
+
forward_kwargs = {}
|
98 |
+
# --- Forward specific args (LLM generation) ---
|
99 |
+
# Use kwargs.get to allow users to override defaults
|
100 |
+
forward_kwargs["max_new_tokens"] = kwargs.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
|
101 |
+
forward_kwargs["do_sample"] = kwargs.get("do_sample", True)
|
102 |
+
forward_kwargs["temperature"] = kwargs.get("temperature", DEFAULT_TEMPERATURE)
|
103 |
+
forward_kwargs["top_k"] = kwargs.get("top_k", DEFAULT_TOP_K)
|
104 |
+
forward_kwargs["top_p"] = kwargs.get("top_p", DEFAULT_TOP_P)
|
105 |
+
# Ensure essential generation tokens are present if needed
|
106 |
+
if self.tokenizer.eos_token_id is not None:
|
107 |
+
forward_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
|
108 |
+
if self.tokenizer.pad_token_id is not None:
|
109 |
+
forward_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
110 |
+
elif self.tokenizer.eos_token_id is not None:
|
111 |
+
logger.warning("Setting pad_token_id to eos_token_id for open-end generation.")
|
112 |
+
forward_kwargs["pad_token_id"] = self.tokenizer.eos_token_id
|
113 |
+
# Filter out None values that might cause issues with generate
|
114 |
+
forward_kwargs = {k: v for k, v in forward_kwargs.items() if v is not None}
|
115 |
+
|
116 |
+
postprocess_kwargs = {}
|
117 |
+
# --- Postprocessing specific args (if any in the future) ---
|
118 |
+
# Example: if you added an option to return tokens instead of audio
|
119 |
+
# if "return_tokens" in kwargs:
|
120 |
+
# postprocess_kwargs["return_tokens"] = kwargs["return_tokens"]
|
121 |
+
|
122 |
+
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
123 |
+
|
124 |
+
def preprocess(self, inputs, **preprocess_kwargs) -> dict:
|
125 |
+
"""
|
126 |
+
Transforms text input and preprocess_kwargs into model input format.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
inputs (str): The text to synthesize.
|
130 |
+
preprocess_kwargs (dict): Arguments relevant to preprocessing (e.g., prompt paths, controls).
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
dict: Containing `model_inputs` (tokenized dict) and `global_token_ids_prompt` (optional Tensor).
|
134 |
+
"""
|
135 |
+
text = inputs
|
136 |
+
prompt_speech_path = preprocess_kwargs.get("prompt_speech_path")
|
137 |
+
prompt_text = preprocess_kwargs.get("prompt_text")
|
138 |
+
gender = preprocess_kwargs.get("gender")
|
139 |
+
pitch = preprocess_kwargs.get("pitch")
|
140 |
+
speed = preprocess_kwargs.get("speed")
|
141 |
+
|
142 |
+
global_token_ids = None
|
143 |
+
llm_prompt_string = ""
|
144 |
+
|
145 |
+
# --- Logic to build llm_prompt_string and get global_token_ids ---
|
146 |
+
if prompt_speech_path is not None:
|
147 |
+
# Voice Cloning Mode
|
148 |
+
logger.info(f"Preprocessing for Voice Cloning (prompt: {prompt_speech_path})")
|
149 |
+
if not Path(prompt_speech_path).exists():
|
150 |
+
raise FileNotFoundError(f"Prompt speech file not found: {prompt_speech_path}")
|
151 |
+
# Use the MODEL's method for tokenization (self.model is set by base Pipeline class)
|
152 |
+
global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path)
|
153 |
+
global_token_ids = global_tokens # Keep Tensor for detokenization
|
154 |
+
|
155 |
+
global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()])
|
156 |
+
if prompt_text and len(prompt_text) > 1:
|
157 |
+
semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()])
|
158 |
+
llm_prompt_parts = [
|
159 |
+
TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>",
|
160 |
+
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
|
161 |
+
"<|start_semantic_token|>", semantic_tokens_str,
|
162 |
+
]
|
163 |
+
else:
|
164 |
+
llm_prompt_parts = [
|
165 |
+
TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>",
|
166 |
+
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
|
167 |
+
]
|
168 |
+
llm_prompt_string = "".join(llm_prompt_parts)
|
169 |
+
elif gender is not None and pitch is not None and speed is not None:
|
170 |
+
# Voice Creation Mode
|
171 |
+
logger.info(f"Preprocessing for Voice Creation (gender: {gender}, pitch: {pitch}, speed: {speed})")
|
172 |
+
if gender not in GENDER_MAP: raise ValueError(f"Invalid gender: {gender}")
|
173 |
+
if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch: {pitch}")
|
174 |
+
if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed: {speed}")
|
175 |
+
|
176 |
+
gender_id = GENDER_MAP[gender]
|
177 |
+
pitch_level_id = LEVELS_MAP[pitch]
|
178 |
+
speed_level_id = LEVELS_MAP[speed]
|
179 |
+
attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>"
|
180 |
+
llm_prompt_parts = [
|
181 |
+
TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>",
|
182 |
+
"<|start_style_label|>", attribute_tokens, "<|end_style_label|>",
|
183 |
+
]
|
184 |
+
llm_prompt_string = "".join(llm_prompt_parts)
|
185 |
+
else:
|
186 |
+
raise ValueError("Pipeline requires 'prompt_speech_path' (for cloning) or 'gender', 'pitch', 'speed' (for creation).")
|
187 |
+
# --- End prompt building logic ---
|
188 |
+
|
189 |
+
# Tokenize the final prompt for the LLM
|
190 |
+
# Use self.tokenizer (set by base Pipeline class)
|
191 |
+
model_inputs = self.tokenizer(llm_prompt_string, return_tensors=self.framework, padding=False)
|
192 |
+
|
193 |
+
return {"model_inputs": model_inputs, "global_token_ids_prompt": global_token_ids}
|
194 |
+
|
195 |
+
|
196 |
+
def _forward(self, model_inputs, **forward_kwargs) -> dict:
|
197 |
+
"""
|
198 |
+
Passes model_inputs to the model's LLM generate method with forward_kwargs.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
model_inputs (dict): Output from `preprocess`.
|
202 |
+
forward_kwargs (dict): Generation arguments from `_sanitize_parameters`.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
dict: Containing `generated_ids`, `input_ids_len`, and context (`global_token_ids_prompt`).
|
206 |
+
"""
|
207 |
+
llm_inputs = model_inputs["model_inputs"]
|
208 |
+
global_token_ids_prompt = model_inputs.get("global_token_ids_prompt")
|
209 |
+
|
210 |
+
# Move inputs to the correct device (self.device is set by base Pipeline class)
|
211 |
+
llm_inputs = {k: v.to(self.device) for k, v in llm_inputs.items()}
|
212 |
+
input_ids = llm_inputs["input_ids"]
|
213 |
+
input_ids_len = input_ids.shape[-1]
|
214 |
+
|
215 |
+
# Combine LLM inputs and generation arguments
|
216 |
+
generate_kwargs = {**llm_inputs, **forward_kwargs}
|
217 |
+
|
218 |
+
logger.info(f"Generating tokens with args: {forward_kwargs}")
|
219 |
+
# Use the model's LLM component (self.model is set by base Pipeline class)
|
220 |
+
with torch.no_grad():
|
221 |
+
generated_ids = self.model.llm.generate(**generate_kwargs)
|
222 |
+
|
223 |
+
# Prepare output dict to pass to postprocess
|
224 |
+
output_dict = {
|
225 |
+
"generated_ids": generated_ids,
|
226 |
+
"input_ids_len": input_ids_len,
|
227 |
+
}
|
228 |
+
if global_token_ids_prompt is not None:
|
229 |
+
output_dict["global_token_ids_prompt"] = global_token_ids_prompt
|
230 |
+
|
231 |
+
return output_dict
|
232 |
+
|
233 |
+
def postprocess(self, model_outputs, **postprocess_kwargs) -> dict:
|
234 |
+
"""
|
235 |
+
Transforms model outputs (from _forward) into the final audio dictionary.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
model_outputs (dict): Dictionary from `_forward`.
|
239 |
+
postprocess_kwargs (dict): Arguments relevant to postprocessing (currently none).
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
dict: Containing `audio` (np.ndarray) and `sampling_rate` (int).
|
243 |
+
"""
|
244 |
+
generated_ids = model_outputs["generated_ids"]
|
245 |
+
input_ids_len = model_outputs["input_ids_len"]
|
246 |
+
global_token_ids_prompt = model_outputs.get("global_token_ids_prompt")
|
247 |
+
|
248 |
+
# --- Logic to extract tokens and detokenize ---
|
249 |
+
output_ids = generated_ids[0, input_ids_len:] # Assumes batch size 1
|
250 |
+
# Use self.tokenizer (set by base Pipeline class)
|
251 |
+
predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
|
252 |
+
|
253 |
+
semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text)
|
254 |
+
if not semantic_matches:
|
255 |
+
logger.warning("No semantic tokens found. Returning empty audio.")
|
256 |
+
# Use self.model.config for sampling rate
|
257 |
+
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.model.config.sample_rate}
|
258 |
+
|
259 |
+
pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0).to(self.device)
|
260 |
+
|
261 |
+
if global_token_ids_prompt is not None:
|
262 |
+
# Voice Cloning: Use prompt tokens
|
263 |
+
global_token_ids = global_token_ids_prompt.to(self.device)
|
264 |
+
logger.info("Using global tokens from prompt.")
|
265 |
+
else:
|
266 |
+
# Voice Creation: Extract generated tokens
|
267 |
+
global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text)
|
268 |
+
if not global_matches:
|
269 |
+
raise ValueError("Voice creation failed: No bicodec_global tokens found.")
|
270 |
+
global_token_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0).to(self.device)
|
271 |
+
if global_token_ids.ndim == 2: global_token_ids = global_token_ids.unsqueeze(1)
|
272 |
+
logger.info("Using global tokens from generated text.")
|
273 |
+
|
274 |
+
# Use the MODEL's method for detokenization (self.model is set by base Pipeline class)
|
275 |
+
wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids)
|
276 |
+
logger.info(f"Generated audio shape: {wav_np.shape}")
|
277 |
+
# --- End detokenization logic ---
|
278 |
+
|
279 |
+
# Return final output dictionary
|
280 |
+
return {"audio": wav_np, "sampling_rate": self.model.config.sample_rate}
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
# --- Add Registration Code Here ---
|
285 |
+
# This code will execute when this file is loaded via trust_remote_code
|
286 |
+
try:
|
287 |
+
from transformers.pipelines import PIPELINE_REGISTRY
|
288 |
+
from transformers import AutoModel # Use AutoModel for registration
|
289 |
+
|
290 |
+
print(f"Registering SparkTTSPipeline for task 'text-to-speech' from pipeline_spark_tts.py...")
|
291 |
+
PIPELINE_REGISTRY.register_pipeline(
|
292 |
+
"text-to-speech", # Task name
|
293 |
+
pipeline_class=SparkTTSPipeline, # The class defined above
|
294 |
+
pt_model=AutoModel, # Compatible PT AutoModel class
|
295 |
+
# tf_model=None, # Add TF class if needed
|
296 |
+
)
|
297 |
+
print("Pipeline registration call completed successfully.")
|
298 |
+
except ImportError:
|
299 |
+
# Handle potential import error if transformers structure changes
|
300 |
+
print("WARNING: Could not import PIPELINE_REGISTRY or AutoModel. Pipeline registration failed.")
|
301 |
+
except Exception as e:
|
302 |
+
print(f"ERROR: An unexpected error occurred during pipeline registration: {e}")
|
303 |
+
# --- End Registration Code ---
|
processing_spark_tts.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The SparkAudio Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
# ... (license) ...
|
4 |
+
"""Processor class for SparkTTS."""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import re
|
8 |
+
import numpy as np
|
9 |
+
import warnings
|
10 |
+
from typing import Optional, Dict, Any, Union, List, Tuple
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
from transformers.processing_utils import ProcessorMixin
|
14 |
+
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
15 |
+
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
|
16 |
+
from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor
|
17 |
+
from transformers.utils import logging
|
18 |
+
|
19 |
+
# Import necessary items directly or ensure they are available via model reference
|
20 |
+
# Note: Avoid direct model imports here if possible, rely on the model reference.
|
21 |
+
# from .modeling_spark_tts import SparkTTSModel # Avoid direct model import if possible
|
22 |
+
from .configuration_spark_tts import SparkTTSConfig # Config is okay
|
23 |
+
|
24 |
+
# Import utils needed for prompt formatting (assuming they are merged into modeling)
|
25 |
+
# We'll access them via the model reference if needed, or duplicate simple ones like token maps.
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__)
|
28 |
+
|
29 |
+
# --- Token Maps (Duplicate here for direct use in processor) ---
|
30 |
+
TASK_TOKEN_MAP = {
|
31 |
+
"tts": "<|task_tts|>",
|
32 |
+
"controllable_tts": "<|task_controllable_tts|>",
|
33 |
+
# Add other tasks if needed by processor logic
|
34 |
+
}
|
35 |
+
LEVELS_MAP = {"very_low": 0, "low": 1, "moderate": 2, "high": 3, "very_high": 4}
|
36 |
+
GENDER_MAP = {"female": 0, "male": 1}
|
37 |
+
# --- End Token Maps ---
|
38 |
+
|
39 |
+
|
40 |
+
class SparkTTSProcessor(ProcessorMixin):
|
41 |
+
r"""
|
42 |
+
Constructs a SparkTTS processor which wraps a text tokenizer and an audio feature extractor
|
43 |
+
into a single processor.
|
44 |
+
|
45 |
+
[`SparkTTSProcessor`] offers all the functionalities of [`AutoTokenizer`] and [`Wav2Vec2FeatureExtractor`].
|
46 |
+
It processes text input for the LLM and prepares audio inputs if needed (delegating actual audio tokenization
|
47 |
+
to the model). It also handles decoding the final output.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
tokenizer (`PreTrainedTokenizerBase`):
|
51 |
+
An instance of [`AutoTokenizer`]. The tokenizer is used to encode the prompt text.
|
52 |
+
feature_extractor (`Wav2Vec2FeatureExtractor`):
|
53 |
+
An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is used to processor reference audio
|
54 |
+
(though the main processing happens inside the model).
|
55 |
+
model (`PreTrainedModel`, *optional*):
|
56 |
+
A reference to the loaded `SparkTTSModel`. This is REQUIRED for voice cloning (prompt audio processing)
|
57 |
+
and final audio decoding, as these steps rely on the model's internal BiCodec and Wav2Vec2 components.
|
58 |
+
Set this using `processor.model = model` after loading both.
|
59 |
+
config (`SparkTTSConfig`, *optional*):
|
60 |
+
The configuration object, needed for parameters like sample_rate. Can often be inferred from the model.
|
61 |
+
"""
|
62 |
+
attributes = ["tokenizer", "feature_extractor"]
|
63 |
+
tokenizer_class = ("Qwen2TokenizerFast", "Qwen2Tokenizer") # Specify the underlying tokenizer type
|
64 |
+
feature_extractor_class = ("Wav2Vec2FeatureExtractor",) # Specify the underlying feature extractor type
|
65 |
+
|
66 |
+
def __init__(self, tokenizer=None, feature_extractor=None, model=None, config=None, **kwargs):
|
67 |
+
if tokenizer is None:
|
68 |
+
raise ValueError("SparkTTSProcessor requires a `tokenizer`.")
|
69 |
+
if feature_extractor is None:
|
70 |
+
# Attempt to load default if path is known or provide clearer error
|
71 |
+
raise ValueError("SparkTTSProcessor requires a `feature_extractor` (Wav2Vec2FeatureExtractor).")
|
72 |
+
|
73 |
+
super().__init__(tokenizer, feature_extractor)
|
74 |
+
self.model = model # Store model reference (can be None initially)
|
75 |
+
self.config = config # Store config reference
|
76 |
+
|
77 |
+
# Get sampling rate from config if available
|
78 |
+
self.sampling_rate = None
|
79 |
+
if self.config and hasattr(self.config, 'sample_rate'):
|
80 |
+
self.sampling_rate = self.config.sample_rate
|
81 |
+
elif self.model and hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'):
|
82 |
+
self.sampling_rate = self.model.config.sample_rate
|
83 |
+
else:
|
84 |
+
# Try feature extractor default, or raise warning
|
85 |
+
if hasattr(self.feature_extractor, 'sampling_rate'):
|
86 |
+
self.sampling_rate = self.feature_extractor.sampling_rate
|
87 |
+
else:
|
88 |
+
logger.warning("Could not determine sampling rate. Defaulting to 16000. Set `processor.sampling_rate` manually if needed.")
|
89 |
+
self.sampling_rate = 16000
|
90 |
+
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
94 |
+
"""
|
95 |
+
Instantiate a [`SparkTTSProcessor`] from a pretrained processor configuration.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
99 |
+
This can be either:
|
100 |
+
- a string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co.
|
101 |
+
- a path to a *directory* containing processor files saved using the `save_pretrained()` method,
|
102 |
+
e.g., `./my_model_directory/`.
|
103 |
+
**kwargs:
|
104 |
+
Additional keyword arguments passed along to both `AutoTokenizer.from_pretrained()` and
|
105 |
+
`AutoFeatureExtractor.from_pretrained()`.
|
106 |
+
"""
|
107 |
+
config = kwargs.pop("config", None)
|
108 |
+
if config is None:
|
109 |
+
# Try loading the specific config first
|
110 |
+
try:
|
111 |
+
config = SparkTTSConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
112 |
+
except Exception:
|
113 |
+
logger.warning(f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. Processor might lack some config values.")
|
114 |
+
config = None
|
115 |
+
|
116 |
+
|
117 |
+
# Resolve component paths relative to the main path
|
118 |
+
def _resolve_path(sub_path):
|
119 |
+
p = Path(sub_path)
|
120 |
+
if p.is_absolute():
|
121 |
+
return str(p)
|
122 |
+
# Try resolving relative to the main path if it's a directory
|
123 |
+
main_path = Path(pretrained_model_name_or_path)
|
124 |
+
if main_path.is_dir():
|
125 |
+
resolved = main_path / p
|
126 |
+
if resolved.exists():
|
127 |
+
return str(resolved)
|
128 |
+
# Fallback to assuming sub_path is relative within a repo structure (might fail for local non-dirs)
|
129 |
+
return sub_path
|
130 |
+
|
131 |
+
# Determine paths from config or assume defaults
|
132 |
+
llm_tokenizer_path = "./LLM"
|
133 |
+
w2v_processor_path = "./wav2vec2-large-xlsr-53"
|
134 |
+
if config:
|
135 |
+
llm_tokenizer_path = getattr(config, 'llm_model_name_or_path', llm_tokenizer_path)
|
136 |
+
w2v_processor_path = getattr(config, 'wav2vec2_model_name_or_path', w2v_processor_path)
|
137 |
+
|
138 |
+
resolved_tokenizer_path = _resolve_path(llm_tokenizer_path)
|
139 |
+
resolved_w2v_path = _resolve_path(w2v_processor_path)
|
140 |
+
|
141 |
+
try:
|
142 |
+
tokenizer = AutoTokenizer.from_pretrained(resolved_tokenizer_path, **kwargs)
|
143 |
+
except Exception as e:
|
144 |
+
raise OSError(f"Could not load tokenizer from {resolved_tokenizer_path}. Ensure path is correct and files exist. Original error: {e}")
|
145 |
+
|
146 |
+
try:
|
147 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(resolved_w2v_path, **kwargs)
|
148 |
+
except Exception as e:
|
149 |
+
raise OSError(f"Could not load feature extractor from {resolved_w2v_path}. Ensure path is correct and files exist. Original error: {e}")
|
150 |
+
|
151 |
+
# The 'model' attribute will be set later externally
|
152 |
+
return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=config)
|
153 |
+
|
154 |
+
|
155 |
+
def __call__(self, text: str = None,
|
156 |
+
prompt_speech_path: Optional[str] = None,
|
157 |
+
prompt_text: Optional[str] = None,
|
158 |
+
gender: Optional[str] = None,
|
159 |
+
pitch: Optional[str] = None,
|
160 |
+
speed: Optional[str] = None,
|
161 |
+
return_tensors: Optional[str] = "pt",
|
162 |
+
**kwargs) -> BatchEncoding:
|
163 |
+
"""
|
164 |
+
Main method to process inputs for the SparkTTS model.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
text (`str`): The text to be synthesized.
|
168 |
+
prompt_speech_path (`str`, *optional*): Path to prompt audio for voice cloning.
|
169 |
+
prompt_text (`str`, *optional*): Transcript of prompt audio.
|
170 |
+
gender (`str`, *optional*): Target gender ('male' or 'female') for voice creation.
|
171 |
+
pitch (`str`, *optional*): Target pitch level ('very_low'...'very_high') for voice creation.
|
172 |
+
speed (`str`, *optional*): Target speed level ('very_low'...'very_high') for voice creation.
|
173 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
174 |
+
Framework of the returned tensors (`"pt"` for PyTorch, `"np"` for NumPy).
|
175 |
+
**kwargs: Additional arguments (currently ignored).
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
`BatchEncoding`: A dictionary containing the `input_ids`, `attention_mask`, and optionally
|
179 |
+
`global_token_ids_prompt` ready for the model's `.generate()` method.
|
180 |
+
"""
|
181 |
+
if text is None:
|
182 |
+
raise ValueError("`text` input must be provided.")
|
183 |
+
|
184 |
+
global_token_ids_prompt = None
|
185 |
+
llm_prompt_string = ""
|
186 |
+
|
187 |
+
if prompt_speech_path is not None:
|
188 |
+
# --- Voice Cloning Mode ---
|
189 |
+
if self.model is None:
|
190 |
+
raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for voice cloning.")
|
191 |
+
if not hasattr(self.model, '_tokenize_audio'):
|
192 |
+
raise AttributeError("The provided model object does not have the required '_tokenize_audio' method.")
|
193 |
+
|
194 |
+
logger.info(f"Processing prompt audio: {prompt_speech_path}")
|
195 |
+
# Delegate audio tokenization to the model
|
196 |
+
try:
|
197 |
+
# _tokenize_audio returns (global_tokens, semantic_tokens)
|
198 |
+
global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path)
|
199 |
+
global_token_ids_prompt = global_tokens # Keep for decoding stage
|
200 |
+
except Exception as e:
|
201 |
+
logger.error(f"Error tokenizing prompt audio: {e}", exc_info=True)
|
202 |
+
raise RuntimeError(f"Failed to process prompt audio file: {prompt_speech_path}. Check file integrity and model compatibility.") from e
|
203 |
+
|
204 |
+
# Format prompt string using token maps
|
205 |
+
global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()])
|
206 |
+
|
207 |
+
if prompt_text and len(prompt_text) > 1:
|
208 |
+
semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()])
|
209 |
+
llm_prompt_parts = [
|
210 |
+
TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>",
|
211 |
+
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
|
212 |
+
"<|start_semantic_token|>", semantic_tokens_str,
|
213 |
+
]
|
214 |
+
else:
|
215 |
+
llm_prompt_parts = [
|
216 |
+
TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>",
|
217 |
+
"<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
|
218 |
+
]
|
219 |
+
llm_prompt_string = "".join(llm_prompt_parts)
|
220 |
+
|
221 |
+
elif gender is not None and pitch is not None and speed is not None:
|
222 |
+
# --- Voice Creation Mode ---
|
223 |
+
if gender not in GENDER_MAP: raise ValueError(f"Invalid gender '{gender}'.")
|
224 |
+
if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch '{pitch}'.")
|
225 |
+
if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed '{speed}'.")
|
226 |
+
|
227 |
+
gender_id = GENDER_MAP[gender]
|
228 |
+
pitch_level_id = LEVELS_MAP[pitch]
|
229 |
+
speed_level_id = LEVELS_MAP[speed]
|
230 |
+
|
231 |
+
attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>"
|
232 |
+
|
233 |
+
llm_prompt_parts = [
|
234 |
+
TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>",
|
235 |
+
"<|start_style_label|>", attribute_tokens, "<|end_style_label|>",
|
236 |
+
]
|
237 |
+
llm_prompt_string = "".join(llm_prompt_parts)
|
238 |
+
# No global_token_ids_prompt needed
|
239 |
+
|
240 |
+
else:
|
241 |
+
raise ValueError("Processor requires either 'prompt_speech_path' (for cloning) or 'gender', 'pitch', and 'speed' (for creation).")
|
242 |
+
|
243 |
+
# Tokenize the final LLM prompt string
|
244 |
+
inputs = self.tokenizer(llm_prompt_string, return_tensors=return_tensors, padding=False, truncation=False)
|
245 |
+
|
246 |
+
# Add prompt global tokens to the output if they exist (for passing to decode)
|
247 |
+
if global_token_ids_prompt is not None:
|
248 |
+
inputs["global_token_ids_prompt"] = global_token_ids_prompt
|
249 |
+
|
250 |
+
return inputs
|
251 |
+
|
252 |
+
def decode(self,
|
253 |
+
generated_ids: Union[List[int], np.ndarray, torch.Tensor],
|
254 |
+
global_token_ids_prompt: Optional[torch.Tensor] = None,
|
255 |
+
input_ids_len: Optional[int] = None,
|
256 |
+
skip_special_tokens: bool = True) -> Dict[str, Any]:
|
257 |
+
"""
|
258 |
+
Decodes the raw token IDs generated by the model into an audio waveform.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
generated_ids (`Union[List[int], np.ndarray, torch.Tensor]`):
|
262 |
+
The token IDs generated by the `model.generate()` method. Assumed to be a single sequence (batch size 1).
|
263 |
+
global_token_ids_prompt (`torch.Tensor`, *optional*):
|
264 |
+
The global tokens obtained from the prompt audio during preprocessing (needed for voice cloning).
|
265 |
+
Should be passed from the `__call__` output.
|
266 |
+
input_ids_len (`int`, *optional*):
|
267 |
+
The length of the original prompt `input_ids`. If provided, the prompt part will be stripped from
|
268 |
+
`generated_ids` before decoding the text representation. If None, assumes `generated_ids` contains
|
269 |
+
*only* the generated part.
|
270 |
+
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
271 |
+
Whether to skip special tokens when decoding the text representation for parsing.
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
`Dict[str, Any]`: A dictionary containing:
|
275 |
+
- `audio` (`np.ndarray`): The generated audio waveform.
|
276 |
+
- `sampling_rate` (`int`): The sampling rate of the audio.
|
277 |
+
"""
|
278 |
+
if self.model is None:
|
279 |
+
raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for decoding.")
|
280 |
+
if not hasattr(self.model, '_detokenize_audio'):
|
281 |
+
raise AttributeError("The provided model object does not have the required '_detokenize_audio' method.")
|
282 |
+
if self.sampling_rate is None:
|
283 |
+
raise ValueError("Processor could not determine sampling_rate. Set `processor.sampling_rate`.")
|
284 |
+
|
285 |
+
# Ensure generated_ids is a tensor on the correct device
|
286 |
+
if isinstance(generated_ids, (list, np.ndarray)):
|
287 |
+
output_ids_tensor = torch.tensor(generated_ids)
|
288 |
+
else:
|
289 |
+
output_ids_tensor = generated_ids
|
290 |
+
|
291 |
+
# Remove prompt if input_ids_len is provided
|
292 |
+
if input_ids_len is not None:
|
293 |
+
# Handle potential batch dimension if present (though usually not for decode)
|
294 |
+
if output_ids_tensor.ndim > 1:
|
295 |
+
output_ids = output_ids_tensor[0, input_ids_len:]
|
296 |
+
else:
|
297 |
+
output_ids = output_ids_tensor[input_ids_len:]
|
298 |
+
else:
|
299 |
+
if output_ids_tensor.ndim > 1:
|
300 |
+
output_ids = output_ids_tensor[0]
|
301 |
+
else:
|
302 |
+
output_ids = output_ids_tensor
|
303 |
+
|
304 |
+
if output_ids.numel() == 0:
|
305 |
+
logger.warning("Received empty generated IDs after removing prompt. Returning empty audio.")
|
306 |
+
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
|
307 |
+
|
308 |
+
# Decode the text representation to parse tokens
|
309 |
+
predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens)
|
310 |
+
|
311 |
+
# Extract semantic tokens
|
312 |
+
semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text)
|
313 |
+
if not semantic_matches:
|
314 |
+
logger.warning("No semantic tokens found in the generated output text. Cannot synthesize audio.")
|
315 |
+
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
|
316 |
+
# Use model's device for tensors
|
317 |
+
device = self.model.device
|
318 |
+
pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches], dtype=torch.long, device=device).unsqueeze(0) # Add batch dim
|
319 |
+
|
320 |
+
# Determine global tokens
|
321 |
+
if global_token_ids_prompt is not None:
|
322 |
+
# Voice Cloning: Use prompt global tokens
|
323 |
+
global_token_ids = global_token_ids_prompt.to(device)
|
324 |
+
# Ensure correct shape (B, T_token, Q) or (B, D) - BiCodec detokenize needs to handle this
|
325 |
+
if global_token_ids.ndim == 2: # If (B, D), maybe unsqueeze? Check BiCodec.detokenize expectation
|
326 |
+
global_token_ids = global_token_ids.unsqueeze(1) # Assume (B, 1, D) might be needed
|
327 |
+
else:
|
328 |
+
# Voice Creation: Parse global tokens from generated text
|
329 |
+
global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text)
|
330 |
+
if not global_matches:
|
331 |
+
logger.error("Voice creation failed: No global tokens found in generated text.")
|
332 |
+
raise ValueError("Voice creation failed: Could not find bicodec_global tokens in the LLM output.")
|
333 |
+
global_token_ids = torch.tensor([int(token) for token in global_matches], dtype=torch.long, device=device).unsqueeze(0) # Add batch dim
|
334 |
+
# Add sequence dimension if needed (check BiCodec.detokenize)
|
335 |
+
if global_token_ids.ndim == 2:
|
336 |
+
global_token_ids = global_token_ids.unsqueeze(1) # Assume (B, 1, D)
|
337 |
+
|
338 |
+
# Detokenize audio using the model's method
|
339 |
+
try:
|
340 |
+
wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids)
|
341 |
+
except Exception as e:
|
342 |
+
logger.error(f"Error during audio detokenization: {e}", exc_info=True)
|
343 |
+
raise RuntimeError("Failed to synthesize audio waveform from generated tokens.") from e
|
344 |
+
|
345 |
+
return {"audio": wav_np, "sampling_rate": self.sampling_rate}
|
wav2vec2-large-xlsr-53/README.md
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: multilingual
|
3 |
+
datasets:
|
4 |
+
- common_voice
|
5 |
+
tags:
|
6 |
+
- speech
|
7 |
+
license: apache-2.0
|
8 |
+
---
|
9 |
+
|
10 |
+
# Wav2Vec2-XLSR-53
|
11 |
+
|
12 |
+
[Facebook's XLSR-Wav2Vec2](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/)
|
13 |
+
|
14 |
+
The base model pretrained on 16kHz sampled speech audio. When using the model make sure that your speech input is also sampled at 16Khz. Note that this model should be fine-tuned on a downstream task, like Automatic Speech Recognition. Check out [this blog](https://huggingface.co/blog/fine-tune-wav2vec2-english) for more information.
|
15 |
+
|
16 |
+
[Paper](https://arxiv.org/abs/2006.13979)
|
17 |
+
|
18 |
+
Authors: Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli
|
19 |
+
|
20 |
+
**Abstract**
|
21 |
+
This paper presents XLSR which learns cross-lingual speech representations by pretraining a single model from the raw waveform of speech in multiple languages. We build on wav2vec 2.0 which is trained by solving a contrastive task over masked latent speech representations and jointly learns a quantization of the latents shared across languages. The resulting model is fine-tuned on labeled data and experiments show that cross-lingual pretraining significantly outperforms monolingual pretraining. On the CommonVoice benchmark, XLSR shows a relative phoneme error rate reduction of 72% compared to the best known results. On BABEL, our approach improves word error rate by 16% relative compared to a comparable system. Our approach enables a single multilingual speech recognition model which is competitive to strong individual models. Analysis shows that the latent discrete speech representations are shared across languages with increased sharing for related languages. We hope to catalyze research in low-resource speech understanding by releasing XLSR-53, a large model pretrained in 53 languages.
|
22 |
+
|
23 |
+
The original model can be found under https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#wav2vec-20.
|
24 |
+
|
25 |
+
# Usage
|
26 |
+
|
27 |
+
See [this notebook](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_Tune_XLSR_Wav2Vec2_on_Turkish_ASR_with_%F0%9F%A4%97_Transformers.ipynb) for more information on how to fine-tune the model.
|
28 |
+
|
29 |
+

|
wav2vec2-large-xlsr-53/config.json
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation_dropout": 0.0,
|
3 |
+
"apply_spec_augment": true,
|
4 |
+
"architectures": [
|
5 |
+
"Wav2Vec2ForPreTraining"
|
6 |
+
],
|
7 |
+
"attention_dropout": 0.1,
|
8 |
+
"bos_token_id": 1,
|
9 |
+
"codevector_dim": 768,
|
10 |
+
"contrastive_logits_temperature": 0.1,
|
11 |
+
"conv_bias": true,
|
12 |
+
"conv_dim": [
|
13 |
+
512,
|
14 |
+
512,
|
15 |
+
512,
|
16 |
+
512,
|
17 |
+
512,
|
18 |
+
512,
|
19 |
+
512
|
20 |
+
],
|
21 |
+
"conv_kernel": [
|
22 |
+
10,
|
23 |
+
3,
|
24 |
+
3,
|
25 |
+
3,
|
26 |
+
3,
|
27 |
+
2,
|
28 |
+
2
|
29 |
+
],
|
30 |
+
"conv_stride": [
|
31 |
+
5,
|
32 |
+
2,
|
33 |
+
2,
|
34 |
+
2,
|
35 |
+
2,
|
36 |
+
2,
|
37 |
+
2
|
38 |
+
],
|
39 |
+
"ctc_loss_reduction": "sum",
|
40 |
+
"ctc_zero_infinity": false,
|
41 |
+
"diversity_loss_weight": 0.1,
|
42 |
+
"do_stable_layer_norm": true,
|
43 |
+
"eos_token_id": 2,
|
44 |
+
"feat_extract_activation": "gelu",
|
45 |
+
"feat_extract_dropout": 0.0,
|
46 |
+
"feat_extract_norm": "layer",
|
47 |
+
"feat_proj_dropout": 0.1,
|
48 |
+
"feat_quantizer_dropout": 0.0,
|
49 |
+
"final_dropout": 0.0,
|
50 |
+
"gradient_checkpointing": false,
|
51 |
+
"hidden_act": "gelu",
|
52 |
+
"hidden_dropout": 0.1,
|
53 |
+
"hidden_size": 1024,
|
54 |
+
"initializer_range": 0.02,
|
55 |
+
"intermediate_size": 4096,
|
56 |
+
"layer_norm_eps": 1e-05,
|
57 |
+
"layerdrop": 0.1,
|
58 |
+
"mask_channel_length": 10,
|
59 |
+
"mask_channel_min_space": 1,
|
60 |
+
"mask_channel_other": 0.0,
|
61 |
+
"mask_channel_prob": 0.0,
|
62 |
+
"mask_channel_selection": "static",
|
63 |
+
"mask_feature_length": 10,
|
64 |
+
"mask_feature_prob": 0.0,
|
65 |
+
"mask_time_length": 10,
|
66 |
+
"mask_time_min_space": 1,
|
67 |
+
"mask_time_other": 0.0,
|
68 |
+
"mask_time_prob": 0.075,
|
69 |
+
"mask_time_selection": "static",
|
70 |
+
"model_type": "wav2vec2",
|
71 |
+
"num_attention_heads": 16,
|
72 |
+
"num_codevector_groups": 2,
|
73 |
+
"num_codevectors_per_group": 320,
|
74 |
+
"num_conv_pos_embedding_groups": 16,
|
75 |
+
"num_conv_pos_embeddings": 128,
|
76 |
+
"num_feat_extract_layers": 7,
|
77 |
+
"num_hidden_layers": 24,
|
78 |
+
"num_negatives": 100,
|
79 |
+
"pad_token_id": 0,
|
80 |
+
"proj_codevector_dim": 768,
|
81 |
+
"transformers_version": "4.7.0.dev0",
|
82 |
+
"vocab_size": 32
|
83 |
+
}
|
wav2vec2-large-xlsr-53/preprocessor_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
4 |
+
"feature_size": 1,
|
5 |
+
"padding_side": "right",
|
6 |
+
"padding_value": 0,
|
7 |
+
"return_attention_mask": true,
|
8 |
+
"sampling_rate": 16000
|
9 |
+
}
|
wav2vec2-large-xlsr-53/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:314340227371a608f71adcd5f0de5933824fe77e55822aa4b24dba9c1c364dcb
|
3 |
+
size 1269737156
|