Spaces:
Running
Running
Johann.Haselberger (PEG-AS)
Merge branch 'main' of https://huggingface.co/spaces/jHaselberger/cool-avatar
4221f6e
import gradio as gr | |
import numpy as np | |
import cv2 | |
from PIL import Image, ImageOps, ImageDraw | |
import os | |
import torch | |
from transformers import AutoModelForImageSegmentation | |
from torchvision import transforms | |
import hashlib | |
import re | |
import urllib.request as urllib2 | |
from loguru import logger | |
# Set up model and transformations | |
def get_background_removal_model(): | |
try: | |
# Using BiRefNet model for background removal | |
model = AutoModelForImageSegmentation.from_pretrained( | |
"ZhengPeng7/BiRefNet", trust_remote_code=True | |
) | |
# Use CPU if CUDA is not available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
return model, device | |
except Exception as e: | |
print(f"Error loading background removal model: {e}") | |
return None, None | |
# Set up image transformation | |
transform_image = transforms.Compose( | |
[ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
# Cache for storing background removal results | |
bg_removal_cache = {} | |
def get_image_hash(image): | |
"""Generate a hash for an image to use as cache key""" | |
if image is None: | |
return None | |
# Convert to bytes and generate hash | |
img_byte_arr = image.tobytes() | |
img_hash = hashlib.md5(img_byte_arr).hexdigest() | |
# Include image dimensions in the hash to ensure uniqueness | |
return f"{img_hash}_{image.width}_{image.height}" | |
def remove_background(image, model_data): | |
if model_data[0] is None: | |
return None, None | |
# Generate a hash for the image to use as cache key | |
img_hash = get_image_hash(image) | |
# Check if result is already in cache | |
if img_hash in bg_removal_cache: | |
logger.info("Using cached background removal result") | |
return bg_removal_cache[img_hash] | |
model, device = model_data | |
try: | |
logger.info("Starting background removal process") | |
# Convert image to RGB if needed | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Store original size for later resizing | |
image_size = image.size | |
# Apply transformations and move to device | |
input_images = transform_image(image).unsqueeze(0).to(device) | |
# Run prediction | |
with torch.no_grad(): | |
preds = model(input_images)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
# Convert prediction to PIL image | |
pred_pil = transforms.ToPILImage()(pred) | |
# Resize mask back to original image size | |
mask = pred_pil.resize(image_size) | |
# Create a copy of the original image and apply alpha channel | |
result_image = image.copy() | |
result_image.putalpha(mask) | |
# Cache the result | |
result = (result_image, np.array(mask)) | |
bg_removal_cache[img_hash] = result | |
logger.info("Background removal process completed") | |
return result | |
except Exception as e: | |
logger.error(f"Error during background removal: {e}") | |
return None, None | |
def parse_color(color_str): | |
"""Parse different color formats including rgba strings""" | |
if isinstance(color_str, tuple): | |
# If it's already a tuple, make sure it has alpha | |
if len(color_str) == 3: | |
return color_str + (255,) | |
return color_str | |
if isinstance(color_str, str): | |
# Handle hex color format | |
if color_str.startswith("#"): | |
if len(color_str) == 7: # #RRGGBB format | |
r = int(color_str[1:3], 16) | |
g = int(color_str[3:5], 16) | |
b = int(color_str[5:7], 16) | |
return (r, g, b, 255) | |
else: | |
# Fallback to white if format is unexpected | |
return (255, 255, 255, 255) | |
# Handle rgba() format from Gradio color picker | |
rgba_match = re.match(r"rgba?\(([^)]+)\)", color_str) | |
if rgba_match: | |
values = [float(x.strip()) for x in rgba_match.group(1).split(",")] | |
r = min(255, int(values[0])) | |
g = min(255, int(values[1])) | |
b = min(255, int(values[2])) | |
# Handle alpha if present | |
a = 255 | |
if len(values) > 3: | |
a = min(255, int(values[3] * 255)) | |
return (r, g, b, a) | |
# For named colors, return as is for PIL to handle | |
return color_str | |
# Default fallback | |
return (255, 255, 255, 255) # White | |
def add_person_border(image, mask, border_size, border_color="white"): | |
"""Add a border around the person based on the segmentation mask""" | |
if border_size == 0: | |
return image | |
# Convert mask to binary | |
binary_mask = (np.array(mask) > 4).astype(np.uint8) * 255 | |
# Dilate the mask to create the border | |
kernel = np.ones((border_size * 2 + 1, border_size * 2 + 1), np.uint8) | |
dilated_mask = cv2.dilate(binary_mask, kernel, iterations=1) | |
# Create border mask (includes both the person area and border area) | |
border_mask_pil = Image.fromarray(dilated_mask) | |
# Create an image with the border color (white) | |
border_color_rgba = parse_color("white") # Default white border | |
border_img = Image.new("RGBA", image.size, color=border_color_rgba) | |
# Create transparent image for result | |
result = Image.new("RGBA", image.size, (0, 0, 0, 0)) | |
# First paste the white border shape (which includes both border and person area) | |
result.paste(border_img, (0, 0), border_mask_pil) | |
# Then paste the original image on top, but only the non-transparent parts | |
# This will show the original person on top of the white area | |
result.paste(image, (0, 0), Image.fromarray(binary_mask)) | |
return result | |
def detect_face(image): | |
"""Detect the largest face in the image and return its bounding box""" | |
logger.info("Starting face detection") | |
# Convert PIL image to OpenCV format | |
img_cv = np.array(image.convert("RGB")) | |
img_cv = img_cv[:, :, ::-1].copy() # Convert RGB to BGR for OpenCV | |
# Load the Haar cascade for face detection | |
face_cascade_path = cv2.data.haarcascades + "haarcascade_frontalface_default.xml" | |
face_cascade = cv2.CascadeClassifier(face_cascade_path) | |
# Convert to grayscale for face detection | |
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) | |
# Detect faces | |
faces = face_cascade.detectMultiScale(gray, 1.1, 4) | |
if len(faces) == 0: | |
logger.warning("No faces detected") | |
return None | |
# Find the largest face | |
largest_face = None | |
max_area = 0 | |
for x, y, w, h in faces: | |
if w * h > max_area: | |
max_area = w * h | |
largest_face = (x, y, w, h) | |
logger.info(f"Largest face detected at: {largest_face}") | |
return largest_face | |
def center_portrait(portrait, face_box, target_width, target_height, zoom_level=1.0): | |
"""Center the portrait based on face position and crop to avoid blurriness""" | |
if face_box is None: | |
# If no face detected, just center the portrait | |
return portrait.crop((0, 0, target_width, target_height)), (0, 0) | |
x, y, w, h = face_box | |
# Calculate face center | |
face_center_x = x + w // 2 | |
face_center_y = y + h // 2 | |
# Calculate crop box dimensions | |
crop_width = int(target_width / zoom_level) | |
crop_height = int(target_height / zoom_level) | |
# Ensure the crop box stays within the image bounds | |
left = max(0, face_center_x - crop_width // 2) | |
top = max(0, face_center_y - crop_height // 2) | |
right = min(portrait.width, left + crop_width) | |
bottom = min(portrait.height, top + crop_height) | |
# Adjust left and top if the crop box is smaller than the target dimensions | |
left = max(0, right - crop_width) | |
top = max(0, bottom - crop_height) | |
# Crop the image | |
cropped_img = portrait.crop((left, top, right, bottom)) | |
# Center the cropped image on a transparent canvas | |
centered_img = Image.new("RGBA", (target_width, target_height), (0, 0, 0, 0)) | |
offset_x = (target_width - cropped_img.width) // 2 | |
offset_y = (target_height - cropped_img.height) // 2 | |
centered_img.paste(cropped_img, (offset_x, offset_y), cropped_img) | |
return centered_img, (offset_x, offset_y) | |
def process_portrait( | |
input_image, border_size=10, bg_color="#0000FF", zoom_level=1.0, erode_size=5, circular_overlay=False | |
): | |
if input_image is None: | |
return None | |
# Global model instance to avoid reloading | |
global model_instance | |
if "model_instance" not in globals(): | |
logger.info("Loading background removal model...") | |
model_instance = get_background_removal_model() | |
logger.info("Processing image...") | |
result = remove_background(input_image, model_instance) | |
if result[0] is None: | |
logger.warning("Failed to remove background, returning original image") | |
return input_image | |
person_img, mask = result | |
# Detect face before any transformations | |
face_box = detect_face(input_image) | |
if face_box: | |
logger.info(f"Face detected at: {face_box}") | |
else: | |
logger.warning("No face detected, will center the entire portrait") | |
# Expand the mask by 3 pixels | |
expanded_mask = cv2.erode( | |
np.array(mask), np.ones((erode_size, erode_size), np.uint8), iterations=1 | |
) | |
expanded_mask_pil = Image.fromarray(expanded_mask) | |
mask = expanded_mask_pil | |
logger.info("Adding white border...") | |
# Add white border only around the person | |
bordered_img = add_person_border(person_img, mask, border_size, "white") | |
logger.info(f"Creating colored background with color: {bg_color}") | |
# Parse the background color | |
bg_color_rgba = parse_color(bg_color) | |
# Create colored background | |
width, height = bordered_img.size | |
bg_image = Image.new("RGBA", (width, height), color=bg_color_rgba) | |
# Center the portrait based on face location and apply zoom | |
logger.info(f"Applying zoom level: {zoom_level}") | |
centered_portrait, offset = center_portrait( | |
bordered_img, face_box, width, height, zoom_level | |
) | |
# Create the final composite | |
final_image = Image.alpha_composite(bg_image, centered_portrait) | |
# Crop the final image to the target dimensions | |
crop_width = int(width / zoom_level) | |
crop_height = int(height / zoom_level) | |
left = (width - crop_width) // 2 | |
top = (height - crop_height) // 2 | |
right = left + crop_width | |
bottom = top + crop_height | |
final_image = final_image.crop((left, top, right, bottom)) | |
# Convert back to RGB for display | |
final_image = final_image.convert("RGB") | |
# Ensure the final image is square | |
width, height = final_image.size | |
square_size = min(width, height) | |
left = (width - square_size) // 2 | |
top = (height - square_size) // 2 | |
right = left + square_size | |
bottom = top + square_size | |
final_image = final_image.crop((left, top, right, bottom)) | |
if circular_overlay: | |
# Create a circular mask | |
mask = Image.new("L", (square_size, square_size), 0) | |
draw = ImageDraw.Draw(mask) | |
draw.ellipse((0, 0, square_size, square_size), fill=255) | |
# Apply the circular mask to the final image | |
final_image.putalpha(mask) | |
logger.info( | |
f"Processing complete (portrait offset by {offset}, zoom: {zoom_level})" | |
) | |
return final_image | |
# Create Gradio interface | |
with gr.Blocks(title="Cool Avatar Creator") as app: | |
gr.Markdown("# Cool Avatar Creator") | |
gr.Markdown( | |
"Upload a portrait image to remove the background, add a white border, and place on a colored background." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Input Image") | |
border_slider = gr.Slider( | |
minimum=0, maximum=50, value=10, step=1, label="Border Size (pixels)" | |
) | |
bg_color = gr.ColorPicker(value="#fdc915", label="Background Color") | |
zoom_slider = gr.Slider( | |
minimum=0.5, maximum=4.0, value=1.2, step=0.1, label="Zoom Level" | |
) | |
erode_slider = gr.Slider( | |
minimum=1, maximum=30, value=15, step=1, label="Erode Size" | |
) | |
circular_overlay_toggle = gr.Checkbox(label="Enable Circular Overlay") | |
process_button = gr.Button("Process Image") | |
with gr.Column(): | |
output_image = gr.Image(type="pil", label="Processed Image") | |
# Add example images | |
examples = [ | |
[ | |
"https://brobible.com/wp-content/uploads/2019/11/istock-153696622.jpg", | |
26, | |
"#fdc915", | |
1.85, | |
], | |
[ | |
"https://as1.ftcdn.net/jpg/00/26/35/66/1000_F_26356634_6hC5kmcoRfysvavKTZdDQwsk5CMZwwDs.jpg", | |
23, | |
"#00FF00", | |
1.4, | |
], | |
["https://i.imgflip.com/1freth.jpg?a483936", 29, "#FF0000", 1.4], | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[input_image, border_slider, bg_color, zoom_slider], | |
outputs=output_image, | |
fn=process_portrait, | |
cache_examples=False | |
) | |
process_button.click( | |
fn=process_portrait, | |
inputs=[input_image, border_slider, bg_color, zoom_slider, erode_slider, circular_overlay_toggle], | |
outputs=output_image, | |
) | |
if __name__ == "__main__": | |
app.launch(share=False) # Share=True creates a public link | |