""" Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) Copyright(c) 2023 lyuwenyu. All Rights Reserved. """ import importlib.metadata from torch import Tensor if "0.15.2" in importlib.metadata.version("torchvision"): import torchvision torchvision.disable_beta_transforms_warning() from torchvision.datapoints import BoundingBox as BoundingBoxes from torchvision.datapoints import BoundingBoxFormat, Image, Mask, Video from torchvision.transforms.v2 import SanitizeBoundingBox as SanitizeBoundingBoxes _boxes_keys = ["format", "spatial_size"] elif "0.17" > importlib.metadata.version("torchvision") >= "0.16": import torchvision torchvision.disable_beta_transforms_warning() from torchvision.transforms.v2 import SanitizeBoundingBoxes from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video _boxes_keys = ["format", "canvas_size"] elif importlib.metadata.version("torchvision") >= "0.17": import torchvision from torchvision.transforms.v2 import SanitizeBoundingBoxes from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video _boxes_keys = ["format", "canvas_size"] else: raise RuntimeError("Please make sure torchvision version >= 0.15.2") def convert_to_tv_tensor(tensor: Tensor, key: str, box_format="xyxy", spatial_size=None) -> Tensor: """ Args: tensor (Tensor): input tensor key (str): transform to key Return: Dict[str, TV_Tensor] """ assert key in ( "boxes", "masks", ), "Only support 'boxes' and 'masks'" if key == "boxes": box_format = getattr(BoundingBoxFormat, box_format.upper()) _kwargs = dict(zip(_boxes_keys, [box_format, spatial_size])) return BoundingBoxes(tensor, **_kwargs) if key == "masks": return Mask(tensor)