|
|
|
import copy |
|
from typing import Callable, Dict, List, Optional, Union |
|
|
|
import numpy as np |
|
from mmcv.transforms import BaseTransform, Compose |
|
from mmcv.transforms.utils import cache_random_params, cache_randomness |
|
|
|
from mmdet.registry import TRANSFORMS |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class MultiBranch(BaseTransform): |
|
r"""Multiple branch pipeline wrapper. |
|
|
|
Generate multiple data-augmented versions of the same image. |
|
`MultiBranch` needs to specify the branch names of all |
|
pipelines of the dataset, perform corresponding data augmentation |
|
for the current branch, and return None for other branches, |
|
which ensures the consistency of return format across |
|
different samples. |
|
|
|
Args: |
|
branch_field (list): List of branch names. |
|
branch_pipelines (dict): Dict of different pipeline configs |
|
to be composed. |
|
|
|
Examples: |
|
>>> branch_field = ['sup', 'unsup_teacher', 'unsup_student'] |
|
>>> sup_pipeline = [ |
|
>>> dict(type='LoadImageFromFile'), |
|
>>> dict(type='LoadAnnotations', with_bbox=True), |
|
>>> dict(type='Resize', scale=(1333, 800), keep_ratio=True), |
|
>>> dict(type='RandomFlip', prob=0.5), |
|
>>> dict( |
|
>>> type='MultiBranch', |
|
>>> branch_field=branch_field, |
|
>>> sup=dict(type='PackDetInputs')) |
|
>>> ] |
|
>>> weak_pipeline = [ |
|
>>> dict(type='LoadImageFromFile'), |
|
>>> dict(type='LoadAnnotations', with_bbox=True), |
|
>>> dict(type='Resize', scale=(1333, 800), keep_ratio=True), |
|
>>> dict(type='RandomFlip', prob=0.0), |
|
>>> dict( |
|
>>> type='MultiBranch', |
|
>>> branch_field=branch_field, |
|
>>> sup=dict(type='PackDetInputs')) |
|
>>> ] |
|
>>> strong_pipeline = [ |
|
>>> dict(type='LoadImageFromFile'), |
|
>>> dict(type='LoadAnnotations', with_bbox=True), |
|
>>> dict(type='Resize', scale=(1333, 800), keep_ratio=True), |
|
>>> dict(type='RandomFlip', prob=1.0), |
|
>>> dict( |
|
>>> type='MultiBranch', |
|
>>> branch_field=branch_field, |
|
>>> sup=dict(type='PackDetInputs')) |
|
>>> ] |
|
>>> unsup_pipeline = [ |
|
>>> dict(type='LoadImageFromFile'), |
|
>>> dict(type='LoadEmptyAnnotations'), |
|
>>> dict( |
|
>>> type='MultiBranch', |
|
>>> branch_field=branch_field, |
|
>>> unsup_teacher=weak_pipeline, |
|
>>> unsup_student=strong_pipeline) |
|
>>> ] |
|
>>> from mmcv.transforms import Compose |
|
>>> sup_branch = Compose(sup_pipeline) |
|
>>> unsup_branch = Compose(unsup_pipeline) |
|
>>> print(sup_branch) |
|
>>> Compose( |
|
>>> LoadImageFromFile(ignore_empty=False, to_float32=False, color_type='color', imdecode_backend='cv2') # noqa |
|
>>> LoadAnnotations(with_bbox=True, with_label=True, with_mask=False, with_seg=False, poly2mask=True, imdecode_backend='cv2') # noqa |
|
>>> Resize(scale=(1333, 800), scale_factor=None, keep_ratio=True, clip_object_border=True), backend=cv2), interpolation=bilinear) # noqa |
|
>>> RandomFlip(prob=0.5, direction=horizontal) |
|
>>> MultiBranch(branch_pipelines=['sup']) |
|
>>> ) |
|
>>> print(unsup_branch) |
|
>>> Compose( |
|
>>> LoadImageFromFile(ignore_empty=False, to_float32=False, color_type='color', imdecode_backend='cv2') # noqa |
|
>>> LoadEmptyAnnotations(with_bbox=True, with_label=True, with_mask=False, with_seg=False, seg_ignore_label=255) # noqa |
|
>>> MultiBranch(branch_pipelines=['unsup_teacher', 'unsup_student']) |
|
>>> ) |
|
""" |
|
|
|
def __init__(self, branch_field: List[str], |
|
**branch_pipelines: dict) -> None: |
|
self.branch_field = branch_field |
|
self.branch_pipelines = { |
|
branch: Compose(pipeline) |
|
for branch, pipeline in branch_pipelines.items() |
|
} |
|
|
|
def transform(self, results: dict) -> dict: |
|
"""Transform function to apply transforms sequentially. |
|
|
|
Args: |
|
results (dict): Result dict from loading pipeline. |
|
|
|
Returns: |
|
dict: |
|
|
|
- 'inputs' (Dict[str, obj:`torch.Tensor`]): The forward data of |
|
models from different branches. |
|
- 'data_sample' (Dict[str,obj:`DetDataSample`]): The annotation |
|
info of the sample from different branches. |
|
""" |
|
|
|
multi_results = {} |
|
for branch in self.branch_field: |
|
multi_results[branch] = {'inputs': None, 'data_samples': None} |
|
for branch, pipeline in self.branch_pipelines.items(): |
|
branch_results = pipeline(copy.deepcopy(results)) |
|
|
|
|
|
if branch_results is None: |
|
return None |
|
multi_results[branch] = branch_results |
|
|
|
format_results = {} |
|
for branch, results in multi_results.items(): |
|
for key in results.keys(): |
|
if format_results.get(key, None) is None: |
|
format_results[key] = {branch: results[key]} |
|
else: |
|
format_results[key][branch] = results[key] |
|
return format_results |
|
|
|
def __repr__(self) -> str: |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(branch_pipelines={list(self.branch_pipelines.keys())})' |
|
return repr_str |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class RandomOrder(Compose): |
|
"""Shuffle the transform Sequence.""" |
|
|
|
@cache_randomness |
|
def _random_permutation(self): |
|
return np.random.permutation(len(self.transforms)) |
|
|
|
def transform(self, results: Dict) -> Optional[Dict]: |
|
"""Transform function to apply transforms in random order. |
|
|
|
Args: |
|
results (dict): A result dict contains the results to transform. |
|
|
|
Returns: |
|
dict or None: Transformed results. |
|
""" |
|
inds = self._random_permutation() |
|
for idx in inds: |
|
t = self.transforms[idx] |
|
results = t(results) |
|
if results is None: |
|
return None |
|
return results |
|
|
|
def __repr__(self): |
|
"""Compute the string representation.""" |
|
format_string = self.__class__.__name__ + '(' |
|
for t in self.transforms: |
|
format_string += f'{t.__class__.__name__}, ' |
|
format_string += ')' |
|
return format_string |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class ProposalBroadcaster(BaseTransform): |
|
"""A transform wrapper to apply the wrapped transforms to process both |
|
`gt_bboxes` and `proposals` without adding any codes. It will do the |
|
following steps: |
|
|
|
1. Scatter the broadcasting targets to a list of inputs of the wrapped |
|
transforms. The type of the list should be list[dict, dict], which |
|
the first is the original inputs, the second is the processing |
|
results that `gt_bboxes` being rewritten by the `proposals`. |
|
2. Apply ``self.transforms``, with same random parameters, which is |
|
sharing with a context manager. The type of the outputs is a |
|
list[dict, dict]. |
|
3. Gather the outputs, update the `proposals` in the first item of |
|
the outputs with the `gt_bboxes` in the second . |
|
|
|
Args: |
|
transforms (list, optional): Sequence of transform |
|
object or config dict to be wrapped. Defaults to []. |
|
|
|
Note: The `TransformBroadcaster` in MMCV can achieve the same operation as |
|
`ProposalBroadcaster`, but need to set more complex parameters. |
|
|
|
Examples: |
|
>>> pipeline = [ |
|
>>> dict(type='LoadImageFromFile'), |
|
>>> dict(type='LoadProposals', num_max_proposals=2000), |
|
>>> dict(type='LoadAnnotations', with_bbox=True), |
|
>>> dict( |
|
>>> type='ProposalBroadcaster', |
|
>>> transforms=[ |
|
>>> dict(type='Resize', scale=(1333, 800), |
|
>>> keep_ratio=True), |
|
>>> dict(type='RandomFlip', prob=0.5), |
|
>>> ]), |
|
>>> dict(type='PackDetInputs')] |
|
""" |
|
|
|
def __init__(self, transforms: List[Union[dict, Callable]] = []) -> None: |
|
self.transforms = Compose(transforms) |
|
|
|
def transform(self, results: dict) -> dict: |
|
"""Apply wrapped transform functions to process both `gt_bboxes` and |
|
`proposals`. |
|
|
|
Args: |
|
results (dict): Result dict from loading pipeline. |
|
|
|
Returns: |
|
dict: Updated result dict. |
|
""" |
|
assert results.get('proposals', None) is not None, \ |
|
'`proposals` should be in the results, please delete ' \ |
|
'`ProposalBroadcaster` in your configs, or check whether ' \ |
|
'you have load proposals successfully.' |
|
|
|
inputs = self._process_input(results) |
|
outputs = self._apply_transforms(inputs) |
|
outputs = self._process_output(outputs) |
|
return outputs |
|
|
|
def _process_input(self, data: dict) -> list: |
|
"""Scatter the broadcasting targets to a list of inputs of the wrapped |
|
transforms. |
|
|
|
Args: |
|
data (dict): The original input data. |
|
|
|
Returns: |
|
list[dict]: A list of input data. |
|
""" |
|
cp_data = copy.deepcopy(data) |
|
cp_data['gt_bboxes'] = cp_data['proposals'] |
|
scatters = [data, cp_data] |
|
return scatters |
|
|
|
def _apply_transforms(self, inputs: list) -> list: |
|
"""Apply ``self.transforms``. |
|
|
|
Args: |
|
inputs (list[dict, dict]): list of input data. |
|
|
|
Returns: |
|
list[dict]: The output of the wrapped pipeline. |
|
""" |
|
assert len(inputs) == 2 |
|
ctx = cache_random_params |
|
with ctx(self.transforms): |
|
output_scatters = [self.transforms(_input) for _input in inputs] |
|
return output_scatters |
|
|
|
def _process_output(self, output_scatters: list) -> dict: |
|
"""Gathering and renaming data items. |
|
|
|
Args: |
|
output_scatters (list[dict, dict]): The output of the wrapped |
|
pipeline. |
|
|
|
Returns: |
|
dict: Updated result dict. |
|
""" |
|
assert isinstance(output_scatters, list) and \ |
|
isinstance(output_scatters[0], dict) and \ |
|
len(output_scatters) == 2 |
|
outputs = output_scatters[0] |
|
outputs['proposals'] = output_scatters[1]['gt_bboxes'] |
|
return outputs |
|
|