hujiecpp commited on
Commit
a908963
·
1 Parent(s): 04a035a

init project

Browse files
modules/dust3r/cloud_opt/base_opt.py CHANGED
@@ -121,12 +121,11 @@ class BasePCOptimizer (nn.Module):
121
  self.fix_imgs = rgb(ori_imgs)
122
  self.smoothed_imgs = rgb(smoothed_imgs)
123
 
124
- self.cogs = [torch.zeros((h, w, 1024), device="cuda") for h, w in self.imshapes]
125
- semantic_feats = semantic_feats.to("cuda")
126
- self.segmaps = [-torch.ones((h, w), device="cuda") for h, w in self.imshapes]
127
- self.rev_segmaps = [-torch.ones((h, w), device="cuda") for h, w in self.imshapes]
128
- # self.conf_1 = [torch.zeros((h, w), device="cuda") for h, w in self.imshapes]
129
- # self.conf_2 = [torch.zeros((h, w), device="cuda") for h, w in self.imshapes]
130
  for v in range(len(self.edges)):
131
  idx = view1['idx'][v]
132
 
@@ -142,8 +141,8 @@ class BasePCOptimizer (nn.Module):
142
  seg = cog_seg_map[y, x].squeeze(-1).long()
143
 
144
  self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
145
- self.segmaps[idx] = cog_seg_map.cuda()
146
- self.rev_segmaps[idx] = rev_seg_map.cuda()
147
 
148
  idx = view2['idx'][v]
149
  h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1]
@@ -158,8 +157,8 @@ class BasePCOptimizer (nn.Module):
158
  seg = cog_seg_map[y, x].squeeze(-1).long()
159
 
160
  self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
161
- self.segmaps[idx] = cog_seg_map.cuda()
162
- self.rev_segmaps[idx] = rev_seg_map.cuda()
163
 
164
  self.rendered_imgs = []
165
 
 
121
  self.fix_imgs = rgb(ori_imgs)
122
  self.smoothed_imgs = rgb(smoothed_imgs)
123
 
124
+ self.cogs = [torch.zeros((h, w, 1024)) for h, w in self.imshapes]
125
+ # semantic_feats = semantic_feats.to("cuda")
126
+ self.segmaps = [-torch.ones((h, w)) for h, w in self.imshapes]
127
+ self.rev_segmaps = [-torch.ones((h, w)) for h, w in self.imshapes]
128
+
 
129
  for v in range(len(self.edges)):
130
  idx = view1['idx'][v]
131
 
 
141
  seg = cog_seg_map[y, x].squeeze(-1).long()
142
 
143
  self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
144
+ self.segmaps[idx] = cog_seg_map#.cuda()
145
+ self.rev_segmaps[idx] = rev_seg_map#.cuda()
146
 
147
  idx = view2['idx'][v]
148
  h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1]
 
157
  seg = cog_seg_map[y, x].squeeze(-1).long()
158
 
159
  self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
160
+ self.segmaps[idx] = cog_seg_map#.cuda()
161
+ self.rev_segmaps[idx] = rev_seg_map#.cuda()
162
 
163
  self.rendered_imgs = []
164
 
