Spaces:
Runtime error
Runtime error
# ------------------------------------------------------------------------ | |
# HOTR official code : main.py | |
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved | |
# ------------------------------------------------------------------------ | |
# Modified from DETR (https://github.com/facebookresearch/detr) | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# ------------------------------------------------------------------------ | |
import torch | |
import torch.nn.functional as F | |
import copy | |
import numpy as np | |
import itertools | |
from torch import nn | |
from hotr.util import box_ops | |
from hotr.util.misc import (accuracy, get_world_size, is_dist_avail_and_initialized) | |
class SetCriterion(nn.Module): | |
""" This class computes the loss for DETR. | |
The process happens in two steps: | |
1) we compute hungarian assignment between ground truth boxes and the outputs of the model | |
2) we supervise each pair of matched ground-truth / prediction (supervise class and box) | |
""" | |
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, num_actions=None, HOI_losses=None, HOI_matcher=None, args=None): | |
""" Create the criterion. | |
Parameters: | |
num_classes: number of object categories, omitting the special no-object category | |
matcher: module able to compute a matching between targets and proposals | |
weight_dict: dict containing as key the names of the losses and as values their relative weight. | |
eos_coef: relative classification weight applied to the no-object category | |
losses: list of all the losses to be applied. See get_loss for list of available losses. | |
""" | |
super().__init__() | |
self.num_classes = num_classes | |
self.matcher = matcher | |
self.weight_dict = weight_dict | |
self.losses = losses | |
self.eos_coef=eos_coef | |
self.HOI_losses = HOI_losses | |
self.HOI_matcher = HOI_matcher | |
self.use_consis=args.use_consis & len(args.augpath_name)>0 | |
self.num_path = 1+len(args.augpath_name) | |
if args: | |
self.HOI_eos_coef = args.hoi_eos_coef | |
if args.dataset_file == 'vcoco': | |
self.invalid_ids = args.invalid_ids | |
self.valid_ids = np.concatenate((args.valid_ids,[-1]), axis=0) # no interaction | |
elif args.dataset_file == 'hico-det': | |
self.invalid_ids = [] | |
self.valid_ids = list(range(num_actions)) + [-1] | |
# for targets | |
self.num_tgt_classes = len(args.valid_obj_ids) | |
tgt_empty_weight = torch.ones(self.num_tgt_classes + 1) | |
tgt_empty_weight[-1] = self.HOI_eos_coef | |
self.register_buffer('tgt_empty_weight', tgt_empty_weight) | |
self.dataset_file = args.dataset_file | |
empty_weight = torch.ones(self.num_classes + 1) | |
empty_weight[-1] = eos_coef | |
self.register_buffer('empty_weight', empty_weight) | |
####################################################################################################################### | |
# * DETR Losses | |
####################################################################################################################### | |
def loss_labels(self, outputs, targets, indices, num_boxes, log=True): | |
"""Classification loss (NLL) | |
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] | |
""" | |
assert 'pred_logits' in outputs | |
src_logits = outputs['pred_logits'] | |
idx = self._get_src_permutation_idx(indices) | |
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) | |
target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) | |
target_classes[idx] = target_classes_o | |
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) | |
losses = {'loss_ce': loss_ce} | |
if log: | |
# TODO this should probably be a separate loss, not hacked in this one here | |
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] | |
return losses | |
def loss_cardinality(self, outputs, targets, indices, num_boxes): | |
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes | |
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients | |
""" | |
pred_logits = outputs['pred_logits'] | |
device = pred_logits.device | |
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) | |
# Count the number of predictions that are NOT "no-object" (which is the last class) | |
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) | |
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) | |
losses = {'cardinality_error': card_err} | |
return losses | |
def loss_boxes(self, outputs, targets, indices, num_boxes): | |
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss | |
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] | |
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. | |
""" | |
assert 'pred_boxes' in outputs | |
idx = self._get_src_permutation_idx(indices) | |
src_boxes = outputs['pred_boxes'][idx] | |
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') | |
losses = {} | |
losses['loss_bbox'] = loss_bbox.sum() / num_boxes | |
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( | |
box_ops.box_cxcywh_to_xyxy(src_boxes), | |
box_ops.box_cxcywh_to_xyxy(target_boxes))) | |
losses['loss_giou'] = loss_giou.sum() / num_boxes | |
return losses | |
####################################################################################################################### | |
# * HOTR Losses | |
####################################################################################################################### | |
# >>> HOI Losses 1 : HO Pointer | |
def loss_pair_labels(self, outputs, targets, hoi_indices, num_boxes,use_consis, log=False): | |
assert ('pred_hidx' in outputs and 'pred_oidx' in outputs) | |
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} | |
nu,q,hd=outputs['pred_hidx'].shape | |
src_hidx = outputs['pred_hidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1) | |
src_oidx = outputs['pred_oidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1) | |
hoi_ind=list(itertools.chain.from_iterable(hoi_indices)) | |
idx = self._get_src_permutation_idx(hoi_ind) | |
target_hidx_classes = torch.full(src_hidx.shape[:2], -1, dtype=torch.int64, device=src_hidx.device) | |
target_oidx_classes = torch.full(src_oidx.shape[:2], -1, dtype=torch.int64, device=src_oidx.device) | |
# H Pointer loss | |
target_classes_h = torch.cat([t["h_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice]) | |
target_hidx_classes[idx] = target_classes_h | |
# O Pointer loss | |
target_classes_o = torch.cat([t["o_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice]) | |
target_oidx_classes[idx] = target_classes_o | |
loss_h = F.cross_entropy(src_hidx.transpose(1, 2), target_hidx_classes, ignore_index=-1) | |
loss_o = F.cross_entropy(src_oidx.transpose(1, 2), target_oidx_classes, ignore_index=-1) | |
#Consistency loss | |
if use_consis: | |
consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices ] | |
src_hidx_inputs=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
src_hidx_targets=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
src_oidx_inputs=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
src_oidx_targets=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
loss_h_consistency=[0.5*(F.kl_div(src_hidx_input.log(),src_hidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_hidx_target.log(),src_hidx_input.clone().detach(),reduction='batchmean')) for src_hidx_input,src_hidx_target in zip(src_hidx_inputs,src_hidx_targets)] | |
loss_o_consistency=[0.5*(F.kl_div(src_oidx_input.log(),src_oidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_oidx_target.log(),src_oidx_input.clone().detach(),reduction='batchmean')) for src_oidx_input,src_oidx_target in zip(src_oidx_inputs,src_oidx_targets)] | |
loss_h_consistency=torch.mean(torch.stack(loss_h_consistency)) | |
loss_o_consistency=torch.mean(torch.stack(loss_o_consistency)) | |
losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o,'loss_h_consistency':loss_h_consistency,'loss_o_consistency':loss_o_consistency} | |
else: | |
losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o} | |
return losses | |
# >>> HOI Losses 2 : pair actions | |
def loss_pair_actions(self, outputs, targets, hoi_indices, num_boxes,use_consis): | |
assert 'pred_actions' in outputs | |
src_actions = outputs['pred_actions'].flatten(end_dim=1) | |
hoi_ind=list(itertools.chain.from_iterable(hoi_indices)) | |
# idx = self._get_src_permutation_idx(hoi_indices) | |
idx = self._get_src_permutation_idx(hoi_ind) | |
# Construct Target -------------------------------------------------------------------------------------------------------------- | |
target_classes_o = torch.cat([t["pair_actions"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice]) | |
target_classes = torch.full(src_actions.shape, 0, dtype=torch.float32, device=src_actions.device) | |
target_classes[..., -1] = 1 # the last index for no-interaction is '1' if a label exists | |
pos_classes = torch.full(target_classes[idx].shape, 0, dtype=torch.float32, device=src_actions.device) # else, the last index for no-interaction is '0' | |
pos_classes[:, :-1] = target_classes_o.float() | |
target_classes[idx] = pos_classes | |
# -------------------------------------------------------------------------------------------------------------------------------- | |
# BCE Loss ----------------------------------------------------------------------------------------------------------------------- | |
logits = src_actions.sigmoid() | |
loss_bce = F.binary_cross_entropy(logits[..., self.valid_ids], target_classes[..., self.valid_ids], reduction='none') | |
p_t = logits[..., self.valid_ids] * target_classes[..., self.valid_ids] + (1 - logits[..., self.valid_ids]) * (1 - target_classes[..., self.valid_ids]) | |
loss_bce = ((1-p_t)**2 * loss_bce) | |
alpha_t = 0.25 * target_classes[..., self.valid_ids] + (1 - 0.25) * (1 - target_classes[..., self.valid_ids]) | |
loss_focal = alpha_t * loss_bce | |
loss_act = loss_focal.sum() / max(target_classes[..., self.valid_ids[:-1]].sum(), 1) | |
# -------------------------------------------------------------------------------------------------------------------------------- | |
#Consistency loss | |
if use_consis: | |
consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices] | |
src_action_inputs=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[0]]) for i,consistency_idx in enumerate(consistency_idxs)] | |
src_action_targets=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[1]]) for i,consistency_idx in enumerate(consistency_idxs)] | |
loss_action_consistency=[F.mse_loss(src_action_input,src_action_target) for src_action_input,src_action_target in zip(src_action_inputs,src_action_targets)] | |
loss_action_consistency=torch.mean(torch.stack(loss_action_consistency)) | |
# import pdb;pdb.set_trace() | |
losses = {'loss_act': loss_act,'loss_act_consistency':loss_action_consistency} | |
else: | |
losses = {'loss_act': loss_act} | |
return losses | |
# HOI Losses 3 : action targets | |
def loss_pair_targets(self, outputs, targets, hoi_indices, num_interactions,use_consis, log=True): | |
assert 'pred_obj_logits' in outputs | |
src_logits = outputs['pred_obj_logits'] | |
nu,q,hd=outputs['pred_obj_logits'].shape | |
hoi_ind=list(itertools.chain.from_iterable(hoi_indices)) | |
idx = self._get_src_permutation_idx(hoi_ind) | |
target_classes_o = torch.cat([t['pair_targets'][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice]) | |
pad_tgt = -1 # src_logits.shape[2]-1 | |
target_classes = torch.full(src_logits.shape[:2], pad_tgt, dtype=torch.int64, device=src_logits.device) | |
target_classes[idx] = target_classes_o | |
loss_obj_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.tgt_empty_weight, ignore_index=-1) | |
#consistency | |
if use_consis: | |
consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices] | |
src_logits_inputs=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
src_logits_targets=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
loss_tgt_consistency=[0.5*(F.kl_div(src_logit_input.log(),src_logit_target.clone().detach(),reduction='batchmean')+F.kl_div(src_logit_target.log(),src_logit_input.clone().detach(),reduction='batchmean')) for src_logit_input,src_logit_target in zip(src_logits_inputs,src_logits_targets)] | |
loss_tgt_consistency=torch.mean(torch.stack(loss_tgt_consistency)) | |
losses = {'loss_tgt': loss_obj_ce,"loss_tgt_label_consistency":loss_tgt_consistency} | |
else: | |
losses = {'loss_tgt': loss_obj_ce} | |
if log: | |
ignore_idx = (target_classes_o != -1) | |
losses['obj_class_error'] = 100 - accuracy(src_logits[idx][ignore_idx, :-1], target_classes_o[ignore_idx])[0] | |
# losses['obj_class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] | |
return losses | |
def _get_src_permutation_idx(self, indices): | |
# permute predictions following indices | |
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
src_idx = torch.cat([src for (src, _) in indices]) | |
return batch_idx, src_idx | |
def _get_consistency_src_permutation_idx(self, indices): | |
all_tgt=torch.cat([j for(_,j) in indices]).unique() | |
path_idxs=[torch.cat([torch.tensor([i]) for i,(_,t)in enumerate(indices) if (t==tgt).any()]) for tgt in all_tgt] | |
q_idxs=[torch.cat([s[t==tgt] for (s,t)in indices]) for tgt in all_tgt] | |
path_idxs=torch.cat([torch.combinations(path_idx) for path_idx in path_idxs if len(path_idx)>1]) | |
q_idxs=torch.cat([torch.combinations(q_idx) for q_idx in q_idxs if len(q_idx)>1]) | |
return (path_idxs[:,0],q_idxs[:,0]),(path_idxs[:,1],q_idxs[:,1]) | |
def _get_tgt_permutation_idx(self, indices): | |
# permute targets following indices | |
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |
return batch_idx, tgt_idx | |
# ***************************************************************************** | |
# >>> DETR Losses | |
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): | |
loss_map = { | |
'labels': self.loss_labels, | |
'cardinality': self.loss_cardinality, | |
'boxes': self.loss_boxes | |
} | |
assert loss in loss_map, f'do you really want to compute {loss} loss?' | |
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) | |
# >>> HOTR Losses | |
def get_HOI_loss(self, loss, outputs, targets, indices, num_boxes,use_consis, **kwargs): | |
loss_map = { | |
'pair_labels': self.loss_pair_labels, | |
'pair_actions': self.loss_pair_actions | |
} | |
if self.dataset_file == 'hico-det': loss_map['pair_targets'] = self.loss_pair_targets | |
assert loss in loss_map, f'do you really want to compute {loss} loss?' | |
return loss_map[loss](outputs, targets, indices, num_boxes,use_consis, **kwargs) | |
# ***************************************************************************** | |
def forward(self, outputs, targets, log=False): | |
""" This performs the loss computation. | |
Parameters: | |
outputs: dict of tensors, see the output specification of the model for the format | |
targets: list of dicts, such that len(targets) == batch_size. | |
The expected keys in each dict depends on the losses applied, see each loss' doc | |
""" | |
outputs_without_aux = {k: v for k, v in outputs.items() if (k != 'aux_outputs' and k != 'hoi_aux_outputs')} | |
# Retrieve the matching between the outputs of the last layer and the targets | |
indices = self.matcher(outputs_without_aux, targets) | |
if self.HOI_losses is not None: | |
input_targets = [copy.deepcopy(target) for target in targets] | |
hoi_indices, hoi_targets = self.HOI_matcher(outputs_without_aux, input_targets, indices, log) | |
# Compute the average number of target boxes accross all nodes, for normalization purposes | |
num_boxes = sum(len(t["labels"]) for t in targets) | |
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) | |
if is_dist_avail_and_initialized(): | |
torch.distributed.all_reduce(num_boxes) | |
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() | |
# Compute all the requested losses | |
losses = {} | |
for loss in self.losses: | |
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) | |
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
if 'aux_outputs' in outputs: | |
for i, aux_outputs in enumerate(outputs['aux_outputs']): | |
indices = self.matcher(aux_outputs, targets) | |
for loss in self.losses: | |
if loss == 'masks': | |
# Intermediate masks losses are too costly to compute, we ignore them. | |
continue | |
kwargs = {} | |
if loss == 'labels': | |
# Logging is enabled only for the last layer | |
kwargs = {'log': False} | |
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) | |
l_dict = {k + f'_{i}': v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
# HOI detection losses | |
if self.HOI_losses is not None: | |
for loss in self.HOI_losses: | |
losses.update(self.get_HOI_loss(loss, outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis)) | |
# if self.dataset_file == 'hico-det': losses['loss_oidx'] += losses['loss_tgt'] | |
if 'hoi_aux_outputs' in outputs: | |
for i, aux_outputs in enumerate(outputs['hoi_aux_outputs']): | |
input_targets = [copy.deepcopy(target) for target in targets] | |
hoi_indices, targets_for_aux = self.HOI_matcher(aux_outputs, input_targets, indices, log) | |
for loss in self.HOI_losses: | |
kwargs = {} | |
if loss == 'pair_targets': kwargs = {'log': False} # Logging is enabled only for the last layer | |
l_dict = self.get_HOI_loss(loss, aux_outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis, **kwargs) | |
l_dict = {k + f'_{i}': v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
# if self.dataset_file == 'hico-det': losses[f'loss_oidx_{i}'] += losses[f'loss_tgt_{i}'] | |
return losses |