Spaces:
Runtime error
Runtime error
File size: 8,207 Bytes
5e0b9df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# ------------------------------------------------------------------------
# HOTR official code : hotr/models/post_process.py
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import time
import copy
import torch
import torch.nn.functional as F
from torch import nn
from hotr.util import box_ops
class PostProcess(nn.Module):
""" This module converts the model's output into the format expected by the coco api"""
def __init__(self, HOIDet):
super().__init__()
self.HOIDet = HOIDet
@torch.no_grad()
def forward(self, outputs, target_sizes, threshold=0, dataset='coco',args=None):
""" Perform the computation
Parameters:
outputs: raw outputs of the model
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
For evaluation, this must be the original image size (before any data augmentation)
For visualization, this should be the image size after data augment, but before padding
"""
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
num_path = 1+len(args.augpath_name)
path_id = args.path_id
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
prob = F.softmax(out_logits, -1)
scores, labels = prob[..., :-1].max(-1)
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
# Preidction Branch for HOI detection
if self.HOIDet:
if dataset == 'vcoco':
""" Compute HOI triplet prediction score for V-COCO.
Our scoring function follows the implementation details of UnionDet.
"""
out_time = outputs['hoi_recognition_time']
bss,q,hd=outputs['pred_hidx'].shape
start_time = time.time()
pair_actions = torch.sigmoid(outputs['pred_actions'][:,path_id,...])
h_prob = F.softmax(outputs['pred_hidx'].view(num_path,bss//num_path,q,hd)[path_id], -1)
h_idx_score, h_indices = h_prob.max(-1)
o_prob = F.softmax(outputs['pred_oidx'].view(num_path,bss//num_path,q,hd)[path_id], -1)
o_idx_score, o_indices = o_prob.max(-1)
hoi_recognition_time = (time.time() - start_time) + out_time
# import pdb;pdb.set_trace()
results = []
# iterate for batch size
for batch_idx, (s, l, b) in enumerate(zip(scores, labels, boxes)):
h_inds = (l == 1) & (s > threshold)
o_inds = (s > threshold)
h_box, h_cat = b[h_inds], s[h_inds]
o_box, o_cat = b[o_inds], s[o_inds]
# for scenario 1 in v-coco dataset
o_inds = torch.cat((o_inds, torch.ones(1).type(torch.bool).to(o_inds.device)))
o_box = torch.cat((o_box, torch.Tensor([0, 0, 0, 0]).unsqueeze(0).to(o_box.device)))
result_dict = {
'h_box': h_box, 'h_cat': h_cat,
'o_box': o_box, 'o_cat': o_cat,
'scores': s, 'labels': l, 'boxes': b
}
h_inds_lst = (h_inds == True).nonzero(as_tuple=False).squeeze(-1)
o_inds_lst = (o_inds == True).nonzero(as_tuple=False).squeeze(-1)
K = boxes.shape[1]
n_act = pair_actions[batch_idx][:, :-1].shape[-1]
score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device)
sorted_score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device)
id_score = torch.zeros((K, K+1)).to(pair_actions[batch_idx].device)
# import pdb;pdb.set_trace()
# Score function
for hs, h_idx, os, o_idx, pair_action in zip(h_idx_score[batch_idx], h_indices[batch_idx], o_idx_score[batch_idx], o_indices[batch_idx], pair_actions[batch_idx]):
matching_score = (1-pair_action[-1]) # no interaction score
if h_idx == o_idx: o_idx = -1
if matching_score > id_score[h_idx, o_idx]:
id_score[h_idx, o_idx] = matching_score
sorted_score[:, h_idx, o_idx] = matching_score * pair_action[:-1]
score[:, h_idx, o_idx] += matching_score * pair_action[:-1]
score += sorted_score
score = score[:, h_inds, :]
score = score[:, :, o_inds]
result_dict.update({
'pair_score': score,
'hoi_recognition_time': hoi_recognition_time,
})
results.append(result_dict)
elif dataset == 'hico-det':
""" Compute HOI triplet prediction score for HICO-DET.
For HICO-DET, we follow the same scoring function but do not accumulate the results.
"""
bss,q,hd=outputs['pred_hidx'].shape
out_time = outputs['hoi_recognition_time']
a,b,c=outputs['pred_obj_logits'].shape
start_time = time.time()
out_obj_logits, out_verb_logits = outputs['pred_obj_logits'].view(-1,num_path,b,c)[:,path_id,...], outputs['pred_actions'][:,path_id,...]
out_verb_logits = outputs['pred_actions'][:,path_id,...]
# actions
matching_scores = (1-out_verb_logits.sigmoid()[..., -1:]) #* (1-out_verb_logits.sigmoid()[..., 57:58])
verb_scores = out_verb_logits.sigmoid()[..., :-1] * matching_scores
# hbox, obox
outputs_hrepr, outputs_orepr = outputs['pred_hidx'].view(num_path,bss//num_path,q,hd)[path_id], outputs['pred_oidx'].view(num_path,bss//num_path,q,hd)[path_id]
obj_scores, obj_labels = F.softmax(out_obj_logits, -1)[..., :-1].max(-1)
h_prob = F.softmax(outputs_hrepr, -1)
h_idx_score, h_indices = h_prob.max(-1)
# targets
o_prob = F.softmax(outputs_orepr, -1)
o_idx_score, o_indices = o_prob.max(-1)
hoi_recognition_time = (time.time() - start_time) + out_time
# hidx, oidx
sub_boxes, obj_boxes = [], []
for batch_id, (box, h_idx, o_idx) in enumerate(zip(boxes, h_indices, o_indices)):
sub_boxes.append(box[h_idx, :])
obj_boxes.append(box[o_idx, :])
sub_boxes = torch.stack(sub_boxes, dim=0)
obj_boxes = torch.stack(obj_boxes, dim=0)
# accumulate results (iterate through interaction queries)
results = []
for os, ol, vs, ms, sb, ob in zip(obj_scores, obj_labels, verb_scores, matching_scores, sub_boxes, obj_boxes):
sl = torch.full_like(ol, 0) # self.subject_category_id = 0 in HICO-DET
l = torch.cat((sl, ol))
b = torch.cat((sb, ob))
results.append({'labels': l.to('cpu'), 'boxes': b.to('cpu')})
vs = vs * os.unsqueeze(1)
ids = torch.arange(b.shape[0])
res_dict = {
'verb_scores': vs.to('cpu'),
'sub_ids': ids[:ids.shape[0] // 2],
'obj_ids': ids[ids.shape[0] // 2:],
'hoi_recognition_time': hoi_recognition_time
}
results[-1].update(res_dict)
else:
results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
return results
|