diff --git "a/brushedit_app_new.py" "b/brushedit_app_new.py" new file mode 100644--- /dev/null +++ "b/brushedit_app_new.py" @@ -0,0 +1,3451 @@ +##!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os, random, sys +import numpy as np +import requests +import torch + +from pathlib import Path +import pandas as pd +import concurrent.futures +import faiss +import gradio as gr + +from pathlib import Path +import os +import json + +import logging +from datetime import datetime +import time + +from PIL import Image +from torch import nn +from typing import Dict, List, Tuple + +import torch.nn.functional as F +from huggingface_hub import hf_hub_download, snapshot_download +from scipy.ndimage import binary_dilation, binary_erosion +from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration, Qwen2VLForConditionalGeneration, Qwen2VLProcessor) + +from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator +from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler +from diffusers.image_processor import VaeImageProcessor + + +from app.src.vlm_pipeline import ( + vlm_response_editing_type, + vlm_response_object_wait_for_edit, + vlm_response_mask, + vlm_response_prompt_after_apply_instruction +) +from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline +from app.utils.utils import load_grounding_dino_model + +from app.src.vlm_template import vlms_template +from app.src.base_model_template import base_models_template +from app.src.aspect_ratio_template import aspect_ratios + +from openai import OpenAI +base_openai_url = "https://api.deepseek.com/" +base_api_key = "sk-d145b963a92649a88843caeb741e8bbc" + + + +from transformers import BlipProcessor, BlipForConditionalGeneration +from transformers import Blip2Processor, Blip2ForConditionalGeneration + +from transformers import CLIPProcessor, CLIPModel + +from app.deepseek.instructions import ( + create_apply_editing_messages_deepseek, + create_decomposed_query_messages_deepseek +) +from clip_retrieval.clip_client import ClipClient + + +#### Description #### +logo = r""" +
BrushEdit logo
+""" +head = r""" +
+

+ 基于扩散模型先验和大语言模型的
+ 零样本组合查询图像检索 +

+
+ Project Page + + +
+
+
+""" + +descriptions = r""" +
+ +

+ 🎨 + 一个无需训练的组合图像检索的交互系统,支持通过文本指令修改参考图像并进行语义检索。 +

+
+""" + +instructions = r""" +
+ +
    +
  1. 上传图像:点击画布或上传按钮添加参考图像
  2. +
  3. 输入指令:在文本框中描述您想对图像进行的修改
  4. +
  5. 生成掩膜:使用掩膜工具精确控制编辑区域
  6. +
  7. 智能增强:系统会自动生成图像描述,并可进一步优化
  8. +
  9. 执行编辑:点击"图像编辑"按钮生成修改后的图像
  10. +
  11. 检索结果:点击"开始检索"获取相似图像结果
  12. +
+
+""" + +tips = r""" +
+

+ 🖌️ 图像编辑功能 +

+ + +

+ 🧠 智能描述系统 +

+ + +

+ 🔍 高级检索能力 +

+ + +

+ ⚙️ 技术参数调整 +

+ + +

+ 💡 使用建议 +