modules/dust3r/cloud_opt/optimizer.py.bak.1216 DELETED
@@ -1,533 +0,0 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # Main class for the implementation of the global alignment
6
- # --------------------------------------------------------
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
-
11
- from dust3r.cloud_opt.base_opt import BasePCOptimizer
12
- from dust3r.utils.geometry import xy_grid, geotrf
13
- from dust3r.utils.device import to_cpu, to_numpy
14
- import torch.nn.functional as F
15
-
16
- class PointCloudOptimizer(BasePCOptimizer):
17
- """ Optimize a global scene, given a list of pairwise observations.
18
- Graph node: images
19
- Graph edges: observations = (pred1, pred2)
20
- """
21
-
22
- def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
23
- super().__init__(*args, **kwargs)
24
-
25
- self.has_im_poses = True # by definition of this class
26
- self.focal_break = focal_break
27
-
28
- # adding thing to optimize
29
- self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
30
- self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
31
- self.im_focals = nn.ParameterList(torch.FloatTensor(
32
- [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics
33
- self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
34
- self.im_pp.requires_grad_(optimize_pp)
35
-
36
- self.imshape = self.imshapes[0]
37
- im_areas = [h*w for h, w in self.imshapes]
38
- self.max_area = max(im_areas)
39
-
40
- # adding thing to optimize
41
- self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)
42
- self.im_poses = ParameterStack(self.im_poses, is_param=True)
43
- self.im_focals = ParameterStack(self.im_focals, is_param=True)
44
- self.im_pp = ParameterStack(self.im_pp, is_param=True)
45
- self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))
46
- self.register_buffer('_grid', ParameterStack(
47
- [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))
48
-
49
- # pre-compute pixel weights
50
- self.register_buffer('_weight_i', ParameterStack(
51
- [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))
52
- self.register_buffer('_weight_j', ParameterStack(
53
- [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))
54
-
55
- # precompute aa
56
- self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))
57
- self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))
58
- self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))
59
- self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))
60
- self.total_area_i = sum([im_areas[i] for i, j in self.edges])
61
- self.total_area_j = sum([im_areas[j] for i, j in self.edges])
62
-
63
- def _check_all_imgs_are_selected(self, msk):
64
- assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'
65
-
66
- def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
67
- self._check_all_imgs_are_selected(pose_msk)
68
-
69
- if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
70
- known_poses = [known_poses]
71
- for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
72
- if self.verbose:
73
- print(f' (setting pose #{idx} = {pose[:3,3]})')
74
- self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
75
-
76
- # normalize scale if there's less than 1 known pose
77
- n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
78
- self.norm_pw_scale = (n_known_poses <= 1)
79
-
80
- self.im_poses.requires_grad_(False)
81
- self.norm_pw_scale = False
82
-
83
- def preset_focal(self, known_focals, msk=None):
84
- self._check_all_imgs_are_selected(msk)
85
-
86
- for idx, focal in zip(self._get_msk_indices(msk), known_focals):
87
- if self.verbose:
88
- print(f' (setting focal #{idx} = {focal})')
89
- self._no_grad(self._set_focal(idx, focal))
90
-
91
- self.im_focals.requires_grad_(False)
92
-
93
- def preset_principal_point(self, known_pp, msk=None):
94
- self._check_all_imgs_are_selected(msk)
95
-
96
- for idx, pp in zip(self._get_msk_indices(msk), known_pp):
97
- if self.verbose:
98
- print(f' (setting principal point #{idx} = {pp})')
99
- self._no_grad(self._set_principal_point(idx, pp))
100
-
101
- self.im_pp.requires_grad_(False)
102
-
103
- def _get_msk_indices(self, msk):
104
- if msk is None:
105
- return range(self.n_imgs)
106
- elif isinstance(msk, int):
107
- return [msk]
108
- elif isinstance(msk, (tuple, list)):
109
- return self._get_msk_indices(np.array(msk))
110
- elif msk.dtype in (bool, torch.bool, np.bool_):
111
- assert len(msk) == self.n_imgs
112
- return np.where(msk)[0]
113
- elif np.issubdtype(msk.dtype, np.integer):
114
- return msk
115
- else:
116
- raise ValueError(f'bad {msk=}')
117
-
118
- def _no_grad(self, tensor):
119
- assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'
120
-
121
- def _set_focal(self, idx, focal, force=False):
122
- param = self.im_focals[idx]
123
- if param.requires_grad or force: # can only init a parameter not already initialized
124
- param.data[:] = self.focal_break * np.log(focal)
125
- return param
126
-
127
- def get_focals(self):
128
- log_focals = torch.stack(list(self.im_focals), dim=0)
129
- return (log_focals / self.focal_break).exp()
130
-
131
- def get_known_focal_mask(self):
132
- return torch.tensor([not (p.requires_grad) for p in self.im_focals])
133
-
134
- def _set_principal_point(self, idx, pp, force=False):
135
- param = self.im_pp[idx]
136
- H, W = self.imshapes[idx]
137
- if param.requires_grad or force: # can only init a parameter not already initialized
138
- param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
139
- return param
140
-
141
- def get_principal_points(self):
142
- return self._pp + 10 * self.im_pp
143
-
144
- def get_intrinsics(self):
145
- K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
146
- focals = self.get_focals().flatten()
147
- K[:, 0, 0] = K[:, 1, 1] = focals
148
- K[:, :2, 2] = self.get_principal_points()
149
- K[:, 2, 2] = 1
150
- return K
151
-
152
- def get_im_poses(self): # cam to world
153
- cam2world = self._get_poses(self.im_poses)
154
- return cam2world
155
-
156
- def _set_depthmap(self, idx, depth, force=False):
157
- depth = _ravel_hw(depth, self.max_area)
158
-
159
- param = self.im_depthmaps[idx]
160
- if param.requires_grad or force: # can only init a parameter not already initialized
161
- param.data[:] = depth.log().nan_to_num(neginf=0)
162
- return param
163
-
164
- def get_depthmaps(self, raw=False):
165
- res = self.im_depthmaps.exp()
166
- if not raw:
167
- res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
168
- return res
169
-
170
- def depth_to_pts3d(self):
171
- # Get depths and projection params if not provided
172
- focals = self.get_focals()
173
- pp = self.get_principal_points()
174
- im_poses = self.get_im_poses()
175
- depth = self.get_depthmaps(raw=True)
176
-
177
- # get pointmaps in camera frame
178
- rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)
179
- # project to world frame
180
- return geotrf(im_poses, rel_ptmaps)
181
-
182
- def get_pts3d(self, raw=False):
183
- res = self.depth_to_pts3d()
184
- if not raw:
185
- res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
186
- return res
187
-
188
- # def cosine_similarity_batch(self, semantic_features, query_pixels):
189
- # # 扩展维度进行广播计算余弦相似度
190
- # query_pixels = query_pixels.unsqueeze(1) # [B, 1, C]
191
- # semantic_features = semantic_features.unsqueeze(0) # [1, H, W, C]
192
- # cos_sim = F.cosine_similarity(query_pixels, semantic_features, dim=-1) # [B, H, W]
193
- # return cos_sim
194
-
195
- # def semantic_loss(self, semantic_features, predicted_depth, window_size=32, stride=16, lambda_semantic=0.1):
196
- # # 获取图像的尺寸
197
- # height, width, channels = semantic_features.shape
198
- # # 执行矩阵化处理
199
- # ret_loss = 0.0
200
- # cnt = 0
201
- # for i in range(0, height, stride):
202
- # for j in range(0, width, stride):
203
- # window_semantic = semantic_features[i:min(i+window_size,height), j:min(j+window_size,width), :]
204
- # window_depth = predicted_depth[i:min(i+window_size,height), j:min(j+window_size,width)]
205
- # # print(window_semantic.shape, window_depth.shape)
206
-
207
- # window_semantic = window_semantic.reshape(-1, channels)
208
- # window_depth = window_depth.reshape(-1, 1)
209
-
210
- # cos_sim = torch.matmul(window_semantic, window_semantic.t())
211
- # dep_dif = torch.abs(window_depth - window_depth.reshape(1, -1))
212
-
213
- # # print(torch.sum(cos_sim * dep_dif))
214
- # ret_loss += torch.mean(cos_sim * dep_dif)
215
- # cnt += 1
216
-
217
- # return ret_loss / cnt
218
-
219
- # def segmap_loss(self, predicted_depth, seg_map):
220
- # ret_loss = 0.0
221
- # cnt = 0
222
- # seg_map = seg_map.view(-1)
223
- # predicted_depth = predicted_depth.view(-1, 1)
224
- # unique_groups = torch.unique(seg_map)
225
- # for group in unique_groups:
226
- # # print(group)
227
- # if group == -1:
228
- # continue
229
- # group_indices = (seg_map == group).nonzero(as_tuple=True)[0]
230
- # if len(group_indices) > 0:
231
- # now_feat = predicted_depth[group_indices]
232
-
233
- # dep_dif = torch.abs(now_feat - now_feat.reshape(1, -1))
234
-
235
- # ret_loss += torch.mean(dep_dif)
236
- # cnt += 1
237
-
238
- # return ret_loss / cnt if cnt > 0 else ret_loss
239
-
240
- # def spatial_smoothness_loss(self, point_map, semantic_map):
241
- # """
242
- # 计算空间平滑性损失,使得同一语义类别的相邻像素点空间位置变化不剧烈。
243
- # 使用八邻域。
244
-
245
- # 参数:
246
- # - point_map: (H, W, 3),表示每个像素点的空间坐标 (x, y, z)
247
- # - semantic_map: (H, W, 1),每个像素点的语义标签
248
-
249
- # 返回:
250
- # - 总损失值
251
- # """
252
-
253
- # # 获取图像的高度和宽度
254
- # H, W = semantic_map.shape
255
-
256
- # # 将点图和语义图调整为二维形式
257
- # point_map = point_map.view(-1, 3) # (H * W, 3)
258
- # semantic_map = semantic_map.view(-1) # (H * W,)
259
-
260
- # # 创建图像的索引
261
- # row_idx, col_idx = torch.meshgrid(torch.arange(H), torch.arange(W))
262
- # row_idx = row_idx.flatten()
263
- # col_idx = col_idx.flatten()
264
-
265
- # # 定义八邻域偏移
266
- # neighbor_offsets = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1],
267
- # [-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=torch.long)
268
-
269
- # # 存储损失值
270
- # total_loss = 0.0
271
-
272
- # # 对每个像素点进行计算
273
- # for offset in neighbor_offsets:
274
- # # 计算邻居位置
275
- # neighbor_row = row_idx + offset[0]
276
- # neighbor_col = col_idx + offset[1]
277
-
278
- # # 确保邻居在图像内部
279
- # valid_mask = (neighbor_row >= 0) & (neighbor_row < H) & (neighbor_col >= 0) & (neighbor_col < W)
280
- # valid_row = neighbor_row[valid_mask]
281
- # valid_col = neighbor_col[valid_mask]
282
-
283
- # # 获取有效像素点的索引
284
- # idx = valid_mask.nonzero(as_tuple=True)[0]
285
- # neighbor_idx = valid_row * W + valid_col
286
-
287
- # # 获取相邻像素点的语义标签和空间坐标
288
- # sem_i = semantic_map[idx]
289
- # sem_j = semantic_map[neighbor_idx]
290
- # p_i = point_map[idx]
291
- # p_j = point_map[neighbor_idx]
292
-
293
- # # 计算空间坐标差异的平方
294
- # distance = torch.sum((p_i - p_j) ** 2, dim=1)
295
-
296
- # # 如果相邻像素属于同一语义类别,计算损失
297
- # loss_mask = (sem_i == sem_j)
298
- # total_loss += torch.sum(loss_mask * distance)
299
-
300
- # # 平均损失
301
- # return total_loss / point_map.size(0)
302
-
303
-
304
- def spatial_smoothness_loss_multi_image(self, point_maps, semantic_maps, confidence_maps):
305
- """
306
- 计算空间平滑性损失,考虑多张图像中属于同一物体的像素点的空间平滑性。
307
-
308
- 参数:
309
- - point_maps: (B, H, W, 3),每张图像的空间坐标 (x, y, z) B是batch大小
310
- - semantic_maps: (B, H, W, 1),每张图像的语义标签
311
-
312
- 返回:
313
- - 总损失值
314
- """
315
-
316
- B, H, W = semantic_maps.shape
317
-
318
- # 将点图和语义图调整为二维形式
319
- point_maps = point_maps.view(B, -1, 3) # (B, H*W, 3)
320
- semantic_maps = semantic_maps.view(B, -1) # (B, H*W)
321
- confidence_maps = confidence_maps.view(B, -1) # (B, H*W)
322
-
323
- # 存储损失值
324
- total_loss = 0.0
325
-
326
- # 对每张图像中的每个像素进行计算
327
- for b in range(B):
328
- # 获取当前图像的点图和语义图
329
- point_map = point_maps[b]
330
- semantic_map = semantic_maps[b]
331
- confidence_map = confidence_maps[b]
332
-
333
- # 创建图像的索引
334
- row_idx, col_idx = torch.meshgrid(torch.arange(H), torch.arange(W))
335
- row_idx = row_idx.flatten()
336
- col_idx = col_idx.flatten()
337
-
338
- # 定义八邻域偏移
339
- neighbor_offsets = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1],
340
- [-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=torch.long)
341
-
342
- # 对每个像素点进行计算(仅在当前图像内计算邻域关系)
343
- for offset in neighbor_offsets:
344
- # 计算邻居位置
345
- neighbor_row = row_idx + offset[0]
346
- neighbor_col = col_idx + offset[1]
347
-
348
- # 确保邻居在图像内部
349
- valid_mask = (neighbor_row >= 0) & (neighbor_row < H) & (neighbor_col >= 0) & (neighbor_col < W)
350
- valid_row = neighbor_row[valid_mask]
351
- valid_col = neighbor_col[valid_mask]
352
-
353
- # 获取有效像素点的索引
354
- idx = valid_mask.nonzero(as_tuple=True)[0]
355
- neighbor_idx = valid_row * W + valid_col
356
-
357
- # 获取相邻像素点的语义标签和空间坐��
358
- sem_i = semantic_map[idx]
359
- sem_j = semantic_map[neighbor_idx]
360
- p_i = point_map[idx]
361
- p_j = point_map[neighbor_idx]
362
- conf_i = confidence_map[idx]
363
- conf_j = confidence_map[neighbor_idx]
364
-
365
- # 计算空间坐标差异的平方
366
- distance = torch.sum((p_i - p_j)**2, dim=1)
367
-
368
- # 如果相邻像素属于同一语义类别,计算加权损失
369
- loss_mask = (sem_i == sem_j)
370
-
371
- # 反向加权,低置信度的点会有更高的权重
372
- # inverse_weight_i = 1.0 / (conf_i) # 防止除零错误
373
- # inverse_weight_j = 1.0 / (conf_j)
374
- weighted_distance = loss_mask * distance # 加权损失 * inverse_weight_i * inverse_weight_j
375
- total_loss += torch.sum(weighted_distance)
376
-
377
- # 跨图计算:对于同一语义类别的像素,只计算其均值差异,避免两两计算
378
- # for b2 in range(B):
379
- # if b == b2:
380
- # continue # 跳过与自己图像的比较
381
- # point_map_b2 = point_maps[b2]
382
- # semantic_map_b2 = semantic_maps[b2]
383
- # confidence_map_b2 = confidence_maps[b2]
384
-
385
- # for sem_id in torch.unique(semantic_map):
386
- # sem_mask_a = (semantic_map == sem_id)
387
- # sem_mask_b2 = (semantic_map_b2 == sem_id)
388
-
389
- # # 提取同一语义类别的像素点
390
- # shared_points_a = point_map[sem_mask_a]
391
- # shared_points_b2 = point_map_b2[sem_mask_b2]
392
- # shared_conf_a = confidence_map[sem_mask_a]
393
- # shared_conf_b2 = confidence_map_b2[sem_mask_b2]
394
-
395
- # if shared_points_a.shape[0] > 0 and shared_points_b2.shape[0] > 0:
396
- # # 计算这些像素点的均值
397
- # mean_a = shared_points_a.mean(dim=0) # 当前图像该语义类别的均值
398
- # mean_b2 = shared_points_b2.mean(dim=0) # 第b2图像该语义类别的均值
399
- # mean_conf_a = shared_conf_a.mean() # 当前图像该语义类别的置信度均值
400
- # mean_conf_b2 = shared_conf_b2.mean() # 第b2图像该语义类别的置信度均值
401
-
402
- # # 计算均值之间的空间差异,并考虑置信度的加权
403
- # distance_cross = torch.sum((mean_a - mean_b2) ** 2)
404
- # weighted_distance_cross = distance_cross * mean_conf_a * mean_conf_b2
405
- # total_loss += weighted_distance_cross
406
-
407
- # 平均损失
408
- return total_loss / (B * H * W)
409
-
410
-
411
-
412
- def forward(self, cur_iter=0):
413
- pw_poses = self.get_pw_poses() # cam-to-world
414
- pw_adapt = self.get_adaptors().unsqueeze(1)
415
- proj_pts3d = self.get_pts3d(raw=True)
416
-
417
- loss = 0.0
418
-
419
- # depth = self.get_depthmaps(raw=True)
420
- # print(depth.shape)
421
- # if cur_iter < 100:
422
- # # for i, pointmap in enumerate(proj_pts3d):
423
- # # loss += self.spatial_smoothness_loss(pointmap, seg_maps[i].cuda())
424
-
425
- # # depths = self.get_depthmaps()
426
- # # # cogs = self.cogs
427
- # # seg_maps = self.segmaps
428
- # # im_conf = self.conf_trf(torch.stack([param_tensor for param_tensor in self.im_conf]))
429
-
430
- # # for i, depth in enumerate(depths):
431
- # # # print(seg_maps[i].shape)
432
- # # # H, W = depth.shape
433
- # # # tmp = cogs[i].reshape(-1, 1024)
434
- # # # tmp = torch.matmul(tmp, self.cog_matrix.detach().t())
435
- # # # tmp / (tmp.norm(dim=-1, keepdim=True)+0.000000000001)
436
- # # # tmp = tmp.reshape(H, W, 3)
437
- # # loss += self.segmap_loss(depth, seg_maps[i], im_conf[i])
438
- # # loss += self.semantic_loss(cogs[i], depth)
439
-
440
- # # im_conf = self.conf_trf(torch.stack([param_tensor for param_tensor in self.im_conf]))
441
-
442
- # # cogs = self.cogs.permute(0, 3, 1, 2)
443
- # # cogs = F.interpolate(cogs, scale_factor=2, mode='nearest')
444
- # # cogs = cogs.permute(0, 2, 3, 1)
445
- # # cogs = torch.stack(self.cogs).view(-1, 1024)
446
- # # proj = proj_pts3d.view(-1, 3)
447
- # # proj = proj / proj.norm(dim=-1, keepdim=True)
448
- # # img_conf = im_conf.view(-1,1)
449
-
450
- # # selected_indices = torch.where(img_conf > 2.0)[0]
451
- # # img_conf = img_conf[selected_indices]
452
- # # cogs = cogs[selected_indices]
453
- # # proj = proj[selected_indices]
454
- # # print(img_conf.shape, cogs.shape, proj.shape)
455
- # # proj_dis = torch.matmul(proj, proj.t())
456
- # # cogs_dis = torch.matmul(cogs, cogs.t())
457
- # # loss += (im_conf * F.mse_loss(proj_dis, cogs_dis, reduction='none')).mean()
458
-
459
- # # if cur_iter % 2 == 0:
460
- # # tmp = torch.matmul(cogs.detach(), self.cog_matrix.detach().t())
461
- # # tmp = tmp / (tmp.norm(dim=-1, keepdim=True)+0.000000000001)
462
- # # loss += 0/1*(img_conf * F.mse_loss(proj, tmp, reduction='none')).mean()
463
- # # if cur_iter % 2 == 1:
464
- # # tmp = torch.matmul(cogs.view(-1, 1024), self.cog_matrix.detach().t())
465
- # # tmp = tmp / tmp.norm(dim=-1, keepdim=True)
466
- # # loss += (im_conf.view(-1,1) * F.mse_loss(proj.detach(), tmp, reduction='none')).mean()
467
- # # if cur_iter % 3 == 2:
468
- # # tmp = torch.matmul(cogs.view(-1, 1024).detach(), self.cog_matrix.t())
469
- # # tmp = tmp / tmp.norm(dim=-1, keepdim=True)
470
- # # loss += (im_conf.view(-1,1) * F.mse_loss(proj.detach(), tmp, reduction='none')).mean()
471
- seg_maps = torch.stack(self.segmaps).cuda()
472
- im_conf = self.conf_trf(torch.stack([param_tensor for param_tensor in self.im_conf]))
473
- loss += self.spatial_smoothness_loss_multi_image(proj_pts3d, seg_maps, im_conf)
474
- # # if cur_iter > 100:
475
- # # rotate pairwise prediction according to pw_poses
476
- # aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i)
477
- # aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
478
-
479
- # loss += self.spatial_smoothness_loss_multi_image(aligned_pred_i, seg_maps[self._ei], im_conf[self._ei])
480
- # loss += self.spatial_smoothness_loss_multi_image(aligned_pred_j, seg_maps[self._ej], im_conf[self._ej])
481
-
482
- # # compute the less
483
- # loss += self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i
484
- # loss += self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j
485
-
486
- return loss
487
-
488
-
489
- def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
490
- pp = pp.unsqueeze(1)
491
- focal = focal.unsqueeze(1)
492
- assert focal.shape == (len(depth), 1, 1)
493
- assert pp.shape == (len(depth), 1, 2)
494
- assert pixel_grid.shape == depth.shape + (2,)
495
- depth = depth.unsqueeze(-1)
496
- return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)
497
-
498
-
499
- def ParameterStack(params, keys=None, is_param=None, fill=0):
500
- if keys is not None:
501
- params = [params[k] for k in keys]
502
-
503
- if fill > 0:
504
- params = [_ravel_hw(p, fill) for p in params]
505
-
506
- requires_grad = params[0].requires_grad
507
- assert all(p.requires_grad == requires_grad for p in params)
508
-
509
- params = torch.stack(list(params)).float().detach()
510
- if is_param or requires_grad:
511
- params = nn.Parameter(params)
512
- params.requires_grad_(requires_grad)
513
- return params
514
-
515
-
516
- def _ravel_hw(tensor, fill=0):
517
- # ravel H,W
518
- tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
519
-
520
- if len(tensor) < fill:
521
- tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))
522
- return tensor
523
-
524
-
525
- def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
526
- focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
527
- return minf*focal_base, maxf*focal_base
528
-
529
-
530
- def apply_mask(img, msk):
531
- img = img.copy()
532
- img[msk] = 0
533
- return img