File size: 8,207 Bytes
5e0b9df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# ------------------------------------------------------------------------
# HOTR official code : hotr/models/post_process.py
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import time
import copy
import torch
import torch.nn.functional as F
from torch import nn
from hotr.util import box_ops

class PostProcess(nn.Module):
    """ This module converts the model's output into the format expected by the coco api"""
    def __init__(self, HOIDet):
        super().__init__()
        self.HOIDet = HOIDet

    @torch.no_grad()
    def forward(self, outputs, target_sizes, threshold=0, dataset='coco',args=None):
        """ Perform the computation
        Parameters:
            outputs: raw outputs of the model
            target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
                          For evaluation, this must be the original image size (before any data augmentation)
                          For visualization, this should be the image size after data augment, but before padding
        """
        out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
        num_path = 1+len(args.augpath_name)
        path_id = args.path_id
        assert len(out_logits) == len(target_sizes)
        assert target_sizes.shape[1] == 2

        prob = F.softmax(out_logits, -1)
        scores, labels = prob[..., :-1].max(-1)

        boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
        img_h, img_w = target_sizes.unbind(1)
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        boxes = boxes * scale_fct[:, None, :]

        # Preidction Branch for HOI detection
        if self.HOIDet:
            if dataset == 'vcoco':
                """ Compute HOI triplet prediction score for V-COCO.
                Our scoring function follows the implementation details of UnionDet.
                """
                
                out_time = outputs['hoi_recognition_time']
                bss,q,hd=outputs['pred_hidx'].shape
                start_time = time.time()
                pair_actions = torch.sigmoid(outputs['pred_actions'][:,path_id,...])
                h_prob = F.softmax(outputs['pred_hidx'].view(num_path,bss//num_path,q,hd)[path_id], -1)
                h_idx_score, h_indices = h_prob.max(-1)

                o_prob = F.softmax(outputs['pred_oidx'].view(num_path,bss//num_path,q,hd)[path_id], -1)
                o_idx_score, o_indices = o_prob.max(-1)
                hoi_recognition_time = (time.time() - start_time) + out_time
                # import pdb;pdb.set_trace()
                results = []
                # iterate for batch size
                for batch_idx, (s, l, b) in enumerate(zip(scores, labels, boxes)):
                    h_inds = (l == 1) & (s > threshold)
                    o_inds = (s > threshold)

                    h_box, h_cat = b[h_inds], s[h_inds]
                    o_box, o_cat = b[o_inds], s[o_inds]

                    # for scenario 1 in v-coco dataset
                    o_inds = torch.cat((o_inds, torch.ones(1).type(torch.bool).to(o_inds.device)))
                    o_box = torch.cat((o_box, torch.Tensor([0, 0, 0, 0]).unsqueeze(0).to(o_box.device)))

                    result_dict = {
                        'h_box': h_box, 'h_cat': h_cat,
                        'o_box': o_box, 'o_cat': o_cat,
                        'scores': s, 'labels': l, 'boxes': b
                    }

                    h_inds_lst = (h_inds == True).nonzero(as_tuple=False).squeeze(-1)
                    o_inds_lst = (o_inds == True).nonzero(as_tuple=False).squeeze(-1)

                    K = boxes.shape[1]
                    n_act = pair_actions[batch_idx][:, :-1].shape[-1]
                    score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device)
                    sorted_score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device)
                    id_score = torch.zeros((K, K+1)).to(pair_actions[batch_idx].device)
                    # import pdb;pdb.set_trace()
                    # Score function
                    for hs, h_idx, os, o_idx, pair_action in zip(h_idx_score[batch_idx], h_indices[batch_idx], o_idx_score[batch_idx], o_indices[batch_idx], pair_actions[batch_idx]):
                        matching_score = (1-pair_action[-1]) # no interaction score
                        if h_idx == o_idx: o_idx = -1
                        if matching_score > id_score[h_idx, o_idx]:
                            id_score[h_idx, o_idx] = matching_score
                            sorted_score[:, h_idx, o_idx] = matching_score * pair_action[:-1]
                        score[:, h_idx, o_idx] += matching_score * pair_action[:-1]

                    score += sorted_score
                    score = score[:, h_inds, :]
                    score = score[:, :, o_inds]

                    result_dict.update({
                        'pair_score': score,
                        'hoi_recognition_time': hoi_recognition_time,
                    })

                    results.append(result_dict)

            elif dataset == 'hico-det':
                """ Compute HOI triplet prediction score for HICO-DET.
                For HICO-DET, we follow the same scoring function but do not accumulate the results.
                """
                
                bss,q,hd=outputs['pred_hidx'].shape
                out_time = outputs['hoi_recognition_time']
                a,b,c=outputs['pred_obj_logits'].shape
                start_time = time.time()
                out_obj_logits, out_verb_logits = outputs['pred_obj_logits'].view(-1,num_path,b,c)[:,path_id,...], outputs['pred_actions'][:,path_id,...]
                out_verb_logits = outputs['pred_actions'][:,path_id,...]

                # actions
                matching_scores = (1-out_verb_logits.sigmoid()[..., -1:]) #* (1-out_verb_logits.sigmoid()[..., 57:58])
                verb_scores = out_verb_logits.sigmoid()[..., :-1] * matching_scores

                # hbox, obox
                outputs_hrepr, outputs_orepr = outputs['pred_hidx'].view(num_path,bss//num_path,q,hd)[path_id], outputs['pred_oidx'].view(num_path,bss//num_path,q,hd)[path_id]
                obj_scores, obj_labels = F.softmax(out_obj_logits, -1)[..., :-1].max(-1)

                h_prob = F.softmax(outputs_hrepr, -1)
                h_idx_score, h_indices = h_prob.max(-1)

                # targets
                o_prob = F.softmax(outputs_orepr, -1)
                o_idx_score, o_indices = o_prob.max(-1)
                hoi_recognition_time = (time.time() - start_time) + out_time

                # hidx, oidx
                sub_boxes, obj_boxes = [], []
                for batch_id, (box, h_idx, o_idx) in enumerate(zip(boxes, h_indices, o_indices)):
                    sub_boxes.append(box[h_idx, :])
                    obj_boxes.append(box[o_idx, :])
                sub_boxes = torch.stack(sub_boxes, dim=0)
                obj_boxes = torch.stack(obj_boxes, dim=0)

                # accumulate results (iterate through interaction queries)
                results = []
                for os, ol, vs, ms, sb, ob in zip(obj_scores, obj_labels, verb_scores, matching_scores, sub_boxes, obj_boxes):
                    sl = torch.full_like(ol, 0) # self.subject_category_id = 0 in HICO-DET
                    l = torch.cat((sl, ol))
                    b = torch.cat((sb, ob))
                    results.append({'labels': l.to('cpu'), 'boxes': b.to('cpu')})
                    vs = vs * os.unsqueeze(1)
                    ids = torch.arange(b.shape[0])
                    res_dict = {
                        'verb_scores': vs.to('cpu'),
                        'sub_ids': ids[:ids.shape[0] // 2],
                        'obj_ids': ids[ids.shape[0] // 2:],
                        'hoi_recognition_time': hoi_recognition_time
                    }
                    results[-1].update(res_dict)
        else:
            results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]

        return results