|
import torch |
|
import torch.nn as nn |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
from torchvision.transforms.functional import to_pil_image |
|
from torchvision.transforms import Resize |
|
import cv2 |
|
import numpy as np |
|
from torchcam.utils import overlay_mask |
|
import gradio as gr |
|
import os |
|
|
|
class GeoPrior(nn.Module): |
|
|
|
def __init__(self, embed_dim=128, num_heads=4, initial_value=2, heads_range=6): |
|
super().__init__() |
|
angle = 1.0 / (10000 ** torch.linspace(0, 1, embed_dim // num_heads // 2)) |
|
angle = angle.unsqueeze(-1).repeat(1, 2).flatten() |
|
self.initial_value = initial_value |
|
self.heads_range = heads_range |
|
self.num_heads = num_heads |
|
decay = torch.log(1 - 2 ** (-initial_value - heads_range * torch.arange(num_heads, dtype=torch.float) / num_heads)) |
|
self.register_buffer('angle', angle) |
|
self.register_buffer('decay', decay) |
|
|
|
def generate_pos_decay(self, H: int, W: int): |
|
''' |
|
generate 2d decay mask, the result is (HW)*(HW) |
|
''' |
|
index_h = torch.arange(H).to(self.decay) |
|
index_w = torch.arange(W).to(self.decay) |
|
grid = torch.meshgrid([index_h, index_w]) |
|
grid = torch.stack(grid, dim=-1).reshape(H*W, 2) |
|
mask = grid[:, None, :] - grid[None, :, :] |
|
mask = (mask.abs()).sum(dim=-1) |
|
mask = mask |
|
return mask |
|
|
|
def generate_2d_depth_decay(self, H: int, W: int, depth_grid): |
|
''' |
|
generate 2d decay mask, the result is (HW)*(HW) |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
B,_,H,W = depth_grid.shape |
|
grid_d = depth_grid.reshape(B, H*W, 1) |
|
print(grid_d.dtype,'aaaaaaaaaaaaaaaaaa') |
|
|
|
mask_d = grid_d[:, :, None, :] - grid_d[:, None,:, :] |
|
|
|
|
|
mask_d = (mask_d.abs()).sum(dim=-1) |
|
|
|
|
|
mask_d = mask_d.unsqueeze(1) |
|
return mask_d |
|
|
|
|
|
|
|
def forward(self, slen, depth_map, activate_recurrent=False, chunkwise_recurrent=False): |
|
''' |
|
slen: (h, w) |
|
h * w == l |
|
recurrent is not implemented |
|
''' |
|
|
|
depth_map = F.interpolate(depth_map, size=slen,mode='bilinear',align_corners=False) |
|
|
|
depth_map = depth_map.float() |
|
|
|
|
|
index = torch.arange(slen[0]*slen[1]).to(self.decay) |
|
sin = torch.sin(index[:, None] * self.angle[None, :]) |
|
sin = sin.reshape(slen[0], slen[1], -1) |
|
cos = torch.cos(index[:, None] * self.angle[None, :]) |
|
cos = cos.reshape(slen[0], slen[1], -1) |
|
mask_1 = self.generate_pos_decay(slen[0], slen[1]) |
|
mask_d = self.generate_2d_depth_decay(slen[0], slen[1], depth_map) |
|
print(torch.max(mask_d),torch.min(mask_d),'-2') |
|
mask = mask_d |
|
mask_sum = (0.85*mask_1.cpu()+0.15*mask) * self.decay[:, None, None].cpu() |
|
retention_rel_pos = ((sin, cos), mask, mask_1, mask_sum) |
|
print(mask.shape,mask_1.shape) |
|
|
|
|
|
return retention_rel_pos |
|
|
|
def fangda(mask, in_size=(480//20,640//20), out_size=(480,640)): |
|
new_mask = torch.zeros(out_size) |
|
ratio_h, ratio_w = out_size[0]//in_size[0], out_size[1]//in_size[1] |
|
for i in range(in_size[0]): |
|
for j in range(in_size[1]): |
|
new_mask[i*ratio_h:(i+1)*ratio_h,j*ratio_w:(j+1)*ratio_w]=mask[i,j] |
|
return new_mask |
|
|
|
def put_mask(image,mask,color_rgb=None,border_mask=False,color_temp='jet',num_c='',beta=2,fixed_num=None): |
|
mask = mask.numpy() |
|
image = cv2.resize(image,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
|
mask = cv2.resize(mask,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
|
color=np.zeros((1,1,3), dtype=np.uint8) |
|
if color_rgb is not None: |
|
color[0,0,2],color[0,0,1],color[0,0,0]=color_rgb |
|
else: |
|
color[0, 0, 2], color[0, 0, 1], color[0, 0, 0]=120,86,87 |
|
if fixed_num is not None: |
|
mask = ((1-mask/255)) |
|
else: |
|
mask=(1-mask/np.max(mask)) |
|
|
|
|
|
|
|
|
|
result = overlay_mask(to_pil_image(image.astype(np.uint8)), to_pil_image(mask), colormap = color_temp, alpha=0.4) |
|
|
|
|
|
return np.array(result) |
|
|
|
|
|
def visualize_geometry_prior(RGB_path, Depth_path, index_list=[[584]], cmap_list = ['jet_r'],x=0,y=0): |
|
|
|
H = 480//20 |
|
W = 640//20 |
|
index_num = int(x//20)+int((y//20+1)*32) |
|
index_list = [[index_num]] |
|
print(index_num) |
|
|
|
grid_d = cv2.imread(Depth_path,0) |
|
|
|
grid_d = cv2.resize(grid_d,dsize=(W,H),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
|
|
|
grid_d = torch.tensor(grid_d).reshape(1,1,H,W) |
|
grid_d_copy=cv2.imread(Depth_path) |
|
grid_d_copy = cv2.resize(grid_d_copy,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
|
grid_d_copy_gray = cv2.imread(Depth_path,0) |
|
grid_d_copy_gray = cv2.resize(grid_d_copy_gray,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
|
print('min max', torch.max(grid_d), torch.min(grid_d)) |
|
print(grid_d.shape) |
|
grid_d=grid_d.cpu() |
|
|
|
respos = GeoPrior() |
|
((sin,cos), depth_map, mask_1, mask_sum) = respos((H,W), grid_d) |
|
print(depth_map.shape, mask_1.shape,'-1') |
|
print(torch.max(depth_map),torch.min(depth_map)) |
|
|
|
|
|
img_path = RGB_path |
|
img = cv2.imread(img_path) |
|
img = cv2.resize(img,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
|
|
|
grid_d_old = cv2.imread(Depth_path,0) |
|
grid_d_old = cv2.resize(grid_d_old,dsize=(W,H),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
|
grid_d_old = torch.tensor(grid_d_old).reshape(H*W,1) |
|
grid_d=grid_d.cpu() |
|
mask_d_old = grid_d_old[:, None, :] - grid_d_old[None, :, :] |
|
mask_d_old = (mask_d_old.abs()).sum(dim=-1) |
|
Color_N=255 |
|
for i in index_list[0]: |
|
for color_temp in cmap_list: |
|
|
|
temp_mask_d = depth_map[0,0,i,:].reshape(H,W).cpu() |
|
|
|
temp_mask = mask_1[i,:].reshape(H,W).cpu() |
|
print(torch.max(temp_mask_d),torch.min(temp_mask_d)) |
|
temp_mask_d_old = mask_d_old[i,:].reshape(H,W).cpu() |
|
temp_mask_sum = mask_sum[0,0,i,:].reshape(H,W).cpu() |
|
temp_mask_d=torch.nn.functional.normalize(temp_mask_d, p=2.0, dim=1, eps=1e-12, out=None) |
|
|
|
temp_mask_d = 255*(temp_mask_d-torch.min(temp_mask_d))/(torch.max(temp_mask_d)-torch.min(temp_mask_d)) |
|
|
|
temp_mask = 255*((temp_mask-torch.min(temp_mask))/(torch.max(temp_mask)-torch.min(temp_mask))) |
|
|
|
temp_mask_sum = 255*((temp_mask_sum-torch.min(temp_mask_sum))/(torch.max(temp_mask_sum)-torch.min(temp_mask_sum))) |
|
gama =0.55 |
|
temp_mask_d_old = 255*(temp_mask_d_old-torch.min(temp_mask_d_old))/(torch.max(temp_mask_d_old)-torch.min(temp_mask_d_old)) |
|
a0=put_mask(img,fangda(temp_mask),color_temp=color_temp) |
|
jiange = 255*torch.ones(img.shape[0],20) |
|
temp_mask_fuse = torch.cat([fangda(temp_mask),jiange,fangda(temp_mask_d),jiange,fangda(gama*temp_mask+(1-gama)*temp_mask_d),jiange,torch.tensor(grid_d_copy_gray)],dim=1) |
|
jiange = np.ones((img.shape[0],20, 3)) * 255 |
|
|
|
a2 = put_mask(img, fangda(temp_mask_d),color_temp=color_temp) |
|
print(a2.shape) |
|
a3 = put_mask(img,fangda(gama*temp_mask+(1-gama)*temp_mask_d),color_temp=color_temp) |
|
|
|
|
|
|
|
|
|
return a3.astype(np.uint8) |
|
|
|
|
|
def process_images(rgb_image, depth_image): |
|
""" |
|
处理上传的图像并返回可视化结果 |
|
|
|
Args: |
|
rgb_image: gradio上传的RGB图像 |
|
depth_image: gradio上传的深度图像 |
|
Returns: |
|
可视化结果图像 |
|
""" |
|
|
|
temp_rgb_path = "temp_rgb.jpg" |
|
temp_depth_path = "temp_depth.png" |
|
|
|
|
|
cv2.imwrite(temp_rgb_path, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) |
|
cv2.imwrite(temp_depth_path, depth_image) |
|
|
|
|
|
try: |
|
result = visualize_geometry_prior(temp_rgb_path, temp_depth_path,x=x,y=y) |
|
|
|
|
|
os.remove(temp_rgb_path) |
|
os.remove(temp_depth_path) |
|
|
|
|
|
result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) |
|
return result |
|
except Exception as e: |
|
print(f"Error during processing: {str(e)}") |
|
return None |
|
finally: |
|
|
|
if os.path.exists(temp_rgb_path): |
|
os.remove(temp_rgb_path) |
|
if os.path.exists(temp_depth_path): |
|
os.remove(temp_depth_path) |
|
|
|
def draw_star(image, x, y, size=20, color=(255, 0, 0), thickness=2): |
|
"""在图像上绘制五角星""" |
|
|
|
pts = np.array([[x, y - size], |
|
[x + size * 0.588, y + size * 0.809], |
|
[x - size * 0.951, y - size * 0.309], |
|
[x + size * 0.951, y - size * 0.309], |
|
[x - size * 0.588, y + size * 0.809]], np.int32) |
|
|
|
|
|
cv2.polylines(image, [pts], True, color, thickness) |
|
return image |
|
|
|
|
|
def create_demo(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Geometry Prior Visualization Demo") |
|
gr.Markdown(""" |
|
### Instructions: |
|
1. Upload RGB and Depth images |
|
2. Enter X (0-640) and Y (0-480) coordinates |
|
3. A star marker will be shown on the images at the selected position |
|
4. Click "Generate Visualization" to create the visualization |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
rgb_input = gr.Image(label="Upload RGB Image") |
|
depth_input = gr.Image(label="Upload Depth Image", image_mode="L") |
|
with gr.Row(): |
|
x_coord = gr.Number(label="X (0-640)", value=160, minimum=0, maximum=640) |
|
y_coord = gr.Number(label="Y (0-480)", value=270, minimum=0, maximum=480) |
|
coordinates_text = gr.Textbox(label="Grid Position and Index", interactive=False) |
|
|
|
with gr.Column(): |
|
marked_rgb = gr.Image(label="Marked RGB Image") |
|
marked_depth = gr.Image(label="Marked Depth Image") |
|
output_image = gr.Image(label="Visualization Result") |
|
status_text = gr.Textbox(label="Status", interactive=False) |
|
|
|
def update_coordinates_and_images(rgb_image, depth_image, x, y): |
|
|
|
x = max(0, min(640, float(x))) |
|
y = max(0, min(480, float(y))) |
|
|
|
|
|
H, W = 480//20, 640//20 |
|
scaled_x = int(x * W / 640) |
|
scaled_y = int(y * H / 480) |
|
|
|
|
|
grid_index = scaled_y * W + scaled_x |
|
|
|
|
|
rgb_marked = rgb_image.copy() |
|
if len(rgb_marked.shape) == 2: |
|
rgb_marked = cv2.cvtColor(rgb_marked, cv2.COLOR_GRAY2BGR) |
|
elif rgb_marked.shape[2] == 4: |
|
rgb_marked = cv2.cvtColor(rgb_marked, cv2.COLOR_RGBA2BGR) |
|
rgb_marked = draw_star(rgb_marked, int(x), int(y), size=20, color=(255, 0, 0)) |
|
|
|
|
|
depth_marked = depth_image.copy() |
|
if len(depth_marked.shape) == 2: |
|
depth_marked = cv2.cvtColor(depth_marked, cv2.COLOR_GRAY2BGR) |
|
depth_marked = draw_star(depth_marked, int(x), int(y), size=20, color=(0, 255, 0)) |
|
|
|
return (f"Grid position: ({scaled_x}, {scaled_y}), Index: {grid_index}", |
|
rgb_marked, |
|
depth_marked) |
|
|
|
|
|
coord_update_btn = gr.Button("Update Coordinates") |
|
coord_update_btn.click( |
|
fn=update_coordinates_and_images, |
|
inputs=[rgb_input, depth_input, x_coord, y_coord], |
|
outputs=[coordinates_text, marked_rgb, marked_depth] |
|
) |
|
|
|
def process_with_status(rgb_image, depth_image, coords_text, x, y): |
|
try: |
|
|
|
temp_rgb_path = "temp_rgb.jpg" |
|
temp_depth_path = "temp_depth.png" |
|
|
|
|
|
cv2.imwrite(temp_rgb_path, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) |
|
cv2.imwrite(temp_depth_path, depth_image) |
|
|
|
if coords_text: |
|
index = int(coords_text.split("Index: ")[-1]) |
|
index_list = [[index]] |
|
else: |
|
index_list = [[584]] |
|
|
|
result = visualize_geometry_prior(temp_rgb_path, temp_depth_path, index_list=index_list, x=x, y=y) |
|
|
|
|
|
os.remove(temp_rgb_path) |
|
os.remove(temp_depth_path) |
|
|
|
|
|
result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) |
|
return result, "Processing completed successfully!" |
|
except Exception as e: |
|
|
|
if os.path.exists(temp_rgb_path): |
|
os.remove(temp_rgb_path) |
|
if os.path.exists(temp_depth_path): |
|
os.remove(temp_depth_path) |
|
return None, f"Error: {str(e)}" |
|
|
|
process_btn = gr.Button("Generate Visualization") |
|
process_btn.click( |
|
fn=process_with_status, |
|
inputs=[rgb_input, depth_input, coordinates_text, x_coord, y_coord], |
|
outputs=[output_image, status_text] |
|
) |
|
|
|
|
|
x_coord.change( |
|
fn=update_coordinates_and_images, |
|
inputs=[rgb_input, depth_input, x_coord, y_coord], |
|
outputs=[coordinates_text, marked_rgb, marked_depth] |
|
) |
|
y_coord.change( |
|
fn=update_coordinates_and_images, |
|
inputs=[rgb_input, depth_input, x_coord, y_coord], |
|
outputs=[coordinates_text, marked_rgb, marked_depth] |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["assets/example_rgb.jpg", "assets/example_depth.png"] |
|
], |
|
inputs=[rgb_input, depth_input] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_demo() |
|
demo.queue() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
share=True, |
|
debug=True |
|
) |
|
|