import torch import cv2 import os import wget import gradio as gr import numpy as np import gdown from huggingface_hub import hf_hub_download from argparse import Namespace try: import detectron2 except: # requirements.txt gives error since detectron2 > setup.py requires torch to be installed, which is not installed before this. os.system("python3 -m pip install 'git+https://github.com/facebookresearch/detectron2.git'") import detectron2 from demo import setup_cfg from proxydet.predictor import VisualizationDemo # Use GPU if available if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") # # download metadata # zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy' # if not os.path.exists(zs_weight_path): # wget.download("https://github.com/facebookresearch/Detic/raw/main/datasets/metadata/lvis_v1_clip_a+cname.npy", out=zs_weight_path) # base_cat_mask_path = "datasets/metadata/lvis_v1_base_cat_mask.npy" # if not os.path.exists(base_cat_mask_path): # wget.download("https://docs.google.com/uc?export=download&id=1CbSs5yeqMsWDkRSsIlB-ln_bXDv686rH", out=base_cat_mask_path) # lvis_train_cat_info_path = "datasets/metadata/lvis_v1_train_cat_info.json" # if not os.path.exists(lvis_train_cat_info_path): # wget.download("https://docs.google.com/uc?export=download&id=17WmkAJYBK4xT-YkiXLcwIWmtfulSUtmO", out=lvis_train_cat_info_path) # # download model # model_path = "models/proxydet_swinb_w_inl.pth" # if not os.path.exists(model_path): # gdown.download("https://docs.google.com/uc?export=download&id=17kUPoi-pEK7BlTBheGzWxe_DXJlg28qF", model_path) hf_hub_download( repo_id="doublejtoh/proxydet_data", filename="models/proxydet_swinb_w_inl.pth", repo_type="model", local_dir="./" ) hf_hub_download( repo_id="doublejtoh/proxydet_data", filename="datasets/metadata/lvis_v1_base_cat_mask.npy", repo_type="model", local_dir="./" ) hf_hub_download( repo_id="doublejtoh/proxydet_data", filename="datasets/metadata/lvis_v1_clip_a+cname.npy", repo_type="model", local_dir="./" ) hf_hub_download( repo_id="doublejtoh/proxydet_data", filename="datasets/metadata/lvis_v1_train_cat_info.json", repo_type="model", local_dir="./" ) model_path = "models/proxydet_swinb_w_inl.pth" zs_weight_path = 'datasets/metadata/lvis_v1_clip_a+cname.npy' args = Namespace( base_cat_threshold=0.9, confidence_threshold=0.0, config_file='configs/ProxyDet_SwinB_Lbase_INL.yaml', cpu=not torch.cuda.is_available(), custom_vocabulary='headphone,webcam,paper,coffe', input=['.assets/desk.jpg'], opts=['MODEL.WEIGHTS', model_path], output='out.jpg', pred_all_class=False, video_input=None, vocabulary='custom', webcam=None, zeroshot_weight_path=zs_weight_path ) cfg = setup_cfg(args) ovd_demo = VisualizationDemo(cfg, args) def query_image(img, text_queries, score_threshold, base_alpha, novel_beta): text_queries_split = text_queries.split(",") ovd_demo.reset_classifier(text_queries) ovd_demo.reset_base_cat_mask() ovd_demo.predictor.model.roi_heads.cmm_base_alpha = base_alpha ovd_demo.predictor.model.roi_heads.cmm_novel_beta = novel_beta img_bgr = img[:, :, ::-1] with torch.no_grad(): predictions, visualized_output = ovd_demo.run_on_image(img_bgr) output_instances = predictions["instances"].to(device) boxes = output_instances.pred_boxes.tensor scores = output_instances.scores labels = output_instances.pred_classes.tolist() font = cv2.FONT_HERSHEY_SIMPLEX for box, score, label in zip(boxes, scores, labels): box = [int(i) for i in box.tolist()] if score >= score_threshold: img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5) if box[3] + 25 > 768: y = box[3] - 10 else: y = box[3] + 25 img = cv2.putText( img, text_queries_split[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA ) return img if __name__ == "__main__": description = """ Gradio demo for ProxyDet, introduced in ProxyDet: Synthesizing Proxy Novel Classes via Classwise Mixup for Open-Vocabulary Object Detection. \n\nYou can use ProxyDet to query images with text descriptions of any object. How to use? - Simply upload an image and enter comma separated objects (e.g., "dog,cat,headphone") which you want to detect within the image.\n Parameters: - You can also use the score threshold slider to set a threshold to filter out low probability predictions. - adjust alpha and beta value for base and novel classes, respectively. These determine how much importance will you assign to the scores sourced from our proposed detection head which is trained with our proxy-novel classes. """ demo = gr.Interface( query_image, inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1), gr.Slider(0, 1, value=0.15), gr.Slider(0, 1, value=0.35)], outputs="image", title="Open-Vocabulary Object Detection with ProxyDet", description=description, examples=[ ["assets/desk.jpg", "headphone,webcam,paper,coffee", 0.11, 0.15, 0.35], ["assets/beach.jpg", "person,kite", 0.1, 0.15, 0.35], ["assets/pikachu.jpg", "pikachu,person", 0.15, 0.15, 0.35], ], ) demo.launch()