File size: 15,135 Bytes
129c05b 19649d5 129c05b 19649d5 129c05b 19649d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 |
import re
from typing import List, Optional, Union, Dict, Any
import math
import numpy as np
import scipy.signal
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers.audio_utils import AudioInput
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import make_nested_list_of_images
from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs, Unpack
from transformers.utils import TensorType, to_py_obj, logging
# Constants
DEFAULT_SAMPLING_RATE = 16000
DEFAULT_N_FFT = 512
DEFAULT_WIN_LENGTH = 400
DEFAULT_HOP_LENGTH = 160
DEFAULT_N_MELS = 80
DEFAULT_COMPRESSION_RATE = 4
DEFAULT_QFORMER_RATE = 2
DEFAULT_FEAT_STRIDE = 4
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>"
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>"
DEFAULT_MAX_LENGTH = 16384
logger = logging.get_logger(__name__)
def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0,
fmax: Optional[float] = None) -> np.ndarray:
"""Create Mel filterbank for audio processing."""
fmax = fmax or sampling_rate / 2
def hz_to_mel(f: float) -> float:
return 1127.0 * math.log(1 + f / 700.0)
mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1)
bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int)
filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32)
for m in range(1, n_mels + 1):
left, center, right = bins[m - 1:m + 2]
filterbank[m - 1, left:center] = (np.arange(left, center) - left) / (center - left)
filterbank[m - 1, center:right] = (right - np.arange(center, right)) / (right - center)
return filterbank
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
"""Converts 16-kHz mono waveform to (T, 80) log-Mel frames."""
model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
def __init__(
self,
compression_rate: int = DEFAULT_COMPRESSION_RATE,
qformer_rate: int = DEFAULT_QFORMER_RATE,
feat_stride: int = DEFAULT_FEAT_STRIDE,
sampling_rate: int = DEFAULT_SAMPLING_RATE,
n_fft: int = DEFAULT_N_FFT,
win_length: int = DEFAULT_WIN_LENGTH,
hop_length: int = DEFAULT_HOP_LENGTH,
n_mels: int = DEFAULT_N_MELS,
**kwargs
):
super().__init__(n_mels, sampling_rate, 0.0, **kwargs)
self.compression_rate = compression_rate
self.qformer_rate = qformer_rate
self.feat_stride = feat_stride
self.sampling_rate = sampling_rate
self.window = np.hamming(win_length).astype(np.float32)
self.mel_filterbank = create_mel_filterbank(sampling_rate, n_fft, n_mels).T
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
def __call__(
self,
audios: List[AudioInput],
return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH
) -> BatchFeature:
features, sizes, frames = [], [], []
for wav in audios:
processed_wav = self._preprocess_audio(wav, 22500)
mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav)
feature_tensor = torch.tensor(mel_spectrogram, dtype=torch.float32)
features.append(feature_tensor)
sizes.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0])))
frames.append(feature_tensor.shape[0] * self.feat_stride)
audio_embeds = pad_sequence(features, batch_first=True)
size_tensor = torch.stack(sizes)
attention_mask = None
if len(audios) > 1:
frame_lengths = torch.tensor(frames)
attention_mask = torch.arange(frame_lengths.max()).unsqueeze(0) < frame_lengths.unsqueeze(1)
output_data = {
"input_audio_embeds": audio_embeds,
"audio_embed_sizes": size_tensor
}
if attention_mask is not None:
output_data["audio_attention_mask"] = attention_mask
return BatchFeature(data=output_data, tensor_type=return_tensors)
def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray:
wav = torch.as_tensor(wav).float().numpy()
if wav.ndim > 1:
wav = wav.mean(axis=0)
if source_sr != self.sampling_rate:
wav = scipy.signal.resample_poly(wav, self.sampling_rate, source_sr)
return wav / max(np.abs(wav).max(), 1e-6)
def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray:
frame_count = 1 + (len(wav) - self.win_length) // self.hop_length
strides = wav.strides[0]
frames = np.lib.stride_tricks.as_strided(
wav,
shape=(frame_count, self.win_length),
strides=(strides * self.hop_length, strides),
writeable=False
).copy()
frames *= self.window
spectrum = np.fft.rfft(frames, n=self.n_fft).astype(np.complex64)
power = np.abs(spectrum) ** 2
mel_spectrogram = np.dot(power, self.mel_filterbank)
mel_spectrogram = np.clip(mel_spectrogram, 1.0, None)
return np.log(mel_spectrogram, dtype=np.float32)
def _calculate_embed_length(self, frame_count: int) -> int:
compressed = math.ceil(frame_count / self.compression_rate)
return math.ceil(compressed / self.qformer_rate)
class Gemma3ImagesKwargs(ImagesKwargs):
do_pan_and_scan: Optional[bool]
pan_and_scan_min_crop_size: Optional[int]
pan_and_scan_max_num_crops: Optional[int]
pan_and_scan_min_ratio_to_activate: Optional[float]
do_convert_rgb: Optional[bool]
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: Dict[str, Any]
audio_kwargs: Dict[str, Any]
_defaults = {
"text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH},
"images_kwargs": {},
"audio_kwargs": {}
}
class Gemma3OmniProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer", "audio_processor"]
valid_kwargs = ["chat_template", "image_seq_length"]
image_processor_class = "AutoImageProcessor"
audio_processor_class = "AutoFeatureExtractor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor,
audio_processor,
tokenizer,
chat_template=None,
image_seq_length: int = 256,
**kwargs
):
self.image_seq_length = image_seq_length
self.image_token_id = tokenizer.image_token_id
self.boi_token = tokenizer.boi_token
self.image_token = tokenizer.image_token
self.audio_token = "<audio_soft_token>"
self.expected_audio_token_id = 262143
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{''.join([tokenizer.image_token] * image_seq_length)}{tokenizer.eoi_token}\n\n"
self.compression_rate = 8
self.qformer_compression_rate = 1
self.feat_stride = 1
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
if self.audio_token_id != self.expected_audio_token_id:
logger.warning(
f"Assigned ID {self.audio_token_id} for '{self.audio_token}' does not match expected ID {self.expected_audio_token_id}. "
"Using assigned ID. Model embedding layer may need resizing."
)
super().__init__(
image_processor=image_processor,
audio_processor=audio_processor,
tokenizer=tokenizer,
chat_template=chat_template,
**kwargs
)
def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs):
default_kwargs = {}
for modality in ModelProcessorKwargs._defaults:
default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
# Update defaults with tokenizer init kwargs
for modality in default_kwargs:
modality_kwargs = default_kwargs[modality]
for key in modality_kwargs:
if key in tokenizer_init_kwargs:
value = (
getattr(self.tokenizer, key)
if hasattr(self.tokenizer, key)
else tokenizer_init_kwargs[key]
)
modality_kwargs[key] = value
# Update with user-provided kwargs
for modality in default_kwargs:
if modality in kwargs:
default_kwargs[modality].update(kwargs[modality])
# Ensure text_kwargs has truncation=False and large max_length
default_kwargs["text_kwargs"]["truncation"] = False
default_kwargs["text_kwargs"]["max_length"] = default_kwargs["text_kwargs"].get("max_length",
DEFAULT_MAX_LENGTH)
return default_kwargs
def _compute_audio_embed_size(self, audio_frames: int) -> int:
result = math.ceil(audio_frames / self.compression_rate)
return math.ceil(result / self.qformer_compression_rate)
def __call__(
self,
images=None,
text=None,
videos=None,
audio=None,
**kwargs: Unpack[Gemma3ProcessorKwargs]
) -> BatchFeature:
if text is None and images is None:
raise ValueError("Provide at least one of `text` or `images`.")
output_kwargs = self._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs
)
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) or not all(isinstance(t, str) for t in text):
raise ValueError("Input text must be a string or list of strings")
image_inputs = {}
if images is not None:
batched_images = make_nested_list_of_images(images)
image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
if not text:
text = [" ".join([self.boi_token] * len(images)) for images in batched_images]
if len(batched_images) != len(text):
raise ValueError(
f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts"
)
num_crops = to_py_obj(image_inputs.pop("num_crops"))
batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
for batch_idx, (prompt, images, crops) in enumerate(zip(text, batched_images, batch_num_crops)):
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
if len(images) != len(image_indexes):
raise ValueError(
f"Prompt has {len(image_indexes)} image tokens but received {len(images)} images"
)
for num, idx in reversed(list(zip(crops, image_indexes))):
if num:
formatted_image_text = (
f"Here is the original image {self.boi_token} and here are some crops to help you see better "
+ " ".join([self.boi_token] * num)
)
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token):]
text[batch_idx] = prompt
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
audio_inputs = {}
if audio is not None:
audio_inputs = self.audio_processor(audio, "pt")
audio_embeds = audio_inputs['input_audio_embeds']
audio_frames = audio_embeds.shape[1] * self.feat_stride
audio_seq_length = self._compute_audio_embed_size(audio_frames)
audio_tokens = {
"boa_token": "<start_of_audio>",
"eoa_token": "<end_of_audio>",
"audio_token": "<audio_soft_token>",
"boa_token_id": 256001,
"eoa_token_id": 256002,
"audio_token_id": self.audio_token_id # Use dynamic ID
}
audio_sequence = f"\n\n{audio_tokens['boa_token']}{''.join([audio_tokens['audio_token']] * audio_seq_length)}{audio_tokens['eoa_token']}\n\n"
text = [prompt.replace(audio_tokens['boa_token'], audio_sequence) for prompt in text]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
# Debug: Log text and token counts before validation
for i, (txt, ids) in enumerate(zip(text, text_inputs["input_ids"])):
audio_text_count = txt.count(self.audio_token)
audio_ids_count = list(ids).count(self.audio_token_id)
logger.debug(
f"Sample {i}: Audio tokens in text={audio_text_count}, in input_ids={audio_ids_count}, "
f"Text length={len(txt)}, Input IDs length={len(ids)}"
)
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "audio"])
array_ids = text_inputs["input_ids"]
mm_token_type_ids = np.zeros_like(array_ids)
mm_token_type_ids[array_ids == self.image_token_id] = 1 # Image token type
mm_token_type_ids[array_ids == self.audio_token_id] = 2 # Audio token type
text_inputs = {k: v.tolist() for k, v in text_inputs.items()}
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
image_processor_inputs = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs))
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# exports
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
__all__ = [
"Gemma3OmniProcessor",
"Gemma3AudioFeatureExtractor"
]
|