root
initial commit
5e0b9df
# ------------------------------------------------------------------------
# HOTR official code : main.py
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import torch
import torch.nn.functional as F
import copy
import numpy as np
import itertools
from torch import nn
from hotr.util import box_ops
from hotr.util.misc import (accuracy, get_world_size, is_dist_avail_and_initialized)
class SetCriterion(nn.Module):
""" This class computes the loss for DETR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, num_actions=None, HOI_losses=None, HOI_matcher=None, args=None):
""" Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
matcher: module able to compute a matching between targets and proposals
weight_dict: dict containing as key the names of the losses and as values their relative weight.
eos_coef: relative classification weight applied to the no-object category
losses: list of all the losses to be applied. See get_loss for list of available losses.
"""
super().__init__()
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.losses = losses
self.eos_coef=eos_coef
self.HOI_losses = HOI_losses
self.HOI_matcher = HOI_matcher
self.use_consis=args.use_consis & len(args.augpath_name)>0
self.num_path = 1+len(args.augpath_name)
if args:
self.HOI_eos_coef = args.hoi_eos_coef
if args.dataset_file == 'vcoco':
self.invalid_ids = args.invalid_ids
self.valid_ids = np.concatenate((args.valid_ids,[-1]), axis=0) # no interaction
elif args.dataset_file == 'hico-det':
self.invalid_ids = []
self.valid_ids = list(range(num_actions)) + [-1]
# for targets
self.num_tgt_classes = len(args.valid_obj_ids)
tgt_empty_weight = torch.ones(self.num_tgt_classes + 1)
tgt_empty_weight[-1] = self.HOI_eos_coef
self.register_buffer('tgt_empty_weight', tgt_empty_weight)
self.dataset_file = args.dataset_file
empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = eos_coef
self.register_buffer('empty_weight', empty_weight)
#######################################################################################################################
# * DETR Losses
#######################################################################################################################
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits']
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {'loss_ce': loss_ce}
if log:
# TODO this should probably be a separate loss, not hacked in this one here
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses
@torch.no_grad()
def loss_cardinality(self, outputs, targets, indices, num_boxes):
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
"""
pred_logits = outputs['pred_logits']
device = pred_logits.device
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
losses = {'cardinality_error': card_err}
return losses
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
"""
assert 'pred_boxes' in outputs
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs['pred_boxes'][idx]
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
losses = {}
losses['loss_bbox'] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
box_ops.box_cxcywh_to_xyxy(src_boxes),
box_ops.box_cxcywh_to_xyxy(target_boxes)))
losses['loss_giou'] = loss_giou.sum() / num_boxes
return losses
#######################################################################################################################
# * HOTR Losses
#######################################################################################################################
# >>> HOI Losses 1 : HO Pointer
def loss_pair_labels(self, outputs, targets, hoi_indices, num_boxes,use_consis, log=False):
assert ('pred_hidx' in outputs and 'pred_oidx' in outputs)
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
nu,q,hd=outputs['pred_hidx'].shape
src_hidx = outputs['pred_hidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1)
src_oidx = outputs['pred_oidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1)
hoi_ind=list(itertools.chain.from_iterable(hoi_indices))
idx = self._get_src_permutation_idx(hoi_ind)
target_hidx_classes = torch.full(src_hidx.shape[:2], -1, dtype=torch.int64, device=src_hidx.device)
target_oidx_classes = torch.full(src_oidx.shape[:2], -1, dtype=torch.int64, device=src_oidx.device)
# H Pointer loss
target_classes_h = torch.cat([t["h_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
target_hidx_classes[idx] = target_classes_h
# O Pointer loss
target_classes_o = torch.cat([t["o_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
target_oidx_classes[idx] = target_classes_o
loss_h = F.cross_entropy(src_hidx.transpose(1, 2), target_hidx_classes, ignore_index=-1)
loss_o = F.cross_entropy(src_oidx.transpose(1, 2), target_oidx_classes, ignore_index=-1)
#Consistency loss
if use_consis:
consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices ]
src_hidx_inputs=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
src_hidx_targets=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
src_oidx_inputs=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
src_oidx_targets=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
loss_h_consistency=[0.5*(F.kl_div(src_hidx_input.log(),src_hidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_hidx_target.log(),src_hidx_input.clone().detach(),reduction='batchmean')) for src_hidx_input,src_hidx_target in zip(src_hidx_inputs,src_hidx_targets)]
loss_o_consistency=[0.5*(F.kl_div(src_oidx_input.log(),src_oidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_oidx_target.log(),src_oidx_input.clone().detach(),reduction='batchmean')) for src_oidx_input,src_oidx_target in zip(src_oidx_inputs,src_oidx_targets)]
loss_h_consistency=torch.mean(torch.stack(loss_h_consistency))
loss_o_consistency=torch.mean(torch.stack(loss_o_consistency))
losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o,'loss_h_consistency':loss_h_consistency,'loss_o_consistency':loss_o_consistency}
else:
losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o}
return losses
# >>> HOI Losses 2 : pair actions
def loss_pair_actions(self, outputs, targets, hoi_indices, num_boxes,use_consis):
assert 'pred_actions' in outputs
src_actions = outputs['pred_actions'].flatten(end_dim=1)
hoi_ind=list(itertools.chain.from_iterable(hoi_indices))
# idx = self._get_src_permutation_idx(hoi_indices)
idx = self._get_src_permutation_idx(hoi_ind)
# Construct Target --------------------------------------------------------------------------------------------------------------
target_classes_o = torch.cat([t["pair_actions"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
target_classes = torch.full(src_actions.shape, 0, dtype=torch.float32, device=src_actions.device)
target_classes[..., -1] = 1 # the last index for no-interaction is '1' if a label exists
pos_classes = torch.full(target_classes[idx].shape, 0, dtype=torch.float32, device=src_actions.device) # else, the last index for no-interaction is '0'
pos_classes[:, :-1] = target_classes_o.float()
target_classes[idx] = pos_classes
# --------------------------------------------------------------------------------------------------------------------------------
# BCE Loss -----------------------------------------------------------------------------------------------------------------------
logits = src_actions.sigmoid()
loss_bce = F.binary_cross_entropy(logits[..., self.valid_ids], target_classes[..., self.valid_ids], reduction='none')
p_t = logits[..., self.valid_ids] * target_classes[..., self.valid_ids] + (1 - logits[..., self.valid_ids]) * (1 - target_classes[..., self.valid_ids])
loss_bce = ((1-p_t)**2 * loss_bce)
alpha_t = 0.25 * target_classes[..., self.valid_ids] + (1 - 0.25) * (1 - target_classes[..., self.valid_ids])
loss_focal = alpha_t * loss_bce
loss_act = loss_focal.sum() / max(target_classes[..., self.valid_ids[:-1]].sum(), 1)
# --------------------------------------------------------------------------------------------------------------------------------
#Consistency loss
if use_consis:
consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices]
src_action_inputs=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[0]]) for i,consistency_idx in enumerate(consistency_idxs)]
src_action_targets=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[1]]) for i,consistency_idx in enumerate(consistency_idxs)]
loss_action_consistency=[F.mse_loss(src_action_input,src_action_target) for src_action_input,src_action_target in zip(src_action_inputs,src_action_targets)]
loss_action_consistency=torch.mean(torch.stack(loss_action_consistency))
# import pdb;pdb.set_trace()
losses = {'loss_act': loss_act,'loss_act_consistency':loss_action_consistency}
else:
losses = {'loss_act': loss_act}
return losses
# HOI Losses 3 : action targets
def loss_pair_targets(self, outputs, targets, hoi_indices, num_interactions,use_consis, log=True):
assert 'pred_obj_logits' in outputs
src_logits = outputs['pred_obj_logits']
nu,q,hd=outputs['pred_obj_logits'].shape
hoi_ind=list(itertools.chain.from_iterable(hoi_indices))
idx = self._get_src_permutation_idx(hoi_ind)
target_classes_o = torch.cat([t['pair_targets'][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice])
pad_tgt = -1 # src_logits.shape[2]-1
target_classes = torch.full(src_logits.shape[:2], pad_tgt, dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
loss_obj_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.tgt_empty_weight, ignore_index=-1)
#consistency
if use_consis:
consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices]
src_logits_inputs=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
src_logits_targets=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)]
loss_tgt_consistency=[0.5*(F.kl_div(src_logit_input.log(),src_logit_target.clone().detach(),reduction='batchmean')+F.kl_div(src_logit_target.log(),src_logit_input.clone().detach(),reduction='batchmean')) for src_logit_input,src_logit_target in zip(src_logits_inputs,src_logits_targets)]
loss_tgt_consistency=torch.mean(torch.stack(loss_tgt_consistency))
losses = {'loss_tgt': loss_obj_ce,"loss_tgt_label_consistency":loss_tgt_consistency}
else:
losses = {'loss_tgt': loss_obj_ce}
if log:
ignore_idx = (target_classes_o != -1)
losses['obj_class_error'] = 100 - accuracy(src_logits[idx][ignore_idx, :-1], target_classes_o[ignore_idx])[0]
# losses['obj_class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def _get_consistency_src_permutation_idx(self, indices):
all_tgt=torch.cat([j for(_,j) in indices]).unique()
path_idxs=[torch.cat([torch.tensor([i]) for i,(_,t)in enumerate(indices) if (t==tgt).any()]) for tgt in all_tgt]
q_idxs=[torch.cat([s[t==tgt] for (s,t)in indices]) for tgt in all_tgt]
path_idxs=torch.cat([torch.combinations(path_idx) for path_idx in path_idxs if len(path_idx)>1])
q_idxs=torch.cat([torch.combinations(q_idx) for q_idx in q_idxs if len(q_idx)>1])
return (path_idxs[:,0],q_idxs[:,0]),(path_idxs[:,1],q_idxs[:,1])
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
# *****************************************************************************
# >>> DETR Losses
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
loss_map = {
'labels': self.loss_labels,
'cardinality': self.loss_cardinality,
'boxes': self.loss_boxes
}
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
# >>> HOTR Losses
def get_HOI_loss(self, loss, outputs, targets, indices, num_boxes,use_consis, **kwargs):
loss_map = {
'pair_labels': self.loss_pair_labels,
'pair_actions': self.loss_pair_actions
}
if self.dataset_file == 'hico-det': loss_map['pair_targets'] = self.loss_pair_targets
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, num_boxes,use_consis, **kwargs)
# *****************************************************************************
def forward(self, outputs, targets, log=False):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
outputs_without_aux = {k: v for k, v in outputs.items() if (k != 'aux_outputs' and k != 'hoi_aux_outputs')}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
if self.HOI_losses is not None:
input_targets = [copy.deepcopy(target) for target in targets]
hoi_indices, hoi_targets = self.HOI_matcher(outputs_without_aux, input_targets, indices, log)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
# Compute all the requested losses
losses = {}
for loss in self.losses:
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if 'aux_outputs' in outputs:
for i, aux_outputs in enumerate(outputs['aux_outputs']):
indices = self.matcher(aux_outputs, targets)
for loss in self.losses:
if loss == 'masks':
# Intermediate masks losses are too costly to compute, we ignore them.
continue
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs = {'log': False}
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
losses.update(l_dict)
# HOI detection losses
if self.HOI_losses is not None:
for loss in self.HOI_losses:
losses.update(self.get_HOI_loss(loss, outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis))
# if self.dataset_file == 'hico-det': losses['loss_oidx'] += losses['loss_tgt']
if 'hoi_aux_outputs' in outputs:
for i, aux_outputs in enumerate(outputs['hoi_aux_outputs']):
input_targets = [copy.deepcopy(target) for target in targets]
hoi_indices, targets_for_aux = self.HOI_matcher(aux_outputs, input_targets, indices, log)
for loss in self.HOI_losses:
kwargs = {}
if loss == 'pair_targets': kwargs = {'log': False} # Logging is enabled only for the last layer
l_dict = self.get_HOI_loss(loss, aux_outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis, **kwargs)
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
losses.update(l_dict)
# if self.dataset_file == 'hico-det': losses[f'loss_oidx_{i}'] += losses[f'loss_tgt_{i}']
return losses