import os import torch from dataclasses import dataclass import gradio as gr import numpy as np import matplotlib.pyplot as plt import cv2 import mediapipe as mp from torchvision.transforms import Compose, Resize, ToTensor, Normalize import vqvae import vit from typing import Literal from diffusion import create_diffusion from utils import scale_keypoint, keypoint_heatmap, check_keypoints_validity from segment_hoi import init_sam from io import BytesIO from PIL import Image import random from copy import deepcopy from huggingface_hub import hf_hub_download from gradio_toggle import Toggle try: import spaces except: pass MAX_N = 6 FIX_MAX_N = 6 LENGTH = 480 placeholder = cv2.cvtColor(cv2.imread("placeholder.png"), cv2.COLOR_BGR2RGB) NEW_MODEL = True MODEL_EPOCH = 6 HF = True pre_device = "cpu" if HF else "cuda" spaces_60_fn = spaces.GPU(duration=60) if HF else (lambda f: f) spaces_120_fn = spaces.GPU(duration=60) if HF else (lambda f: f) spaces_300_fn = spaces.GPU(duration=60) if HF else (lambda f: f) def set_seed(seed): seed = int(seed) torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) device = "cuda" def remove_prefix(text, prefix): if text.startswith(prefix): return text[len(prefix) :] return text def unnormalize(x): return (((x + 1) / 2) * 255).astype(np.uint8) def visualize_hand(all_joints, img, side=["right", "left"], n_avail_joints=21): # Define the connections between joints for drawing lines and their corresponding colors connections = [ ((0, 1), "red"), ((1, 2), "green"), ((2, 3), "blue"), ((3, 4), "purple"), ((0, 5), "orange"), ((5, 6), "pink"), ((6, 7), "brown"), ((7, 8), "cyan"), ((0, 9), "yellow"), ((9, 10), "magenta"), ((10, 11), "lime"), ((11, 12), "indigo"), ((0, 13), "olive"), ((13, 14), "teal"), ((14, 15), "navy"), ((15, 16), "gray"), ((0, 17), "lavender"), ((17, 18), "silver"), ((18, 19), "maroon"), ((19, 20), "fuchsia"), ] H, W, C = img.shape # Create a figure and axis plt.figure() ax = plt.gca() # Plot joints as points ax.imshow(img) start_is = [] if "right" in side: start_is.append(0) if "left" in side: start_is.append(21) for start_i in start_is: joints = all_joints[start_i : start_i + n_avail_joints] if len(joints) == 1: ax.scatter(joints[0][0], joints[0][1], color="red", s=10) else: for connection, color in connections[: len(joints) - 1]: joint1 = joints[connection[0]] joint2 = joints[connection[1]] ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color) ax.set_xlim([0, W]) ax.set_ylim([0, H]) ax.grid(False) ax.set_axis_off() ax.invert_yaxis() # plt.subplots_adjust(wspace=0.01) # plt.show() buf = BytesIO() plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close() # Convert BytesIO object to numpy array buf.seek(0) img_pil = Image.open(buf) img_pil = img_pil.resize((W, H)) numpy_img = np.array(img_pil) return numpy_img def mask_image(image, mask, color=[0, 0, 0], alpha=0.6, transparent=True): """Overlay mask on image for visualization purpose. Args: image (H, W, 3) or (H, W): input image mask (H, W): mask to be overlaid color: the color of overlaid mask alpha: the transparency of the mask """ out = deepcopy(image) img = deepcopy(image) img[mask == 1] = color if transparent: out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out) else: out = img return out def scale_keypoint(keypoint, original_size, target_size): """Scale a keypoint based on the resizing of the image.""" keypoint_copy = keypoint.copy() keypoint_copy[:, 0] *= target_size[0] / original_size[0] keypoint_copy[:, 1] *= target_size[1] / original_size[1] return keypoint_copy print("Configure...") @dataclass class HandDiffOpts: run_name: str = "ViT_256_handmask_heatmap_nvs_b25_lr1e-5" sd_path: str = "/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt" log_dir: str = "/users/kchen157/scratch/log" data_root: str = "/users/kchen157/data/users/kchen157/dataset/handdiff" image_size: tuple = (256, 256) latent_size: tuple = (32, 32) latent_dim: int = 4 mask_bg: bool = False kpts_form: str = "heatmap" n_keypoints: int = 42 n_mask: int = 1 noise_steps: int = 1000 test_sampling_steps: int = 250 ddim_steps: int = 100 ddim_discretize: str = "uniform" ddim_eta: float = 0.0 beta_start: float = 8.5e-4 beta_end: float = 0.012 latent_scaling_factor: float = 0.18215 cfg_pose: float = 5.0 cfg_appearance: float = 3.5 batch_size: int = 25 lr: float = 1e-5 max_epochs: int = 500 log_every_n_steps: int = 100 limit_val_batches: int = 1 n_gpu: int = 8 num_nodes: int = 1 precision: str = "16-mixed" profiler: str = "simple" swa_epoch_start: int = 10 swa_lrs: float = 1e-3 num_workers: int = 10 n_val_samples: int = 4 # load models token = os.getenv("HF_TOKEN") if NEW_MODEL: opts = HandDiffOpts() if MODEL_EPOCH == 7: model_path = './DINO_EMA_11M_b50_lr1e-5_epoch7_step380k.ckpt' elif MODEL_EPOCH == 6: model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt" if not os.path.exists(model_path): model_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt", token=token) elif MODEL_EPOCH == 4: model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch4_step210k.ckpt" elif MODEL_EPOCH == 10: model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch10_step550k.ckpt" else: raise ValueError(f"new model epoch should be either 6 or 7, got {MODEL_EPOCH}") vae_path = './vae-ft-mse-840000-ema-pruned.ckpt' if not os.path.exists(vae_path): vae_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="vae-ft-mse-840000-ema-pruned.ckpt", token=token) # sd_path = './sd-v1-4.ckpt' print('Load diffusion model...') diffusion = create_diffusion(str(opts.test_sampling_steps)) model = vit.DiT_XL_2( input_size=opts.latent_size[0], latent_dim=opts.latent_dim, in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask, learn_sigma=True, ).to(device) # ckpt_state_dict = torch.load(model_path)['model_state_dict'] ckpt_state_dict = torch.load(model_path, map_location='cpu')['ema_state_dict'] missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False) model = model.to(device) model.eval() print(missing_keys, extra_keys) assert len(missing_keys) == 0 vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict'] print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}") autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False) print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}") print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}") print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}") missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False) print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}") print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}") autoencoder = autoencoder.to(device) autoencoder.eval() print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}") print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}") print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}") assert len(missing_keys) == 0 sam_path = "sam_vit_h_4b8939.pth" if not os.path.exists(sam_path): sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token) sam_predictor = init_sam(ckpt_path=sam_path, device=pre_device) print("Mediapipe hand detector and SAM ready...") mp_hands = mp.solutions.hands hands = mp_hands.Hands( static_image_mode=True, # Use False if image is part of a video stream max_num_hands=2, # Maximum number of hands to detect min_detection_confidence=0.1, ) no_hands_open = cv2.resize(np.array(Image.open("no_hands_open.jpeg"))[..., :3], (LENGTH, LENGTH)) def prepare_anno(ref, ref_is_user): if not ref_is_user: # no_hand_open.jpeg return gr.update(value=None), gr.update(value=None) if ref is None or ref["background"] is None or ref["background"].sum()==0: # clear_all return ( gr.update(value=None), gr.update(value=None), ) img = ref["composite"][..., :3] img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA) keypts = np.zeros((42, 2)) mp_pose = hands.process(img) if mp_pose.multi_hand_landmarks: # handedness is flipped assuming the input image is mirrored in MediaPipe for hand_landmarks, handedness in zip( mp_pose.multi_hand_landmarks, mp_pose.multi_handedness ): # actually right hand if handedness.classification[0].label == "Left": start_idx = 0 # actually left hand elif handedness.classification[0].label == "Right": start_idx = 21 for i, landmark in enumerate(hand_landmarks.landmark): keypts[start_idx + i] = [ landmark.x * opts.image_size[1], landmark.y * opts.image_size[0], ] print(f"keypts.max(): {keypts.max()}, keypts.min(): {keypts.min()}") return img, keypts else: return img, None def get_ref_anno(img, keypts, use_mask, use_pose): no_mask, no_pose = not use_mask, not use_pose if img.sum() == 0: # clear_all return None, gr.update(), None, gr.update(), True elif keypts is None: # hand not detected no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH)) return None, no_hands, None, no_hands_open, False missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False) if no_pose: keypts = np.zeros((42, 2)) else: if isinstance(keypts, list): if len(keypts[0]) == 0: keypts[0] = np.zeros((21, 2)) elif len(keypts[0]) == 21: keypts[0] = np.array(keypts[0], dtype=np.float32) else: gr.Info("Number of right hand keypoints should be either 0 or 21.") return None, None, None, gr.update(), gr.update() if len(keypts[1]) == 0: keypts[1] = np.zeros((21, 2)) elif len(keypts[1]) == 21: keypts[1] = np.array(keypts[1], dtype=np.float32) else: gr.Info("Number of left hand keypoints should be either 0 or 21.") return None, None, None, gr.update(), gr.update() keypts = np.concatenate(keypts, axis=0) if no_mask: hand_mask = np.zeros_like(img[:,:, 0]) ref_pose = visualize_hand(keypts, img) else: sam_predictor.set_image(img) if keypts[0].sum() != 0 and keypts[21].sum() != 0: # input_point = np.array([keypts[0], keypts[21]]) input_point = np.array(keypts) input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)]) # input_label = np.array([1, 1]) elif keypts[0].sum() != 0: # input_point = np.array(keypts[:1]) input_point = np.array(keypts[:21]) input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)]) # input_label = np.array([1]) elif keypts[21].sum() != 0: input_point = np.array(keypts[21:]) # input_label = np.array([1]) input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)]) input_label = np.ones_like(input_point[:, 0]).astype(np.int32) box_shift_ratio = 0.5 box_size_factor = 1.2 box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio) input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1) masks, _, _ = sam_predictor.predict( point_coords=input_point, point_labels=input_label, box=input_box[None, :], multimask_output=False, ) hand_mask = masks[0] masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None]) ref_pose = visualize_hand(keypts, masked_img) def make_ref_cond( img, keypts, hand_mask, device="cuda", target_size=(256, 256), latent_size=(32, 32), ): image_transform = Compose( [ ToTensor(), Resize(target_size), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) image = image_transform(img).to(device) kpts_valid = check_keypoints_validity(keypts, target_size) heatmaps = torch.tensor( keypoint_heatmap( scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0 ) * kpts_valid[:, None, None], dtype=torch.float, device=device )[None, ...] mask = torch.tensor( cv2.resize( hand_mask.astype(int), dsize=latent_size, interpolation=cv2.INTER_NEAREST, ), dtype=torch.float, device=device, ).unsqueeze(0)[None, ...] return image[None, ...], heatmaps, mask print(f"img.max(): {img.max()}, img.min(): {img.min()}") image, heatmaps, mask = make_ref_cond( img, keypts, hand_mask, device=pre_device, target_size=opts.image_size, latent_size=opts.latent_size, ) print(f"image.max(): {image.max()}, image.min(): {image.min()}") print(f"opts.latent_scaling_factor: {opts.latent_scaling_factor}") print(f"autoencoder encoder before operating max: {min([p.min() for p in autoencoder.encoder.parameters()])}") print(f"autoencoder encoder before operating min: {max([p.max() for p in autoencoder.encoder.parameters()])}") print(f"autoencoder encoder before operating dtype: {next(autoencoder.encoder.parameters()).dtype}") latent = opts.latent_scaling_factor * autoencoder.encode(image).sample() print(f"latent.max(): {latent.max()}, latent.min(): {latent.min()}") if no_mask: mask = torch.zeros_like(mask) print(f"heatmaps.max(): {heatmaps.max()}, heatmaps.min(): {heatmaps.min()}") print(f"mask.max(): {mask.max()}, mask.min(): {mask.min()}") ref_cond = torch.cat([latent, heatmaps, mask], 1) print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}") return img, ref_pose, ref_cond, gr.update(), True def get_target_anno(img, keypts): if img.sum() == 0: # clear_all return None, gr.update(), None, gr.update(), True if keypts is None: # hands not detected no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH)) return None, no_hands, None, None, no_hands_open, False if isinstance(keypts, list): if len(keypts[0]) == 0: keypts[0] = np.zeros((21, 2)) elif len(keypts[0]) == 21: keypts[0] = np.array(keypts[0], dtype=np.float32) else: gr.Info("Number of right hand keypoints should be either 0 or 21.") return None, None, None, gr.update(), gr.update(), gr.update() if len(keypts[1]) == 0: keypts[1] = np.zeros((21, 2)) elif len(keypts[1]) == 21: keypts[1] = np.array(keypts[1], dtype=np.float32) else: gr.Info("Number of left hand keypoints should be either 0 or 21.") return None, None, None, gr.update(), gr.update(), gr.update() keypts = np.concatenate(keypts, axis=0) target_pose = visualize_hand(keypts, img) kpts_valid = check_keypoints_validity(keypts, opts.image_size) target_heatmaps = torch.tensor( keypoint_heatmap( scale_keypoint(keypts, opts.image_size, opts.latent_size), opts.latent_size, var=1.0, ) * kpts_valid[:, None, None], dtype=torch.float, device=pre_device, )[None, ...] target_cond = torch.cat( [target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1 ) return img, target_pose, target_cond, keypts, gr.update(), True def visualize_ref(ref, ex_mask): if ref is None: return None # from user or from example # h, w = ref["background"].shape[:2] # if ref["layers"][0].sum() == 0: # if ref["background"][:, :, -1].sum() == h * w * 255: # from_example = False # else: # from_example = True # else: # from_example = False if ex_mask is None: from_example = False else: from_example = True # inpaint mask if from_example: # inpaint_mask = ref["background"][:, :, -1] inpaint_mask = (np.all(ex_mask > 245, axis=-1)).astype(np.uint8)*128 + 64 inpainted = inpaint_mask.copy() inpaint_mask = cv2.resize( inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA ) inpaint_mask = (inpaint_mask > 128).astype(np.uint8) img = cv2.cvtColor(ref["background"], cv2.COLOR_RGBA2RGB) else: inpaint_mask = np.array(ref["layers"][0])[..., -1] inpaint_mask = cv2.resize( inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA ) inpaint_mask = (inpaint_mask >= 128).astype(np.uint8) inpainted = ref["layers"][0][..., -1] img = ref["background"][..., :3] # viualization mask = inpainted < 128 img = mask_image(img, mask) if inpaint_mask.sum() == 0: gr.Warning("Run botton not enabled? Please try again.", duration=10) return img, inpaint_mask def make_composite(img, mask): if mask is None: return gr.update() mask = (np.all(mask > 245, axis=-1)).astype(np.uint8)*128 + 64 # composite = np.concatenate((img[..., :3], mask[..., None]), axis=-1) composite = { "background": img[..., :3], "layers": [ mask, ], "composite": np.concatenate((img[..., :3], mask[..., None]), axis=-1), } return composite def get_kps(img, keypoints, side: Literal["right", "left"], evt: gr.SelectData): if keypoints is None: keypoints = [[], []] kps = np.zeros((42, 2)) if side == "right": if len(keypoints[0]) == 21: gr.Info("21 keypoints for right hand already selected. Try reset if something looks wrong.") else: keypoints[0].append(list(evt.index)) len_kps = len(keypoints[0]) kps[:len_kps] = np.array(keypoints[0]) elif side == "left": if len(keypoints[1]) == 21: gr.Info("21 keypoints for left hand already selected. Try reset if something looks wrong.") else: keypoints[1].append(list(evt.index)) len_kps = len(keypoints[1]) kps[21 : 21 + len_kps] = np.array(keypoints[1]) vis_hand = visualize_hand(kps, img, side, len_kps) return vis_hand, keypoints def undo_kps(img, keypoints, side: Literal["right", "left"]): if keypoints is None: return img, None kps = np.zeros((42, 2)) if side == "right": if len(keypoints[0]) == 0: return img, keypoints keypoints[0].pop() len_kps = len(keypoints[0]) kps[:len_kps] = np.array(keypoints[0]) elif side == "left": if len(keypoints[1]) == 0: return img, keypoints keypoints[1].pop() len_kps = len(keypoints[1]) kps[21 : 21 + len_kps] = np.array(keypoints[1]) vis_hand = visualize_hand(kps, img, side, len_kps) return vis_hand, keypoints def reset_kps(img, keypoints, side: Literal["right", "left"]): if keypoints is None: return img, None if side == "right": keypoints[0] = [] elif side == "left": keypoints[1] = [] return img, keypoints def read_kpts(kpts_path): if kpts_path is None or len(kpts_path)==0: return None kpts = np.load(kpts_path) return kpts def stay_crop(img, crop_coord): if img is not None: if crop_coord is None: crop_coord = [[0, 0], [img.shape[1], img.shape[0]]] cropped = img.copy() return crop_coord, cropped else: return gr.update(), gr.update() else: return None, None def stash_original(img): if img is None: return None else: return img[:,:,:3] def process_crop(img, crop_coord, evt:gr.SelectData): image = img.copy() if len(crop_coord) == 2: # will add first click crop_coord = [list(evt.index)] cropped = image cropped_vis = image.copy() alpha = np.ones_like(cropped_vis[:,:, -1]) * 255 cv2.circle(alpha, tuple(crop_coord[0]), 5, 0, 4) cropped_vis[:,:,-1] = alpha elif len(crop_coord) == 1: new_coord =list(evt.index) if new_coord[0] <= crop_coord[0][0] or new_coord[1] <= crop_coord[0][1]: # will skip gr.Warning("Second click should be more under and more right thand the first click. Try second click again.", duration=3) cropped = image cropped_vis = image.copy() cropped_vis[:,:,-1] = 255 else: # will add second click crop_coord.append(new_coord) x1, y1 = crop_coord[0] x2, y2 = crop_coord[1] cropped = image[y1:y2, x1:x2] cropped_vis = image.copy() alpha = np.ones_like(cropped_vis[:,:, -1]) * 255 cv2.rectangle(alpha, tuple([x1, y1]), tuple([x2, y2]), 0, 4) cropped_vis[:,:,-1] = alpha else: gr.Error("Something is wrong", duration=3) return crop_coord, cropped, cropped_vis def disable_crop(crop_coord): if len(crop_coord) == 2: return gr.update(interactive=False) else: return gr.update(interactive=True) @spaces_60_fn def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg): set_seed(seed) z = torch.randn( (num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=device, ) print(f"z.device: {z.device}") target_cond = target_cond.repeat(num_gen, 1, 1, 1).to(z.device) ref_cond = ref_cond.repeat(num_gen, 1, 1, 1).to(z.device) print(f"target_cond.max(): {target_cond.max()}, target_cond.min(): {target_cond.min()}") print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}") # novel view synthesis mode = off nvs = torch.zeros(num_gen, dtype=torch.int, device=device) z = torch.cat([z, z], 0) model_kwargs = dict( target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]), ref_cond=torch.cat([ref_cond, torch.zeros_like(ref_cond)]), nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]), cfg_scale=cfg, ) gr.Info("The process successfully started to run. Please wait for 50s x Number of Generation.", duration=20) samples, _ = diffusion.p_sample_loop( model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=gr.Progress(), device=device, ).chunk(2) sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor) sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0) sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy()) results = [] results_pose = [] for i in range(MAX_N): if i < num_gen: results.append(sampled_images[i]) results_pose.append(visualize_hand(target_keypts, sampled_images[i])) else: results.append(placeholder) results_pose.append(placeholder) print(f"results[0].max(): {results[0].max()}") return results, results_pose @spaces_120_fn def ready_sample(img_cropped, img_original, ex_mask, inpaint_mask, keypts, keypts_np): if ex_mask is None: img = cv2.resize(img_cropped["background"][..., :3], opts.image_size, interpolation=cv2.INTER_AREA) else: img = cv2.resize(img_original[..., :3], opts.image_size, interpolation=cv2.INTER_AREA) sam_predictor.set_image(img) if keypts is None and keypts_np is not None: keypts = keypts_np else: if len(keypts[0]) == 0: keypts[0] = np.zeros((21, 2)) elif len(keypts[0]) == 21: keypts[0] = np.array(keypts[0], dtype=np.float32) else: gr.Info("Number of right hand keypoints should be either 0 or 21.") return None, None if len(keypts[1]) == 0: keypts[1] = np.zeros((21, 2)) elif len(keypts[1]) == 21: keypts[1] = np.array(keypts[1], dtype=np.float32) else: gr.Info("Number of left hand keypoints should be either 0 or 21.") return None, None keypts = np.concatenate(keypts, axis=0) keypts = scale_keypoint(keypts, (img_cropped["background"].shape[1], img_cropped["background"].shape[0]), opts.image_size) box_shift_ratio = 0.5 box_size_factor = 1.2 if keypts[0].sum() != 0 and keypts[21].sum() != 0: input_point = np.array(keypts) input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)]) elif keypts[0].sum() != 0: input_point = np.array(keypts[:21]) input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)]) elif keypts[21].sum() != 0: input_point = np.array(keypts[21:]) input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)]) else: raise ValueError( "Something wrong. If no hand detected, it should not reach here." ) input_label = np.ones_like(input_point[:, 0]).astype(np.int32) box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio) input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1) masks, _, _ = sam_predictor.predict( point_coords=input_point, point_labels=input_label, box=input_box[None, :], multimask_output=False, ) hand_mask = masks[0] inpaint_latent_mask = torch.tensor( cv2.resize( inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST ), dtype=torch.float, device=pre_device, ).unsqueeze(0)[None, ...] def make_ref_cond( img, keypts, hand_mask, device=device, target_size=(256, 256), latent_size=(32, 32), ): image_transform = Compose( [ ToTensor(), Resize(target_size), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) image = image_transform(img).to(device) kpts_valid = check_keypoints_validity(keypts, target_size) heatmaps = torch.tensor( keypoint_heatmap( scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0 ) * kpts_valid[:, None, None], dtype=torch.float, device=device, )[None, ...] mask = torch.tensor( cv2.resize( hand_mask.astype(int), dsize=latent_size, interpolation=cv2.INTER_NEAREST, ), dtype=torch.float, device=device, ).unsqueeze(0)[None, ...] return image[None, ...], heatmaps, mask image, heatmaps, mask = make_ref_cond( img, keypts, hand_mask * (1 - inpaint_mask), device=pre_device, target_size=opts.image_size, latent_size=opts.latent_size, ) latent = opts.latent_scaling_factor * autoencoder.encode(image).sample() target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1) ref_cond = torch.cat([latent, heatmaps, mask], 1) ref_cond = torch.zeros_like(ref_cond) img32 = cv2.resize(img, opts.latent_size, interpolation=cv2.INTER_NEAREST) assert mask.max() == 1 vis_mask32 = mask_image( img32, inpaint_latent_mask[0,0].cpu().numpy(), (255,255,255), transparent=False ).astype(np.uint8) # 1.0 - mask[0, 0].cpu().numpy() assert np.unique(inpaint_mask).shape[0] <= 2 assert hand_mask.dtype == bool mask256 = inpaint_mask # hand_mask * (1 - inpaint_mask) vis_mask256 = mask_image(img, mask256, (255,255,255), transparent=False).astype( np.uint8 ) # 1 - mask256 return ( ref_cond, target_cond, latent, inpaint_latent_mask, keypts, vis_mask32, vis_mask256, ) def switch_mask_size(radio): if radio == "256x256": out = (gr.update(visible=False), gr.update(visible=True)) elif radio == "latent size (32x32)": out = (gr.update(visible=True), gr.update(visible=False)) return out @spaces_300_fn def sample_inpaint( ref_cond, target_cond, latent, inpaint_latent_mask, keypts, img_original, crop_coord, num_gen, seed, cfg, quality, ): if inpaint_latent_mask is None: return None, None, None set_seed(seed) N = num_gen jump_length = 10 jump_n_sample = quality cfg_scale = cfg z = torch.randn( (N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=device ) target_cond_N = target_cond.repeat(N, 1, 1, 1).to(z.device) ref_cond_N = ref_cond.repeat(N, 1, 1, 1).to(z.device) # novel view synthesis mode = off nvs = torch.zeros(N, dtype=torch.int, device=device) z = torch.cat([z, z], 0) model_kwargs = dict( target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]), ref_cond=torch.cat([ref_cond_N, torch.zeros_like(ref_cond_N)]), nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]), cfg_scale=cfg_scale, ) gr.Info("The process successfully started to run. Please wait for around 3.5 minutes.", duration=220) samples, _ = diffusion.inpaint_p_sample_loop( model.forward_with_cfg, z.shape, latent.to(z.device), inpaint_latent_mask.to(z.device), z, clip_denoised=False, model_kwargs=model_kwargs, progress=gr.Progress(), device=z.device, jump_length=jump_length, jump_n_sample=jump_n_sample, ).chunk(2) sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor) sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0) sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy()) # visualize results = [] results_pose = [] results_original = [] for i in range(FIX_MAX_N): if i < num_gen: res =sampled_images[i] results.append(res) results_pose.append(visualize_hand(keypts, res)) res = cv2.resize(res, (crop_coord[1][0]-crop_coord[0][0], crop_coord[1][1]-crop_coord[0][1])) res_original = img_original.copy() res_original[crop_coord[0][1]:crop_coord[1][1], crop_coord[0][0]:crop_coord[1][0], :] = res results_original.append(res_original) else: results.append(placeholder) results_pose.append(placeholder) results_original.append(placeholder) return results, results_pose, results_original def flip_hand( img, img_raw, pose_img, pose_manual_img, manual_kp_right, manual_kp_left, cond, auto_cond, manual_cond, keypts=None, auto_keypts=None, manual_keypts=None ): if cond is None: # clear clicked return img["composite"] = img["composite"][:, ::-1, :] img["background"] = img["background"][:, ::-1, :] img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]] if img_raw is not None: img_raw = img_raw[:, ::-1, :] pose_img = pose_img[:, ::-1, :] if pose_manual_img is not None: pose_manual_img = pose_manual_img[:, ::-1, :] if manual_kp_right is not None: manual_kp_right = manual_kp_right[:, ::-1, :] if manual_kp_left is not None: manual_kp_left = manual_kp_left[:, ::-1, :] cond = cond.flip(-1) if auto_cond is not None: auto_cond = auto_cond.flip(-1) if manual_cond is not None: manual_cond = manual_cond.flip(-1) if keypts is not None: if keypts[:21, :].sum() != 0: keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0] if keypts[21:, :].sum() != 0: keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0] if auto_keypts is not None: if auto_keypts[:21, :].sum() != 0: auto_keypts[:21, 0] = opts.image_size[1] - auto_keypts[:21, 0] if auto_keypts[21:, :].sum() != 0: auto_keypts[21:, 0] = opts.image_size[1] - auto_keypts[21:, 0] if manual_keypts is not None: if manual_keypts[:21, :].sum() != 0: manual_keypts[:21, 0] = opts.image_size[1] - manual_keypts[:21, 0] if manual_keypts[21:, :].sum() != 0: manual_keypts[21:, 0] = opts.image_size[1] - manual_keypts[21:, 0] return img, img_raw, pose_img, pose_manual_img, manual_kp_right, manual_kp_left, cond, auto_cond, manual_cond, keypts, auto_keypts, manual_keypts def resize_to_full(img): img["background"] = cv2.resize(img["background"], (LENGTH, LENGTH)) img["composite"] = cv2.resize(img["composite"], (LENGTH, LENGTH)) img["layers"] = [cv2.resize(layer, (LENGTH, LENGTH)) for layer in img["layers"]] return img def clear_all(): return ( None, [], None, None, None, None, None, None, False, None, None, [], None, None, None, None, None, None, False, None, None, 1, 42, 3.0, gr.update(interactive=False), False, ) def fix_clear_all(): return ( None, None, None, [], None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, 1, 42, 3.0, 10, None ) def enable_component(image1, image2): if image1 is None or image2 is None: return gr.update(interactive=False) if isinstance(image1, np.ndarray) and image1.sum() == 0: return gr.update(interactive=False) if isinstance(image2, np.ndarray) and image2.sum() == 0: return gr.update(interactive=False) if isinstance(image1, dict) and "background" in image1 and "layers" in image1 and "composite" in image1: if image1["background"] is None or ( image1["background"].sum() == 0 and (sum([im.sum() for im in image1["layers"]]) == 0) and image1["composite"].sum() == 0 ): return gr.update(interactive=False) if isinstance(image1, dict) and "background" in image2 and "layers" in image2 and "composite" in image2: if image2["background"] is None or ( image2["background"].sum() == 0 and (sum([im.sum() for im in image2["layers"]]) == 0) and image2["composite"].sum() == 0 ): return gr.update(interactive=False) return gr.update(interactive=True) def set_visible(checkbox, kpts, img_clean, img_pose_right, img_pose_left, done=None, done_info=None): if kpts is None: kpts = [[], []] if "Right hand" not in checkbox: kpts[0] = [] vis_right = img_clean update_right = gr.update(visible=False) update_r_info = gr.update(visible=False) else: vis_right = img_pose_right update_right = gr.update(visible=True) update_r_info = gr.update(visible=True) if "Left hand" not in checkbox: kpts[1] = [] vis_left = img_clean update_left = gr.update(visible=False) update_l_info = gr.update(visible=False) else: vis_left = img_pose_left update_left = gr.update(visible=True) update_l_info = gr.update(visible=True) ret = [ kpts, vis_right, vis_left, update_right, update_right, update_right, update_left, update_left, update_left, update_r_info, update_l_info, ] if done is not None: if not checkbox: ret.append(gr.update(visible=False)) ret.append(gr.update(visible=False)) else: ret.append(gr.update(visible=True)) ret.append(gr.update(visible=True)) return tuple(ret) def set_unvisible(): return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) ) def fix_set_unvisible(): return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) ) def visible_component(decider, component): if decider is not None: update_component = gr.update(visible=True) else: update_component = gr.update(visible=False) return update_component def unvisible_component(decider, component): if decider is not None: update_component = gr.update(visible=False) else: update_component = gr.update(visible=True) return update_component example_ref_imgs = [ [ "sample_images/sample1.jpg", ], [ "sample_images/sample2.jpg", ], [ "sample_images/sample3.jpg", ], [ "sample_images/sample4.jpg", ], [ "sample_images/sample6.jpg", ], ] example_target_imgs = [ [ "sample_images/sample5.jpg", ], [ "sample_images/sample9.jpg", ], [ "sample_images/sample10.jpg", ], [ "sample_images/sample11.jpg", ], ["pose_images/pose1.jpg"], ] fix_example_imgs = [ ["bad_hands/1.jpg"], ["bad_hands/3.jpg"], ["bad_hands/4.jpg"], ["bad_hands/5.jpg"], ["bad_hands/6.jpg"], ["bad_hands/7.jpg"], ] fix_example_brush = [ ["bad_hands/1_composite.png"], ["bad_hands/3_composite.png"], ["bad_hands/4_composite.png"], ["bad_hands/5_composite.png"], ["bad_hands/6_composite.png"], ["bad_hands/7_composite.png"], ] fix_example_kpts = [ ["bad_hands/1_kpts.png", 3.0, 1224], ["bad_hands/3_kpts.png", 1.0, 42], ["bad_hands/4_kpts.png", 2.0, 42], ["bad_hands/5_kpts.png", 3.0, 42], ["bad_hands/6_kpts.png", 3.0, 1348], ["bad_hands/7_kpts.png", 3.0, 42], ] fix_example_all = [ ["bad_hands/1.jpg", "bad_hands/1_composite.png", "bad_hands/1_mask.jpg", "bad_hands/1_kpts.png", 3.0, 1224], ["bad_hands/3.jpg", "bad_hands/3_composite.png", "bad_hands/3_mask.jpg", "bad_hands/3_kpts.png", 1.0, 42], ["bad_hands/4.jpg", "bad_hands/4_composite.png", "bad_hands/4_mask.jpg", "bad_hands/4_kpts.png", 2.0, 42], ["bad_hands/5.jpg", "bad_hands/5_composite.png", "bad_hands/5_mask.jpg", "bad_hands/5_kpts.png", 3.0, 42], ["bad_hands/6.jpg", "bad_hands/6_composite.png", "bad_hands/6_mask.jpg", "bad_hands/6_kpts.png", 3.0, 1348], ["bad_hands/7.jpg", "bad_hands/7_composite.png", "bad_hands/7_mask.jpg", "bad_hands/7_kpts.png", 3.0, 42], ] for i in range(len(fix_example_kpts)): npy_path = fix_example_kpts[i][0].replace("_kpts.png", ".npy") fix_example_kpts[i].append(npy_path) for i in range(len(fix_example_all)): npy_path = fix_example_all[i][3].replace("_kpts.png", ".npy") fix_example_all[i].append(npy_path) custom_css = """ .gradio-container .examples img { width: 240px !important; height: 240px !important; } #fix-tab-button { font-size: 18px !important; font-weight: bold !important; background-color: #FFDAB9 !important; } #repose-tab-button { font-size: 18px !important; font-weight: bold !important; background-color: #FFB6C1 !important; } #kpts_examples table tr th:nth-child(2), #kpts_examples table tr td:nth-child(2) { display: none !important; } #kpts_examples table tr th:nth-child(3), #kpts_examples table tr td:nth-child(3) { display: none !important; } #kpts_examples table tr th:nth-child(4), #kpts_examples table tr td:nth-child(4) { display: none !important; } #fix_examples_all table tr th:nth-child(2), #fix_examples_all table tr td:nth-child(2) { display: none !important; } #fix_examples_all table tr th:nth-child(3), #fix_examples_all table tr td:nth-child(3) { display: none !important; } #fix_examples_all table tr th:nth-child(4), #fix_examples_all table tr td:nth-child(4) { display: none !important; } #fix_examples_all table tr th:nth-child(5), #fix_examples_all table tr td:nth-child(5) { display: none !important; } #fix_examples_all table tr th:nth-child(6), #fix_examples_all table tr td:nth-child(6) { display: none !important; } #fix_examples_all table tr th:nth-child(7), #fix_examples_all table tr td:nth-child(7) { display: none !important; } #fix_examples_all table tr:first-child { display: none !important; } #repose_tutorial video { width: 50% !important; display: block; margin: 0 auto; padding: 0; } #accordion_bold button span { font-weight: bold !important; } #accordion_bold_large button span { font-weight: bold !important; font-size: 20px !important; } #accordion_bold_large_center button span { font-weight: bold !important; font-size: 20px !important; } #accordion_bold_large_center button { text-align: center !important; margin: 0 auto !important; display: block !important; } #fix_examples_all table tbody { display: flex !important; flex-direction: row; flex-wrap: nowrap; } #fix_examples_all table tr { display: flex !important; align-items: center; } #fix_examples_all table tr th, #fix_examples_all table tr td { display: table-cell; } #gradio-app { flex-direction: row; !important; } """ ##no_wrap_row { # display: flex !important; # flex-direction: row !important; # flex-wrap: nowrap !important; # } ##no_wrap_row > div { # flex: 1 1 auto !important; # min-width: 0; # } button_css = """ #clear_button { background-color: #f44336 !important; } #clear_button:hover { background-color: #d32f2f !important; } #run_button { background-color: #4CAF50 !important; cursor: pointer; transition: background-color 0.3s ease; } #run_button:hover { background-color: #388E3C !important; } """ tut1_custom = f""" """ tut1_example = f""" """ tut2_example = f""" """ _HEADER_ = '''
Kefan Chen1,2* Chaerin Min1* Linguang Zhang2 Shreyas Hampali2 Cem Keskin2 Srinath Sridhar1
1Brown University 2Meta Reality Labs
Below are two important abilities of our model. First, we can automatically fix malformed hand images, following the user-provided target hand pose and area to fix. Second, we can repose hand given two hand images - one is the image to edit, and the other one provides target hand pose.
_CITE_ = r"""@article{chen2024foundhand, title={FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation}, author={Chen, Kefan and Min, Chaerin and Zhang, Linguang and Hampali, Shreyas and Keskin, Cem and Sridhar, Srinath}, journal={arXiv preprint arXiv:2412.02690}, year={2024} }""" _ACK_ = r"""
Part of this work was done during Kefan (Arthur) Chen’s internship at Meta Reality Lab. This work was additionally supported by NSF CAREER grant #2143576, NASA grant #80NSSC23M0075, and an Amazon Cloud Credits Award.""" with gr.Blocks(css=custom_css, theme="soft") as demo: gr.HTML(f"") gr.Markdown(_HEADER_) with gr.Tab("Demo 1. Repose Hands", elem_id="repose-tab"): # ref states dump = gr.State(value=None) ref_img = gr.State(value=None) ref_im_raw = gr.State(value=None) ref_kp_raw = gr.State(value=0) ref_is_user = gr.State(value=True) ref_kp_got = gr.State(value=None) ref_manual_cond = gr.State(value=None) ref_auto_cond = gr.State(value=None) ref_cond = gr.State(value=None) # target states target_img = gr.State(value=None) target_im_raw = gr.State(value=None) target_kp_raw = gr.State(value=0) target_is_user = gr.State(value=True) target_kp_got = gr.State(value=None) target_manual_keypts = gr.State(value=None) target_auto_keypts = gr.State(value=None) target_keypts = gr.State(value=None) target_manual_cond = gr.State(value=None) target_auto_cond = gr.State(value=None) target_cond = gr.State(value=None) # config use_pose = gr.State(value=True) # main tabs with gr.Row(): # ref column with gr.Column(): gr.Markdown( """
1. Upload a hand image to repose or choose one below 📥
""" ) # gr.Markdown( # """Optionally crop the image
""" # ) ref = gr.ImageEditor( type="numpy", label="Reference", show_label=True, height=LENGTH, width=LENGTH, brush=False, layers=False, crop_size="1:1", ) gr.Examples(example_ref_imgs, [ref], examples_per_page=20) use_mask = Toggle(label="Use mask", value=False, interactive=True) with gr.Accordion(label="See hand pose & mask", open=False): with gr.Tab("Automatic hand keypoints"): ref_pose = gr.Image( type="numpy", label="Reference Pose", show_label=True, height=LENGTH, width=LENGTH, interactive=False, ) ref_use_auto = gr.Button(value="Click here to use automatic, not manual", interactive=False, visible=True) with gr.Tab("Manual hand keypoints"): ref_manual_checkbox_info = gr.Markdown( """Step 1. Tell us if this is right, left, or both hands.
""", visible=True, ) ref_manual_checkbox = gr.CheckboxGroup( ["Right hand", "Left hand"], show_label=False, visible=True, interactive=True, ) ref_manual_kp_r_info = gr.Markdown( """Step 2. Click on image to provide hand keypoints for right hand. See \"OpenPose Keypoint Convention\" for guidance.
""", visible=False, ) ref_manual_kp_right = gr.Image( type="numpy", label="Keypoint Selection (right hand)", show_label=True, height=LENGTH, width=LENGTH, interactive=False, visible=False, sources=[], ) with gr.Row(): ref_manual_undo_right = gr.Button( value="Undo", interactive=True, visible=False ) ref_manual_reset_right = gr.Button( value="Reset", interactive=True, visible=False ) ref_manual_kp_l_info = gr.Markdown( """Step 2. Click on image to provide hand keypoints for left hand. See \"OpenPose keypoint convention\" for guidance.
""", visible=False ) ref_manual_kp_left = gr.Image( type="numpy", label="Keypoint Selection (left hand)", show_label=True, height=LENGTH, width=LENGTH, interactive=False, visible=False, sources=[], ) with gr.Row(): ref_manual_undo_left = gr.Button( value="Undo", interactive=True, visible=False ) ref_manual_reset_left = gr.Button( value="Reset", interactive=True, visible=False ) ref_manual_done_info = gr.Markdown( """Step 3. Hit \"Done\" button to confirm.
""", visible=False, ) ref_manual_done = gr.Button(value="Done", interactive=True, visible=False) ref_manual_pose = gr.Image( type="numpy", label="Reference Pose", show_label=True, height=LENGTH, width=LENGTH, interactive=False, visible=False ) ref_use_manual = gr.Button(value="Click here to use manual, not automatic", interactive=True, visible=False) ref_manual_instruct = gr.Markdown( value="""OpenPose Keypoints Convention
""", visible=True ) ref_manual_openpose = gr.Image( value="openpose.png", type="numpy", show_label=False, height=LENGTH // 2, width=LENGTH // 2, interactive=False, visible=True ) # gr.Markdown( # """Optionally flip the hand
""" # ) ref_flip = gr.Checkbox( value=False, label="Flip Handedness (Reference)", interactive=False ) # target column with gr.Column(): gr.Markdown( """2. Upload a hand image for target hand pose or choose one below 📥
""" ) # gr.Markdown( # """Optionally crop the image
""" # ) target = gr.ImageEditor( type="numpy", label="Target", show_label=True, height=LENGTH, width=LENGTH, brush=False, layers=False, crop_size="1:1", ) gr.Examples(example_target_imgs, [target], examples_per_page=20) with gr.Accordion(label="See hand pose", open=False): with gr.Tab("Automatic hand keypoints"): target_pose = gr.Image( type="numpy", label="Target Pose", show_label=True, height=LENGTH, width=LENGTH, interactive=False, ) target_use_auto = gr.Button(value="Click here to use automatic, not manual", interactive=False, visible=True) with gr.Tab("Manual hand keypoints"): target_manual_checkbox_info = gr.Markdown( """Step 1. Tell us if this is right, left, or both hands.
""", visible=True, ) target_manual_checkbox = gr.CheckboxGroup( ["Right hand", "Left hand"], show_label=False, visible=True, interactive=True, ) target_manual_kp_r_info = gr.Markdown( """Step 2. Click on image to provide hand keypoints for right hand. See \"OpenPose Keypoint Convention\" for guidance.
""", visible=False, ) target_manual_kp_right = gr.Image( type="numpy", label="Keypoint Selection (right hand)", show_label=True, height=LENGTH, width=LENGTH, interactive=False, visible=False, sources=[], ) with gr.Row(): target_manual_undo_right = gr.Button( value="Undo", interactive=True, visible=False ) target_manual_reset_right = gr.Button( value="Reset", interactive=True, visible=False ) target_manual_kp_l_info = gr.Markdown( """Step 2. Click on image to provide hand keypoints for left hand. See \"OpenPose keypoint convention\" for guidance.
""", visible=False ) target_manual_kp_left = gr.Image( type="numpy", label="Keypoint Selection (left hand)", show_label=True, height=LENGTH, width=LENGTH, interactive=False, visible=False, sources=[], ) with gr.Row(): target_manual_undo_left = gr.Button( value="Undo", interactive=True, visible=False ) target_manual_reset_left = gr.Button( value="Reset", interactive=True, visible=False ) target_manual_done_info = gr.Markdown( """Step 3. Hit \"Done\" button to confirm.
""", visible=False, ) target_manual_done = gr.Button(value="Done", interactive=True, visible=False) target_manual_pose = gr.Image( type="numpy", label="Target Pose", show_label=True, height=LENGTH, width=LENGTH, interactive=False, visible=False ) target_use_manual = gr.Button(value="Click here to use manual, not automatic", interactive=True, visible=False) target_manual_instruct = gr.Markdown( value="""OpenPose Keypoints Convention
""", visible=True ) target_manual_openpose = gr.Image( value="openpose.png", type="numpy", show_label=False, height=LENGTH // 2, width=LENGTH // 2, interactive=False, visible=True ) # gr.Markdown( # """Optionally flip the hand
""" # ) target_flip = gr.Checkbox( value=False, label="Flip Handedness (Target)", interactive=False ) # result column with gr.Column(): gr.Markdown( """3. Press "Run" 🎯
""" ) run = gr.Button(value="Run", interactive=False, elem_id="run_button") # gr.Markdown( # """⚠️ ~50s per generation
""" # with RTX3090. ~50s with A100.✨ Hit "Clear" to restart from the beginning
""" # ) clear = gr.ClearButton(elem_id="clear_button") # more options with gr.Accordion(label="More options", open=False): with gr.Row(): n_generation = gr.Slider( label="Number of generations", value=1, minimum=1, maximum=MAX_N, step=1, randomize=False, interactive=True, ) seed = gr.Slider( label="Seed", value=42, minimum=0, maximum=10000, step=1, randomize=False, interactive=True, ) cfg = gr.Slider( label="Classifier free guidance scale", value=2.5, minimum=0.0, maximum=10.0, step=0.1, randomize=False, interactive=True, ) # tutorial video with gr.Accordion("Tutorial Video of Demo 1", elem_id="accordion_bold_large_center"): # gr.Markdown("""""") with gr.Row(variant="panel", elem_id="repose_tutorial"): with gr.Column(): # gr.Video( # "how_to_videos/subtitled_repose_hands.mp4", # label="Tutorial", # autoplay=True, # loop=True, # show_label=True, # ) gr.HTML(tut2_example) # reference listeners ref.change(prepare_anno, [ref, ref_is_user], [ref_im_raw, ref_kp_raw]) ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_right) ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_left) ref_kp_raw.change(get_ref_anno, [ref_im_raw, ref_kp_raw, use_mask, use_pose], [ref_img, ref_pose, ref_auto_cond, ref, ref_is_user]) use_mask.input(get_ref_anno, [ref_im_raw, ref_kp_raw, use_mask, use_pose], [ref_img, ref_pose, ref_auto_cond, ref, ref_is_user]) ref_pose.change(enable_component, [ref_kp_raw, ref_pose], ref_use_auto) ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip) ref_auto_cond.change(lambda x: x, ref_auto_cond, ref_cond) ref_use_auto.click(lambda x: x, ref_auto_cond, ref_cond) ref_use_auto.click(lambda x: gr.Info("Automatic hand keypoints will be used for 'Reference'", duration=3)) ref_manual_checkbox.select( set_visible, [ref_manual_checkbox, ref_kp_got, ref_im_raw, ref_manual_kp_right, ref_manual_kp_left, ref_manual_done], [ ref_kp_got, ref_manual_kp_right, ref_manual_kp_left, ref_manual_kp_right, ref_manual_undo_right, ref_manual_reset_right, ref_manual_kp_left, ref_manual_undo_left, ref_manual_reset_left, ref_manual_kp_r_info, ref_manual_kp_l_info, ref_manual_done, ref_manual_done_info ] ) ref_manual_kp_right.select( get_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got] ) ref_manual_undo_right.click( undo_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got] ) ref_manual_reset_right.click( reset_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got] ) ref_manual_kp_left.select( get_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got] ) ref_manual_undo_left.click( undo_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got] ) ref_manual_reset_left.click( reset_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got] ) ref_manual_done.click(visible_component, [gr.State(0), ref_manual_pose], ref_manual_pose) ref_manual_done.click(visible_component, [gr.State(0), ref_use_manual], ref_use_manual) ref_manual_done.click(get_ref_anno, [ref_im_raw, ref_kp_got, use_mask, use_pose], [ref_img, ref_manual_pose, ref_manual_cond]) ref_manual_pose.change(enable_component, [ref_manual_pose, ref_manual_pose], ref_manual_done) ref_manual_pose.change(enable_component, [ref_img, ref_manual_pose], ref_flip) ref_manual_cond.change(lambda x: x, ref_manual_cond, ref_cond) ref_use_manual.click(lambda x: x, ref_manual_cond, ref_cond) ref_use_manual.click(lambda x: gr.Info("Manual hand keypoints will be used for 'Reference'", duration=3)) ref_flip.select( flip_hand, [ref, ref_im_raw, ref_pose, ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left, ref_cond, ref_auto_cond, ref_manual_cond], [ref, ref_im_raw, ref_pose, ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left, ref_cond, ref_auto_cond, ref_manual_cond] ) # target listeners target.change(prepare_anno, [target, target_is_user], [target_im_raw, target_kp_raw]) target_kp_raw.change(lambda x:x, target_im_raw, target_manual_kp_right) target_kp_raw.change(lambda x:x, target_im_raw, target_manual_kp_left) target_kp_raw.change(get_target_anno, [target_im_raw, target_kp_raw], [target_img, target_pose, target_auto_cond, target_auto_keypts, target, target_is_user]) target_pose.change(enable_component, [target_kp_raw, target_pose], target_use_auto) target_pose.change(enable_component, [target_img, target_pose], target_flip) target_auto_cond.change(lambda x: x, target_auto_cond, target_cond) target_auto_keypts.change(lambda x: x, target_auto_keypts, target_keypts) target_use_auto.click(lambda x: x, target_auto_cond, target_cond) target_use_auto.click(lambda x: x, target_auto_keypts, target_keypts) target_use_auto.click(lambda x: gr.Info("Automatic hand keypoints will be used for 'Target'", duration=3)) target_manual_checkbox.select( set_visible, [target_manual_checkbox, target_kp_got, target_im_raw, target_manual_kp_right, target_manual_kp_left, target_manual_done], [ target_kp_got, target_manual_kp_right, target_manual_kp_left, target_manual_kp_right, target_manual_undo_right, target_manual_reset_right, target_manual_kp_left, target_manual_undo_left, target_manual_reset_left, target_manual_kp_r_info, target_manual_kp_l_info, target_manual_done, target_manual_done_info ] ) target_manual_kp_right.select( get_kps, [target_im_raw, target_kp_got, gr.State("right")], [target_manual_kp_right, target_kp_got] ) target_manual_undo_right.click( undo_kps, [target_im_raw, target_kp_got, gr.State("right")], [target_manual_kp_right, target_kp_got] ) target_manual_reset_right.click( reset_kps, [target_im_raw, target_kp_got, gr.State("right")], [target_manual_kp_right, target_kp_got] ) target_manual_kp_left.select( get_kps, [target_im_raw, target_kp_got, gr.State("left")], [target_manual_kp_left, target_kp_got] ) target_manual_undo_left.click( undo_kps, [target_im_raw, target_kp_got, gr.State("left")], [target_manual_kp_left, target_kp_got] ) target_manual_reset_left.click( reset_kps, [target_im_raw, target_kp_got, gr.State("left")], [target_manual_kp_left, target_kp_got] ) target_manual_done.click(visible_component, [gr.State(0), target_manual_pose], target_manual_pose) target_manual_done.click(visible_component, [gr.State(0), target_use_manual], target_use_manual) target_manual_done.click(get_target_anno, [target_im_raw, target_kp_got], [target_img, target_manual_pose, target_manual_cond, target_manual_keypts]) target_manual_pose.change(enable_component, [target_manual_pose, target_manual_pose], target_manual_done) target_manual_pose.change(enable_component, [target_img, target_manual_pose], target_flip) target_manual_cond.change(lambda x: x, target_manual_cond, target_cond) target_manual_keypts.change(lambda x: x, target_manual_keypts, target_keypts) target_use_manual.click(lambda x: x, target_manual_cond, target_cond) target_use_manual.click(lambda x: x, target_manual_keypts, target_keypts) target_use_manual.click(lambda x: gr.Info("Manual hand keypoints will be used for 'Reference'", duration=3)) target_flip.select( flip_hand, [target, target_im_raw, target_pose, target_manual_pose, target_manual_kp_right, target_manual_kp_left, target_cond, target_auto_cond, target_manual_cond, target_keypts, target_auto_keypts, target_manual_keypts], [target, target_im_raw, target_pose, target_manual_pose, target_manual_kp_right, target_manual_kp_left, target_cond, target_auto_cond, target_manual_cond, target_keypts, target_auto_keypts, target_manual_keypts], ) # run listerners ref_cond.change(enable_component, [ref_cond, target_cond], run) target_cond.change(enable_component, [ref_cond, target_cond], run) run.click( sample_diff, [ref_cond, target_cond, target_keypts, n_generation, seed, cfg], [results, results_pose], ) clear.click( clear_all, [], [ ref, ref_manual_checkbox, ref_manual_kp_right, ref_manual_kp_left, ref_img, ref_pose, ref_manual_pose, ref_cond, ref_flip, target, target_keypts, target_manual_checkbox, target_manual_kp_right, target_manual_kp_left, target_img, target_pose, target_manual_pose, target_cond, target_flip, results, results_pose, n_generation, seed, cfg, ref_kp_raw, use_mask, ], ) clear.click( set_unvisible, [], [ ref_manual_kp_l_info, ref_manual_kp_r_info, ref_manual_kp_left, ref_manual_kp_right, ref_manual_undo_left, ref_manual_undo_right, ref_manual_reset_left, ref_manual_reset_right, ref_manual_done, ref_manual_done_info, ref_manual_pose, ref_use_manual, target_manual_kp_l_info, target_manual_kp_r_info, target_manual_kp_left, target_manual_kp_right, target_manual_undo_left, target_manual_undo_right, target_manual_reset_left, target_manual_reset_right, target_manual_done, target_manual_done_info, target_manual_pose, target_use_manual, ] ) with gr.Tab("Demo 2. Malformed Hand Correction", elem_id="fix-tab"): fix_inpaint_mask = gr.State(value=None) fix_original = gr.State(value=None) fix_crop_coord = gr.State(value=None) fix_img = gr.State(value=None) fix_kpts = gr.State(value=None) fix_kpts_path = gr.Textbox(visible=False) fix_kpts_np = gr.State(value=None) fix_ref_cond = gr.State(value=None) fix_target_cond = gr.State(value=None) fix_latent = gr.State(value=None) fix_inpaint_latent = gr.State(value=None) fix_ex_mask = gr.Image(value=None, visible=False) # more options with gr.Accordion(label="More options", open=False): gr.Markdown( "⚠️ Currently, Number of generation > 1 could lead to out-of-memory" ) with gr.Row(): fix_n_generation = gr.Slider( label="Number of generations", value=1, minimum=1, maximum=FIX_MAX_N, step=1, randomize=False, interactive=True, ) fix_seed = gr.Slider( label="Seed", value=42, minimum=0, maximum=10000, step=1, randomize=False, interactive=True, ) fix_cfg = gr.Slider( label="Classifier free guidance scale", value=3.0, minimum=0.0, maximum=10.0, step=0.1, randomize=False, interactive=True, ) fix_quality = gr.Slider( label="Quality", value=10, minimum=1, maximum=10, step=1, randomize=False, interactive=True, ) # main tabs with gr.Row(): # crop & brush with gr.Column(): gr.Markdown( """1. Upload a malformed hand image or choose one below 📥
""" ) fix_crop = gr.Image( type="numpy", sources=["upload", "webcam", "clipboard"], label="Input Image", show_label=True, height=LENGTH, width=LENGTH, interactive=True, visible=True, image_mode="RGBA" ) # gr.Markdown( # """💡 If you crop, the model can focus on more details of the cropped area. Square crops might work better than rectangle crops.
""" # ) # fix_example = gr.Examples( # fix_example_imgs, # inputs=[fix_crop], # examples_per_page=20, # ) gr.Markdown( """To crop, click top left and bottom right
""" # of your desired bounding box around the hand) ) with gr.Column(): gr.Markdown( """2. Brush wrong finger
(⚠️and surrounding area)
Don't brush the entire hand!
""" # ) fix_ref = gr.ImageEditor( type="numpy", label="Image Brushing", sources=(), show_label=True, height=LENGTH, width=LENGTH, layers=False, transforms=("brush"), brush=gr.Brush( colors=["rgb(255, 255, 255)"], default_size=20 ), # 204, 50, 50 image_mode="RGBA", container=False, interactive=True, ) # fix_ex_brush = gr.Examples( # fix_example_brush, # inputs=[fix_ref], # examples_per_page=20, # ) # keypoint selection with gr.Column(): gr.Markdown( """3. Target hand pose
""" ) # gr.Markdown( # """Either get hand pose from Examples, or manually give hand pose (located at the bottom)
""" # ) fix_kp_all = gr.Image( type="numpy", label="Target Hand Pose", show_label=False, height=LENGTH, width=LENGTH, interactive=False, visible=True, sources=(), image_mode="RGBA" ) # with gr.Accordion(open=True): # fix_ex_kpts = gr.Examples( # fix_example_kpts, # inputs=[fix_kp_all, fix_cfg, fix_seed, fix_kpts_path], # examples_per_page=20, # postprocess=False, # elem_id="kpts_examples" # ) with gr.Accordion("[Your own image] Manually give hand pose", open=False, elem_id="accordion_bold"): gr.Markdown( """① Tell us if this is right, left, or both hands
""" ) fix_checkbox = gr.CheckboxGroup( ["Right hand", "Left hand"], show_label=False, interactive=False, ) fix_kp_r_info = gr.Markdown( """② Click 21 keypoints on the image to provide the target hand pose of right hand. See the \"OpenPose keypoints convention\" for guidance.
""", visible=False ) fix_kp_right = gr.Image( type="numpy", label="Keypoint Selection (right hand)", show_label=True, height=LENGTH, width=LENGTH, interactive=False, visible=False, sources=[], ) with gr.Row(): fix_undo_right = gr.Button( value="Undo", interactive=False, visible=False ) fix_reset_right = gr.Button( value="Reset", interactive=False, visible=False ) fix_kp_l_info = gr.Markdown( """② Click 21 keypoints on the image to provide the target hand pose of left hand. See the \"OpenPose keypoints convention\" for guidance.
""", visible=False ) fix_kp_left = gr.Image( type="numpy", label="Keypoint Selection (left hand)", show_label=True, height=LENGTH, width=LENGTH, interactive=False, visible=False, sources=[], ) with gr.Row(): fix_undo_left = gr.Button( value="Undo", interactive=False, visible=False ) fix_reset_left = gr.Button( value="Reset", interactive=False, visible=False ) gr.Markdown( """OpenPose keypoints convention
""" ) fix_openpose = gr.Image( value="openpose.png", type="numpy", show_label=False, height=LENGTH // 2, width=LENGTH // 2, interactive=False, ) # result column with gr.Column(): gr.Markdown( """4. Press "Run" 🎯
""" ) fix_vis_mask32 = gr.Image( type="numpy", label=f"Visualized {opts.latent_size} Inpaint Mask", show_label=True, height=opts.latent_size, width=opts.latent_size, interactive=False, visible=False, ) fix_run = gr.Button(value="Run", interactive=False, elem_id="run_button") fix_vis_mask256 = gr.Image( type="numpy", show_label=False, height=opts.image_size, width=opts.image_size, interactive=False, visible=False, ) # gr.Markdown( # """⚠️ >3min per generation
""" # ) fix_result_original = gr.Gallery( type="numpy", label="Results", show_label=True, height=LENGTH, min_width=LENGTH, columns=FIX_MAX_N, interactive=False, preview=True, ) with gr.Accordion(label="Results of cropped area / Results with pose", open=False): fix_result = gr.Gallery( type="numpy", label="Results of cropped area", show_label=True, height=LENGTH, min_width=LENGTH, columns=FIX_MAX_N, interactive=False, preview=True, ) fix_result_pose = gr.Gallery( type="numpy", label="Results Pose", show_label=True, height=LENGTH, min_width=LENGTH, columns=FIX_MAX_N, interactive=False, preview=True, ) # gr.Markdown( # """✨ Hit "Clear" to restart from the beginning
""" # ) fix_clear = gr.ClearButton(elem_id="clear_button") with gr.Row(): gr.Examples( fix_example_all, inputs=[fix_crop, fix_ref, fix_ex_mask, fix_kp_all, fix_cfg, fix_seed, fix_kpts_path], examples_per_page=20, postprocess=True, elem_id="fix_examples_all", ) with gr.Row(): gr.Markdown( """⚠️ If brushed image doesn't load, please click the example again
""" ) # tutorial video with gr.Accordion("Tutorial Videos of Demo 2", elem_id="accordion_bold_large_center"): # gr.Markdown("""""") with gr.Row(variant="panel"): with gr.Column(): # gr.Video( # "how_to_videos/subtitled_fix_hands_custom.mp4", # label="Using your own image", # autoplay=True, # loop=True, # show_label=True, # ) gr.HTML(tut1_custom) with gr.Column(): # gr.Video( # "how_to_videos/subtitled_fix_hands_example.mp4", # label="Using our example image", # autoplay=True, # loop=True, # show_label=True, # ) gr.HTML(tut1_example) # listeners # fix_ex_mask.change(make_composite, [fix_crop, fix_ex_mask], fix_ref) fix_crop.input(lambda x: gr.update(None), fix_crop, fix_ex_mask) fix_crop.change(stash_original, fix_crop, fix_original) # fix_original: (real_H, real_W, 3) fix_crop.change(stay_crop, [fix_crop, fix_crop_coord], [fix_crop_coord, fix_ref]) fix_crop.select(process_crop, [fix_crop, fix_crop_coord], [fix_crop_coord, fix_ref, fix_crop]) fix_ref.change(visualize_ref, [fix_ref, fix_ex_mask], [fix_img, fix_inpaint_mask]) fix_img.change(lambda x: x, [fix_img], [fix_kp_right]) fix_img.change(lambda x: x, [fix_img], [fix_kp_left]) fix_ref.change( enable_component, [fix_ref, fix_ref], fix_checkbox ) fix_inpaint_mask.change( enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_right ) fix_inpaint_mask.change( enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_right ) fix_inpaint_mask.change( enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_right ) fix_inpaint_mask.change( enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_left ) fix_inpaint_mask.change( enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_left ) fix_inpaint_mask.change( enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_left ) fix_inpaint_mask.change( enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_run ) fix_checkbox.select( set_visible, [fix_checkbox, fix_kpts, fix_img, fix_kp_right, fix_kp_left], [ fix_kpts, fix_kp_right, fix_kp_left, fix_kp_right, fix_undo_right, fix_reset_right, fix_kp_left, fix_undo_left, fix_reset_left, fix_kp_r_info, fix_kp_l_info, ], ) fix_kp_right.select( get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts] # fix_img: (real_cropped_H, real_cropped_W, 3) ) fix_undo_right.click( undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts] ) fix_reset_right.click( reset_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts] ) fix_kp_left.select( get_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts] ) fix_undo_left.click( undo_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts] ) fix_reset_left.click( reset_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts] ) fix_kpts_path.change(read_kpts, fix_kpts_path, fix_kpts_np) fix_inpaint_mask.change(enable_component, [fix_inpaint_mask, fix_kpts_np], fix_run) fix_kpts_np.change(enable_component, [fix_inpaint_mask, fix_kpts_np], fix_run) fix_run.click( ready_sample, [fix_ref, fix_original, fix_ex_mask, fix_inpaint_mask, fix_kpts, fix_kpts_np], [ fix_ref_cond, fix_target_cond, fix_latent, fix_inpaint_latent, fix_kpts_np, fix_vis_mask32, fix_vis_mask256, ], ) fix_inpaint_latent.change( sample_inpaint, [ fix_ref_cond, fix_target_cond, fix_latent, fix_inpaint_latent, fix_kpts_np, fix_original, fix_crop_coord, fix_n_generation, fix_seed, fix_cfg, fix_quality, ], [fix_result, fix_result_pose, fix_result_original], ) fix_clear.click( fix_clear_all, [], [ fix_crop, fix_crop_coord, fix_ref, fix_checkbox, fix_kp_all, fix_kp_right, fix_kp_left, fix_result, fix_result_pose, fix_result_original, fix_inpaint_mask, fix_original, fix_img, fix_vis_mask32, fix_vis_mask256, fix_kpts, fix_kpts_np, fix_ref_cond, fix_target_cond, fix_latent, fix_inpaint_latent, fix_kpts_path, fix_n_generation, fix_seed, fix_cfg, fix_quality, fix_ex_mask, ], ) fix_clear.click( fix_set_unvisible, [], [ fix_kp_right, fix_kp_left, fix_kp_r_info, fix_kp_l_info, fix_undo_left, fix_undo_right, fix_reset_left, fix_reset_right ] ) # gr.Markdown("If this was useful, please cite us! ❤️
""" ) gr.Markdown(_CITE_) with gr.Accordion("Trouble Shooting", open=False, elem_id="accordion_bold_large"): gr.Markdown("If error persists, please try the following steps: