bbynku commited on
Commit
4dee8df
·
verified ·
1 Parent(s): 4e78087

Upload 5 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/example_depth.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/example_rgb.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Geometry Prior Visualization Demo
2
+
3
+ This demo shows the visualization of geometry priors from RGB and depth images.
4
+
5
+ ## Usage
6
+ Upload an RGB image and a depth image to see the visualization result.
7
+
8
+ ## Examples
9
+ The demo includes example images from the NYUDepthv2 dataset.
app.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.nn as nn
7
+ from torchvision.transforms.functional import to_pil_image
8
+ from torchvision.transforms import Resize
9
+ import cv2
10
+ import numpy as np
11
+ from torchcam.utils import overlay_mask
12
+ import gradio as gr
13
+ import os
14
+
15
+ class GeoPrior(nn.Module):
16
+
17
+ def __init__(self, embed_dim=128, num_heads=4, initial_value=2, heads_range=6):
18
+ super().__init__()
19
+ angle = 1.0 / (10000 ** torch.linspace(0, 1, embed_dim // num_heads // 2))
20
+ angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
21
+ self.initial_value = initial_value
22
+ self.heads_range = heads_range
23
+ self.num_heads = num_heads
24
+ decay = torch.log(1 - 2 ** (-initial_value - heads_range * torch.arange(num_heads, dtype=torch.float) / num_heads))
25
+ self.register_buffer('angle', angle)
26
+ self.register_buffer('decay', decay)
27
+
28
+ def generate_pos_decay(self, H: int, W: int):
29
+ '''
30
+ generate 2d decay mask, the result is (HW)*(HW)
31
+ '''
32
+ index_h = torch.arange(H).to(self.decay) #保持一個類型
33
+ index_w = torch.arange(W).to(self.decay) #
34
+ grid = torch.meshgrid([index_h, index_w])
35
+ grid = torch.stack(grid, dim=-1).reshape(H*W, 2) #(H*W 2)
36
+ mask = grid[:, None, :] - grid[None, :, :] #(H*W H*W 2)
37
+ mask = (mask.abs()).sum(dim=-1)
38
+ mask = mask #* self.decay[:, None, None] #(n H*W H*W)
39
+ return mask
40
+
41
+ def generate_2d_depth_decay(self, H: int, W: int, depth_grid):
42
+ '''
43
+ generate 2d decay mask, the result is (HW)*(HW)
44
+ '''
45
+ # index_h = torch.arange(H).to(self.decay) #保持一個類型
46
+ # index_w = torch.arange(W).to(self.decay) #
47
+ # grid = torch.meshgrid([index_h, index_w])
48
+ # grid = torch.stack(grid, dim=-1).reshape(H*W, 2) #(H*W 2)
49
+ # to do: resize depth_grid to H,W
50
+ # print(depth_grid.shape,H,W,'2d')
51
+ B,_,H,W = depth_grid.shape
52
+ grid_d = depth_grid.reshape(B, H*W, 1)
53
+ print(grid_d.dtype,'aaaaaaaaaaaaaaaaaa')
54
+ # exit()
55
+ mask_d = grid_d[:, :, None, :] - grid_d[:, None,:, :] #(H*W H*W)
56
+ # mask = grid[:, None, :] - grid[None, :, :] #(H*W H*W 2)
57
+ # print(mask_d.shape, self.decay[None, :, None, None].shape,'11111')
58
+ mask_d = (mask_d.abs()).sum(dim=-1)
59
+ # print(torch.max(mask_d),torch.min(mask_d))
60
+ # exit()
61
+ mask_d = mask_d.unsqueeze(1) #* self.decay[None, :, None, None].cpu() #(n H*W H*W)
62
+ return mask_d
63
+
64
+
65
+
66
+ def forward(self, slen, depth_map, activate_recurrent=False, chunkwise_recurrent=False):
67
+ '''
68
+ slen: (h, w)
69
+ h * w == l
70
+ recurrent is not implemented
71
+ '''
72
+ # print(depth_map.shape,'depth_map')
73
+ depth_map = F.interpolate(depth_map, size=slen,mode='bilinear',align_corners=False)
74
+ # print(depth_map.shape,'downsampled')
75
+ depth_map = depth_map.float()
76
+ # depth_map = Resize(slen[0],slen[1])(depth_map).reshape(slen[0],slen[1])
77
+
78
+ index = torch.arange(slen[0]*slen[1]).to(self.decay)
79
+ sin = torch.sin(index[:, None] * self.angle[None, :]) #(l d1)
80
+ sin = sin.reshape(slen[0], slen[1], -1) #(h w d1)
81
+ cos = torch.cos(index[:, None] * self.angle[None, :]) #(l d1)
82
+ cos = cos.reshape(slen[0], slen[1], -1) #(h w d1)
83
+ mask_1 = self.generate_pos_decay(slen[0], slen[1]) #(n l l)
84
+ mask_d = self.generate_2d_depth_decay(slen[0], slen[1], depth_map)
85
+ print(torch.max(mask_d),torch.min(mask_d),'-2')
86
+ mask = mask_d#/torch.max(mask_d, dim=0)[0] #mask.cpu() * (2*(1-
87
+ mask_sum = (0.85*mask_1.cpu()+0.15*mask) * self.decay[:, None, None].cpu()
88
+ retention_rel_pos = ((sin, cos), mask, mask_1, mask_sum)
89
+ print(mask.shape,mask_1.shape)
90
+ # exit()
91
+
92
+ return retention_rel_pos
93
+
94
+ def fangda(mask, in_size=(480//20,640//20), out_size=(480,640)):
95
+ new_mask = torch.zeros(out_size)
96
+ ratio_h, ratio_w = out_size[0]//in_size[0], out_size[1]//in_size[1]
97
+ for i in range(in_size[0]):
98
+ for j in range(in_size[1]):
99
+ new_mask[i*ratio_h:(i+1)*ratio_h,j*ratio_w:(j+1)*ratio_w]=mask[i,j]
100
+ return new_mask
101
+
102
+ def put_mask(image,mask,color_rgb=None,border_mask=False,color_temp='jet',num_c='',beta=2,fixed_num=None):
103
+ mask = mask.numpy()
104
+ image = cv2.resize(image,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
105
+ mask = cv2.resize(mask,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
106
+ color=np.zeros((1,1,3), dtype=np.uint8)
107
+ if color_rgb is not None:
108
+ color[0,0,2],color[0,0,1],color[0,0,0]=color_rgb
109
+ else:
110
+ color[0, 0, 2], color[0, 0, 1], color[0, 0, 0]=120,86,87
111
+ if fixed_num is not None:
112
+ mask = ((1-mask/255))
113
+ else:
114
+ mask=(1-mask/np.max(mask))#*0.5+0.5
115
+
116
+
117
+
118
+
119
+ result = overlay_mask(to_pil_image(image.astype(np.uint8)), to_pil_image(mask), colormap = color_temp, alpha=0.4)
120
+
121
+
122
+ return np.array(result)
123
+
124
+
125
+ def visualize_geometry_prior(RGB_path, Depth_path, index_list=[[584]], cmap_list = ['jet_r'],x=0,y=0):
126
+
127
+ H = 480//20
128
+ W = 640//20
129
+ index_num = int(x//20)+int((y//20+1)*32)
130
+ index_list = [[index_num]]
131
+ print(index_num)
132
+
133
+ grid_d = cv2.imread(Depth_path,0)
134
+ # return grid_d
135
+ grid_d = cv2.resize(grid_d,dsize=(W,H),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
136
+
137
+ grid_d = torch.tensor(grid_d).reshape(1,1,H,W)
138
+ grid_d_copy=cv2.imread(Depth_path)
139
+ grid_d_copy = cv2.resize(grid_d_copy,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
140
+ grid_d_copy_gray = cv2.imread(Depth_path,0)
141
+ grid_d_copy_gray = cv2.resize(grid_d_copy_gray,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
142
+ print('min max', torch.max(grid_d), torch.min(grid_d))
143
+ print(grid_d.shape)
144
+ grid_d=grid_d.cpu()
145
+
146
+ respos = GeoPrior()
147
+ ((sin,cos), depth_map, mask_1, mask_sum) = respos((H,W), grid_d)
148
+ print(depth_map.shape, mask_1.shape,'-1')
149
+ print(torch.max(depth_map),torch.min(depth_map))
150
+
151
+ #
152
+ img_path = RGB_path
153
+ img = cv2.imread(img_path)
154
+ img = cv2.resize(img,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
155
+
156
+ grid_d_old = cv2.imread(Depth_path,0)
157
+ grid_d_old = cv2.resize(grid_d_old,dsize=(W,H),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
158
+ grid_d_old = torch.tensor(grid_d_old).reshape(H*W,1)
159
+ grid_d=grid_d.cpu()
160
+ mask_d_old = grid_d_old[:, None, :] - grid_d_old[None, :, :] #(H*W H*W 2)
161
+ mask_d_old = (mask_d_old.abs()).sum(dim=-1)
162
+ Color_N=255
163
+ for i in index_list[0]:#range(0,H*W,4):#range(0,H*W,4):#index_list[i_temp]:#range(0,H*W,1): [242,258]:#range
164
+ for color_temp in cmap_list:
165
+
166
+ temp_mask_d = depth_map[0,0,i,:].reshape(H,W).cpu()
167
+
168
+ temp_mask = mask_1[i,:].reshape(H,W).cpu()
169
+ print(torch.max(temp_mask_d),torch.min(temp_mask_d))
170
+ temp_mask_d_old = mask_d_old[i,:].reshape(H,W).cpu()
171
+ temp_mask_sum = mask_sum[0,0,i,:].reshape(H,W).cpu()
172
+ temp_mask_d=torch.nn.functional.normalize(temp_mask_d, p=2.0, dim=1, eps=1e-12, out=None)
173
+
174
+ temp_mask_d = 255*(temp_mask_d-torch.min(temp_mask_d))/(torch.max(temp_mask_d)-torch.min(temp_mask_d))
175
+
176
+ temp_mask = 255*((temp_mask-torch.min(temp_mask))/(torch.max(temp_mask)-torch.min(temp_mask)))
177
+
178
+ temp_mask_sum = 255*((temp_mask_sum-torch.min(temp_mask_sum))/(torch.max(temp_mask_sum)-torch.min(temp_mask_sum)))
179
+ gama =0.55
180
+ temp_mask_d_old = 255*(temp_mask_d_old-torch.min(temp_mask_d_old))/(torch.max(temp_mask_d_old)-torch.min(temp_mask_d_old))
181
+ a0=put_mask(img,fangda(temp_mask),color_temp=color_temp)
182
+ jiange = 255*torch.ones(img.shape[0],20)
183
+ temp_mask_fuse = torch.cat([fangda(temp_mask),jiange,fangda(temp_mask_d),jiange,fangda(gama*temp_mask+(1-gama)*temp_mask_d),jiange,torch.tensor(grid_d_copy_gray)],dim=1)
184
+ jiange = np.ones((img.shape[0],20, 3)) * 255
185
+
186
+ a2 = put_mask(img, fangda(temp_mask_d),color_temp=color_temp)
187
+ print(a2.shape)
188
+ a3 = put_mask(img,fangda(gama*temp_mask+(1-gama)*temp_mask_d),color_temp=color_temp)
189
+ # image = np.concatenate([img,grid_d_copy,a0,jiange, a2,jiange,a3,jiange] ,axis=1)
190
+ # print(image.dtype)
191
+ # print(np.max(image),np.min(image))
192
+ # cv2.imshow('./temp/demo_'+color_temp+"_"+str(i)+'.png',image.astype(np.uint8))
193
+ return a3.astype(np.uint8)
194
+
195
+ # 新增gradio接口函数
196
+ def process_images(rgb_image, depth_image):
197
+ """
198
+ 处理上传的图像并返回可视化结果
199
+
200
+ Args:
201
+ rgb_image: gradio上传的RGB图像
202
+ depth_image: gradio上传的深度图像
203
+ Returns:
204
+ 可视化结果图像
205
+ """
206
+ # 保存临时文件
207
+ temp_rgb_path = "temp_rgb.jpg"
208
+ temp_depth_path = "temp_depth.png"
209
+
210
+ # 保存上传的图像
211
+ cv2.imwrite(temp_rgb_path, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR))
212
+ cv2.imwrite(temp_depth_path, depth_image)
213
+
214
+ # 调用原有的可视化函数
215
+ try:
216
+ result = visualize_geometry_prior(temp_rgb_path, temp_depth_path,x=x,y=y)
217
+
218
+ # 清理临时文件
219
+ os.remove(temp_rgb_path)
220
+ os.remove(temp_depth_path)
221
+
222
+ # 转换颜色空间
223
+ result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
224
+ return result
225
+ except Exception as e:
226
+ print(f"Error during processing: {str(e)}")
227
+ return None
228
+ finally:
229
+ # 确保临时文件被删除
230
+ if os.path.exists(temp_rgb_path):
231
+ os.remove(temp_rgb_path)
232
+ if os.path.exists(temp_depth_path):
233
+ os.remove(temp_depth_path)
234
+
235
+ def draw_star(image, x, y, size=20, color=(255, 0, 0), thickness=2):
236
+ """在图像上绘制五角星"""
237
+ # 计算五角星的顶点
238
+ pts = np.array([[x, y - size], # 顶部点
239
+ [x + size * 0.588, y + size * 0.809], # 右下
240
+ [x - size * 0.951, y - size * 0.309], # 左上
241
+ [x + size * 0.951, y - size * 0.309], # 右上
242
+ [x - size * 0.588, y + size * 0.809]], np.int32) # 左下
243
+
244
+ # 绘制五角星
245
+ cv2.polylines(image, [pts], True, color, thickness)
246
+ return image
247
+
248
+ # 创建Gradio界面
249
+ def create_demo():
250
+ with gr.Blocks() as demo:
251
+ gr.Markdown("# Geometry Prior Visualization Demo")
252
+ gr.Markdown("""
253
+ ### Instructions:
254
+ 1. Upload RGB and Depth images
255
+ 2. Enter X (0-640) and Y (0-480) coordinates
256
+ 3. A star marker will be shown on the images at the selected position
257
+ 4. Click "Generate Visualization" to create the visualization
258
+ """)
259
+
260
+ with gr.Row():
261
+ with gr.Column():
262
+ rgb_input = gr.Image(label="Upload RGB Image")
263
+ depth_input = gr.Image(label="Upload Depth Image", image_mode="L")
264
+ with gr.Row():
265
+ x_coord = gr.Number(label="X (0-640)", value=160, minimum=0, maximum=640)
266
+ y_coord = gr.Number(label="Y (0-480)", value=270, minimum=0, maximum=480)
267
+ coordinates_text = gr.Textbox(label="Grid Position and Index", interactive=False)
268
+
269
+ with gr.Column():
270
+ marked_rgb = gr.Image(label="Marked RGB Image")
271
+ marked_depth = gr.Image(label="Marked Depth Image")
272
+ output_image = gr.Image(label="Visualization Result")
273
+ status_text = gr.Textbox(label="Status", interactive=False)
274
+
275
+ def update_coordinates_and_images(rgb_image, depth_image, x, y):
276
+ # 确保坐标在有效范围内
277
+ x = max(0, min(640, float(x)))
278
+ y = max(0, min(480, float(y)))
279
+
280
+ # 计算在24x32网格中的位置
281
+ H, W = 480//20, 640//20 # 24x32
282
+ scaled_x = int(x * W / 640)
283
+ scaled_y = int(y * H / 480)
284
+
285
+ # 计算一维索引
286
+ grid_index = scaled_y * W + scaled_x
287
+
288
+ # 在RGB图像上绘制五角星
289
+ rgb_marked = rgb_image.copy()
290
+ if len(rgb_marked.shape) == 2: # 如果是灰度图
291
+ rgb_marked = cv2.cvtColor(rgb_marked, cv2.COLOR_GRAY2BGR)
292
+ elif rgb_marked.shape[2] == 4: # 如果是RGBA
293
+ rgb_marked = cv2.cvtColor(rgb_marked, cv2.COLOR_RGBA2BGR)
294
+ rgb_marked = draw_star(rgb_marked, int(x), int(y), size=20, color=(255, 0, 0))
295
+
296
+ # 在深度图像上绘制五角星
297
+ depth_marked = depth_image.copy()
298
+ if len(depth_marked.shape) == 2: # 如果是单通道
299
+ depth_marked = cv2.cvtColor(depth_marked, cv2.COLOR_GRAY2BGR)
300
+ depth_marked = draw_star(depth_marked, int(x), int(y), size=20, color=(0, 255, 0))
301
+
302
+ return (f"Grid position: ({scaled_x}, {scaled_y}), Index: {grid_index}",
303
+ rgb_marked,
304
+ depth_marked)
305
+
306
+ # 坐标更新按钮
307
+ coord_update_btn = gr.Button("Update Coordinates")
308
+ coord_update_btn.click(
309
+ fn=update_coordinates_and_images,
310
+ inputs=[rgb_input, depth_input, x_coord, y_coord],
311
+ outputs=[coordinates_text, marked_rgb, marked_depth]
312
+ )
313
+
314
+ def process_with_status(rgb_image, depth_image, coords_text, x, y):
315
+ try:
316
+ # 保存临时文件
317
+ temp_rgb_path = "temp_rgb.jpg"
318
+ temp_depth_path = "temp_depth.png"
319
+
320
+ # 保存上传的图像
321
+ cv2.imwrite(temp_rgb_path, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR))
322
+ cv2.imwrite(temp_depth_path, depth_image)
323
+
324
+ if coords_text:
325
+ index = int(coords_text.split("Index: ")[-1])
326
+ index_list = [[index]]
327
+ else:
328
+ index_list = [[584]] # 默认值
329
+
330
+ result = visualize_geometry_prior(temp_rgb_path, temp_depth_path, index_list=index_list, x=x, y=y)
331
+
332
+ # 清理临时文件
333
+ os.remove(temp_rgb_path)
334
+ os.remove(temp_depth_path)
335
+
336
+ # 转换颜色空间以正确显示
337
+ result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
338
+ return result, "Processing completed successfully!"
339
+ except Exception as e:
340
+ # 确保临时文件被删除
341
+ if os.path.exists(temp_rgb_path):
342
+ os.remove(temp_rgb_path)
343
+ if os.path.exists(temp_depth_path):
344
+ os.remove(temp_depth_path)
345
+ return None, f"Error: {str(e)}"
346
+
347
+ process_btn = gr.Button("Generate Visualization")
348
+ process_btn.click(
349
+ fn=process_with_status,
350
+ inputs=[rgb_input, depth_input, coordinates_text, x_coord, y_coord],
351
+ outputs=[output_image, status_text]
352
+ )
353
+
354
+ # 添加自动更新功能
355
+ x_coord.change(
356
+ fn=update_coordinates_and_images,
357
+ inputs=[rgb_input, depth_input, x_coord, y_coord],
358
+ outputs=[coordinates_text, marked_rgb, marked_depth]
359
+ )
360
+ y_coord.change(
361
+ fn=update_coordinates_and_images,
362
+ inputs=[rgb_input, depth_input, x_coord, y_coord],
363
+ outputs=[coordinates_text, marked_rgb, marked_depth]
364
+ )
365
+
366
+ gr.Examples(
367
+ examples=[
368
+ ["assets/example_rgb.jpg", "assets/example_depth.png"]
369
+ ],
370
+ inputs=[rgb_input, depth_input]
371
+ )
372
+
373
+ return demo
374
+
375
+ # 启动代码
376
+ if __name__ == "__main__":
377
+ demo = create_demo()
378
+ demo.queue()
379
+ demo.launch(
380
+ server_name="0.0.0.0",
381
+ share=True,
382
+ debug=True
383
+ )
assets/example_depth.png ADDED

Git LFS Details

  • SHA256: 702a9dacb2e590795c2faf1ca9fd7477ecfe43ea4ebaaf7784a089cb577d30df
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB
assets/example_rgb.jpg ADDED

Git LFS Details

  • SHA256: e08ce451dc2aa69601c9c73f85bfa39a59f35f6c98ea43c629b10fb7b376609e
  • Pointer size: 131 Bytes
  • Size of remote file: 111 kB
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ opencv-python-headless
5
+ numpy
6
+ torchcam