|
|
|
import copy |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, |
|
reweight_loss_dict) |
|
from mmdet.registry import MODELS |
|
from mmdet.structures import SampleList |
|
from mmdet.structures.bbox import bbox_project |
|
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig |
|
from .base import BaseDetector |
|
|
|
|
|
@MODELS.register_module() |
|
class SemiBaseDetector(BaseDetector): |
|
"""Base class for semi-supervised detectors. |
|
|
|
Semi-supervised detectors typically consisting of a teacher model |
|
updated by exponential moving average and a student model updated |
|
by gradient descent. |
|
|
|
Args: |
|
detector (:obj:`ConfigDict` or dict): The detector config. |
|
semi_train_cfg (:obj:`ConfigDict` or dict, optional): |
|
The semi-supervised training config. |
|
semi_test_cfg (:obj:`ConfigDict` or dict, optional): |
|
The semi-supervised testing config. |
|
data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of |
|
:class:`DetDataPreprocessor` to process the input data. |
|
Defaults to None. |
|
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or |
|
list[dict], optional): Initialization config dict. |
|
Defaults to None. |
|
""" |
|
|
|
def __init__(self, |
|
detector: ConfigType, |
|
semi_train_cfg: OptConfigType = None, |
|
semi_test_cfg: OptConfigType = None, |
|
data_preprocessor: OptConfigType = None, |
|
init_cfg: OptMultiConfig = None) -> None: |
|
super().__init__( |
|
data_preprocessor=data_preprocessor, init_cfg=init_cfg) |
|
self.student = MODELS.build(detector) |
|
self.teacher = MODELS.build(detector) |
|
self.semi_train_cfg = semi_train_cfg |
|
self.semi_test_cfg = semi_test_cfg |
|
if self.semi_train_cfg.get('freeze_teacher', True) is True: |
|
self.freeze(self.teacher) |
|
|
|
@staticmethod |
|
def freeze(model: nn.Module): |
|
"""Freeze the model.""" |
|
model.eval() |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
def loss(self, multi_batch_inputs: Dict[str, Tensor], |
|
multi_batch_data_samples: Dict[str, SampleList]) -> dict: |
|
"""Calculate losses from multi-branch inputs and data samples. |
|
|
|
Args: |
|
multi_batch_inputs (Dict[str, Tensor]): The dict of multi-branch |
|
input images, each value with shape (N, C, H, W). |
|
Each value should usually be mean centered and std scaled. |
|
multi_batch_data_samples (Dict[str, List[:obj:`DetDataSample`]]): |
|
The dict of multi-branch data samples. |
|
|
|
Returns: |
|
dict: A dictionary of loss components |
|
""" |
|
losses = dict() |
|
losses.update(**self.loss_by_gt_instances( |
|
multi_batch_inputs['sup'], multi_batch_data_samples['sup'])) |
|
|
|
origin_pseudo_data_samples, batch_info = self.get_pseudo_instances( |
|
multi_batch_inputs['unsup_teacher'], |
|
multi_batch_data_samples['unsup_teacher']) |
|
multi_batch_data_samples[ |
|
'unsup_student'] = self.project_pseudo_instances( |
|
origin_pseudo_data_samples, |
|
multi_batch_data_samples['unsup_student']) |
|
losses.update(**self.loss_by_pseudo_instances( |
|
multi_batch_inputs['unsup_student'], |
|
multi_batch_data_samples['unsup_student'], batch_info)) |
|
return losses |
|
|
|
def loss_by_gt_instances(self, batch_inputs: Tensor, |
|
batch_data_samples: SampleList) -> dict: |
|
"""Calculate losses from a batch of inputs and ground-truth data |
|
samples. |
|
|
|
Args: |
|
batch_inputs (Tensor): Input images of shape (N, C, H, W). |
|
These should usually be mean centered and std scaled. |
|
batch_data_samples (List[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. |
|
|
|
Returns: |
|
dict: A dictionary of loss components |
|
""" |
|
|
|
losses = self.student.loss(batch_inputs, batch_data_samples) |
|
sup_weight = self.semi_train_cfg.get('sup_weight', 1.) |
|
return rename_loss_dict('sup_', reweight_loss_dict(losses, sup_weight)) |
|
|
|
def loss_by_pseudo_instances(self, |
|
batch_inputs: Tensor, |
|
batch_data_samples: SampleList, |
|
batch_info: Optional[dict] = None) -> dict: |
|
"""Calculate losses from a batch of inputs and pseudo data samples. |
|
|
|
Args: |
|
batch_inputs (Tensor): Input images of shape (N, C, H, W). |
|
These should usually be mean centered and std scaled. |
|
batch_data_samples (List[:obj:`DetDataSample`]): The batch |
|
data samples. It usually includes information such |
|
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, |
|
which are `pseudo_instance` or `pseudo_panoptic_seg` |
|
or `pseudo_sem_seg` in fact. |
|
batch_info (dict): Batch information of teacher model |
|
forward propagation process. Defaults to None. |
|
|
|
Returns: |
|
dict: A dictionary of loss components |
|
""" |
|
batch_data_samples = filter_gt_instances( |
|
batch_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr) |
|
losses = self.student.loss(batch_inputs, batch_data_samples) |
|
pseudo_instances_num = sum([ |
|
len(data_samples.gt_instances) |
|
for data_samples in batch_data_samples |
|
]) |
|
unsup_weight = self.semi_train_cfg.get( |
|
'unsup_weight', 1.) if pseudo_instances_num > 0 else 0. |
|
return rename_loss_dict('unsup_', |
|
reweight_loss_dict(losses, unsup_weight)) |
|
|
|
@torch.no_grad() |
|
def get_pseudo_instances( |
|
self, batch_inputs: Tensor, batch_data_samples: SampleList |
|
) -> Tuple[SampleList, Optional[dict]]: |
|
"""Get pseudo instances from teacher model.""" |
|
self.teacher.eval() |
|
results_list = self.teacher.predict( |
|
batch_inputs, batch_data_samples, rescale=False) |
|
batch_info = {} |
|
for data_samples, results in zip(batch_data_samples, results_list): |
|
data_samples.gt_instances = results.pred_instances |
|
data_samples.gt_instances.bboxes = bbox_project( |
|
data_samples.gt_instances.bboxes, |
|
torch.from_numpy(data_samples.homography_matrix).inverse().to( |
|
self.data_preprocessor.device), data_samples.ori_shape) |
|
return batch_data_samples, batch_info |
|
|
|
def project_pseudo_instances(self, batch_pseudo_instances: SampleList, |
|
batch_data_samples: SampleList) -> SampleList: |
|
"""Project pseudo instances.""" |
|
for pseudo_instances, data_samples in zip(batch_pseudo_instances, |
|
batch_data_samples): |
|
data_samples.gt_instances = copy.deepcopy( |
|
pseudo_instances.gt_instances) |
|
data_samples.gt_instances.bboxes = bbox_project( |
|
data_samples.gt_instances.bboxes, |
|
torch.tensor(data_samples.homography_matrix).to( |
|
self.data_preprocessor.device), data_samples.img_shape) |
|
wh_thr = self.semi_train_cfg.get('min_pseudo_bbox_wh', (1e-2, 1e-2)) |
|
return filter_gt_instances(batch_data_samples, wh_thr=wh_thr) |
|
|
|
def predict(self, batch_inputs: Tensor, |
|
batch_data_samples: SampleList) -> SampleList: |
|
"""Predict results from a batch of inputs and data samples with post- |
|
processing. |
|
|
|
Args: |
|
batch_inputs (Tensor): Inputs with shape (N, C, H, W). |
|
batch_data_samples (List[:obj:`DetDataSample`]): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
rescale (bool): Whether to rescale the results. |
|
Defaults to True. |
|
|
|
Returns: |
|
list[:obj:`DetDataSample`]: Return the detection results of the |
|
input images. The returns value is DetDataSample, |
|
which usually contain 'pred_instances'. And the |
|
``pred_instances`` usually contains following keys. |
|
|
|
- scores (Tensor): Classification scores, has a shape |
|
(num_instance, ) |
|
- labels (Tensor): Labels of bboxes, has a shape |
|
(num_instances, ). |
|
- bboxes (Tensor): Has a shape (num_instances, 4), |
|
the last dimension 4 arrange as (x1, y1, x2, y2). |
|
- masks (Tensor): Has a shape (num_instances, H, W). |
|
""" |
|
if self.semi_test_cfg.get('predict_on', 'teacher') == 'teacher': |
|
return self.teacher( |
|
batch_inputs, batch_data_samples, mode='predict') |
|
else: |
|
return self.student( |
|
batch_inputs, batch_data_samples, mode='predict') |
|
|
|
def _forward(self, batch_inputs: Tensor, |
|
batch_data_samples: SampleList) -> SampleList: |
|
"""Network forward process. Usually includes backbone, neck and head |
|
forward without any post-processing. |
|
|
|
Args: |
|
batch_inputs (Tensor): Inputs with shape (N, C, H, W). |
|
|
|
Returns: |
|
tuple: A tuple of features from ``rpn_head`` and ``roi_head`` |
|
forward. |
|
""" |
|
if self.semi_test_cfg.get('forward_on', 'teacher') == 'teacher': |
|
return self.teacher( |
|
batch_inputs, batch_data_samples, mode='tensor') |
|
else: |
|
return self.student( |
|
batch_inputs, batch_data_samples, mode='tensor') |
|
|
|
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: |
|
"""Extract features. |
|
|
|
Args: |
|
batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). |
|
|
|
Returns: |
|
tuple[Tensor]: Multi-level features that may have |
|
different resolutions. |
|
""" |
|
if self.semi_test_cfg.get('extract_feat_on', 'teacher') == 'teacher': |
|
return self.teacher.extract_feat(batch_inputs) |
|
else: |
|
return self.student.extract_feat(batch_inputs) |
|
|
|
def _load_from_state_dict(self, state_dict: dict, prefix: str, |
|
local_metadata: dict, strict: bool, |
|
missing_keys: Union[List[str], str], |
|
unexpected_keys: Union[List[str], str], |
|
error_msgs: Union[List[str], str]) -> None: |
|
"""Add teacher and student prefixes to model parameter names.""" |
|
if not any([ |
|
'student' in key or 'teacher' in key |
|
for key in state_dict.keys() |
|
]): |
|
keys = list(state_dict.keys()) |
|
state_dict.update({'teacher.' + k: state_dict[k] for k in keys}) |
|
state_dict.update({'student.' + k: state_dict[k] for k in keys}) |
|
for k in keys: |
|
state_dict.pop(k) |
|
return super()._load_from_state_dict( |
|
state_dict, |
|
prefix, |
|
local_metadata, |
|
strict, |
|
missing_keys, |
|
unexpected_keys, |
|
error_msgs, |
|
) |
|
|