Kimi-VL-A3B-Instruct / image_processing_kimi_vl.py
zhouzaida
upload
704b5c8
raw
history blame
4.68 kB
"""Image processor class for KimiVL."""
import math
import numpy as np
from PIL import Image
from typing import Optional, Union
import torch
from torchvision.transforms import functional as TF
from transformers.image_utils import ImageInput, make_list_of_images, valid_images
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.utils import TensorType
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
class KimiVLImageProcessor(BaseImageProcessor):
model_type = "kimi_vl"
def __init__(
self,
patch_size: int = 14,
pad_input: bool = False,
image_mean: tuple[float, float, float] = OPENAI_DATASET_MEAN,
image_std: tuple[float, float, float] = OPENAI_DATASET_STD,
in_token_limit: int = 4096,
merge_kernel_size: list[int, int] = [2, 2],
**kwargs,
):
super().__init__(**kwargs)
self.in_token_limit = in_token_limit
self.patch_size = patch_size
self.pad_input = pad_input
self.image_mean = image_mean
self.image_std = image_std
self.merge_kernel_size = merge_kernel_size
def rescale(
self, image: Image.Image, merge_kernel_size: list[int, int] = [2, 2]
) -> Image.Image:
w, h = image.size
patch_size = self.patch_size
if (w // patch_size) * (h // patch_size) > self.in_token_limit:
scale = math.sqrt(self.in_token_limit / ((w // patch_size) * (h // patch_size)))
new_w, new_h = int(w * scale), int(h * scale)
image = image.resize((new_w, new_h), Image.Resampling.BICUBIC)
if self.pad_input:
new_w, new_h = image.size
pad_size_h = merge_kernel_size[0] * patch_size
pad_size_w = merge_kernel_size[1] * patch_size
pad_h = (pad_size_h - new_h % pad_size_h) % pad_size_h
pad_w = (pad_size_w - new_w % pad_size_w) % pad_size_w
image = TF.pad(image, (0, 0, pad_w, pad_h))
else:
new_w, new_h = image.size
new_w = new_w - new_w % patch_size
new_h = new_h - new_h % patch_size
image = TF.center_crop(image, (new_h, new_w))
w, h = image.size
if w // patch_size >= 512 or h // patch_size >= 512:
raise ValueError("Exceed pos emb")
return image
def to_tensor(self, image: Image.Image) -> torch.Tensor:
return TF.to_tensor(image.convert("RGB"))
def normalize(self, image: torch.Tensor) -> torch.Tensor:
return TF.normalize(image, self.image_mean, self.image_std)
def patchify(self, image: torch.Tensor) -> tuple[torch.Tensor, list[int, int]]:
patch_size = self.patch_size
C, H, W = image.shape
patches = image.reshape(C, H // patch_size, patch_size, W // patch_size, patch_size)
patches = patches.permute(1, 3, 0, 2, 4)
patches = patches.contiguous().view(-1, C, patch_size, patch_size)
grid_hw = (H // patch_size, W // patch_size)
return patches, grid_hw
def _preprocess(self, image: ImageInput) -> tuple[torch.Tensor, list[int, int]]:
"""
Preprocess image and patchify it.
Args:
image (`ImageInput`):
Image to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
Returns:
patches: torch.Tensor
grid_hw: list[int, int]
"""
image = self.rescale(image, self.merge_kernel_size)
image = self.to_tensor(image)
image = self.normalize(image)
patches, grid_hw = self.patchify(image)
return patches, grid_hw
def preprocess(
self,
images: ImageInput,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
pixel_values, image_grid_hws = [], []
for image in images:
patches, image_grid_hw = self._preprocess(image)
pixel_values.append(patches)
image_grid_hws.append(image_grid_hw)
pixel_values = torch.concat(pixel_values, dim=0)
image_grid_hws = np.array(image_grid_hws)
data = {"pixel_values": pixel_values, "image_grid_hws": image_grid_hws}
return BatchFeature(data=data, tensor_type=return_tensors)