import re from typing import List, Optional, Union, Dict, Any import math import numpy as np import scipy.signal import torch from torch.nn.utils.rnn import pad_sequence from transformers.audio_utils import AudioInput from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import make_nested_list_of_images from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs, Unpack from transformers.utils import TensorType, to_py_obj, logging # Constants DEFAULT_SAMPLING_RATE = 16000 DEFAULT_N_FFT = 512 DEFAULT_WIN_LENGTH = 400 DEFAULT_HOP_LENGTH = 160 DEFAULT_N_MELS = 80 DEFAULT_COMPRESSION_RATE = 4 DEFAULT_QFORMER_RATE = 2 DEFAULT_FEAT_STRIDE = 4 IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>" AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>" DEFAULT_MAX_LENGTH = 16384 logger = logging.get_logger(__name__) def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: Optional[float] = None) -> np.ndarray: """Create Mel filterbank for audio processing.""" fmax = fmax or sampling_rate / 2 def hz_to_mel(f: float) -> float: return 1127.0 * math.log(1 + f / 700.0) mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2) freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int) filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32) for m in range(1, n_mels + 1): left, center, right = bins[m - 1:m + 2] filterbank[m - 1, left:center] = (np.arange(left, center) - left) / (center - left) filterbank[m - 1, center:right] = (right - np.arange(center, right)) / (right - center) return filterbank class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor): """Converts 16-kHz mono waveform to (T, 80) log-Mel frames.""" model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"] def __init__( self, compression_rate: int = DEFAULT_COMPRESSION_RATE, qformer_rate: int = DEFAULT_QFORMER_RATE, feat_stride: int = DEFAULT_FEAT_STRIDE, sampling_rate: int = DEFAULT_SAMPLING_RATE, n_fft: int = DEFAULT_N_FFT, win_length: int = DEFAULT_WIN_LENGTH, hop_length: int = DEFAULT_HOP_LENGTH, n_mels: int = DEFAULT_N_MELS, **kwargs ): super().__init__(n_mels, sampling_rate, 0.0, **kwargs) self.compression_rate = compression_rate self.qformer_rate = qformer_rate self.feat_stride = feat_stride self.sampling_rate = sampling_rate self.window = np.hamming(win_length).astype(np.float32) self.mel_filterbank = create_mel_filterbank(sampling_rate, n_fft, n_mels).T self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length def __call__( self, audios: List[AudioInput], return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH ) -> BatchFeature: features, sizes, frames = [], [], [] for wav in audios: processed_wav = self._preprocess_audio(wav, 22500) mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav) feature_tensor = torch.tensor(mel_spectrogram, dtype=torch.float32) features.append(feature_tensor) sizes.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0]))) frames.append(feature_tensor.shape[0] * self.feat_stride) audio_embeds = pad_sequence(features, batch_first=True) size_tensor = torch.stack(sizes) attention_mask = None if len(audios) > 1: frame_lengths = torch.tensor(frames) attention_mask = torch.arange(frame_lengths.max()).unsqueeze(0) < frame_lengths.unsqueeze(1) output_data = { "input_audio_embeds": audio_embeds, "audio_embed_sizes": size_tensor } if attention_mask is not None: output_data["audio_attention_mask"] = attention_mask return BatchFeature(data=output_data, tensor_type=return_tensors) def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray: wav = torch.as_tensor(wav).float().numpy() if wav.ndim > 1: wav = wav.mean(axis=0) if source_sr != self.sampling_rate: wav = scipy.signal.resample_poly(wav, self.sampling_rate, source_sr) return wav / max(np.abs(wav).max(), 1e-6) def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray: frame_count = 1 + (len(wav) - self.win_length) // self.hop_length strides = wav.strides[0] frames = np.lib.stride_tricks.as_strided( wav, shape=(frame_count, self.win_length), strides=(strides * self.hop_length, strides), writeable=False ).copy() frames *= self.window spectrum = np.fft.rfft(frames, n=self.n_fft).astype(np.complex64) power = np.abs(spectrum) ** 2 mel_spectrogram = np.dot(power, self.mel_filterbank) mel_spectrogram = np.clip(mel_spectrogram, 1.0, None) return np.log(mel_spectrogram, dtype=np.float32) def _calculate_embed_length(self, frame_count: int) -> int: compressed = math.ceil(frame_count / self.compression_rate) return math.ceil(compressed / self.qformer_rate) class Gemma3ImagesKwargs(ImagesKwargs): do_pan_and_scan: Optional[bool] pan_and_scan_min_crop_size: Optional[int] pan_and_scan_max_num_crops: Optional[int] pan_and_scan_min_ratio_to_activate: Optional[float] do_convert_rgb: Optional[bool] class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): images_kwargs: Dict[str, Any] audio_kwargs: Dict[str, Any] _defaults = { "text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH}, "images_kwargs": {}, "audio_kwargs": {} } class Gemma3OmniProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer", "audio_processor"] valid_kwargs = ["chat_template", "image_seq_length"] image_processor_class = "AutoImageProcessor" audio_processor_class = "AutoFeatureExtractor" tokenizer_class = "AutoTokenizer" def __init__( self, image_processor, audio_processor, tokenizer, chat_template=None, image_seq_length: int = 256, **kwargs ): self.image_seq_length = image_seq_length self.image_token_id = tokenizer.image_token_id self.boi_token = tokenizer.boi_token self.image_token = tokenizer.image_token self.audio_token = "" self.expected_audio_token_id = 262143 self.full_image_sequence = f"\n\n{tokenizer.boi_token}{''.join([tokenizer.image_token] * image_seq_length)}{tokenizer.eoi_token}\n\n" self.compression_rate = 8 self.qformer_compression_rate = 1 self.feat_stride = 1 self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token) if self.audio_token_id != self.expected_audio_token_id: logger.warning( f"Assigned ID {self.audio_token_id} for '{self.audio_token}' does not match expected ID {self.expected_audio_token_id}. " "Using assigned ID. Model embedding layer may need resizing." ) super().__init__( image_processor=image_processor, audio_processor=audio_processor, tokenizer=tokenizer, chat_template=chat_template, **kwargs ) def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs): default_kwargs = {} for modality in ModelProcessorKwargs._defaults: default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy() # Update defaults with tokenizer init kwargs for modality in default_kwargs: modality_kwargs = default_kwargs[modality] for key in modality_kwargs: if key in tokenizer_init_kwargs: value = ( getattr(self.tokenizer, key) if hasattr(self.tokenizer, key) else tokenizer_init_kwargs[key] ) modality_kwargs[key] = value # Update with user-provided kwargs for modality in default_kwargs: if modality in kwargs: default_kwargs[modality].update(kwargs[modality]) # Ensure text_kwargs has truncation=False and large max_length default_kwargs["text_kwargs"]["truncation"] = False default_kwargs["text_kwargs"]["max_length"] = default_kwargs["text_kwargs"].get("max_length", DEFAULT_MAX_LENGTH) return default_kwargs def _compute_audio_embed_size(self, audio_frames: int) -> int: result = math.ceil(audio_frames / self.compression_rate) return math.ceil(result / self.qformer_compression_rate) def __call__( self, images=None, text=None, videos=None, audio=None, **kwargs: Unpack[Gemma3ProcessorKwargs] ) -> BatchFeature: if text is None and images is None: raise ValueError("Provide at least one of `text` or `images`.") output_kwargs = self._merge_kwargs( Gemma3ProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs ) if isinstance(text, str): text = [text] elif not isinstance(text, list) or not all(isinstance(t, str) for t in text): raise ValueError("Input text must be a string or list of strings") image_inputs = {} if images is not None: batched_images = make_nested_list_of_images(images) image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"]) if not text: text = [" ".join([self.boi_token] * len(images)) for images in batched_images] if len(batched_images) != len(text): raise ValueError( f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts" ) num_crops = to_py_obj(image_inputs.pop("num_crops")) batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images] for batch_idx, (prompt, images, crops) in enumerate(zip(text, batched_images, batch_num_crops)): image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)] if len(images) != len(image_indexes): raise ValueError( f"Prompt has {len(image_indexes)} image tokens but received {len(images)} images" ) for num, idx in reversed(list(zip(crops, image_indexes))): if num: formatted_image_text = ( f"Here is the original image {self.boi_token} and here are some crops to help you see better " + " ".join([self.boi_token] * num) ) prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token):] text[batch_idx] = prompt text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text] audio_inputs = {} if audio is not None: audio_inputs = self.audio_processor(audio, "pt") audio_embeds = audio_inputs['input_audio_embeds'] audio_frames = audio_embeds.shape[1] * self.feat_stride audio_seq_length = self._compute_audio_embed_size(audio_frames) audio_tokens = { "boa_token": "", "eoa_token": "", "audio_token": "", "boa_token_id": 256001, "eoa_token_id": 256002, "audio_token_id": self.audio_token_id # Use dynamic ID } audio_sequence = f"\n\n{audio_tokens['boa_token']}{''.join([audio_tokens['audio_token']] * audio_seq_length)}{audio_tokens['eoa_token']}\n\n" text = [prompt.replace(audio_tokens['boa_token'], audio_sequence) for prompt in text] return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np") # Debug: Log text and token counts before validation for i, (txt, ids) in enumerate(zip(text, text_inputs["input_ids"])): audio_text_count = txt.count(self.audio_token) audio_ids_count = list(ids).count(self.audio_token_id) logger.debug( f"Sample {i}: Audio tokens in text={audio_text_count}, in input_ids={audio_ids_count}, " f"Text length={len(txt)}, Input IDs length={len(ids)}" ) self._check_special_mm_tokens(text, text_inputs, modalities=["image", "audio"]) array_ids = text_inputs["input_ids"] mm_token_type_ids = np.zeros_like(array_ids) mm_token_type_ids[array_ids == self.image_token_id] = 1 # Image token type mm_token_type_ids[array_ids == self.audio_token_id] = 2 # Audio token type text_inputs = {k: v.tolist() for k, v in text_inputs.items()} text_inputs["token_type_ids"] = mm_token_type_ids.tolist() return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors) def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) @property def model_input_names(self): tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"] image_processor_inputs = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs)) # ────────────────────────────────────────────────────────────────────────────── # exports # ────────────────────────────────────────────────────────────────────────────── __all__ = [ "Gemma3OmniProcessor", "Gemma3AudioFeatureExtractor" ]