root commited on
Commit
49bb3b0
·
1 Parent(s): 9586537

initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +0 -1
  2. hotr/engine/arg_parser.py +0 -2
  3. hotr/models/detr.py +16 -33
.gitignore CHANGED
@@ -131,7 +131,6 @@ wandb/
131
  checkpoints/
132
 
133
  # old version
134
- hotr/models/hotr_v1.py
135
  Makefile
136
 
137
  .DS_Store
 
131
  checkpoints/
132
 
133
  # old version
 
134
  Makefile
135
 
136
  .DS_Store
hotr/engine/arg_parser.py CHANGED
@@ -126,8 +126,6 @@ def get_args_parser():
126
  parser.add_argument('--stop_grad_stage',action='store_true',help='Do not back propogate loss to previous stage')
127
  parser.add_argument('--path_id', default=0, type=int)
128
 
129
- parser.add_argument('--sep_enc_forward',action='store_true')
130
-
131
  # * dataset parameters
132
  parser.add_argument('--dataset_file', help='[coco | vcoco]')
133
  parser.add_argument('--data_path', type=str)
 
126
  parser.add_argument('--stop_grad_stage',action='store_true',help='Do not back propogate loss to previous stage')
127
  parser.add_argument('--path_id', default=0, type=int)
128
 
 
 
129
  # * dataset parameters
130
  parser.add_argument('--dataset_file', help='[coco | vcoco]')
131
  parser.add_argument('--data_path', type=str)
hotr/models/detr.py CHANGED
@@ -23,7 +23,6 @@ from .post_process import PostProcess
23
  from .feed_forward import MLP
24
 
25
  from .hotr import HOTR
26
- from .hotr_v1 import HOTR_V1
27
 
28
  class DETR(nn.Module):
29
  """ This is the DETR module that performs object detection """
@@ -145,38 +144,22 @@ def build(args):
145
 
146
  kwargs = {}
147
  if args.dataset_file == 'hico-det': kwargs['return_obj_class'] = args.valid_obj_ids
148
- if args.sep_enc_forward:
149
- model = HOTR_V1(
150
- detr=model,
151
- num_hoi_queries=args.num_hoi_queries,
152
- num_actions=args.num_actions,
153
- interaction_transformer=interaction_transformer,
154
- augpath_name = args.augpath_name,
155
- share_dec_param = args.share_dec_param,
156
- stop_grad_stage = args.stop_grad_stage,
157
- freeze_detr=(args.frozen_weights is not None),
158
- share_enc=args.share_enc,
159
- pretrained_dec=args.pretrained_dec,
160
- temperature=args.temperature,
161
- hoi_aux_loss=args.hoi_aux_loss,
162
- **kwargs # only return verb class for HICO-DET dataset
163
- )
164
- else:
165
- model = HOTR(
166
- detr=model,
167
- num_hoi_queries=args.num_hoi_queries,
168
- num_actions=args.num_actions,
169
- interaction_transformer=interaction_transformer,
170
- augpath_name = args.augpath_name,
171
- share_dec_param = args.share_dec_param,
172
- stop_grad_stage = args.stop_grad_stage,
173
- freeze_detr=(args.frozen_weights is not None),
174
- share_enc=args.share_enc,
175
- pretrained_dec=args.pretrained_dec,
176
- temperature=args.temperature,
177
- hoi_aux_loss=args.hoi_aux_loss,
178
- **kwargs # only return verb class for HICO-DET dataset
179
- )
180
  postprocessors = {'hoi': PostProcess(args.HOIDet)}
181
  else:
182
  criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=weight_dict,
 
23
  from .feed_forward import MLP
24
 
25
  from .hotr import HOTR
 
26
 
27
  class DETR(nn.Module):
28
  """ This is the DETR module that performs object detection """
 
144
 
145
  kwargs = {}
146
  if args.dataset_file == 'hico-det': kwargs['return_obj_class'] = args.valid_obj_ids
147
+
148
+ model = HOTR(
149
+ detr=model,
150
+ num_hoi_queries=args.num_hoi_queries,
151
+ num_actions=args.num_actions,
152
+ interaction_transformer=interaction_transformer,
153
+ augpath_name = args.augpath_name,
154
+ share_dec_param = args.share_dec_param,
155
+ stop_grad_stage = args.stop_grad_stage,
156
+ freeze_detr=(args.frozen_weights is not None),
157
+ share_enc=args.share_enc,
158
+ pretrained_dec=args.pretrained_dec,
159
+ temperature=args.temperature,
160
+ hoi_aux_loss=args.hoi_aux_loss,
161
+ **kwargs # only return verb class for HICO-DET dataset
162
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  postprocessors = {'hoi': PostProcess(args.HOIDet)}
164
  else:
165
  criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=weight_dict,