cool-avatar / app.py
Johann.Haselberger (PEG-AS)
Merge branch 'main' of https://huggingface.co/spaces/jHaselberger/cool-avatar
4221f6e
import gradio as gr
import numpy as np
import cv2
from PIL import Image, ImageOps, ImageDraw
import os
import torch
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
import hashlib
import re
import urllib.request as urllib2
from loguru import logger
# Set up model and transformations
def get_background_removal_model():
try:
# Using BiRefNet model for background removal
model = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
# Use CPU if CUDA is not available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return model, device
except Exception as e:
print(f"Error loading background removal model: {e}")
return None, None
# Set up image transformation
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
# Cache for storing background removal results
bg_removal_cache = {}
def get_image_hash(image):
"""Generate a hash for an image to use as cache key"""
if image is None:
return None
# Convert to bytes and generate hash
img_byte_arr = image.tobytes()
img_hash = hashlib.md5(img_byte_arr).hexdigest()
# Include image dimensions in the hash to ensure uniqueness
return f"{img_hash}_{image.width}_{image.height}"
def remove_background(image, model_data):
if model_data[0] is None:
return None, None
# Generate a hash for the image to use as cache key
img_hash = get_image_hash(image)
# Check if result is already in cache
if img_hash in bg_removal_cache:
logger.info("Using cached background removal result")
return bg_removal_cache[img_hash]
model, device = model_data
try:
logger.info("Starting background removal process")
# Convert image to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Store original size for later resizing
image_size = image.size
# Apply transformations and move to device
input_images = transform_image(image).unsqueeze(0).to(device)
# Run prediction
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Convert prediction to PIL image
pred_pil = transforms.ToPILImage()(pred)
# Resize mask back to original image size
mask = pred_pil.resize(image_size)
# Create a copy of the original image and apply alpha channel
result_image = image.copy()
result_image.putalpha(mask)
# Cache the result
result = (result_image, np.array(mask))
bg_removal_cache[img_hash] = result
logger.info("Background removal process completed")
return result
except Exception as e:
logger.error(f"Error during background removal: {e}")
return None, None
def parse_color(color_str):
"""Parse different color formats including rgba strings"""
if isinstance(color_str, tuple):
# If it's already a tuple, make sure it has alpha
if len(color_str) == 3:
return color_str + (255,)
return color_str
if isinstance(color_str, str):
# Handle hex color format
if color_str.startswith("#"):
if len(color_str) == 7: # #RRGGBB format
r = int(color_str[1:3], 16)
g = int(color_str[3:5], 16)
b = int(color_str[5:7], 16)
return (r, g, b, 255)
else:
# Fallback to white if format is unexpected
return (255, 255, 255, 255)
# Handle rgba() format from Gradio color picker
rgba_match = re.match(r"rgba?\(([^)]+)\)", color_str)
if rgba_match:
values = [float(x.strip()) for x in rgba_match.group(1).split(",")]
r = min(255, int(values[0]))
g = min(255, int(values[1]))
b = min(255, int(values[2]))
# Handle alpha if present
a = 255
if len(values) > 3:
a = min(255, int(values[3] * 255))
return (r, g, b, a)
# For named colors, return as is for PIL to handle
return color_str
# Default fallback
return (255, 255, 255, 255) # White
def add_person_border(image, mask, border_size, border_color="white"):
"""Add a border around the person based on the segmentation mask"""
if border_size == 0:
return image
# Convert mask to binary
binary_mask = (np.array(mask) > 4).astype(np.uint8) * 255
# Dilate the mask to create the border
kernel = np.ones((border_size * 2 + 1, border_size * 2 + 1), np.uint8)
dilated_mask = cv2.dilate(binary_mask, kernel, iterations=1)
# Create border mask (includes both the person area and border area)
border_mask_pil = Image.fromarray(dilated_mask)
# Create an image with the border color (white)
border_color_rgba = parse_color("white") # Default white border
border_img = Image.new("RGBA", image.size, color=border_color_rgba)
# Create transparent image for result
result = Image.new("RGBA", image.size, (0, 0, 0, 0))
# First paste the white border shape (which includes both border and person area)
result.paste(border_img, (0, 0), border_mask_pil)
# Then paste the original image on top, but only the non-transparent parts
# This will show the original person on top of the white area
result.paste(image, (0, 0), Image.fromarray(binary_mask))
return result
def detect_face(image):
"""Detect the largest face in the image and return its bounding box"""
logger.info("Starting face detection")
# Convert PIL image to OpenCV format
img_cv = np.array(image.convert("RGB"))
img_cv = img_cv[:, :, ::-1].copy() # Convert RGB to BGR for OpenCV
# Load the Haar cascade for face detection
face_cascade_path = cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
face_cascade = cv2.CascadeClassifier(face_cascade_path)
# Convert to grayscale for face detection
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
# Detect faces
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
if len(faces) == 0:
logger.warning("No faces detected")
return None
# Find the largest face
largest_face = None
max_area = 0
for x, y, w, h in faces:
if w * h > max_area:
max_area = w * h
largest_face = (x, y, w, h)
logger.info(f"Largest face detected at: {largest_face}")
return largest_face
def center_portrait(portrait, face_box, target_width, target_height, zoom_level=1.0):
"""Center the portrait based on face position and crop to avoid blurriness"""
if face_box is None:
# If no face detected, just center the portrait
return portrait.crop((0, 0, target_width, target_height)), (0, 0)
x, y, w, h = face_box
# Calculate face center
face_center_x = x + w // 2
face_center_y = y + h // 2
# Calculate crop box dimensions
crop_width = int(target_width / zoom_level)
crop_height = int(target_height / zoom_level)
# Ensure the crop box stays within the image bounds
left = max(0, face_center_x - crop_width // 2)
top = max(0, face_center_y - crop_height // 2)
right = min(portrait.width, left + crop_width)
bottom = min(portrait.height, top + crop_height)
# Adjust left and top if the crop box is smaller than the target dimensions
left = max(0, right - crop_width)
top = max(0, bottom - crop_height)
# Crop the image
cropped_img = portrait.crop((left, top, right, bottom))
# Center the cropped image on a transparent canvas
centered_img = Image.new("RGBA", (target_width, target_height), (0, 0, 0, 0))
offset_x = (target_width - cropped_img.width) // 2
offset_y = (target_height - cropped_img.height) // 2
centered_img.paste(cropped_img, (offset_x, offset_y), cropped_img)
return centered_img, (offset_x, offset_y)
def process_portrait(
input_image, border_size=10, bg_color="#0000FF", zoom_level=1.0, erode_size=5, circular_overlay=False
):
if input_image is None:
return None
# Global model instance to avoid reloading
global model_instance
if "model_instance" not in globals():
logger.info("Loading background removal model...")
model_instance = get_background_removal_model()
logger.info("Processing image...")
result = remove_background(input_image, model_instance)
if result[0] is None:
logger.warning("Failed to remove background, returning original image")
return input_image
person_img, mask = result
# Detect face before any transformations
face_box = detect_face(input_image)
if face_box:
logger.info(f"Face detected at: {face_box}")
else:
logger.warning("No face detected, will center the entire portrait")
# Expand the mask by 3 pixels
expanded_mask = cv2.erode(
np.array(mask), np.ones((erode_size, erode_size), np.uint8), iterations=1
)
expanded_mask_pil = Image.fromarray(expanded_mask)
mask = expanded_mask_pil
logger.info("Adding white border...")
# Add white border only around the person
bordered_img = add_person_border(person_img, mask, border_size, "white")
logger.info(f"Creating colored background with color: {bg_color}")
# Parse the background color
bg_color_rgba = parse_color(bg_color)
# Create colored background
width, height = bordered_img.size
bg_image = Image.new("RGBA", (width, height), color=bg_color_rgba)
# Center the portrait based on face location and apply zoom
logger.info(f"Applying zoom level: {zoom_level}")
centered_portrait, offset = center_portrait(
bordered_img, face_box, width, height, zoom_level
)
# Create the final composite
final_image = Image.alpha_composite(bg_image, centered_portrait)
# Crop the final image to the target dimensions
crop_width = int(width / zoom_level)
crop_height = int(height / zoom_level)
left = (width - crop_width) // 2
top = (height - crop_height) // 2
right = left + crop_width
bottom = top + crop_height
final_image = final_image.crop((left, top, right, bottom))
# Convert back to RGB for display
final_image = final_image.convert("RGB")
# Ensure the final image is square
width, height = final_image.size
square_size = min(width, height)
left = (width - square_size) // 2
top = (height - square_size) // 2
right = left + square_size
bottom = top + square_size
final_image = final_image.crop((left, top, right, bottom))
if circular_overlay:
# Create a circular mask
mask = Image.new("L", (square_size, square_size), 0)
draw = ImageDraw.Draw(mask)
draw.ellipse((0, 0, square_size, square_size), fill=255)
# Apply the circular mask to the final image
final_image.putalpha(mask)
logger.info(
f"Processing complete (portrait offset by {offset}, zoom: {zoom_level})"
)
return final_image
# Create Gradio interface
with gr.Blocks(title="Cool Avatar Creator") as app:
gr.Markdown("# Cool Avatar Creator")
gr.Markdown(
"Upload a portrait image to remove the background, add a white border, and place on a colored background."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
border_slider = gr.Slider(
minimum=0, maximum=50, value=10, step=1, label="Border Size (pixels)"
)
bg_color = gr.ColorPicker(value="#fdc915", label="Background Color")
zoom_slider = gr.Slider(
minimum=0.5, maximum=4.0, value=1.2, step=0.1, label="Zoom Level"
)
erode_slider = gr.Slider(
minimum=1, maximum=30, value=15, step=1, label="Erode Size"
)
circular_overlay_toggle = gr.Checkbox(label="Enable Circular Overlay")
process_button = gr.Button("Process Image")
with gr.Column():
output_image = gr.Image(type="pil", label="Processed Image")
# Add example images
examples = [
[
"https://brobible.com/wp-content/uploads/2019/11/istock-153696622.jpg",
26,
"#fdc915",
1.85,
],
[
"https://as1.ftcdn.net/jpg/00/26/35/66/1000_F_26356634_6hC5kmcoRfysvavKTZdDQwsk5CMZwwDs.jpg",
23,
"#00FF00",
1.4,
],
["https://i.imgflip.com/1freth.jpg?a483936", 29, "#FF0000", 1.4],
]
gr.Examples(
examples=examples,
inputs=[input_image, border_slider, bg_color, zoom_slider],
outputs=output_image,
fn=process_portrait,
cache_examples=False
)
process_button.click(
fn=process_portrait,
inputs=[input_image, border_slider, bg_color, zoom_slider, erode_slider, circular_overlay_toggle],
outputs=output_image,
)
if __name__ == "__main__":
app.launch(share=False) # Share=True creates a public link