import io import os import sys import cv2 import requests import numpy as np from io import BytesIO from pathlib import Path from PIL import Image, ImageFile, ImageFilter, ImageEnhance, ImageOps # from misc import get_potrait import torch import contextlib from data.base_dataset import get_transform from models.cut_model import CUTModel from util.util import tensor2im from argparse import Namespace from pathlib import Path from copy import deepcopy from codeformer.app import inference_app from rembg import remove, new_session import gradio as gr # CUTGAN input options OPT = Namespace( batch_size=1, checkpoints_dir="cyclegan", crop_size=256, # dataroot=".", dataset_mode="unaligned", direction="AtoB", display_id=-1, display_winsize=256, epoch="latest", eval=False, gpu_ids=[], nce_layers="0,4,8,12,16", nce_idt=False, lambda_NCE=10.0, lambda_GAN=1.0, init_gain=0.02, nce_includes_all_negatives_from_minibatch=False, init_type="xavier", normG="instance", no_antialias=False, no_antialias_up=False, netF="mlp_sample", netF_nc=256, nce_T=0.07, num_patches=256, CUT_mode="FastCUT", input_nc=3, isTrain=False, load_iter=0, load_size=256, max_dataset_size=float("inf"), model="CUT", n_layers_D=3, name=None, ndf=64, netD="basic", netG="resnet_9blocks", ngf=64, no_dropout=True, no_flip=True, num_test=50, num_threads=4, output_nc=3, phase="test", preprocess="scale_width", random_scale_max=3.0, results_dir="./results/", serial_batches=True, suffix="", verbose=False, ) class SingleImageDataset(torch.utils.data.Dataset): """dataset with precisely one image""" def __init__(self, img, preprocess): img = preprocess(img) self.img = img def __getitem__(self, i): return self.img def __len__(self): return 1 fp = "cyclegan/EyeFastcut/latest_net_G.pth" opt = deepcopy(OPT) model_name = "EyeFastcut" opt.name = model_name if opt.verbose: # model = load_model(opt, model_fp) model = CUTModel(opt).netG model.load_state_dict(torch.load(fp)) else: with contextlib.redirect_stdout(io.StringIO()): # model = load_model(opt, model_fp) model = CUTModel(opt).netG model.load_state_dict(torch.load(fp)) # inference code for single image - cutgan """reference inference code: https://www.jeremyafisher.com/running-cyclegan-programmatically.html """ def cutgan(img: Image) -> Image: img = img.convert("RGB") data_loader = torch.utils.data.DataLoader( SingleImageDataset(img, get_transform(opt)), batch_size=1 ) data = next(iter(data_loader)) with torch.no_grad(): pred = model(data) pred_arr = tensor2im(pred) pred_img = Image.fromarray(pred_arr) return pred_img # image resize function def imsize(img, max_size=512, maintain_aspect_ratio=True): # calculate desired dimensions if maintain_aspect_ratio: if img.height > max_size or img.width > max_size: # if width > height: if img.width > img.height: desired_width = max_size desired_height = int(img.height / (img.width / max_size)) # if height > width: elif img.height > img.width: desired_height = max_size desired_width = int(img.width / (img.height / max_size)) else: desired_height = max_size desired_width = max_size else: desired_width = img.width desired_height = img.height else: desired_width = max_size desired_height = max_size # round desired dimensions to nearest multiple of 8 desired_width = (desired_width // 8) * 8 desired_height = (desired_height // 8) * 8 # resize image desired_dimensions = (desired_width, desired_height) transition_image = img.resize(desired_dimensions) return transition_image def rem_glass(input_img): w_0, h_0 = input_img.size #resizing im = imsize(input_img, max_size=256, maintain_aspect_ratio=False) width, height = im.size ori_im = im.copy() # get cutout and mask session = new_session("u2net_human_seg") im = remove(ori_im,False,240,10,20, session, only_mask=False) mask = remove(ori_im,False,240,10,20, session, only_mask=True) # send image to model to remove glasses im = cutgan(im) # composite original image and output based on mask w, h = im.size ori_im = ori_im.resize((w, h)) mask = mask.resize((w, h)) img = Image.composite(im, ori_im, mask) # upscale the image # img = upscale(img, model_cran_v2) # scale image to original size img = img.resize((w_0, h_0)) img.save("removal.png") inference_app( image="removal.png", background_enhance=False, face_upsample=False, upscale=2, codeformer_fidelity=0.5,) return Image.open('output/out.png') demo = gr.Interface(rem_glass, gr.inputs.Image(type="pil"), gr.outputs.Image(type="pil"),) if __name__ == "__main__": demo.launch()