Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,253 Bytes
3324de2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import os
import torch
import einops
import argparse
import numpy as np
from PIL import Image
from PIL.Image import Resampling
from depthfm import DepthFM
import matplotlib.pyplot as plt
def get_dtype_from_str(dtype_str):
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str]
def resize_max_res(
img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR
) -> Image.Image:
"""
Resize image to limit maximum edge length while keeping aspect ratio.
Args:
img (`Image.Image`):
Image to be resized.
max_edge_resolution (`int`):
Maximum edge length (pixel).
resample_method (`PIL.Image.Resampling`):
Resampling method used to resize images.
Returns:
`Image.Image`: Resized image.
"""
original_width, original_height = img.size
downscale_factor = min( max_edge_resolution / original_width, max_edge_resolution / original_height)
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)
new_width = round(new_width / 64) * 64
new_height = round(new_height / 64) * 64
print(f"Resizing image from {original_width}x{original_height} to {new_width}x{new_height}")
resized_img = img.resize((new_width, new_height), resample=resample_method)
return resized_img, (original_width, original_height)
def load_im(fp, processing_res=-1):
assert os.path.exists(fp), f"File not found: {fp}"
im = Image.open(fp).convert('RGB')
if processing_res < 0:
processing_res = max(im.size)
im, orig_res = resize_max_res(im, processing_res)
x = np.array(im)
x = einops.rearrange(x, 'h w c -> c h w')
x = x / 127.5 - 1
x = torch.tensor(x, dtype=torch.float32)[None]
return x, orig_res
def main(args):
print(f"{'Input':<10}: {args.img}")
print(f"{'Steps':<10}: {args.num_steps}")
print(f"{'Ensemble':<10}: {args.ensemble_size}")
# Load the model
model = DepthFM(args.ckpt)
model.cuda(args.device).eval()
# Load an image
im, orig_res = load_im(args.img, args.processing_res)
im = im.cuda(args.device)
# Generate depth
dtype = get_dtype_from_str(args.dtype)
model.model.dtype = dtype
with torch.autocast(device_type="cuda", dtype=dtype):
depth = model.predict_depth(im, num_steps=args.num_steps, ensemble_size=args.ensemble_size)
depth = depth.squeeze(0).squeeze(0).cpu().numpy() # (h, w) in [0, 1]
# Convert depth to [0, 255] range
if args.no_color:
depth = (depth * 255).astype(np.uint8)
else:
depth = plt.get_cmap('magma')(depth, bytes=True)[..., :3]
# Save the depth map
depth_fp = args.img + '_depth.png'
depth_img = Image.fromarray(depth)
if depth_img.size != orig_res:
depth_img = depth_img.resize(orig_res, Resampling.BILINEAR)
depth_img.save(depth_fp)
print(f"==> Saved depth map to {depth_fp}")
if __name__ == "__main__":
parser = argparse.ArgumentParser("DepthFM Inference")
parser.add_argument("--img", type=str, default="assets/dog.png",
help="Path to the input image")
parser.add_argument("--ckpt", type=str, default="checkpoints/depthfm-v1.ckpt",
help="Path to the model checkpoint")
parser.add_argument("--num_steps", type=int, default=2,
help="Number of steps for ODE solver")
parser.add_argument("--ensemble_size", type=int, default=4,
help="Number of ensemble members")
parser.add_argument("--no_color", action="store_true",
help="If set, the depth map will be grayscale")
parser.add_argument("--device", type=int, default=0,
help="GPU to use")
parser.add_argument("--processing_res", type=int, default=-1,
help="Longer edge of the image will be resized to this resolution. -1 to disable resizing.")
parser.add_argument("--dtype", type=str, choices=["fp32", "bf16", "fp16"], default="fp16",
help="Run with specific precision. Speeds up inference with subtle loss")
args = parser.parse_args()
main(args)
|