##!/usr/bin/python3 # -*- coding: utf-8 -*- import os, random, sys import numpy as np import requests import torch from pathlib import Path import pandas as pd import concurrent.futures import faiss import gradio as gr from PIL import Image import torch.nn.functional as F # 新增此行 from huggingface_hub import hf_hub_download, snapshot_download from scipy.ndimage import binary_dilation, binary_erosion from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration, Qwen2VLForConditionalGeneration, Qwen2VLProcessor) from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler from diffusers.image_processor import VaeImageProcessor from app.src.vlm_pipeline import ( vlm_response_editing_type, vlm_response_object_wait_for_edit, vlm_response_mask, vlm_response_prompt_after_apply_instruction ) from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline from app.utils.utils import load_grounding_dino_model from app.src.vlm_template import vlms_template from app.src.base_model_template import base_models_template from app.src.aspect_ratio_template import aspect_ratios from openai import OpenAI base_openai_url = "https://api.deepseek.com/" base_api_key = "sk-d145b963a92649a88843caeb741e8bbc" from transformers import BlipProcessor, BlipForConditionalGeneration from transformers import CLIPProcessor, CLIPModel from app.deepseek.instructions import ( create_apply_editing_messages_deepseek, create_decomposed_query_messages_deepseek ) from clip_retrieval.clip_client import ClipClient #### Description #### logo = r"""
BrushEdit logo
""" head = r"""

基于扩散模型先验和大语言模型的零样本组合查询图像检索

Project Page

