import base64 import io from PIL import Image from typing import Dict, List, Any from transformers.utils.import_utils import is_flash_attn_2_available from colpali_engine.models import ColQwen2, ColQwen2Processor import torch class EndpointHandler(): def __init__(self, path=""): self.model = ColQwen2.from_pretrained( path, torch_dtype=torch.bfloat16, device_map="cuda:0", # or "mps" if on Apple Silicon attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, ).eval() self.processor = ColQwen2Processor.from_pretrained(path) # self.model = torch.compile(self.model) print(f"Model and processor loaded {'with' if is_flash_attn_2_available() else 'without'} FA2") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Expects data in the following format: { "inputs": [ "base64_encoded_image1", "base64_encoded_image2", ... ] } Decodes each Base64 image into a PIL Image, processes them, and returns the embeddings. """ # Retrieve the list of base64 encoded images base64_images = data.get("inputs", []) if not isinstance(base64_images, list): base64_images = [base64_images] else: if len(base64_images) > 4: return {"message": "Send a maximum of 4 images at once. We recommend sending one by one."} # Decode each image from base64 and convert to a PIL Image decoded_images = [] for img_str in base64_images: try: img_data = base64.b64decode(img_str) image = Image.open(io.BytesIO(img_data)).convert("RGB") decoded_images.append(image) except Exception as e: print(f"Error decoding an image: {e}") # Process the images using the processor batch_images = self.processor.process_images(decoded_images).to(self.model.device) # Forward pass through the model with torch.no_grad(): image_embeddings = self.model(**batch_images).tolist() return {"embeddings": image_embeddings}