ColPali v1.3 ONNX
Collection
6 items
•
Updated
•
1
import onnxruntime as ort
import os
import torch
from PIL import Image
from transformers import ColPaliProcessor
MODEL_IMAGE_PATH = "ssonpull519/colpali-v1.3-hf-image-onnx-fp16"
MODEL_TEXT_PATH = "ssonpull519/colpali-v1.3-hf-text-onnx-fp16"
device = "cuda"
processor = ColPaliProcessor.from_pretrained(MODEL_IMAGE_PATH)
# Your inputs
images = [
Image.open("image1.png"),
Image.open("image2.png"),
]
queries = [
"Who printed the edition of Romeo and Juliet?",
"When was the United States Declaration of Independence proclaimed?",
]
# Process the inputs
batch_images = processor(images=images, return_tensors="pt") # ['input_ids', 'attention_mask', 'pixel_values']; (B, 1030), (B, 3, 448, 448); input_ids are full of <image> + prefix.
batch_queries = processor(text=queries, return_tensors="pt") # ['input_ids', 'attention_mask']; (B, S)
# move inputs to GPU
batch_images = batch_images.to(device)
batch_queries = batch_queries.to(device)
# Convert the inputs to numpy arrays for the ONNX model
inputs_images_onnx = {name: tensor.cpu().numpy() for name, tensor in batch_images.items()}
inputs_queries_onnx = {name: tensor.cpu().numpy() for name, tensor in batch_queries.items()}
# Run the ONNX model
sess_image = ort.InferenceSession(os.path.join(MODEL_IMAGE_PATH, "model.onnx"))
sess_text = ort.InferenceSession(os.path.join(MODEL_TEXT_PATH, "model.onnx"))
onnx_output_images = sess_image.run(None, inputs_images_onnx)
onnx_output_queries = sess_text.run(None, inputs_queries_onnx)
# Score the queries against the images
scores = processor.score_retrieval(torch.Tensor(onnx_output_queries[0]), torch.Tensor(onnx_output_images[0])) # (Bt, Bi, S, 1030) -> (Bt, Bi)
print("onnx_output size [images]:", onnx_output_images[0].shape)
print("onnx_output size [queries]:", onnx_output_queries[0].shape)
print("scores:")
print(scores)
Currently in pull request (not merged yet).
optimum-cli export onnx --model vidore/colpali-v1.3-hf ./onnx_output --task feature-extraction --variant vision --dtype fp16
For fp16, there's an issue with transformers that is not fixed for now, so please use script below.
from pathlib import Path
from optimum.exporters import TasksManager
from optimum.exporters.onnx import export
from transformers import ColPaliForRetrieval
import torch
MODEL_PATH = "vidore/colpali-v1.3-hf"
VARIANT = "vision" # one of "vision" or "text"
ONNX_PATH = f"onnx/{VARIANT}/model.onnx"
MODEL_DTYPE = torch.float16 # one of torch.float32 or torch.float16
base_model = ColPaliForRetrieval.from_pretrained(MODEL_PATH)
base_model = base_model.to(dtype=MODEL_DTYPE)
onnx_path = Path(ONNX_PATH)
onnx_config_constructor = TasksManager.get_exporter_config_constructor("onnx", base_model)
onnx_config = onnx_config_constructor(base_model.config)
onnx_config.variant = VARIANT
onnx_inputs, onnx_outputs = export(base_model, onnx_config, onnx_path, onnx_config.DEFAULT_ONNX_OPSET)
# -- validate model --
import onnx
onnx_model = onnx.load(ONNX_PATH)
onnx.checker.check_model(ONNX_PATH)
from optimum.exporters.onnx import validate_model_outputs
validate_model_outputs(
onnx_config, base_model, onnx_path, ["embeddings"], onnx_config.ATOL_FOR_VALIDATION, use_subprocess=False
)