# ------------------------------------------------------------------------ # HOTR official code : hotr/models/hotr_matcher.py # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved # ------------------------------------------------------------------------ import torch from scipy.optimize import linear_sum_assignment from torch import nn from hotr.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou import hotr.util.misc as utils import wandb class HungarianPairMatcher(nn.Module): def __init__(self, args): """Creates the matcher Params: cost_action: This is the relative weight of the multi-label action classification error in the matching cost cost_hbox: This is the relative weight of the classification error for human idx in the matching cost cost_obox: This is the relative weight of the classification error for object idx in the matching cost """ super().__init__() self.cost_action = args.set_cost_act self.cost_hbox = self.cost_obox = args.set_cost_idx self.cost_target = args.set_cost_tgt self.log_printer = args.wandb self.is_vcoco = (args.dataset_file == 'vcoco') self.is_hico = (args.dataset_file == 'hico-det') if self.is_vcoco: self.valid_ids = args.valid_ids self.invalid_ids = args.invalid_ids assert self.cost_action != 0 or self.cost_hbox != 0 or self.cost_obox != 0, "all costs cant be 0" def reduce_redundant_gt_box(self, tgt_bbox, indices): """Filters redundant Ground-Truth Bounding Boxes Due to random crop augmentation, there exists cases where there exists multiple redundant labels for the exact same bounding box and object class. This function deals with the redundant labels for smoother HOTR training. """ tgt_bbox_unique, map_idx, idx_cnt = torch.unique(tgt_bbox, dim=0, return_inverse=True, return_counts=True) k_idx, bbox_idx = indices triggered = False if (len(tgt_bbox) != len(tgt_bbox_unique)): map_dict = {k: v for k, v in enumerate(map_idx)} map_bbox2kidx = {int(bbox_id): k_id for bbox_id, k_id in zip(bbox_idx, k_idx)} bbox_lst, k_lst = [], [] for bbox_id in bbox_idx: if map_dict[int(bbox_id)] not in bbox_lst: bbox_lst.append(map_dict[int(bbox_id)]) k_lst.append(map_bbox2kidx[int(bbox_id)]) bbox_idx = torch.tensor(bbox_lst) k_idx = torch.tensor(k_lst) tgt_bbox_res = tgt_bbox_unique else: tgt_bbox_res = tgt_bbox bbox_idx = bbox_idx.to(tgt_bbox.device) return tgt_bbox_res, k_idx, bbox_idx @torch.no_grad() def forward(self, outputs, targets, indices, log=False): assert "pred_actions" in outputs, "There is no action output for pair matching" num_obj_queries = outputs["pred_boxes"].shape[1] bs,num_path, num_queries = outputs["pred_actions"].shape[:3] detr_query_num = outputs["pred_logits"].shape[1] \ if (outputs["pred_oidx"].shape[-1] == (outputs["pred_logits"].shape[1] + 1)) else -1 return_list = [] if self.log_printer and log: log_dict = {'h_cost': [], 'o_cost': [], 'act_cost': []} if self.is_hico: log_dict['tgt_cost'] = [] for batch_idx in range(bs): tgt_bbox = targets[batch_idx]["boxes"] # (num_boxes, 4) tgt_cls = targets[batch_idx]["labels"] # (num_boxes) if self.is_vcoco: targets[batch_idx]["pair_actions"][:, self.invalid_ids] = 0 keep_idx = (targets[batch_idx]["pair_actions"].sum(dim=-1) != 0) targets[batch_idx]["pair_boxes"] = targets[batch_idx]["pair_boxes"][keep_idx] targets[batch_idx]["pair_actions"] = targets[batch_idx]["pair_actions"][keep_idx] targets[batch_idx]["pair_targets"] = targets[batch_idx]["pair_targets"][keep_idx] tgt_pbox = targets[batch_idx]["pair_boxes"] # (num_pair_boxes, 8) tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 29) tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes) tgt_hbox = tgt_pbox[:, :4] # (num_pair_boxes, 4) tgt_obox = tgt_pbox[:, 4:] # (num_pair_boxes, 4) elif self.is_hico: tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 117) tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes) tgt_hbox = targets[batch_idx]["sub_boxes"] # (num_pair_boxes, 4) tgt_obox = targets[batch_idx]["obj_boxes"] # (num_pair_boxes, 4) # find which gt boxes match the h, o boxes in the pair if self.is_vcoco: hbox_with_cls = torch.cat([tgt_hbox, torch.ones((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1) elif self.is_hico: hbox_with_cls = torch.cat([tgt_hbox, torch.zeros((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1) obox_with_cls = torch.cat([tgt_obox, tgt_tgt.unsqueeze(-1)], dim=1) obox_with_cls[obox_with_cls[:, :4].sum(dim=1) == -4, -1] = -1 # turn the class of occluded objects to -1 bbox_with_cls = torch.cat([tgt_bbox, tgt_cls.unsqueeze(-1)], dim=1) bbox_with_cls, k_idx, bbox_idx = self.reduce_redundant_gt_box(bbox_with_cls, indices[batch_idx]) bbox_with_cls = torch.cat((bbox_with_cls, torch.as_tensor([-1.]*5).unsqueeze(0).to(tgt_cls.device)), dim=0) cost_hbox = torch.cdist(hbox_with_cls, bbox_with_cls, p=1) cost_obox = torch.cdist(obox_with_cls, bbox_with_cls, p=1) # find which gt boxes matches which prediction in K h_match_indices = torch.nonzero(cost_hbox == 0, as_tuple=False) # (num_hbox, num_boxes) o_match_indices = torch.nonzero(cost_obox == 0, as_tuple=False) # (num_obox, num_boxes) tgt_hids, tgt_oids = [], [] # obtain ground truth indices for h if len(h_match_indices) != len(o_match_indices): import pdb; pdb.set_trace() for h_match_idx, o_match_idx in zip(h_match_indices, o_match_indices): hbox_idx, H_bbox_idx = h_match_idx obox_idx, O_bbox_idx = o_match_idx if O_bbox_idx == (len(bbox_with_cls)-1): # if the object class is -1 O_bbox_idx = H_bbox_idx # happens in V-COCO, the target object may not appear GT_idx_for_H = (bbox_idx == H_bbox_idx).nonzero(as_tuple=False).squeeze(-1) query_idx_for_H = k_idx[GT_idx_for_H] tgt_hids.append(query_idx_for_H) GT_idx_for_O = (bbox_idx == O_bbox_idx).nonzero(as_tuple=False).squeeze(-1) query_idx_for_O = k_idx[GT_idx_for_O] tgt_oids.append(query_idx_for_O) # check if empty if len(tgt_hids) == 0: tgt_hids.append(torch.as_tensor([-1])) # we later ignore the label -1 if len(tgt_oids) == 0: tgt_oids.append(torch.as_tensor([-1])) # we later ignore the label -1 tgt_sum = (tgt_act.sum(dim=-1)).unsqueeze(0) flag = False if tgt_act.shape[0] == 0: tgt_act = torch.zeros((1, tgt_act.shape[1])).to(targets[batch_idx]["pair_actions"].device) targets[batch_idx]["pair_actions"] = torch.zeros((1, targets[batch_idx]["pair_actions"].shape[1])).to(targets[batch_idx]["pair_actions"].device) if self.is_hico: pad_tgt = -1 # outputs["pred_obj_logits"].shape[-1]-1 tgt_tgt = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"]) targets[batch_idx]["pair_targets"] = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"].device) tgt_sum = (tgt_act.sum(dim=-1) + 1).unsqueeze(0) # Concat target label tgt_hids = torch.cat(tgt_hids).repeat(num_path) tgt_oids = torch.cat(tgt_oids).repeat(num_path) # import pdb;pdb.set_trace() outputs_hidx=outputs["pred_hidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2) outputs_oidx=outputs["pred_oidx"].view(num_path,bs,num_queries,-1).transpose(0,1).flatten(1,2) outputs_action=outputs["pred_actions"].view(bs,num_path*num_queries,-1) out_hprob = outputs_hidx[batch_idx].softmax(-1) out_oprob = outputs_oidx[batch_idx].softmax(-1) out_act = outputs_action[batch_idx].clone() if self.is_vcoco: out_act[..., self.invalid_ids] = 0 if self.is_hico: outputs_obj_logits=outputs["pred_obj_logits"].view(bs,num_path,num_queries,-1).view(bs,num_path*num_queries,-1) out_tgt = outputs_obj_logits[batch_idx].softmax(-1) out_tgt[..., -1] = 0 # don't get cost for no-object tgt_act = torch.cat([tgt_act, torch.zeros(tgt_act.shape[0]).unsqueeze(-1).to(tgt_act.device)], dim=-1).repeat(num_path,1) cost_hclass = -out_hprob[:, tgt_hids] # [batch_size * num_queries, detr.num_queries+1] cost_oclass = -out_oprob[:, tgt_oids] # [batch_size * num_queries, detr.num_queries+1] # import pdb;pdb.set_trace() cost_pos_act = (-torch.matmul(out_act, tgt_act.t().float())) / tgt_sum.repeat(1,num_path) cost_neg_act = (torch.matmul(out_act, (~tgt_act.bool()).type(torch.int64).t().float())) / (~tgt_act.bool()).type(torch.int64).sum(dim=-1).unsqueeze(0) cost_action = cost_pos_act + cost_neg_act h_cost = self.cost_hbox * cost_hclass o_cost = self.cost_obox * cost_oclass act_cost = self.cost_action * cost_action C = h_cost + o_cost + act_cost if self.is_hico: cost_target = -out_tgt[:, tgt_tgt.repeat(num_path)] tgt_cost = self.cost_target * cost_target C += tgt_cost C = C.view(num_path,num_queries, -1).cpu() sizes = [len(tgt_hids)//num_path]*num_path hoi_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] return_list.append([(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in hoi_indices]) # import pdb;pdb.set_trace() targets[batch_idx]["h_labels"] = tgt_hids.to(tgt_hbox.device) targets[batch_idx]["o_labels"] = tgt_oids.to(tgt_obox.device) log_act_cost = torch.zeros([1]).to(tgt_act.device) if tgt_act.shape[0] == 0 else act_cost.min(dim=0)[0].mean() if self.log_printer and log: log_dict['h_cost'].append(h_cost[:num_queries].min(dim=0)[0].mean()) log_dict['o_cost'].append(o_cost[:num_queries].min(dim=0)[0].mean()) log_dict['act_cost'].append(act_cost[:num_queries].min(dim=0)[0].mean()) if self.is_hico: log_dict['tgt_cost'].append(tgt_cost[:num_queries].min(dim=0)[0].mean()) if self.log_printer and log: log_dict['h_cost'] = torch.stack(log_dict['h_cost']).mean() log_dict['o_cost'] = torch.stack(log_dict['o_cost']).mean() log_dict['act_cost'] = torch.stack(log_dict['act_cost']).mean() if self.is_hico: log_dict['tgt_cost'] = torch.stack(log_dict['tgt_cost']).mean() if utils.get_rank() == 0: wandb.log(log_dict) return return_list, targets def build_hoi_matcher(args): return HungarianPairMatcher(args)