File size: 5,506 Bytes
69e8976
 
668bc57
fec1a45
69e8976
 
43afdb7
d9cdf23
8075387
 
 
 
fec1a45
668bc57
8075387
 
 
69e8976
 
 
 
 
 
 
d9cdf23
 
 
 
69e8976
d9cdf23
 
 
7f3a3f1
d9cdf23
 
 
7f3a3f1
d9cdf23
 
 
 
3ad7c50
d9cdf23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ad7c50
c6e542f
43afdb7
8075387
 
 
 
 
 
 
43afdb7
8075387
 
 
 
 
43afdb7
8075387
 
 
69e8976
8075387
 
 
 
 
 
 
69e8976
8075387
 
 
 
 
69e8976
 
 
 
 
 
 
 
 
 
 
 
 
8075387
69e8976
 
 
668bc57
 
8075387
 
3354665
 
d6f8777
3354665
 
43afdb7
668bc57
 
 
8075387
668bc57
8075387
668bc57
 
8075387
 
 
668bc57
 
d6f8777
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import cv2
import os
import wget
import gradio as gr
import numpy as np
import gdown
from huggingface_hub import hf_hub_download
from argparse import Namespace
try:
    import detectron2
except:
    # requirements.txt gives error since detectron2 > setup.py requires torch to be installed, which is not installed before this.
    os.system("python3 -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
    import detectron2
from demo import setup_cfg 
from proxydet.predictor import VisualizationDemo

# Use GPU if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# # download metadata
# zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
# if not os.path.exists(zs_weight_path):
#     wget.download("https://github.com/facebookresearch/Detic/raw/main/datasets/metadata/lvis_v1_clip_a+cname.npy", out=zs_weight_path)

# base_cat_mask_path = "datasets/metadata/lvis_v1_base_cat_mask.npy"
# if not os.path.exists(base_cat_mask_path):
#     wget.download("https://docs.google.com/uc?export=download&id=1CbSs5yeqMsWDkRSsIlB-ln_bXDv686rH", out=base_cat_mask_path)

# lvis_train_cat_info_path = "datasets/metadata/lvis_v1_train_cat_info.json"
# if not os.path.exists(lvis_train_cat_info_path):
#     wget.download("https://docs.google.com/uc?export=download&id=17WmkAJYBK4xT-YkiXLcwIWmtfulSUtmO", out=lvis_train_cat_info_path)

# # download model
# model_path = "models/proxydet_swinb_w_inl.pth"
# if not os.path.exists(model_path):
#     gdown.download("https://docs.google.com/uc?export=download&id=17kUPoi-pEK7BlTBheGzWxe_DXJlg28qF", model_path)

hf_hub_download(
    repo_id="doublejtoh/proxydet_data",
    filename="models/proxydet_swinb_w_inl.pth",
    repo_type="model",
    local_dir="./"
)

hf_hub_download(
    repo_id="doublejtoh/proxydet_data",
    filename="datasets/metadata/lvis_v1_base_cat_mask.npy",
    repo_type="model",
    local_dir="./"
)

hf_hub_download(
    repo_id="doublejtoh/proxydet_data",
    filename="datasets/metadata/lvis_v1_clip_a+cname.npy",
    repo_type="model",
    local_dir="./"
)

hf_hub_download(
    repo_id="doublejtoh/proxydet_data",
    filename="datasets/metadata/lvis_v1_train_cat_info.json",
    repo_type="model",
    local_dir="./"
)
model_path = "models/proxydet_swinb_w_inl.pth"
zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy'

args = Namespace(
    base_cat_threshold=0.9,
    confidence_threshold=0.0,
    config_file='configs/ProxyDet_SwinB_Lbase_INL.yaml',
    cpu=not torch.cuda.is_available(),
    custom_vocabulary='headphone,webcam,paper,coffe',
    input=['.assets/desk.jpg'],
    opts=['MODEL.WEIGHTS', model_path],
    output='out.jpg',
    pred_all_class=False,
    video_input=None,
    vocabulary='custom',
    webcam=None,
    zeroshot_weight_path=zs_weight_path
)
cfg = setup_cfg(args)
ovd_demo = VisualizationDemo(cfg, args)

def query_image(img, text_queries, score_threshold, base_alpha, novel_beta):
    text_queries_split = text_queries.split(",")
    ovd_demo.reset_classifier(text_queries)
    ovd_demo.reset_base_cat_mask()
    ovd_demo.predictor.model.roi_heads.cmm_base_alpha = base_alpha
    ovd_demo.predictor.model.roi_heads.cmm_novel_beta = novel_beta
    img_bgr = img[:, :, ::-1]
    with torch.no_grad():
        predictions, visualized_output = ovd_demo.run_on_image(img_bgr)
    output_instances = predictions["instances"].to(device)
    boxes = output_instances.pred_boxes.tensor
    scores = output_instances.scores
    labels = output_instances.pred_classes.tolist()
    font = cv2.FONT_HERSHEY_SIMPLEX

    for box, score, label in zip(boxes, scores, labels):
        box = [int(i) for i in box.tolist()]

        if score >= score_threshold:
            img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
            if box[3] + 25 > 768:
                y = box[3] - 10
            else:
                y = box[3] + 25
                
            img = cv2.putText(
                img, text_queries_split[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
            )
    return img

if __name__ == "__main__":
    description = """
Gradio demo for ProxyDet, introduced in <a href="https://arxiv.org/abs/2312.07266">ProxyDet: Synthesizing Proxy Novel Classes via Classwise Mixup for Open-Vocabulary Object Detection</a>. 
\n\nYou can use ProxyDet to query images with text descriptions of any object. 

How to use?
- Simply upload an image and enter comma separated objects (e.g., "dog,cat,headphone") which you want to detect within the image.\n
Parameters:
- You can also use the score threshold slider to set a threshold to filter out low probability predictions.
- adjust alpha and beta value for base and novel classes, respectively. These determine <b>how much importance will you assign to the scores sourced from our proposed detection head which is trained with our proxy-novel classes</b>.
    """
    demo = gr.Interface(
        query_image, 
        inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1), gr.Slider(0, 1, value=0.15), gr.Slider(0, 1, value=0.35)],
        outputs="image",
        title="Open-Vocabulary Object Detection with ProxyDet",
        description=description,
        examples=[
            ["assets/desk.jpg", "headphone,webcam,paper,coffee", 0.11, 0.15, 0.35], 
            ["assets/beach.jpg", "person,kite", 0.1, 0.15, 0.35],
            ["assets/pikachu.jpg", "pikachu,person", 0.15, 0.15, 0.35],
        ],
    )
    demo.launch()