import torch import torch.nn as nn import numpy as np from PIL import Image from .model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX from .model.conversation import SeparatorStyle, conv_templates from .model.mm_utils import KeywordsStoppingCriteria, process_image, tokenizer_image_token from .model import get_model_name_from_path, load_pretrained_model from transformers import TextIteratorStreamer from threading import Thread class DescribeAnythingModel(nn.Module): def __init__(self, model_path, conv_mode, prompt_mode, temperature, top_p, num_beams, max_new_tokens, **kwargs): super().__init__() self.model_path = model_path self.conv_mode = conv_mode self.prompt_mode = prompt_mode self.temperature = temperature self.top_p = top_p self.num_beams = num_beams self.max_new_tokens = max_new_tokens tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, None, **kwargs) model.config.image_processor = image_processor self.tokenizer = tokenizer self.model = model self.context_len = context_len self.model_name = get_model_name_from_path(model_path) def get_prompt(self, qs): if DEFAULT_IMAGE_TOKEN not in qs: raise ValueError("no tag found in input.") conv = conv_templates[self.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() return prompt, conv @staticmethod def mask_to_box(mask_np): mask_coords = np.argwhere(mask_np) y0, x0 = mask_coords.min(axis=0) y1, x1 = mask_coords.max(axis=0) + 1 h = y1 - y0 w = x1 - x0 return x0, y0, w, h @classmethod def crop_image(cls, pil_img, mask_np, crop_mode, min_box_w=48, min_box_h=48): if crop_mode == "full": # no crop info = dict(mask_np=mask_np) return pil_img, info if crop_mode == "crop": # crop image and mask x0, y0, w, h = cls.mask_to_box(mask_np) img_np = np.asarray(pil_img) assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" cropped_mask_np = mask_np[y0:y0+h, x0:x0+w] cropped_img_np = img_np[y0:y0+h, x0:x0+w] cropped_pil_img = Image.fromarray(cropped_img_np) elif crop_mode == "context_crop": # crop image and mask x0, y0, w, h = cls.mask_to_box(mask_np) img_np = np.asarray(pil_img) assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" img_h, img_w = img_np.shape[:2] cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] cropped_pil_img = Image.fromarray(cropped_img_np) elif crop_mode == "focal_crop": # crop image and mask x0, y0, w, h = cls.mask_to_box(mask_np) img_np = np.asarray(pil_img) assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" img_h, img_w = img_np.shape[:2] xc, yc = x0 + w/2, y0 + h/2 # focal_crop: need to have at least min_box_w and min_box_h pixels, otherwise resizing to (384, 384) leads to artifacts that may be OOD w, h = max(w, min_box_w), max(h, min_box_h) x0, y0 = int(xc - w / 2), int(yc - h / 2) cropped_mask_np = mask_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] cropped_img_np = img_np[max(y0-h, 0):min(y0+2*h, img_h), max(x0-w, 0):min(x0+2*w, img_w)] cropped_pil_img = Image.fromarray(cropped_img_np) elif crop_mode == "crop_mask": # crop image and mask x0, y0, w, h = cls.mask_to_box(mask_np) img_np = np.asarray(pil_img) assert img_np.shape[:2] == mask_np.shape, f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}" cropped_mask_np = mask_np[y0:y0+h, x0:x0+w] cropped_img_np = img_np[y0:y0+h, x0:x0+w] # Mask the image cropped_img_np = cropped_img_np * cropped_mask_np[..., None] cropped_pil_img = Image.fromarray(cropped_img_np) else: raise ValueError(f"Unsupported crop_mode: {crop_mode}") info = dict(mask_np=cropped_mask_np) return cropped_pil_img, info def get_description(self, image_pil, mask_pil, query, streaming=False): prompt, conv = self.get_prompt(query) if not isinstance(image_pil, (list, tuple)): assert not isinstance(mask_pil, (list, tuple)), "image_pil and mask_pil must be both list or tuple or not list or tuple." image_pils = [image_pil] mask_pils = [mask_pil] else: image_pils = image_pil mask_pils = mask_pil description = self.get_description_from_prompt(image_pils, mask_pils, prompt, conv, streaming=streaming) return description def get_image_tensor(self, image_pil, mask_pil, crop_mode, crop_mode2): # the pil has True/False (if the value is non-zero, then we treat it as True) mask_np = (np.asarray(mask_pil) > 0).astype(np.uint8) images_tensor, image_info = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(image_pil, mask_np=mask_np, crop_mode=crop_mode)) images_tensor = images_tensor[None].to(self.model.device, dtype=torch.float16) mask_np = image_info["mask_np"] mask_pil = Image.fromarray(mask_np * 255) masks_tensor = process_image(mask_pil, self.model.config, None) masks_tensor = masks_tensor[None].to(self.model.device, dtype=torch.float16) images_tensor = torch.cat((images_tensor, masks_tensor[:, :1, ...]), dim=1) if crop_mode2 is not None: images_tensor2, image_info2 = process_image(image_pil, self.model.config, None, pil_preprocess_fn=lambda pil_img: self.crop_image(pil_img, mask_np=mask_np, crop_mode=crop_mode2)) images_tensor2 = images_tensor2[None].to(self.model.device, dtype=torch.float16) mask_np2 = image_info2["mask_np"] mask_pil2 = Image.fromarray(mask_np2 * 255) masks_tensor2 = process_image(mask_pil2, self.model.config, None) masks_tensor2 = masks_tensor2[None].to(self.model.device, dtype=torch.float16) images_tensor2 = torch.cat((images_tensor2, masks_tensor2[:, :1, ...]), dim=1) else: images_tensor2 = None return torch.cat((images_tensor, images_tensor2), dim=1) if images_tensor2 is not None else images_tensor def get_description_from_prompt(self, image_pils, mask_pils, prompt, conv, streaming=False): if streaming: return self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=True) else: # If streaming is False, there will be only one output output = self.get_description_from_prompt_iterator(image_pils, mask_pils, prompt, conv, streaming=False) return next(output) def get_description_from_prompt_iterator(self, image_pils, mask_pils, prompt, conv, streaming=False): crop_mode, crop_mode2 = self.prompt_mode.split("+") assert crop_mode == "full", "Current prompt only supports first crop as full (non-cropped). If you need other specifications, please update the prompt." assert len(image_pils) == len(mask_pils), f"image_pils and mask_pils must have the same length. Got {len(image_pils)} and {len(mask_pils)}." image_tensors = [self.get_image_tensor(image_pil, mask_pil, crop_mode=crop_mode, crop_mode2=crop_mode2) for image_pil, mask_pil in zip(image_pils, mask_pils)] input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) if streaming else None generation_kwargs = dict( input_ids=input_ids, images=image_tensors, do_sample=True if self.temperature > 0 else False, temperature=self.temperature, top_p=self.top_p, num_beams=self.num_beams, max_new_tokens=self.max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria], streamer=streamer ) if streaming: thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text if stop_str in generated_text: generated_text = generated_text[:generated_text.find(stop_str)] break yield new_text thread.join() else: with torch.inference_mode(): output_ids = self.model.generate(**generation_kwargs) outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[: -len(stop_str)] outputs = outputs.strip() yield outputs