|
import re |
|
from typing import List, Optional, Union, Dict, Any |
|
|
|
import math |
|
import numpy as np |
|
import scipy.signal |
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from transformers.audio_utils import AudioInput |
|
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.image_utils import make_nested_list_of_images |
|
from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs, Unpack |
|
from transformers.utils import TensorType, to_py_obj, logging |
|
|
|
|
|
DEFAULT_SAMPLING_RATE = 16000 |
|
DEFAULT_N_FFT = 512 |
|
DEFAULT_WIN_LENGTH = 400 |
|
DEFAULT_HOP_LENGTH = 160 |
|
DEFAULT_N_MELS = 80 |
|
DEFAULT_COMPRESSION_RATE = 4 |
|
DEFAULT_QFORMER_RATE = 2 |
|
DEFAULT_FEAT_STRIDE = 4 |
|
IMAGE_TOKEN_PATTERN = r"<\|image_\d+\|>" |
|
AUDIO_TOKEN_PATTERN = r"<\|audio_\d+\|>" |
|
DEFAULT_MAX_LENGTH = 16384 |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def create_mel_filterbank(sampling_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0, |
|
fmax: Optional[float] = None) -> np.ndarray: |
|
"""Create Mel filterbank for audio processing.""" |
|
fmax = fmax or sampling_rate / 2 |
|
|
|
def hz_to_mel(f: float) -> float: |
|
return 1127.0 * math.log(1 + f / 700.0) |
|
|
|
mel_points = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2) |
|
freq_points = 700.0 * (np.exp(mel_points / 1127.0) - 1) |
|
bins = np.floor((n_fft + 1) * freq_points / sampling_rate).astype(int) |
|
|
|
filterbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=np.float32) |
|
for m in range(1, n_mels + 1): |
|
left, center, right = bins[m - 1:m + 2] |
|
filterbank[m - 1, left:center] = (np.arange(left, center) - left) / (center - left) |
|
filterbank[m - 1, center:right] = (right - np.arange(center, right)) / (right - center) |
|
|
|
return filterbank |
|
|
|
|
|
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor): |
|
"""Converts 16-kHz mono waveform to (T, 80) log-Mel frames.""" |
|
|
|
model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"] |
|
|
|
def __init__( |
|
self, |
|
compression_rate: int = DEFAULT_COMPRESSION_RATE, |
|
qformer_rate: int = DEFAULT_QFORMER_RATE, |
|
feat_stride: int = DEFAULT_FEAT_STRIDE, |
|
sampling_rate: int = DEFAULT_SAMPLING_RATE, |
|
n_fft: int = DEFAULT_N_FFT, |
|
win_length: int = DEFAULT_WIN_LENGTH, |
|
hop_length: int = DEFAULT_HOP_LENGTH, |
|
n_mels: int = DEFAULT_N_MELS, |
|
**kwargs |
|
): |
|
super().__init__(n_mels, sampling_rate, 0.0, **kwargs) |
|
self.compression_rate = compression_rate |
|
self.qformer_rate = qformer_rate |
|
self.feat_stride = feat_stride |
|
self.sampling_rate = sampling_rate |
|
|
|
self.window = np.hamming(win_length).astype(np.float32) |
|
self.mel_filterbank = create_mel_filterbank(sampling_rate, n_fft, n_mels).T |
|
self.n_fft = n_fft |
|
self.hop_length = hop_length |
|
self.win_length = win_length |
|
|
|
def __call__( |
|
self, |
|
audios: List[AudioInput], |
|
return_tensors: Union[TensorType, str, None] = TensorType.PYTORCH |
|
) -> BatchFeature: |
|
features, sizes, frames = [], [], [] |
|
|
|
for wav in audios: |
|
processed_wav = self._preprocess_audio(wav, 22500) |
|
mel_spectrogram = self._compute_log_mel_spectrogram(processed_wav) |
|
feature_tensor = torch.tensor(mel_spectrogram, dtype=torch.float32) |
|
features.append(feature_tensor) |
|
sizes.append(torch.tensor(self._calculate_embed_length(feature_tensor.shape[0]))) |
|
frames.append(feature_tensor.shape[0] * self.feat_stride) |
|
|
|
audio_embeds = pad_sequence(features, batch_first=True) |
|
size_tensor = torch.stack(sizes) |
|
|
|
attention_mask = None |
|
if len(audios) > 1: |
|
frame_lengths = torch.tensor(frames) |
|
attention_mask = torch.arange(frame_lengths.max()).unsqueeze(0) < frame_lengths.unsqueeze(1) |
|
|
|
output_data = { |
|
"input_audio_embeds": audio_embeds, |
|
"audio_embed_sizes": size_tensor |
|
} |
|
if attention_mask is not None: |
|
output_data["audio_attention_mask"] = attention_mask |
|
|
|
return BatchFeature(data=output_data, tensor_type=return_tensors) |
|
|
|
def _preprocess_audio(self, wav: np.ndarray, source_sr: int) -> np.ndarray: |
|
wav = torch.as_tensor(wav).float().numpy() |
|
if wav.ndim > 1: |
|
wav = wav.mean(axis=0) |
|
if source_sr != self.sampling_rate: |
|
wav = scipy.signal.resample_poly(wav, self.sampling_rate, source_sr) |
|
return wav / max(np.abs(wav).max(), 1e-6) |
|
|
|
def _compute_log_mel_spectrogram(self, wav: np.ndarray) -> np.ndarray: |
|
frame_count = 1 + (len(wav) - self.win_length) // self.hop_length |
|
strides = wav.strides[0] |
|
frames = np.lib.stride_tricks.as_strided( |
|
wav, |
|
shape=(frame_count, self.win_length), |
|
strides=(strides * self.hop_length, strides), |
|
writeable=False |
|
).copy() |
|
frames *= self.window |
|
|
|
spectrum = np.fft.rfft(frames, n=self.n_fft).astype(np.complex64) |
|
power = np.abs(spectrum) ** 2 |
|
mel_spectrogram = np.dot(power, self.mel_filterbank) |
|
mel_spectrogram = np.clip(mel_spectrogram, 1.0, None) |
|
return np.log(mel_spectrogram, dtype=np.float32) |
|
|
|
def _calculate_embed_length(self, frame_count: int) -> int: |
|
compressed = math.ceil(frame_count / self.compression_rate) |
|
return math.ceil(compressed / self.qformer_rate) |
|
|
|
|
|
class Gemma3ImagesKwargs(ImagesKwargs): |
|
do_pan_and_scan: Optional[bool] |
|
pan_and_scan_min_crop_size: Optional[int] |
|
pan_and_scan_max_num_crops: Optional[int] |
|
pan_and_scan_min_ratio_to_activate: Optional[float] |
|
do_convert_rgb: Optional[bool] |
|
|
|
|
|
class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): |
|
images_kwargs: Dict[str, Any] |
|
audio_kwargs: Dict[str, Any] |
|
_defaults = { |
|
"text_kwargs": {"padding": False, "truncation": False, "max_length": DEFAULT_MAX_LENGTH}, |
|
"images_kwargs": {}, |
|
"audio_kwargs": {} |
|
} |
|
|
|
|
|
class Gemma3OmniProcessor(ProcessorMixin): |
|
attributes = ["image_processor", "tokenizer", "audio_processor"] |
|
valid_kwargs = ["chat_template", "image_seq_length"] |
|
image_processor_class = "AutoImageProcessor" |
|
audio_processor_class = "AutoFeatureExtractor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__( |
|
self, |
|
image_processor, |
|
audio_processor, |
|
tokenizer, |
|
chat_template=None, |
|
image_seq_length: int = 256, |
|
**kwargs |
|
): |
|
self.image_seq_length = image_seq_length |
|
self.image_token_id = tokenizer.image_token_id |
|
self.boi_token = tokenizer.boi_token |
|
self.image_token = tokenizer.image_token |
|
self.audio_token = "<audio_soft_token>" |
|
self.expected_audio_token_id = 262143 |
|
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{''.join([tokenizer.image_token] * image_seq_length)}{tokenizer.eoi_token}\n\n" |
|
|
|
self.compression_rate = 8 |
|
self.qformer_compression_rate = 1 |
|
self.feat_stride = 1 |
|
|
|
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token) |
|
if self.audio_token_id != self.expected_audio_token_id: |
|
logger.warning( |
|
f"Assigned ID {self.audio_token_id} for '{self.audio_token}' does not match expected ID {self.expected_audio_token_id}. " |
|
"Using assigned ID. Model embedding layer may need resizing." |
|
) |
|
|
|
super().__init__( |
|
image_processor=image_processor, |
|
audio_processor=audio_processor, |
|
tokenizer=tokenizer, |
|
chat_template=chat_template, |
|
**kwargs |
|
) |
|
|
|
def _merge_kwargs(self, ModelProcessorKwargs, tokenizer_init_kwargs, **kwargs): |
|
default_kwargs = {} |
|
for modality in ModelProcessorKwargs._defaults: |
|
default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy() |
|
|
|
|
|
for modality in default_kwargs: |
|
modality_kwargs = default_kwargs[modality] |
|
for key in modality_kwargs: |
|
if key in tokenizer_init_kwargs: |
|
value = ( |
|
getattr(self.tokenizer, key) |
|
if hasattr(self.tokenizer, key) |
|
else tokenizer_init_kwargs[key] |
|
) |
|
modality_kwargs[key] = value |
|
|
|
|
|
for modality in default_kwargs: |
|
if modality in kwargs: |
|
default_kwargs[modality].update(kwargs[modality]) |
|
|
|
|
|
default_kwargs["text_kwargs"]["truncation"] = False |
|
default_kwargs["text_kwargs"]["max_length"] = default_kwargs["text_kwargs"].get("max_length", |
|
DEFAULT_MAX_LENGTH) |
|
|
|
return default_kwargs |
|
|
|
def _compute_audio_embed_size(self, audio_frames: int) -> int: |
|
result = math.ceil(audio_frames / self.compression_rate) |
|
return math.ceil(result / self.qformer_compression_rate) |
|
|
|
def __call__( |
|
self, |
|
images=None, |
|
text=None, |
|
videos=None, |
|
audio=None, |
|
**kwargs: Unpack[Gemma3ProcessorKwargs] |
|
) -> BatchFeature: |
|
if text is None and images is None: |
|
raise ValueError("Provide at least one of `text` or `images`.") |
|
|
|
output_kwargs = self._merge_kwargs( |
|
Gemma3ProcessorKwargs, |
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
|
**kwargs |
|
) |
|
|
|
if isinstance(text, str): |
|
text = [text] |
|
elif not isinstance(text, list) or not all(isinstance(t, str) for t in text): |
|
raise ValueError("Input text must be a string or list of strings") |
|
|
|
image_inputs = {} |
|
if images is not None: |
|
batched_images = make_nested_list_of_images(images) |
|
image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"]) |
|
|
|
if not text: |
|
text = [" ".join([self.boi_token] * len(images)) for images in batched_images] |
|
|
|
if len(batched_images) != len(text): |
|
raise ValueError( |
|
f"Inconsistent batch sizes: {len(batched_images)} images, {len(text)} texts" |
|
) |
|
|
|
num_crops = to_py_obj(image_inputs.pop("num_crops")) |
|
batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images] |
|
|
|
for batch_idx, (prompt, images, crops) in enumerate(zip(text, batched_images, batch_num_crops)): |
|
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)] |
|
if len(images) != len(image_indexes): |
|
raise ValueError( |
|
f"Prompt has {len(image_indexes)} image tokens but received {len(images)} images" |
|
) |
|
|
|
for num, idx in reversed(list(zip(crops, image_indexes))): |
|
if num: |
|
formatted_image_text = ( |
|
f"Here is the original image {self.boi_token} and here are some crops to help you see better " |
|
+ " ".join([self.boi_token] * num) |
|
) |
|
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token):] |
|
text[batch_idx] = prompt |
|
|
|
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text] |
|
|
|
audio_inputs = {} |
|
if audio is not None: |
|
audio_inputs = self.audio_processor(audio, "pt") |
|
audio_embeds = audio_inputs['input_audio_embeds'] |
|
audio_frames = audio_embeds.shape[1] * self.feat_stride |
|
audio_seq_length = self._compute_audio_embed_size(audio_frames) |
|
|
|
audio_tokens = { |
|
"boa_token": "<start_of_audio>", |
|
"eoa_token": "<end_of_audio>", |
|
"audio_token": "<audio_soft_token>", |
|
"boa_token_id": 256001, |
|
"eoa_token_id": 256002, |
|
"audio_token_id": self.audio_token_id |
|
} |
|
|
|
audio_sequence = f"\n\n{audio_tokens['boa_token']}{''.join([audio_tokens['audio_token']] * audio_seq_length)}{audio_tokens['eoa_token']}\n\n" |
|
text = [prompt.replace(audio_tokens['boa_token'], audio_sequence) for prompt in text] |
|
|
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) |
|
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np") |
|
|
|
|
|
for i, (txt, ids) in enumerate(zip(text, text_inputs["input_ids"])): |
|
audio_text_count = txt.count(self.audio_token) |
|
audio_ids_count = list(ids).count(self.audio_token_id) |
|
logger.debug( |
|
f"Sample {i}: Audio tokens in text={audio_text_count}, in input_ids={audio_ids_count}, " |
|
f"Text length={len(txt)}, Input IDs length={len(ids)}" |
|
) |
|
|
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "audio"]) |
|
|
|
array_ids = text_inputs["input_ids"] |
|
mm_token_type_ids = np.zeros_like(array_ids) |
|
mm_token_type_ids[array_ids == self.image_token_id] = 1 |
|
mm_token_type_ids[array_ids == self.audio_token_id] = 2 |
|
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} |
|
text_inputs["token_type_ids"] = mm_token_type_ids.tolist() |
|
|
|
return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors) |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
@property |
|
def model_input_names(self): |
|
tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"] |
|
image_processor_inputs = self.image_processor.model_input_names |
|
return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs)) |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
"Gemma3OmniProcessor", |
|
"Gemma3AudioFeatureExtractor" |
|
] |
|
|