patrickvonplaten commited on
Commit
1a847ef
·
1 Parent(s): 17fe79b

adapt config to robust wav2vec2; play around with dummy run

Browse files
.gitattributes CHANGED
@@ -14,3 +14,4 @@
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -21,6 +21,11 @@ Authors: Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
21
  We show for the first time that learning powerful representations from speech audio alone followed by fine-tuning on transcribed speech can outperform the best semi-supervised methods while being conceptually simpler. wav2vec 2.0 masks the speech input in the latent space and solves a contrastive task defined over a quantization of the latent representations which are jointly learned. Experiments using all labeled data of Librispeech achieve 1.8/3.3 WER on the clean/other test sets. When lowering the amount of labeled data to one hour, wav2vec 2.0 outperforms the previous state of the art on the 100 hour subset while using 100 times less labeled data. Using just ten minutes of labeled data and pre-training on 53k hours of unlabeled data still achieves 4.8/8.2 WER. This demonstrates the feasibility of speech recognition with limited amounts of labeled data.
22
  The original model can be found under https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#wav2vec-20.
23
 
 
 
 
 
 
24
  ## Model description `TODO: Update`
25
 
26
  ## How to use `TODO: Update`
 
21
  We show for the first time that learning powerful representations from speech audio alone followed by fine-tuning on transcribed speech can outperform the best semi-supervised methods while being conceptually simpler. wav2vec 2.0 masks the speech input in the latent space and solves a contrastive task defined over a quantization of the latent representations which are jointly learned. Experiments using all labeled data of Librispeech achieve 1.8/3.3 WER on the clean/other test sets. When lowering the amount of labeled data to one hour, wav2vec 2.0 outperforms the previous state of the art on the 100 hour subset while using 100 times less labeled data. Using just ten minutes of labeled data and pre-training on 53k hours of unlabeled data still achieves 4.8/8.2 WER. This demonstrates the feasibility of speech recognition with limited amounts of labeled data.
22
  The original model can be found under https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#wav2vec-20.
23
 
24
+ ## Necessary installations:
25
+
26
+ - sndfile library: `sudo apt-get install libsndfile1-dev`
27
+ - ffmpeg: `sudo apt install ffmpeg` & `pip install ffmpeg`
28
+
29
  ## Model description `TODO: Update`
30
 
31
  ## How to use `TODO: Update`
config.json CHANGED
@@ -39,10 +39,10 @@
39
  "ctc_loss_reduction": "sum",
40
  "ctc_zero_infinity": false,
41
  "diversity_loss_weight": 0.1,
42
- "do_stable_layer_norm": false,
43
  "eos_token_id": 2,
44
  "feat_extract_activation": "gelu",
45
- "feat_extract_norm": "group",
46
  "feat_proj_dropout": 0.1,
47
  "feat_quantizer_dropout": 0.0,
48
  "final_dropout": 0.0,
 
39
  "ctc_loss_reduction": "sum",
40
  "ctc_zero_infinity": false,
41
  "diversity_loss_weight": 0.1,
42
+ "do_stable_layer_norm": true,
43
  "eos_token_id": 2,
44
  "feat_extract_activation": "gelu",
45
+ "feat_extract_norm": "layer",
46
  "feat_proj_dropout": 0.1,
47
  "feat_quantizer_dropout": 0.0,
48
  "final_dropout": 0.0,
dummy/events.out.tfevents.1625513107.t1v-n-3abeb69a-w-0.321469.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f577bc179f9667346b0ed1dba80874f1381cc2be7c673d7993d0e3d5986ac40
3
+ size 40
dummy/events.out.tfevents.1625513446.t1v-n-3abeb69a-w-0.323511.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0c445e153190a878b14309f5d5aff5c23ad9cba191537a0f3ef757a5f2cb01c
3
+ size 40
dummy/events.out.tfevents.1625513922.t1v-n-3abeb69a-w-0.327753.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:857103668126e88fc40e91917710bc62dc17b7c6c917f653b350439ca66cbfd7
3
+ size 40
dummy/events.out.tfevents.1625514432.t1v-n-3abeb69a-w-0.496301.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5e783fd11032a79389eee5b16af099eb6d8faebec27c2eed4c30a716c6634ab
3
+ size 40
preprocessor_config.json CHANGED
@@ -3,6 +3,6 @@
3
  "feature_size": 1,
4
  "padding_side": "right",