+ +
+""" + +citation = r""" +""" + +# - - - - - examples - - - - - # +EXAMPLES = [ + + [ + Image.open("./assets/frog/frog.jpeg").convert("RGBA"), + "add a magic hat on frog head.", + 642087011, + "frog", + "frog", + True, + False, + "GPT4-o (Highly Recommended)" + ], + [ + Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"), + "replace the background to ancient China.", + 648464818, + "chinese_girl", + "chinese_girl", + True, + False, + "GPT4-o (Highly Recommended)" + ], + [ + Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"), + "remove the deer.", + 648464818, + "angel_christmas", + "angel_christmas", + False, + False, + "GPT4-o (Highly Recommended)" + ], + [ + Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"), + "add a wreath on head.", + 648464818, + "sunflower_girl", + "sunflower_girl", + True, + False, + "GPT4-o (Highly Recommended)" + ], + [ + Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"), + "add a butterfly fairy.", + 648464818, + "girl_on_sun", + "girl_on_sun", + True, + False, + "GPT4-o (Highly Recommended)" + ], + [ + Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"), + "remove the christmas hat.", + 642087011, + "spider_man_rm", + "spider_man_rm", + False, + False, + "GPT4-o (Highly Recommended)" + ], + [ + Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"), + "remove the flower.", + 642087011, + "anime_flower", + "anime_flower", + False, + False, + "GPT4-o (Highly Recommended)" + ], + [ + Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"), + "replace the clothes to a delicated floral skirt.", + 648464818, + "chenduling", + "chenduling", + True, + False, + "GPT4-o (Highly Recommended)" + ], + [ + Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"), + "make the hedgehog in Italy.", + 648464818, + "hedgehog_rp_bg", + "hedgehog_rp_bg", + True, + False, + "GPT4-o (Highly Recommended)" + ], + +] + +INPUT_IMAGE_PATH = { + "frog": "./assets/frog/frog.jpeg", + "chinese_girl": "./assets/chinese_girl/chinese_girl.png", + "angel_christmas": "./assets/angel_christmas/angel_christmas.png", + "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png", + "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png", + "spider_man_rm": "./assets/spider_man_rm/spider_man.png", + "anime_flower": "./assets/anime_flower/anime_flower.png", + "chenduling": "./assets/chenduling/chengduling.jpg", + "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png", +} +MASK_IMAGE_PATH = { + "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png", + "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png", + "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png", + "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png", + "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png", + "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png", + "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png", + "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png", + "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png", +} +MASKED_IMAGE_PATH = { + "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png", + "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png", + "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png", + "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png", + "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png", + "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png", + "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png", + "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png", + "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png", +} +OUTPUT_IMAGE_PATH = { + "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png", + "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png", + "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png", + "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png", + "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png", + "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png", + "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png", + "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png", + "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png", +} + + + +VLM_MODEL_NAMES = list(vlms_template.keys()) +DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)" + + +BASE_MODELS = list(base_models_template.keys()) +DEFAULT_BASE_MODEL = "realisticVision (Default)" + +ASPECT_RATIO_LABELS = list(aspect_ratios) +DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0] + + +## init device +try: + if torch.cuda.is_available(): + device = "cuda" + elif sys.platform == "darwin" and torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" +except: + device = "cpu" + +# ## init torch dtype +# if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): +# torch_dtype = torch.bfloat16 +# else: +# torch_dtype = torch.float16 + +# if device == "mps": +# torch_dtype = torch.float16 + +torch_dtype = torch.float16 + + + +# download hf models +BrushEdit_path = "models/" +if not os.path.exists(BrushEdit_path): + BrushEdit_path = snapshot_download( + repo_id="TencentARC/BrushEdit", + local_dir=BrushEdit_path, + token=os.getenv("HF_TOKEN"), + ) + +## init default VLM +vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME] +if vlm_processor != "" and vlm_model != "": + vlm_model.to(device) +else: + raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.") + +## init default LLM +llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url) + +## init base model +base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE") +brushnet_path = os.path.join(BrushEdit_path, "brushnetX") +sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth") +groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth") + + +# input brushnetX ckpt path +brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype) +pipe = StableDiffusionBrushNetPipeline.from_pretrained( + base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False + ) +# speed up diffusion process with faster scheduler and memory optimization +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) +# remove following line if xformers is not installed or when using Torch 2.0. +# pipe.enable_xformers_memory_efficient_attention() +pipe.enable_model_cpu_offload() + + +## init SAM +sam = build_sam(checkpoint=sam_path) +sam.to(device=device) +sam_predictor = SamPredictor(sam) +sam_automask_generator = SamAutomaticMaskGenerator(sam) + +## init groundingdino_model +config_file = 'app/utils/GroundingDINO_SwinT_OGC.py' +groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device) + + + +blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", revision="51572668da0eb669e01a189dc22abe6088589a24") +blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", revision="51572668da0eb669e01a189dc22abe6088589a24", torch_dtype=torch.float16).to(device) + +clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") +clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device) + +# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") +# clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14",torch_dtype=torch.float16).to(device) + + + +# Ordinary function +def crop_and_resize(image: Image.Image, + target_width: int, + target_height: int) -> Image.Image: + """ + Crops and resizes an image while preserving the aspect ratio. + + Args: + image (Image.Image): Input PIL image to be cropped and resized. + target_width (int): Target width of the output image. + target_height (int): Target height of the output image. + + Returns: + Image.Image: Cropped and resized image. + """ + # Original dimensions + original_width, original_height = image.size + original_aspect = original_width / original_height + target_aspect = target_width / target_height + + # Calculate crop box to maintain aspect ratio + if original_aspect > target_aspect: + # Crop horizontally + new_width = int(original_height * target_aspect) + new_height = original_height + left = (original_width - new_width) / 2 + top = 0 + right = left + new_width + bottom = original_height + else: + # Crop vertically + new_width = original_width + new_height = int(original_width / target_aspect) + left = 0 + top = (original_height - new_height) / 2 + right = original_width + bottom = top + new_height + + # Crop and resize + cropped_image = image.crop((left, top, right, bottom)) + resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST) + return resized_image + + +def resize(image: Image.Image, + target_width: int, + target_height: int) -> Image.Image: + """ + Crops and resizes an image while preserving the aspect ratio. + + Args: + image (Image.Image): Input PIL image to be cropped and resized. + target_width (int): Target width of the output image. + target_height (int): Target height of the output image. + + Returns: + Image.Image: Cropped and resized image. + """ + # Original dimensions + resized_image = image.resize((target_width, target_height), Image.NEAREST) + return resized_image + + +def move_mask_func(mask, direction, units): + binary_mask = mask.squeeze()>0 + rows, cols = binary_mask.shape + moved_mask = np.zeros_like(binary_mask, dtype=bool) + + if direction == 'down': + # move down + moved_mask[max(0, units):, :] = binary_mask[:rows - units, :] + + elif direction == 'up': + # move up + moved_mask[:rows - units, :] = binary_mask[units:, :] + + elif direction == 'right': + # move left + moved_mask[:, max(0, units):] = binary_mask[:, :cols - units] + + elif direction == 'left': + # move right + moved_mask[:, :cols - units] = binary_mask[:, units:] + + return moved_mask + + +def random_mask_func(mask, dilation_type='square', dilation_size=20): + # Randomly select the size of dilation + binary_mask = mask.squeeze()>0 + + if dilation_type == 'square_dilation': + structure = np.ones((dilation_size, dilation_size), dtype=bool) + dilated_mask = binary_dilation(binary_mask, structure=structure) + elif dilation_type == 'square_erosion': + structure = np.ones((dilation_size, dilation_size), dtype=bool) + dilated_mask = binary_erosion(binary_mask, structure=structure) + elif dilation_type == 'bounding_box': + # find the most left top and left bottom point + rows, cols = np.where(binary_mask) + if len(rows) == 0 or len(cols) == 0: + return mask # return original mask if no valid points + + min_row = np.min(rows) + max_row = np.max(rows) + min_col = np.min(cols) + max_col = np.max(cols) + + # create a bounding box + dilated_mask = np.zeros_like(binary_mask, dtype=bool) + dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True + + elif dilation_type == 'bounding_ellipse': + # find the most left top and left bottom point + rows, cols = np.where(binary_mask) + if len(rows) == 0 or len(cols) == 0: + return mask # return original mask if no valid points + + min_row = np.min(rows) + max_row = np.max(rows) + min_col = np.min(cols) + max_col = np.max(cols) + + # calculate the center and axis length of the ellipse + center = ((min_col + max_col) // 2, (min_row + max_row) // 2) + a = (max_col - min_col) // 2 # half long axis + b = (max_row - min_row) // 2 # half short axis + + # create a bounding ellipse + y, x = np.ogrid[:mask.shape[0], :mask.shape[1]] + ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1 + dilated_mask = np.zeros_like(binary_mask, dtype=bool) + dilated_mask[ellipse_mask] = True + else: + ValueError("dilation_type must be 'square' or 'ellipse'") + + # use binary dilation + dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255 + return dilated_mask + + +## Gradio component function +def update_vlm_model(vlm_name): + global vlm_model, vlm_processor + if vlm_model is not None: + del vlm_model + torch.cuda.empty_cache() + + vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name] + + ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py + if vlm_type == "llava-next": + if vlm_processor != "" and vlm_model != "": + vlm_model.to(device) + return vlm_model_dropdown + else: + if os.path.exists(vlm_local_path): + vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path) + vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto") + else: + if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)": + vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") + vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto") + elif vlm_name == "llama3-llava-next-8b-hf (Preload)": + vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf") + vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto") + elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)": + vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf") + vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto") + elif vlm_name == "llava-v1.6-34b-hf (Preload)": + vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf") + vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto") + elif vlm_name == "llava-next-72b-hf (Preload)": + vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf") + vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto") + elif vlm_type == "qwen2-vl": + if vlm_processor != "" and vlm_model != "": + vlm_model.to(device) + return vlm_model_dropdown + else: + if os.path.exists(vlm_local_path): + vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path) + vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto") + else: + if vlm_name == "qwen2-vl-2b-instruct (Preload)": + vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") + vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto") + elif vlm_name == "qwen2-vl-7b-instruct (Preload)": + vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto") + elif vlm_name == "qwen2-vl-72b-instruct (Preload)": + vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct") + vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto") + elif vlm_type == "openai": + pass + return "success" + + +def update_base_model(base_model_name): + global pipe + ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py + if pipe is not None: + del pipe + torch.cuda.empty_cache() + base_model_path, pipe = base_models_template[base_model_name] + if pipe != "": + pipe.to(device) + else: + if os.path.exists(base_model_path): + pipe = StableDiffusionBrushNetPipeline.from_pretrained( + base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False + ) + # pipe.enable_xformers_memory_efficient_attention() + pipe.enable_model_cpu_offload() + else: + raise gr.Error(f"The base model {base_model_name} does not exist") + return "success" + + +def process_random_mask(input_image, + original_image, + original_mask, + resize_default, + aspect_ratio_name, + ): + + alpha_mask = input_image["layers"][0].split()[3] + input_mask = np.asarray(alpha_mask) + + output_w, output_h = aspect_ratios[aspect_ratio_name] + if output_w == "" or output_h == "": + output_h, output_w = original_image.shape[:2] + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + else: + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + pass + else: + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + + + if input_mask.max() == 0: + original_mask = original_mask + else: + original_mask = input_mask + + if original_mask is None: + raise gr.Error('Please generate mask first') + + if original_mask.ndim == 2: + original_mask = original_mask[:,:,None] + + dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse']) + random_mask = random_mask_func(original_mask, dilation_type).squeeze() + + mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") + + masked_image = original_image * (1 - (random_mask[:,:,None]>0)) + masked_image = masked_image.astype(original_image.dtype) + masked_image = Image.fromarray(masked_image) + + + return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) + + +def process_dilation_mask(input_image, + original_image, + original_mask, + resize_default, + aspect_ratio_name, + dilation_size=20): + + alpha_mask = input_image["layers"][0].split()[3] + input_mask = np.asarray(alpha_mask) + + output_w, output_h = aspect_ratios[aspect_ratio_name] + if output_w == "" or output_h == "": + output_h, output_w = original_image.shape[:2] + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + else: + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + pass + else: + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + + if input_mask.max() == 0: + original_mask = original_mask + else: + original_mask = input_mask + + if original_mask is None: + raise gr.Error('Please generate mask first') + + if original_mask.ndim == 2: + original_mask = original_mask[:,:,None] + + dilation_type = np.random.choice(['square_dilation']) + random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze() + + mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") + + masked_image = original_image * (1 - (random_mask[:,:,None]>0)) + masked_image = masked_image.astype(original_image.dtype) + masked_image = Image.fromarray(masked_image) + + return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) + + +def process_erosion_mask(input_image, + original_image, + original_mask, + resize_default, + aspect_ratio_name, + dilation_size=20): + alpha_mask = input_image["layers"][0].split()[3] + input_mask = np.asarray(alpha_mask) + output_w, output_h = aspect_ratios[aspect_ratio_name] + if output_w == "" or output_h == "": + output_h, output_w = original_image.shape[:2] + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + else: + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + pass + else: + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + + if input_mask.max() == 0: + original_mask = original_mask + else: + original_mask = input_mask + + if original_mask is None: + raise gr.Error('Please generate mask first') + + if original_mask.ndim == 2: + original_mask = original_mask[:,:,None] + + dilation_type = np.random.choice(['square_erosion']) + random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze() + + mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB") + + masked_image = original_image * (1 - (random_mask[:,:,None]>0)) + masked_image = masked_image.astype(original_image.dtype) + masked_image = Image.fromarray(masked_image) + + + return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8) + + +def move_mask_left(input_image, + original_image, + original_mask, + moving_pixels, + resize_default, + aspect_ratio_name): + + alpha_mask = input_image["layers"][0].split()[3] + input_mask = np.asarray(alpha_mask) + + output_w, output_h = aspect_ratios[aspect_ratio_name] + if output_w == "" or output_h == "": + output_h, output_w = original_image.shape[:2] + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + else: + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + pass + else: + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + + if input_mask.max() == 0: + original_mask = original_mask + else: + original_mask = input_mask + + if original_mask is None: + raise gr.Error('Please generate mask first') + + if original_mask.ndim == 2: + original_mask = original_mask[:,:,None] + + moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze() + mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") + + masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) + masked_image = masked_image.astype(original_image.dtype) + masked_image = Image.fromarray(masked_image) + + if moved_mask.max() <= 1: + moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) + original_mask = moved_mask + return [masked_image], [mask_image], original_mask.astype(np.uint8) + + +def move_mask_right(input_image, + original_image, + original_mask, + moving_pixels, + resize_default, + aspect_ratio_name): + alpha_mask = input_image["layers"][0].split()[3] + input_mask = np.asarray(alpha_mask) + + output_w, output_h = aspect_ratios[aspect_ratio_name] + if output_w == "" or output_h == "": + output_h, output_w = original_image.shape[:2] + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + else: + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + pass + else: + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + + if input_mask.max() == 0: + original_mask = original_mask + else: + original_mask = input_mask + + if original_mask is None: + raise gr.Error('Please generate mask first') + + if original_mask.ndim == 2: + original_mask = original_mask[:,:,None] + + moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze() + + mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") + + masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) + masked_image = masked_image.astype(original_image.dtype) + masked_image = Image.fromarray(masked_image) + + + if moved_mask.max() <= 1: + moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) + original_mask = moved_mask + + return [masked_image], [mask_image], original_mask.astype(np.uint8) + + +def move_mask_up(input_image, + original_image, + original_mask, + moving_pixels, + resize_default, + aspect_ratio_name): + alpha_mask = input_image["layers"][0].split()[3] + input_mask = np.asarray(alpha_mask) + + output_w, output_h = aspect_ratios[aspect_ratio_name] + if output_w == "" or output_h == "": + output_h, output_w = original_image.shape[:2] + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + else: + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + pass + else: + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + + if input_mask.max() == 0: + original_mask = original_mask + else: + original_mask = input_mask + + if original_mask is None: + raise gr.Error('Please generate mask first') + + if original_mask.ndim == 2: + original_mask = original_mask[:,:,None] + + moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze() + mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") + + masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) + masked_image = masked_image.astype(original_image.dtype) + masked_image = Image.fromarray(masked_image) + + if moved_mask.max() <= 1: + moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) + original_mask = moved_mask + + return [masked_image], [mask_image], original_mask.astype(np.uint8) + + +def move_mask_down(input_image, + original_image, + original_mask, + moving_pixels, + resize_default, + aspect_ratio_name): + alpha_mask = input_image["layers"][0].split()[3] + input_mask = np.asarray(alpha_mask) + output_w, output_h = aspect_ratios[aspect_ratio_name] + if output_w == "" or output_h == "": + output_h, output_w = original_image.shape[:2] + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + else: + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + pass + else: + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + + if input_mask.max() == 0: + original_mask = original_mask + else: + original_mask = input_mask + + if original_mask is None: + raise gr.Error('Please generate mask first') + + if original_mask.ndim == 2: + original_mask = original_mask[:,:,None] + + moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze() + mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB") + + masked_image = original_image * (1 - (moved_mask[:,:,None]>0)) + masked_image = masked_image.astype(original_image.dtype) + masked_image = Image.fromarray(masked_image) + + if moved_mask.max() <= 1: + moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8) + original_mask = moved_mask + + return [masked_image], [mask_image], original_mask.astype(np.uint8) + + +def invert_mask(input_image, + original_image, + original_mask, + ): + alpha_mask = input_image["layers"][0].split()[3] + input_mask = np.asarray(alpha_mask) + if input_mask.max() == 0: + original_mask = 1 - (original_mask>0).astype(np.uint8) + else: + original_mask = 1 - (input_mask>0).astype(np.uint8) + + if original_mask is None: + raise gr.Error('Please generate mask first') + + original_mask = original_mask.squeeze() + mask_image = Image.fromarray(original_mask*255).convert("RGB") + + if original_mask.ndim == 2: + original_mask = original_mask[:,:,None] + + if original_mask.max() <= 1: + original_mask = (original_mask * 255).astype(np.uint8) + + masked_image = original_image * (1 - (original_mask>0)) + masked_image = masked_image.astype(original_image.dtype) + masked_image = Image.fromarray(masked_image) + + return [masked_image], [mask_image], original_mask, True + + +def reset_func(input_image, + original_image, + original_mask, + prompt, + blip2_output, + target_prompt, + ): + input_image = None + original_image = None + original_mask = None + prompt = '' + mask_gallery = [] + masked_gallery = [] + result_gallery = [] + blip2_output = '' + target_prompt = '' + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, blip2_output, target_prompt, True, False + + +def update_example(example_type, + prompt, + example_change_times): + input_image = INPUT_IMAGE_PATH[example_type] + image_pil = Image.open(input_image).convert("RGB") + mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")] + masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")] + result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")] + width, height = image_pil.size + image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True) + height_new, width_new = image_processor.get_default_height_width(image_pil, height, width) + image_pil = image_pil.resize((width_new, height_new)) + mask_gallery[0] = mask_gallery[0].resize((width_new, height_new)) + masked_gallery[0] = masked_gallery[0].resize((width_new, height_new)) + result_gallery[0] = result_gallery[0].resize((width_new, height_new)) + + original_image = np.array(image_pil) + original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1 + aspect_ratio = "Custom resolution" + example_change_times += 1 + return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times + + +def generate_caption(blip2_processor, blip2_model, input_image, device): + image_pil = input_image["background"].convert("RGB") + inputs = blip2_processor(images=image_pil, return_tensors="pt").to(device, torch.float16) + generated_ids = blip2_model.generate(**inputs) + caption = blip2_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + return caption + + +def generate_blip2_description(input_image): + try: + description = generate_caption(blip2_processor, blip2_model, input_image, device) + return description, description + except Exception as e: + return "", f"Caption generation failed: {str(e)}" + + +def generate_target_prompt(input_image, + original_image, + prompt): + # load example image + if isinstance(original_image, str): + original_image = input_image + + image_caption = generate_caption(blip2_processor, blip2_model, input_image, device) + + prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction( + vlm_processor, + vlm_model, + llm_model, + original_image, + image_caption, + prompt, + device) + return prompt_after_apply_instruction + + +def init_img(base, + init_type, + prompt, + aspect_ratio, + example_change_times + ): + image_pil = base["background"].convert("RGB") + original_image = np.array(image_pil) + if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0: + raise gr.Error('image aspect ratio cannot be larger than 2.0') + if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2: + mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")] + masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")] + result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")] + width, height = image_pil.size + image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True) + height_new, width_new = image_processor.get_default_height_width(image_pil, height, width) + image_pil = image_pil.resize((width_new, height_new)) + mask_gallery[0] = mask_gallery[0].resize((width_new, height_new)) + masked_gallery[0] = masked_gallery[0].resize((width_new, height_new)) + result_gallery[0] = result_gallery[0].resize((width_new, height_new)) + original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1 + return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times + else: + if aspect_ratio not in ASPECT_RATIO_LABELS: + aspect_ratio = "Custom resolution" + return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0 + + +def process_mask(input_image, + original_image, + prompt, + resize_default, + aspect_ratio_name): + if original_image is None: + raise gr.Error('Please upload the input image') + if prompt is None: + raise gr.Error("Please input your instructions, e.g., remove the xxx") + + ## load mask + alpha_mask = input_image["layers"][0].split()[3] + input_mask = np.array(alpha_mask) + + # load example image + if isinstance(original_image, str): + original_image = input_image["background"] + + image_caption = generate_caption(blip2_processor, blip2_model, input_image, device) + + if input_mask.max() == 0: + # category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device) + category = vlm_response_editing_type(vlm_processor, vlm_model, llm_model, original_image, image_caption, prompt, device) + + object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor, + vlm_model, + llm_model, + original_image, + image_caption, + category, + prompt, + device) + # original mask: h,w,1 [0, 255] + original_mask = vlm_response_mask( + vlm_processor, + vlm_model, + category, + original_image, + prompt, + object_wait_for_edit, + sam, + sam_predictor, + sam_automask_generator, + groundingdino_model, + device).astype(np.uint8) + else: + original_mask = input_mask.astype(np.uint8) + category = None + + ## resize mask if needed + output_w, output_h = aspect_ratios[aspect_ratio_name] + if output_w == "" or output_h == "": + output_h, output_w = original_image.shape[:2] + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + else: + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + pass + else: + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + + + if original_mask.ndim == 2: + original_mask = original_mask[:,:,None] + + mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB") + + masked_image = original_image * (1 - (original_mask>0)) + masked_image = masked_image.astype(np.uint8) + masked_image = Image.fromarray(masked_image) + + return [masked_image], [mask_image], original_mask.astype(np.uint8), category + + +def process(input_image, + original_image, + original_mask, + prompt, + negative_prompt, + control_strength, + seed, + randomize_seed, + guidance_scale, + num_inference_steps, + num_samples, + blending, + category, + target_prompt, + resize_default, + aspect_ratio_name, + invert_mask_state): + if original_image is None: + if input_image is None: + raise gr.Error('Please upload the input image') + else: + image_pil = input_image["background"].convert("RGB") + original_image = np.array(image_pil) + if prompt is None or prompt == "": + if target_prompt is None or target_prompt == "": + raise gr.Error("Please input your instructions, e.g., remove the xxx") + + alpha_mask = input_image["layers"][0].split()[3] + input_mask = np.asarray(alpha_mask) + output_w, output_h = aspect_ratios[aspect_ratio_name] + if output_w == "" or output_h == "": + output_h, output_w = original_image.shape[:2] + + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + else: + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + pass + else: + if resize_default: + short_side = min(output_w, output_h) + scale_ratio = 640 / short_side + output_w = int(output_w * scale_ratio) + output_h = int(output_h * scale_ratio) + gr.Info(f"Output aspect ratio: {output_w}:{output_h}") + original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h)) + original_image = np.array(original_image) + if input_mask is not None: + input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h)) + input_mask = np.array(input_mask) + if original_mask is not None: + original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h)) + original_mask = np.array(original_mask) + + if invert_mask_state: + original_mask = original_mask + else: + if input_mask.max() == 0: + original_mask = original_mask + else: + original_mask = input_mask + + image_caption = generate_caption(blip2_processor, blip2_model, input_image, device) + + # inpainting directly if target_prompt is not None + if category is not None: + pass + elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None: + pass + else: + try: + # category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device) + category = vlm_response_editing_type(vlm_processor, vlm_model, llm_model, original_image, image_caption, prompt, device) + print(category) + except Exception as e: + raise gr.Error("Time1. Please select the correct VLM model and input the correct API Key first!") + + + if original_mask is not None: + original_mask = np.clip(original_mask, 0, 255).astype(np.uint8) + else: + try: + object_wait_for_edit = vlm_response_object_wait_for_edit( + vlm_processor, + vlm_model, + llm_model, + original_image, + image_caption, + category, + prompt, + device) + print(object_wait_for_edit) + + original_mask = vlm_response_mask(vlm_processor, + vlm_model, + category, + original_image, + prompt, + object_wait_for_edit, + sam, + sam_predictor, + sam_automask_generator, + groundingdino_model, + device).astype(np.uint8) + print("Got genarated mask!") + except Exception as e: + raise gr.Error("Time2. Please select the correct VLM model and input the correct API Key first!") + + if original_mask.ndim == 2: + original_mask = original_mask[:,:,None] + + + if target_prompt is not None and len(target_prompt) >= 1: + prompt_after_apply_instruction = target_prompt + + else: + try: + prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction( + vlm_processor, + vlm_model, + llm_model, + original_image, + image_caption, + prompt, + device) + except Exception as e: + raise gr.Error("Time3. Please select the correct VLM model and input the correct API Key first!") + + generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed) + + + with torch.autocast(device): + image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe, + prompt_after_apply_instruction, + original_mask, + original_image, + generator, + num_inference_steps, + guidance_scale, + control_strength, + negative_prompt, + num_samples, + blending) + original_image = np.array(init_image_np) + masked_image = original_image * (1 - (mask_np>0)) + masked_image = masked_image.astype(np.uint8) + masked_image = Image.fromarray(masked_image) + + return image, [mask_image], [masked_image], prompt, prompt_after_apply_instruction, False + + +def submit_GPT4o_KEY(GPT4o_KEY): + global vlm_model, vlm_processor + if vlm_model is not None: + del vlm_model + torch.cuda.empty_cache() + try: + vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com") + vlm_processor = "" + response = vlm_model.chat.completions.create( + model="deepseek-chat", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello."} + ] + ) + response_str = response.choices[0].message.content + + return "Success. " + response_str, "GPT4-o (Highly Recommended)" + except Exception as e: + return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)" + + +def verify_deepseek_api(): + try: + response = llm_model.chat.completions.create( + model="deepseek-chat", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello."} + ] + ) + response_str = response.choices[0].message.content + + return True, "Success. " + response_str + + except Exception as e: + return False, "Invalid DeepSeek API Key" + + +def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt): + try: + messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt) + response = llm_model.chat.completions.create( + model="deepseek-chat", + messages=messages + ) + response_str = response.choices[0].message.content + return response_str + except Exception as e: + raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息") + + +def llm_decomposed_prompt_after_apply_instruction(integrated_query): + try: + messages = create_decomposed_query_messages_deepseek(integrated_query) + response = llm_model.chat.completions.create( + model="deepseek-chat", + messages=messages + ) + response_str = response.choices[0].message.content + return response_str + except Exception as e: + raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息") + + +def enhance_description(blip2_description, prompt): + try: + if not prompt or not blip2_description: + print("Empty prompt or blip2_description detected") + return "", "" + + print(f"Enhancing with prompt: {prompt}") + enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip2_description, prompt) + return enhanced_description, enhanced_description + + except Exception as e: + print(f"Enhancement failed: {str(e)}") + return "Error occurred", "Error occurred" + + +def decompose_description(enhanced_description): + try: + if not enhanced_description: + print("Empty enhanced_description detected") + return "", "" + + print(f"Decomposing the enhanced description: {enhanced_description}") + decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description) + return decomposed_description, decomposed_description + + except Exception as e: + print(f"Decomposition failed: {str(e)}") + return "Error occurred", "Error occurred" + + +@torch.no_grad() +def mix_and_search(enhanced_text: str, gallery_images: list): + # 获取最新生成的图像元组 + latest_item = gallery_images[-1] if gallery_images else None + + # 初始化特征列表 + features = [] + + # 图像特征提取 + if latest_item and isinstance(latest_item, tuple): + try: + image_path = latest_item[0] + pil_image = Image.open(image_path).convert("RGB") + + # 使用 CLIPProcessor 处理图像 + image_inputs = clip_processor( + images=pil_image, + return_tensors="pt" + ).to(device) + + image_features = clip_model.get_image_features(**image_inputs) + features.append(F.normalize(image_features, dim=-1)) + except Exception as e: + print(f"图像处理失败: {str(e)}") + + # 文本特征提取 + if enhanced_text.strip(): + text_inputs = clip_processor( + text=enhanced_text, + return_tensors="pt", + padding=True, + truncation=True + ).to(device) + + text_features = clip_model.get_text_features(**text_inputs) + features.append(F.normalize(text_features, dim=-1)) + + if not features: + return [] + + + # 特征融合逻辑 + if len(features) == 2: + # 图文双特征时加权:40%图像 + 60%文本 + mixed = 0.4 * features[0] + 0.6 * features[1] + else: + # 单特征时直接求和(保持原逻辑) + mixed = sum(features) + + mixed = F.normalize(mixed, dim=-1) + + # 加载Faiss索引和图片路径映射 + index_path = "/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_knn.index" + input_data_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_embedding_folder/metadata") + base_image_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/") + + # 按文件名中的数字排序并直接读取parquet文件 + parquet_files = sorted( + input_data_dir.glob('*.parquet'), + key=lambda x: int(x.stem.split("_")[-1]) + ) + + # 合并所有parquet数据 + dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取 + df = pd.concat(dfs, ignore_index=True) + image_paths = df["image_path"].tolist() + + index = faiss.read_index(index_path) + assert mixed.shape[1] == index.d, "特征维度不匹配" + mixed = mixed.cpu().detach().numpy().astype('float32') + distances, indices = index.search(mixed, 50) + + # 获取并验证图片路径 + retrieved_images = [] + for idx in indices[0]: + if 0 <= idx < len(image_paths): + img_path = base_image_dir / image_paths[idx] + try: + if img_path.exists(): + retrieved_images.append(Image.open(img_path).convert("RGB")) + else: + print(f"警告:文件缺失 {img_path}") + except Exception as e: + print(f"图片加载失败: {str(e)}") + + return retrieved_images if retrieved_images else ([]) + +# 问题如下:在cap_file中,每张图片会充当多次reference的值,能够将同一张图片区分开来的是cap_file中的“caption”。我可以如何调整现在的保存逻辑,使得能够区分同一reference不同caption的情况,而不是直接覆盖呢? +# def process_cirr_images(): + +# if not all([vlm_model, sam_predictor, groundingdino_model]): +# raise RuntimeError("Required models not initialized") + +# # Define paths +# dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev") +# cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json") +# output_dirs = { +# "edited": Path("/home/zt/data/BrushEdit/cirr/img_paint/cirr_edited"), +# "mask": Path("/home/zt/data/BrushEdit/cirr/img_paint/cirr_mask"), +# "masked": Path("/home/zt/data/BrushEdit/cirr/img_paint/cirr_masked") +# } +# output_json_path = Path("/home/zt/data/BrushEdit/cirr/image_paint.json") +# descriptions = {} + +# # Create output directories +# for dir_path in output_dirs.values(): +# dir_path.mkdir(parents=True, exist_ok=True) + +# # Load captions +# with open(cap_file, 'r') as f: +# captions = json.load(f) + +# for img_path in dev_dir.glob("*.png"): +# base_name = img_path.stem +# caption = next((item["caption"] for item in captions if item.get("reference") == base_name), None) + +# if not caption: +# print(f"Warning: No caption for {base_name}") +# continue + +# try: +# # 构造空alpha通道(全0) +# rgb_image = Image.open(img_path).convert("RGB") +# empty_alpha = Image.new("L", rgb_image.size, 0) # 全透明alpha通道 +# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha)) + +# # 调用init_img初始化 +# base = {"background": image, "layers": [image]} +# init_results = init_img( +# base=base, +# init_type="custom", # 使用自定义初始化 +# prompt=caption, +# aspect_ratio="Custom resolution", +# example_change_times=0 +# ) + +# # 获取初始化后的参数 +# input_image = init_results[0] +# original_image = init_results[1] +# original_mask = init_results[2] + +# # 正确设置process参数 +# result_images, mask_images, masked_images, _, target_description, _ = process( +# input_image=input_image, +# original_image=original_image, +# original_mask=original_mask, # 传递初始化后的mask +# prompt=caption, +# negative_prompt="ugly, low quality", +# control_strength=1.0, +# seed=648464818, +# randomize_seed=False, +# guidance_scale=7.5, +# num_inference_steps=50, +# num_samples=1, +# blending=True, +# category=None, +# target_prompt="", +# resize_default=True, +# aspect_ratio_name="Custom resolution", +# invert_mask_state=False +# ) + +# # Save images +# output_dirs["edited"].mkdir(exist_ok=True) +# result_images[0].save(output_dirs["edited"] / f"{base_name}.png") +# mask_images[0].save(output_dirs["mask"] / f"{base_name}_mask.png") +# masked_images[0].save(output_dirs["masked"] / f"{base_name}_masked.png") + +# # Generate BLIP2 description +# blip2_desc, _ = generate_blip2_description(input_image) + +# descriptions[base_name] = { +# "original_caption": caption, +# "blip2_description": blip2_desc, +# "llm_enhanced_caption": target_description +# } + +# with open(output_json_path, 'w') as f: +# json.dump(descriptions, f, indent=4) # indent保持可读性 + +# print(f"Processed {base_name}") + +# except Exception as e: +# print(f"Error processing {base_name}: {str(e)}") +# continue + +# print("Processing completed!") + + +# cirr 的 val数据集。目前来看这个函数没有太大问题 +# def process_cirr_images(): + +# if not all([vlm_model, sam_predictor, groundingdino_model]): +# raise RuntimeError("Required models not initialized") + +# # Define paths +# # dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev") +# # cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json") +# # output_dirs = { +# # "edited": Path("/home/zt/data/BrushEdit/cirr/img_paint_pairid/cirr_edited"), +# # "mask": Path("/home/zt/data/BrushEdit/cirr/img_paint_pairid/cirr_mask"), +# # "masked": Path("/home/zt/data/BrushEdit/cirr/img_paint_pairid/cirr_masked") +# # } +# # output_json_path = Path("/home/zt/data/BrushEdit/cirr/image_paint_pairid.json") + + +# dev_dir = Path("/home/zt/data/BrushEdit/CIRR/dev") +# cap_file = Path("/home/zt/data/BrushEdit/CIRR/cirr/captions/val_deepseek_missed_174.json") +# output_dirs = { +# "edited": Path("/home/zt/data/BrushEdit/CIRR/img_paint_pairid/missed/cirr_edited"), +# "mask": Path("/home/zt/data/BrushEdit/CIRR/img_paint_pairid/missed/cirr_mask"), +# "masked": Path("/home/zt/data/BrushEdit/CIRR/img_paint_pairid/missed/cirr_masked") +# } +# output_json_path = Path("/home/zt/data/BrushEdit/CIRR/cirr/image_paint_deepseek_missed_pairid.json") +# descriptions = {} + +# # Create output directories +# for dir_path in output_dirs.values(): +# dir_path.mkdir(parents=True, exist_ok=True) + +# # Load captions +# with open(cap_file, 'r') as f: +# captions = json.load(f) + +# for img_path in dev_dir.glob("*.png"): +# base_name = img_path.stem +# # 获取所有匹配的caption条目 +# matched_items = [item for item in captions if item.get("reference") == base_name] + +# if not matched_items: +# print(f"Warning: No captions for {base_name}") +# continue + +# for item in matched_items: +# # 验证必要字段存在 +# pairid = item.get("pairid") +# caption = item.get("caption") + +# if not all([pairid, caption]): +# print(f"Skipping invalid item for {base_name}: {item}") +# continue + +# # 使用pairid构造唯一标识 +# processed_base = f"{base_name}_{pairid}" + +# try: +# # 构造空alpha通道 +# rgb_image = Image.open(img_path).convert("RGB") +# empty_alpha = Image.new("L", rgb_image.size, 0) +# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha)) + +# # 初始化图像 +# base = {"background": image, "layers": [image]} +# init_results = init_img( +# base=base, +# init_type="custom", +# prompt=caption, +# aspect_ratio="Custom resolution", +# example_change_times=0 +# ) + +# # 获取处理参数 +# input_image = init_results[0] +# original_image = init_results[1] +# original_mask = init_results[2] + +# # 正确设置process参数 +# result_images, mask_images, masked_images, _, target_description, _ = process( +# input_image=input_image, +# original_image=original_image, +# original_mask=original_mask, # 传递初始化后的mask +# prompt=caption, +# negative_prompt="ugly, low quality", +# control_strength=1.0, +# seed=648464818, +# randomize_seed=False, +# guidance_scale=7.5, +# num_inference_steps=50, +# num_samples=1, +# blending=True, +# category=None, +# target_prompt="", +# resize_default=True, +# aspect_ratio_name="Custom resolution", +# invert_mask_state=False +# ) + +# # 保存文件(使用pairid标识) +# result_images[0].save(output_dirs["edited"] / f"{processed_base}.png") +# mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.png") +# masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.png") + +# # 生成描述 +# blip2_desc, _ = generate_blip2_description(input_image) + +# # 使用pairid作为主键存储元数据 +# descriptions[pairid] = { +# "reference": base_name, +# "user_editing_prompt": caption, +# "blip2_description": blip2_desc, +# "llm_enhanced_caption": target_description, +# "processed_files": { +# "edited": f"{processed_base}.png", +# "mask": f"{processed_base}_mask.png", +# "masked": f"{processed_base}_masked.png" +# } +# } + +# print(f"Processed {processed_base}") + +# except Exception as e: +# print(f"Error processing {processed_base}: {str(e)}") +# continue + +# # 保存元数据 +# with open(output_json_path, 'w') as f: +# json.dump(descriptions, f, indent=4) + +# print("Processing completed!") + +# cirr 的 test1数据集。 +# def process_cirr_images(): + +# if not all([vlm_model, sam_predictor, groundingdino_model]): +# raise RuntimeError("Required models not initialized") + +# # Define paths +# dev_dir = Path("/home/zt/data/BrushEdit/CIRR/test1") +# cap_file = Path("/home/zt/data/BrushEdit/CIRR/cirr/captions/cap.rc2.test1.json") +# output_dirs = { +# "edited": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/cirr_edited"), +# "mask": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/cirr_mask"), +# "masked": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/cirr_masked") +# } +# output_json_path = Path("/home/zt/data/BrushEdit/CIRR/test1_image_paint_pairid.json") +# descriptions = {} + +# # Create output directories +# for dir_path in output_dirs.values(): +# dir_path.mkdir(parents=True, exist_ok=True) + +# # Load captions +# with open(cap_file, 'r') as f: +# captions = json.load(f) + +# for img_path in dev_dir.glob("*.png"): +# base_name = img_path.stem +# # 获取所有匹配的caption条目 +# matched_items = [item for item in captions if item.get("reference") == base_name] + +# if not matched_items: +# print(f"Warning: No captions for {base_name}") +# continue + +# for item in matched_items: +# # 验证必要字段存在 +# pairid = item.get("pairid") +# caption = item.get("caption") + +# if not all([pairid, caption]): +# print(f"Skipping invalid item for {base_name}: {item}") +# continue + +# # 使用pairid构造唯一标识 +# processed_base = f"{base_name}_{pairid}" + +# try: +# # 构造空alpha通道 +# rgb_image = Image.open(img_path).convert("RGB") +# empty_alpha = Image.new("L", rgb_image.size, 0) +# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha)) + +# # 初始化图像 +# base = {"background": image, "layers": [image]} +# init_results = init_img( +# base=base, +# init_type="custom", +# prompt=caption, +# aspect_ratio="Custom resolution", +# example_change_times=0 +# ) + +# # 获取处理参数 +# input_image = init_results[0] +# original_image = init_results[1] +# original_mask = init_results[2] + +# # 正确设置process参数 +# result_images, mask_images, masked_images, _, target_description, _ = process( +# input_image=input_image, +# original_image=original_image, +# original_mask=original_mask, # 传递初始化后的mask +# prompt=caption, +# negative_prompt="ugly, low quality", +# control_strength=1.0, +# seed=648464818, +# randomize_seed=False, +# guidance_scale=7.5, +# num_inference_steps=50, +# num_samples=1, +# blending=True, +# category=None, +# target_prompt="", +# resize_default=True, +# aspect_ratio_name="Custom resolution", +# invert_mask_state=False +# ) + +# # 保存文件(使用pairid标识) +# result_images[0].save(output_dirs["edited"] / f"{processed_base}.png") +# mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.png") +# masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.png") + +# # 生成描述 +# blip2_desc, _ = generate_blip2_description(input_image) + +# # 使用pairid作为主键存储元数据 +# descriptions[pairid] = { +# "reference": base_name, +# "user_editing_prompt": caption, +# "blip2_description": blip2_desc, +# "llm_enhanced_caption": target_description, +# "processed_files": { +# "edited": f"{processed_base}.png", +# "mask": f"{processed_base}_mask.png", +# "masked": f"{processed_base}_masked.png" +# } +# } + +# print(f"Processed {processed_base}") + +# except Exception as e: +# print(f"Error processing {processed_base}: {str(e)}") +# continue + +# # 保存元数据 +# with open(output_json_path, 'w') as f: +# json.dump(descriptions, f, indent=4) + +# print("Processing completed!") + + +def process_cirr_images(): + if not all([vlm_model, sam_predictor, groundingdino_model]): + raise RuntimeError("Required models not initialized") + + # Define paths + dev_dir = Path("/home/zt/data/BrushEdit/CIRR/test1") + cap_file = Path("/home/zt/data/BrushEdit/CIRR/cirr/captions/cap.rc2.test1.json") + output_dirs = { + "edited": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/qw_cirr_edited"), + "mask": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/qw_cirr_mask"), + "masked": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/qw_cirr_masked") + } + output_json_path = Path("/home/zt/data/BrushEdit/CIRR/qw_test1_image_paint_pairid.json") + + # Create output directories + for dir_path in output_dirs.values(): + dir_path.mkdir(parents=True, exist_ok=True) + + # 1. 加载已有处理结果 + processed_pairids = set() + if output_json_path.exists(): + try: + with open(output_json_path, 'r') as f: + descriptions = json.load(f) + processed_pairids = set(descriptions.keys()) + print(f"Loaded {len(processed_pairids)} previously processed pairids") + except Exception as e: + print(f"Error loading existing results: {str(e)}, starting from scratch") + descriptions = {} + else: + descriptions = {} + + # 2. 创建临时文件写入器 + temp_json_path = output_json_path.with_suffix(".tmp") + + # Load captions + with open(cap_file, 'r') as f: + captions = json.load(f) + + # 3. 处理进度跟踪 + total = 0 + processed = 0 + for img_path in dev_dir.glob("*.png"): + base_name = img_path.stem + matched_items = [item for item in captions if item.get("reference") == base_name] + + if not matched_items: + print(f"Warning: No captions for {base_name}") + continue + + for item in matched_items: + total += 1 + pairid = str(item.get("pairid")) # 确保字符串类型 + caption = item.get("caption") + + if not all([pairid, caption]): + print(f"Skipping invalid item for {base_name}: {item}") + continue + + # 4. 跳过已处理的pairid + if pairid in processed_pairids: + print(f"Skipping already processed pairid: {pairid}") + continue + + processed_base = f"{base_name}_{pairid}" + print(f"Processing {processed_base} ({processed+1}/{total})") + + try: + # 构造空alpha通道 + rgb_image = Image.open(img_path).convert("RGB") + empty_alpha = Image.new("L", rgb_image.size, 0) + image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha)) + + # 初始化图像 + base = {"background": image, "layers": [image]} + init_results = init_img( + base=base, + init_type="custom", + prompt=caption, + aspect_ratio="Custom resolution", + example_change_times=0 + ) + + # 获取处理参数 + input_image = init_results[0] + original_image = init_results[1] + original_mask = init_results[2] + + # 正确设置process参数 + result_images, mask_images, masked_images, _, target_description, _ = process( + input_image=input_image, + original_image=original_image, + original_mask=original_mask, # 传递初始化后的mask + prompt=caption, + negative_prompt="ugly, low quality", + control_strength=1.0, + seed=648464818, + randomize_seed=False, + guidance_scale=7.5, + num_inference_steps=50, + num_samples=1, + blending=True, + category=None, + target_prompt="", + resize_default=True, + aspect_ratio_name="Custom resolution", + invert_mask_state=False + ) + + # 保存文件(使用pairid标识) + result_images[0].save(output_dirs["edited"] / f"{processed_base}.png") + mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.png") + masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.png") + + # 生成描述 + blip2_desc, _ = generate_blip2_description(input_image) + + # 更新描述信息 使用pairid作为主键存储元数据 + descriptions[pairid] = { + "reference": base_name, + "user_editing_prompt": caption, + "blip2_description": blip2_desc, + "llm_enhanced_caption": target_description, + "processed_files": { + "edited": f"{processed_base}.png", + "mask": f"{processed_base}_mask.png", + "masked": f"{processed_base}_masked.png" + } + } + + # 5. 原子化写入:先写入临时文件,再替换原文件 + with open(temp_json_path, 'w') as f: + json.dump(descriptions, f, indent=4) + temp_json_path.replace(output_json_path) + + processed +=1 + processed_pairids.add(pairid) + print(f"Successfully processed {pairid}") + + except Exception as e: + print(f"Error processing {pairid}: {str(e)}") + # 删除可能生成的不完整文件 + for ext in ["", "_mask.png", "_masked.png"]: + incomplete_file = output_dirs["edited"] / f"{processed_base}{ext}" + if incomplete_file.exists(): + incomplete_file.unlink() + continue + + print(f"Processing completed! Total processed: {processed}/{total}") + +# circo 的 val数据集。目前来看这个函数没有太大问题 +def process_circo_val_images(): + + if not all([vlm_model, sam_predictor, groundingdino_model]): + raise RuntimeError("Required models not initialized") + + # Define paths + # 实际上不单单是dev数据集的图片存储空间,所有的(包括test集合&没有用到的)图片都存在这里 + dev_dir = Path("/home/zt/data/BrushEdit/CIRCO/COCO2017_unlabeled/unlabeled2017") + cap_file = Path("/home/zt/data/BrushEdit/CIRCO/annotations/val.json") + output_dirs = { + "edited": Path("/home/zt/data/BrushEdit/CIRCO/img_paint_pairid/circo_edited"), + "mask": Path("/home/zt/data/BrushEdit/CIRCO/img_paint_pairid/circo_mask"), + "masked": Path("/home/zt/data/BrushEdit/CIRCO/img_paint_pairid/circo_masked") + } + output_json_path = Path("/home/zt/data/BrushEdit/CIRCO/image_paint_pairid.json") + descriptions = {} + + # Create output directories + for dir_path in output_dirs.values(): + dir_path.mkdir(parents=True, exist_ok=True) + + # Load captions + with open(cap_file, 'r') as f: + captions = json.load(f) + + + for img_path in dev_dir.glob("*.jpg"): + # 不包含扩展名 + base_name = img_path.stem + # 提取后六位作为参考ID,reference_part是字符串类型 + reference_part = base_name[-6:] + # 将JSON中的reference_img_id(原本为int),转换为字符串后比较 + matched_items = [item for item in captions if str(item.get("reference_img_id")) == reference_part] + + if not matched_items: + print(f"Warning: No captions for {base_name}") + continue + + for item in matched_items: + # 验证必要字段存在 + pairid = item.get("id") + caption = item.get("relative_caption") + + if not all([pairid, caption]): + print(f"Skipping invalid item for {base_name}: {item}") + continue + + # 使用pairid构造唯一标识 + # 使用f-string进行字符串格式化时,它会自动将非字符串类型的变量转换为字符串类型 + processed_base = f"{reference_part}_{pairid}" + + try: + # 构造空alpha通道 + rgb_image = Image.open(img_path).convert("RGB") + empty_alpha = Image.new("L", rgb_image.size, 0) + image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha)) + + # 初始化图像 + base = {"background": image, "layers": [image]} + init_results = init_img( + base=base, + init_type="custom", + prompt=caption, + aspect_ratio="Custom resolution", + example_change_times=0 + ) + + # 获取处理参数 + input_image = init_results[0] + original_image = init_results[1] + original_mask = init_results[2] + + # 正确设置process参数 + result_images, mask_images, masked_images, _, target_description, _ = process( + input_image=input_image, + original_image=original_image, + original_mask=original_mask, # 传递初始化后的mask + prompt=caption, + negative_prompt="ugly, low quality", + control_strength=1.0, + seed=648464818, + randomize_seed=False, + guidance_scale=7.5, + num_inference_steps=50, + num_samples=1, + blending=True, + category=None, + target_prompt="", + resize_default=True, + aspect_ratio_name="Custom resolution", + invert_mask_state=False + ) + + # 保存文件(使用pairid标识) + result_images[0].save(output_dirs["edited"] / f"{processed_base}.jpg") + mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.jpg") + masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.jpg") + + # 生成描述 + blip2_desc, _ = generate_blip2_description(input_image) + + # 使用pairid作为主键存储元数据 + descriptions[pairid] = { + "reference": int(reference_part), + "user_editing_prompt": caption, + "blip2_description": blip2_desc, + "llm_enhanced_caption": target_description, + "processed_files": { + "edited": f"{processed_base}.jpg", + "mask": f"{processed_base}_mask.jpg", + "masked": f"{processed_base}_masked.jpg" + } + } + + print(f"Processed {processed_base}") + + except Exception as e: + print(f"Error processing {processed_base}: {str(e)}") + continue + + # 保存元数据 + with open(output_json_path, 'w') as f: + json.dump(descriptions, f, indent=4) + + print("Processing completed!") + + +# circo 的 test数据集。目前来看这个函数没有太大问题 +def process_circo_test_images(): + + if not all([vlm_model, sam_predictor, groundingdino_model]): + raise RuntimeError("Required models not initialized") + + # Define paths + # 实际上不单单是dev数据集的图片存储空间,所有的(包括test集合&没有用到的)图片都存在这里 + dev_dir = Path("/home/zt/data/BrushEdit/CIRCO/COCO2017_unlabeled/unlabeled2017") + cap_file = Path("/home/zt/data/BrushEdit/CIRCO/annotations/test.json") + output_dirs = { + "edited": Path("/home/zt/data/BrushEdit/CIRCO/test_img_paint_pairid/circo_edited"), + "mask": Path("/home/zt/data/BrushEdit/CIRCO/test_img_paint_pairid/circo_mask"), + "masked": Path("/home/zt/data/BrushEdit/CIRCO/test_img_paint_pairid/circo_masked") + } + output_json_path = Path("/home/zt/data/BrushEdit/CIRCO/test_image_paint_pairid.json") + descriptions = {} + + # Create output directories + for dir_path in output_dirs.values(): + dir_path.mkdir(parents=True, exist_ok=True) + + # Load captions + with open(cap_file, 'r') as f: + captions = json.load(f) + + + for img_path in dev_dir.glob("*.jpg"): + # 不包含扩展名 + base_name = img_path.stem + # 提取后六位作为参考ID,reference_part是字符串类型 + reference_part = base_name[-6:] + # 将JSON中的reference_img_id(原本为int),转换为字符串后比较 + matched_items = [item for item in captions if str(item.get("reference_img_id")) == reference_part] + + if not matched_items: + print(f"Warning: No captions for {base_name}") + continue + + for item in matched_items: + # 验证必要字段存在 + pairid = item.get("id") + caption = item.get("relative_caption") + + if not all([pairid, caption]): + print(f"Skipping invalid item for {base_name}: {item}") + continue + + # 使用pairid构造唯一标识 + # 使用f-string进行字符串格式化时,它会自动将非字符串类型的变量转换为字符串类型 + processed_base = f"{reference_part}_{pairid}" + + try: + # 构造空alpha通道 + rgb_image = Image.open(img_path).convert("RGB") + empty_alpha = Image.new("L", rgb_image.size, 0) + image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha)) + + # 初始化图像 + base = {"background": image, "layers": [image]} + init_results = init_img( + base=base, + init_type="custom", + prompt=caption, + aspect_ratio="Custom resolution", + example_change_times=0 + ) + + # 获取处理参数 + input_image = init_results[0] + original_image = init_results[1] + original_mask = init_results[2] + + # 正确设置process参数 + result_images, mask_images, masked_images, _, target_description, _ = process( + input_image=input_image, + original_image=original_image, + original_mask=original_mask, # 传递初始化后的mask + prompt=caption, + negative_prompt="ugly, low quality", + control_strength=1.0, + seed=648464818, + randomize_seed=False, + guidance_scale=7.5, + num_inference_steps=50, + num_samples=1, + blending=True, + category=None, + target_prompt="", + resize_default=True, + aspect_ratio_name="Custom resolution", + invert_mask_state=False + ) + + # 保存文件(使用pairid标识) + result_images[0].save(output_dirs["edited"] / f"{processed_base}.jpg") + mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.jpg") + masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.jpg") + + # 生成描述 + blip2_desc, _ = generate_blip2_description(input_image) + + # 使用pairid作为主键存储元数据 + descriptions[pairid] = { + "reference": int(reference_part), + "user_editing_prompt": caption, + "blip2_description": blip2_desc, + "llm_enhanced_caption": target_description, + "processed_files": { + "edited": f"{processed_base}.jpg", + "mask": f"{processed_base}_mask.jpg", + "masked": f"{processed_base}_masked.jpg" + } + } + + print(f"Processed {processed_base}") + + except Exception as e: + print(f"Error processing {processed_base}: {str(e)}") + continue + + # 保存元数据 + with open(output_json_path, 'w') as f: + json.dump(descriptions, f, indent=4) + + print("Processing completed!") + + + + + +# 目前来看这个函数没有太大问题 +@torch.no_grad() +def batch_mix_and_search_cirr( + json_path: str = "/home/zt/data/BrushEdit/cirr/image_paint_pairid.json", + image_dir: str = "/home/zt/data/BrushEdit/cirr/img_paint/cirr_edited", + alpha: float = 0.8, + batch_size: int = 32, + output_json_path: str = "retrieval_results_pairid.json" +) -> Dict[str, List[Dict]]: + # 加载索引和元数据 + index = faiss.read_index("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_knn.index") + metadata = pd.read_parquet("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_embedding_folder/metadata/metadata_0.parquet") + all_index_ids = metadata["image_path"].tolist() + + # 加载并验证输入数据 + with open(json_path) as f: + samples = json.load(f) + + valid_samples = [] + image_dir = Path(image_dir) + for pair_id, sample_info in samples.items(): + reference = sample_info["reference"] + img_path = image_dir / f"{reference}.png" + if img_path.exists(): + valid_samples.append(( + pair_id, + reference, + img_path, + sample_info['llm_enhanced_caption'], + sample_info['user_editing_prompt'] + )) + + # 初始化结果字典 + results = {} + total_samples = len(valid_samples) + + # 分批次处理 + for batch_idx in range(0, total_samples, batch_size): + batch_end = min(batch_idx + batch_size, total_samples) + current_batch = valid_samples[batch_idx:batch_end] + + # 批量处理图像(保持原逻辑) + batch_images = [Image.open(s[2]).convert("RGB") for s in current_batch] + image_inputs = clip_processor(images=batch_images, return_tensors="pt").to(device) + image_features = clip_model.get_image_features(**image_inputs) + image_features = nn.functional.normalize(image_features, dim=-1) + + # 批量处理文本(保持原逻辑) + batch_texts = [s[3] for s in current_batch] + text_inputs = clip_processor( + text=batch_texts, + return_tensors="pt", + padding=True, + truncation=True + ).to(device) + text_features = clip_model.get_text_features(**text_inputs) + text_features = nn.functional.normalize(text_features, dim=-1) + + # 混合特征 + mixed_features = (1 - alpha) * image_features + alpha * text_features + mixed_features = nn.functional.normalize(mixed_features, dim=-1) + + # 批量检索 + query_features = mixed_features.cpu().numpy().astype("float32") + distances, indices = index.search(query_features, 100) + + # 保存当前批次结果 + for (pair_id, reference, _, enhanced_cap, editing_prompt), dist_row, idx_row in zip(current_batch, distances, indices): + results[pair_id] = { + "reference": reference, + "llm_enhanced_caption": enhanced_cap, + "user_editing_prompt": editing_prompt, + "retrieved_results": [] + } + for distance, idx in zip(dist_row, idx_row): + if 0 <= idx < len(all_index_ids): + raw_id = all_index_ids[idx] + base_name = os.path.basename(raw_id) + file_name = os.path.splitext(base_name)[0] + results[pair_id]["retrieved_results"].append({ + "retrieved_id": file_name, + "score": float(distance), + }) + + # 保存结果到JSON + with open(output_json_path, 'w') as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + print("Retrieving completed!") + return results + +# @torch.no_grad() +# def batch_mix_and_search_cirr( +# json_path: str = "/home/zt/data/BrushEdit/cirr/image_paint_pairid.json", +# image_dir: str = "/home/zt/data/BrushEdit/cirr/img_paint/cirr_edited", +# alpha: float = 0.8, +# batch_size: int = 32, +# output_json_path: str = "retrieval_results_pairid.json" +# ) -> Dict[str, List[Dict]]: +# # 加载索引和元数据 +# index = faiss.read_index("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_knn.index") +# metadata = pd.read_parquet("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_embedding_folder/metadata/metadata_0.parquet") +# all_index_ids = metadata["image_path"].tolist() + +# # 加载并验证输入数据 +# with open(json_path) as f: +# samples = json.load(f) + +# valid_samples = [] +# image_dir = Path(image_dir) +# for image_id, sample_info in samples.items(): +# img_path = image_dir / f"{image_id}.png" +# if img_path.exists(): +# valid_samples.append(( +# image_id, +# img_path, +# sample_info['llm_enhanced_caption'], +# sample_info['user_editing_prompt'] +# )) + +# # 初始化结果字典 +# results = {} +# total_samples = len(valid_samples) + +# # 分批次处理 +# for batch_idx in range(0, total_samples, batch_size): +# batch_end = min(batch_idx + batch_size, total_samples) +# current_batch = valid_samples[batch_idx:batch_end] + +# # 批量处理图像 +# batch_images = [Image.open(s[1]).convert("RGB") for s in current_batch] +# image_inputs = clip_processor(images=batch_images, return_tensors="pt").to(device) +# image_features = clip_model.get_image_features(**image_inputs) +# image_features = nn.functional.normalize(image_features, dim=-1) + +# # 批量处理文本 +# batch_texts = [s[2] for s in current_batch] +# text_inputs = clip_processor( +# text=batch_texts, +# return_tensors="pt", +# padding=True, +# truncation=True +# ).to(device) +# text_features = clip_model.get_text_features(**text_inputs) +# text_features = nn.functional.normalize(text_features, dim=-1) + +# # 混合特征 +# mixed_features = (1 - alpha) * image_features + alpha * text_features +# mixed_features = nn.functional.normalize(mixed_features, dim=-1) + +# # 批量检索 +# query_features = mixed_features.cpu().numpy().astype("float32") +# distances, indices = index.search(query_features, 100) + +# # 保存当前批次结果 +# for (sample_id, _, enhanced_cap, original_cap), dist_row, idx_row in zip(current_batch, distances, indices): +# results[sample_id] = { +# "llm_enhanced_caption": enhanced_cap, # 增强描述 +# "original_caption": original_cap, # 原始描述 +# "retrieved_results": [] +# } +# for distance, idx in zip(dist_row, idx_row): +# if 0 <= idx < len(all_index_ids): +# raw_id = all_index_ids[idx] +# base_name = os.path.basename(raw_id) +# file_name = os.path.splitext(base_name)[0] +# results[sample_id]["retrieved_results"].append({ +# "retrieved_id": file_name, +# "score": float(distance), +# }) + +# # 保存结果到JSON +# with open(output_json_path, 'w') as f: +# json.dump(results, f, indent=2, ensure_ascii=False) + +# print("Retrieving completed!") +# return results + + +# def evaluate_cirr_scores() -> List[Tuple[str, float]]: + +# # 设置数据集和检索结果的路径 +# dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json" +# retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_quchong.json" + +# # 加载数据集 +# with open(dataset_path, 'r') as f: +# dataset = json.load(f) +# print(len(dataset)) + +# # 加载检索结果 +# with open(retrieval_results_path, 'r') as f: +# retrieval_results = json.load(f) +# print(len(retrieval_results)) + +# # 数据结构初始化 +# all_target_captions_soft = [] +# all_set_member_idx = [] +# nn_result = [] + +# # 构建匹配数据结构 +# for sample in dataset: +# all_target_captions_soft.append(sample["target_soft"]) +# all_set_member_idx.append(sample["img_set"]["members"]) +# query_id = str(sample["reference"]) +# retrieved_items = retrieval_results.get(query_id, []) +# nn_result.append([item["retrieved_id"] for item in retrieved_items]) + +# # 计算召回指标 +# out = [] +# # Recall@K (全局检索) +# for k in [1, 5, 10, 50]: +# total_score = 0.0 +# for i in range(len(dataset)): +# query_id = str(dataset[i]["reference"]) # 获取当前查询的参考ID +# # 过滤掉参考图像本身 +# filtered_results = [rid for rid in nn_result[i] if rid != query_id] +# top_k = filtered_results[:k] + +# best_score = 0.0 +# for target_id, score in all_target_captions_soft[i].items(): +# if target_id in top_k: +# best_score = max(best_score, score) +# total_score += best_score +# recall = total_score / len(dataset) * 100 + +# out.append((f"recall_top{k}_correct_composition", recall)) + +# # Recall_subset@K (子集检索) +# for k in [1, 2, 3]: +# total_score = 0.0 +# for i in range(len(dataset)): +# query_id = str(dataset[i]["reference"]) +# # 双重过滤:子集成员 + 排除参考图像 +# subset_results = [ +# rid for rid in nn_result[i] +# if rid in all_set_member_idx[i] and rid != query_id +# ] +# top_k_subset = subset_results[:k] + +# best_score = 0.0 +# for target_id, score in all_target_captions_soft[i].items(): +# if target_id in top_k_subset: +# best_score = max(best_score, score) +# total_score += best_score +# recall = total_score / len(dataset) * 100 + +# out.append((f"recall_inset_top{k}_correct_composition", recall)) + +# # 打印和保存结果 +# print("\n" + "="*30 + " Evaluation Results " + "="*30) +# for metric, value in out: +# print(f"{metric:<40}: {value:.4f}") + +# output_dir = os.path.dirname(retrieval_results_path) +# output_filename = "evaluation_results_quchong.txt" +# output_path = os.path.join(output_dir, output_filename) + +# with open(output_path, 'w') as f: +# f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n") +# for metric, value in out: +# line = f"{metric:<40}: {value:.4f}\n" +# f.write(line) +# print(f"\nResults saved to {output_path}") + +# return out + + +# # 这是因为dataset中的样本的reference字段大量重复(CIRR数据集,每个参考图像对应多个不同的目标描述caption) +# def evaluate_cirr_scores() -> List[Tuple[str, float]]: +# dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json" +# retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_noquery_a08.json" + +# # 加载数据集 +# with open(dataset_path, 'r') as f: +# dataset = json.load(f) + +# # 加载检索结果 +# with open(retrieval_results_path, 'r') as f: +# retrieval_results = json.load(f) + +# # 调试:检查数据一致性 +# print(len(dataset)) +# print(len(retrieval_results)) + +# # 过滤有效样本(确保类型严格一致) +# valid_samples = [] +# for sample in dataset: +# query_id = str(sample["reference"]) +# if query_id in retrieval_results: +# valid_samples.append(sample) + +# print(f"Valid samples count: {len(valid_samples)} (must be 2086)") + +# # 构建数据结构 +# all_target_captions_soft = [] +# all_set_member_idx = [] +# nn_result = [] + +# for sample in valid_samples: +# all_target_captions_soft.append(sample["target_soft"]) +# all_set_member_idx.append(sample["img_set"]["members"]) +# query_id = str(sample["reference"]) +# # 直接获取该query_id的检索结果列表 +# retrieved_items = retrieval_results[query_id] +# nn_result.append([item["retrieved_id"] for item in retrieved_items]) + +# out = [] +# # Recall@K 计算(使用有效样本数量和过滤后的数据) +# for k in [1, 5, 10, 50]: +# total_score = 0.0 +# for i in range(len(valid_samples)): +# query_id = str(valid_samples[i]["reference"]) +# filtered_results = [rid for rid in nn_result[i] if rid != query_id] +# top_k = filtered_results[:k] + +# best_score = 0.0 +# for target_id, score in all_target_captions_soft[i].items(): +# if target_id in top_k: +# best_score = max(best_score, score) +# total_score += best_score + +# recall = total_score / len(valid_samples) * 100 +# print(len(valid_samples)) +# out.append((f"recall_top{k}_correct_composition", recall)) + +# # Recall_subset@K 计算 +# for k in [1, 2, 3]: +# total_score = 0.0 +# for i in range(len(valid_samples)): +# query_id = str(valid_samples[i]["reference"]) +# subset_results = [ +# rid for rid in nn_result[i] +# if rid in all_set_member_idx[i] and rid != query_id +# ] +# top_k_subset = subset_results[:k] + +# best_score = 0.0 +# for target_id, score in all_target_captions_soft[i].items(): +# if target_id in top_k_subset: +# best_score = max(best_score, score) +# total_score += best_score + +# recall = total_score / len(valid_samples) * 100 +# out.append((f"recall_inset_top{k}_correct_composition", recall)) + +# # 输出结果(保持不变) +# print("\n" + "="*30 + " Evaluation Results " + "="*30) +# for metric, value in out: +# print(f"{metric:<40}: {value:.4f}") + +# output_dir = os.path.dirname(retrieval_results_path) +# output_filename = "evaluation_results_valid.txt" +# output_path = os.path.join(output_dir, output_filename) + +# with open(output_path, 'w') as f: +# f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n") +# for metric, value in out: +# line = f"{metric:<40}: {value:.4f}\n" +# f.write(line) +# print(f"\nResults saved to {output_path}") + +# return out + +# # 可以截取前n个计算分数 +# def evaluate_cirr_scores() -> List[Tuple[str, float]]: + +# # 设置数据集和检索结果的路径 +# dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json" +# retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_quchongnew.json" + +# # 加载数据集 +# with open(dataset_path, 'r') as f: +# dataset = json.load(f) +# print(f"Total samples in dataset: {len(dataset)}") + +# # 加载检索结果 +# with open(retrieval_results_path, 'r') as f: +# retrieval_results = json.load(f) +# print(f"Total queries in retrieval results: {len(retrieval_results)}") + +# # 数据结构初始化 +# valid_query_ids = [] # 存储匹配成功的query_id +# all_target_captions_soft = [] # 存储匹配样本的target_soft +# all_set_member_idx = [] # 存储匹配样本的set members +# nn_result = [] # 存储匹配样本的检索结果 + +# # 构建匹配数据结构(新增双重验证) +# for sample in dataset: +# query_id = str(sample["reference"]) +# caption = sample["caption"] + +# # 获取对应的检索结果条目 +# retrieved_entry = retrieval_results.get(query_id) + +# # 双重验证:query_id存在且caption匹配 +# if retrieved_entry and retrieved_entry["original_caption"] == caption: +# valid_query_ids.append(query_id) +# all_target_captions_soft.append(sample["target_soft"]) +# all_set_member_idx.append(sample["img_set"]["members"]) +# # 提取检索结果并转换为id列表 +# retrieved_items = retrieved_entry["retrieved_results"] +# nn_result.append([item["retrieved_id"] for item in retrieved_items]) + +# print(f"Valid matched samples after verification: {len(valid_query_ids)}") + +# ############################################## +# # 新增修改:仅取前100个样本 +# # 截取前100个有效样本(若不足100则取全部) +# valid_query_ids = valid_query_ids[:300] +# all_target_captions_soft = all_target_captions_soft[:300] +# all_set_member_idx = all_set_member_idx[:300] +# nn_result = nn_result[:300] +# total_samples = len(valid_query_ids) +# ############################################## + +# print(f"Evaluating on first {total_samples} samples") + +# # 计算召回指标(仅使用有效样本) +# out = [] +# total_samples = len(valid_query_ids) + +# # Recall@K (全局检索) +# for k in [1, 5, 10, 50]: +# total_score = 0.0 +# for i in range(total_samples): +# # 过滤参考图像本身 +# filtered_results = [rid for rid in nn_result[i] if rid != valid_query_ids[i]] +# top_k = filtered_results[:k] + +# # 计算最佳匹配分数 +# best_score = max( +# (score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k), +# default=0.0 +# ) +# total_score += best_score + +# recall = (total_score / total_samples) * 100 +# out.append((f"recall_top{k}_correct_composition", recall)) + +# # Recall_subset@K (子集检索) +# for k in [1, 2, 3]: +# total_score = 0.0 +# for i in range(total_samples): +# # 双重过滤:子集成员且非参考图像 +# subset_results = [ +# rid for rid in nn_result[i] +# if rid in all_set_member_idx[i] and rid != valid_query_ids[i] +# ] +# top_k_subset = subset_results[:k] + +# # 计算最佳匹配分数 +# best_score = max( +# (score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k_subset), +# default=0.0 +# ) +# total_score += best_score + +# recall = (total_score / total_samples) * 100 +# out.append((f"recall_inset_top{k}_correct_composition", recall)) + +# # 打印和保存结果 +# print("\n" + "="*30 + " Evaluation Results " + "="*30) +# for metric, value in out: +# print(f"{metric:<40}: {value:.4f}") + +# output_dir = os.path.dirname(retrieval_results_path) +# output_filename = "evaluation_results_quchongnew.txt" +# output_path = os.path.join(output_dir, output_filename) + +# with open(output_path, 'w') as f: +# f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n") +# for metric, value in out: +# line = f"{metric:<40}: {value:.4f}\n" +# f.write(line) +# print(f"\nResults saved to {output_path}") + +# return out + + +# 目前来看这个函数没有太大问题 +def evaluate_cirr_scores() -> List[Tuple[str, float]]: + + # 设置数据集和检索结果的路径 + dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json" + retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_quchongnew.json" + + # 加载数据集 + with open(dataset_path, 'r') as f: + dataset = json.load(f) + print(f"Total samples in dataset: {len(dataset)}") + + # 加载检索结果 + with open(retrieval_results_path, 'r') as f: + retrieval_results = json.load(f) + print(f"Total queries in retrieval results: {len(retrieval_results)}") + + # 数据结构初始化 + valid_query_ids = [] # 存储匹配成功的query_id + all_target_captions_soft = [] # 存储匹配样本的target_soft + all_set_member_idx = [] # 存储匹配样本的set members + nn_result = [] # 存储匹配样本的检索结果 + + # 构建匹配数据结构(新增双重验证) + for sample in dataset: + query_id = str(sample["reference"]) + caption = sample["caption"] + + # 获取对应的检索结果条目 + retrieved_entry = retrieval_results.get(query_id) + + # 双重验证:query_id存在且caption匹配 + if retrieved_entry and retrieved_entry["original_caption"] == caption: + valid_query_ids.append(query_id) + all_target_captions_soft.append(sample["target_soft"]) + all_set_member_idx.append(sample["img_set"]["members"]) + # 提取检索结果并转换为id列表 + retrieved_items = retrieved_entry["retrieved_results"] + nn_result.append([item["retrieved_id"] for item in retrieved_items]) + + print(f"Valid matched samples after verification: {len(valid_query_ids)}") + + # 计算召回指标(仅使用有效样本) + out = [] + total_samples = len(valid_query_ids) + + # Recall@K (全局检索) + for k in [1, 5, 10, 50]: + total_score = 0.0 + for i in range(total_samples): + # 过滤参考图像本身 + filtered_results = [rid for rid in nn_result[i] if rid != valid_query_ids[i]] + top_k = filtered_results[:k] + + # 计算最佳匹配分数 + best_score = max( + (score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k), + default=0.0 + ) + total_score += best_score + + recall = (total_score / total_samples) * 100 + out.append((f"recall_top{k}_correct_composition", recall)) + + # Recall_subset@K (子集检索) + for k in [1, 2, 3]: + total_score = 0.0 + for i in range(total_samples): + # 双重过滤:子集成员且非参考图像 + subset_results = [ + rid for rid in nn_result[i] + if rid in all_set_member_idx[i] and rid != valid_query_ids[i] + ] + top_k_subset = subset_results[:k] + + # 计算最佳匹配分数 + best_score = max( + (score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k_subset), + default=0.0 + ) + total_score += best_score + + recall = (total_score / total_samples) * 100 + out.append((f"recall_inset_top{k}_correct_composition", recall)) + + # 打印和保存结果 + print("\n" + "="*30 + " Evaluation Results " + "="*30) + for metric, value in out: + print(f"{metric:<40}: {value:.4f}") + + output_dir = os.path.dirname(retrieval_results_path) + output_filename = "evaluation_results_quchongnew.txt" + output_path = os.path.join(output_dir, output_filename) + + with open(output_path, 'w') as f: + f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n") + for metric, value in out: + line = f"{metric:<40}: {value:.4f}\n" + f.write(line) + print(f"\nResults saved to {output_path}") + + return out + + +if __name__ == "__main__": + # process_circo_val_images() + # process_circo_test_images() + process_cirr_images() + + + + + + +# def process_circo_images(): + +# if not all([vlm_model, sam_predictor, groundingdino_model]): +# raise RuntimeError("Required models not initialized") + +# # Define paths +# dev_dir = Path("/home/zt/data/BrushEdit/CIRCO/img_raw/dev") +# cap_file = Path("/home/zt/data/BrushEdit/CIRCO/annotations/val.json") + +# output_dirs = { +# "edited": Path("/home/zt/data/BrushEdit/CIRCO/img_paint/circo_edited"), +# "mask": Path("/home/zt/data/BrushEdit/CIRCO/img_paint/circo_mask"), +# "masked": Path("/home/zt/data/BrushEdit/CIRCO/img_paint/circo_masked") +# } +# output_json_path = Path("/home/zt/data/BrushEdit/CIRCO/image_paint.json") +# descriptions = {} + +# # Create output directories +# for dir_path in output_dirs.values(): +# dir_path.mkdir(parents=True, exist_ok=True) + +# # Load captions +# with open(cap_file, 'r') as f: +# captions = json.load(f) + +# for img_path in dev_dir.glob("*.jpg"): +# base_name = img_path.stem +# # 提取后六位作为参考ID +# reference_part = base_name[-6:] +# # 将JSON中的reference_img_id转换为字符串后比较 +# caption = next( +# (item["relative_caption"] for item in captions +# if str(item.get("reference_img_id")) == reference_part), +# None +# ) + +# if not caption: +# print(f"Warning: No caption for {base_name}") +# continue + +# try: +# # 构造空alpha通道(全0) +# rgb_image = Image.open(img_path).convert("RGB") +# empty_alpha = Image.new("L", rgb_image.size, 0) # 全透明alpha通道 +# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha)) + +# # 调用init_img初始化 +# base = {"background": image, "layers": [image]} +# init_results = init_img( +# base=base, +# init_type="custom", # 使用自定义初始化 +# prompt=caption, +# aspect_ratio="Custom resolution", +# example_change_times=0 +# ) + +# # 获取初始化后的参数 +# input_image = init_results[0] +# original_image = init_results[1] +# original_mask = init_results[2] + +# # 正确设置process参数 +# result_images, mask_images, masked_images, _, target_description, _ = process( +# input_image=input_image, +# original_image=original_image, +# original_mask=original_mask, # 传递初始化后的mask +# prompt=caption, +# negative_prompt="ugly, low quality", +# control_strength=1.0, +# seed=648464818, +# randomize_seed=False, +# guidance_scale=7.5, +# num_inference_steps=50, +# num_samples=1, +# blending=True, +# category=None, +# target_prompt="", +# resize_default=True, +# aspect_ratio_name="Custom resolution", +# invert_mask_state=False +# ) + +# # Save images +# output_dirs["edited"].mkdir(exist_ok=True) +# result_images[0].save(output_dirs["edited"] / f"{base_name}.jpg") +# mask_images[0].save(output_dirs["mask"] / f"{base_name}_mask.jpg") +# masked_images[0].save(output_dirs["masked"] / f"{base_name}_masked.jpg") + +# # Generate BLIP2 description +# blip2_desc, _ = generate_blip2_description(input_image) + +# descriptions[base_name] = { +# "original_caption": caption, +# "blip2_description": blip2_desc, +# "llm_enhanced_caption": target_description +# } + +# with open(output_json_path, 'w') as f: +# json.dump(descriptions, f, indent=4) # indent保持可读性 + +# print(f"Processed {base_name}") + +# except Exception as e: +# print(f"Error processing {base_name}: {str(e)}") +# continue + +# print("Processing completed!") + + +# @torch.no_grad() +# def batch_mix_and_search_circo( +# json_path: str = "/home/zt/data/BrushEdit/CIRCO/image_paint.json", +# image_dir: str = "/home/zt/data/BrushEdit/CIRCO/img_paint/circo_edited", +# alpha: float = 0.6, +# batch_size: int = 32, +# output_json_path: str = "circo_retrieval_results.json" +# ) -> Dict[str, List[Dict]]: + +# # 加载索引和元数据 +# index = faiss.read_index("/home/zt/data/BrushEdit/CIRCO/img_raw/dev/dev_knn.index") +# metadata = pd.read_parquet("/home/zt/data/BrushEdit/CIRCO/img_raw/dev/dev_embedding_folder/metadata/metadata_0.parquet") +# all_index_ids = metadata["image_path"].tolist() + +# # 加载并验证输入数据 +# with open(json_path) as f: +# samples = json.load(f) + +# valid_samples = [] +# image_dir = Path(image_dir) +# for image_id, sample_info in samples.items(): # 关键修改点 +# img_path = image_dir / f"{image_id}.jpg" # 直接用字典的key作为image_id +# if img_path.exists(): +# valid_samples.append( +# (image_id, img_path, sample_info['llm_enhanced_caption']) # 从value中取caption +# ) + +# # 初始化结果字典 +# results = {} +# total_samples = len(valid_samples) + +# # 分批次处理 +# for batch_idx in range(0, total_samples, batch_size): +# batch_end = min(batch_idx + batch_size, total_samples) +# current_batch = valid_samples[batch_idx:batch_end] + +# # 批量处理图像 +# batch_images = [Image.open(s[1]).convert("RGB") for s in current_batch] +# image_inputs = clip_processor(images=batch_images, return_tensors="pt").to(device) +# image_features = clip_model.get_image_features(**image_inputs) +# image_features = nn.functional.normalize(image_features, dim=-1) + +# # 批量处理文本 +# batch_texts = [s[2] for s in current_batch] +# text_inputs = clip_processor( +# text=batch_texts, +# return_tensors="pt", +# padding=True, +# truncation=True +# ).to(device) +# text_features = clip_model.get_text_features(**text_inputs) +# text_features = nn.functional.normalize(text_features, dim=-1) + +# # 混合特征 +# mixed_features = (1 - alpha) * image_features + alpha * text_features +# mixed_features = nn.functional.normalize(mixed_features, dim=-1) + +# # 批量检索 +# query_features = mixed_features.cpu().numpy().astype("float32") +# distances, indices = index.search(query_features, 100) + +# # 保存当前批次结果 +# for (sample_id, _, _), dist_row, idx_row in zip(current_batch, distances, indices): +# results[sample_id] = [] +# for distance, idx in zip(dist_row, idx_row): +# if 0 <= idx < len(all_index_ids): +# results[sample_id].append({ +# "retrieved_id": all_index_ids[idx], +# "score": float(distance) +# }) + +# # 保存结果到JSON +# with open(output_json_path, 'w') as f: +# json.dump(results, f, indent=2, ensure_ascii=False) + +# print("Retrieving completed!") +# return results + + +# block = gr.Blocks( +# theme=gr.themes.Soft( +# radius_size=gr.themes.sizes.radius_none, +# text_size=gr.themes.sizes.text_md +# ) +# ) + +# with block as demo: +# with gr.Row(): +# with gr.Column(): +# gr.HTML(head) +# gr.HTML(descriptions) +# with gr.Accordion(label="🧭 小白也能秒懂的魔法指南:", open=True, elem_id="accordion"): +# with gr.Row(equal_height=True): +# gr.HTML(instructions) + +# original_image = gr.State(value=None) +# original_mask = gr.State(value=None) +# category = gr.State(value=None) +# status = gr.State(value=None) +# invert_mask_state = gr.State(value=False) +# example_change_times = gr.State(value=0) +# deepseek_verified = gr.State(value=False) +# blip2_description = gr.State(value="") +# enhanced_description = gr.State(value="") +# decomposed_description = gr.State(value="") + +# with gr.Row(): +# with gr.Column(): +# with gr.Group(): +# input_image = gr.ImageEditor( +# label="参考图像", +# type="pil", +# brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"), +# layers = False, +# interactive=True, +# # height=1024, +# height=408, +# sources=["upload"], +# placeholder="🫧 点击此处或下面的图标上传图像 🫧", +# ) +# prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=2) + +# with gr.Group(): + +# with gr.Row(): +# mask_button = gr.Button("💎 掩膜生成") +# invert_mask_button = gr.Button("👐 掩膜翻转") +# # random_mask_button = gr.Button("⭕️ 随机掩膜") +# with gr.Row(): +# masked_gallery = gr.Gallery(label="掩膜图像", show_label=True, preview=True, height=360) +# mask_gallery = gr.Gallery(label="掩膜", show_label=True, preview=True, height=360) + + +# with gr.Accordion("高级掩膜选项", open=False, elem_id="accordion1"): +# dilation_size = gr.Slider( +# label="每次放缩的尺度: ", show_label=True,minimum=0, maximum=50, step=1, value=20 +# ) +# with gr.Row(): +# dilation_mask_button = gr.Button("放大掩膜") +# erosion_mask_button = gr.Button("缩小掩膜") + +# moving_pixels = gr.Slider( +# label="每次移动的像素:", show_label=True, minimum=0, maximum=50, value=4, step=1 +# ) +# with gr.Row(): +# move_left_button = gr.Button("左移") +# move_right_button = gr.Button("右移") +# with gr.Row(): +# move_up_button = gr.Button("上移") +# move_down_button = gr.Button("下移") + + + +# with gr.Column(): +# with gr.Row(): +# deepseek_key = gr.Textbox(label="LLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1, container=False, type="password") +# verify_deepseek = gr.Button("🔑 验证密钥", scale=0) + +# blip2_output = gr.Textbox(label="1. 原图描述(BLIP2生成)", placeholder="🖼️ 上传图片后自动生成图片描述 🖼️", lines=2, interactive=True) + +# with gr.Row(): +# target_prompt = gr.Textbox(label="2. 整合增强版", lines=4, interactive=True, placeholder="🚀 点击图片编辑同时生成增强描述 or 点击右侧按钮单独生成增强描述 🚀") +# enhance_button = gr.Button("✨ 智能整合") + +# with gr.Row(): +# decomposed_output = gr.Textbox(label="3. 结构分解版", lines=4, interactive=True, placeholder="📝 点击右侧按钮生成结构化描述 📝") +# decompose_button = gr.Button("🔧 结构分解") + + + +# with gr.Group(): +# run_button = gr.Button("💫 图像编辑") +# result_gallery = gr.Gallery(label="💥 编辑结果", show_label=True, columns=2, preview=True, height=360) + +# with gr.Accordion("高级编辑选项", open=False, elem_id="accordion1"): +# vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True) + +# with gr.Group(): +# with gr.Row(): +# # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1) +# GPT4o_KEY = gr.Textbox(label="VLM API密钥", value="", lines=2, container=False, type="password") +# GPT4o_KEY_submit = gr.Button("🔑 验证密钥", scale=0) + +# aspect_ratio = gr.Dropdown(label="输出纵横比", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO) +# resize_default = gr.Checkbox(label="短边裁剪到640像素", value=True) +# base_model_dropdown = gr.Dropdown(label="基础模型", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True) +# negative_prompt = gr.Textbox(label="负向提示", max_lines=5, placeholder="请输入你的负向提示", value='ugly, low quality',lines=1) +# control_strength = gr.Slider(label="控制强度: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01) +# with gr.Group(): +# seed = gr.Slider(label="种子: ", minimum=0, maximum=2147483647, step=1, value=648464818) +# randomize_seed = gr.Checkbox(label="随机种子", value=False) +# blending = gr.Checkbox(label="混合模式", value=True) +# num_samples = gr.Slider(label="生成个数", minimum=0, maximum=4, step=1, value=2) +# with gr.Group(): +# with gr.Row(): +# guidance_scale = gr.Slider(label="指导尺度", minimum=1, maximum=12, step=0.1, value=7.5) +# num_inference_steps = gr.Slider(label="推理步数", minimum=1, maximum=50, step=1, value=50) +# # target_prompt = gr.Textbox(label="Input Target Prompt", max_lines=5, placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)", value='', lines=2) + + + +# init_type = gr.Textbox(label="Init Name", value="", visible=False) +# example_type = gr.Textbox(label="Example Name", value="", visible=False) + +# with gr.Row(): +# reset_button = gr.Button("Reset") +# retrieve_button = gr.Button("🔍 开始检索") + +# with gr.Row(): +# retrieve_gallery = gr.Gallery(label="🎊 检索结果", show_label=True, columns=10, preview=True, height=660) + + +# with gr.Row(): +# example = gr.Examples( +# label="Quick Example", +# examples=EXAMPLES, +# inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown], +# examples_per_page=10, +# cache_examples=False, +# visible=False +# ) + + +# with gr.Accordion(label="🎬 隐藏玩法大公开:", open=True, elem_id="accordion"): +# with gr.Row(equal_height=True): +# gr.HTML(tips) + +# with gr.Row(): +# gr.HTML(citation) + +# ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery. +# ## And we need to solve the conflict between the upload and change example functions. +# input_image.upload( +# init_img, +# [input_image, init_type, prompt, aspect_ratio, example_change_times], +# [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times] + +# ) +# example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times]) + + +# ## vlm and base model dropdown +# vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status]) +# base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status]) + +# GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown]) + + +# ips=[input_image, +# original_image, +# original_mask, +# prompt, +# negative_prompt, +# control_strength, +# seed, +# randomize_seed, +# guidance_scale, +# num_inference_steps, +# num_samples, +# blending, +# category, +# target_prompt, +# resize_default, +# aspect_ratio, +# invert_mask_state] + +# ## run brushedit +# run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state]) + + +# ## mask func +# mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category]) +# # random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask]) +# dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask]) +# erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask]) +# invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state]) + +# ## reset func +# reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, blip2_output, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, blip2_output, target_prompt, resize_default, invert_mask_state]) + +# input_image.upload(fn=generate_blip2_description, inputs=[input_image], outputs=[blip2_description, blip2_output]) +# verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key]) +# # enhance_button.click(fn=enhance_description, inputs=[blip2_output, prompt], outputs=[enhanced_description, enhanced_output]) +# # decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output]) +# # retrieve_button.click(fn=mix_and_search, inputs=[enhanced_output, result_gallery], outputs=[retrieve_gallery]) + +# enhance_button.click(fn=enhance_description, inputs=[blip2_output, prompt], outputs=[enhanced_description, target_prompt]) +# decompose_button.click(fn=decompose_description, inputs=[target_prompt], outputs=[decomposed_description, decomposed_output]) +# retrieve_button.click(fn=mix_and_search, inputs=[target_prompt, result_gallery], outputs=[retrieve_gallery]) + +# demo.launch(server_name="0.0.0.0", server_port=12345, share=True) + +