proxydet / app.py
doublejtoh's picture
fix: move gdrive to huggingface hub
c6e542f
import torch
import cv2
import os
import wget
import gradio as gr
import numpy as np
import gdown
from huggingface_hub import hf_hub_download
from argparse import Namespace
try:
import detectron2
except:
# requirements.txt gives error since detectron2 > setup.py requires torch to be installed, which is not installed before this.
os.system("python3 -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
import detectron2
from demo import setup_cfg
from proxydet.predictor import VisualizationDemo
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# # download metadata
# zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
# if not os.path.exists(zs_weight_path):
# wget.download("https://github.com/facebookresearch/Detic/raw/main/datasets/metadata/lvis_v1_clip_a+cname.npy", out=zs_weight_path)
# base_cat_mask_path = "datasets/metadata/lvis_v1_base_cat_mask.npy"
# if not os.path.exists(base_cat_mask_path):
# wget.download("https://docs.google.com/uc?export=download&id=1CbSs5yeqMsWDkRSsIlB-ln_bXDv686rH", out=base_cat_mask_path)
# lvis_train_cat_info_path = "datasets/metadata/lvis_v1_train_cat_info.json"
# if not os.path.exists(lvis_train_cat_info_path):
# wget.download("https://docs.google.com/uc?export=download&id=17WmkAJYBK4xT-YkiXLcwIWmtfulSUtmO", out=lvis_train_cat_info_path)
# # download model
# model_path = "models/proxydet_swinb_w_inl.pth"
# if not os.path.exists(model_path):
# gdown.download("https://docs.google.com/uc?export=download&id=17kUPoi-pEK7BlTBheGzWxe_DXJlg28qF", model_path)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="models/proxydet_swinb_w_inl.pth",
repo_type="model",
local_dir="./"
)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="datasets/metadata/lvis_v1_base_cat_mask.npy",
repo_type="model",
local_dir="./"
)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="datasets/metadata/lvis_v1_clip_a+cname.npy",
repo_type="model",
local_dir="./"
)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="datasets/metadata/lvis_v1_train_cat_info.json",
repo_type="model",
local_dir="./"
)
model_path = "models/proxydet_swinb_w_inl.pth"
zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
args = Namespace(
base_cat_threshold=0.9,
confidence_threshold=0.0,
config_file='configs/ProxyDet_SwinB_Lbase_INL.yaml',
cpu=not torch.cuda.is_available(),
custom_vocabulary='headphone,webcam,paper,coffe',
input=['.assets/desk.jpg'],
opts=['MODEL.WEIGHTS', model_path],
output='out.jpg',
pred_all_class=False,
video_input=None,
vocabulary='custom',
webcam=None,
zeroshot_weight_path=zs_weight_path
)
cfg = setup_cfg(args)
ovd_demo = VisualizationDemo(cfg, args)
def query_image(img, text_queries, score_threshold, base_alpha, novel_beta):
text_queries_split = text_queries.split(",")
ovd_demo.reset_classifier(text_queries)
ovd_demo.reset_base_cat_mask()
ovd_demo.predictor.model.roi_heads.cmm_base_alpha = base_alpha
ovd_demo.predictor.model.roi_heads.cmm_novel_beta = novel_beta
img_bgr = img[:, :, ::-1]
with torch.no_grad():
predictions, visualized_output = ovd_demo.run_on_image(img_bgr)
output_instances = predictions["instances"].to(device)
boxes = output_instances.pred_boxes.tensor
scores = output_instances.scores
labels = output_instances.pred_classes.tolist()
font = cv2.FONT_HERSHEY_SIMPLEX
for box, score, label in zip(boxes, scores, labels):
box = [int(i) for i in box.tolist()]
if score >= score_threshold:
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
if box[3] + 25 > 768:
y = box[3] - 10
else:
y = box[3] + 25
img = cv2.putText(
img, text_queries_split[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
)
return img
if __name__ == "__main__":
description = """
Gradio demo for ProxyDet, introduced in <a href="https://arxiv.org/abs/2312.07266">ProxyDet: Synthesizing Proxy Novel Classes via Classwise Mixup for Open-Vocabulary Object Detection</a>.
\n\nYou can use ProxyDet to query images with text descriptions of any object.
How to use?
- Simply upload an image and enter comma separated objects (e.g., "dog,cat,headphone") which you want to detect within the image.\n
Parameters:
- You can also use the score threshold slider to set a threshold to filter out low probability predictions.
- adjust alpha and beta value for base and novel classes, respectively. These determine <b>how much importance will you assign to the scores sourced from our proposed detection head which is trained with our proxy-novel classes</b>.
"""
demo = gr.Interface(
query_image,
inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1), gr.Slider(0, 1, value=0.15), gr.Slider(0, 1, value=0.35)],
outputs="image",
title="Open-Vocabulary Object Detection with ProxyDet",
description=description,
examples=[
["assets/desk.jpg", "headphone,webcam,paper,coffee", 0.11, 0.15, 0.35],
["assets/beach.jpg", "person,kite", 0.1, 0.15, 0.35],
["assets/pikachu.jpg", "pikachu,person", 0.15, 0.15, 0.35],
],
)
demo.launch()