|
"""Image processor class for KimiVL.""" |
|
|
|
import math |
|
import numpy as np |
|
from PIL import Image |
|
from typing import Optional, Union |
|
|
|
import torch |
|
from torchvision.transforms import functional as TF |
|
from transformers.image_utils import ImageInput, make_list_of_images, valid_images |
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
|
from transformers.utils import TensorType |
|
|
|
|
|
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
|
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
|
class KimiVLImageProcessor(BaseImageProcessor): |
|
model_type = "kimi_vl" |
|
|
|
def __init__( |
|
self, |
|
patch_size: int = 14, |
|
pad_input: bool = False, |
|
image_mean: tuple[float, float, float] = OPENAI_DATASET_MEAN, |
|
image_std: tuple[float, float, float] = OPENAI_DATASET_STD, |
|
in_token_limit: int = 4096, |
|
merge_kernel_size: list[int, int] = [2, 2], |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.in_token_limit = in_token_limit |
|
self.patch_size = patch_size |
|
self.pad_input = pad_input |
|
self.image_mean = image_mean |
|
self.image_std = image_std |
|
self.merge_kernel_size = merge_kernel_size |
|
|
|
def rescale( |
|
self, image: Image.Image, merge_kernel_size: list[int, int] = [2, 2] |
|
) -> Image.Image: |
|
w, h = image.size |
|
patch_size = self.patch_size |
|
|
|
if (w // patch_size) * (h // patch_size) > self.in_token_limit: |
|
scale = math.sqrt(self.in_token_limit / ((w // patch_size) * (h // patch_size))) |
|
new_w, new_h = int(w * scale), int(h * scale) |
|
image = image.resize((new_w, new_h), Image.Resampling.BICUBIC) |
|
if self.pad_input: |
|
new_w, new_h = image.size |
|
pad_size_h = merge_kernel_size[0] * patch_size |
|
pad_size_w = merge_kernel_size[1] * patch_size |
|
|
|
pad_h = (pad_size_h - new_h % pad_size_h) % pad_size_h |
|
pad_w = (pad_size_w - new_w % pad_size_w) % pad_size_w |
|
|
|
image = TF.pad(image, (0, 0, pad_w, pad_h)) |
|
else: |
|
new_w, new_h = image.size |
|
new_w = new_w - new_w % patch_size |
|
new_h = new_h - new_h % patch_size |
|
image = TF.center_crop(image, (new_h, new_w)) |
|
|
|
w, h = image.size |
|
if w // patch_size >= 512 or h // patch_size >= 512: |
|
raise ValueError("Exceed pos emb") |
|
|
|
return image |
|
|
|
def to_tensor(self, image: Image.Image) -> torch.Tensor: |
|
return TF.to_tensor(image.convert("RGB")) |
|
|
|
def normalize(self, image: torch.Tensor) -> torch.Tensor: |
|
return TF.normalize(image, self.image_mean, self.image_std) |
|
|
|
def patchify(self, image: torch.Tensor) -> tuple[torch.Tensor, list[int, int]]: |
|
patch_size = self.patch_size |
|
C, H, W = image.shape |
|
patches = image.reshape(C, H // patch_size, patch_size, W // patch_size, patch_size) |
|
patches = patches.permute(1, 3, 0, 2, 4) |
|
patches = patches.contiguous().view(-1, C, patch_size, patch_size) |
|
grid_hw = (H // patch_size, W // patch_size) |
|
return patches, grid_hw |
|
|
|
def _preprocess(self, image: ImageInput) -> tuple[torch.Tensor, list[int, int]]: |
|
""" |
|
Preprocess image and patchify it. |
|
|
|
Args: |
|
image (`ImageInput`): |
|
Image to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. |
|
|
|
Returns: |
|
patches: torch.Tensor |
|
grid_hw: list[int, int] |
|
""" |
|
image = self.rescale(image, self.merge_kernel_size) |
|
image = self.to_tensor(image) |
|
image = self.normalize(image) |
|
patches, grid_hw = self.patchify(image) |
|
return patches, grid_hw |
|
|
|
def preprocess( |
|
self, |
|
images: ImageInput, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
) -> BatchFeature: |
|
images = make_list_of_images(images) |
|
|
|
if not valid_images(images): |
|
raise ValueError( |
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " |
|
"torch.Tensor, tf.Tensor or jax.ndarray." |
|
) |
|
|
|
pixel_values, image_grid_hws = [], [] |
|
for image in images: |
|
patches, image_grid_hw = self._preprocess(image) |
|
pixel_values.append(patches) |
|
image_grid_hws.append(image_grid_hw) |
|
pixel_values = torch.concat(pixel_values, dim=0) |
|
image_grid_hws = np.array(image_grid_hws) |
|
data = {"pixel_values": pixel_values, "image_grid_hws": image_grid_hws} |
|
|
|
return BatchFeature(data=data, tensor_type=return_tensors) |
|
|