ancv commited on
Commit
5cd61b8
·
verified ·
1 Parent(s): 883ecd9

Upload 24 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,6 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
 
2
  *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
3
  *.onnx filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
4
  *.pt filter=lfs diff=lfs merge=lfs -text
5
+ *.wav filter=lfs diff=lfs merge=lfs -text
6
+ *.mp3 filter=lfs diff=lfs merge=lfs -textLLM/tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
BiCodec/config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio_tokenizer:
2
+ mel_params:
3
+ sample_rate: 16000
4
+ n_fft: 1024
5
+ win_length: 640
6
+ hop_length: 320
7
+ mel_fmin: 10
8
+ mel_fmax: null
9
+ num_mels: 128
10
+
11
+ encoder:
12
+ input_channels: 1024
13
+ vocos_dim: 384
14
+ vocos_intermediate_dim: 2048
15
+ vocos_num_layers: 12
16
+ out_channels: 1024
17
+ sample_ratios: [1,1]
18
+
19
+ decoder:
20
+ input_channel: 1024
21
+ channels: 1536
22
+ rates: [8, 5, 4, 2]
23
+ kernel_sizes: [16,11,8,4]
24
+
25
+ quantizer:
26
+ input_dim: 1024
27
+ codebook_size: 8192
28
+ codebook_dim: 8
29
+ commitment: 0.25
30
+ codebook_loss_weight: 2.0
31
+ use_l2_normlize: True
32
+ threshold_ema_dead_code: 0.2
33
+
34
+ speaker_encoder:
35
+ input_dim: 128
36
+ out_dim: 1024
37
+ latent_dim: 128
38
+ token_num: 32
39
+ fsq_levels: [4, 4, 4, 4, 4, 4]
40
+ fsq_num_quantizers: 1
41
+
42
+ prenet:
43
+ input_channels: 1024
44
+ vocos_dim: 384
45
+ vocos_intermediate_dim: 2048
46
+ vocos_num_layers: 12
47
+ out_channels: 1024
48
+ condition_dim: 1024
49
+ sample_ratios: [1,1]
50
+ use_tanh_at_final: False
51
+
52
+ postnet:
53
+ input_channels: 1024
54
+ vocos_dim: 384
55
+ vocos_intermediate_dim: 2048
56
+ vocos_num_layers: 6
57
+ out_channels: 1024
58
+ use_tanh_at_final: False
59
+
60
+
BiCodec/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9940cd48d4446e4340ced82d234bf5618350dd9f5db900ebe47a4fdb03867ec
3
+ size 625518756
LLM/added_tokens.json ADDED
The diff for this file is too large to render. See raw diff
 
LLM/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen2ForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 151643,
7
+ "eos_token_id": 151645,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 896,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 4864,
12
+ "max_position_embeddings": 32768,
13
+ "max_window_layers": 21,
14
+ "model_type": "qwen2",
15
+ "num_attention_heads": 14,
16
+ "num_hidden_layers": 24,
17
+ "num_key_value_heads": 2,
18
+ "rms_norm_eps": 1e-06,
19
+ "rope_theta": 1000000.0,
20
+ "sliding_window": 32768,
21
+ "tie_word_embeddings": true,
22
+ "torch_dtype": "bfloat16",
23
+ "transformers_version": "4.43.1",
24
+ "use_cache": true,
25
+ "use_sliding_window": false,
26
+ "vocab_size": 166000
27
+ }
LLM/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
LLM/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54825baf0a2f6076eb3c78fa1d22a95aee225f59070a8b295f8169db860eb109
3
+ size 2026568968
LLM/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
LLM/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c8b057d6ca205a429cc3428b9fc815f0d6ee1d53106dd5e5b129ef9db2ff057
3
+ size 14129172
LLM/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
LLM/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # SparkAudio/Spark-TTS-0.5B_v2/__init__.py
2
+
3
+ # import os
4
+ # from pathlib import Path
5
+ # # Import registry và AutoModel để đăng ký
6
+ # from transformers.pipelines import PIPELINE_REGISTRY
7
+ # from transformers import AutoModel
8
+
9
+ # # Print để xác nhận file này được thực thi
10
+ # print(f"Executing __init__.py for SparkTTS custom module in: {__file__}")
11
+
12
+ # try:
13
+ # # Import lớp Pipeline tùy chỉnh để đăng ký
14
+ # from .pipeline_spark_tts import SparkTTSPipeline
15
+
16
+ # # Đăng ký pipeline tùy chỉnh với registry
17
+ # # Đảm bảo tên task "text-to-speech" khớp với tên bạn dùng khi gọi pipeline()
18
+ # print(f"Attempting to register SparkTTSPipeline for task 'text-to-speech'...")
19
+ # PIPELINE_REGISTRY.register_pipeline(
20
+ # "text-to-speech", # Tên task
21
+ # pipeline_class=SparkTTSPipeline, # Lớp pipeline tùy chỉnh của bạn
22
+ # pt_model=AutoModel, # Lớp AutoModel tương thích cho PyTorch
23
+ # # tf_model=None, # Thêm lớp TF nếu có
24
+ # type="text", # Kiểu input chính (ví dụ: text, audio, image)
25
+ # )
26
+ # print("Pipeline registration call completed.")
27
+
28
+ # # Bắt lỗi nếu import thất bại (ví dụ: thiếu dependency)
29
+ # except ImportError as e:
30
+ # print(f"WARNING: Could not import SparkTTSPipeline in __init__.py: {e}. Pipeline registration failed.")
31
+ # except Exception as e:
32
+ # print(f"ERROR: An unexpected error occurred during SparkTTS __init__ registration: {e}")
_modeling_bicodec_components.py ADDED
The diff for this file is too large to render. See raw diff
 