""" descriptions = r""" Demo for ZS-CIR""" instructions = r""" Demo for ZS-CIR""" tips = r""" Demo for ZS-CIR """ citation = r""" Demo for ZS-CIR""" # - - - - - examples - - - - - # EXAMPLES = [ [ Image.open("./assets/frog/frog.jpeg").convert("RGBA"), "add a magic hat on frog head.", 642087011, "frog", "frog", True, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"), "replace the background to ancient China.", 648464818, "chinese_girl", "chinese_girl", True, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"), "remove the deer.", 648464818, "angel_christmas", "angel_christmas", False, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"), "add a wreath on head.", 648464818, "sunflower_girl", "sunflower_girl", True, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"), "add a butterfly fairy.", 648464818, "girl_on_sun", "girl_on_sun", True, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"), "remove the christmas hat.", 642087011, "spider_man_rm", "spider_man_rm", False, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"), "remove the flower.", 642087011, "anime_flower", "anime_flower", False, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"), "replace the clothes to a delicated floral skirt.", 648464818, "chenduling", "chenduling", True, False, "GPT4-o (Highly Recommended)" ], [ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"), "make the hedgehog in Italy.", 648464818, "hedgehog_rp_bg", "hedgehog_rp_bg", True, False, "GPT4-o (Highly Recommended)" ], ] INPUT_IMAGE_PATH = { "frog": "./assets/frog/frog.jpeg", "chinese_girl": "./assets/chinese_girl/chinese_girl.png", "angel_christmas": "./assets/angel_christmas/angel_christmas.png", "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png", "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png", "spider_man_rm": "./assets/spider_man_rm/spider_man.png", "anime_flower": "./assets/anime_flower/anime_flower.png", "chenduling": "./assets/chenduling/chengduling.jpg", "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png", } MASK_IMAGE_PATH = { "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png", "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png", "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png", "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png", "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png", "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png", "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png", "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png", "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png", } MASKED_IMAGE_PATH = { "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png", "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png", "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png", "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png", "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png", "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png", "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png", "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png", "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png", } OUTPUT_IMAGE_PATH = { "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png", "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png", "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png", "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png", "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png", "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png", "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png", "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png", "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png", } # os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir' # os.makedirs('gradio_temp_dir', exist_ok=True) VLM_MODEL_NAMES = list(vlms_template.keys()) DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)" BASE_MODELS = list(base_models_template.keys()) DEFAULT_BASE_MODEL = "realisticVision (Default)" ASPECT_RATIO_LABELS = list(aspect_ratios) DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0] ## init device try: if torch.cuda.is_available(): device = "cuda" elif sys.platform == "darwin" and torch.backends.mps.is_available(): device = "mps" else: device = "cpu" except: device = "cpu" # ## init torch dtype # if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): # torch_dtype = torch.bfloat16 # else: # torch_dtype = torch.float16 # if device == "mps": # torch_dtype = torch.float16 torch_dtype = torch.float16 # download hf models BrushEdit_path = "models/" if not os.path.exists(BrushEdit_path): BrushEdit_path = snapshot_download( repo_id="TencentARC/BrushEdit", local_dir=BrushEdit_path, token=os.getenv("HF_TOKEN"), ) ## init default VLM vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME] if vlm_processor != "" and vlm_model != "": vlm_model.to(device) else: raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.") ## init default LLM llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url) ## init base model base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE") brushnet_path = os.path.join(BrushEdit_path, "brushnetX") sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth") groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth") # input brushnetX ckpt path brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype) pipe = StableDiffusionBrushNetPipeline.from_pretrained( base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False ) # speed up diffusion process with faster scheduler and memory optimization pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) # remove following line if xformers is not installed or when using Torch 2.0. # pipe.enable_xformers_memory_efficient_attention() pipe.enable_model_cpu_offload() ## init SAM sam = build_sam(checkpoint=sam_path) sam.to(device=device) sam_predictor = SamPredictor(sam) sam_automask_generator = SamAutomaticMaskGenerator(sam) ## init groundingdino_model config_file = 'app/utils/GroundingDINO_SwinT_OGC.py' groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device) ## Ordinary function def crop_and_resize(image: Image.Image, target_width: int, target_height: int) -> Image.Image: """ Crops and resizes an image while preserving the aspect ratio. Args: image (Image.Image): Input PIL image to be cropped and resized. target_width (int): Target width of the output image. target_height (int): Target height of the output image. Returns: Image.Image: Cropped and resized image. """ # Original dimensions original_width, original_height = image.size original_aspect = original_width / original_height target_aspect = target_width / target_height # Calculate crop box to maintain aspect ratio if original_aspect > target_aspect: # Crop horizontally new_width = int(original_height * target_aspect) new_height = original_height left = (original_width - new_width) / 2 top = 0 right = left + new_width bottom = original_height else: # Crop vertically new_width = original_width new_height = int(original_width / target_aspect) left = 0 top = (original_height - new_height) / 2 right = original_width bottom = top + new_height # Crop and resize cropped_image = image.crop((left, top, right, bottom)) resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST) return resized_image ## Ordinary function def resize(image: Image.Image, target_width: int, target_height: int) -> Image.Image: """ Crops and resizes an image while preserving the aspect ratio. Args: image (Image.Image): Input PIL image to be cropped and resized. target_width (int): Target width of the output image. target_height (int): Target height of the output image. Returns: Image.Image: Cropped and resized image. """ # Original dimensions resized_image = image.resize((target_width, target_height), Image.NEAREST) return resized_image def move_mask_func(mask, direction, units): binary_mask = mask.squeeze()>0 rows, cols = binary_mask.shape moved_mask = np.zeros_like(binary_mask, dtype=bool) if direction == 'down': # move down moved_mask[max(0, units):, :] = binary_mask[:rows - units, :] elif direction == 'up': # move up moved_mask[:rows - units, :] = binary_mask[units:, :] elif direction == 'right': # move left moved_mask[:, max(0, units):] = binary_mask[:, :cols - units] elif direction == 'left': # move right moved_mask[:, :cols - units] = binary_mask[:, units:] return moved_mask def random_mask_func(mask, dilation_type='square', dilation_size=20): # Randomly select the size of dilation binary_mask = mask.squeeze()>0 if dilation_type == 'square_dilation': structure = np.ones((dilation_size, dilation_size), dtype=bool) dilated_mask = binary_dilation(binary_mask, structure=structure) elif dilation_type == 'square_erosion': structure = np.ones((dilation_size, dilation_size), dtype=bool) dilated_mask = binary_erosion(binary_mask, structure=structure) elif dilation_type == 'bounding_box': # find the most left top and left bottom point rows, cols = np.where(binary_mask) if len(rows) == 0 or len(cols) == 0: return mask # return original mask if no valid points min_row = np.min(rows) max_row = np.max(rows) min_col = np.min(cols) max_col = np.max(cols) # create a bounding box dilated_mask = np.zeros_like(binary_mask, dtype=bool) dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True elif dilation_type == 'bounding_ellipse': # find the most left top and left bottom point rows, cols = np.where(binary_mask) if len(rows) == 0 or len(cols) == 0: return mask # return original mask if no valid points min_row = np.min(rows) max_row = np.max(rows) min_col = np.min(cols) max_col = np.max(cols) # calculate the center and axis length of the ellipse center = ((min_col + max_col) // 2, (min_row + max_row) // 2) a = (max_col - min_col) // 2 # half long axis b = (max_row - min_row) // 2 # half short axis # create a bounding ellipse y, x = np.ogrid[:mask.shape[0], :mask.shape[1]] ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1 dilated_mask = np.zeros_like(binary_mask, dtype=bool) dilated_mask[ellipse_mask] = True else: ValueError("dilation_type must be 'square' or 'ellipse'") # use binary dilation dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255 return dilated_mask ## Gradio component function def update_vlm_model(vlm_name): global vlm_model, vlm_processor if vlm_model is not None: del vlm_model torch.cuda.empty_cache() vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name] ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py if vlm_type == "llava-next": if vlm_processor != "" and vlm_model != "": vlm_model.to(device) return vlm_model_dropdown else: if os.path.exists(vlm_local_path): vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path) vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto") else: if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)": vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto") elif vlm_name == "llama3-llava-next-8b-hf (Preload)": vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf") vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto") elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)": vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf") vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto") elif vlm_name == "llava-v1.6-34b-hf (Preload)": vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf") vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto") elif vlm_name == "llava-next-72b-hf (Preload)": vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf") vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto") elif vlm_type == "qwen2-vl": if vlm_processor != "" and vlm_model != "": vlm_model.to(device) return vlm_model_dropdown else: if os.path.exists(vlm_local_path): vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path) vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto") else: if vlm_name == "qwen2-vl-2b-instruct (Preload)": vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto") elif vlm_name == "qwen2-vl-7b-instruct (Preload)": vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto") elif vlm_name == "qwen2-vl-72b-instruct (Preload)": vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct") vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto") elif vlm_type == "openai": pass return "success" def update_base_model(base_model_name): global pipe ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py if pipe is not None: del pipe torch.cuda.empty_cache() base_model_path, pipe = base_models_template[base_model_name] if pipe != "": pipe.to(device) else: if os.path.exists(base_model_path): pipe = StableDiffusionBrushNetPipeline.from_pretrained( base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False ) # pipe.enable_xformers_memory_efficient_attention() pipe.enable_model_cpu_offload() else: raise gr.Error(f"The base model {base_model_name} does not exist") return "success" def process_random_mask(input_image, original_image, original_mask, resize_default, aspect_ratio_name, ): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse']) random_mask = random_mask_func(original_mask, dilation_type).squeeze() mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") masked_image = original_image * (1 - (random_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) def process_dilation_mask(input_image, original_image, original_mask, resize_default, aspect_ratio_name, dilation_size=20): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] dilation_type = np.random.choice(['square_dilation']) random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze() mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") masked_image = original_image * (1 - (random_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) def process_erosion_mask(input_image, original_image, original_mask, resize_default, aspect_ratio_name, dilation_size=20): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] dilation_type = np.random.choice(['square_erosion']) random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze() mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") masked_image = original_image * (1 - (random_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) def move_mask_left(input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio_name): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze() mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) if moved_mask.max() <= 1: moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) original_mask = moved_mask return [masked_image], [mask_image], original_mask.astype(np.uint8) def move_mask_right(input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio_name): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze() mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) if moved_mask.max() <= 1: moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) original_mask = moved_mask return [masked_image], [mask_image], original_mask.astype(np.uint8) def move_mask_up(input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio_name): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze() mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) if moved_mask.max() <= 1: moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) original_mask = moved_mask return [masked_image], [mask_image], original_mask.astype(np.uint8) def move_mask_down(input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio_name): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask if original_mask is None: raise gr.Error('Please generate mask first') if original_mask.ndim == 2: original_mask = original_mask[:,:,None] moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze() mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) if moved_mask.max() <= 1: moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) original_mask = moved_mask return [masked_image], [mask_image], original_mask.astype(np.uint8) def invert_mask(input_image, original_image, original_mask, ): alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) if input_mask.max() == 0: original_mask = 1 - (original_mask>0).astype(np.uint8) else: original_mask = 1 - (input_mask>0).astype(np.uint8) if original_mask is None: raise gr.Error('Please generate mask first') original_mask = original_mask.squeeze() mask_image = Image.fromarray(original_mask*255).convert("RGB") if original_mask.ndim == 2: original_mask = original_mask[:,:,None] if original_mask.max() <= 1: original_mask = (original_mask * 255).astype(np.uint8) masked_image = original_image * (1 - (original_mask>0)) masked_image = masked_image.astype(original_image.dtype) masked_image = Image.fromarray(masked_image) return [masked_image], [mask_image], original_mask, True def init_img(base, init_type, prompt, aspect_ratio, example_change_times ): image_pil = base["background"].convert("RGB") original_image = np.array(image_pil) if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0: raise gr.Error('image aspect ratio cannot be larger than 2.0') if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2: mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")] masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")] result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")] width, height = image_pil.size image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True) height_new, width_new = image_processor.get_default_height_width(image_pil, height, width) image_pil = image_pil.resize((width_new, height_new)) mask_gallery[0] = mask_gallery[0].resize((width_new, height_new)) masked_gallery[0] = masked_gallery[0].resize((width_new, height_new)) result_gallery[0] = result_gallery[0].resize((width_new, height_new)) original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1 return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times else: if aspect_ratio not in ASPECT_RATIO_LABELS: aspect_ratio = "Custom resolution" return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0 def reset_func(input_image, original_image, original_mask, prompt, target_prompt, ): input_image = None original_image = None original_mask = None prompt = '' mask_gallery = [] masked_gallery = [] result_gallery = [] target_prompt = '' if torch.cuda.is_available(): torch.cuda.empty_cache() return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False def update_example(example_type, prompt, example_change_times): input_image = INPUT_IMAGE_PATH[example_type] image_pil = Image.open(input_image).convert("RGB") mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")] masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")] result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")] width, height = image_pil.size image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True) height_new, width_new = image_processor.get_default_height_width(image_pil, height, width) image_pil = image_pil.resize((width_new, height_new)) mask_gallery[0] = mask_gallery[0].resize((width_new, height_new)) masked_gallery[0] = masked_gallery[0].resize((width_new, height_new)) result_gallery[0] = result_gallery[0].resize((width_new, height_new)) original_image = np.array(image_pil) original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1 aspect_ratio = "Custom resolution" example_change_times += 1 return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times def generate_target_prompt(input_image, original_image, prompt): # load example image if isinstance(original_image, str): original_image = input_image prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction( vlm_processor, vlm_model, original_image, prompt, device) return prompt_after_apply_instruction def process_mask(input_image, original_image, prompt, resize_default, aspect_ratio_name): if original_image is None: raise gr.Error('Please upload the input image') if prompt is None: raise gr.Error("Please input your instructions, e.g., remove the xxx") ## load mask alpha_mask = input_image["layers"][0].split()[3] input_mask = np.array(alpha_mask) # load example image if isinstance(original_image, str): original_image = input_image["background"] if input_mask.max() == 0: category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device) object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor, vlm_model, original_image, category, prompt, device) # original mask: h,w,1 [0, 255] original_mask = vlm_response_mask( vlm_processor, vlm_model, category, original_image, prompt, object_wait_for_edit, sam, sam_predictor, sam_automask_generator, groundingdino_model, device).astype(np.uint8) else: original_mask = input_mask.astype(np.uint8) category = None ## resize mask if needed output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if original_mask.ndim == 2: original_mask = original_mask[:,:,None] mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB") masked_image = original_image * (1 - (original_mask>0)) masked_image = masked_image.astype(np.uint8) masked_image = Image.fromarray(masked_image) return [masked_image], [mask_image], original_mask.astype(np.uint8), category def process(input_image, original_image, original_mask, prompt, negative_prompt, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps, num_samples, blending, category, target_prompt, resize_default, aspect_ratio_name, invert_mask_state): if original_image is None: if input_image is None: raise gr.Error('Please upload the input image') else: image_pil = input_image["background"].convert("RGB") original_image = np.array(image_pil) if prompt is None or prompt == "": if target_prompt is None or target_prompt == "": raise gr.Error("Please input your instructions, e.g., remove the xxx") alpha_mask = input_image["layers"][0].split()[3] input_mask = np.asarray(alpha_mask) output_w, output_h = aspect_ratios[aspect_ratio_name] if output_w == "" or output_h == "": output_h, output_w = original_image.shape[:2] if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") else: gr.Info(f"Output aspect ratio: {output_w}:{output_h}") pass else: if resize_default: short_side = min(output_w, output_h) scale_ratio = 640 / short_side output_w = int(output_w * scale_ratio) output_h = int(output_h * scale_ratio) gr.Info(f"Output aspect ratio: {output_w}:{output_h}") original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) original_image = np.array(original_image) if input_mask is not None: input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) input_mask = np.array(input_mask) if original_mask is not None: original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) original_mask = np.array(original_mask) if invert_mask_state: original_mask = original_mask else: if input_mask.max() == 0: original_mask = original_mask else: original_mask = input_mask # inpainting directly if target_prompt is not None if category is not None: pass elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None: pass else: try: category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device) except Exception as e: raise gr.Error("Please select the correct VLM model and input the correct API Key first!") if original_mask is not None: original_mask = np.clip(original_mask, 0, 255).astype(np.uint8) else: try: object_wait_for_edit = vlm_response_object_wait_for_edit( vlm_processor, vlm_model, original_image, category, prompt, device) original_mask = vlm_response_mask(vlm_processor, vlm_model, category, original_image, prompt, object_wait_for_edit, sam, sam_predictor, sam_automask_generator, groundingdino_model, device).astype(np.uint8) except Exception as e: raise gr.Error("Please select the correct VLM model and input the correct API Key first!") if original_mask.ndim == 2: original_mask = original_mask[:,:,None] if target_prompt is not None and len(target_prompt) >= 1: prompt_after_apply_instruction = target_prompt else: try: prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction( vlm_processor, vlm_model, original_image, prompt, device) except Exception as e: raise gr.Error("Please select the correct VLM model and input the correct API Key first!") generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed) with torch.autocast(device): image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe, prompt_after_apply_instruction, original_mask, original_image, generator, num_inference_steps, guidance_scale, control_strength, negative_prompt, num_samples, blending) original_image = np.array(init_image_np) masked_image = original_image * (1 - (mask_np>0)) masked_image = masked_image.astype(np.uint8) masked_image = Image.fromarray(masked_image) # Save the images (optional) # import uuid # uuid = str(uuid.uuid4()) # image[0].save(f"outputs/image_edit_{uuid}_0.png") # image[1].save(f"outputs/image_edit_{uuid}_1.png") # image[2].save(f"outputs/image_edit_{uuid}_2.png") # image[3].save(f"outputs/image_edit_{uuid}_3.png") # mask_image.save(f"outputs/mask_{uuid}.png") # masked_image.save(f"outputs/masked_image_{uuid}.png") # gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=20) return image, [mask_image], [masked_image], prompt, '', False # 新增事件处理函数 def generate_blip_description(input_image): if input_image is None: return "", "Input image cannot be None" try: image_pil = input_image["background"].convert("RGB") except KeyError: return "", "Input image missing 'background' key" except AttributeError as e: return "", f"Invalid image object: {str(e)}" try: description = generate_caption(blip_processor, blip_model, image_pil, device) return description, description # 同时更新state和显示组件 except Exception as e: return "", f"Caption generation failed: {str(e)}" from app.utils.utils import generate_caption blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device) clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device) def submit_GPT4o_KEY(GPT4o_KEY): global vlm_model, vlm_processor if vlm_model is not None: del vlm_model torch.cuda.empty_cache() try: vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com") vlm_processor = "" response = vlm_model.chat.completions.create( model="deepseek-chat", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."} ] ) response_str = response.choices[0].message.content return "Success. " + response_str, "GPT4-o (Highly Recommended)" except Exception as e: return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)" def verify_deepseek_api(): try: response = llm_model.chat.completions.create( model="deepseek-chat", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."} ] ) response_str = response.choices[0].message.content return True, "Success. " + response_str except Exception as e: return False, "Invalid DeepSeek API Key" def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt): try: messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt) response = llm_model.chat.completions.create( model="deepseek-chat", messages=messages ) response_str = response.choices[0].message.content return response_str except Exception as e: raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息") def llm_decomposed_prompt_after_apply_instruction(integrated_query): try: messages = create_decomposed_query_messages_deepseek(integrated_query) response = llm_model.chat.completions.create( model="deepseek-chat", messages=messages ) response_str = response.choices[0].message.content return response_str except Exception as e: raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息") def enhance_description(blip_description, prompt): try: if not prompt or not blip_description: print("Empty prompt or blip_description detected") return "", "" print(f"Enhancing with prompt: {prompt}") enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip_description, prompt) return enhanced_description, enhanced_description except Exception as e: print(f"Enhancement failed: {str(e)}") return "Error occurred", "Error occurred" def decompose_description(enhanced_description): try: if not enhanced_description: print("Empty enhanced_description detected") return "", "" print(f"Decomposing the enhanced description: {enhanced_description}") decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description) return decomposed_description, decomposed_description except Exception as e: print(f"Decomposition failed: {str(e)}") return "Error occurred", "Error occurred" @torch.no_grad() def mix_and_search(enhanced_text: str, gallery_images: list): # 获取最新生成的图像元组 latest_item = gallery_images[-1] if gallery_images else None # 初始化特征列表 features = [] # 图像特征提取 if latest_item and isinstance(latest_item, tuple): try: image_path = latest_item[0] pil_image = Image.open(image_path).convert("RGB") # 使用 CLIPProcessor 处理图像 image_inputs = clip_processor( images=pil_image, return_tensors="pt" ).to(device) image_features = clip_model.get_image_features(**image_inputs) features.append(F.normalize(image_features, dim=-1)) except Exception as e: print(f"图像处理失败: {str(e)}") # 文本特征提取 if enhanced_text.strip(): text_inputs = clip_processor( text=enhanced_text, return_tensors="pt", padding=True, truncation=True ).to(device) text_features = clip_model.get_text_features(**text_inputs) features.append(F.normalize(text_features, dim=-1)) if not features: return "## 错误:请先完成图像编辑并生成描述", [] # 特征融合与检索 mixed = sum(features) / len(features) mixed = F.normalize(mixed, dim=-1) # 加载Faiss索引和图片路径映射 index_path = "/home/zt/data/open-images/train/knn.index" input_data_dir = Path("/home/zt/data/open-images/train/embedding_folder/metadata") base_image_dir = Path("/home/zt/data/open-images/train/") # 按文件名中的数字排序并直接读取parquet文件 parquet_files = sorted( input_data_dir.glob('*.parquet'), key=lambda x: int(x.stem.split("_")[-1]) ) # 合并所有parquet数据 dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取 df = pd.concat(dfs, ignore_index=True) image_paths = df["image_path"].tolist() # 读取Faiss索引 index = faiss.read_index(index_path) assert mixed.shape[1] == index.d, "特征维度不匹配" # 执行检索 mixed = mixed.cpu().detach().numpy().astype('float32') distances, indices = index.search(mixed, 5) # 获取并验证图片路径 retrieved_images = [] for idx in indices[0]: if 0 <= idx < len(image_paths): img_path = base_image_dir / image_paths[idx] try: if img_path.exists(): retrieved_images.append(Image.open(img_path).convert("RGB")) else: print(f"警告:文件缺失 {img_path}") except Exception as e: print(f"图片加载失败: {str(e)}") return "## 检索到以下相似图片:", retrieved_images if retrieved_images else ("## 未找到匹配的图片", []) block = gr.Blocks( theme=gr.themes.Soft( radius_size=gr.themes.sizes.radius_none, text_size=gr.themes.sizes.text_md ) ) with block as demo: with gr.Row(): with gr.Column(): gr.HTML(head) gr.Markdown(descriptions) with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"): with gr.Row(equal_height=True): gr.Markdown(instructions) original_image = gr.State(value=None) original_mask = gr.State(value=None) category = gr.State(value=None) status = gr.State(value=None) invert_mask_state = gr.State(value=False) example_change_times = gr.State(value=0) deepseek_verified = gr.State(value=False) blip_description = gr.State(value="") enhanced_description = gr.State(value="") decomposed_description = gr.State(value="") with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.ImageEditor( label="参考图像", type="pil", brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"), layers = False, interactive=True, # height=1024, height=512, sources=["upload"], placeholder="🫧 点击此处或下面的图标上传图像 🫧", ) prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=1) run_button = gr.Button("💫 图像编辑") vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True) with gr.Group(): with gr.Row(): # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1) GPT4o_KEY = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1) GPT4o_KEY_submit = gr.Button("🙈 验证") aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO) resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True) with gr.Row(): mask_button = gr.Button("💎 掩膜生成") random_mask_button = gr.Button("Square/Circle Mask ") # 在分解按钮后添加 with gr.Group(): with gr.Row(): retrieve_button = gr.Button("🔍 开始检索") with gr.Row(): retrieve_output = gr.Markdown(elem_id="accordion") with gr.Row(): retrieve_gallery = gr.Gallery(label="🎊 检索结果",show_label=True, elem_id="gallery", preview=True, height=400) # 新增Gallery组件 with gr.Row(): generate_target_prompt_button = gr.Button("Generate Target Prompt") target_prompt = gr.Text( label="Input Target Prompt", max_lines=5, placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)", value='', lines=2 ) with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"): base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True) negative_prompt = gr.Text( label="Negative Prompt", max_lines=5, placeholder="Please input your negative prompt", value='ugly, low quality',lines=1 ) control_strength = gr.Slider( label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01 ) with gr.Group(): seed = gr.Slider( label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818 ) randomize_seed = gr.Checkbox(label="Randomize seed", value=False) blending = gr.Checkbox(label="Blending mode", value=True) num_samples = gr.Slider( label="Num samples", minimum=0, maximum=4, step=1, value=4 ) with gr.Group(): with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=1, maximum=12, step=0.1, value=7.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=50, ) with gr.Group(visible=True): # BLIP生成的描述 blip_output = gr.Textbox(label="原图描述", placeholder="💭 BLIP生成的图像基础描述 💭", interactive=True, lines=1) # DeepSeek API验证 with gr.Row(): deepseek_key = gr.Textbox(label="密钥输入", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1) verify_deepseek = gr.Button("🙈 验证") # 整合后的描述区域 with gr.Row(): enhanced_output = gr.Textbox(label="描述整合", placeholder="💭 DeepSeek生成的增强描述 💭", interactive=True, lines=3) enhance_button = gr.Button("✨ 整合") # 分解后的描述区域 with gr.Row(): decomposed_output = gr.Textbox(label="描述分解", placeholder="💭 DeepSeek生成的分解描述 💭", interactive=True, lines=3) decompose_button = gr.Button("🔧 分解") with gr.Row(): with gr.Tab(elem_classes="feedback", label="Masked Image"): masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360) with gr.Tab(elem_classes="feedback", label="Mask"): mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360) invert_mask_button = gr.Button("Invert Mask") dilation_size = gr.Slider( label="Dilation size: ", minimum=0, maximum=50, step=1, value=20 ) with gr.Row(): dilation_mask_button = gr.Button("Dilation Generated Mask") erosion_mask_button = gr.Button("Erosion Generated Mask") moving_pixels = gr.Slider( label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1 ) with gr.Row(): move_left_button = gr.Button("Move Left") move_right_button = gr.Button("Move Right") with gr.Row(): move_up_button = gr.Button("Move Up") move_down_button = gr.Button("Move Down") with gr.Tab(elem_classes="feedback", label="Output"): result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400) target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False) reset_button = gr.Button("Reset") init_type = gr.Textbox(label="Init Name", value="", visible=False) example_type = gr.Textbox(label="Example Name", value="", visible=False) with gr.Row(): example = gr.Examples( label="Quick Example", examples=EXAMPLES, inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown], examples_per_page=10, cache_examples=False, ) with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"): with gr.Row(equal_height=True): gr.Markdown(tips) with gr.Row(): gr.Markdown(citation) ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery. ## And we need to solve the conflict between the upload and change example functions. input_image.upload( init_img, [input_image, init_type, prompt, aspect_ratio, example_change_times], [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times] ) example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times]) ## vlm and base model dropdown vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status]) base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status]) GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown]) invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state]) ips=[input_image, original_image, original_mask, prompt, negative_prompt, control_strength, seed, randomize_seed, guidance_scale, num_inference_steps, num_samples, blending, category, target_prompt, resize_default, aspect_ratio, invert_mask_state] ## run brushedit run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state]) ## mask func mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category]) random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask]) dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask]) erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask]) ## reset func reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state]) # 绑定事件处理 input_image.upload(fn=generate_blip_description, inputs=[input_image], outputs=[blip_description, blip_output]) verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key]) enhance_button.click(fn=enhance_description, inputs=[blip_output, prompt], outputs=[enhanced_description, enhanced_output]) decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output]) # 修改事件绑定 retrieve_button.click( fn=mix_and_search, inputs=[enhanced_output, result_gallery], outputs=[retrieve_output, retrieve_gallery] ) demo.launch(server_name="0.0.0.0", server_port=12345, share=True)