import os import torch import einops import argparse import numpy as np from PIL import Image from PIL.Image import Resampling from depthfm import DepthFM import matplotlib.pyplot as plt def get_dtype_from_str(dtype_str): return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str] def resize_max_res( img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR ) -> Image.Image: """ Resize image to limit maximum edge length while keeping aspect ratio. Args: img (`Image.Image`): Image to be resized. max_edge_resolution (`int`): Maximum edge length (pixel). resample_method (`PIL.Image.Resampling`): Resampling method used to resize images. Returns: `Image.Image`: Resized image. """ original_width, original_height = img.size downscale_factor = min( max_edge_resolution / original_width, max_edge_resolution / original_height) new_width = int(original_width * downscale_factor) new_height = int(original_height * downscale_factor) new_width = round(new_width / 64) * 64 new_height = round(new_height / 64) * 64 print(f"Resizing image from {original_width}x{original_height} to {new_width}x{new_height}") resized_img = img.resize((new_width, new_height), resample=resample_method) return resized_img, (original_width, original_height) def load_im(fp, processing_res=-1): assert os.path.exists(fp), f"File not found: {fp}" im = Image.open(fp).convert('RGB') if processing_res < 0: processing_res = max(im.size) im, orig_res = resize_max_res(im, processing_res) x = np.array(im) x = einops.rearrange(x, 'h w c -> c h w') x = x / 127.5 - 1 x = torch.tensor(x, dtype=torch.float32)[None] return x, orig_res def main(args): print(f"{'Input':<10}: {args.img}") print(f"{'Steps':<10}: {args.num_steps}") print(f"{'Ensemble':<10}: {args.ensemble_size}") # Load the model model = DepthFM(args.ckpt) model.cuda(args.device).eval() # Load an image im, orig_res = load_im(args.img, args.processing_res) im = im.cuda(args.device) # Generate depth dtype = get_dtype_from_str(args.dtype) model.model.dtype = dtype with torch.autocast(device_type="cuda", dtype=dtype): depth = model.predict_depth(im, num_steps=args.num_steps, ensemble_size=args.ensemble_size) depth = depth.squeeze(0).squeeze(0).cpu().numpy() # (h, w) in [0, 1] # Convert depth to [0, 255] range if args.no_color: depth = (depth * 255).astype(np.uint8) else: depth = plt.get_cmap('magma')(depth, bytes=True)[..., :3] # Save the depth map depth_fp = args.img + '_depth.png' depth_img = Image.fromarray(depth) if depth_img.size != orig_res: depth_img = depth_img.resize(orig_res, Resampling.BILINEAR) depth_img.save(depth_fp) print(f"==> Saved depth map to {depth_fp}") if __name__ == "__main__": parser = argparse.ArgumentParser("DepthFM Inference") parser.add_argument("--img", type=str, default="assets/dog.png", help="Path to the input image") parser.add_argument("--ckpt", type=str, default="checkpoints/depthfm-v1.ckpt", help="Path to the model checkpoint") parser.add_argument("--num_steps", type=int, default=2, help="Number of steps for ODE solver") parser.add_argument("--ensemble_size", type=int, default=4, help="Number of ensemble members") parser.add_argument("--no_color", action="store_true", help="If set, the depth map will be grayscale") parser.add_argument("--device", type=int, default=0, help="GPU to use") parser.add_argument("--processing_res", type=int, default=-1, help="Longer edge of the image will be resized to this resolution. -1 to disable resizing.") parser.add_argument("--dtype", type=str, choices=["fp32", "bf16", "fp16"], default="fp16", help="Run with specific precision. Speeds up inference with subtle loss") args = parser.parse_args() main(args)