_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ Utility functions for SparkTTS """
15
+
16
+ import random
17
+ import soxr
18
+ import soundfile
19
+ import torch
20
+ import torchaudio
21
+ import numpy as np
22
+
23
+ from pathlib import Path
24
+ from typing import Tuple, Dict, Any
25
+ from numpy.lib.stride_tricks import sliding_window_view
26
+ from omegaconf import OmegaConf # Keep if BiCodec config loading needs it
27
+
28
+
29
+ # --- Token Maps (from sparktts/utils/token_parser.py) ---
30
+ TASK_TOKEN_MAP = {
31
+ "vc": "<|task_vc|>",
32
+ "tts": "<|task_tts|>",
33
+ "asr": "<|task_asr|>",
34
+ "s2s": "<|task_s2s|>",
35
+ "t2s": "<|task_t2s|>",
36
+ "understand": "<|task_understand|>",
37
+ "caption": "<|task_cap|>",
38
+ "controllable_tts": "<|task_controllable_tts|>",
39
+ "prompt_tts": "<|task_prompt_tts|>",
40
+ "speech_edit": "<|task_edit|>",
41
+ }
42
+
43
+ LEVELS_MAP = {
44
+ "very_low": 0,
45
+ "low": 1,
46
+ "moderate": 2,
47
+ "high": 3,
48
+ "very_high": 4,
49
+ }
50
+
51
+ LEVELS_MAP_UI = {
52
+ 1: 'very_low',
53
+ 2: 'low',
54
+ 3: 'moderate',
55
+ 4: 'high',
56
+ 5: 'very_high'
57
+ }
58
+
59
+ GENDER_MAP = {
60
+ "female": 0,
61
+ "male": 1,
62
+ }
63
+
64
+ # --- Audio Utils (from sparktts/utils/audio.py) ---
65
+ def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
66
+ temp = np.sort(np.abs(audio))
67
+ if len(temp) == 0: # Handle empty audio case
68
+ return audio
69
+ if temp[-1] < 0.1:
70
+ scaling_factor = max(temp[-1], 1e-3)
71
+ audio = audio / scaling_factor * 0.1
72
+ temp = temp[temp > 0.01]
73
+ L = temp.shape[0]
74
+ if L <= 10:
75
+ return audio
76
+ volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
77
+ if volume == 0: # Avoid division by zero if volume is effectively zero
78
+ return audio
79
+ audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
80
+ max_value = np.max(np.abs(audio)) if len(audio) > 0 else 0
81
+ if max_value > 1:
82
+ audio = audio / max_value
83
+ return audio
84
+
85
+ def load_audio(
86
+ adfile: Path,
87
+ sampling_rate: int = None,
88
+ length: int = None,
89
+ volume_normalize: bool = False,
90
+ segment_duration: int = None,
91
+ ) -> np.ndarray:
92
+ try:
93
+ audio, sr = soundfile.read(adfile, dtype='float32') # Ensure float32
94
+ except Exception as e:
95
+ raise IOError(f"Could not read audio file {adfile}: {e}")
96
+
97
+ if audio is None or len(audio) == 0:
98
+ raise ValueError(f"Audio file {adfile} is empty or invalid.")
99
+
100
+ if len(audio.shape) > 1:
101
+ audio = audio[:, 0]
102
+
103
+ if sampling_rate is not None and sr != sampling_rate:
104
+ try:
105
+ # Ensure input is float64 for soxr
106
+ audio = audio.astype(np.float64)
107
+ audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
108
+ # Convert back to float32
109
+ audio = audio.astype(np.float32)
110
+ sr = sampling_rate
111
+ except Exception as e:
112
+ raise RuntimeError(f"Failed to resample audio from {sr}Hz to {sampling_rate}Hz: {e}")
113
+
114
+ if segment_duration is not None:
115
+ seg_length = int(sr * segment_duration)
116
+ audio = random_select_audio_segment(audio, seg_length)
117
+
118
+ if volume_normalize:
119
+ audio = audio_volume_normalize(audio)
120
+
121
+ if length is not None:
122
+ if audio.shape[0] > length:
123
+ audio = audio[:length]
124
+ else:
125
+ audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant')
126
+ return audio
127
+
128
+ def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
129
+ if audio.shape[0] < length:
130
+ audio = np.pad(audio, (0, int(length - audio.shape[0])), mode='constant')
131
+ start_index = 0 # If padded, start from beginning
132
+ elif audio.shape[0] == length:
133
+ start_index = 0 # If exact length, start from beginning
134
+ else:
135
+ start_index = random.randint(0, audio.shape[0] - length)
136
+
137
+ end_index = int(start_index + length)
138
+ return audio[start_index:end_index]
139
+
140
+ # --- File Utils (Minimal required) ---
141
+ def load_config_yaml(config_path: Path) -> Dict:
142
+ """Loads a YAML configuration file using OmegaConf."""
143
+ # Check if path exists
144
+ if not Path(config_path).is_file():
145
+ raise FileNotFoundError(f"YAML Config file not found: {config_path}")
146
+ try:
147
+ config = OmegaConf.load(config_path)
148
+ # Convert OmegaConf DictConfig to standard Python dict
149
+ return OmegaConf.to_container(config, resolve=True)
150
+ except Exception as e:
151
+ raise IOError(f"Error loading YAML config file {config_path}: {e}")
config.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "spark-tts",
3
+ "architectures": [
4
+ "SparkTTSModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_spark_tts.SparkTTSConfig",
8
+ "AutoModel": "modeling_spark_tts.SparkTTSModel",
9
+ "AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
10
+ },
11
+ "processor_class": "processing_spark_tts.SparkTTSProcessor",
12
+ "custom_pipelines": {
13
+ "text-to-speech": {
14
+ "impl": "pipeline_spark_tts.SparkTTSPipeline",
15
+ "pt": ["AutoModel"]
16
+ }
17
+ },
18
+ "llm_model_name_or_path": "./LLM",
19
+ "bicodec_model_name_or_path": "./BiCodec",
20
+ "wav2vec2_model_name_or_path": "./wav2vec2-large-xlsr-53",
21
+ "sample_rate": 16000,
22
+ "highpass_cutoff_freq": 40,
23
+ "latent_hop_length": 320,
24
+ "ref_segment_duration": 6.0,
25
+ "volume_normalize": true,
26
+ "bicodec_config": {
27
+ "audio_tokenizer": {
28
+ "mel_params": {
29
+ "sample_rate": 16000,
30
+ "n_fft": 1024,
31
+ "win_length": 640,
32
+ "hop_length": 320,
33
+ "mel_fmin": 10,
34
+ "mel_fmax": null,
35
+ "num_mels": 128
36
+ },
37
+ "encoder": {
38
+ "input_channels": 1024,
39
+ "vocos_dim": 384,
40
+ "vocos_intermediate_dim": 2048,
41
+ "vocos_num_layers": 12,
42
+ "out_channels": 1024,
43
+ "sample_ratios": [1, 1]
44
+ },
45
+ "decoder": {
46
+ "input_channel": 1024,
47
+ "channels": 1536,
48
+ "rates": [8, 5, 4, 2],
49
+ "kernel_sizes": [16, 11, 8, 4]
50
+ },
51
+ "quantizer": {
52
+ "input_dim": 1024,
53
+ "codebook_size": 8192,
54
+ "codebook_dim": 8,
55
+ "commitment": 0.25,
56
+ "codebook_loss_weight": 2.0,
57
+ "use_l2_normlize": true,
58
+ "threshold_ema_dead_code": 0.2
59
+ },
60
+ "speaker_encoder": {
61
+ "input_dim": 128,
62
+ "out_dim": 1024,
63
+ "latent_dim": 128,
64
+ "token_num": 32,
65
+ "fsq_levels": [4, 4, 4, 4, 4, 4],
66
+ "fsq_num_quantizers": 1
67
+ },
68
+ "prenet": {
69
+ "input_channels": 1024,
70
+ "vocos_dim": 384,
71
+ "vocos_intermediate_dim": 2048,
72
+ "vocos_num_layers": 12,
73
+ "out_channels": 1024,
74
+ "condition_dim": 1024,
75
+ "sample_ratios": [1, 1],
76
+ "use_tanh_at_final": false
77
+ },
78
+ "postnet": {
79
+ "input_channels": 1024,
80
+ "vocos_dim": 384,
81
+ "vocos_intermediate_dim": 2048,
82
+ "vocos_num_layers": 6,
83
+ "out_channels": 1024,
84
+ "use_tanh_at_final": false
85
+ }
86
+ }
87
+ },
88
+ "torch_dtype": "bfloat16",
89
+ "transformers_version": "4.43.1"
90
+ }
configuration_spark_tts.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ SparkTTS model configuration"""
15
+
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ class SparkTTSConfig(PretrainedConfig):
23
+ """
24
+ This is the configuration class to store the configuration of a [`SparkTTSModel`].
25
+ It is used to instantiate a SparkTTS model according to the specified arguments, defining the model
26
+ architecture and sub-component paths.
27
+
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
29
+ Read the documentation from [`PretrainedConfig`] for more information.
30
+
31
+ Args:
32
+ llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`):
33
+ Path to the pretrained LLM model or model identifier from huggingface.co/models.
34
+ bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`):
35
+ Path to the pretrained BiCodec model directory.
36
+ wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`):
37
+ Path to the pretrained Wav2Vec2 model directory.
38
+ sample_rate (`int`, *optional*, defaults to 16000):
39
+ The sampling rate of the audio files.
40
+ highpass_cutoff_freq (`int`, *optional*, defaults to 40):
41
+ Highpass filter cutoff frequency for audio processing.
42
+ latent_hop_length (`int`, *optional*, defaults to 320):
43
+ Hop length used in BiCodec processing.
44
+ ref_segment_duration (`float`, *optional*, defaults to 6.0):
45
+ Duration (in seconds) of the reference audio clip used for speaker embedding.
46
+ volume_normalize (`bool`, *optional*, defaults to `True`):
47
+ Whether to normalize the volume of audio inputs.
48
+ bicodec_config (`dict`, *optional*):
49
+ A dictionary containing the configuration for the BiCodec model components (encoder, decoder, etc.).
50
+ This is typically loaded from the `BiCodec/config.yaml` originally.
51
+ **kwargs
52
+ Additional keyword arguments passed along to [`PretrainedConfig`].
53
+ """
54
+
55
+ model_type = "spark-tts"
56
+ processor_class = "SparkTTSProcessor"
57
+ attribute_map = {} # Add mappings if needed for renaming attributes
58
+
59
+ def __init__(
60
+ self,
61
+ llm_model_name_or_path="./LLM",
62
+ bicodec_model_name_or_path="./BiCodec",
63
+ wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
64
+ sample_rate=16000,
65
+ highpass_cutoff_freq=40,
66
+ latent_hop_length=320,
67
+ ref_segment_duration=6.0,
68
+ volume_normalize=True,
69
+ bicodec_config=None,
70
+ **kwargs,
71
+ ):
72
+ self.llm_model_name_or_path = llm_model_name_or_path
73
+ self.bicodec_model_name_or_path = bicodec_model_name_or_path
74
+ self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
75
+ self.sample_rate = sample_rate
76
+ self.highpass_cutoff_freq = highpass_cutoff_freq
77
+ self.latent_hop_length = latent_hop_length
78
+ self.ref_segment_duration = ref_segment_duration
79
+ self.volume_normalize = volume_normalize
80
+ self.bicodec_config = bicodec_config if bicodec_config is not None else {}
81
+
82
+ # REMOVE THIS WARNING - the check in SparkTTSModel is better
83
+ # if not self.bicodec_config:
84
+ # logger.warning("BiCodec config is empty. BiCodec model might not load correctly.")
85
+
86
+ super().__init__(**kwargs)
model_index.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_index": [
3
+ {
4
+ "name": "SparkTTS",
5
+ "results": [],
6
+ "config": {
7
+ "architectures": [
8
+ "SparkTTSForConditionalGeneration"
9
+ ],
10
+ "model_type": "sparktts"
11
+ },
12
+ "filenames": [
13
+ "config.json",
14
+ "model_index.json",
15
+ "modeling_sparktts.py",
16
+ "pipeline_sparktts.py",
17
+ "__init__.py",
18
+ "LLM/config.json",
19
+ "LLM/pytorch_model.bin",
20
+ "LLM/tokenizer.json",
21
+ "BiCodec/config.yaml",
22
+ "BiCodec/model.safetensors",
23
+ "wav2vec2-large-xlsr-53/config.json",
24
+ "wav2vec2-large-xlsr-53/pytorch_model.bin",
25
+ "wav2vec2-large-xlsr-53/preprocessor_config.json",
26
+ "sparktts_code/models/audio_tokenizer.py",
27
+ "sparktts_code/models/bicodec.py",
28
+ "sparktts_code/modules/speaker/speaker_encoder.py",
29
+ "sparktts_code/modules/encoder_decoder/feat_encoder.py",
30
+ "sparktts_code/modules/encoder_decoder/feat_decoder.py",
31
+ "sparktts_code/modules/encoder_decoder/wave_generator.py",
32
+ "sparktts_code/modules/vq/factorized_vector_quantize.py",
33
+ "sparktts_code/utils/audio.py",
34
+ "sparktts_code/utils/file.py",
35
+ "sparktts_code/utils/token_parser.py"
36
+ ]
37
+ }
38
+ ]
39
+ }
modeling_spark_tts.py ADDED
The diff for this file is too large to render. See raw diff
 
pipeline_spark_tts.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import re
17
+ import numpy as np
18
+ from typing import Optional, Dict, Any, Union, List
19
+ from pathlib import Path
20
+ from transformers import Pipeline, PreTrainedModel # Inherit directly from the base class
21
+ from transformers.utils import logging
22
+
23
+ # Import necessary items from this module (assuming they are defined in modeling_spark_tts)
24
+ from .modeling_spark_tts import SparkTTSModel
25
+ from .configuration_spark_tts import SparkTTSConfig
26
+ from .modeling_spark_tts import (
27
+ load_audio,
28
+ TASK_TOKEN_MAP,
29
+ LEVELS_MAP,
30
+ GENDER_MAP,
31
+ LEVELS_MAP_UI,
32
+ )
33
+ # No need to import SparkTTSModel/Config here, pipeline gets them during init
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ # Define constants if needed (e.g., for default generation args)
38
+ DEFAULT_MAX_NEW_TOKENS = 3000
39
+ DEFAULT_TEMPERATURE = 0.8
40
+ DEFAULT_TOP_K = 50
41
+ DEFAULT_TOP_P = 0.95
42
+
43
+ class SparkTTSPipeline(Pipeline):
44
+ """
45
+ Custom Pipeline for SparkTTS text-to-speech generation, following HF documentation structure.
46
+ Handles voice cloning and voice creation modes.
47
+ """
48
+ def __init__(self, model, tokenizer=None, framework="pt", device=None, **kwargs):
49
+ # --- KHÔNG NÊN load model ở đây nữa ---
50
+ # __init__ của pipeline tùy chỉnh nên nhận model và tokenizer đã được load
51
+ # Việc load nên xảy ra BÊN NGOÀI trước khi gọi pipeline() hoặc do pipeline factory xử lý
52
+
53
+ # Kiểm tra model và tokenizer được truyền vào
54
+ if model is None:
55
+ raise ValueError("SparkTTSPipeline requires a 'model' argument.")
56
+ if not isinstance(model, SparkTTSModel):
57
+ # Có thể model được load bằng AutoModel nên là PreTrainedModel, kiểm tra config
58
+ if isinstance(model, PreTrainedModel) and isinstance(model.config, SparkTTSConfig):
59
+ pass # OK, model tương thích dựa trên config
60
+ else:
61
+ raise TypeError(f"Expected model compatible with SparkTTSConfig, but got {type(model)}")
62
+
63
+ if tokenizer is None:
64
+ raise ValueError("SparkTTSPipeline requires a 'tokenizer' argument.")
65
+ if not hasattr(tokenizer, 'encode') or not hasattr(tokenizer, 'batch_decode'):
66
+ raise TypeError("Tokenizer does not seem to be a valid Transformers tokenizer.")
67
+
68
+ # Gọi super().__init__ với model/tokenizer đã nhận
69
+ super().__init__(model=model, tokenizer=tokenizer, framework=framework, device=device, **kwargs)
70
+ if hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'):
71
+ self.sampling_rate = self.model.config.sample_rate
72
+ else:
73
+ # Nên đặt giá trị mặc định hoặc lấy từ nơi khác nếu config không có
74
+ logger.warning("Could not determine sampling rate from model config. Defaulting to 16000.")
75
+ self.sampling_rate = 16000
76
+
77
+ def _sanitize_parameters(self, **kwargs) -> tuple[dict, dict, dict]:
78
+ """
79
+ Sanitizes pipeline parameters and separates them for preprocess, forward, and postprocess.
80
+
81
+ Returns:
82
+ Tuple[dict, dict, dict]: preprocess_kwargs, forward_kwargs, postprocess_kwargs
83
+ """
84
+ preprocess_kwargs = {}
85
+ # --- Preprocessing specific args ---
86
+ if "prompt_speech_path" in kwargs:
87
+ preprocess_kwargs["prompt_speech_path"] = kwargs["prompt_speech_path"]
88
+ if "prompt_text" in kwargs:
89
+ preprocess_kwargs["prompt_text"] = kwargs["prompt_text"]
90
+ if "gender" in kwargs:
91
+ preprocess_kwargs["gender"] = kwargs["gender"]
92
+ if "pitch" in kwargs:
93
+ preprocess_kwargs["pitch"] = kwargs["pitch"]
94
+ if "speed" in kwargs:
95
+ preprocess_kwargs["speed"] = kwargs["speed"]
96
+
97
+ forward_kwargs = {}
98
+ # --- Forward specific args (LLM generation) ---
99
+ # Use kwargs.get to allow users to override defaults
100
+ forward_kwargs["max_new_tokens"] = kwargs.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
101
+ forward_kwargs["do_sample"] = kwargs.get("do_sample", True)
102
+ forward_kwargs["temperature"] = kwargs.get("temperature", DEFAULT_TEMPERATURE)
103
+ forward_kwargs["top_k"] = kwargs.get("top_k", DEFAULT_TOP_K)
104
+ forward_kwargs["top_p"] = kwargs.get("top_p", DEFAULT_TOP_P)
105
+ # Ensure essential generation tokens are present if needed
106
+ if self.tokenizer.eos_token_id is not None:
107
+ forward_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
108
+ if self.tokenizer.pad_token_id is not None:
109
+ forward_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
110
+ elif self.tokenizer.eos_token_id is not None:
111
+ logger.warning("Setting pad_token_id to eos_token_id for open-end generation.")
112
+ forward_kwargs["pad_token_id"] = self.tokenizer.eos_token_id
113
+ # Filter out None values that might cause issues with generate
114
+ forward_kwargs = {k: v for k, v in forward_kwargs.items() if v is not None}
115
+
116
+ postprocess_kwargs = {}
117
+ # --- Postprocessing specific args (if any in the future) ---
118
+ # Example: if you added an option to return tokens instead of audio
119
+ # if "return_tokens" in kwargs:
120
+ # postprocess_kwargs["return_tokens"] = kwargs["return_tokens"]
121
+
122
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
123
+
124
+ def preprocess(self, inputs, **preprocess_kwargs) -> dict:
125
+ """
126
+ Transforms text input and preprocess_kwargs into model input format.
127
+
128
+ Args:
129
+ inputs (str): The text to synthesize.
130
+ preprocess_kwargs (dict): Arguments relevant to preprocessing (e.g., prompt paths, controls).
131
+
132
+ Returns:
133
+ dict: Containing `model_inputs` (tokenized dict) and `global_token_ids_prompt` (optional Tensor).
134
+ """
135
+ text = inputs
136
+ prompt_speech_path = preprocess_kwargs.get("prompt_speech_path")
137
+ prompt_text = preprocess_kwargs.get("prompt_text")
138
+ gender = preprocess_kwargs.get("gender")
139
+ pitch = preprocess_kwargs.get("pitch")
140
+ speed = preprocess_kwargs.get("speed")
141
+
142
+ global_token_ids = None
143
+ llm_prompt_string = ""
144
+
145
+ # --- Logic to build llm_prompt_string and get global_token_ids ---
146
+ if prompt_speech_path is not None:
147
+ # Voice Cloning Mode
148
+ logger.info(f"Preprocessing for Voice Cloning (prompt: {prompt_speech_path})")
149
+ if not Path(prompt_speech_path).exists():
150
+ raise FileNotFoundError(f"Prompt speech file not found: {prompt_speech_path}")
151
+ # Use the MODEL's method for tokenization (self.model is set by base Pipeline class)
152
+ global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path)
153
+ global_token_ids = global_tokens # Keep Tensor for detokenization
154
+
155
+ global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()])
156
+ if prompt_text and len(prompt_text) > 1:
157
+ semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()])
158
+ llm_prompt_parts = [
159
+ TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>",
160
+ "<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
161
+ "<|start_semantic_token|>", semantic_tokens_str,
162
+ ]
163
+ else:
164
+ llm_prompt_parts = [
165
+ TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>",
166
+ "<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
167
+ ]
168
+ llm_prompt_string = "".join(llm_prompt_parts)
169
+ elif gender is not None and pitch is not None and speed is not None:
170
+ # Voice Creation Mode
171
+ logger.info(f"Preprocessing for Voice Creation (gender: {gender}, pitch: {pitch}, speed: {speed})")
172
+ if gender not in GENDER_MAP: raise ValueError(f"Invalid gender: {gender}")
173
+ if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch: {pitch}")
174
+ if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed: {speed}")
175
+
176
+ gender_id = GENDER_MAP[gender]
177
+ pitch_level_id = LEVELS_MAP[pitch]
178
+ speed_level_id = LEVELS_MAP[speed]
179
+ attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>"
180
+ llm_prompt_parts = [
181
+ TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>",
182
+ "<|start_style_label|>", attribute_tokens, "<|end_style_label|>",
183
+ ]
184
+ llm_prompt_string = "".join(llm_prompt_parts)
185
+ else:
186
+ raise ValueError("Pipeline requires 'prompt_speech_path' (for cloning) or 'gender', 'pitch', 'speed' (for creation).")
187
+ # --- End prompt building logic ---
188
+
189
+ # Tokenize the final prompt for the LLM
190
+ # Use self.tokenizer (set by base Pipeline class)
191
+ model_inputs = self.tokenizer(llm_prompt_string, return_tensors=self.framework, padding=False)
192
+
193
+ return {"model_inputs": model_inputs, "global_token_ids_prompt": global_token_ids}
194
+
195
+
196
+ def _forward(self, model_inputs, **forward_kwargs) -> dict:
197
+ """
198
+ Passes model_inputs to the model's LLM generate method with forward_kwargs.
199
+
200
+ Args:
201
+ model_inputs (dict): Output from `preprocess`.
202
+ forward_kwargs (dict): Generation arguments from `_sanitize_parameters`.
203
+
204
+ Returns:
205
+ dict: Containing `generated_ids`, `input_ids_len`, and context (`global_token_ids_prompt`).
206
+ """
207
+ llm_inputs = model_inputs["model_inputs"]
208
+ global_token_ids_prompt = model_inputs.get("global_token_ids_prompt")
209
+
210
+ # Move inputs to the correct device (self.device is set by base Pipeline class)
211
+ llm_inputs = {k: v.to(self.device) for k, v in llm_inputs.items()}
212
+ input_ids = llm_inputs["input_ids"]
213
+ input_ids_len = input_ids.shape[-1]
214
+
215
+ # Combine LLM inputs and generation arguments
216
+ generate_kwargs = {**llm_inputs, **forward_kwargs}
217
+
218
+ logger.info(f"Generating tokens with args: {forward_kwargs}")
219
+ # Use the model's LLM component (self.model is set by base Pipeline class)
220
+ with torch.no_grad():
221
+ generated_ids = self.model.llm.generate(**generate_kwargs)
222
+
223
+ # Prepare output dict to pass to postprocess
224
+ output_dict = {
225
+ "generated_ids": generated_ids,
226
+ "input_ids_len": input_ids_len,
227
+ }
228
+ if global_token_ids_prompt is not None:
229
+ output_dict["global_token_ids_prompt"] = global_token_ids_prompt
230
+
231
+ return output_dict
232
+
233
+ def postprocess(self, model_outputs, **postprocess_kwargs) -> dict:
234
+ """
235
+ Transforms model outputs (from _forward) into the final audio dictionary.
236
+
237
+ Args:
238
+ model_outputs (dict): Dictionary from `_forward`.
239
+ postprocess_kwargs (dict): Arguments relevant to postprocessing (currently none).
240
+
241
+ Returns:
242
+ dict: Containing `audio` (np.ndarray) and `sampling_rate` (int).
243
+ """
244
+ generated_ids = model_outputs["generated_ids"]
245
+ input_ids_len = model_outputs["input_ids_len"]
246
+ global_token_ids_prompt = model_outputs.get("global_token_ids_prompt")
247
+
248
+ # --- Logic to extract tokens and detokenize ---
249
+ output_ids = generated_ids[0, input_ids_len:] # Assumes batch size 1
250
+ # Use self.tokenizer (set by base Pipeline class)
251
+ predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
252
+
253
+ semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text)
254
+ if not semantic_matches:
255
+ logger.warning("No semantic tokens found. Returning empty audio.")
256
+ # Use self.model.config for sampling rate
257
+ return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.model.config.sample_rate}
258
+
259
+ pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0).to(self.device)
260
+
261
+ if global_token_ids_prompt is not None:
262
+ # Voice Cloning: Use prompt tokens
263
+ global_token_ids = global_token_ids_prompt.to(self.device)
264
+ logger.info("Using global tokens from prompt.")
265
+ else:
266
+ # Voice Creation: Extract generated tokens
267
+ global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text)
268
+ if not global_matches:
269
+ raise ValueError("Voice creation failed: No bicodec_global tokens found.")
270
+ global_token_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0).to(self.device)
271
+ if global_token_ids.ndim == 2: global_token_ids = global_token_ids.unsqueeze(1)
272
+ logger.info("Using global tokens from generated text.")
273
+
274
+ # Use the MODEL's method for detokenization (self.model is set by base Pipeline class)
275
+ wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids)
276
+ logger.info(f"Generated audio shape: {wav_np.shape}")
277
+ # --- End detokenization logic ---
278
+
279
+ # Return final output dictionary
280
+ return {"audio": wav_np, "sampling_rate": self.model.config.sample_rate}
281
+
282
+
283
+
284
+ # --- Add Registration Code Here ---
285
+ # This code will execute when this file is loaded via trust_remote_code
286
+ try:
287
+ from transformers.pipelines import PIPELINE_REGISTRY
288
+ from transformers import AutoModel # Use AutoModel for registration
289
+
290
+ print(f"Registering SparkTTSPipeline for task 'text-to-speech' from pipeline_spark_tts.py...")
291
+ PIPELINE_REGISTRY.register_pipeline(
292
+ "text-to-speech", # Task name
293
+ pipeline_class=SparkTTSPipeline, # The class defined above
294
+ pt_model=AutoModel, # Compatible PT AutoModel class
295
+ # tf_model=None, # Add TF class if needed
296
+ )
297
+ print("Pipeline registration call completed successfully.")
298
+ except ImportError:
299
+ # Handle potential import error if transformers structure changes
300
+ print("WARNING: Could not import PIPELINE_REGISTRY or AutoModel. Pipeline registration failed.")
301
+ except Exception as e:
302
+ print(f"ERROR: An unexpected error occurred during pipeline registration: {e}")
303
+ # --- End Registration Code ---
processing_spark_tts.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The SparkAudio Authors and The HuggingFace Inc. team. All rights reserved.
3
+ # ... (license) ...
4
+ """Processor class for SparkTTS."""
5
+
6
+ import torch
7
+ import re
8
+ import numpy as np
9
+ import warnings
10
+ from typing import Optional, Dict, Any, Union, List, Tuple
11
+ from pathlib import Path
12
+
13
+ from transformers.processing_utils import ProcessorMixin
14
+ from transformers.feature_extraction_utils import FeatureExtractionMixin
15
+ from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
16
+ from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor
17
+ from transformers.utils import logging
18
+
19
+ # Import necessary items directly or ensure they are available via model reference
20
+ # Note: Avoid direct model imports here if possible, rely on the model reference.
21
+ # from .modeling_spark_tts import SparkTTSModel # Avoid direct model import if possible
22
+ from .configuration_spark_tts import SparkTTSConfig # Config is okay
23
+
24
+ # Import utils needed for prompt formatting (assuming they are merged into modeling)
25
+ # We'll access them via the model reference if needed, or duplicate simple ones like token maps.
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ # --- Token Maps (Duplicate here for direct use in processor) ---
30
+ TASK_TOKEN_MAP = {
31
+ "tts": "<|task_tts|>",
32
+ "controllable_tts": "<|task_controllable_tts|>",
33
+ # Add other tasks if needed by processor logic
34
+ }
35
+ LEVELS_MAP = {"very_low": 0, "low": 1, "moderate": 2, "high": 3, "very_high": 4}
36
+ GENDER_MAP = {"female": 0, "male": 1}
37
+ # --- End Token Maps ---
38
+
39
+
40
+ class SparkTTSProcessor(ProcessorMixin):
41
+ r"""
42
+ Constructs a SparkTTS processor which wraps a text tokenizer and an audio feature extractor
43
+ into a single processor.
44
+
45
+ [`SparkTTSProcessor`] offers all the functionalities of [`AutoTokenizer`] and [`Wav2Vec2FeatureExtractor`].
46
+ It processes text input for the LLM and prepares audio inputs if needed (delegating actual audio tokenization
47
+ to the model). It also handles decoding the final output.
48
+
49
+ Args:
50
+ tokenizer (`PreTrainedTokenizerBase`):
51
+ An instance of [`AutoTokenizer`]. The tokenizer is used to encode the prompt text.
52
+ feature_extractor (`Wav2Vec2FeatureExtractor`):
53
+ An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is used to processor reference audio
54
+ (though the main processing happens inside the model).
55
+ model (`PreTrainedModel`, *optional*):
56
+ A reference to the loaded `SparkTTSModel`. This is REQUIRED for voice cloning (prompt audio processing)
57
+ and final audio decoding, as these steps rely on the model's internal BiCodec and Wav2Vec2 components.
58
+ Set this using `processor.model = model` after loading both.
59
+ config (`SparkTTSConfig`, *optional*):
60
+ The configuration object, needed for parameters like sample_rate. Can often be inferred from the model.
61
+ """
62
+ attributes = ["tokenizer", "feature_extractor"]
63
+ tokenizer_class = ("Qwen2TokenizerFast", "Qwen2Tokenizer") # Specify the underlying tokenizer type
64
+ feature_extractor_class = ("Wav2Vec2FeatureExtractor",) # Specify the underlying feature extractor type
65
+
66
+ def __init__(self, tokenizer=None, feature_extractor=None, model=None, config=None, **kwargs):
67
+ if tokenizer is None:
68
+ raise ValueError("SparkTTSProcessor requires a `tokenizer`.")
69
+ if feature_extractor is None:
70
+ # Attempt to load default if path is known or provide clearer error
71
+ raise ValueError("SparkTTSProcessor requires a `feature_extractor` (Wav2Vec2FeatureExtractor).")
72
+
73
+ super().__init__(tokenizer, feature_extractor)
74
+ self.model = model # Store model reference (can be None initially)
75
+ self.config = config # Store config reference
76
+
77
+ # Get sampling rate from config if available
78
+ self.sampling_rate = None
79
+ if self.config and hasattr(self.config, 'sample_rate'):
80
+ self.sampling_rate = self.config.sample_rate
81
+ elif self.model and hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'):
82
+ self.sampling_rate = self.model.config.sample_rate
83
+ else:
84
+ # Try feature extractor default, or raise warning
85
+ if hasattr(self.feature_extractor, 'sampling_rate'):
86
+ self.sampling_rate = self.feature_extractor.sampling_rate
87
+ else:
88
+ logger.warning("Could not determine sampling rate. Defaulting to 16000. Set `processor.sampling_rate` manually if needed.")
89
+ self.sampling_rate = 16000
90
+
91
+
92
+ @classmethod
93
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
94
+ """
95
+ Instantiate a [`SparkTTSProcessor`] from a pretrained processor configuration.
96
+
97
+ Args:
98
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
99
+ This can be either:
100
+ - a string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co.
101
+ - a path to a *directory* containing processor files saved using the `save_pretrained()` method,
102
+ e.g., `./my_model_directory/`.
103
+ **kwargs:
104
+ Additional keyword arguments passed along to both `AutoTokenizer.from_pretrained()` and
105
+ `AutoFeatureExtractor.from_pretrained()`.
106
+ """
107
+ config = kwargs.pop("config", None)
108
+ if config is None:
109
+ # Try loading the specific config first
110
+ try:
111
+ config = SparkTTSConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
112
+ except Exception:
113
+ logger.warning(f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. Processor might lack some config values.")
114
+ config = None
115
+
116
+
117
+ # Resolve component paths relative to the main path
118
+ def _resolve_path(sub_path):
119
+ p = Path(sub_path)
120
+ if p.is_absolute():
121
+ return str(p)
122
+ # Try resolving relative to the main path if it's a directory
123
+ main_path = Path(pretrained_model_name_or_path)
124
+ if main_path.is_dir():
125
+ resolved = main_path / p
126
+ if resolved.exists():
127
+ return str(resolved)
128
+ # Fallback to assuming sub_path is relative within a repo structure (might fail for local non-dirs)
129
+ return sub_path
130
+
131
+ # Determine paths from config or assume defaults
132
+ llm_tokenizer_path = "./LLM"
133
+ w2v_processor_path = "./wav2vec2-large-xlsr-53"
134
+ if config:
135
+ llm_tokenizer_path = getattr(config, 'llm_model_name_or_path', llm_tokenizer_path)
136
+ w2v_processor_path = getattr(config, 'wav2vec2_model_name_or_path', w2v_processor_path)
137
+
138
+ resolved_tokenizer_path = _resolve_path(llm_tokenizer_path)
139
+ resolved_w2v_path = _resolve_path(w2v_processor_path)
140
+
141
+ try:
142
+ tokenizer = AutoTokenizer.from_pretrained(resolved_tokenizer_path, **kwargs)
143
+ except Exception as e:
144
+ raise OSError(f"Could not load tokenizer from {resolved_tokenizer_path}. Ensure path is correct and files exist. Original error: {e}")
145
+
146
+ try:
147
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(resolved_w2v_path, **kwargs)
148
+ except Exception as e:
149
+ raise OSError(f"Could not load feature extractor from {resolved_w2v_path}. Ensure path is correct and files exist. Original error: {e}")
150
+
151
+ # The 'model' attribute will be set later externally
152
+ return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=config)
153
+
154
+
155
+ def __call__(self, text: str = None,
156
+ prompt_speech_path: Optional[str] = None,
157
+ prompt_text: Optional[str] = None,
158
+ gender: Optional[str] = None,
159
+ pitch: Optional[str] = None,
160
+ speed: Optional[str] = None,
161
+ return_tensors: Optional[str] = "pt",
162
+ **kwargs) -> BatchEncoding:
163
+ """
164
+ Main method to process inputs for the SparkTTS model.
165
+
166
+ Args:
167
+ text (`str`): The text to be synthesized.
168
+ prompt_speech_path (`str`, *optional*): Path to prompt audio for voice cloning.
169
+ prompt_text (`str`, *optional*): Transcript of prompt audio.
170
+ gender (`str`, *optional*): Target gender ('male' or 'female') for voice creation.
171
+ pitch (`str`, *optional*): Target pitch level ('very_low'...'very_high') for voice creation.
172
+ speed (`str`, *optional*): Target speed level ('very_low'...'very_high') for voice creation.
173
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
174
+ Framework of the returned tensors (`"pt"` for PyTorch, `"np"` for NumPy).
175
+ **kwargs: Additional arguments (currently ignored).
176
+
177
+ Returns:
178
+ `BatchEncoding`: A dictionary containing the `input_ids`, `attention_mask`, and optionally
179
+ `global_token_ids_prompt` ready for the model's `.generate()` method.
180
+ """
181
+ if text is None:
182
+ raise ValueError("`text` input must be provided.")
183
+
184
+ global_token_ids_prompt = None
185
+ llm_prompt_string = ""
186
+
187
+ if prompt_speech_path is not None:
188
+ # --- Voice Cloning Mode ---
189
+ if self.model is None:
190
+ raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for voice cloning.")
191
+ if not hasattr(self.model, '_tokenize_audio'):
192
+ raise AttributeError("The provided model object does not have the required '_tokenize_audio' method.")
193
+
194
+ logger.info(f"Processing prompt audio: {prompt_speech_path}")
195
+ # Delegate audio tokenization to the model
196
+ try:
197
+ # _tokenize_audio returns (global_tokens, semantic_tokens)
198
+ global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path)
199
+ global_token_ids_prompt = global_tokens # Keep for decoding stage
200
+ except Exception as e:
201
+ logger.error(f"Error tokenizing prompt audio: {e}", exc_info=True)
202
+ raise RuntimeError(f"Failed to process prompt audio file: {prompt_speech_path}. Check file integrity and model compatibility.") from e
203
+
204
+ # Format prompt string using token maps
205
+ global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()])
206
+
207
+ if prompt_text and len(prompt_text) > 1:
208
+ semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()])
209
+ llm_prompt_parts = [
210
+ TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>",
211
+ "<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
212
+ "<|start_semantic_token|>", semantic_tokens_str,
213
+ ]
214
+ else:
215
+ llm_prompt_parts = [
216
+ TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>",
217
+ "<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
218
+ ]
219
+ llm_prompt_string = "".join(llm_prompt_parts)
220
+
221
+ elif gender is not None and pitch is not None and speed is not None:
222
+ # --- Voice Creation Mode ---
223
+ if gender not in GENDER_MAP: raise ValueError(f"Invalid gender '{gender}'.")
224
+ if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch '{pitch}'.")
225
+ if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed '{speed}'.")
226
+
227
+ gender_id = GENDER_MAP[gender]
228
+ pitch_level_id = LEVELS_MAP[pitch]
229
+ speed_level_id = LEVELS_MAP[speed]
230
+
231
+ attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>"
232
+
233
+ llm_prompt_parts = [
234
+ TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>",
235
+ "<|start_style_label|>", attribute_tokens, "<|end_style_label|>",
236
+ ]
237
+ llm_prompt_string = "".join(llm_prompt_parts)
238
+ # No global_token_ids_prompt needed
239
+
240
+ else:
241
+ raise ValueError("Processor requires either 'prompt_speech_path' (for cloning) or 'gender', 'pitch', and 'speed' (for creation).")
242
+
243
+ # Tokenize the final LLM prompt string
244
+ inputs = self.tokenizer(llm_prompt_string, return_tensors=return_tensors, padding=False, truncation=False)
245
+
246
+ # Add prompt global tokens to the output if they exist (for passing to decode)
247
+ if global_token_ids_prompt is not None:
248
+ inputs["global_token_ids_prompt"] = global_token_ids_prompt
249
+
250
+ return inputs
251
+
252
+ def decode(self,
253
+ generated_ids: Union[List[int], np.ndarray, torch.Tensor],
254
+ global_token_ids_prompt: Optional[torch.Tensor] = None,
255
+ input_ids_len: Optional[int] = None,
256
+ skip_special_tokens: bool = True) -> Dict[str, Any]:
257
+ """
258
+ Decodes the raw token IDs generated by the model into an audio waveform.
259
+
260
+ Args:
261
+ generated_ids (`Union[List[int], np.ndarray, torch.Tensor]`):
262
+ The token IDs generated by the `model.generate()` method. Assumed to be a single sequence (batch size 1).
263
+ global_token_ids_prompt (`torch.Tensor`, *optional*):
264
+ The global tokens obtained from the prompt audio during preprocessing (needed for voice cloning).
265
+ Should be passed from the `__call__` output.
266
+ input_ids_len (`int`, *optional*):
267
+ The length of the original prompt `input_ids`. If provided, the prompt part will be stripped from
268
+ `generated_ids` before decoding the text representation. If None, assumes `generated_ids` contains
269
+ *only* the generated part.
270
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
271
+ Whether to skip special tokens when decoding the text representation for parsing.
272
+
273
+ Returns:
274
+ `Dict[str, Any]`: A dictionary containing:
275
+ - `audio` (`np.ndarray`): The generated audio waveform.
276
+ - `sampling_rate` (`int`): The sampling rate of the audio.
277
+ """
278
+ if self.model is None:
279
+ raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for decoding.")
280
+ if not hasattr(self.model, '_detokenize_audio'):
281
+ raise AttributeError("The provided model object does not have the required '_detokenize_audio' method.")
282
+ if self.sampling_rate is None:
283
+ raise ValueError("Processor could not determine sampling_rate. Set `processor.sampling_rate`.")
284
+
285
+ # Ensure generated_ids is a tensor on the correct device
286
+ if isinstance(generated_ids, (list, np.ndarray)):
287
+ output_ids_tensor = torch.tensor(generated_ids)
288
+ else:
289
+ output_ids_tensor = generated_ids
290
+
291
+ # Remove prompt if input_ids_len is provided
292
+ if input_ids_len is not None:
293
+ # Handle potential batch dimension if present (though usually not for decode)
294
+ if output_ids_tensor.ndim > 1:
295
+ output_ids = output_ids_tensor[0, input_ids_len:]
296
+ else:
297
+ output_ids = output_ids_tensor[input_ids_len:]
298
+ else:
299
+ if output_ids_tensor.ndim > 1:
300
+ output_ids = output_ids_tensor[0]
301
+ else:
302
+ output_ids = output_ids_tensor
303
+
304
+ if output_ids.numel() == 0:
305
+ logger.warning("Received empty generated IDs after removing prompt. Returning empty audio.")
306
+ return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
307
+
308
+ # Decode the text representation to parse tokens
309
+ predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens)
310
+
311
+ # Extract semantic tokens
312
+ semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text)
313
+ if not semantic_matches:
314
+ logger.warning("No semantic tokens found in the generated output text. Cannot synthesize audio.")
315
+ return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
316
+ # Use model's device for tensors
317
+ device = self.model.device
318
+ pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches], dtype=torch.long, device=device).unsqueeze(0) # Add batch dim
319
+
320
+ # Determine global tokens
321
+ if global_token_ids_prompt is not None:
322
+ # Voice Cloning: Use prompt global tokens
323
+ global_token_ids = global_token_ids_prompt.to(device)
324
+ # Ensure correct shape (B, T_token, Q) or (B, D) - BiCodec detokenize needs to handle this
325
+ if global_token_ids.ndim == 2: # If (B, D), maybe unsqueeze? Check BiCodec.detokenize expectation
326
+ global_token_ids = global_token_ids.unsqueeze(1) # Assume (B, 1, D) might be needed
327
+ else:
328
+ # Voice Creation: Parse global tokens from generated text
329
+ global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text)
330
+ if not global_matches:
331
+ logger.error("Voice creation failed: No global tokens found in generated text.")
332
+ raise ValueError("Voice creation failed: Could not find bicodec_global tokens in the LLM output.")
333
+ global_token_ids = torch.tensor([int(token) for token in global_matches], dtype=torch.long, device=device).unsqueeze(0) # Add batch dim
334
+ # Add sequence dimension if needed (check BiCodec.detokenize)
335
+ if global_token_ids.ndim == 2:
336
+ global_token_ids = global_token_ids.unsqueeze(1) # Assume (B, 1, D)
337
+
338
+ # Detokenize audio using the model's method
339
+ try:
340
+ wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids)
341
+ except Exception as e:
342
+ logger.error(f"Error during audio detokenization: {e}", exc_info=True)
343
+ raise RuntimeError("Failed to synthesize audio waveform from generated tokens.") from e
344
+
345
+ return {"audio": wav_np, "sampling_rate": self.sampling_rate}
wav2vec2-large-xlsr-53/README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: multilingual
3
+ datasets:
4
+ - common_voice
5
+ tags:
6
+ - speech
7
+ license: apache-2.0
8
+ ---
9
+
10
+ # Wav2Vec2-XLSR-53
11
+
12
+ [Facebook's XLSR-Wav2Vec2](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/)
13
+
14
+ The base model pretrained on 16kHz sampled speech audio. When using the model make sure that your speech input is also sampled at 16Khz. Note that this model should be fine-tuned on a downstream task, like Automatic Speech Recognition. Check out [this blog](https://huggingface.co/blog/fine-tune-wav2vec2-english) for more information.
15
+
16
+ [Paper](https://arxiv.org/abs/2006.13979)
17
+
18
+ Authors: Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli
19
+
20
+ **Abstract**
21
+ This paper presents XLSR which learns cross-lingual speech representations by pretraining a single model from the raw waveform of speech in multiple languages. We build on wav2vec 2.0 which is trained by solving a contrastive task over masked latent speech representations and jointly learns a quantization of the latents shared across languages. The resulting model is fine-tuned on labeled data and experiments show that cross-lingual pretraining significantly outperforms monolingual pretraining. On the CommonVoice benchmark, XLSR shows a relative phoneme error rate reduction of 72% compared to the best known results. On BABEL, our approach improves word error rate by 16% relative compared to a comparable system. Our approach enables a single multilingual speech recognition model which is competitive to strong individual models. Analysis shows that the latent discrete speech representations are shared across languages with increased sharing for related languages. We hope to catalyze research in low-resource speech understanding by releasing XLSR-53, a large model pretrained in 53 languages.
22
+
23
+ The original model can be found under https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#wav2vec-20.
24
+
25
+ # Usage
26
+
27
+ See [this notebook](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Fine_Tune_XLSR_Wav2Vec2_on_Turkish_ASR_with_%F0%9F%A4%97_Transformers.ipynb) for more information on how to fine-tune the model.
28
+
29
+ ![model image](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/xlsr_wav2vec2.png)
wav2vec2-large-xlsr-53/config.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "apply_spec_augment": true,
4
+ "architectures": [
5
+ "Wav2Vec2ForPreTraining"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "bos_token_id": 1,
9
+ "codevector_dim": 768,
10
+ "contrastive_logits_temperature": 0.1,
11
+ "conv_bias": true,
12
+ "conv_dim": [
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512,
18
+ 512,
19
+ 512
20
+ ],
21
+ "conv_kernel": [
22
+ 10,
23
+ 3,
24
+ 3,
25
+ 3,
26
+ 3,
27
+ 2,
28
+ 2
29
+ ],
30
+ "conv_stride": [
31
+ 5,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2,
37
+ 2
38
+ ],
39
+ "ctc_loss_reduction": "sum",
40
+ "ctc_zero_infinity": false,
41
+ "diversity_loss_weight": 0.1,
42
+ "do_stable_layer_norm": true,
43
+ "eos_token_id": 2,
44
+ "feat_extract_activation": "gelu",
45
+ "feat_extract_dropout": 0.0,
46
+ "feat_extract_norm": "layer",
47
+ "feat_proj_dropout": 0.1,
48
+ "feat_quantizer_dropout": 0.0,
49
+ "final_dropout": 0.0,
50
+ "gradient_checkpointing": false,
51
+ "hidden_act": "gelu",
52
+ "hidden_dropout": 0.1,
53
+ "hidden_size": 1024,
54
+ "initializer_range": 0.02,
55
+ "intermediate_size": 4096,
56
+ "layer_norm_eps": 1e-05,
57
+ "layerdrop": 0.1,
58
+ "mask_channel_length": 10,
59
+ "mask_channel_min_space": 1,
60
+ "mask_channel_other": 0.0,
61
+ "mask_channel_prob": 0.0,
62
+ "mask_channel_selection": "static",
63
+ "mask_feature_length": 10,
64
+ "mask_feature_prob": 0.0,
65
+ "mask_time_length": 10,
66
+ "mask_time_min_space": 1,
67
+ "mask_time_other": 0.0,
68
+ "mask_time_prob": 0.075,
69
+ "mask_time_selection": "static",
70
+ "model_type": "wav2vec2",
71
+ "num_attention_heads": 16,
72
+ "num_codevector_groups": 2,
73
+ "num_codevectors_per_group": 320,
74
+ "num_conv_pos_embedding_groups": 16,
75
+ "num_conv_pos_embeddings": 128,
76
+ "num_feat_extract_layers": 7,
77
+ "num_hidden_layers": 24,
78
+ "num_negatives": 100,
79
+ "pad_token_id": 0,
80
+ "proj_codevector_dim": 768,
81
+ "transformers_version": "4.7.0.dev0",
82
+ "vocab_size": 32
83
+ }
wav2vec2-large-xlsr-53/preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
wav2vec2-large-xlsr-53/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:314340227371a608f71adcd5f0de5933824fe77e55822aa4b24dba9c1c364dcb
3
+ size 1269737156