DFormerv2 / app.py
bbynku's picture
Upload 5 files
4dee8df verified
import torch
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms import Resize
import cv2
import numpy as np
from torchcam.utils import overlay_mask
import gradio as gr
import os
class GeoPrior(nn.Module):
def __init__(self, embed_dim=128, num_heads=4, initial_value=2, heads_range=6):
super().__init__()
angle = 1.0 / (10000 ** torch.linspace(0, 1, embed_dim // num_heads // 2))
angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
self.initial_value = initial_value
self.heads_range = heads_range
self.num_heads = num_heads
decay = torch.log(1 - 2 ** (-initial_value - heads_range * torch.arange(num_heads, dtype=torch.float) / num_heads))
self.register_buffer('angle', angle)
self.register_buffer('decay', decay)
def generate_pos_decay(self, H: int, W: int):
'''
generate 2d decay mask, the result is (HW)*(HW)
'''
index_h = torch.arange(H).to(self.decay) #保持一個類型
index_w = torch.arange(W).to(self.decay) #
grid = torch.meshgrid([index_h, index_w])
grid = torch.stack(grid, dim=-1).reshape(H*W, 2) #(H*W 2)
mask = grid[:, None, :] - grid[None, :, :] #(H*W H*W 2)
mask = (mask.abs()).sum(dim=-1)
mask = mask #* self.decay[:, None, None] #(n H*W H*W)
return mask
def generate_2d_depth_decay(self, H: int, W: int, depth_grid):
'''
generate 2d decay mask, the result is (HW)*(HW)
'''
# index_h = torch.arange(H).to(self.decay) #保持一個類型
# index_w = torch.arange(W).to(self.decay) #
# grid = torch.meshgrid([index_h, index_w])
# grid = torch.stack(grid, dim=-1).reshape(H*W, 2) #(H*W 2)
# to do: resize depth_grid to H,W
# print(depth_grid.shape,H,W,'2d')
B,_,H,W = depth_grid.shape
grid_d = depth_grid.reshape(B, H*W, 1)
print(grid_d.dtype,'aaaaaaaaaaaaaaaaaa')
# exit()
mask_d = grid_d[:, :, None, :] - grid_d[:, None,:, :] #(H*W H*W)
# mask = grid[:, None, :] - grid[None, :, :] #(H*W H*W 2)
# print(mask_d.shape, self.decay[None, :, None, None].shape,'11111')
mask_d = (mask_d.abs()).sum(dim=-1)
# print(torch.max(mask_d),torch.min(mask_d))
# exit()
mask_d = mask_d.unsqueeze(1) #* self.decay[None, :, None, None].cpu() #(n H*W H*W)
return mask_d
def forward(self, slen, depth_map, activate_recurrent=False, chunkwise_recurrent=False):
'''
slen: (h, w)
h * w == l
recurrent is not implemented
'''
# print(depth_map.shape,'depth_map')
depth_map = F.interpolate(depth_map, size=slen,mode='bilinear',align_corners=False)
# print(depth_map.shape,'downsampled')
depth_map = depth_map.float()
# depth_map = Resize(slen[0],slen[1])(depth_map).reshape(slen[0],slen[1])
index = torch.arange(slen[0]*slen[1]).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :]) #(l d1)
sin = sin.reshape(slen[0], slen[1], -1) #(h w d1)
cos = torch.cos(index[:, None] * self.angle[None, :]) #(l d1)
cos = cos.reshape(slen[0], slen[1], -1) #(h w d1)
mask_1 = self.generate_pos_decay(slen[0], slen[1]) #(n l l)
mask_d = self.generate_2d_depth_decay(slen[0], slen[1], depth_map)
print(torch.max(mask_d),torch.min(mask_d),'-2')
mask = mask_d#/torch.max(mask_d, dim=0)[0] #mask.cpu() * (2*(1-
mask_sum = (0.85*mask_1.cpu()+0.15*mask) * self.decay[:, None, None].cpu()
retention_rel_pos = ((sin, cos), mask, mask_1, mask_sum)
print(mask.shape,mask_1.shape)
# exit()
return retention_rel_pos
def fangda(mask, in_size=(480//20,640//20), out_size=(480,640)):
new_mask = torch.zeros(out_size)
ratio_h, ratio_w = out_size[0]//in_size[0], out_size[1]//in_size[1]
for i in range(in_size[0]):
for j in range(in_size[1]):
new_mask[i*ratio_h:(i+1)*ratio_h,j*ratio_w:(j+1)*ratio_w]=mask[i,j]
return new_mask
def put_mask(image,mask,color_rgb=None,border_mask=False,color_temp='jet',num_c='',beta=2,fixed_num=None):
mask = mask.numpy()
image = cv2.resize(image,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
mask = cv2.resize(mask,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
color=np.zeros((1,1,3), dtype=np.uint8)
if color_rgb is not None:
color[0,0,2],color[0,0,1],color[0,0,0]=color_rgb
else:
color[0, 0, 2], color[0, 0, 1], color[0, 0, 0]=120,86,87
if fixed_num is not None:
mask = ((1-mask/255))
else:
mask=(1-mask/np.max(mask))#*0.5+0.5
result = overlay_mask(to_pil_image(image.astype(np.uint8)), to_pil_image(mask), colormap = color_temp, alpha=0.4)
return np.array(result)
def visualize_geometry_prior(RGB_path, Depth_path, index_list=[[584]], cmap_list = ['jet_r'],x=0,y=0):
H = 480//20
W = 640//20
index_num = int(x//20)+int((y//20+1)*32)
index_list = [[index_num]]
print(index_num)
grid_d = cv2.imread(Depth_path,0)
# return grid_d
grid_d = cv2.resize(grid_d,dsize=(W,H),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
grid_d = torch.tensor(grid_d).reshape(1,1,H,W)
grid_d_copy=cv2.imread(Depth_path)
grid_d_copy = cv2.resize(grid_d_copy,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
grid_d_copy_gray = cv2.imread(Depth_path,0)
grid_d_copy_gray = cv2.resize(grid_d_copy_gray,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
print('min max', torch.max(grid_d), torch.min(grid_d))
print(grid_d.shape)
grid_d=grid_d.cpu()
respos = GeoPrior()
((sin,cos), depth_map, mask_1, mask_sum) = respos((H,W), grid_d)
print(depth_map.shape, mask_1.shape,'-1')
print(torch.max(depth_map),torch.min(depth_map))
#
img_path = RGB_path
img = cv2.imread(img_path)
img = cv2.resize(img,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
grid_d_old = cv2.imread(Depth_path,0)
grid_d_old = cv2.resize(grid_d_old,dsize=(W,H),fx=1,fy=1,interpolation=cv2.INTER_LINEAR)
grid_d_old = torch.tensor(grid_d_old).reshape(H*W,1)
grid_d=grid_d.cpu()
mask_d_old = grid_d_old[:, None, :] - grid_d_old[None, :, :] #(H*W H*W 2)
mask_d_old = (mask_d_old.abs()).sum(dim=-1)
Color_N=255
for i in index_list[0]:#range(0,H*W,4):#range(0,H*W,4):#index_list[i_temp]:#range(0,H*W,1): [242,258]:#range
for color_temp in cmap_list:
temp_mask_d = depth_map[0,0,i,:].reshape(H,W).cpu()
temp_mask = mask_1[i,:].reshape(H,W).cpu()
print(torch.max(temp_mask_d),torch.min(temp_mask_d))
temp_mask_d_old = mask_d_old[i,:].reshape(H,W).cpu()
temp_mask_sum = mask_sum[0,0,i,:].reshape(H,W).cpu()
temp_mask_d=torch.nn.functional.normalize(temp_mask_d, p=2.0, dim=1, eps=1e-12, out=None)
temp_mask_d = 255*(temp_mask_d-torch.min(temp_mask_d))/(torch.max(temp_mask_d)-torch.min(temp_mask_d))
temp_mask = 255*((temp_mask-torch.min(temp_mask))/(torch.max(temp_mask)-torch.min(temp_mask)))
temp_mask_sum = 255*((temp_mask_sum-torch.min(temp_mask_sum))/(torch.max(temp_mask_sum)-torch.min(temp_mask_sum)))
gama =0.55
temp_mask_d_old = 255*(temp_mask_d_old-torch.min(temp_mask_d_old))/(torch.max(temp_mask_d_old)-torch.min(temp_mask_d_old))
a0=put_mask(img,fangda(temp_mask),color_temp=color_temp)
jiange = 255*torch.ones(img.shape[0],20)
temp_mask_fuse = torch.cat([fangda(temp_mask),jiange,fangda(temp_mask_d),jiange,fangda(gama*temp_mask+(1-gama)*temp_mask_d),jiange,torch.tensor(grid_d_copy_gray)],dim=1)
jiange = np.ones((img.shape[0],20, 3)) * 255
a2 = put_mask(img, fangda(temp_mask_d),color_temp=color_temp)
print(a2.shape)
a3 = put_mask(img,fangda(gama*temp_mask+(1-gama)*temp_mask_d),color_temp=color_temp)
# image = np.concatenate([img,grid_d_copy,a0,jiange, a2,jiange,a3,jiange] ,axis=1)
# print(image.dtype)
# print(np.max(image),np.min(image))
# cv2.imshow('./temp/demo_'+color_temp+"_"+str(i)+'.png',image.astype(np.uint8))
return a3.astype(np.uint8)
# 新增gradio接口函数
def process_images(rgb_image, depth_image):
"""
处理上传的图像并返回可视化结果
Args:
rgb_image: gradio上传的RGB图像
depth_image: gradio上传的深度图像
Returns:
可视化结果图像
"""
# 保存临时文件
temp_rgb_path = "temp_rgb.jpg"
temp_depth_path = "temp_depth.png"
# 保存上传的图像
cv2.imwrite(temp_rgb_path, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR))
cv2.imwrite(temp_depth_path, depth_image)
# 调用原有的可视化函数
try:
result = visualize_geometry_prior(temp_rgb_path, temp_depth_path,x=x,y=y)
# 清理临时文件
os.remove(temp_rgb_path)
os.remove(temp_depth_path)
# 转换颜色空间
result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
return result
except Exception as e:
print(f"Error during processing: {str(e)}")
return None
finally:
# 确保临时文件被删除
if os.path.exists(temp_rgb_path):
os.remove(temp_rgb_path)
if os.path.exists(temp_depth_path):
os.remove(temp_depth_path)
def draw_star(image, x, y, size=20, color=(255, 0, 0), thickness=2):
"""在图像上绘制五角星"""
# 计算五角星的顶点
pts = np.array([[x, y - size], # 顶部点
[x + size * 0.588, y + size * 0.809], # 右下
[x - size * 0.951, y - size * 0.309], # 左上
[x + size * 0.951, y - size * 0.309], # 右上
[x - size * 0.588, y + size * 0.809]], np.int32) # 左下
# 绘制五角星
cv2.polylines(image, [pts], True, color, thickness)
return image
# 创建Gradio界面
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# Geometry Prior Visualization Demo")
gr.Markdown("""
### Instructions:
1. Upload RGB and Depth images
2. Enter X (0-640) and Y (0-480) coordinates
3. A star marker will be shown on the images at the selected position
4. Click "Generate Visualization" to create the visualization
""")
with gr.Row():
with gr.Column():
rgb_input = gr.Image(label="Upload RGB Image")
depth_input = gr.Image(label="Upload Depth Image", image_mode="L")
with gr.Row():
x_coord = gr.Number(label="X (0-640)", value=160, minimum=0, maximum=640)
y_coord = gr.Number(label="Y (0-480)", value=270, minimum=0, maximum=480)
coordinates_text = gr.Textbox(label="Grid Position and Index", interactive=False)
with gr.Column():
marked_rgb = gr.Image(label="Marked RGB Image")
marked_depth = gr.Image(label="Marked Depth Image")
output_image = gr.Image(label="Visualization Result")
status_text = gr.Textbox(label="Status", interactive=False)
def update_coordinates_and_images(rgb_image, depth_image, x, y):
# 确保坐标在有效范围内
x = max(0, min(640, float(x)))
y = max(0, min(480, float(y)))
# 计算在24x32网格中的位置
H, W = 480//20, 640//20 # 24x32
scaled_x = int(x * W / 640)
scaled_y = int(y * H / 480)
# 计算一维索引
grid_index = scaled_y * W + scaled_x
# 在RGB图像上绘制五角星
rgb_marked = rgb_image.copy()
if len(rgb_marked.shape) == 2: # 如果是灰度图
rgb_marked = cv2.cvtColor(rgb_marked, cv2.COLOR_GRAY2BGR)
elif rgb_marked.shape[2] == 4: # 如果是RGBA
rgb_marked = cv2.cvtColor(rgb_marked, cv2.COLOR_RGBA2BGR)
rgb_marked = draw_star(rgb_marked, int(x), int(y), size=20, color=(255, 0, 0))
# 在深度图像上绘制五角星
depth_marked = depth_image.copy()
if len(depth_marked.shape) == 2: # 如果是单通道
depth_marked = cv2.cvtColor(depth_marked, cv2.COLOR_GRAY2BGR)
depth_marked = draw_star(depth_marked, int(x), int(y), size=20, color=(0, 255, 0))
return (f"Grid position: ({scaled_x}, {scaled_y}), Index: {grid_index}",
rgb_marked,
depth_marked)
# 坐标更新按钮
coord_update_btn = gr.Button("Update Coordinates")
coord_update_btn.click(
fn=update_coordinates_and_images,
inputs=[rgb_input, depth_input, x_coord, y_coord],
outputs=[coordinates_text, marked_rgb, marked_depth]
)
def process_with_status(rgb_image, depth_image, coords_text, x, y):
try:
# 保存临时文件
temp_rgb_path = "temp_rgb.jpg"
temp_depth_path = "temp_depth.png"
# 保存上传的图像
cv2.imwrite(temp_rgb_path, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR))
cv2.imwrite(temp_depth_path, depth_image)
if coords_text:
index = int(coords_text.split("Index: ")[-1])
index_list = [[index]]
else:
index_list = [[584]] # 默认值
result = visualize_geometry_prior(temp_rgb_path, temp_depth_path, index_list=index_list, x=x, y=y)
# 清理临时文件
os.remove(temp_rgb_path)
os.remove(temp_depth_path)
# 转换颜色空间以正确显示
result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
return result, "Processing completed successfully!"
except Exception as e:
# 确保临时文件被删除
if os.path.exists(temp_rgb_path):
os.remove(temp_rgb_path)
if os.path.exists(temp_depth_path):
os.remove(temp_depth_path)
return None, f"Error: {str(e)}"
process_btn = gr.Button("Generate Visualization")
process_btn.click(
fn=process_with_status,
inputs=[rgb_input, depth_input, coordinates_text, x_coord, y_coord],
outputs=[output_image, status_text]
)
# 添加自动更新功能
x_coord.change(
fn=update_coordinates_and_images,
inputs=[rgb_input, depth_input, x_coord, y_coord],
outputs=[coordinates_text, marked_rgb, marked_depth]
)
y_coord.change(
fn=update_coordinates_and_images,
inputs=[rgb_input, depth_input, x_coord, y_coord],
outputs=[coordinates_text, marked_rgb, marked_depth]
)
gr.Examples(
examples=[
["assets/example_rgb.jpg", "assets/example_depth.png"]
],
inputs=[rgb_input, depth_input]
)
return demo
# 启动代码
if __name__ == "__main__":
demo = create_demo()
demo.queue()
demo.launch(
server_name="0.0.0.0",
share=True,
debug=True
)