5
  "padding_value": 0.0,
6
- "return_attention_mask": false,
7
  "sampling_rate": 16000
8
  }
 
3
  "feature_size": 1,
4
  "padding_side": "right",
5
  "padding_value": 0.0,
6
+ "return_attention_mask": true,
7
  "sampling_rate": 16000
8
  }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ librosa
2
+ tensorflow
3
+ ffmpeg
run_dummy.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ ./run_wav2vec2_pretrain_flax.py \
3
+ --output_dir="./dummy" \
4
+ --speech_file_column="path" \
5
+ --num_train_epochs="1" \
6
+ --per_device_train_batch_size="4" \
7
+ --per_device_eval_batch_size="4" \
8
+ --validation_split_percentage="50" \
9
+ --learning_rate="5e-4" \
10
+ --weight_decay="0.01" \
11
+ --warmup_steps="200" \
12
+ --model_name_or_path="./" \
13
+ --dataset_name="common_voice" \
14
+ --dataset_config_name="cnh" \
15
+ --preprocessing_num_workers="4" \
16
+ --max_duration_in_seconds="20.0" \
17
+ --adam_beta1="0.9" \
18
+ --adam_beta2="0.98" \
19
+ --dtype="bfloat16" \
20
+ --cache_dir="/home/wav2vec2-experiments/data_cache/" \
21
+ --pad_to_multiple_of="32768" \
22
+ --push_to_hub
run_wav2vec2_pretrain_flax.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import logging
3
+ import sys
4
+ import time
5
+ from dataclasses import field
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Union
8
+
9
+ import numpy as np
10
+ from datasets import DatasetDict, load_dataset
11
+ from tqdm import tqdm
12
+
13
+ import flax
14
+ import jax
15
+ import jax.numpy as jnp
16
+ import librosa
17
+ import optax
18
+ from flax import jax_utils, traverse_util
19
+ from flax.training import train_state
20
+ from flax.training.common_utils import get_metrics, onehot, shard
21
+ from transformers import (
22
+ FlaxWav2Vec2ForPreTraining,
23
+ HfArgumentParser,
24
+ TrainingArguments,
25
+ Wav2Vec2Config,
26
+ Wav2Vec2FeatureExtractor,
27
+ is_tensorboard_available,
28
+ )
29
+ from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices, _sample_negative_indices
30
+
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ @flax.struct.dataclass
36
+ class ModelArguments:
37
+ """
38
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
39
+ """
40
+
41
+ model_name_or_path: str = field(
42
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
43
+ )
44
+ cache_dir: Optional[str] = field(
45
+ default=None,
46
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
47
+ )
48
+ freeze_feature_extractor: Optional[bool] = field(
49
+ default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
50
+ )
51
+ gradient_checkpointing: Optional[bool] = field(
52
+ default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
53
+ )
54
+ verbose_logging: Optional[bool] = field(
55
+ default=False,
56
+ metadata={"help": "Whether to log verbose messages or not."},
57
+ )
58
+ max_gumbel_temperature: Optional[float] = field(
59
+ default=2.0, metadata={"help": "Maximum temperature for gumbel softmax."}
60
+ )
61
+ min_gumbel_temperature: Optional[float] = field(
62
+ default=0.1, metadata={"help": "Minimum temperature for gumbel softmax."}
63
+ )
64
+ gumbel_temperature_decay: Optional[float] = field(
65
+ default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
66
+ )
67
+ dtype: Optional[str] = field(
68
+ default="float32",
69
+ metadata={
70
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
71
+ },
72
+ )
73
+
74
+
75
+ @flax.struct.dataclass
76
+ class DataTrainingArguments:
77
+ """
78
+ Arguments pertaining to what data we are going to input our model for training and eval.
79
+
80
+ Using `HfArgumentParser` we can turn this class
81
+ into argparse arguments to be able to specify them on
82
+ the command line.
83
+ """
84
+
85
+ dataset_name: str = field(
86
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
87
+ )
88
+ dataset_config_name: Optional[str] = field(
89
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
90
+ )
91
+ train_split_name: Optional[str] = field(
92
+ default="train",
93
+ metadata={
94
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
95
+ },
96
+ )
97
+ validation_split_name: Optional[str] = field(
98
+ default="validation",
99
+ metadata={
100
+ "help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
101
+ },
102
+ )
103
+ speech_file_column: Optional[str] = field(
104
+ default="file",
105
+ metadata={"help": "Column in the dataset that contains speech file path. Defaults to 'file'"},
106
+ )
107
+ overwrite_cache: bool = field(
108
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
109
+ )
110
+ validation_split_percentage: Optional[int] = field(
111
+ default=5,
112
+ metadata={
113
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
114
+ },
115
+ )
116
+ preprocessing_num_workers: Optional[int] = field(
117
+ default=None,
118
+ metadata={"help": "The number of processes to use for the preprocessing."},
119
+ )
120
+ max_duration_in_seconds: Optional[float] = field(
121
+ default=20.0, metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"}
122
+ )
123
+ pad_to_multiple_of: Optional[int] = field(
124
+ default=1024,
125
+ metadata={
126
+ "help": "If set will pad the sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
127
+ },
128
+ )
129
+
130
+
131
+ @flax.struct.dataclass
132
+ class FlaxDataCollatorForWav2Vec2Pretraining:
133
+ """
134
+ Data collator that will dynamically pad the inputs received and prepare masked indices
135
+ for self-supervised pretraining.
136
+
137
+ Args:
138
+ model (:class:`~transformers.FlaxWav2Vec2ForPreTraining`):
139
+ The Wav2Vec2 model used for pretraining. The data collator needs to have access
140
+ to config and ``_get_feat_extract_output_lengths`` function for correct padding.
141
+ feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
142
+ The processor used for proccessing the data.
143
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
144
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
145
+ among:
146
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
147
+ sequence if provided).
148
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
149
+ maximum acceptable input length for the model if that argument is not provided.
150
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
151
+ different lengths).
152
+ max_length (:obj:`int`, `optional`):
153
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
154
+ pad_to_multiple_of (:obj:`int`, `optional`):
155
+ If set will pad the sequence to a multiple of the provided value.
156
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
157
+ 7.5 (Volta).
158
+ """
159
+
160
+ model: FlaxWav2Vec2ForPreTraining
161
+ feature_extractor: Wav2Vec2FeatureExtractor
162
+ padding: Union[bool, str] = "longest"
163
+ pad_to_multiple_of: Optional[int] = None
164
+ max_length: Optional[int] = None
165
+
166
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
167
+ # reformat list to dict and set to pytorch format
168
+ batch = self.feature_extractor.pad(
169
+ features,
170
+ max_length=self.max_length,
171
+ padding=self.padding,
172
+ pad_to_multiple_of=self.pad_to_multiple_of,
173
+ return_tensors="np",
174
+ )
175
+ mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
176
+
177
+ # sample randomly masked indices
178
+ batch["mask_time_indices"] = _compute_mask_indices(
179
+ (batch["input_values"].shape[0], mask_indices_seq_length),
180
+ self.model.config.mask_time_prob,
181
+ self.model.config.mask_time_length,
182
+ min_masks=2,
183
+ )
184
+
185
+ # sample indices to take for negative vectors
186
+ batch["sampled_negative_indices"] = _sample_negative_indices(
187
+ (batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
188
+ self.model.config.num_negatives,
189
+ )
190
+
191
+ return batch
192
+
193
+
194
+ def configure_logger(model_args: ModelArguments, training_args: TrainingArguments):
195
+ logging.basicConfig(
196
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
197
+ datefmt="%m/%d/%Y %H:%M:%S",
198
+ handlers=[logging.StreamHandler(sys.stdout)],
199
+ )
200
+ logging_level = logging.WARNING
201
+ if model_args.verbose_logging:
202
+ logging_level = logging.DEBUG
203
+ logger.setLevel(logging_level)
204
+
205
+
206
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
207
+ summary_writer.scalar("train_time", train_time, step)
208
+
209
+ train_metrics = get_metrics(train_metrics)
210
+ for key, vals in train_metrics.items():
211
+ tag = f"train_{key}"
212
+ for i, val in enumerate(vals):
213
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
214
+
215
+
216
+ def write_eval_metric(summary_writer, eval_metrics, step):
217
+ for metric_name, value in eval_metrics.items():
218
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
219
+
220
+
221
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
222
+ num_samples = len(samples_idx)
223
+ samples_to_remove = num_samples % batch_size
224
+
225
+ if samples_to_remove != 0:
226
+ samples_idx = samples_idx[:-samples_to_remove]
227
+ sections_split = num_samples // batch_size
228
+ batch_idx = np.split(samples_idx, sections_split)
229
+ return batch_idx
230
+
231
+
232
+ def compute_contrastive_loss(
233
+ quantized_features, transformer_features, negative_indices, mask_time_indices, logits_temp, num_negatives
234
+ ):
235
+ batch_size, sequence_length, hidden_size = quantized_features.shape
236
+
237
+ # take negative vectors from sampled indices
238
+ quantized_negatives = quantized_features.reshape(-1, hidden_size)[negative_indices.reshape(-1)]
239
+ quantized_negatives = quantized_negatives.reshape(
240
+ batch_size, sequence_length, num_negatives, hidden_size
241
+ ).transpose(2, 0, 1, 3)
242
+
243
+ target_features = jnp.concatenate([quantized_features[None, :], quantized_negatives], axis=0)
244
+ loss_logits = optax.cosine_similarity(transformer_features, target_features)
245
+ loss_logits = loss_logits / logits_temp
246
+
247
+ neg_is_pos = (quantized_features == quantized_negatives).all(-1)
248
+ neg_is_pos = jnp.concatenate([jnp.full((1,) + loss_logits.shape[1:], False), neg_is_pos], axis=0)
249
+
250
+ # make sure incorrectly sampled vectors don't contribute to loss
251
+ loss_logits = jnp.where(neg_is_pos, -1e9, loss_logits)
252
+
253
+ predictions = loss_logits.transpose(2, 1, 0).reshape(-1, loss_logits.shape[0])
254
+ targets = ((1 - mask_time_indices) * -100).transpose(1, 0).flatten()
255
+
256
+ target_mask = jnp.where(targets >= 0, 1.0, 0.0)
257
+ contrastive_loss = optax.softmax_cross_entropy(predictions, onehot(targets, predictions.shape[-1])) * target_mask
258
+
259
+ contrastive_loss = contrastive_loss.sum()
260
+
261
+ return contrastive_loss
262
+
263
+
264
+ def main():
265
+ # See all possible arguments in src/transformers/training_args.py
266
+ # or by passing the --help flag to this script.
267
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
268
+
269
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
270
+
271
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
272
+ configure_logger(model_args, training_args)
273
+
274
+ # Downloading and loading a dataset from the hub.
275
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
276
+
277
+ if "validation" not in datasets.keys():
278
+ # make sure only "validation" and "train" keys remain"
279
+ datasets = DatasetDict()
280
+ datasets["validation"] = load_dataset(
281
+ data_args.dataset_name,
282
+ data_args.dataset_config_name,
283
+ split=f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]",
284
+ cache_dir=model_args.cache_dir,
285
+ )
286
+ datasets["train"] = load_dataset(
287
+ data_args.dataset_name,
288
+ data_args.dataset_config_name,
289
+ split=f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]",
290
+ cache_dir=model_args.cache_dir,
291
+ )
292
+ else:
293
+ # make sure only "validation" and "train" keys remain"
294
+ datasets = DatasetDict()
295
+ datasets["validation"] = load_dataset(
296
+ data_args.dataset_name,
297
+ data_args.dataset_config_name,
298
+ split="validation",
299
+ cache_dir=model_args.cache_dir,
300
+ )
301
+ datasets["train"] = load_dataset(
302
+ data_args.dataset_name,
303
+ data_args.dataset_config_name,
304
+ split=f"{data_args.train_split_name}",
305
+ cache_dir=model_args.cache_dir,
306
+ )
307
+
308
+ # only normalized-inputs-training is supported
309
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
310
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True
311
+ )
312
+
313
+ def prepare_dataset(batch):
314
+ # check that all files have the correct sampling rate
315
+ batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate)
316
+ return batch
317
+
318
+ # load audio files into numpy arrays
319
+ vectorized_datasets = datasets.map(
320
+ prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names
321
+ )
322
+
323
+ # filter audio files that are too long
324
+ vectorized_datasets = vectorized_datasets.filter(
325
+ lambda data: len(data["speech"]) < int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
326
+ )
327
+
328
+ def normalize(batch):
329
+ return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate)
330
+
331
+ # normalize and transform to `BatchFeatures`
332
+ vectorized_datasets = vectorized_datasets.map(
333
+ normalize,
334
+ batched=True,
335
+ num_proc=data_args.preprocessing_num_workers,
336
+ load_from_cache_file=not data_args.overwrite_cache,
337
+ remove_columns=vectorized_datasets["train"].column_names,
338
+ )
339
+
340
+ # pretraining is only supported for "newer" stable layer norm architecture
341
+ # apply_spec_augment has to be True, mask_feature_prob has to be 0.0
342
+ config = Wav2Vec2Config.from_pretrained(
343
+ model_args.model_name_or_path,
344
+ cache_dir=model_args.cache_dir,
345
+ gradient_checkpointing=model_args.gradient_checkpointing,
346
+ )
347
+
348
+ if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
349
+ raise ValueError(
350
+ "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
351
+ )
352
+
353
+ model = FlaxWav2Vec2ForPreTraining(
354
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
355
+ )
356
+
357
+ data_collator = FlaxDataCollatorForWav2Vec2Pretraining(
358
+ model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of
359
+ )
360
+
361
+ # Enable tensorboard only on the master node
362
+ has_tensorboard = is_tensorboard_available()
363
+ if has_tensorboard and jax.process_index() == 0:
364
+ try:
365
+ from flax.metrics.tensorboard import SummaryWriter
366
+
367
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
368
+ except ImportError as ie:
369
+ has_tensorboard = False
370
+ logger.warning(
371
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
372
+ )
373
+ else:
374
+ logger.warning(
375
+ "Unable to display metrics through TensorBoard because the package is not installed: "
376
+ "Please run pip install tensorboard to enable."
377
+ )
378
+
379
+ # Initialize our training
380
+ rng = jax.random.PRNGKey(training_args.seed)
381
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
382
+ gumbel_rngs = jax.random.split(rng, jax.local_device_count())
383
+
384
+ num_epochs = int(training_args.num_train_epochs)
385
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
386
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
387
+
388
+ num_train_steps = len(vectorized_datasets["train"]) // train_batch_size * num_epochs
389
+
390
+ # Create learning rate schedule
391
+ warmup_fn = optax.linear_schedule(
392
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
393
+ )
394
+ decay_fn = optax.linear_schedule(
395
+ init_value=training_args.learning_rate,
396
+ end_value=0,
397
+ transition_steps=num_train_steps - training_args.warmup_steps,
398
+ )
399
+ linear_decay_lr_schedule_fn = optax.join_schedules(
400
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
401
+ )
402
+
403
+ # We use Optax's "masking" functionality to not apply weight decay
404
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
405
+ # mask boolean with the same structure as the parameters.
406
+ # The mask is True for parameters that should be decayed.
407
+ def decay_mask_fn(params):
408
+ flat_params = traverse_util.flatten_dict(params)
409
+ flat_mask = {
410
+ path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
411
+ for path in flat_params
412
+ }
413
+ return traverse_util.unflatten_dict(flat_mask)
414
+
415
+ # create adam optimizer
416
+ adamw = optax.adamw(
417
+ learning_rate=linear_decay_lr_schedule_fn,
418
+ b1=training_args.adam_beta1,
419
+ b2=training_args.adam_beta2,
420
+ eps=training_args.adam_epsilon,
421
+ weight_decay=training_args.weight_decay,
422
+ mask=decay_mask_fn,
423
+ )
424
+
425
+ # Setup train state and define training hyper-parameters
426
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
427
+ num_negatives = model.config.num_negatives
428
+ contrastive_logits_temperature = model.config.contrastive_logits_temperature
429
+ num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
430
+ diversity_loss_weight = model.config.diversity_loss_weight
431
+
432
+ # Define gradient update step fn
433
+ def train_step(state, batch, dropout_rng, gumbel_rng):
434
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
435
+ gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng)
436
+
437
+ def loss_fn(params):
438
+ negative_indices = batch.pop("sampled_negative_indices")
439
+
440
+ gumbel_temperature = jnp.clip(
441
+ model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay ** state.step,
442
+ a_min=model_args.min_gumbel_temperature,
443
+ )
444
+
445
+ outputs = state.apply_fn(
446
+ **batch,
447
+ gumbel_temperature=gumbel_temperature,
448
+ params=params,
449
+ dropout_rng=dropout_rng,
450
+ gumbel_rng=gumbel_rng,
451
+ train=True,
452
+ )
453
+
454
+ contrastive_loss = compute_contrastive_loss(
455
+ outputs.projected_quantized_states,
456
+ outputs.projected_states,
457
+ negative_indices,
458
+ batch["mask_time_indices"],
459
+ contrastive_logits_temperature,
460
+ num_negatives,
461
+ )
462
+
463
+ diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
464
+ loss = contrastive_loss + diversity_loss_weight * diversity_loss
465
+
466
+ return loss
467
+
468
+ grad_fn = jax.value_and_grad(loss_fn)
469
+ loss, grad = grad_fn(state.params)
470
+ grad = jax.lax.pmean(grad, "batch")
471
+ new_state = state.apply_gradients(grads=grad)
472
+
473
+ metrics = jax.lax.pmean(
474
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
475
+ )
476
+
477
+ return new_state, metrics, new_dropout_rng, new_gumbel_rng
478
+
479
+ # Create parallel version of the train step
480
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
481
+
482
+ # Define eval fn
483
+ def eval_step(params, batch):
484
+ negative_indices = batch.pop("sampled_negative_indices")
485
+
486
+ outputs = model(**batch, params=params, train=False)
487
+
488
+ contrastive_loss = compute_contrastive_loss(
489
+ outputs.projected_quantized_states,
490
+ outputs.projected_states,
491
+ negative_indices,
492
+ batch["mask_time_indices"],
493
+ contrastive_logits_temperature,
494
+ num_negatives,
495
+ )
496
+
497
+ diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
498
+ loss = contrastive_loss + diversity_loss_weight * diversity_loss
499
+
500
+ # summarize metrics
501
+ metrics = {"loss": loss.mean(), "codevector_perplexity": outputs.codevector_perplexity}
502
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
503
+
504
+ return metrics
505
+
506
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
507
+
508
+ # Replicate the train state on each device
509
+ state = jax_utils.replicate(state)
510
+
511
+ train_time = 0
512
+ train_metrics = []
513
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
514
+ for epoch in epochs:
515
+ # ======================== Training ================================
516
+ train_start = time.time()
517
+
518
+ # Create sampling rng
519
+ rng, input_rng = jax.random.split(rng)
520
+
521
+ # Generate an epoch by shuffling sampling indices from the train dataset
522
+ num_train_samples = len(vectorized_datasets["train"])
523
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
524
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
525
+
526
+ # Gather the indexes for creating the batch and do a training step
527
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
528
+ samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
529
+ model_inputs = data_collator(samples)
530
+ model_inputs = shard(model_inputs.data)
531
+
532
+ # Model forward
533
+ state, train_metric, dropout_rngs, gumbel_rngs = p_train_step(
534
+ state, model_inputs, dropout_rngs, gumbel_rngs
535
+ )
536
+ train_metrics.append(train_metric)
537
+
538
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
539
+
540
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
541
+ # Save metrics
542
+ train_metric = jax_utils.unreplicate(train_metric)
543
+ train_time += time.time() - train_start
544
+ if has_tensorboard and jax.process_index() == 0:
545
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
546
+
547
+ epochs.write(
548
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
549
+ )
550
+
551
+ train_metrics = []
552
+
553
+ # ======================== Evaluating ==============================
554
+ num_eval_samples = len(vectorized_datasets["validation"])
555
+ eval_samples_idx = jnp.arange(num_eval_samples)
556
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
557
+
558
+ eval_metrics = []
559
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
560
+ samples = [vectorized_datasets["validation"][int(idx)] for idx in batch_idx]
561
+ model_inputs = data_collator(samples)
562
+
563
+ # Model forward
564
+ model_inputs = shard(model_inputs.data)
565
+ metrics = p_eval_step(state.params, model_inputs)
566
+ eval_metrics.append(metrics)
567
+
568
+ # get eval metrics
569
+ eval_metrics = get_metrics(eval_metrics)
570
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
571
+
572
+ # Update progress bar
573
+ epochs.write(
574
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})"
575
+ )
576
+
577
+ # Save metrics
578
+ if has_tensorboard and jax.process_index() == 0:
579
+ cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
580
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
581
+
582
+ # save checkpoint after each epoch and push checkpoint to the hub
583
+ if jax.process_index() == 0:
584
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
585
+ model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
586
+
587
+
588
+ if __name__ == "__main__":
589
+ main()
special_tokens_map.json DELETED
@@ -1 +0,0 @@
1
- {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
 
 
tokenizer_config.json DELETED
@@ -1 +0,0 @@
1
- {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>", "do_lower_case": false, "return_attention_mask": false, "do_normalize": true}
 
 
vocab.json DELETED
@@ -1 +0,0 @@
1
- {"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3, "|": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}