voidful commited on
Commit
129c05b
·
verified ·
1 Parent(s): ea6e367

Upload 2 files

Browse files
preprocessor_config.json CHANGED
@@ -1,4 +1,8 @@
1
  {
 
 
 
 
2
  "do_convert_rgb": null,
3
  "do_normalize": true,
4
  "do_pan_and_scan": null,
 
1
  {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_gemma3_omni.Gemma3OmniProcessor",
4
+ "AutoFeatureExtractor": "processing_gemma3_omni.Gemma3AudioFeatureExtractor"
5
+ },
6
  "do_convert_rgb": null,
7
  "do_normalize": true,
8
  "do_pan_and_scan": null,
processing_gemma3_omni.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ from typing import List, Optional, Union, Dict, Any
4
+ import numpy as np
5
+ import scipy.signal
6
+ import torch
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ from transformers.audio_utils import AudioInput
9
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
10
+ from transformers.feature_extraction_utils import BatchFeature
11
+ from transformers.image_utils import make_nested_list_of_images
12
+ from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs, AudioKwargs, Unpack
13
+ from transformers.utils import TensorType, to_py_obj, logging
14
+
15
+ # Constants
16
+ DEFAULT_SAMPLING_RATE = 16000
17
+ DEFAULT_N_FFT = 512
18
+ DEFAULT_WIN_LENGTH = 400
19
+ DEFAULT_HOP_LENGTH = 160
20
+ DEFAULT_N_MELS = 80
21
+ DEFAULT_COMPRESSION_RATE = 4
22
+ DEFAULT_QFORMER_RATE = 2
23
+ DEFAULT_FEAT_STRIDE = 4
24
+ IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
25
+ AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
26
+ DEFAULT_MAX_LENGTH = 16384
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
32
+ fmax: Optional[float] = None) -> np.ndarray:
33
+ """Create Mel filterbank for audio processing."""
34
+ fmax = fmax or sampling_rate / 2
35
+
36
+ def hz_to_mel(f: float) -> float:
37
+ return 1127.0 * math.log(1 + f / 700.0)
38
+
39
+ mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
40
+ freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1)
41
+ bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
42
+
43
+ filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
44
+ for m in range(1, n_mels + 1):
45
+ left, center, right = bins[m - 1:m + 2]
46
+ filterbank[m - 1, left:center] = (np.arange(left, center) - left) / (center - left)
47
+ filterbank[m - 1, center:right] = (right - np.arange(center, right)) / (right - center)
48
+
49
+ return filterbank
50
+
51
+
52
+ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
53
+ """Converts 16-kHz mono waveform to (T, 80) log-Mel frames."""
54
+
55
+ model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
56
+
57
+ def __init__(
58
+ self,
59
+ compression_rate: int = DEFAULT_COMPRESSION_RATE,
60
+ qformer_rate: int = DEFAULT_QFORMER_RATE,
61
+ feat_stride: int = DEFAULT_FEAT_STRIDE,
62
+ sampling_rate: int = DEFAULT_SAMPLING_RATE,
63
+ n_fft: int = DEFAULT_N_FFT,
64
+ win_length: int = DEFAULT_WIN_LENGTH,
65
+ hop_length: int = DEFAULT_HOP_LENGTH,
66
+ n_mels: int = DEFAULT_N_MELS,
67
+ **kwargs
68
+ ):
69
+ super().__init__(n_mels, sampling_rate, 0.0, **kwargs)
70
+ self.compression_rate = compression_rate
71
+ self.qformer_rate = qformer_rate
72
+ self.feat_stride = feat_stride
73
+ self.sampling_rate = sampling_rate
74
+
75
+ self.window = np.hamming(win_length).astype(np.float32)
76
+ self.mel_filterbank = create_mel_filterbank(sampling_rate, n_fft, n_mels).T
77
+ self.n_fft = n_fft
78
+ self.hop_length = hop_length
79
+ self.win_length = win_length
80
+
81
+ def __call__(
82
+ self,
83
+ audios: List[AudioInput],
84
+ return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
85
+ ) -> BatchFeature:
86
+ features, sizes, frames = [], [], []
87
+
88
+ for wav in audios:
89
+ processed_wav = self._preprocess_audio(wav, 22500)
90
+ mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav)
91
+ feature_tensor = torch.tensor(mel_spectrogram, dtype=torch.float32)
92
+ features.append(feature_tensor)
93
+ sizes.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
94
+ frames.append(feature_tensor.shape[0] * self.feat_stride)
95
+
96
+ audio_embeds = pad_sequence(features, batch_first=True)
97
+ size_tensor = torch.stack(sizes)
98
+
99
+ attention_mask = None
100
+ if len(audios) > 1:
101
+ frame_lengths = torch.tensor(frames)
102
+ attention_mask = torch.arange(frame_lengths.max()).unsqueeze(0) < frame_lengths.unsqueeze(1)
103
+
104
+ output_data = {
105
+ "input_audio_embeds": audio_embeds,
106
+ "audio_embed_sizes": size_tensor
107
+ }
108
+ if attention_mask is not None:
109
+ output_data["audio_attention_mask"] = attention_mask
110
+
111
+ return BatchFeature(data=output_data, tensor_type=return_tensors)
112
+
113
+ def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
114
+ wav = torch.as_tensor(wav).float().numpy()
115
+ if wav.ndim > 1:
116
+ wav = wav.mean(axis=0)
117
+ if source_sr != self.sampling_rate:
118
+ wav = scipy.signal.resample_poly(wav, self.sampling_rate, source_sr)
119
+ return wav / max(np.abs(wav).max(), 1e-6)
120
+
121
+ def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
122
+ frame_count = 1 + (len(wav) - self.win_length) // self.hop_length
123
+ strides = wav.strides[0]
124
+ frames = np.lib.stride_tricks.as_strided(
125
+ wav,
126
+ shape=(frame_count, self.win_length),
127
+ strides=(strides * self.hop_length, strides),
128
+ writeable=False
129
+ ).copy()
130
+ frames *= self.window
131
+
132
+ spectrum = np.fft.rfft(frames, n=self.n_fft).astype(np.complex64)
133
+ power = np.abs(spectrum) ** 2
134
+ mel_spectrogram = np.dot(power, self.mel_filterbank)
135
+ mel_spectrogram = np.clip(mel_spectrogram, 1.0, None)
136
+ return np.log(mel_spectrogram, dtype=np.float32)
137
+
138
+ def _calculate_embed_length(self, frame_count: int) -> int:
139
+ compressed = math.ceil(frame_count / self.compression_rate)
140
+ return math.ceil(compressed / self.qformer_rate)
141
+
142
+
143
+ class Gemma3ImagesKwargs(ImagesKwargs):
144
+ do_pan_and_scan: Optional[bool]
145
+ pan_and_scan_min_crop_size: Optional[int]
146
+ pan_and_scan_max_num_crops: Optional[int]
147
+ pan_and_scan_min_ratio_to_activate: Optional[float]
148
+ do_convert_rgb: Optional[bool]
149
+
150
+
151
+ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
152
+ images_kwargs: Dict[str, Any]
153
+ audio_kwargs: Dict[str, Any]
154
+ _defaults = {
155
+ "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
156
+ "images_kwargs": {},
157
+ "audio_kwargs": {}
158
+ }
159
+
160
+
161
+ class Gemma3OmniProcessor(ProcessorMixin):
162
+ attributes = ["image_processor", "tokenizer", "audio_processor"]
163
+ valid_kwargs = ["chat_template", "image_seq_length"]
164
+ image_processor_class = "AutoImageProcessor"
165
+ audio_processor_class = "AutoFeatureExtractor"
166
+ tokenizer_class = "AutoTokenizer"
167
+
168
+ def __init__(
169
+ self,
170
+ image_processor,
171
+ audio_processor,
172
+ tokenizer,
173
+ chat_template=None,
174
+ image_seq_length: int = 256,
175
+ **kwargs
176
+ ):
177
+ self.image_seq_length = image_seq_length
178
+ self.image_token_id = tokenizer.image_token_id
179
+ self.boi_token = tokenizer.boi_token
180
+ self.image_token = tokenizer.image_token
181
+ self.audio_token = "<audio_soft_token>"
182
+ self.expected_audio_token_id = 262143
183
+ self.full_image_sequence = f"\n\n{tokenizer.boi_token}{''.join([tokenizer.image_token] * image_seq_length)}{tokenizer.eoi_token}\n\n"
184
+
185
+ self.compression_rate = 8
186
+ self.qformer_compression_rate = 1
187
+ self.feat_stride = 1
188
+
189
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
190
+ if self.audio_token_id != self.expected_audio_token_id:
191
+ logger.warning(
192
+ f"Assigned ID {self.audio_token_id} for '{self.audio_token}' does not match expected ID {self.expected_audio_token_id}. "
193
+ "Using assigned ID. Model embedding layer may need resizing."
194
+ )
195
+
196
+ super().__init__(
197
+ image_processor=image_processor,
198
+ audio_processor=audio_processor,
199
+ tokenizer=tokenizer,
200
+ chat_template=chat_template,
201
+ **kwargs
202
+ )
203
+
204
+ def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs):
205
+ default_kwargs = {}
206
+ for modality in ModelProcessorKwargs._defaults:
207
+ default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
208
+
209
+ # Update defaults with tokenizer init kwargs
210
+ for modality in default_kwargs:
211
+ modality_kwargs = default_kwargs[modality]
212
+ for key in modality_kwargs:
213
+ if key in tokenizer_init_kwargs:
214
+ value = (
215
+ getattr(self.tokenizer, key)
216
+ if hasattr(self.tokenizer, key)
217
+ else tokenizer_init_kwargs[key]
218
+ )
219
+ modality_kwargs[key] = value
220
+
221
+ # Update with user-provided kwargs
222
+ for modality in default_kwargs:
223
+ if modality in kwargs:
224
+ default_kwargs[modality].update(kwargs[modality])
225
+
226
+ # Ensure text_kwargs has truncation=False and large max_length
227
+ default_kwargs["text_kwargs"]["truncation"] = False
228
+ default_kwargs["text_kwargs"]["max_length"] = default_kwargs["text_kwargs"].get("max_length",
229
+ DEFAULT_MAX_LENGTH)
230
+
231
+ return default_kwargs
232
+
233
+ def _compute_audio_embed_size(self, audio_frames: int) -> int:
234
+ result = math.ceil(audio_frames / self.compression_rate)
235
+ return math.ceil(result / self.qformer_compression_rate)
236
+
237
+ def __call__(
238
+ self,
239
+ images=None,
240
+ text=None,
241
+ videos=None,
242
+ audio=None,
243
+ **kwargs: Unpack[Gemma3ProcessorKwargs]
244
+ ) -> BatchFeature:
245
+ if text is None and images is None:
246
+ raise ValueError("Provide at least one of `text` or `images`.")
247
+
248
+ output_kwargs = self._merge_kwargs(
249
+ Gemma3ProcessorKwargs,
250
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
251
+ **kwargs
252
+ )
253
+
254
+ if isinstance(text, str):
255
+ text = [text]
256
+ elif not isinstance(text, list) or not all(isinstance(t, str) for t in text):
257
+ raise ValueError("Input text must be a string or list of strings")
258
+
259
+ image_inputs = {}
260
+ if images is not None:
261
+ batched_images = make_nested_list_of_images(images)
262
+ image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
263
+
264
+ if not text:
265
+ text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
266
+
267
+ if len(batched_images) != len(text):
268
+ raise ValueError(
269
+ f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts"
270
+ )
271
+
272
+ num_crops = to_py_obj(image_inputs.pop("num_crops"))
273
+ batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
274
+
275
+ for batch_idx, (prompt, images, crops) in enumerate(zip(text, batched_images, batch_num_crops)):
276
+ image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
277
+ if len(images) != len(image_indexes):
278
+ raise ValueError(
279
+ f"Prompt has {len(image_indexes)} image tokens but received {len(images)} images"
280
+ )
281
+
282
+ for num, idx in reversed(list(zip(crops, image_indexes))):
283
+ if num:
284
+ formatted_image_text = (
285
+ f"Here is the original image {self.boi_token} and here are some crops to help you see better "
286
+ + " ".join([self.boi_token] * num)
287
+ )
288
+ prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token):]
289
+ text[batch_idx] = prompt
290
+
291
+ text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
292
+
293
+ audio_inputs = {}
294
+ if audio is not None:
295
+ audio_inputs = self.audio_processor(audio, "pt")
296
+ audio_embeds = audio_inputs['input_audio_embeds']
297
+ audio_frames = audio_embeds.shape[1] * self.feat_stride
298
+ audio_seq_length = self._compute_audio_embed_size(audio_frames)
299
+
300
+ audio_tokens = {
301
+ "boa_token": "<start_of_audio>",
302
+ "eoa_token": "<end_of_audio>",
303
+ "audio_token": "<audio_soft_token>",
304
+ "boa_token_id": 256001,
305
+ "eoa_token_id": 256002,
306
+ "audio_token_id": self.audio_token_id # Use dynamic ID
307
+ }
308
+
309
+ audio_sequence = f"\n\n{audio_tokens['boa_token']}{''.join([audio_tokens['audio_token']] * audio_seq_length)}{audio_tokens['eoa_token']}\n\n"
310
+ text = [prompt.replace(audio_tokens['boa_token'], audio_sequence) for prompt in text]
311
+
312
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
313
+ text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
314
+
315
+ # Debug: Log text and token counts before validation
316
+ for i, (txt, ids) in enumerate(zip(text, text_inputs["input_ids"])):
317
+ audio_text_count = txt.count(self.audio_token)
318
+ audio_ids_count = list(ids).count(self.audio_token_id)
319
+ logger.debug(
320
+ f"Sample {i}: Audio tokens in text={audio_text_count}, in input_ids={audio_ids_count}, "
321
+ f"Text length={len(txt)}, Input IDs length={len(ids)}"
322
+ )
323
+
324
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image", "audio"])
325
+
326
+ array_ids = text_inputs["input_ids"]
327
+ mm_token_type_ids = np.zeros_like(array_ids)
328
+ mm_token_type_ids[array_ids == self.image_token_id] = 1 # Image token type
329
+ mm_token_type_ids[array_ids == self.audio_token_id] = 2 # Audio token type
330
+ text_inputs = {k: v.tolist() for k, v in text_inputs.items()}
331
+ text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
332
+
333
+ return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
334
+
335
+ def batch_decode(self, *args, **kwargs):
336
+ return self.tokenizer.batch_decode(*args, **kwargs)
337
+
338
+ def decode(self, *args, **kwargs):
339
+ return self.tokenizer.decode(*args, **kwargs)
340
+
341
+ @property
342
+ def model_input_names(self):
343
+ tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
344
+ image_processor_inputs = self.image_processor.model_input_names
345
+ return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs))