diff --git "a/brushedit_app_new_jietu2.py" "b/brushedit_app_new_jietu2.py"
new file mode 100644--- /dev/null
+++ "b/brushedit_app_new_jietu2.py"
@@ -0,0 +1,3453 @@
+##!/usr/bin/python3
+# -*- coding: utf-8 -*-
+import os, random, sys
+import numpy as np
+import requests
+import torch
+
+from pathlib import Path
+import pandas as pd
+import concurrent.futures
+import faiss
+import gradio as gr
+
+from pathlib import Path
+import os
+import json
+
+import logging
+from datetime import datetime
+import time
+
+from PIL import Image
+from torch import nn
+from typing import Dict, List, Tuple
+
+import torch.nn.functional as F
+from huggingface_hub import hf_hub_download, snapshot_download
+from scipy.ndimage import binary_dilation, binary_erosion
+from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration, Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
+
+from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
+from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
+from diffusers.image_processor import VaeImageProcessor
+
+
+from app.src.vlm_pipeline_noqwen import (
+ vlm_response_editing_type,
+ vlm_response_object_wait_for_edit,
+ vlm_response_mask,
+ vlm_response_prompt_after_apply_instruction
+)
+from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
+from app.utils.utils import load_grounding_dino_model
+
+from app.src.vlm_template import vlms_template
+from app.src.base_model_template import base_models_template
+from app.src.aspect_ratio_template import aspect_ratios
+
+from openai import OpenAI
+base_openai_url = "https://api.deepseek.com/"
+base_api_key = "sk-d145b963a92649a88843caeb741e8bbc"
+
+
+
+from transformers import BlipProcessor, BlipForConditionalGeneration
+from transformers import Blip2Processor, Blip2ForConditionalGeneration
+
+from transformers import CLIPProcessor, CLIPModel
+
+from app.deepseek.instructions import (
+ create_apply_editing_messages_deepseek,
+ create_decomposed_query_messages_deepseek
+)
+from clip_retrieval.clip_client import ClipClient
+
+
+#### Description ####
+logo = r"""
+
+"""
+head = r"""
+
+
+ 基于扩散模型先验和大语言模型的
+ 零样本组合查询图像检索
+
+
+
+
+"""
+
+descriptions = r"""
+
+
+
+ 🎨
+ 一个无需训练的组合图像检索的交互系统,支持通过文本指令修改参考图像并进行语义检索。
+
+
+"""
+
+instructions = r"""
+
+
+
+ - 上传图像:点击画布或上传按钮添加参考图像
+ - 输入指令:在文本框中描述您想对图像进行的修改
+ - 生成掩膜:使用掩膜工具精确控制编辑区域
+ - 智能增强:系统会自动生成图像描述,并可进一步优化
+ - 执行编辑:点击"图像编辑"按钮生成修改后的图像
+ - 检索结果:点击"开始检索"获取相似图像结果
+
+
+"""
+
+tips = r"""
+
+
+ 🖌️ 图像编辑功能
+
+
+ - 支持画笔工具创建精确掩膜
+ - 提供掩膜放大/缩小、翻转等操作
+ - 多参数控制生成效果
+
+
+
+ 🧠 智能描述系统
+
+
+ - 自动生成图像描述(BLIP2)
+ - 指令增强生成优化提示词
+ - 结构化分解复杂描述
+
+
+
+ 🔍 高级检索能力
+
+
+ - 零样本学习无需训练
+ - 结合视觉-语言模型理解
+ - 支持多模态查询(图像+文本)
+
+
+
+ ⚙️ 技术参数调整
+
+
+ - 可调节控制强度、引导尺度等
+ - 支持多种基础模型选择
+ - 自定义输出尺寸和比例
+
+
+
+ 💡 使用建议
+
+
+ - 清晰具体的指令会得到更好结果
+ - 合理使用掩膜提高编辑精度
+ - 尝试不同参数组合优化效果
+
+
+"""
+
+citation = r"""
+"""
+
+# - - - - - examples - - - - - #
+EXAMPLES = [
+
+ [
+ Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
+ "add a magic hat on frog head.",
+ 642087011,
+ "frog",
+ "frog",
+ True,
+ False,
+ "GPT4-o (Highly Recommended)"
+ ],
+ [
+ Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
+ "replace the background to ancient China.",
+ 648464818,
+ "chinese_girl",
+ "chinese_girl",
+ True,
+ False,
+ "GPT4-o (Highly Recommended)"
+ ],
+ [
+ Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
+ "remove the deer.",
+ 648464818,
+ "angel_christmas",
+ "angel_christmas",
+ False,
+ False,
+ "GPT4-o (Highly Recommended)"
+ ],
+ [
+ Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
+ "add a wreath on head.",
+ 648464818,
+ "sunflower_girl",
+ "sunflower_girl",
+ True,
+ False,
+ "GPT4-o (Highly Recommended)"
+ ],
+ [
+ Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
+ "add a butterfly fairy.",
+ 648464818,
+ "girl_on_sun",
+ "girl_on_sun",
+ True,
+ False,
+ "GPT4-o (Highly Recommended)"
+ ],
+ [
+ Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
+ "remove the christmas hat.",
+ 642087011,
+ "spider_man_rm",
+ "spider_man_rm",
+ False,
+ False,
+ "GPT4-o (Highly Recommended)"
+ ],
+ [
+ Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
+ "remove the flower.",
+ 642087011,
+ "anime_flower",
+ "anime_flower",
+ False,
+ False,
+ "GPT4-o (Highly Recommended)"
+ ],
+ [
+ Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
+ "replace the clothes to a delicated floral skirt.",
+ 648464818,
+ "chenduling",
+ "chenduling",
+ True,
+ False,
+ "GPT4-o (Highly Recommended)"
+ ],
+ [
+ Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
+ "make the hedgehog in Italy.",
+ 648464818,
+ "hedgehog_rp_bg",
+ "hedgehog_rp_bg",
+ True,
+ False,
+ "GPT4-o (Highly Recommended)"
+ ],
+
+]
+
+INPUT_IMAGE_PATH = {
+ "frog": "./assets/frog/frog.jpeg",
+ "chinese_girl": "./assets/chinese_girl/chinese_girl.png",
+ "angel_christmas": "./assets/angel_christmas/angel_christmas.png",
+ "sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
+ "girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
+ "spider_man_rm": "./assets/spider_man_rm/spider_man.png",
+ "anime_flower": "./assets/anime_flower/anime_flower.png",
+ "chenduling": "./assets/chenduling/chengduling.jpg",
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
+}
+MASK_IMAGE_PATH = {
+ "frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
+ "chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
+ "angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
+ "sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
+ "girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
+ "spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
+ "anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
+ "chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
+}
+MASKED_IMAGE_PATH = {
+ "frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
+ "chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
+ "angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
+ "sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
+ "girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
+ "spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
+ "anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
+ "chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
+}
+OUTPUT_IMAGE_PATH = {
+ "frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
+ "chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
+ "angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
+ "sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
+ "girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
+ "spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
+ "anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
+ "chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
+ "hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
+}
+
+
+
+VLM_MODEL_NAMES = list(vlms_template.keys())
+DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
+
+
+BASE_MODELS = list(base_models_template.keys())
+DEFAULT_BASE_MODEL = "realisticVision (Default)"
+
+ASPECT_RATIO_LABELS = list(aspect_ratios)
+DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
+
+
+## init device
+try:
+ if torch.cuda.is_available():
+ device = "cuda"
+ elif sys.platform == "darwin" and torch.backends.mps.is_available():
+ device = "mps"
+ else:
+ device = "cpu"
+except:
+ device = "cpu"
+
+# ## init torch dtype
+# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
+# torch_dtype = torch.bfloat16
+# else:
+# torch_dtype = torch.float16
+
+# if device == "mps":
+# torch_dtype = torch.float16
+
+torch_dtype = torch.float16
+
+
+
+# download hf models
+BrushEdit_path = "models/"
+if not os.path.exists(BrushEdit_path):
+ BrushEdit_path = snapshot_download(
+ repo_id="TencentARC/BrushEdit",
+ local_dir=BrushEdit_path,
+ token=os.getenv("HF_TOKEN"),
+ )
+
+## init default VLM
+vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
+if vlm_processor != "" and vlm_model != "":
+ vlm_model.to(device)
+else:
+ raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
+
+## init default LLM
+llm_model = OpenAI(api_key=base_api_key, base_url=base_openai_url)
+
+## init base model
+base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
+brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
+sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
+groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
+
+
+# input brushnetX ckpt path
+brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
+pipe = StableDiffusionBrushNetPipeline.from_pretrained(
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
+ )
+# speed up diffusion process with faster scheduler and memory optimization
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+# remove following line if xformers is not installed or when using Torch 2.0.
+# pipe.enable_xformers_memory_efficient_attention()
+pipe.enable_model_cpu_offload()
+
+
+## init SAM
+sam = build_sam(checkpoint=sam_path)
+sam.to(device=device)
+sam_predictor = SamPredictor(sam)
+sam_automask_generator = SamAutomaticMaskGenerator(sam)
+
+## init groundingdino_model
+config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
+groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
+
+
+
+blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", revision="51572668da0eb669e01a189dc22abe6088589a24")
+blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", revision="51572668da0eb669e01a189dc22abe6088589a24", torch_dtype=torch.float16).to(device)
+
+clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",torch_dtype=torch.float16).to(device)
+
+# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
+# clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14",torch_dtype=torch.float16).to(device)
+
+
+
+# Ordinary function
+def crop_and_resize(image: Image.Image,
+ target_width: int,
+ target_height: int) -> Image.Image:
+ """
+ Crops and resizes an image while preserving the aspect ratio.
+
+ Args:
+ image (Image.Image): Input PIL image to be cropped and resized.
+ target_width (int): Target width of the output image.
+ target_height (int): Target height of the output image.
+
+ Returns:
+ Image.Image: Cropped and resized image.
+ """
+ # Original dimensions
+ original_width, original_height = image.size
+ original_aspect = original_width / original_height
+ target_aspect = target_width / target_height
+
+ # Calculate crop box to maintain aspect ratio
+ if original_aspect > target_aspect:
+ # Crop horizontally
+ new_width = int(original_height * target_aspect)
+ new_height = original_height
+ left = (original_width - new_width) / 2
+ top = 0
+ right = left + new_width
+ bottom = original_height
+ else:
+ # Crop vertically
+ new_width = original_width
+ new_height = int(original_width / target_aspect)
+ left = 0
+ top = (original_height - new_height) / 2
+ right = original_width
+ bottom = top + new_height
+
+ # Crop and resize
+ cropped_image = image.crop((left, top, right, bottom))
+ resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
+ return resized_image
+
+
+def resize(image: Image.Image,
+ target_width: int,
+ target_height: int) -> Image.Image:
+ """
+ Crops and resizes an image while preserving the aspect ratio.
+
+ Args:
+ image (Image.Image): Input PIL image to be cropped and resized.
+ target_width (int): Target width of the output image.
+ target_height (int): Target height of the output image.
+
+ Returns:
+ Image.Image: Cropped and resized image.
+ """
+ # Original dimensions
+ resized_image = image.resize((target_width, target_height), Image.NEAREST)
+ return resized_image
+
+
+def move_mask_func(mask, direction, units):
+ binary_mask = mask.squeeze()>0
+ rows, cols = binary_mask.shape
+ moved_mask = np.zeros_like(binary_mask, dtype=bool)
+
+ if direction == 'down':
+ # move down
+ moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
+
+ elif direction == 'up':
+ # move up
+ moved_mask[:rows - units, :] = binary_mask[units:, :]
+
+ elif direction == 'right':
+ # move left
+ moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
+
+ elif direction == 'left':
+ # move right
+ moved_mask[:, :cols - units] = binary_mask[:, units:]
+
+ return moved_mask
+
+
+def random_mask_func(mask, dilation_type='square', dilation_size=20):
+ # Randomly select the size of dilation
+ binary_mask = mask.squeeze()>0
+
+ if dilation_type == 'square_dilation':
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
+ dilated_mask = binary_dilation(binary_mask, structure=structure)
+ elif dilation_type == 'square_erosion':
+ structure = np.ones((dilation_size, dilation_size), dtype=bool)
+ dilated_mask = binary_erosion(binary_mask, structure=structure)
+ elif dilation_type == 'bounding_box':
+ # find the most left top and left bottom point
+ rows, cols = np.where(binary_mask)
+ if len(rows) == 0 or len(cols) == 0:
+ return mask # return original mask if no valid points
+
+ min_row = np.min(rows)
+ max_row = np.max(rows)
+ min_col = np.min(cols)
+ max_col = np.max(cols)
+
+ # create a bounding box
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
+ dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
+
+ elif dilation_type == 'bounding_ellipse':
+ # find the most left top and left bottom point
+ rows, cols = np.where(binary_mask)
+ if len(rows) == 0 or len(cols) == 0:
+ return mask # return original mask if no valid points
+
+ min_row = np.min(rows)
+ max_row = np.max(rows)
+ min_col = np.min(cols)
+ max_col = np.max(cols)
+
+ # calculate the center and axis length of the ellipse
+ center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
+ a = (max_col - min_col) // 2 # half long axis
+ b = (max_row - min_row) // 2 # half short axis
+
+ # create a bounding ellipse
+ y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
+ ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
+ dilated_mask = np.zeros_like(binary_mask, dtype=bool)
+ dilated_mask[ellipse_mask] = True
+ else:
+ ValueError("dilation_type must be 'square' or 'ellipse'")
+
+ # use binary dilation
+ dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
+ return dilated_mask
+
+
+## Gradio component function
+def update_vlm_model(vlm_name):
+ global vlm_model, vlm_processor
+ if vlm_model is not None:
+ del vlm_model
+ torch.cuda.empty_cache()
+
+ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
+
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
+ if vlm_type == "llava-next":
+ if vlm_processor != "" and vlm_model != "":
+ vlm_model.to(device)
+ return vlm_model_dropdown
+ else:
+ if os.path.exists(vlm_local_path):
+ vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
+ else:
+ if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
+ elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
+ elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
+ elif vlm_name == "llava-v1.6-34b-hf (Preload)":
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
+ elif vlm_name == "llava-next-72b-hf (Preload)":
+ vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
+ vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
+ elif vlm_type == "qwen2-vl":
+ if vlm_processor != "" and vlm_model != "":
+ vlm_model.to(device)
+ return vlm_model_dropdown
+ else:
+ if os.path.exists(vlm_local_path):
+ vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
+ else:
+ if vlm_name == "qwen2-vl-2b-instruct (Preload)":
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
+ elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
+ elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
+ vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
+ vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
+ elif vlm_type == "openai":
+ pass
+ return "success"
+
+
+def update_base_model(base_model_name):
+ global pipe
+ ## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
+ if pipe is not None:
+ del pipe
+ torch.cuda.empty_cache()
+ base_model_path, pipe = base_models_template[base_model_name]
+ if pipe != "":
+ pipe.to(device)
+ else:
+ if os.path.exists(base_model_path):
+ pipe = StableDiffusionBrushNetPipeline.from_pretrained(
+ base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
+ )
+ # pipe.enable_xformers_memory_efficient_attention()
+ pipe.enable_model_cpu_offload()
+ else:
+ raise gr.Error(f"The base model {base_model_name} does not exist")
+ return "success"
+
+
+def process_random_mask(input_image,
+ original_image,
+ original_mask,
+ resize_default,
+ aspect_ratio_name,
+ ):
+
+ alpha_mask = input_image["layers"][0].split()[3]
+ input_mask = np.asarray(alpha_mask)
+
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
+ if output_w == "" or output_h == "":
+ output_h, output_w = original_image.shape[:2]
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ else:
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ pass
+ else:
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+
+
+ if input_mask.max() == 0:
+ original_mask = original_mask
+ else:
+ original_mask = input_mask
+
+ if original_mask is None:
+ raise gr.Error('Please generate mask first')
+
+ if original_mask.ndim == 2:
+ original_mask = original_mask[:,:,None]
+
+ dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
+ random_mask = random_mask_func(original_mask, dilation_type).squeeze()
+
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
+
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
+ masked_image = masked_image.astype(original_image.dtype)
+ masked_image = Image.fromarray(masked_image)
+
+
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
+
+
+def process_dilation_mask(input_image,
+ original_image,
+ original_mask,
+ resize_default,
+ aspect_ratio_name,
+ dilation_size=20):
+
+ alpha_mask = input_image["layers"][0].split()[3]
+ input_mask = np.asarray(alpha_mask)
+
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
+ if output_w == "" or output_h == "":
+ output_h, output_w = original_image.shape[:2]
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ else:
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ pass
+ else:
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+
+ if input_mask.max() == 0:
+ original_mask = original_mask
+ else:
+ original_mask = input_mask
+
+ if original_mask is None:
+ raise gr.Error('Please generate mask first')
+
+ if original_mask.ndim == 2:
+ original_mask = original_mask[:,:,None]
+
+ dilation_type = np.random.choice(['square_dilation'])
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
+
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
+
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
+ masked_image = masked_image.astype(original_image.dtype)
+ masked_image = Image.fromarray(masked_image)
+
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
+
+
+def process_erosion_mask(input_image,
+ original_image,
+ original_mask,
+ resize_default,
+ aspect_ratio_name,
+ dilation_size=20):
+ alpha_mask = input_image["layers"][0].split()[3]
+ input_mask = np.asarray(alpha_mask)
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
+ if output_w == "" or output_h == "":
+ output_h, output_w = original_image.shape[:2]
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ else:
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ pass
+ else:
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+
+ if input_mask.max() == 0:
+ original_mask = original_mask
+ else:
+ original_mask = input_mask
+
+ if original_mask is None:
+ raise gr.Error('Please generate mask first')
+
+ if original_mask.ndim == 2:
+ original_mask = original_mask[:,:,None]
+
+ dilation_type = np.random.choice(['square_erosion'])
+ random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
+
+ mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
+
+ masked_image = original_image * (1 - (random_mask[:,:,None]>0))
+ masked_image = masked_image.astype(original_image.dtype)
+ masked_image = Image.fromarray(masked_image)
+
+
+ return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
+
+
+def move_mask_left(input_image,
+ original_image,
+ original_mask,
+ moving_pixels,
+ resize_default,
+ aspect_ratio_name):
+
+ alpha_mask = input_image["layers"][0].split()[3]
+ input_mask = np.asarray(alpha_mask)
+
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
+ if output_w == "" or output_h == "":
+ output_h, output_w = original_image.shape[:2]
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ else:
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ pass
+ else:
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+
+ if input_mask.max() == 0:
+ original_mask = original_mask
+ else:
+ original_mask = input_mask
+
+ if original_mask is None:
+ raise gr.Error('Please generate mask first')
+
+ if original_mask.ndim == 2:
+ original_mask = original_mask[:,:,None]
+
+ moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
+
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
+ masked_image = masked_image.astype(original_image.dtype)
+ masked_image = Image.fromarray(masked_image)
+
+ if moved_mask.max() <= 1:
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
+ original_mask = moved_mask
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
+
+
+def move_mask_right(input_image,
+ original_image,
+ original_mask,
+ moving_pixels,
+ resize_default,
+ aspect_ratio_name):
+ alpha_mask = input_image["layers"][0].split()[3]
+ input_mask = np.asarray(alpha_mask)
+
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
+ if output_w == "" or output_h == "":
+ output_h, output_w = original_image.shape[:2]
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ else:
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ pass
+ else:
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+
+ if input_mask.max() == 0:
+ original_mask = original_mask
+ else:
+ original_mask = input_mask
+
+ if original_mask is None:
+ raise gr.Error('Please generate mask first')
+
+ if original_mask.ndim == 2:
+ original_mask = original_mask[:,:,None]
+
+ moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
+
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
+
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
+ masked_image = masked_image.astype(original_image.dtype)
+ masked_image = Image.fromarray(masked_image)
+
+
+ if moved_mask.max() <= 1:
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
+ original_mask = moved_mask
+
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
+
+
+def move_mask_up(input_image,
+ original_image,
+ original_mask,
+ moving_pixels,
+ resize_default,
+ aspect_ratio_name):
+ alpha_mask = input_image["layers"][0].split()[3]
+ input_mask = np.asarray(alpha_mask)
+
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
+ if output_w == "" or output_h == "":
+ output_h, output_w = original_image.shape[:2]
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ else:
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ pass
+ else:
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+
+ if input_mask.max() == 0:
+ original_mask = original_mask
+ else:
+ original_mask = input_mask
+
+ if original_mask is None:
+ raise gr.Error('Please generate mask first')
+
+ if original_mask.ndim == 2:
+ original_mask = original_mask[:,:,None]
+
+ moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
+
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
+ masked_image = masked_image.astype(original_image.dtype)
+ masked_image = Image.fromarray(masked_image)
+
+ if moved_mask.max() <= 1:
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
+ original_mask = moved_mask
+
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
+
+
+def move_mask_down(input_image,
+ original_image,
+ original_mask,
+ moving_pixels,
+ resize_default,
+ aspect_ratio_name):
+ alpha_mask = input_image["layers"][0].split()[3]
+ input_mask = np.asarray(alpha_mask)
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
+ if output_w == "" or output_h == "":
+ output_h, output_w = original_image.shape[:2]
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ else:
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ pass
+ else:
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+
+ if input_mask.max() == 0:
+ original_mask = original_mask
+ else:
+ original_mask = input_mask
+
+ if original_mask is None:
+ raise gr.Error('Please generate mask first')
+
+ if original_mask.ndim == 2:
+ original_mask = original_mask[:,:,None]
+
+ moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
+ mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
+
+ masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
+ masked_image = masked_image.astype(original_image.dtype)
+ masked_image = Image.fromarray(masked_image)
+
+ if moved_mask.max() <= 1:
+ moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
+ original_mask = moved_mask
+
+ return [masked_image], [mask_image], original_mask.astype(np.uint8)
+
+
+def invert_mask(input_image,
+ original_image,
+ original_mask,
+ ):
+ alpha_mask = input_image["layers"][0].split()[3]
+ input_mask = np.asarray(alpha_mask)
+ if input_mask.max() == 0:
+ original_mask = 1 - (original_mask>0).astype(np.uint8)
+ else:
+ original_mask = 1 - (input_mask>0).astype(np.uint8)
+
+ if original_mask is None:
+ raise gr.Error('Please generate mask first')
+
+ original_mask = original_mask.squeeze()
+ mask_image = Image.fromarray(original_mask*255).convert("RGB")
+
+ if original_mask.ndim == 2:
+ original_mask = original_mask[:,:,None]
+
+ if original_mask.max() <= 1:
+ original_mask = (original_mask * 255).astype(np.uint8)
+
+ masked_image = original_image * (1 - (original_mask>0))
+ masked_image = masked_image.astype(original_image.dtype)
+ masked_image = Image.fromarray(masked_image)
+
+ return [masked_image], [mask_image], original_mask, True
+
+
+def reset_func(input_image,
+ original_image,
+ original_mask,
+ prompt,
+ blip2_output,
+ target_prompt,
+ ):
+ input_image = None
+ original_image = None
+ original_mask = None
+ prompt = ''
+ mask_gallery = []
+ masked_gallery = []
+ result_gallery = []
+ blip2_output = ''
+ target_prompt = ''
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, blip2_output, target_prompt, True, False
+
+
+def update_example(example_type,
+ prompt,
+ example_change_times):
+ input_image = INPUT_IMAGE_PATH[example_type]
+ image_pil = Image.open(input_image).convert("RGB")
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
+ width, height = image_pil.size
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
+ image_pil = image_pil.resize((width_new, height_new))
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
+
+ original_image = np.array(image_pil)
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
+ aspect_ratio = "Custom resolution"
+ example_change_times += 1
+ return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
+
+
+def generate_caption(blip2_processor, blip2_model, input_image, device):
+ image_pil = input_image["background"].convert("RGB")
+ inputs = blip2_processor(images=image_pil, return_tensors="pt").to(device, torch.float16)
+ generated_ids = blip2_model.generate(**inputs)
+ caption = blip2_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
+ return caption
+
+
+def generate_blip2_description(input_image):
+ try:
+ description = generate_caption(blip2_processor, blip2_model, input_image, device)
+ return description, description
+ except Exception as e:
+ return "", f"Caption generation failed: {str(e)}"
+
+
+def generate_target_prompt(input_image,
+ original_image,
+ prompt):
+ # load example image
+ if isinstance(original_image, str):
+ original_image = input_image
+
+ image_caption = generate_caption(blip2_processor, blip2_model, input_image, device)
+
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
+ vlm_processor,
+ vlm_model,
+ llm_model,
+ original_image,
+ image_caption,
+ prompt,
+ device)
+ return prompt_after_apply_instruction
+
+
+def init_img(base,
+ init_type,
+ prompt,
+ aspect_ratio,
+ example_change_times
+ ):
+ image_pil = base["background"].convert("RGB")
+ original_image = np.array(image_pil)
+ if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
+ raise gr.Error('image aspect ratio cannot be larger than 2.0')
+ if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
+ mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
+ masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
+ result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
+ width, height = image_pil.size
+ image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
+ height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
+ image_pil = image_pil.resize((width_new, height_new))
+ mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
+ masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
+ result_gallery[0] = result_gallery[0].resize((width_new, height_new))
+ original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
+ return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
+ else:
+ if aspect_ratio not in ASPECT_RATIO_LABELS:
+ aspect_ratio = "Custom resolution"
+ return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
+
+
+def process_mask(input_image,
+ original_image,
+ prompt,
+ resize_default,
+ aspect_ratio_name):
+ if original_image is None:
+ raise gr.Error('Please upload the input image')
+ if prompt is None:
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
+
+ ## load mask
+ alpha_mask = input_image["layers"][0].split()[3]
+ input_mask = np.array(alpha_mask)
+
+ # load example image
+ if isinstance(original_image, str):
+ original_image = input_image["background"]
+
+ image_caption = generate_caption(blip2_processor, blip2_model, input_image, device)
+
+ if input_mask.max() == 0:
+ # category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
+ category = vlm_response_editing_type(vlm_processor, vlm_model, llm_model, original_image, image_caption, prompt, device)
+
+ object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
+ vlm_model,
+ llm_model,
+ original_image,
+ image_caption,
+ category,
+ prompt,
+ device)
+ # original mask: h,w,1 [0, 255]
+ original_mask = vlm_response_mask(
+ vlm_processor,
+ vlm_model,
+ category,
+ original_image,
+ prompt,
+ object_wait_for_edit,
+ sam,
+ sam_predictor,
+ sam_automask_generator,
+ groundingdino_model,
+ device).astype(np.uint8)
+ else:
+ original_mask = input_mask.astype(np.uint8)
+ category = None
+
+ ## resize mask if needed
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
+ if output_w == "" or output_h == "":
+ output_h, output_w = original_image.shape[:2]
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ else:
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ pass
+ else:
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+
+
+ if original_mask.ndim == 2:
+ original_mask = original_mask[:,:,None]
+
+ mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
+
+ masked_image = original_image * (1 - (original_mask>0))
+ masked_image = masked_image.astype(np.uint8)
+ masked_image = Image.fromarray(masked_image)
+
+ return [masked_image], [mask_image], original_mask.astype(np.uint8), category
+
+
+def process(input_image,
+ original_image,
+ original_mask,
+ prompt,
+ negative_prompt,
+ control_strength,
+ seed,
+ randomize_seed,
+ guidance_scale,
+ num_inference_steps,
+ num_samples,
+ blending,
+ category,
+ target_prompt,
+ resize_default,
+ aspect_ratio_name,
+ invert_mask_state):
+ if original_image is None:
+ if input_image is None:
+ raise gr.Error('Please upload the input image')
+ else:
+ image_pil = input_image["background"].convert("RGB")
+ original_image = np.array(image_pil)
+ if prompt is None or prompt == "":
+ if target_prompt is None or target_prompt == "":
+ raise gr.Error("Please input your instructions, e.g., remove the xxx")
+
+ alpha_mask = input_image["layers"][0].split()[3]
+ input_mask = np.asarray(alpha_mask)
+ output_w, output_h = aspect_ratios[aspect_ratio_name]
+ if output_w == "" or output_h == "":
+ output_h, output_w = original_image.shape[:2]
+
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ else:
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ pass
+ else:
+ if resize_default:
+ short_side = min(output_w, output_h)
+ scale_ratio = 640 / short_side
+ output_w = int(output_w * scale_ratio)
+ output_h = int(output_h * scale_ratio)
+ gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
+ original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
+ original_image = np.array(original_image)
+ if input_mask is not None:
+ input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
+ input_mask = np.array(input_mask)
+ if original_mask is not None:
+ original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
+ original_mask = np.array(original_mask)
+
+ if invert_mask_state:
+ original_mask = original_mask
+ else:
+ if input_mask.max() == 0:
+ original_mask = original_mask
+ else:
+ original_mask = input_mask
+
+ image_caption = generate_caption(blip2_processor, blip2_model, input_image, device)
+
+ # inpainting directly if target_prompt is not None
+ if category is not None:
+ pass
+ elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
+ pass
+ else:
+ try:
+ # category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
+ category = vlm_response_editing_type(vlm_processor, vlm_model, llm_model, original_image, image_caption, prompt, device)
+ print(category)
+ except Exception as e:
+ raise gr.Error("Time1. Please select the correct VLM model and input the correct API Key first!")
+
+
+ if original_mask is not None:
+ original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
+ else:
+ try:
+ object_wait_for_edit = vlm_response_object_wait_for_edit(
+ vlm_processor,
+ vlm_model,
+ llm_model,
+ original_image,
+ image_caption,
+ category,
+ prompt,
+ device)
+ print(object_wait_for_edit)
+
+ original_mask = vlm_response_mask(vlm_processor,
+ vlm_model,
+ category,
+ original_image,
+ prompt,
+ object_wait_for_edit,
+ sam,
+ sam_predictor,
+ sam_automask_generator,
+ groundingdino_model,
+ device).astype(np.uint8)
+ print("Got genarated mask!")
+ except Exception as e:
+ raise gr.Error("Time2. Please select the correct VLM model and input the correct API Key first!")
+
+ if original_mask.ndim == 2:
+ original_mask = original_mask[:,:,None]
+
+
+ if target_prompt is not None and len(target_prompt) >= 1:
+ prompt_after_apply_instruction = target_prompt
+
+ else:
+ try:
+ prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
+ vlm_processor,
+ vlm_model,
+ llm_model,
+ original_image,
+ image_caption,
+ prompt,
+ device)
+ except Exception as e:
+ raise gr.Error("Time3. Please select the correct VLM model and input the correct API Key first!")
+
+ generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
+
+
+ with torch.autocast(device):
+ image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
+ prompt_after_apply_instruction,
+ original_mask,
+ original_image,
+ generator,
+ num_inference_steps,
+ guidance_scale,
+ control_strength,
+ negative_prompt,
+ num_samples,
+ blending)
+ original_image = np.array(init_image_np)
+ masked_image = original_image * (1 - (mask_np>0))
+ masked_image = masked_image.astype(np.uint8)
+ masked_image = Image.fromarray(masked_image)
+
+ return image, [mask_image], [masked_image], prompt, prompt_after_apply_instruction, False
+
+
+def submit_GPT4o_KEY(GPT4o_KEY):
+ global vlm_model, vlm_processor
+ if vlm_model is not None:
+ del vlm_model
+ torch.cuda.empty_cache()
+ try:
+ vlm_model = OpenAI(api_key=GPT4o_KEY, base_url="https://api.deepseek.com")
+ vlm_processor = ""
+ response = vlm_model.chat.completions.create(
+ model="deepseek-chat",
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello."}
+ ]
+ )
+ response_str = response.choices[0].message.content
+
+ return "Success. " + response_str, "GPT4-o (Highly Recommended)"
+ except Exception as e:
+ return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
+
+
+def verify_deepseek_api():
+ try:
+ response = llm_model.chat.completions.create(
+ model="deepseek-chat",
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello."}
+ ]
+ )
+ response_str = response.choices[0].message.content
+
+ return True, "Success. " + response_str
+
+ except Exception as e:
+ return False, "Invalid DeepSeek API Key"
+
+
+def llm_enhanced_prompt_after_apply_instruction(image_caption, editing_prompt):
+ try:
+ messages = create_apply_editing_messages_deepseek(image_caption, editing_prompt)
+ response = llm_model.chat.completions.create(
+ model="deepseek-chat",
+ messages=messages
+ )
+ response_str = response.choices[0].message.content
+ return response_str
+ except Exception as e:
+ raise gr.Error(f"整合指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
+
+
+def llm_decomposed_prompt_after_apply_instruction(integrated_query):
+ try:
+ messages = create_decomposed_query_messages_deepseek(integrated_query)
+ response = llm_model.chat.completions.create(
+ model="deepseek-chat",
+ messages=messages
+ )
+ response_str = response.choices[0].message.content
+ return response_str
+ except Exception as e:
+ raise gr.Error(f"分解指令时遇到错误: {str(e)},请检查控制台日志获取详细信息")
+
+
+def enhance_description(blip2_description, prompt):
+ try:
+ if not prompt or not blip2_description:
+ print("Empty prompt or blip2_description detected")
+ return "", ""
+
+ print(f"Enhancing with prompt: {prompt}")
+ enhanced_description = llm_enhanced_prompt_after_apply_instruction(blip2_description, prompt)
+ return enhanced_description, enhanced_description
+
+ except Exception as e:
+ print(f"Enhancement failed: {str(e)}")
+ return "Error occurred", "Error occurred"
+
+
+def decompose_description(enhanced_description):
+ try:
+ if not enhanced_description:
+ print("Empty enhanced_description detected")
+ return "", ""
+
+ print(f"Decomposing the enhanced description: {enhanced_description}")
+ decomposed_description = llm_decomposed_prompt_after_apply_instruction(enhanced_description)
+ return decomposed_description, decomposed_description
+
+ except Exception as e:
+ print(f"Decomposition failed: {str(e)}")
+ return "Error occurred", "Error occurred"
+
+
+@torch.no_grad()
+def mix_and_search(enhanced_text: str, gallery_images: list):
+ # 获取最新生成的图像元组
+ latest_item = gallery_images[-1] if gallery_images else None
+
+ # 初始化特征列表
+ features = []
+
+ # 图像特征提取
+ if latest_item and isinstance(latest_item, tuple):
+ try:
+ image_path = latest_item[0]
+ pil_image = Image.open(image_path).convert("RGB")
+
+ # 使用 CLIPProcessor 处理图像
+ image_inputs = clip_processor(
+ images=pil_image,
+ return_tensors="pt"
+ ).to(device)
+
+ image_features = clip_model.get_image_features(**image_inputs)
+ features.append(F.normalize(image_features, dim=-1))
+ except Exception as e:
+ print(f"图像处理失败: {str(e)}")
+
+ # 文本特征提取
+ if enhanced_text.strip():
+ text_inputs = clip_processor(
+ text=enhanced_text,
+ return_tensors="pt",
+ padding=True,
+ truncation=True
+ ).to(device)
+
+ text_features = clip_model.get_text_features(**text_inputs)
+ features.append(F.normalize(text_features, dim=-1))
+
+ if not features:
+ return []
+
+
+ # 特征融合逻辑
+ if len(features) == 2:
+ # 图文双特征时加权:40%图像 + 60%文本
+ mixed = 0.4 * features[0] + 0.6 * features[1]
+ else:
+ # 单特征时直接求和(保持原逻辑)
+ mixed = sum(features)
+
+ mixed = F.normalize(mixed, dim=-1)
+
+ # 加载Faiss索引和图片路径映射
+ index_path = "/home/zt/data/BrushEdit/CIRR/dev/dev_knn.index"
+ input_data_dir = Path("/home/zt/data/BrushEdit/CIRR/dev/dev_embedding_folder/metadata")
+ base_image_dir = Path("/home/zt/data/BrushEdit/CIRR/")
+
+ # 按文件名中的数字排序并直接读取parquet文件
+ parquet_files = sorted(
+ input_data_dir.glob('*.parquet'),
+ key=lambda x: int(x.stem.split("_")[-1])
+ )
+
+ # 合并所有parquet数据
+ dfs = [pd.read_parquet(file) for file in parquet_files] # 直接内联读取
+ df = pd.concat(dfs, ignore_index=True)
+ image_paths = df["image_path"].tolist()
+
+ index = faiss.read_index(index_path)
+ assert mixed.shape[1] == index.d, "特征维度不匹配"
+ mixed = mixed.cpu().detach().numpy().astype('float32')
+ distances, indices = index.search(mixed, 50)
+
+ # 获取并验证图片路径
+ retrieved_images = []
+ for idx in indices[0]:
+ if 0 <= idx < len(image_paths):
+ img_path = base_image_dir / image_paths[idx]
+ try:
+ if img_path.exists():
+ retrieved_images.append(Image.open(img_path).convert("RGB"))
+ else:
+ print(f"警告:文件缺失 {img_path}")
+ except Exception as e:
+ print(f"图片加载失败: {str(e)}")
+
+ return retrieved_images if retrieved_images else ([])
+
+# 问题如下:在cap_file中,每张图片会充当多次reference的值,能够将同一张图片区分开来的是cap_file中的“caption”。我可以如何调整现在的保存逻辑,使得能够区分同一reference不同caption的情况,而不是直接覆盖呢?
+# def process_cirr_images():
+
+# if not all([vlm_model, sam_predictor, groundingdino_model]):
+# raise RuntimeError("Required models not initialized")
+
+# # Define paths
+# dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev")
+# cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json")
+# output_dirs = {
+# "edited": Path("/home/zt/data/BrushEdit/cirr/img_paint/cirr_edited"),
+# "mask": Path("/home/zt/data/BrushEdit/cirr/img_paint/cirr_mask"),
+# "masked": Path("/home/zt/data/BrushEdit/cirr/img_paint/cirr_masked")
+# }
+# output_json_path = Path("/home/zt/data/BrushEdit/cirr/image_paint.json")
+# descriptions = {}
+
+# # Create output directories
+# for dir_path in output_dirs.values():
+# dir_path.mkdir(parents=True, exist_ok=True)
+
+# # Load captions
+# with open(cap_file, 'r') as f:
+# captions = json.load(f)
+
+# for img_path in dev_dir.glob("*.png"):
+# base_name = img_path.stem
+# caption = next((item["caption"] for item in captions if item.get("reference") == base_name), None)
+
+# if not caption:
+# print(f"Warning: No caption for {base_name}")
+# continue
+
+# try:
+# # 构造空alpha通道(全0)
+# rgb_image = Image.open(img_path).convert("RGB")
+# empty_alpha = Image.new("L", rgb_image.size, 0) # 全透明alpha通道
+# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
+
+# # 调用init_img初始化
+# base = {"background": image, "layers": [image]}
+# init_results = init_img(
+# base=base,
+# init_type="custom", # 使用自定义初始化
+# prompt=caption,
+# aspect_ratio="Custom resolution",
+# example_change_times=0
+# )
+
+# # 获取初始化后的参数
+# input_image = init_results[0]
+# original_image = init_results[1]
+# original_mask = init_results[2]
+
+# # 正确设置process参数
+# result_images, mask_images, masked_images, _, target_description, _ = process(
+# input_image=input_image,
+# original_image=original_image,
+# original_mask=original_mask, # 传递初始化后的mask
+# prompt=caption,
+# negative_prompt="ugly, low quality",
+# control_strength=1.0,
+# seed=648464818,
+# randomize_seed=False,
+# guidance_scale=7.5,
+# num_inference_steps=50,
+# num_samples=1,
+# blending=True,
+# category=None,
+# target_prompt="",
+# resize_default=True,
+# aspect_ratio_name="Custom resolution",
+# invert_mask_state=False
+# )
+
+# # Save images
+# output_dirs["edited"].mkdir(exist_ok=True)
+# result_images[0].save(output_dirs["edited"] / f"{base_name}.png")
+# mask_images[0].save(output_dirs["mask"] / f"{base_name}_mask.png")
+# masked_images[0].save(output_dirs["masked"] / f"{base_name}_masked.png")
+
+# # Generate BLIP2 description
+# blip2_desc, _ = generate_blip2_description(input_image)
+
+# descriptions[base_name] = {
+# "original_caption": caption,
+# "blip2_description": blip2_desc,
+# "llm_enhanced_caption": target_description
+# }
+
+# with open(output_json_path, 'w') as f:
+# json.dump(descriptions, f, indent=4) # indent保持可读性
+
+# print(f"Processed {base_name}")
+
+# except Exception as e:
+# print(f"Error processing {base_name}: {str(e)}")
+# continue
+
+# print("Processing completed!")
+
+
+# cirr 的 val数据集。目前来看这个函数没有太大问题
+# def process_cirr_images():
+
+# if not all([vlm_model, sam_predictor, groundingdino_model]):
+# raise RuntimeError("Required models not initialized")
+
+# # Define paths
+# # dev_dir = Path("/home/zt/data/BrushEdit/cirr/img_raw/dev")
+# # cap_file = Path("/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json")
+# # output_dirs = {
+# # "edited": Path("/home/zt/data/BrushEdit/cirr/img_paint_pairid/cirr_edited"),
+# # "mask": Path("/home/zt/data/BrushEdit/cirr/img_paint_pairid/cirr_mask"),
+# # "masked": Path("/home/zt/data/BrushEdit/cirr/img_paint_pairid/cirr_masked")
+# # }
+# # output_json_path = Path("/home/zt/data/BrushEdit/cirr/image_paint_pairid.json")
+
+
+# dev_dir = Path("/home/zt/data/BrushEdit/CIRR/dev")
+# cap_file = Path("/home/zt/data/BrushEdit/CIRR/cirr/captions/val_deepseek_missed_174.json")
+# output_dirs = {
+# "edited": Path("/home/zt/data/BrushEdit/CIRR/img_paint_pairid/missed/cirr_edited"),
+# "mask": Path("/home/zt/data/BrushEdit/CIRR/img_paint_pairid/missed/cirr_mask"),
+# "masked": Path("/home/zt/data/BrushEdit/CIRR/img_paint_pairid/missed/cirr_masked")
+# }
+# output_json_path = Path("/home/zt/data/BrushEdit/CIRR/cirr/image_paint_deepseek_missed_pairid.json")
+# descriptions = {}
+
+# # Create output directories
+# for dir_path in output_dirs.values():
+# dir_path.mkdir(parents=True, exist_ok=True)
+
+# # Load captions
+# with open(cap_file, 'r') as f:
+# captions = json.load(f)
+
+# for img_path in dev_dir.glob("*.png"):
+# base_name = img_path.stem
+# # 获取所有匹配的caption条目
+# matched_items = [item for item in captions if item.get("reference") == base_name]
+
+# if not matched_items:
+# print(f"Warning: No captions for {base_name}")
+# continue
+
+# for item in matched_items:
+# # 验证必要字段存在
+# pairid = item.get("pairid")
+# caption = item.get("caption")
+
+# if not all([pairid, caption]):
+# print(f"Skipping invalid item for {base_name}: {item}")
+# continue
+
+# # 使用pairid构造唯一标识
+# processed_base = f"{base_name}_{pairid}"
+
+# try:
+# # 构造空alpha通道
+# rgb_image = Image.open(img_path).convert("RGB")
+# empty_alpha = Image.new("L", rgb_image.size, 0)
+# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
+
+# # 初始化图像
+# base = {"background": image, "layers": [image]}
+# init_results = init_img(
+# base=base,
+# init_type="custom",
+# prompt=caption,
+# aspect_ratio="Custom resolution",
+# example_change_times=0
+# )
+
+# # 获取处理参数
+# input_image = init_results[0]
+# original_image = init_results[1]
+# original_mask = init_results[2]
+
+# # 正确设置process参数
+# result_images, mask_images, masked_images, _, target_description, _ = process(
+# input_image=input_image,
+# original_image=original_image,
+# original_mask=original_mask, # 传递初始化后的mask
+# prompt=caption,
+# negative_prompt="ugly, low quality",
+# control_strength=1.0,
+# seed=648464818,
+# randomize_seed=False,
+# guidance_scale=7.5,
+# num_inference_steps=50,
+# num_samples=1,
+# blending=True,
+# category=None,
+# target_prompt="",
+# resize_default=True,
+# aspect_ratio_name="Custom resolution",
+# invert_mask_state=False
+# )
+
+# # 保存文件(使用pairid标识)
+# result_images[0].save(output_dirs["edited"] / f"{processed_base}.png")
+# mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.png")
+# masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.png")
+
+# # 生成描述
+# blip2_desc, _ = generate_blip2_description(input_image)
+
+# # 使用pairid作为主键存储元数据
+# descriptions[pairid] = {
+# "reference": base_name,
+# "user_editing_prompt": caption,
+# "blip2_description": blip2_desc,
+# "llm_enhanced_caption": target_description,
+# "processed_files": {
+# "edited": f"{processed_base}.png",
+# "mask": f"{processed_base}_mask.png",
+# "masked": f"{processed_base}_masked.png"
+# }
+# }
+
+# print(f"Processed {processed_base}")
+
+# except Exception as e:
+# print(f"Error processing {processed_base}: {str(e)}")
+# continue
+
+# # 保存元数据
+# with open(output_json_path, 'w') as f:
+# json.dump(descriptions, f, indent=4)
+
+# print("Processing completed!")
+
+# cirr 的 test1数据集。
+# def process_cirr_images():
+
+# if not all([vlm_model, sam_predictor, groundingdino_model]):
+# raise RuntimeError("Required models not initialized")
+
+# # Define paths
+# dev_dir = Path("/home/zt/data/BrushEdit/CIRR/test1")
+# cap_file = Path("/home/zt/data/BrushEdit/CIRR/cirr/captions/cap.rc2.test1.json")
+# output_dirs = {
+# "edited": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/cirr_edited"),
+# "mask": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/cirr_mask"),
+# "masked": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/cirr_masked")
+# }
+# output_json_path = Path("/home/zt/data/BrushEdit/CIRR/test1_image_paint_pairid.json")
+# descriptions = {}
+
+# # Create output directories
+# for dir_path in output_dirs.values():
+# dir_path.mkdir(parents=True, exist_ok=True)
+
+# # Load captions
+# with open(cap_file, 'r') as f:
+# captions = json.load(f)
+
+# for img_path in dev_dir.glob("*.png"):
+# base_name = img_path.stem
+# # 获取所有匹配的caption条目
+# matched_items = [item for item in captions if item.get("reference") == base_name]
+
+# if not matched_items:
+# print(f"Warning: No captions for {base_name}")
+# continue
+
+# for item in matched_items:
+# # 验证必要字段存在
+# pairid = item.get("pairid")
+# caption = item.get("caption")
+
+# if not all([pairid, caption]):
+# print(f"Skipping invalid item for {base_name}: {item}")
+# continue
+
+# # 使用pairid构造唯一标识
+# processed_base = f"{base_name}_{pairid}"
+
+# try:
+# # 构造空alpha通道
+# rgb_image = Image.open(img_path).convert("RGB")
+# empty_alpha = Image.new("L", rgb_image.size, 0)
+# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
+
+# # 初始化图像
+# base = {"background": image, "layers": [image]}
+# init_results = init_img(
+# base=base,
+# init_type="custom",
+# prompt=caption,
+# aspect_ratio="Custom resolution",
+# example_change_times=0
+# )
+
+# # 获取处理参数
+# input_image = init_results[0]
+# original_image = init_results[1]
+# original_mask = init_results[2]
+
+# # 正确设置process参数
+# result_images, mask_images, masked_images, _, target_description, _ = process(
+# input_image=input_image,
+# original_image=original_image,
+# original_mask=original_mask, # 传递初始化后的mask
+# prompt=caption,
+# negative_prompt="ugly, low quality",
+# control_strength=1.0,
+# seed=648464818,
+# randomize_seed=False,
+# guidance_scale=7.5,
+# num_inference_steps=50,
+# num_samples=1,
+# blending=True,
+# category=None,
+# target_prompt="",
+# resize_default=True,
+# aspect_ratio_name="Custom resolution",
+# invert_mask_state=False
+# )
+
+# # 保存文件(使用pairid标识)
+# result_images[0].save(output_dirs["edited"] / f"{processed_base}.png")
+# mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.png")
+# masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.png")
+
+# # 生成描述
+# blip2_desc, _ = generate_blip2_description(input_image)
+
+# # 使用pairid作为主键存储元数据
+# descriptions[pairid] = {
+# "reference": base_name,
+# "user_editing_prompt": caption,
+# "blip2_description": blip2_desc,
+# "llm_enhanced_caption": target_description,
+# "processed_files": {
+# "edited": f"{processed_base}.png",
+# "mask": f"{processed_base}_mask.png",
+# "masked": f"{processed_base}_masked.png"
+# }
+# }
+
+# print(f"Processed {processed_base}")
+
+# except Exception as e:
+# print(f"Error processing {processed_base}: {str(e)}")
+# continue
+
+# # 保存元数据
+# with open(output_json_path, 'w') as f:
+# json.dump(descriptions, f, indent=4)
+
+# print("Processing completed!")
+
+
+def process_cirr_images():
+ if not all([vlm_model, sam_predictor, groundingdino_model]):
+ raise RuntimeError("Required models not initialized")
+
+ # Define paths
+ dev_dir = Path("/home/zt/data/BrushEdit/CIRR/test1")
+ cap_file = Path("/home/zt/data/BrushEdit/CIRR/cirr/captions/cap.rc2.test1.json")
+ output_dirs = {
+ "edited": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/qw_cirr_edited"),
+ "mask": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/qw_cirr_mask"),
+ "masked": Path("/home/zt/data/BrushEdit/CIRR/test1_img_paint_pairid/qw_cirr_masked")
+ }
+ output_json_path = Path("/home/zt/data/BrushEdit/CIRR/qw_test1_image_paint_pairid.json")
+
+ # Create output directories
+ for dir_path in output_dirs.values():
+ dir_path.mkdir(parents=True, exist_ok=True)
+
+ # 1. 加载已有处理结果
+ processed_pairids = set()
+ if output_json_path.exists():
+ try:
+ with open(output_json_path, 'r') as f:
+ descriptions = json.load(f)
+ processed_pairids = set(descriptions.keys())
+ print(f"Loaded {len(processed_pairids)} previously processed pairids")
+ except Exception as e:
+ print(f"Error loading existing results: {str(e)}, starting from scratch")
+ descriptions = {}
+ else:
+ descriptions = {}
+
+ # 2. 创建临时文件写入器
+ temp_json_path = output_json_path.with_suffix(".tmp")
+
+ # Load captions
+ with open(cap_file, 'r') as f:
+ captions = json.load(f)
+
+ # 3. 处理进度跟踪
+ total = 0
+ processed = 0
+ for img_path in dev_dir.glob("*.png"):
+ base_name = img_path.stem
+ matched_items = [item for item in captions if item.get("reference") == base_name]
+
+ if not matched_items:
+ print(f"Warning: No captions for {base_name}")
+ continue
+
+ for item in matched_items:
+ total += 1
+ pairid = str(item.get("pairid")) # 确保字符串类型
+ caption = item.get("caption")
+
+ if not all([pairid, caption]):
+ print(f"Skipping invalid item for {base_name}: {item}")
+ continue
+
+ # 4. 跳过已处理的pairid
+ if pairid in processed_pairids:
+ print(f"Skipping already processed pairid: {pairid}")
+ continue
+
+ processed_base = f"{base_name}_{pairid}"
+ print(f"Processing {processed_base} ({processed+1}/{total})")
+
+ try:
+ # 构造空alpha通道
+ rgb_image = Image.open(img_path).convert("RGB")
+ empty_alpha = Image.new("L", rgb_image.size, 0)
+ image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
+
+ # 初始化图像
+ base = {"background": image, "layers": [image]}
+ init_results = init_img(
+ base=base,
+ init_type="custom",
+ prompt=caption,
+ aspect_ratio="Custom resolution",
+ example_change_times=0
+ )
+
+ # 获取处理参数
+ input_image = init_results[0]
+ original_image = init_results[1]
+ original_mask = init_results[2]
+
+ # 正确设置process参数
+ result_images, mask_images, masked_images, _, target_description, _ = process(
+ input_image=input_image,
+ original_image=original_image,
+ original_mask=original_mask, # 传递初始化后的mask
+ prompt=caption,
+ negative_prompt="ugly, low quality",
+ control_strength=1.0,
+ seed=648464818,
+ randomize_seed=False,
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ num_samples=1,
+ blending=True,
+ category=None,
+ target_prompt="",
+ resize_default=True,
+ aspect_ratio_name="Custom resolution",
+ invert_mask_state=False
+ )
+
+ # 保存文件(使用pairid标识)
+ result_images[0].save(output_dirs["edited"] / f"{processed_base}.png")
+ mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.png")
+ masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.png")
+
+ # 生成描述
+ blip2_desc, _ = generate_blip2_description(input_image)
+
+ # 更新描述信息 使用pairid作为主键存储元数据
+ descriptions[pairid] = {
+ "reference": base_name,
+ "user_editing_prompt": caption,
+ "blip2_description": blip2_desc,
+ "llm_enhanced_caption": target_description,
+ "processed_files": {
+ "edited": f"{processed_base}.png",
+ "mask": f"{processed_base}_mask.png",
+ "masked": f"{processed_base}_masked.png"
+ }
+ }
+
+ # 5. 原子化写入:先写入临时文件,再替换原文件
+ with open(temp_json_path, 'w') as f:
+ json.dump(descriptions, f, indent=4)
+ temp_json_path.replace(output_json_path)
+
+ processed +=1
+ processed_pairids.add(pairid)
+ print(f"Successfully processed {pairid}")
+
+ except Exception as e:
+ print(f"Error processing {pairid}: {str(e)}")
+ # 删除可能生成的不完整文件
+ for ext in ["", "_mask.png", "_masked.png"]:
+ incomplete_file = output_dirs["edited"] / f"{processed_base}{ext}"
+ if incomplete_file.exists():
+ incomplete_file.unlink()
+ continue
+
+ print(f"Processing completed! Total processed: {processed}/{total}")
+
+# circo 的 val数据集。目前来看这个函数没有太大问题
+def process_circo_val_images():
+
+ if not all([vlm_model, sam_predictor, groundingdino_model]):
+ raise RuntimeError("Required models not initialized")
+
+ # Define paths
+ # 实际上不单单是dev数据集的图片存储空间,所有的(包括test集合&没有用到的)图片都存在这里
+ dev_dir = Path("/home/zt/data/BrushEdit/CIRCO/COCO2017_unlabeled/unlabeled2017")
+ cap_file = Path("/home/zt/data/BrushEdit/CIRCO/annotations/val.json")
+ output_dirs = {
+ "edited": Path("/home/zt/data/BrushEdit/CIRCO/img_paint_pairid/circo_edited"),
+ "mask": Path("/home/zt/data/BrushEdit/CIRCO/img_paint_pairid/circo_mask"),
+ "masked": Path("/home/zt/data/BrushEdit/CIRCO/img_paint_pairid/circo_masked")
+ }
+ output_json_path = Path("/home/zt/data/BrushEdit/CIRCO/image_paint_pairid.json")
+ descriptions = {}
+
+ # Create output directories
+ for dir_path in output_dirs.values():
+ dir_path.mkdir(parents=True, exist_ok=True)
+
+ # Load captions
+ with open(cap_file, 'r') as f:
+ captions = json.load(f)
+
+
+ for img_path in dev_dir.glob("*.jpg"):
+ # 不包含扩展名
+ base_name = img_path.stem
+ # 提取后六位作为参考ID,reference_part是字符串类型
+ reference_part = base_name[-6:]
+ # 将JSON中的reference_img_id(原本为int),转换为字符串后比较
+ matched_items = [item for item in captions if str(item.get("reference_img_id")) == reference_part]
+
+ if not matched_items:
+ print(f"Warning: No captions for {base_name}")
+ continue
+
+ for item in matched_items:
+ # 验证必要字段存在
+ pairid = item.get("id")
+ caption = item.get("relative_caption")
+
+ if not all([pairid, caption]):
+ print(f"Skipping invalid item for {base_name}: {item}")
+ continue
+
+ # 使用pairid构造唯一标识
+ # 使用f-string进行字符串格式化时,它会自动将非字符串类型的变量转换为字符串类型
+ processed_base = f"{reference_part}_{pairid}"
+
+ try:
+ # 构造空alpha通道
+ rgb_image = Image.open(img_path).convert("RGB")
+ empty_alpha = Image.new("L", rgb_image.size, 0)
+ image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
+
+ # 初始化图像
+ base = {"background": image, "layers": [image]}
+ init_results = init_img(
+ base=base,
+ init_type="custom",
+ prompt=caption,
+ aspect_ratio="Custom resolution",
+ example_change_times=0
+ )
+
+ # 获取处理参数
+ input_image = init_results[0]
+ original_image = init_results[1]
+ original_mask = init_results[2]
+
+ # 正确设置process参数
+ result_images, mask_images, masked_images, _, target_description, _ = process(
+ input_image=input_image,
+ original_image=original_image,
+ original_mask=original_mask, # 传递初始化后的mask
+ prompt=caption,
+ negative_prompt="ugly, low quality",
+ control_strength=1.0,
+ seed=648464818,
+ randomize_seed=False,
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ num_samples=1,
+ blending=True,
+ category=None,
+ target_prompt="",
+ resize_default=True,
+ aspect_ratio_name="Custom resolution",
+ invert_mask_state=False
+ )
+
+ # 保存文件(使用pairid标识)
+ result_images[0].save(output_dirs["edited"] / f"{processed_base}.jpg")
+ mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.jpg")
+ masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.jpg")
+
+ # 生成描述
+ blip2_desc, _ = generate_blip2_description(input_image)
+
+ # 使用pairid作为主键存储元数据
+ descriptions[pairid] = {
+ "reference": int(reference_part),
+ "user_editing_prompt": caption,
+ "blip2_description": blip2_desc,
+ "llm_enhanced_caption": target_description,
+ "processed_files": {
+ "edited": f"{processed_base}.jpg",
+ "mask": f"{processed_base}_mask.jpg",
+ "masked": f"{processed_base}_masked.jpg"
+ }
+ }
+
+ print(f"Processed {processed_base}")
+
+ except Exception as e:
+ print(f"Error processing {processed_base}: {str(e)}")
+ continue
+
+ # 保存元数据
+ with open(output_json_path, 'w') as f:
+ json.dump(descriptions, f, indent=4)
+
+ print("Processing completed!")
+
+
+# circo 的 test数据集。目前来看这个函数没有太大问题
+def process_circo_test_images():
+
+ if not all([vlm_model, sam_predictor, groundingdino_model]):
+ raise RuntimeError("Required models not initialized")
+
+ # Define paths
+ # 实际上不单单是dev数据集的图片存储空间,所有的(包括test集合&没有用到的)图片都存在这里
+ dev_dir = Path("/home/zt/data/BrushEdit/CIRCO/COCO2017_unlabeled/unlabeled2017")
+ cap_file = Path("/home/zt/data/BrushEdit/CIRCO/annotations/test.json")
+ output_dirs = {
+ "edited": Path("/home/zt/data/BrushEdit/CIRCO/test_img_paint_pairid/circo_edited"),
+ "mask": Path("/home/zt/data/BrushEdit/CIRCO/test_img_paint_pairid/circo_mask"),
+ "masked": Path("/home/zt/data/BrushEdit/CIRCO/test_img_paint_pairid/circo_masked")
+ }
+ output_json_path = Path("/home/zt/data/BrushEdit/CIRCO/test_image_paint_pairid.json")
+ descriptions = {}
+
+ # Create output directories
+ for dir_path in output_dirs.values():
+ dir_path.mkdir(parents=True, exist_ok=True)
+
+ # Load captions
+ with open(cap_file, 'r') as f:
+ captions = json.load(f)
+
+
+ for img_path in dev_dir.glob("*.jpg"):
+ # 不包含扩展名
+ base_name = img_path.stem
+ # 提取后六位作为参考ID,reference_part是字符串类型
+ reference_part = base_name[-6:]
+ # 将JSON中的reference_img_id(原本为int),转换为字符串后比较
+ matched_items = [item for item in captions if str(item.get("reference_img_id")) == reference_part]
+
+ if not matched_items:
+ print(f"Warning: No captions for {base_name}")
+ continue
+
+ for item in matched_items:
+ # 验证必要字段存在
+ pairid = item.get("id")
+ caption = item.get("relative_caption")
+
+ if not all([pairid, caption]):
+ print(f"Skipping invalid item for {base_name}: {item}")
+ continue
+
+ # 使用pairid构造唯一标识
+ # 使用f-string进行字符串格式化时,它会自动将非字符串类型的变量转换为字符串类型
+ processed_base = f"{reference_part}_{pairid}"
+
+ try:
+ # 构造空alpha通道
+ rgb_image = Image.open(img_path).convert("RGB")
+ empty_alpha = Image.new("L", rgb_image.size, 0)
+ image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
+
+ # 初始化图像
+ base = {"background": image, "layers": [image]}
+ init_results = init_img(
+ base=base,
+ init_type="custom",
+ prompt=caption,
+ aspect_ratio="Custom resolution",
+ example_change_times=0
+ )
+
+ # 获取处理参数
+ input_image = init_results[0]
+ original_image = init_results[1]
+ original_mask = init_results[2]
+
+ # 正确设置process参数
+ result_images, mask_images, masked_images, _, target_description, _ = process(
+ input_image=input_image,
+ original_image=original_image,
+ original_mask=original_mask, # 传递初始化后的mask
+ prompt=caption,
+ negative_prompt="ugly, low quality",
+ control_strength=1.0,
+ seed=648464818,
+ randomize_seed=False,
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ num_samples=1,
+ blending=True,
+ category=None,
+ target_prompt="",
+ resize_default=True,
+ aspect_ratio_name="Custom resolution",
+ invert_mask_state=False
+ )
+
+ # 保存文件(使用pairid标识)
+ result_images[0].save(output_dirs["edited"] / f"{processed_base}.jpg")
+ mask_images[0].save(output_dirs["mask"] / f"{processed_base}_mask.jpg")
+ masked_images[0].save(output_dirs["masked"] / f"{processed_base}_masked.jpg")
+
+ # 生成描述
+ blip2_desc, _ = generate_blip2_description(input_image)
+
+ # 使用pairid作为主键存储元数据
+ descriptions[pairid] = {
+ "reference": int(reference_part),
+ "user_editing_prompt": caption,
+ "blip2_description": blip2_desc,
+ "llm_enhanced_caption": target_description,
+ "processed_files": {
+ "edited": f"{processed_base}.jpg",
+ "mask": f"{processed_base}_mask.jpg",
+ "masked": f"{processed_base}_masked.jpg"
+ }
+ }
+
+ print(f"Processed {processed_base}")
+
+ except Exception as e:
+ print(f"Error processing {processed_base}: {str(e)}")
+ continue
+
+ # 保存元数据
+ with open(output_json_path, 'w') as f:
+ json.dump(descriptions, f, indent=4)
+
+ print("Processing completed!")
+
+
+
+
+
+# 目前来看这个函数没有太大问题
+@torch.no_grad()
+def batch_mix_and_search_cirr(
+ json_path: str = "/home/zt/data/BrushEdit/cirr/image_paint_pairid.json",
+ image_dir: str = "/home/zt/data/BrushEdit/cirr/img_paint/cirr_edited",
+ alpha: float = 0.8,
+ batch_size: int = 32,
+ output_json_path: str = "retrieval_results_pairid.json"
+) -> Dict[str, List[Dict]]:
+ # 加载索引和元数据
+ index = faiss.read_index("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_knn.index")
+ metadata = pd.read_parquet("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_embedding_folder/metadata/metadata_0.parquet")
+ all_index_ids = metadata["image_path"].tolist()
+
+ # 加载并验证输入数据
+ with open(json_path) as f:
+ samples = json.load(f)
+
+ valid_samples = []
+ image_dir = Path(image_dir)
+ for pair_id, sample_info in samples.items():
+ reference = sample_info["reference"]
+ img_path = image_dir / f"{reference}.png"
+ if img_path.exists():
+ valid_samples.append((
+ pair_id,
+ reference,
+ img_path,
+ sample_info['llm_enhanced_caption'],
+ sample_info['user_editing_prompt']
+ ))
+
+ # 初始化结果字典
+ results = {}
+ total_samples = len(valid_samples)
+
+ # 分批次处理
+ for batch_idx in range(0, total_samples, batch_size):
+ batch_end = min(batch_idx + batch_size, total_samples)
+ current_batch = valid_samples[batch_idx:batch_end]
+
+ # 批量处理图像(保持原逻辑)
+ batch_images = [Image.open(s[2]).convert("RGB") for s in current_batch]
+ image_inputs = clip_processor(images=batch_images, return_tensors="pt").to(device)
+ image_features = clip_model.get_image_features(**image_inputs)
+ image_features = nn.functional.normalize(image_features, dim=-1)
+
+ # 批量处理文本(保持原逻辑)
+ batch_texts = [s[3] for s in current_batch]
+ text_inputs = clip_processor(
+ text=batch_texts,
+ return_tensors="pt",
+ padding=True,
+ truncation=True
+ ).to(device)
+ text_features = clip_model.get_text_features(**text_inputs)
+ text_features = nn.functional.normalize(text_features, dim=-1)
+
+ # 混合特征
+ mixed_features = (1 - alpha) * image_features + alpha * text_features
+ mixed_features = nn.functional.normalize(mixed_features, dim=-1)
+
+ # 批量检索
+ query_features = mixed_features.cpu().numpy().astype("float32")
+ distances, indices = index.search(query_features, 100)
+
+ # 保存当前批次结果
+ for (pair_id, reference, _, enhanced_cap, editing_prompt), dist_row, idx_row in zip(current_batch, distances, indices):
+ results[pair_id] = {
+ "reference": reference,
+ "llm_enhanced_caption": enhanced_cap,
+ "user_editing_prompt": editing_prompt,
+ "retrieved_results": []
+ }
+ for distance, idx in zip(dist_row, idx_row):
+ if 0 <= idx < len(all_index_ids):
+ raw_id = all_index_ids[idx]
+ base_name = os.path.basename(raw_id)
+ file_name = os.path.splitext(base_name)[0]
+ results[pair_id]["retrieved_results"].append({
+ "retrieved_id": file_name,
+ "score": float(distance),
+ })
+
+ # 保存结果到JSON
+ with open(output_json_path, 'w') as f:
+ json.dump(results, f, indent=2, ensure_ascii=False)
+
+ print("Retrieving completed!")
+ return results
+
+# @torch.no_grad()
+# def batch_mix_and_search_cirr(
+# json_path: str = "/home/zt/data/BrushEdit/cirr/image_paint_pairid.json",
+# image_dir: str = "/home/zt/data/BrushEdit/cirr/img_paint/cirr_edited",
+# alpha: float = 0.8,
+# batch_size: int = 32,
+# output_json_path: str = "retrieval_results_pairid.json"
+# ) -> Dict[str, List[Dict]]:
+# # 加载索引和元数据
+# index = faiss.read_index("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_knn.index")
+# metadata = pd.read_parquet("/home/zt/data/BrushEdit/cirr/img_raw/dev/dev_embedding_folder/metadata/metadata_0.parquet")
+# all_index_ids = metadata["image_path"].tolist()
+
+# # 加载并验证输入数据
+# with open(json_path) as f:
+# samples = json.load(f)
+
+# valid_samples = []
+# image_dir = Path(image_dir)
+# for image_id, sample_info in samples.items():
+# img_path = image_dir / f"{image_id}.png"
+# if img_path.exists():
+# valid_samples.append((
+# image_id,
+# img_path,
+# sample_info['llm_enhanced_caption'],
+# sample_info['user_editing_prompt']
+# ))
+
+# # 初始化结果字典
+# results = {}
+# total_samples = len(valid_samples)
+
+# # 分批次处理
+# for batch_idx in range(0, total_samples, batch_size):
+# batch_end = min(batch_idx + batch_size, total_samples)
+# current_batch = valid_samples[batch_idx:batch_end]
+
+# # 批量处理图像
+# batch_images = [Image.open(s[1]).convert("RGB") for s in current_batch]
+# image_inputs = clip_processor(images=batch_images, return_tensors="pt").to(device)
+# image_features = clip_model.get_image_features(**image_inputs)
+# image_features = nn.functional.normalize(image_features, dim=-1)
+
+# # 批量处理文本
+# batch_texts = [s[2] for s in current_batch]
+# text_inputs = clip_processor(
+# text=batch_texts,
+# return_tensors="pt",
+# padding=True,
+# truncation=True
+# ).to(device)
+# text_features = clip_model.get_text_features(**text_inputs)
+# text_features = nn.functional.normalize(text_features, dim=-1)
+
+# # 混合特征
+# mixed_features = (1 - alpha) * image_features + alpha * text_features
+# mixed_features = nn.functional.normalize(mixed_features, dim=-1)
+
+# # 批量检索
+# query_features = mixed_features.cpu().numpy().astype("float32")
+# distances, indices = index.search(query_features, 100)
+
+# # 保存当前批次结果
+# for (sample_id, _, enhanced_cap, original_cap), dist_row, idx_row in zip(current_batch, distances, indices):
+# results[sample_id] = {
+# "llm_enhanced_caption": enhanced_cap, # 增强描述
+# "original_caption": original_cap, # 原始描述
+# "retrieved_results": []
+# }
+# for distance, idx in zip(dist_row, idx_row):
+# if 0 <= idx < len(all_index_ids):
+# raw_id = all_index_ids[idx]
+# base_name = os.path.basename(raw_id)
+# file_name = os.path.splitext(base_name)[0]
+# results[sample_id]["retrieved_results"].append({
+# "retrieved_id": file_name,
+# "score": float(distance),
+# })
+
+# # 保存结果到JSON
+# with open(output_json_path, 'w') as f:
+# json.dump(results, f, indent=2, ensure_ascii=False)
+
+# print("Retrieving completed!")
+# return results
+
+
+# def evaluate_cirr_scores() -> List[Tuple[str, float]]:
+
+# # 设置数据集和检索结果的路径
+# dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json"
+# retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_quchong.json"
+
+# # 加载数据集
+# with open(dataset_path, 'r') as f:
+# dataset = json.load(f)
+# print(len(dataset))
+
+# # 加载检索结果
+# with open(retrieval_results_path, 'r') as f:
+# retrieval_results = json.load(f)
+# print(len(retrieval_results))
+
+# # 数据结构初始化
+# all_target_captions_soft = []
+# all_set_member_idx = []
+# nn_result = []
+
+# # 构建匹配数据结构
+# for sample in dataset:
+# all_target_captions_soft.append(sample["target_soft"])
+# all_set_member_idx.append(sample["img_set"]["members"])
+# query_id = str(sample["reference"])
+# retrieved_items = retrieval_results.get(query_id, [])
+# nn_result.append([item["retrieved_id"] for item in retrieved_items])
+
+# # 计算召回指标
+# out = []
+# # Recall@K (全局检索)
+# for k in [1, 5, 10, 50]:
+# total_score = 0.0
+# for i in range(len(dataset)):
+# query_id = str(dataset[i]["reference"]) # 获取当前查询的参考ID
+# # 过滤掉参考图像本身
+# filtered_results = [rid for rid in nn_result[i] if rid != query_id]
+# top_k = filtered_results[:k]
+
+# best_score = 0.0
+# for target_id, score in all_target_captions_soft[i].items():
+# if target_id in top_k:
+# best_score = max(best_score, score)
+# total_score += best_score
+# recall = total_score / len(dataset) * 100
+
+# out.append((f"recall_top{k}_correct_composition", recall))
+
+# # Recall_subset@K (子集检索)
+# for k in [1, 2, 3]:
+# total_score = 0.0
+# for i in range(len(dataset)):
+# query_id = str(dataset[i]["reference"])
+# # 双重过滤:子集成员 + 排除参考图像
+# subset_results = [
+# rid for rid in nn_result[i]
+# if rid in all_set_member_idx[i] and rid != query_id
+# ]
+# top_k_subset = subset_results[:k]
+
+# best_score = 0.0
+# for target_id, score in all_target_captions_soft[i].items():
+# if target_id in top_k_subset:
+# best_score = max(best_score, score)
+# total_score += best_score
+# recall = total_score / len(dataset) * 100
+
+# out.append((f"recall_inset_top{k}_correct_composition", recall))
+
+# # 打印和保存���果
+# print("\n" + "="*30 + " Evaluation Results " + "="*30)
+# for metric, value in out:
+# print(f"{metric:<40}: {value:.4f}")
+
+# output_dir = os.path.dirname(retrieval_results_path)
+# output_filename = "evaluation_results_quchong.txt"
+# output_path = os.path.join(output_dir, output_filename)
+
+# with open(output_path, 'w') as f:
+# f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n")
+# for metric, value in out:
+# line = f"{metric:<40}: {value:.4f}\n"
+# f.write(line)
+# print(f"\nResults saved to {output_path}")
+
+# return out
+
+
+# # 这是因为dataset中的样本的reference字段大量重复(CIRR数据集,每个参考图像对应多个不同的目标描述caption)
+# def evaluate_cirr_scores() -> List[Tuple[str, float]]:
+# dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json"
+# retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_noquery_a08.json"
+
+# # 加载数据集
+# with open(dataset_path, 'r') as f:
+# dataset = json.load(f)
+
+# # 加载检索结果
+# with open(retrieval_results_path, 'r') as f:
+# retrieval_results = json.load(f)
+
+# # 调试:检查数据一致性
+# print(len(dataset))
+# print(len(retrieval_results))
+
+# # 过滤有效样本(确保类型严格一致)
+# valid_samples = []
+# for sample in dataset:
+# query_id = str(sample["reference"])
+# if query_id in retrieval_results:
+# valid_samples.append(sample)
+
+# print(f"Valid samples count: {len(valid_samples)} (must be 2086)")
+
+# # 构建数据结构
+# all_target_captions_soft = []
+# all_set_member_idx = []
+# nn_result = []
+
+# for sample in valid_samples:
+# all_target_captions_soft.append(sample["target_soft"])
+# all_set_member_idx.append(sample["img_set"]["members"])
+# query_id = str(sample["reference"])
+# # 直接获取该query_id的检索结果列表
+# retrieved_items = retrieval_results[query_id]
+# nn_result.append([item["retrieved_id"] for item in retrieved_items])
+
+# out = []
+# # Recall@K 计算(使用有效样本数量和过滤后的数据)
+# for k in [1, 5, 10, 50]:
+# total_score = 0.0
+# for i in range(len(valid_samples)):
+# query_id = str(valid_samples[i]["reference"])
+# filtered_results = [rid for rid in nn_result[i] if rid != query_id]
+# top_k = filtered_results[:k]
+
+# best_score = 0.0
+# for target_id, score in all_target_captions_soft[i].items():
+# if target_id in top_k:
+# best_score = max(best_score, score)
+# total_score += best_score
+
+# recall = total_score / len(valid_samples) * 100
+# print(len(valid_samples))
+# out.append((f"recall_top{k}_correct_composition", recall))
+
+# # Recall_subset@K 计算
+# for k in [1, 2, 3]:
+# total_score = 0.0
+# for i in range(len(valid_samples)):
+# query_id = str(valid_samples[i]["reference"])
+# subset_results = [
+# rid for rid in nn_result[i]
+# if rid in all_set_member_idx[i] and rid != query_id
+# ]
+# top_k_subset = subset_results[:k]
+
+# best_score = 0.0
+# for target_id, score in all_target_captions_soft[i].items():
+# if target_id in top_k_subset:
+# best_score = max(best_score, score)
+# total_score += best_score
+
+# recall = total_score / len(valid_samples) * 100
+# out.append((f"recall_inset_top{k}_correct_composition", recall))
+
+# # 输出结果(保持不变)
+# print("\n" + "="*30 + " Evaluation Results " + "="*30)
+# for metric, value in out:
+# print(f"{metric:<40}: {value:.4f}")
+
+# output_dir = os.path.dirname(retrieval_results_path)
+# output_filename = "evaluation_results_valid.txt"
+# output_path = os.path.join(output_dir, output_filename)
+
+# with open(output_path, 'w') as f:
+# f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n")
+# for metric, value in out:
+# line = f"{metric:<40}: {value:.4f}\n"
+# f.write(line)
+# print(f"\nResults saved to {output_path}")
+
+# return out
+
+# # 可以截取前n个计算分数
+# def evaluate_cirr_scores() -> List[Tuple[str, float]]:
+
+# # 设置数据集和检索结果的路径
+# dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json"
+# retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_quchongnew.json"
+
+# # 加载数据集
+# with open(dataset_path, 'r') as f:
+# dataset = json.load(f)
+# print(f"Total samples in dataset: {len(dataset)}")
+
+# # 加载���索结果
+# with open(retrieval_results_path, 'r') as f:
+# retrieval_results = json.load(f)
+# print(f"Total queries in retrieval results: {len(retrieval_results)}")
+
+# # 数据结构初始化
+# valid_query_ids = [] # 存储匹配成功的query_id
+# all_target_captions_soft = [] # 存储匹配样本的target_soft
+# all_set_member_idx = [] # 存储匹配样本的set members
+# nn_result = [] # 存储匹配样本的检索结果
+
+# # 构建匹配数据结构(新增双重验证)
+# for sample in dataset:
+# query_id = str(sample["reference"])
+# caption = sample["caption"]
+
+# # 获取对应的检索结果条目
+# retrieved_entry = retrieval_results.get(query_id)
+
+# # 双重验证:query_id存在且caption匹配
+# if retrieved_entry and retrieved_entry["original_caption"] == caption:
+# valid_query_ids.append(query_id)
+# all_target_captions_soft.append(sample["target_soft"])
+# all_set_member_idx.append(sample["img_set"]["members"])
+# # 提取检索结果并转换为id列表
+# retrieved_items = retrieved_entry["retrieved_results"]
+# nn_result.append([item["retrieved_id"] for item in retrieved_items])
+
+# print(f"Valid matched samples after verification: {len(valid_query_ids)}")
+
+# ##############################################
+# # 新增修改:仅取前100个样本
+# # 截取前100个有效样本(若不足100则取全部)
+# valid_query_ids = valid_query_ids[:300]
+# all_target_captions_soft = all_target_captions_soft[:300]
+# all_set_member_idx = all_set_member_idx[:300]
+# nn_result = nn_result[:300]
+# total_samples = len(valid_query_ids)
+# ##############################################
+
+# print(f"Evaluating on first {total_samples} samples")
+
+# # 计算召回指标(仅使用有效样本)
+# out = []
+# total_samples = len(valid_query_ids)
+
+# # Recall@K (全局检索)
+# for k in [1, 5, 10, 50]:
+# total_score = 0.0
+# for i in range(total_samples):
+# # 过滤参考图像本身
+# filtered_results = [rid for rid in nn_result[i] if rid != valid_query_ids[i]]
+# top_k = filtered_results[:k]
+
+# # 计算最佳匹配分数
+# best_score = max(
+# (score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k),
+# default=0.0
+# )
+# total_score += best_score
+
+# recall = (total_score / total_samples) * 100
+# out.append((f"recall_top{k}_correct_composition", recall))
+
+# # Recall_subset@K (子集检索)
+# for k in [1, 2, 3]:
+# total_score = 0.0
+# for i in range(total_samples):
+# # 双重过滤:子集成员且非参考图像
+# subset_results = [
+# rid for rid in nn_result[i]
+# if rid in all_set_member_idx[i] and rid != valid_query_ids[i]
+# ]
+# top_k_subset = subset_results[:k]
+
+# # 计算最佳匹配分数
+# best_score = max(
+# (score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k_subset),
+# default=0.0
+# )
+# total_score += best_score
+
+# recall = (total_score / total_samples) * 100
+# out.append((f"recall_inset_top{k}_correct_composition", recall))
+
+# # 打印和保存结果
+# print("\n" + "="*30 + " Evaluation Results " + "="*30)
+# for metric, value in out:
+# print(f"{metric:<40}: {value:.4f}")
+
+# output_dir = os.path.dirname(retrieval_results_path)
+# output_filename = "evaluation_results_quchongnew.txt"
+# output_path = os.path.join(output_dir, output_filename)
+
+# with open(output_path, 'w') as f:
+# f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n")
+# for metric, value in out:
+# line = f"{metric:<40}: {value:.4f}\n"
+# f.write(line)
+# print(f"\nResults saved to {output_path}")
+
+# return out
+
+
+# 目前来看这个函数没有太大问题
+def evaluate_cirr_scores() -> List[Tuple[str, float]]:
+
+ # 设置数据集和检索结果的路径
+ dataset_path = "/home/zt/data/BrushEdit/cirr/captions/cap.rc2.val.json"
+ retrieval_results_path = "/home/zt/data/BrushEdit/retrieval_results_quchongnew.json"
+
+ # 加载数据集
+ with open(dataset_path, 'r') as f:
+ dataset = json.load(f)
+ print(f"Total samples in dataset: {len(dataset)}")
+
+ # 加载检索结果
+ with open(retrieval_results_path, 'r') as f:
+ retrieval_results = json.load(f)
+ print(f"Total queries in retrieval results: {len(retrieval_results)}")
+
+ # 数据结构初始化
+ valid_query_ids = [] # 存储匹配成功的query_id
+ all_target_captions_soft = [] # 存储匹配样本的target_soft
+ all_set_member_idx = [] # 存储匹配样本的set members
+ nn_result = [] # 存储匹配样本的检索结果
+
+ # 构建匹配数据结构(新增双重验证)
+ for sample in dataset:
+ query_id = str(sample["reference"])
+ caption = sample["caption"]
+
+ # 获取对应的检索结果条目
+ retrieved_entry = retrieval_results.get(query_id)
+
+ # 双重验证:query_id存在且caption匹配
+ if retrieved_entry and retrieved_entry["original_caption"] == caption:
+ valid_query_ids.append(query_id)
+ all_target_captions_soft.append(sample["target_soft"])
+ all_set_member_idx.append(sample["img_set"]["members"])
+ # 提取检索结果并转换为id列表
+ retrieved_items = retrieved_entry["retrieved_results"]
+ nn_result.append([item["retrieved_id"] for item in retrieved_items])
+
+ print(f"Valid matched samples after verification: {len(valid_query_ids)}")
+
+ # 计算召回指标(仅使用有效样本)
+ out = []
+ total_samples = len(valid_query_ids)
+
+ # Recall@K (全局检索)
+ for k in [1, 5, 10, 50]:
+ total_score = 0.0
+ for i in range(total_samples):
+ # 过滤参考图像本身
+ filtered_results = [rid for rid in nn_result[i] if rid != valid_query_ids[i]]
+ top_k = filtered_results[:k]
+
+ # 计算最佳匹配分数
+ best_score = max(
+ (score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k),
+ default=0.0
+ )
+ total_score += best_score
+
+ recall = (total_score / total_samples) * 100
+ out.append((f"recall_top{k}_correct_composition", recall))
+
+ # Recall_subset@K (子集检索)
+ for k in [1, 2, 3]:
+ total_score = 0.0
+ for i in range(total_samples):
+ # 双重过滤:子集成员且非参考图像
+ subset_results = [
+ rid for rid in nn_result[i]
+ if rid in all_set_member_idx[i] and rid != valid_query_ids[i]
+ ]
+ top_k_subset = subset_results[:k]
+
+ # 计算最佳匹配分数
+ best_score = max(
+ (score for target_id, score in all_target_captions_soft[i].items() if target_id in top_k_subset),
+ default=0.0
+ )
+ total_score += best_score
+
+ recall = (total_score / total_samples) * 100
+ out.append((f"recall_inset_top{k}_correct_composition", recall))
+
+ # 打印和保存结果
+ print("\n" + "="*30 + " Evaluation Results " + "="*30)
+ for metric, value in out:
+ print(f"{metric:<40}: {value:.4f}")
+
+ output_dir = os.path.dirname(retrieval_results_path)
+ output_filename = "evaluation_results_quchongnew.txt"
+ output_path = os.path.join(output_dir, output_filename)
+
+ with open(output_path, 'w') as f:
+ f.write("\n" + "="*30 + " Evaluation Results " + "="*30 + "\n")
+ for metric, value in out:
+ line = f"{metric:<40}: {value:.4f}\n"
+ f.write(line)
+ print(f"\nResults saved to {output_path}")
+
+ return out
+
+
+# if __name__ == "__main__":
+#注意使用的是qwen
+
+ # process_circo_val_images()
+ # process_circo_test_images()
+ # process_cirr_images()
+
+
+
+
+
+
+# def process_circo_images():
+
+# if not all([vlm_model, sam_predictor, groundingdino_model]):
+# raise RuntimeError("Required models not initialized")
+
+# # Define paths
+# dev_dir = Path("/home/zt/data/BrushEdit/CIRCO/img_raw/dev")
+# cap_file = Path("/home/zt/data/BrushEdit/CIRCO/annotations/val.json")
+
+# output_dirs = {
+# "edited": Path("/home/zt/data/BrushEdit/CIRCO/img_paint/circo_edited"),
+# "mask": Path("/home/zt/data/BrushEdit/CIRCO/img_paint/circo_mask"),
+# "masked": Path("/home/zt/data/BrushEdit/CIRCO/img_paint/circo_masked")
+# }
+# output_json_path = Path("/home/zt/data/BrushEdit/CIRCO/image_paint.json")
+# descriptions = {}
+
+# # Create output directories
+# for dir_path in output_dirs.values():
+# dir_path.mkdir(parents=True, exist_ok=True)
+
+# # Load captions
+# with open(cap_file, 'r') as f:
+# captions = json.load(f)
+
+# for img_path in dev_dir.glob("*.jpg"):
+# base_name = img_path.stem
+# # 提取后六位作为参考ID
+# reference_part = base_name[-6:]
+# # 将JSON中的reference_img_id转换为字符串后比较
+# caption = next(
+# (item["relative_caption"] for item in captions
+# if str(item.get("reference_img_id")) == reference_part),
+# None
+# )
+
+# if not caption:
+# print(f"Warning: No caption for {base_name}")
+# continue
+
+# try:
+# # 构造空alpha通道(全0)
+# rgb_image = Image.open(img_path).convert("RGB")
+# empty_alpha = Image.new("L", rgb_image.size, 0) # 全透明alpha通道
+# image = Image.merge("RGBA", (*rgb_image.split(), empty_alpha))
+
+# # 调用init_img初始化
+# base = {"background": image, "layers": [image]}
+# init_results = init_img(
+# base=base,
+# init_type="custom", # 使用自定义初始化
+# prompt=caption,
+# aspect_ratio="Custom resolution",
+# example_change_times=0
+# )
+
+# # 获取初始化后的参数
+# input_image = init_results[0]
+# original_image = init_results[1]
+# original_mask = init_results[2]
+
+# # 正确设置process参数
+# result_images, mask_images, masked_images, _, target_description, _ = process(
+# input_image=input_image,
+# original_image=original_image,
+# original_mask=original_mask, # 传递初始化后的mask
+# prompt=caption,
+# negative_prompt="ugly, low quality",
+# control_strength=1.0,
+# seed=648464818,
+# randomize_seed=False,
+# guidance_scale=7.5,
+# num_inference_steps=50,
+# num_samples=1,
+# blending=True,
+# category=None,
+# target_prompt="",
+# resize_default=True,
+# aspect_ratio_name="Custom resolution",
+# invert_mask_state=False
+# )
+
+# # Save images
+# output_dirs["edited"].mkdir(exist_ok=True)
+# result_images[0].save(output_dirs["edited"] / f"{base_name}.jpg")
+# mask_images[0].save(output_dirs["mask"] / f"{base_name}_mask.jpg")
+# masked_images[0].save(output_dirs["masked"] / f"{base_name}_masked.jpg")
+
+# # Generate BLIP2 description
+# blip2_desc, _ = generate_blip2_description(input_image)
+
+# descriptions[base_name] = {
+# "original_caption": caption,
+# "blip2_description": blip2_desc,
+# "llm_enhanced_caption": target_description
+# }
+
+# with open(output_json_path, 'w') as f:
+# json.dump(descriptions, f, indent=4) # indent保持可读性
+
+# print(f"Processed {base_name}")
+
+# except Exception as e:
+# print(f"Error processing {base_name}: {str(e)}")
+# continue
+
+# print("Processing completed!")
+
+
+# @torch.no_grad()
+# def batch_mix_and_search_circo(
+# json_path: str = "/home/zt/data/BrushEdit/CIRCO/image_paint.json",
+# image_dir: str = "/home/zt/data/BrushEdit/CIRCO/img_paint/circo_edited",
+# alpha: float = 0.6,
+# batch_size: int = 32,
+# output_json_path: str = "circo_retrieval_results.json"
+# ) -> Dict[str, List[Dict]]:
+
+# # 加载索引和元数据
+# index = faiss.read_index("/home/zt/data/BrushEdit/CIRCO/img_raw/dev/dev_knn.index")
+# metadata = pd.read_parquet("/home/zt/data/BrushEdit/CIRCO/img_raw/dev/dev_embedding_folder/metadata/metadata_0.parquet")
+# all_index_ids = metadata["image_path"].tolist()
+
+# # 加载并验证输入数据
+# with open(json_path) as f:
+# samples = json.load(f)
+
+# valid_samples = []
+# image_dir = Path(image_dir)
+# for image_id, sample_info in samples.items(): # 关键修改点
+# img_path = image_dir / f"{image_id}.jpg" # 直接用字典的key作为image_id
+# if img_path.exists():
+# valid_samples.append(
+# (image_id, img_path, sample_info['llm_enhanced_caption']) # 从value中取caption
+# )
+
+# # 初始化结果字典
+# results = {}
+# total_samples = len(valid_samples)
+
+# # 分批次处理
+# for batch_idx in range(0, total_samples, batch_size):
+# batch_end = min(batch_idx + batch_size, total_samples)
+# current_batch = valid_samples[batch_idx:batch_end]
+
+# # 批量处理图像
+# batch_images = [Image.open(s[1]).convert("RGB") for s in current_batch]
+# image_inputs = clip_processor(images=batch_images, return_tensors="pt").to(device)
+# image_features = clip_model.get_image_features(**image_inputs)
+# image_features = nn.functional.normalize(image_features, dim=-1)
+
+# # 批量处理文本
+# batch_texts = [s[2] for s in current_batch]
+# text_inputs = clip_processor(
+# text=batch_texts,
+# return_tensors="pt",
+# padding=True,
+# truncation=True
+# ).to(device)
+# text_features = clip_model.get_text_features(**text_inputs)
+# text_features = nn.functional.normalize(text_features, dim=-1)
+
+# # 混合特征
+# mixed_features = (1 - alpha) * image_features + alpha * text_features
+# mixed_features = nn.functional.normalize(mixed_features, dim=-1)
+
+# # 批量检索
+# query_features = mixed_features.cpu().numpy().astype("float32")
+# distances, indices = index.search(query_features, 100)
+
+# # 保存当前批次结果
+# for (sample_id, _, _), dist_row, idx_row in zip(current_batch, distances, indices):
+# results[sample_id] = []
+# for distance, idx in zip(dist_row, idx_row):
+# if 0 <= idx < len(all_index_ids):
+# results[sample_id].append({
+# "retrieved_id": all_index_ids[idx],
+# "score": float(distance)
+# })
+
+# # 保存结果到JSON
+# with open(output_json_path, 'w') as f:
+# json.dump(results, f, indent=2, ensure_ascii=False)
+
+# print("Retrieving completed!")
+# return results
+
+
+block = gr.Blocks(
+ theme=gr.themes.Soft(
+ radius_size=gr.themes.sizes.radius_none,
+ text_size=gr.themes.sizes.text_md
+ )
+ )
+
+with block as demo:
+ with gr.Row():
+ with gr.Column():
+ gr.HTML(head)
+ gr.HTML(descriptions)
+ with gr.Accordion(label="🧭 小白也能秒懂的魔法指南:", open=True, elem_id="accordion"):
+ with gr.Row(equal_height=True):
+ gr.HTML(instructions)
+
+ original_image = gr.State(value=None)
+ original_mask = gr.State(value=None)
+ category = gr.State(value=None)
+ status = gr.State(value=None)
+ invert_mask_state = gr.State(value=False)
+ example_change_times = gr.State(value=0)
+ deepseek_verified = gr.State(value=False)
+ blip2_description = gr.State(value="")
+ enhanced_description = gr.State(value="")
+ decomposed_description = gr.State(value="")
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Group():
+ input_image = gr.ImageEditor(
+ label="参考图像",
+ type="pil",
+ brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
+ layers = False,
+ interactive=True,
+ # height=1024,
+ height=408,
+ sources=["upload"],
+ placeholder="🫧 点击此处或下面的图标上传图像 🫧",
+ )
+ prompt = gr.Textbox(label="修改指令", placeholder="😜 在此处输入你对参考图像的修改预期 😜", value="",lines=2)
+
+ with gr.Group():
+
+ with gr.Row():
+ mask_button = gr.Button("💎 掩膜生成")
+ invert_mask_button = gr.Button("👐 掩膜翻转")
+ # random_mask_button = gr.Button("⭕️ 随机掩膜")
+ with gr.Row():
+ masked_gallery = gr.Gallery(label="掩膜图像", show_label=True, preview=True, height=360)
+ mask_gallery = gr.Gallery(label="掩膜", show_label=True, preview=True, height=360)
+
+
+ with gr.Accordion("高级掩膜选项", open=False, elem_id="accordion1"):
+ dilation_size = gr.Slider(
+ label="每次放缩的尺度: ", show_label=True,minimum=0, maximum=50, step=1, value=20
+ )
+ with gr.Row():
+ dilation_mask_button = gr.Button("放大掩膜")
+ erosion_mask_button = gr.Button("缩小掩膜")
+
+ moving_pixels = gr.Slider(
+ label="每次移动的像素:", show_label=True, minimum=0, maximum=50, value=4, step=1
+ )
+ with gr.Row():
+ move_left_button = gr.Button("左移")
+ move_right_button = gr.Button("右移")
+ with gr.Row():
+ move_up_button = gr.Button("上移")
+ move_down_button = gr.Button("下移")
+
+
+
+ with gr.Column():
+ with gr.Row():
+ deepseek_key = gr.Textbox(label="LLM API密钥", value="sk-d145b963a92649a88843caeb741e8bbc", lines=1, container=False, type="password")
+ verify_deepseek = gr.Button("🔑 验证密钥", scale=0)
+
+ blip2_output = gr.Textbox(label="1. 原图描述(BLIP2生成)", placeholder="🖼️ 上传图片后自动生成图片描述 🖼️", lines=2, interactive=True)
+
+ with gr.Row():
+ target_prompt = gr.Textbox(label="2. 整合增强版", lines=4, interactive=True, placeholder="🚀 点击图片编辑同时生成增强描述 or 点击右侧按钮单独生成增强描述 🚀")
+ enhance_button = gr.Button("✨ 智能整合")
+
+ with gr.Row():
+ decomposed_output = gr.Textbox(label="3. 结构分解版", lines=4, interactive=True, placeholder="📝 点击右侧按钮生成结构化描述 📝")
+ decompose_button = gr.Button("🔧 结构分解")
+
+
+
+ with gr.Group():
+ run_button = gr.Button("💫 图像编辑")
+ result_gallery = gr.Gallery(label="💥 编辑结果", show_label=True, columns=2, preview=True, height=360)
+
+ with gr.Accordion("高级编辑选项", open=False, elem_id="accordion1"):
+ vlm_model_dropdown = gr.Dropdown(label="VLM 模型", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
+
+ with gr.Group():
+ with gr.Row():
+ # GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
+ GPT4o_KEY = gr.Textbox(label="VLM API密钥", value="", lines=2, container=False, type="password")
+ GPT4o_KEY_submit = gr.Button("🔑 验证密钥", scale=0)
+
+ aspect_ratio = gr.Dropdown(label="输出纵横比", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
+ resize_default = gr.Checkbox(label="短边裁剪到640像素", value=True)
+ base_model_dropdown = gr.Dropdown(label="基础模型", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
+ negative_prompt = gr.Textbox(label="负向提示", max_lines=5, placeholder="请输入你的负向提示", value='ugly, low quality',lines=1)
+ control_strength = gr.Slider(label="控制强度: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01)
+ with gr.Group():
+ seed = gr.Slider(label="种子: ", minimum=0, maximum=2147483647, step=1, value=648464818)
+ randomize_seed = gr.Checkbox(label="随机种子", value=False)
+ blending = gr.Checkbox(label="混合模式", value=True)
+ num_samples = gr.Slider(label="生成个数", minimum=0, maximum=4, step=1, value=2)
+ with gr.Group():
+ with gr.Row():
+ guidance_scale = gr.Slider(label="指导尺度", minimum=1, maximum=12, step=0.1, value=7.5)
+ num_inference_steps = gr.Slider(label="推理步数", minimum=1, maximum=50, step=1, value=50)
+ # target_prompt = gr.Textbox(label="Input Target Prompt", max_lines=5, placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)", value='', lines=2)
+
+
+
+ init_type = gr.Textbox(label="Init Name", value="", visible=False)
+ example_type = gr.Textbox(label="Example Name", value="", visible=False)
+
+ with gr.Row():
+ reset_button = gr.Button("🆕 重置检索")
+ retrieve_button = gr.Button("🔍 开始检索")
+
+ with gr.Row():
+ retrieve_gallery = gr.Gallery(label="🎊 检索结果", show_label=True, columns=10, preview=True, height=660)
+
+
+ with gr.Row():
+ example = gr.Examples(
+ label="Quick Example",
+ examples=EXAMPLES,
+ inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
+ examples_per_page=10,
+ cache_examples=False,
+ visible=False
+ )
+
+
+ with gr.Accordion(label="🎬 隐藏玩法大公开:", open=True, elem_id="accordion"):
+ with gr.Row(equal_height=True):
+ gr.HTML(tips)
+
+ with gr.Row():
+ gr.HTML(citation)
+
+ ## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
+ ## And we need to solve the conflict between the upload and change example functions.
+ input_image.upload(
+ init_img,
+ [input_image, init_type, prompt, aspect_ratio, example_change_times],
+ [input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
+
+ )
+ example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
+
+
+ ## vlm and base model dropdown
+ vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
+ base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
+
+ GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
+
+
+ ips=[input_image,
+ original_image,
+ original_mask,
+ prompt,
+ negative_prompt,
+ control_strength,
+ seed,
+ randomize_seed,
+ guidance_scale,
+ num_inference_steps,
+ num_samples,
+ blending,
+ category,
+ target_prompt,
+ resize_default,
+ aspect_ratio,
+ invert_mask_state]
+
+ ## run brushedit
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
+
+
+ ## mask func
+ mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
+ # random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
+ dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
+ erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
+ invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
+
+ ## reset func
+ reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, blip2_output, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, blip2_output, target_prompt, resize_default, invert_mask_state])
+
+ input_image.upload(fn=generate_blip2_description, inputs=[input_image], outputs=[blip2_description, blip2_output])
+ verify_deepseek.click(fn=verify_deepseek_api, outputs=[deepseek_verified, deepseek_key])
+ # enhance_button.click(fn=enhance_description, inputs=[blip2_output, prompt], outputs=[enhanced_description, enhanced_output])
+ # decompose_button.click(fn=decompose_description, inputs=[enhanced_output], outputs=[decomposed_description, decomposed_output])
+ # retrieve_button.click(fn=mix_and_search, inputs=[enhanced_output, result_gallery], outputs=[retrieve_gallery])
+
+ enhance_button.click(fn=enhance_description, inputs=[blip2_output, prompt], outputs=[enhanced_description, target_prompt])
+ decompose_button.click(fn=decompose_description, inputs=[target_prompt], outputs=[decomposed_description, decomposed_output])
+ retrieve_button.click(fn=mix_and_search, inputs=[target_prompt, result_gallery], outputs=[retrieve_gallery])
+
+demo.launch(server_name="0.0.0.0", server_port=1234, share=True)
+
+