import torch import argparse import PIL from PIL import Image import os from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from conversation import conv_templates, SeparatorStyle from torchvision import transforms from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN from threading import Thread from unitok.config import Args from unitok.model import UniTok from model.builder import load_pretrained_model from mm_utils import tokenizer_image_token, get_model_name_from_path IMAGE_TOKEN_INDEX=-200 def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def main(args): ckpt = torch.load(args.unitok_path, map_location='cpu') vae_cfg = Args() vae_cfg.load_state_dict(ckpt['args']) vq_model = UniTok(vae_cfg) vq_model.load_state_dict(ckpt['trainer']['unitok']) vq_model.to('cuda') vq_model.eval() model_path = os.path.expanduser(args.mllm_path) model_name = get_model_name_from_path(model_path) tokenizer, vqllm, image_processor, context_len = load_pretrained_model(model_path, model_name, load_8bit=args.load_8bit) qs = args.prompt qs = '' + '\n' + qs conv = conv_templates['llava_v1'].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() crop_size = 256 transform = transforms.Compose([ transforms.Resize((crop_size, crop_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) ]) print(prompt) image = Image.open(args.image_path).convert('RGB') pad_image = expand2square(image, (122, 116, 104) ) # import pdb;pdb.set_trace() img = transform(pad_image).unsqueeze(0) img = img.to('cuda') # import pdb;pdb.set_trace() with torch.no_grad(): vq_code = vq_model.img_to_idx(img) image_codes = vq_code.unsqueeze(0) input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') # input_ids = torch.cat(text_ids, dim=0) # input_embeddings = vqllm.embed_tokens(input_ids) inputs = { "inputs":input_ids.unsqueeze(0).to("cuda:0"), "images":image_codes.to("cuda:0"), "max_new_tokens":1024, "bos_token_id":tokenizer.bos_token_id, # Begin of sequence token "eos_token_id":tokenizer.eos_token_id, # End of sequence token "pad_token_id":tokenizer.pad_token_id, # Pad token } streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True}) # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) thread = Thread(target=vqllm.generate_mllm, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text print(generated_text) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--unitok_path', type=str, default=r'D:\projects\liquid_app\UniTok\UniTok_weights\unitok_tokenizer\unitok_tokenizer.pth',required=False) parser.add_argument('--mllm_path', type=str, default= r'C:\debug_ckpts\unitok_mllm', required=False) parser.add_argument('--prompt', type=str, required=True, help='input text prompt') parser.add_argument('--image_path', type=str, required=True, help='input image path') parser.add_argument('--load_8bit', action='store_true', default=False, help='use 8bit to save memory') args = parser.parse_args() main(args)