import torch import gradio as gr import plotly.graph_objects as go import trimesh from pathlib import Path device = torch.device("cpu") model = torch.jit.load('model_scripted.pt').to(device) def normalize_vertices(verts): # Center the vertices center = verts.mean(dim=0) verts = verts - center # Find the maximum absolute value for each axis to scale them independently scale = verts.abs().max(dim=0)[0] # This finds the max in each dimension independently # Scale the vertices so that in each dimension, the furthest point is exactly at 1 or -1 # We avoid division by zero by ensuring scale values are at least a very small number scale = torch.where(scale == 0, torch.ones_like(scale), scale) # Prevent division by zero return verts / scale def plot_3d_results(verts, faces, uv_seam_edge_indices): # Convert vertices to NumPy for easier manipulation verts_np = verts.cpu().numpy() faces_np = faces.cpu().numpy() # Prepare the vertex coordinates for the Mesh3d plot x, y, z = verts_np[:, 0], verts_np[:, 1], verts_np[:, 2] i, j, k = faces_np[:, 0], faces_np[:, 1], faces_np[:, 2] # Create the 3D mesh plot mesh = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='lightblue', opacity=0.50, name='Mesh') # Prepare lines for the predicted edges edge_x, edge_y, edge_z = [], [], [] for edge in uv_seam_edge_indices: x0, y0, z0 = verts_np[edge[0]] x1, y1, z1 = verts_np[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) edge_z.extend([z0, z1, None]) # Create a trace for edges edges_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(color='red', width=2), name='Predicted Edges') # Create a figure and add the mesh and edges fig = go.Figure(data=[mesh, edges_trace]) fig.update_layout(scene=dict( xaxis=dict(nticks=4, backgroundcolor="rgb(200, 200, 230)", gridcolor="white", showbackground=True, zerolinecolor="white"), yaxis=dict(nticks=4, backgroundcolor="rgb(230, 200,230)", gridcolor="white", showbackground=True, zerolinecolor="white"), zaxis=dict(nticks=4, backgroundcolor="rgb(230, 230,200)", gridcolor="white", showbackground=True, zerolinecolor="white"), camera=dict(up=dict(x=0, y=1, z=0), eye=dict(x=1.25, y=1.25, z=1.25))), title_text='Predicted Edges') # return the figure return fig def generate_prediction(file_input, treshold_value=0.5): if not file_input: return # Load and triangulate the mesh mesh = trimesh.load_mesh(file_input) # For production, we should use a faster method to preprocess the mesh! # Convert vertices to a PyTorch tensor vertices = torch.tensor(mesh.vertices, dtype=torch.float32) vertices = normalize_vertices(vertices) # Initialize containers for unique vertices and mapping unique_vertices = [] vertex_mapping = {} new_faces = [] # Populate unique vertices and create new faces with updated indices for face in mesh.faces: new_face = [] for orig_index in face: vertex = tuple(vertices[orig_index].tolist()) # Convert to tuple (hashable) if vertex not in vertex_mapping: vertex_mapping[vertex] = len(unique_vertices) unique_vertices.append(vertices[orig_index]) new_face.append(vertex_mapping[vertex]) new_faces.append(new_face) # Create edge set to ensure uniqueness edge_set = set() for face in new_faces: # Unpack the vertex indices v1, v2, v3 = face # Create undirected edges (use tuple sorting to ensure uniqueness) edge_set.add(tuple(sorted((v1, v2)))) edge_set.add(tuple(sorted((v2, v3)))) edge_set.add(tuple(sorted((v1, v3)))) # Convert edges back to tensor edges = torch.tensor(list(edge_set), dtype=torch.long) # Convert unique vertices and new faces back to tensors verts = torch.stack(unique_vertices) faces = torch.tensor(new_faces, dtype=torch.long) model.eval() with torch.no_grad(): test_outputs_logits = model(verts, edges).to(device) test_outputs = torch.sigmoid(test_outputs_logits).to(device) test_predictions = (test_outputs > treshold_value).int().cpu() uv_seam_edges_mask = test_predictions.cpu().squeeze() == 1 uv_seam_edges = edges[uv_seam_edges_mask].cpu().tolist() # Return the HTML content generated by plot_3d_results return plot_3d_results(verts, faces, uv_seam_edges) def run_gradio(): with gr.Blocks() as demo: gr.Label("Proof of concept demo. Predict UV seams on a 3D sphere meshes.") with gr.Row(): model3d_input = gr.FileExplorer(label="Sphere Prototype Model", file_count='single', value='randomSphere_180.obj', glob='**/*.obj') with gr.Column(): model3d_output = gr.Plot() treshold_value = gr.Slider(minimum=0, maximum=1, value=0.6, label="Threshold") button = gr.Button("Predict") button.click(generate_prediction, inputs=[model3d_input, treshold_value], outputs=model3d_output) demo.launch() run_gradio()