import torch import torch.nn as nn import torch import torch.nn.functional as F import torch.nn as nn from torchvision.transforms.functional import to_pil_image from torchvision.transforms import Resize import cv2 import numpy as np from torchcam.utils import overlay_mask import gradio as gr import os class GeoPrior(nn.Module): def __init__(self, embed_dim=128, num_heads=4, initial_value=2, heads_range=6): super().__init__() angle = 1.0 / (10000 ** torch.linspace(0, 1, embed_dim // num_heads // 2)) angle = angle.unsqueeze(-1).repeat(1, 2).flatten() self.initial_value = initial_value self.heads_range = heads_range self.num_heads = num_heads decay = torch.log(1 - 2 ** (-initial_value - heads_range * torch.arange(num_heads, dtype=torch.float) / num_heads)) self.register_buffer('angle', angle) self.register_buffer('decay', decay) def generate_pos_decay(self, H: int, W: int): ''' generate 2d decay mask, the result is (HW)*(HW) ''' index_h = torch.arange(H).to(self.decay) #保持一個類型 index_w = torch.arange(W).to(self.decay) # grid = torch.meshgrid([index_h, index_w]) grid = torch.stack(grid, dim=-1).reshape(H*W, 2) #(H*W 2) mask = grid[:, None, :] - grid[None, :, :] #(H*W H*W 2) mask = (mask.abs()).sum(dim=-1) mask = mask #* self.decay[:, None, None] #(n H*W H*W) return mask def generate_2d_depth_decay(self, H: int, W: int, depth_grid): ''' generate 2d decay mask, the result is (HW)*(HW) ''' # index_h = torch.arange(H).to(self.decay) #保持一個類型 # index_w = torch.arange(W).to(self.decay) # # grid = torch.meshgrid([index_h, index_w]) # grid = torch.stack(grid, dim=-1).reshape(H*W, 2) #(H*W 2) # to do: resize depth_grid to H,W # print(depth_grid.shape,H,W,'2d') B,_,H,W = depth_grid.shape grid_d = depth_grid.reshape(B, H*W, 1) print(grid_d.dtype,'aaaaaaaaaaaaaaaaaa') # exit() mask_d = grid_d[:, :, None, :] - grid_d[:, None,:, :] #(H*W H*W) # mask = grid[:, None, :] - grid[None, :, :] #(H*W H*W 2) # print(mask_d.shape, self.decay[None, :, None, None].shape,'11111') mask_d = (mask_d.abs()).sum(dim=-1) # print(torch.max(mask_d),torch.min(mask_d)) # exit() mask_d = mask_d.unsqueeze(1) #* self.decay[None, :, None, None].cpu() #(n H*W H*W) return mask_d def forward(self, slen, depth_map, activate_recurrent=False, chunkwise_recurrent=False): ''' slen: (h, w) h * w == l recurrent is not implemented ''' # print(depth_map.shape,'depth_map') depth_map = F.interpolate(depth_map, size=slen,mode='bilinear',align_corners=False) # print(depth_map.shape,'downsampled') depth_map = depth_map.float() # depth_map = Resize(slen[0],slen[1])(depth_map).reshape(slen[0],slen[1]) index = torch.arange(slen[0]*slen[1]).to(self.decay) sin = torch.sin(index[:, None] * self.angle[None, :]) #(l d1) sin = sin.reshape(slen[0], slen[1], -1) #(h w d1) cos = torch.cos(index[:, None] * self.angle[None, :]) #(l d1) cos = cos.reshape(slen[0], slen[1], -1) #(h w d1) mask_1 = self.generate_pos_decay(slen[0], slen[1]) #(n l l) mask_d = self.generate_2d_depth_decay(slen[0], slen[1], depth_map) print(torch.max(mask_d),torch.min(mask_d),'-2') mask = mask_d#/torch.max(mask_d, dim=0)[0] #mask.cpu() * (2*(1- mask_sum = (0.85*mask_1.cpu()+0.15*mask) * self.decay[:, None, None].cpu() retention_rel_pos = ((sin, cos), mask, mask_1, mask_sum) print(mask.shape,mask_1.shape) # exit() return retention_rel_pos def fangda(mask, in_size=(480//20,640//20), out_size=(480,640)): new_mask = torch.zeros(out_size) ratio_h, ratio_w = out_size[0]//in_size[0], out_size[1]//in_size[1] for i in range(in_size[0]): for j in range(in_size[1]): new_mask[i*ratio_h:(i+1)*ratio_h,j*ratio_w:(j+1)*ratio_w]=mask[i,j] return new_mask def put_mask(image,mask,color_rgb=None,border_mask=False,color_temp='jet',num_c='',beta=2,fixed_num=None): mask = mask.numpy() image = cv2.resize(image,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) mask = cv2.resize(mask,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) color=np.zeros((1,1,3), dtype=np.uint8) if color_rgb is not None: color[0,0,2],color[0,0,1],color[0,0,0]=color_rgb else: color[0, 0, 2], color[0, 0, 1], color[0, 0, 0]=120,86,87 if fixed_num is not None: mask = ((1-mask/255)) else: mask=(1-mask/np.max(mask))#*0.5+0.5 heatmap = cv2.applyColorMap((mask * 255).astype(np.uint8), cv2.COLORMAP_JET) result = cv2.addWeighted(image, 0.6, heatmap, 0.4, 0) return np.array(result) def visualize_geometry_prior(RGB_path, Depth_path, index_list=[[584]], cmap_list = ['jet_r'],x=0,y=0): H = 480//20 W = 640//20 index_num = int(x//20)+int((y//20+1)*32) index_list = [[index_num]] print(index_num) grid_d = cv2.imread(Depth_path,0) # return grid_d grid_d = cv2.resize(grid_d,dsize=(W,H),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) grid_d = torch.tensor(grid_d).reshape(1,1,H,W) grid_d_copy=cv2.imread(Depth_path) grid_d_copy = cv2.resize(grid_d_copy,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) grid_d_copy_gray = cv2.imread(Depth_path,0) grid_d_copy_gray = cv2.resize(grid_d_copy_gray,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) print('min max', torch.max(grid_d), torch.min(grid_d)) print(grid_d.shape) grid_d=grid_d.cpu() respos = GeoPrior() ((sin,cos), depth_map, mask_1, mask_sum) = respos((H,W), grid_d) print(depth_map.shape, mask_1.shape,'-1') print(torch.max(depth_map),torch.min(depth_map)) # img_path = RGB_path img = cv2.imread(img_path) img = cv2.resize(img,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) grid_d_old = cv2.imread(Depth_path,0) grid_d_old = cv2.resize(grid_d_old,dsize=(W,H),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) grid_d_old = torch.tensor(grid_d_old).reshape(H*W,1) grid_d=grid_d.cpu() mask_d_old = grid_d_old[:, None, :] - grid_d_old[None, :, :] #(H*W H*W 2) mask_d_old = (mask_d_old.abs()).sum(dim=-1) Color_N=255 for i in index_list[0]:#range(0,H*W,4):#range(0,H*W,4):#index_list[i_temp]:#range(0,H*W,1): [242,258]:#range for color_temp in cmap_list: temp_mask_d = depth_map[0,0,i,:].reshape(H,W).cpu() temp_mask = mask_1[i,:].reshape(H,W).cpu() print(torch.max(temp_mask_d),torch.min(temp_mask_d)) temp_mask_d_old = mask_d_old[i,:].reshape(H,W).cpu() temp_mask_sum = mask_sum[0,0,i,:].reshape(H,W).cpu() temp_mask_d=torch.nn.functional.normalize(temp_mask_d, p=2.0, dim=1, eps=1e-12, out=None) temp_mask_d = 255*(temp_mask_d-torch.min(temp_mask_d))/(torch.max(temp_mask_d)-torch.min(temp_mask_d)) temp_mask = 255*((temp_mask-torch.min(temp_mask))/(torch.max(temp_mask)-torch.min(temp_mask))) temp_mask_sum = 255*((temp_mask_sum-torch.min(temp_mask_sum))/(torch.max(temp_mask_sum)-torch.min(temp_mask_sum))) gama =0.55 temp_mask_d_old = 255*(temp_mask_d_old-torch.min(temp_mask_d_old))/(torch.max(temp_mask_d_old)-torch.min(temp_mask_d_old)) a0=put_mask(img,fangda(temp_mask),color_temp=color_temp) jiange = 255*torch.ones(img.shape[0],20) temp_mask_fuse = torch.cat([fangda(temp_mask),jiange,fangda(temp_mask_d),jiange,fangda(gama*temp_mask+(1-gama)*temp_mask_d),jiange,torch.tensor(grid_d_copy_gray)],dim=1) jiange = np.ones((img.shape[0],20, 3)) * 255 a2 = put_mask(img, fangda(temp_mask_d),color_temp=color_temp) print(a2.shape) a3 = put_mask(img,fangda(gama*temp_mask+(1-gama)*temp_mask_d),color_temp=color_temp) # image = np.concatenate([img,grid_d_copy,a0,jiange, a2,jiange,a3,jiange] ,axis=1) # print(image.dtype) # print(np.max(image),np.min(image)) # cv2.imshow('./temp/demo_'+color_temp+"_"+str(i)+'.png',image.astype(np.uint8)) return a3.astype(np.uint8) # 新增gradio接口函数 def process_images(rgb_image, depth_image): """ 处理上传的图像并返回可视化结果 Args: rgb_image: gradio上传的RGB图像 depth_image: gradio上传的深度图像 Returns: 可视化结果图像 """ # 保存临时文件 temp_rgb_path = "temp_rgb.jpg" temp_depth_path = "temp_depth.png" # 保存上传的图像 cv2.imwrite(temp_rgb_path, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) cv2.imwrite(temp_depth_path, depth_image) # 调用原有的可视化函数 try: result = visualize_geometry_prior(temp_rgb_path, temp_depth_path,x=x,y=y) # 清理临时文件 os.remove(temp_rgb_path) os.remove(temp_depth_path) # 转换颜色空间 result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) return result except Exception as e: print(f"Error during processing: {str(e)}") return None finally: # 确保临时文件被删除 if os.path.exists(temp_rgb_path): os.remove(temp_rgb_path) if os.path.exists(temp_depth_path): os.remove(temp_depth_path) def draw_star(image, x, y, size=20, color=(255, 0, 0), thickness=2): """在图像上绘制五角星""" # 计算五角星的顶点 pts = np.array([[x, y - size], # 顶部点 [x + size * 0.588, y + size * 0.809], # 右下 [x - size * 0.951, y - size * 0.309], # 左上 [x + size * 0.951, y - size * 0.309], # 右上 [x - size * 0.588, y + size * 0.809]], np.int32) # 左下 # 绘制五角星 cv2.polylines(image, [pts], True, color, thickness) return image # 创建Gradio界面 def create_demo(): with gr.Blocks() as demo: gr.Markdown("# Geometry Prior Visualization Demo") gr.Markdown(""" ### Instructions: 1. Upload RGB and Depth images 2. Enter X (0-640) and Y (0-480) coordinates 3. A star marker will be shown on the images at the selected position 4. Click "Generate Visualization" to create the visualization """) with gr.Row(): with gr.Column(): rgb_input = gr.Image(label="Upload RGB Image") depth_input = gr.Image(label="Upload Depth Image", image_mode="L") with gr.Row(): x_coord = gr.Number(label="X (0-640)", value=160, minimum=0, maximum=640) y_coord = gr.Number(label="Y (0-480)", value=270, minimum=0, maximum=480) coordinates_text = gr.Textbox(label="Grid Position and Index", interactive=False) with gr.Column(): marked_rgb = gr.Image(label="Marked RGB Image") marked_depth = gr.Image(label="Marked Depth Image") output_image = gr.Image(label="Visualization Result") status_text = gr.Textbox(label="Status", interactive=False) def update_coordinates_and_images(rgb_image, depth_image, x, y): # 确保坐标在有效范围内 x = max(0, min(640, float(x))) y = max(0, min(480, float(y))) # 计算在24x32网格中的位置 H, W = 480//20, 640//20 # 24x32 scaled_x = int(x * W / 640) scaled_y = int(y * H / 480) # 计算一维索引 grid_index = scaled_y * W + scaled_x # 在RGB图像上绘制五角星 rgb_marked = rgb_image.copy() if len(rgb_marked.shape) == 2: # 如果是灰度图 rgb_marked = cv2.cvtColor(rgb_marked, cv2.COLOR_GRAY2BGR) elif rgb_marked.shape[2] == 4: # 如果是RGBA rgb_marked = cv2.cvtColor(rgb_marked, cv2.COLOR_RGBA2BGR) rgb_marked = draw_star(rgb_marked, int(x), int(y), size=20, color=(255, 0, 0)) # 在深度图像上绘制五角星 depth_marked = depth_image.copy() if len(depth_marked.shape) == 2: # 如果是单通道 depth_marked = cv2.cvtColor(depth_marked, cv2.COLOR_GRAY2BGR) depth_marked = draw_star(depth_marked, int(x), int(y), size=20, color=(0, 255, 0)) return (f"Grid position: ({scaled_x}, {scaled_y}), Index: {grid_index}", rgb_marked, depth_marked) # 坐标更新按钮 coord_update_btn = gr.Button("Update Coordinates") coord_update_btn.click( fn=update_coordinates_and_images, inputs=[rgb_input, depth_input, x_coord, y_coord], outputs=[coordinates_text, marked_rgb, marked_depth] ) def process_with_status(rgb_image, depth_image, coords_text, x, y): try: # 保存临时文件 temp_rgb_path = "temp_rgb.jpg" temp_depth_path = "temp_depth.png" # 保存上传的图像 cv2.imwrite(temp_rgb_path, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) cv2.imwrite(temp_depth_path, depth_image) if coords_text: index = int(coords_text.split("Index: ")[-1]) index_list = [[index]] else: index_list = [[584]] # 默认值 result = visualize_geometry_prior(temp_rgb_path, temp_depth_path, index_list=index_list, x=x, y=y) # 清理临时文件 os.remove(temp_rgb_path) os.remove(temp_depth_path) # 转换颜色空间以正确显示 result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) return result, "Processing completed successfully!" except Exception as e: # 确保临时文件被删除 if os.path.exists(temp_rgb_path): os.remove(temp_rgb_path) if os.path.exists(temp_depth_path): os.remove(temp_depth_path) return None, f"Error: {str(e)}" process_btn = gr.Button("Generate Visualization") process_btn.click( fn=process_with_status, inputs=[rgb_input, depth_input, coordinates_text, x_coord, y_coord], outputs=[output_image, status_text] ) # 添加自动更新功能 x_coord.change( fn=update_coordinates_and_images, inputs=[rgb_input, depth_input, x_coord, y_coord], outputs=[coordinates_text, marked_rgb, marked_depth] ) y_coord.change( fn=update_coordinates_and_images, inputs=[rgb_input, depth_input, x_coord, y_coord], outputs=[coordinates_text, marked_rgb, marked_depth] ) gr.Examples( examples=[ ["assets/example_rgb.jpg", "assets/example_depth.png"] ], inputs=[rgb_input, depth_input] ) return demo # 启动代码 if __name__ == "__main__": demo = create_demo() demo.queue() demo.launch( server_name="0.0.0.0", share=True, debug=True )