AnimeIns_CPU / seg_script.py
svjack's picture
Upload seg_script.py
e81f668 verified
raw
history blame
3.59 kB
'''
python seg_script.py Genshin_Impact_Images Genshin_Impact_Images_Seg
'''
import os
import cv2
import argparse
from PIL import Image
import numpy as np
from tqdm import tqdm
from pathlib import Path
from animeinsseg import AnimeInsSeg, AnimeInstances
from animeinsseg.anime_instances import get_color
# 设置模型路径
ckpt = r'models/AnimeInstanceSegmentation/rtmdetl_e60.ckpt'
mask_thres = 0.3
instance_thres = 0.3
refine_kwargs = {'refine_method': 'refinenet_isnet'} # 如果不使用 refinenet,设置为 None
# refine_kwargs = None
# 初始化模型
net = AnimeInsSeg(ckpt, mask_thr=mask_thres, refine_kwargs=refine_kwargs)
def process_image(image_path, output_dir):
# 读取图像
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 推理
instances: AnimeInstances = net.infer(
img,
output_type='numpy',
pred_score_thr=instance_thres
)
# 初始化输出图像
drawed = img.copy()
im_h, im_w = img.shape[:2]
# 如果没有检测到对象,直接返回原图
if instances.bboxes is None:
return
# 保存绘制后的图像(只保存一次)
base_name = Path(image_path).stem
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
Image.fromarray(drawed).save(output_path / f"{base_name}_drawed.png")
# 处理每个实例
for ii, (xywh, mask) in enumerate(zip(instances.bboxes, instances.masks)):
color = get_color(ii)
mask_alpha = 0.5
linewidth = max(round(sum(img.shape) / 2 * 0.003), 2)
# 绘制边界框
p1, p2 = (int(xywh[0]), int(xywh[1])), (int(xywh[2] + xywh[0]), int(xywh[3] + xywh[1]))
cv2.rectangle(drawed, p1, p2, color, thickness=linewidth, lineType=cv2.LINE_AA)
# 绘制掩码
p = mask.astype(np.float32)
blend_mask = np.full((im_h, im_w, 3), color, dtype=np.float32)
alpha_msk = (mask_alpha * p)[..., None]
alpha_ori = 1 - alpha_msk
drawed = drawed * alpha_ori + alpha_msk * blend_mask
drawed = drawed.astype(np.uint8)
# 裁剪图像
x1, y1, x2, y2 = int(xywh[0]), int(xywh[1]), int(xywh[0] + xywh[2]), int(xywh[1] + xywh[3])
cropped_img = img[y1:y2, x1:x2]
cropped_mask = mask[y1:y2, x1:x2]
# 创建透明通道的边缘图
alpha_channel = (cropped_mask * 255).astype(np.uint8)
rgba_image = np.dstack((cropped_img, alpha_channel))
# 保存裁剪后的图像和分割后的图像(文件名包含实例下标)
Image.fromarray(cropped_img).save(output_path / f"{base_name}_cropped_{ii}.png")
Image.fromarray(rgba_image, 'RGBA').save(output_path / f"{base_name}_segmented_{ii}.png")
def main():
parser = argparse.ArgumentParser(description="Anime Instance Segmentation")
parser.add_argument("input_path", type=str, help="Path to the input image or folder")
parser.add_argument("output_dir", type=str, help="Path to the output directory")
args = parser.parse_args()
input_path = Path(args.input_path)
output_dir = Path(args.output_dir)
if input_path.is_file():
process_image(input_path, output_dir)
elif input_path.is_dir():
image_paths = list(input_path.rglob("*.png")) + list(input_path.rglob("*.jpg"))
for image_path in tqdm(image_paths, desc="Processing images"):
process_image(image_path, output_dir)
else:
print("Invalid input path. Please provide a valid image or folder path.")
if __name__ == "__main__":
main()