Upload 2 files
Browse files- preprocessor_config.json +4 -0
- processing_gemma3_omni.py +345 -0
preprocessor_config.json
CHANGED
@@ -1,4 +1,8 @@
|
|
1 |
{
|
|
|
|
|
|
|
|
|
2 |
"do_convert_rgb": null,
|
3 |
"do_normalize": true,
|
4 |
"do_pan_and_scan": null,
|
|
|
1 |
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoProcessor": "processing_gemma3_omni.Gemma3OmniProcessor",
|
4 |
+
"AutoFeatureExtractor": "processing_gemma3_omni.Gemma3AudioFeatureExtractor"
|
5 |
+
},
|
6 |
"do_convert_rgb": null,
|
7 |
"do_normalize": true,
|
8 |
"do_pan_and_scan": null,
|
processing_gemma3_omni.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import re
|
3 |
+
from typing import List, Optional, Union, Dict, Any
|
4 |
+
import numpy as np
|
5 |
+
import scipy.signal
|
6 |
+
import torch
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
from transformers.audio_utils import AudioInput
|
9 |
+
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
10 |
+
from transformers.feature_extraction_utils import BatchFeature
|
11 |
+
from transformers.image_utils import make_nested_list_of_images
|
12 |
+
from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs, AudioKwargs, Unpack
|
13 |
+
from transformers.utils import TensorType, to_py_obj, logging
|
14 |
+
|
15 |
+
# Constants
|
16 |
+
DEFAULT_SAMPLING_RATE = 16000
|
17 |
+
DEFAULT_N_FFT = 512
|
18 |
+
DEFAULT_WIN_LENGTH = 400
|
19 |
+
DEFAULT_HOP_LENGTH = 160
|
20 |
+
DEFAULT_N_MELS = 80
|
21 |
+
DEFAULT_COMPRESSION_RATE = 4
|
22 |
+
DEFAULT_QFORMER_RATE = 2
|
23 |
+
DEFAULT_FEAT_STRIDE = 4
|
24 |
+
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
|
25 |
+
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
|
26 |
+
DEFAULT_MAX_LENGTH = 16384
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
|
32 |
+
fmax: Optional[float] = None) -> np.ndarray:
|
33 |
+
"""Create Mel filterbank for audio processing."""
|
34 |
+
fmax = fmax or sampling_rate / 2
|
35 |
+
|
36 |
+
def hz_to_mel(f: float) -> float:
|
37 |
+
return 1127.0 * math.log(1 + f / 700.0)
|
38 |
+
|
39 |
+
mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
|
40 |
+
freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1)
|
41 |
+
bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
|
42 |
+
|
43 |
+
filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
|
44 |
+
for m in range(1, n_mels + 1):
|
45 |
+
left, center, right = bins[m - 1:m + 2]
|
46 |
+
filterbank[m - 1, left:center] = (np.arange(left, center) - left) / (center - left)
|
47 |
+
filterbank[m - 1, center:right] = (right - np.arange(center, right)) / (right - center)
|
48 |
+
|
49 |
+
return filterbank
|
50 |
+
|
51 |
+
|
52 |
+
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
53 |
+
"""Converts 16-kHz mono waveform to (T, 80) log-Mel frames."""
|
54 |
+
|
55 |
+
model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
compression_rate: int = DEFAULT_COMPRESSION_RATE,
|
60 |
+
qformer_rate: int = DEFAULT_QFORMER_RATE,
|
61 |
+
feat_stride: int = DEFAULT_FEAT_STRIDE,
|
62 |
+
sampling_rate: int = DEFAULT_SAMPLING_RATE,
|
63 |
+
n_fft: int = DEFAULT_N_FFT,
|
64 |
+
win_length: int = DEFAULT_WIN_LENGTH,
|
65 |
+
hop_length: int = DEFAULT_HOP_LENGTH,
|
66 |
+
n_mels: int = DEFAULT_N_MELS,
|
67 |
+
**kwargs
|
68 |
+
):
|
69 |
+
super().__init__(n_mels, sampling_rate, 0.0, **kwargs)
|
70 |
+
self.compression_rate = compression_rate
|
71 |
+
self.qformer_rate = qformer_rate
|
72 |
+
self.feat_stride = feat_stride
|
73 |
+
self.sampling_rate = sampling_rate
|
74 |
+
|
75 |
+
self.window = np.hamming(win_length).astype(np.float32)
|
76 |
+
self.mel_filterbank = create_mel_filterbank(sampling_rate, n_fft, n_mels).T
|
77 |
+
self.n_fft = n_fft
|
78 |
+
self.hop_length = hop_length
|
79 |
+
self.win_length = win_length
|
80 |
+
|
81 |
+
def __call__(
|
82 |
+
self,
|
83 |
+
audios: List[AudioInput],
|
84 |
+
return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
|
85 |
+
) -> BatchFeature:
|
86 |
+
features, sizes, frames = [], [], []
|
87 |
+
|
88 |
+
for wav in audios:
|
89 |
+
processed_wav = self._preprocess_audio(wav, 22500)
|
90 |
+
mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav)
|
91 |
+
feature_tensor = torch.tensor(mel_spectrogram, dtype=torch.float32)
|
92 |
+
features.append(feature_tensor)
|
93 |
+
sizes.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
|
94 |
+
frames.append(feature_tensor.shape[0] * self.feat_stride)
|
95 |
+
|
96 |
+
audio_embeds = pad_sequence(features, batch_first=True)
|
97 |
+
size_tensor = torch.stack(sizes)
|
98 |
+
|
99 |
+
attention_mask = None
|
100 |
+
if len(audios) > 1:
|
101 |
+
frame_lengths = torch.tensor(frames)
|
102 |
+
attention_mask = torch.arange(frame_lengths.max()).unsqueeze(0) < frame_lengths.unsqueeze(1)
|
103 |
+
|
104 |
+
output_data = {
|
105 |
+
"input_audio_embeds": audio_embeds,
|
106 |
+
"audio_embed_sizes": size_tensor
|
107 |
+
}
|
108 |
+
if attention_mask is not None:
|
109 |
+
output_data["audio_attention_mask"] = attention_mask
|
110 |
+
|
111 |
+
return BatchFeature(data=output_data, tensor_type=return_tensors)
|
112 |
+
|
113 |
+
def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
|
114 |
+
wav = torch.as_tensor(wav).float().numpy()
|
115 |
+
if wav.ndim > 1:
|
116 |
+
wav = wav.mean(axis=0)
|
117 |
+
if source_sr != self.sampling_rate:
|
118 |
+
wav = scipy.signal.resample_poly(wav, self.sampling_rate, source_sr)
|
119 |
+
return wav / max(np.abs(wav).max(), 1e-6)
|
120 |
+
|
121 |
+
def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
|
122 |
+
frame_count = 1 + (len(wav) - self.win_length) // self.hop_length
|
123 |
+
strides = wav.strides[0]
|
124 |
+
frames = np.lib.stride_tricks.as_strided(
|
125 |
+
wav,
|
126 |
+
shape=(frame_count, self.win_length),
|
127 |
+
strides=(strides * self.hop_length, strides),
|
128 |
+
writeable=False
|
129 |
+
).copy()
|
130 |
+
frames *= self.window
|
131 |
+
|
132 |
+
spectrum = np.fft.rfft(frames, n=self.n_fft).astype(np.complex64)
|
133 |
+
power = np.abs(spectrum) ** 2
|
134 |
+
mel_spectrogram = np.dot(power, self.mel_filterbank)
|
135 |
+
mel_spectrogram = np.clip(mel_spectrogram, 1.0, None)
|
136 |
+
return np.log(mel_spectrogram, dtype=np.float32)
|
137 |
+
|
138 |
+
def _calculate_embed_length(self, frame_count: int) -> int:
|
139 |
+
compressed = math.ceil(frame_count / self.compression_rate)
|
140 |
+
return math.ceil(compressed / self.qformer_rate)
|
141 |
+
|
142 |
+
|
143 |
+
class Gemma3ImagesKwargs(ImagesKwargs):
|
144 |
+
do_pan_and_scan: Optional[bool]
|
145 |
+
pan_and_scan_min_crop_size: Optional[int]
|
146 |
+
pan_and_scan_max_num_crops: Optional[int]
|
147 |
+
pan_and_scan_min_ratio_to_activate: Optional[float]
|
148 |
+
do_convert_rgb: Optional[bool]
|
149 |
+
|
150 |
+
|
151 |
+
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
|
152 |
+
images_kwargs: Dict[str, Any]
|
153 |
+
audio_kwargs: Dict[str, Any]
|
154 |
+
_defaults = {
|
155 |
+
"text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
|
156 |
+
"images_kwargs": {},
|
157 |
+
"audio_kwargs": {}
|
158 |
+
}
|
159 |
+
|
160 |
+
|
161 |
+
class Gemma3OmniProcessor(ProcessorMixin):
|
162 |
+
attributes = ["image_processor", "tokenizer", "audio_processor"]
|
163 |
+
valid_kwargs = ["chat_template", "image_seq_length"]
|
164 |
+
image_processor_class = "AutoImageProcessor"
|
165 |
+
audio_processor_class = "AutoFeatureExtractor"
|
166 |
+
tokenizer_class = "AutoTokenizer"
|
167 |
+
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
image_processor,
|
171 |
+
audio_processor,
|
172 |
+
tokenizer,
|
173 |
+
chat_template=None,
|
174 |
+
image_seq_length: int = 256,
|
175 |
+
**kwargs
|
176 |
+
):
|
177 |
+
self.image_seq_length = image_seq_length
|
178 |
+
self.image_token_id = tokenizer.image_token_id
|
179 |
+
self.boi_token = tokenizer.boi_token
|
180 |
+
self.image_token = tokenizer.image_token
|
181 |
+
self.audio_token = "<audio_soft_token>"
|
182 |
+
self.expected_audio_token_id = 262143
|
183 |
+
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{''.join([tokenizer.image_token] * image_seq_length)}{tokenizer.eoi_token}\n\n"
|
184 |
+
|
185 |
+
self.compression_rate = 8
|
186 |
+
self.qformer_compression_rate = 1
|
187 |
+
self.feat_stride = 1
|
188 |
+
|
189 |
+
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
|
190 |
+
if self.audio_token_id != self.expected_audio_token_id:
|
191 |
+
logger.warning(
|
192 |
+
f"Assigned ID {self.audio_token_id} for '{self.audio_token}' does not match expected ID {self.expected_audio_token_id}. "
|
193 |
+
"Using assigned ID. Model embedding layer may need resizing."
|
194 |
+
)
|
195 |
+
|
196 |
+
super().__init__(
|
197 |
+
image_processor=image_processor,
|
198 |
+
audio_processor=audio_processor,
|
199 |
+
tokenizer=tokenizer,
|
200 |
+
chat_template=chat_template,
|
201 |
+
**kwargs
|
202 |
+
)
|
203 |
+
|
204 |
+
def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs):
|
205 |
+
default_kwargs = {}
|
206 |
+
for modality in ModelProcessorKwargs._defaults:
|
207 |
+
default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
|
208 |
+
|
209 |
+
# Update defaults with tokenizer init kwargs
|
210 |
+
for modality in default_kwargs:
|
211 |
+
modality_kwargs = default_kwargs[modality]
|
212 |
+
for key in modality_kwargs:
|
213 |
+
if key in tokenizer_init_kwargs:
|
214 |
+
value = (
|
215 |
+
getattr(self.tokenizer, key)
|
216 |
+
if hasattr(self.tokenizer, key)
|
217 |
+
else tokenizer_init_kwargs[key]
|
218 |
+
)
|
219 |
+
modality_kwargs[key] = value
|
220 |
+
|
221 |
+
# Update with user-provided kwargs
|
222 |
+
for modality in default_kwargs:
|
223 |
+
if modality in kwargs:
|
224 |
+
default_kwargs[modality].update(kwargs[modality])
|
225 |
+
|
226 |
+
# Ensure text_kwargs has truncation=False and large max_length
|
227 |
+
default_kwargs["text_kwargs"]["truncation"] = False
|
228 |
+
default_kwargs["text_kwargs"]["max_length"] = default_kwargs["text_kwargs"].get("max_length",
|
229 |
+
DEFAULT_MAX_LENGTH)
|
230 |
+
|
231 |
+
return default_kwargs
|
232 |
+
|
233 |
+
def _compute_audio_embed_size(self, audio_frames: int) -> int:
|
234 |
+
result = math.ceil(audio_frames / self.compression_rate)
|
235 |
+
return math.ceil(result / self.qformer_compression_rate)
|
236 |
+
|
237 |
+
def __call__(
|
238 |
+
self,
|
239 |
+
images=None,
|
240 |
+
text=None,
|
241 |
+
videos=None,
|
242 |
+
audio=None,
|
243 |
+
**kwargs: Unpack[Gemma3ProcessorKwargs]
|
244 |
+
) -> BatchFeature:
|
245 |
+
if text is None and images is None:
|
246 |
+
raise ValueError("Provide at least one of `text` or `images`.")
|
247 |
+
|
248 |
+
output_kwargs = self._merge_kwargs(
|
249 |
+
Gemma3ProcessorKwargs,
|
250 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
251 |
+
**kwargs
|
252 |
+
)
|
253 |
+
|
254 |
+
if isinstance(text, str):
|
255 |
+
text = [text]
|
256 |
+
elif not isinstance(text, list) or not all(isinstance(t, str) for t in text):
|
257 |
+
raise ValueError("Input text must be a string or list of strings")
|
258 |
+
|
259 |
+
image_inputs = {}
|
260 |
+
if images is not None:
|
261 |
+
batched_images = make_nested_list_of_images(images)
|
262 |
+
image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
|
263 |
+
|
264 |
+
if not text:
|
265 |
+
text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
|
266 |
+
|
267 |
+
if len(batched_images) != len(text):
|
268 |
+
raise ValueError(
|
269 |
+
f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts"
|
270 |
+
)
|
271 |
+
|
272 |
+
num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
273 |
+
batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
|
274 |
+
|
275 |
+
for batch_idx, (prompt, images, crops) in enumerate(zip(text, batched_images, batch_num_crops)):
|
276 |
+
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
|
277 |
+
if len(images) != len(image_indexes):
|
278 |
+
raise ValueError(
|
279 |
+
f"Prompt has {len(image_indexes)} image tokens but received {len(images)} images"
|
280 |
+
)
|
281 |
+
|
282 |
+
for num, idx in reversed(list(zip(crops, image_indexes))):
|
283 |
+
if num:
|
284 |
+
formatted_image_text = (
|
285 |
+
f"Here is the original image {self.boi_token} and here are some crops to help you see better "
|
286 |
+
+ " ".join([self.boi_token] * num)
|
287 |
+
)
|
288 |
+
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token):]
|
289 |
+
text[batch_idx] = prompt
|
290 |
+
|
291 |
+
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
292 |
+
|
293 |
+
audio_inputs = {}
|
294 |
+
if audio is not None:
|
295 |
+
audio_inputs = self.audio_processor(audio, "pt")
|
296 |
+
audio_embeds = audio_inputs['input_audio_embeds']
|
297 |
+
audio_frames = audio_embeds.shape[1] * self.feat_stride
|
298 |
+
audio_seq_length = self._compute_audio_embed_size(audio_frames)
|
299 |
+
|
300 |
+
audio_tokens = {
|
301 |
+
"boa_token": "<start_of_audio>",
|
302 |
+
"eoa_token": "<end_of_audio>",
|
303 |
+
"audio_token": "<audio_soft_token>",
|
304 |
+
"boa_token_id": 256001,
|
305 |
+
"eoa_token_id": 256002,
|
306 |
+
"audio_token_id": self.audio_token_id # Use dynamic ID
|
307 |
+
}
|
308 |
+
|
309 |
+
audio_sequence = f"\n\n{audio_tokens['boa_token']}{''.join([audio_tokens['audio_token']] * audio_seq_length)}{audio_tokens['eoa_token']}\n\n"
|
310 |
+
text = [prompt.replace(audio_tokens['boa_token'], audio_sequence) for prompt in text]
|
311 |
+
|
312 |
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
313 |
+
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
|
314 |
+
|
315 |
+
# Debug: Log text and token counts before validation
|
316 |
+
for i, (txt, ids) in enumerate(zip(text, text_inputs["input_ids"])):
|
317 |
+
audio_text_count = txt.count(self.audio_token)
|
318 |
+
audio_ids_count = list(ids).count(self.audio_token_id)
|
319 |
+
logger.debug(
|
320 |
+
f"Sample {i}: Audio tokens in text={audio_text_count}, in input_ids={audio_ids_count}, "
|
321 |
+
f"Text length={len(txt)}, Input IDs length={len(ids)}"
|
322 |
+
)
|
323 |
+
|
324 |
+
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "audio"])
|
325 |
+
|
326 |
+
array_ids = text_inputs["input_ids"]
|
327 |
+
mm_token_type_ids = np.zeros_like(array_ids)
|
328 |
+
mm_token_type_ids[array_ids == self.image_token_id] = 1 # Image token type
|
329 |
+
mm_token_type_ids[array_ids == self.audio_token_id] = 2 # Audio token type
|
330 |
+
text_inputs = {k: v.tolist() for k, v in text_inputs.items()}
|
331 |
+
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
|
332 |
+
|
333 |
+
return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
|
334 |
+
|
335 |
+
def batch_decode(self, *args, **kwargs):
|
336 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
337 |
+
|
338 |
+
def decode(self, *args, **kwargs):
|
339 |
+
return self.tokenizer.decode(*args, **kwargs)
|
340 |
+
|
341 |
+
@property
|
342 |
+
def model_input_names(self):
|
343 |
+
tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
|
344 |
+
image_processor_inputs = self.image_processor.model_input_names
|
345 |
+
return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs))
|