File size: 4,253 Bytes
3324de2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import torch
import einops
import argparse
import numpy as np
from PIL import Image
from PIL.Image import Resampling
from depthfm import DepthFM
import matplotlib.pyplot as plt

def get_dtype_from_str(dtype_str):
    return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str]

def resize_max_res(
    img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR
) -> Image.Image:
    """
    Resize image to limit maximum edge length while keeping aspect ratio.

    Args:
        img (`Image.Image`):
            Image to be resized.
        max_edge_resolution (`int`):
            Maximum edge length (pixel).
        resample_method (`PIL.Image.Resampling`):
            Resampling method used to resize images.

    Returns:
        `Image.Image`: Resized image.
    """
    original_width, original_height = img.size
    downscale_factor = min( max_edge_resolution / original_width, max_edge_resolution / original_height)

    new_width  = int(original_width * downscale_factor)
    new_height = int(original_height * downscale_factor)

    new_width  = round(new_width / 64) * 64
    new_height = round(new_height / 64) * 64

    print(f"Resizing image from {original_width}x{original_height} to {new_width}x{new_height}")

    resized_img = img.resize((new_width, new_height), resample=resample_method)
    return resized_img, (original_width, original_height)

def load_im(fp, processing_res=-1):
    assert os.path.exists(fp), f"File not found: {fp}"
    im = Image.open(fp).convert('RGB')
    if processing_res < 0:
        processing_res = max(im.size)
    im, orig_res = resize_max_res(im, processing_res)
    x = np.array(im)
    x = einops.rearrange(x, 'h w c -> c h w')
    x = x / 127.5 - 1
    x = torch.tensor(x, dtype=torch.float32)[None]
    return x, orig_res


def main(args):
    print(f"{'Input':<10}: {args.img}")
    print(f"{'Steps':<10}: {args.num_steps}")
    print(f"{'Ensemble':<10}: {args.ensemble_size}")

    # Load the model
    model = DepthFM(args.ckpt)
    model.cuda(args.device).eval()

    # Load an image
    im, orig_res = load_im(args.img, args.processing_res)
    im = im.cuda(args.device)

    # Generate depth
    dtype = get_dtype_from_str(args.dtype)
    model.model.dtype = dtype
    with torch.autocast(device_type="cuda", dtype=dtype):
        depth = model.predict_depth(im, num_steps=args.num_steps, ensemble_size=args.ensemble_size)
    depth = depth.squeeze(0).squeeze(0).cpu().numpy()       # (h, w) in [0, 1]

    # Convert depth to [0, 255] range
    if args.no_color:
        depth = (depth * 255).astype(np.uint8)
    else:
        depth = plt.get_cmap('magma')(depth, bytes=True)[..., :3]

    # Save the depth map
    depth_fp = args.img + '_depth.png'
    depth_img = Image.fromarray(depth)
    if depth_img.size != orig_res:
        depth_img = depth_img.resize(orig_res, Resampling.BILINEAR)
    depth_img.save(depth_fp)
    print(f"==> Saved depth map to {depth_fp}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("DepthFM Inference")
    parser.add_argument("--img", type=str, default="assets/dog.png",
                        help="Path to the input image")
    parser.add_argument("--ckpt", type=str, default="checkpoints/depthfm-v1.ckpt",
                        help="Path to the model checkpoint")
    parser.add_argument("--num_steps", type=int, default=2,
                        help="Number of steps for ODE solver")
    parser.add_argument("--ensemble_size", type=int, default=4,
                        help="Number of ensemble members")
    parser.add_argument("--no_color", action="store_true",
                        help="If set, the depth map will be grayscale")
    parser.add_argument("--device", type=int, default=0,
                        help="GPU to use")
    parser.add_argument("--processing_res", type=int, default=-1, 
                        help="Longer edge of the image will be resized to this resolution. -1 to disable resizing.")
    parser.add_argument("--dtype", type=str, choices=["fp32", "bf16", "fp16"], default="fp16", 
                        help="Run with specific precision. Speeds up inference with subtle loss")
    args = parser.parse_args()

    main(args)