File size: 1,923 Bytes
7daa40c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from transformers import SpeechEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer

# Encoder for speech feature extraction
encoder_checkpoint = "facebook/wav2vec2-base-en-voxpopuli-v2"
# Decoder for text generation + its tokenizer
decoder_checkpoint = "facebook/bart-base"

# Path where this initial combined model is saved
# This path is then used as --model_name_or_path in the fine-tuning script
# e.g., "./seq2seq_wav2vec2_bart-base_24k-en-voxpopuli"
INITIAL_MODEL_SAVE_PATH = "path_to_save_initial_model"

model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_checkpoint,
    decoder_checkpoint,
    encoder_add_adapter=True,  # Enables adapter mechanism
    encoder_num_adapter_layers=3,  # Specifies 3 adapter layers
)

# Configure encoder properties (example from thesis experiments)
model.config.encoder.feat_proj_dropout = 0.0
# model.config.encoder.mask_time_prob = 0.0 # No SpecAugment at initialization

# Configure decoder start token, pad token, eos token from the decoder's config
model.config.decoder_start_token_id = model.decoder.config.bos_token_id
model.config.pad_token_id = (
    model.decoder.config.pad_token_id
)  # Or tokenizer.pad_token_id
model.config.eos_token_id = (
    model.decoder.config.eos_token_id
)  # Or tokenizer.eos_token_id

# Configure generation parameters
model.config.max_length = 128
model.config.encoder.layerdrop = 0.0
model.config.use_cache = False  # Important for training

# Save the initialized model, feature extractor, and tokenizer
model.save_pretrained(INITIAL_MODEL_SAVE_PATH)

feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_checkpoint)
feature_extractor.save_pretrained(INITIAL_MODEL_SAVE_PATH)

tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
tokenizer.save_pretrained(INITIAL_MODEL_SAVE_PATH)

print(
    f"Initialized model, feature extractor, and tokenizer saved to {INITIAL_MODEL_SAVE_PATH}"
)