Spaces:
Runtime error
Runtime error
File size: 5,506 Bytes
69e8976 668bc57 fec1a45 69e8976 43afdb7 d9cdf23 8075387 fec1a45 668bc57 8075387 69e8976 d9cdf23 69e8976 d9cdf23 7f3a3f1 d9cdf23 7f3a3f1 d9cdf23 3ad7c50 d9cdf23 3ad7c50 c6e542f 43afdb7 8075387 43afdb7 8075387 43afdb7 8075387 69e8976 8075387 69e8976 8075387 69e8976 8075387 69e8976 668bc57 8075387 3354665 d6f8777 3354665 43afdb7 668bc57 8075387 668bc57 8075387 668bc57 8075387 668bc57 d6f8777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch
import cv2
import os
import wget
import gradio as gr
import numpy as np
import gdown
from huggingface_hub import hf_hub_download
from argparse import Namespace
try:
import detectron2
except:
# requirements.txt gives error since detectron2 > setup.py requires torch to be installed, which is not installed before this.
os.system("python3 -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
import detectron2
from demo import setup_cfg
from proxydet.predictor import VisualizationDemo
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# # download metadata
# zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
# if not os.path.exists(zs_weight_path):
# wget.download("https://github.com/facebookresearch/Detic/raw/main/datasets/metadata/lvis_v1_clip_a+cname.npy", out=zs_weight_path)
# base_cat_mask_path = "datasets/metadata/lvis_v1_base_cat_mask.npy"
# if not os.path.exists(base_cat_mask_path):
# wget.download("https://docs.google.com/uc?export=download&id=1CbSs5yeqMsWDkRSsIlB-ln_bXDv686rH", out=base_cat_mask_path)
# lvis_train_cat_info_path = "datasets/metadata/lvis_v1_train_cat_info.json"
# if not os.path.exists(lvis_train_cat_info_path):
# wget.download("https://docs.google.com/uc?export=download&id=17WmkAJYBK4xT-YkiXLcwIWmtfulSUtmO", out=lvis_train_cat_info_path)
# # download model
# model_path = "models/proxydet_swinb_w_inl.pth"
# if not os.path.exists(model_path):
# gdown.download("https://docs.google.com/uc?export=download&id=17kUPoi-pEK7BlTBheGzWxe_DXJlg28qF", model_path)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="models/proxydet_swinb_w_inl.pth",
repo_type="model",
local_dir="./"
)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="datasets/metadata/lvis_v1_base_cat_mask.npy",
repo_type="model",
local_dir="./"
)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="datasets/metadata/lvis_v1_clip_a+cname.npy",
repo_type="model",
local_dir="./"
)
hf_hub_download(
repo_id="doublejtoh/proxydet_data",
filename="datasets/metadata/lvis_v1_train_cat_info.json",
repo_type="model",
local_dir="./"
)
model_path = "models/proxydet_swinb_w_inl.pth"
zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
args = Namespace(
base_cat_threshold=0.9,
confidence_threshold=0.0,
config_file='configs/ProxyDet_SwinB_Lbase_INL.yaml',
cpu=not torch.cuda.is_available(),
custom_vocabulary='headphone,webcam,paper,coffe',
input=['.assets/desk.jpg'],
opts=['MODEL.WEIGHTS', model_path],
output='out.jpg',
pred_all_class=False,
video_input=None,
vocabulary='custom',
webcam=None,
zeroshot_weight_path=zs_weight_path
)
cfg = setup_cfg(args)
ovd_demo = VisualizationDemo(cfg, args)
def query_image(img, text_queries, score_threshold, base_alpha, novel_beta):
text_queries_split = text_queries.split(",")
ovd_demo.reset_classifier(text_queries)
ovd_demo.reset_base_cat_mask()
ovd_demo.predictor.model.roi_heads.cmm_base_alpha = base_alpha
ovd_demo.predictor.model.roi_heads.cmm_novel_beta = novel_beta
img_bgr = img[:, :, ::-1]
with torch.no_grad():
predictions, visualized_output = ovd_demo.run_on_image(img_bgr)
output_instances = predictions["instances"].to(device)
boxes = output_instances.pred_boxes.tensor
scores = output_instances.scores
labels = output_instances.pred_classes.tolist()
font = cv2.FONT_HERSHEY_SIMPLEX
for box, score, label in zip(boxes, scores, labels):
box = [int(i) for i in box.tolist()]
if score >= score_threshold:
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
if box[3] + 25 > 768:
y = box[3] - 10
else:
y = box[3] + 25
img = cv2.putText(
img, text_queries_split[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
)
return img
if __name__ == "__main__":
description = """
Gradio demo for ProxyDet, introduced in <a href="https://arxiv.org/abs/2312.07266">ProxyDet: Synthesizing Proxy Novel Classes via Classwise Mixup for Open-Vocabulary Object Detection</a>.
\n\nYou can use ProxyDet to query images with text descriptions of any object.
How to use?
- Simply upload an image and enter comma separated objects (e.g., "dog,cat,headphone") which you want to detect within the image.\n
Parameters:
- You can also use the score threshold slider to set a threshold to filter out low probability predictions.
- adjust alpha and beta value for base and novel classes, respectively. These determine <b>how much importance will you assign to the scores sourced from our proposed detection head which is trained with our proxy-novel classes</b>.
"""
demo = gr.Interface(
query_image,
inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1), gr.Slider(0, 1, value=0.15), gr.Slider(0, 1, value=0.35)],
outputs="image",
title="Open-Vocabulary Object Detection with ProxyDet",
description=description,
examples=[
["assets/desk.jpg", "headphone,webcam,paper,coffee", 0.11, 0.15, 0.35],
["assets/beach.jpg", "person,kite", 0.1, 0.15, 0.35],
["assets/pikachu.jpg", "pikachu,person", 0.15, 0.15, 0.35],
],
)
demo.launch()
|