import torch
import cv2
import os
import wget
import gradio as gr
import numpy as np
import gdown
from huggingface_hub import hf_hub_download
from argparse import Namespace
try:
import detectron2
except:
# requirements.txt gives error since detectron2 > setup.py requires torch to be installed, which is not installed before this.
os.system("python3 -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
import detectron2
from demo import setup_cfg
from proxydet.predictor import VisualizationDemo
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# # download metadata
# zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
# if not os.path.exists(zs_weight_path):
# wget.download("https://github.com/facebookresearch/Detic/raw/main/datasets/metadata/lvis_v1_clip_a+cname.npy", out=zs_weight_path)
# base_cat_mask_path = "datasets/metadata/lvis_v1_base_cat_mask.npy"
# if not os.path.exists(base_cat_mask_path):
# wget.download("https://docs.google.com/uc?export=download&id=1CbSs5yeqMsWDkRSsIlB-ln_bXDv686rH", out=base_cat_mask_path)
# lvis_train_cat_info_path = "datasets/metadata/lvis_v1_train_cat_info.json"
# if not os.path.exists(lvis_train_cat_info_path):
# wget.download("https://docs.google.com/uc?export=download&id=17WmkAJYBK4xT-YkiXLcwIWmtfulSUtmO", out=lvis_train_cat_info_path)
# # download model
# model_path = "models/proxydet_swinb_w_inl.pth"
# if not os.path.exists(model_path):
# gdown.download("https://docs.google.com/uc?export=download&id=17kUPoi-pEK7BlTBheGzWxe_DXJlg28qF", model_path)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="models/proxydet_swinb_w_inl.pth",
repo_type="model",
local_dir="./"
)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="datasets/metadata/lvis_v1_base_cat_mask.npy",
repo_type="model",
local_dir="./"
)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="datasets/metadata/lvis_v1_clip_a+cname.npy",
repo_type="model",
local_dir="./"
)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="datasets/metadata/lvis_v1_train_cat_info.json",
repo_type="model",
local_dir="./"
)
model_path = "models/proxydet_swinb_w_inl.pth"
zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
args = Namespace(
base_cat_threshold=0.9,
confidence_threshold=0.0,
config_file='configs/ProxyDet_SwinB_Lbase_INL.yaml',
cpu=not torch.cuda.is_available(),
custom_vocabulary='headphone,webcam,paper,coffe',
input=['.assets/desk.jpg'],
opts=['MODEL.WEIGHTS', model_path],
output='out.jpg',
pred_all_class=False,
video_input=None,
vocabulary='custom',
webcam=None,
zeroshot_weight_path=zs_weight_path
)
cfg = setup_cfg(args)
ovd_demo = VisualizationDemo(cfg, args)
def query_image(img, text_queries, score_threshold, base_alpha, novel_beta):
text_queries_split = text_queries.split(",")
ovd_demo.reset_classifier(text_queries)
ovd_demo.reset_base_cat_mask()
ovd_demo.predictor.model.roi_heads.cmm_base_alpha = base_alpha
ovd_demo.predictor.model.roi_heads.cmm_novel_beta = novel_beta
img_bgr = img[:, :, ::-1]
with torch.no_grad():
predictions, visualized_output = ovd_demo.run_on_image(img_bgr)
output_instances = predictions["instances"].to(device)
boxes = output_instances.pred_boxes.tensor
scores = output_instances.scores
labels = output_instances.pred_classes.tolist()
font = cv2.FONT_HERSHEY_SIMPLEX
for box, score, label in zip(boxes, scores, labels):
box = [int(i) for i in box.tolist()]
if score >= score_threshold:
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
if box[3] + 25 > 768:
y = box[3] - 10
else:
y = box[3] + 25
img = cv2.putText(
img, text_queries_split[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
)
return img
if __name__ == "__main__":
description = """
Gradio demo for ProxyDet, introduced in ProxyDet: Synthesizing Proxy Novel Classes via Classwise Mixup for Open-Vocabulary Object Detection.
\n\nYou can use ProxyDet to query images with text descriptions of any object.
How to use?
- Simply upload an image and enter comma separated objects (e.g., "dog,cat,headphone") which you want to detect within the image.\n
Parameters:
- You can also use the score threshold slider to set a threshold to filter out low probability predictions.
- adjust alpha and beta value for base and novel classes, respectively. These determine how much importance will you assign to the scores sourced from our proposed detection head which is trained with our proxy-novel classes.
"""
demo = gr.Interface(
query_image,
inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1), gr.Slider(0, 1, value=0.15), gr.Slider(0, 1, value=0.35)],
outputs="image",
title="Open-Vocabulary Object Detection with ProxyDet",
description=description,
examples=[
["assets/desk.jpg", "headphone,webcam,paper,coffee", 0.11, 0.15, 0.35],
["assets/beach.jpg", "person,kite", 0.1, 0.15, 0.35],
["assets/pikachu.jpg", "pikachu,person", 0.15, 0.15, 0.35],
],
)
demo.launch()