Spaces:
Running
Running
brentyi
commited on
Commit
·
cdf42c7
1
Parent(s):
a04bda0
Initial commit
Browse files- app.py +64 -0
- requirements.txt +1 -0
- viser_proxy_manager.py +206 -0
app.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import fastapi
|
4 |
+
import gradio as gr
|
5 |
+
import uvicorn
|
6 |
+
|
7 |
+
from viser_proxy_manager import ViserProxyManager
|
8 |
+
|
9 |
+
|
10 |
+
def main() -> None:
|
11 |
+
app = fastapi.FastAPI()
|
12 |
+
viser_manager = ViserProxyManager(app)
|
13 |
+
|
14 |
+
# Create a Gradio interface with title, iframe, and buttons
|
15 |
+
with gr.Blocks(title="Viser Viewer") as demo:
|
16 |
+
# Add a title and description
|
17 |
+
gr.Markdown("# 🌐 Viser Interactive Viewer")
|
18 |
+
|
19 |
+
# Add the iframe with a border
|
20 |
+
add_sphere_btn = gr.Button("Add Random Sphere")
|
21 |
+
iframe_html = gr.HTML("")
|
22 |
+
|
23 |
+
@demo.load(outputs=[iframe_html])
|
24 |
+
def start_server(request: gr.Request):
|
25 |
+
assert request.session_hash is not None
|
26 |
+
viser_manager.start_server(request.session_hash)
|
27 |
+
|
28 |
+
# Use the request's base URL if available
|
29 |
+
host = request.headers["host"]
|
30 |
+
|
31 |
+
return f"""
|
32 |
+
<div style="border: 2px solid #ccc; padding: 10px;">
|
33 |
+
<iframe src="http://{host}/viser/{request.session_hash}/" width="100%" height="500px" frameborder="0"></iframe>
|
34 |
+
</div>
|
35 |
+
"""
|
36 |
+
|
37 |
+
@add_sphere_btn.click
|
38 |
+
def add_random_sphere(request: gr.Request):
|
39 |
+
assert request.session_hash is not None
|
40 |
+
server = viser_manager.get_server(request.session_hash)
|
41 |
+
|
42 |
+
# Add icosphere with random properties
|
43 |
+
server.scene.add_icosphere(
|
44 |
+
name=f"sphere_{random.randint(1, 10000)}",
|
45 |
+
position=(
|
46 |
+
random.uniform(-1, 1),
|
47 |
+
random.uniform(-1, 1),
|
48 |
+
random.uniform(-1, 1),
|
49 |
+
),
|
50 |
+
radius=random.uniform(0.05, 0.2),
|
51 |
+
color=(random.random(), random.random(), random.random()),
|
52 |
+
)
|
53 |
+
|
54 |
+
@demo.unload
|
55 |
+
def stop(request: gr.Request):
|
56 |
+
assert request.session_hash is not None
|
57 |
+
viser_manager.stop_server(request.session_hash)
|
58 |
+
|
59 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
60 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
viser
|
viser_proxy_manager.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
import httpx
|
4 |
+
import viser
|
5 |
+
import websockets
|
6 |
+
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
7 |
+
from fastapi.responses import Response
|
8 |
+
|
9 |
+
|
10 |
+
class ViserProxyManager:
|
11 |
+
"""Manages Viser server instances for Gradio applications.
|
12 |
+
|
13 |
+
This class handles the creation, retrieval, and cleanup of Viser server instances,
|
14 |
+
as well as proxying HTTP and WebSocket requests to the appropriate Viser server.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
app: The FastAPI application to which the proxy routes will be added.
|
18 |
+
min_local_port: Minimum local port number to use for Viser servers. Defaults to 8000.
|
19 |
+
These ports are used only for internal communication and don't need to be publicly exposed.
|
20 |
+
max_local_port: Maximum local port number to use for Viser servers. Defaults to 9000.
|
21 |
+
These ports are used only for internal communication and don't need to be publicly exposed.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
app: FastAPI,
|
27 |
+
min_local_port: int = 8000,
|
28 |
+
max_local_port: int = 9000,
|
29 |
+
) -> None:
|
30 |
+
self._min_port = min_local_port
|
31 |
+
self._max_port = max_local_port
|
32 |
+
self._server_from_session_hash: dict[str, viser.ViserServer] = {}
|
33 |
+
self._last_port = min_local_port - 1 # Track last port tried
|
34 |
+
|
35 |
+
@app.get("/viser/{server_id}/{proxy_path:path}")
|
36 |
+
async def proxy(request: Request, server_id: str, proxy_path: str):
|
37 |
+
"""Proxy HTTP requests to the appropriate Viser server."""
|
38 |
+
# Get the local port for this server ID
|
39 |
+
server = self._server_from_session_hash.get(server_id)
|
40 |
+
if server is None:
|
41 |
+
return Response(content="Server not found", status_code=404)
|
42 |
+
|
43 |
+
# Build target URL
|
44 |
+
if proxy_path:
|
45 |
+
path_suffix = f"/{proxy_path}"
|
46 |
+
else:
|
47 |
+
path_suffix = "/"
|
48 |
+
|
49 |
+
target_url = f"http://127.0.0.1:{server.get_port()}{path_suffix}"
|
50 |
+
if request.url.query:
|
51 |
+
target_url += f"?{request.url.query}"
|
52 |
+
|
53 |
+
# Forward request
|
54 |
+
async with httpx.AsyncClient() as client:
|
55 |
+
# Forward the original headers, but remove any problematic ones
|
56 |
+
headers = dict(request.headers)
|
57 |
+
headers.pop("host", None) # Remove host header to avoid conflicts
|
58 |
+
headers["accept-encoding"] = "identity" # Disable compression
|
59 |
+
|
60 |
+
proxied_req = client.build_request(
|
61 |
+
method=request.method,
|
62 |
+
url=target_url,
|
63 |
+
headers=headers,
|
64 |
+
content=await request.body(),
|
65 |
+
)
|
66 |
+
proxied_resp = await client.send(proxied_req, stream=True)
|
67 |
+
|
68 |
+
# Get response headers
|
69 |
+
response_headers = dict(proxied_resp.headers)
|
70 |
+
|
71 |
+
# Check if this is an HTML response
|
72 |
+
content = await proxied_resp.aread()
|
73 |
+
return Response(
|
74 |
+
content=content,
|
75 |
+
status_code=proxied_resp.status_code,
|
76 |
+
headers=response_headers,
|
77 |
+
)
|
78 |
+
|
79 |
+
# WebSocket Proxy
|
80 |
+
@app.websocket("/viser/{server_id}")
|
81 |
+
async def websocket_proxy(websocket: WebSocket, server_id: str):
|
82 |
+
"""Proxy WebSocket connections to the appropriate Viser server."""
|
83 |
+
await websocket.accept()
|
84 |
+
|
85 |
+
server = self._server_from_session_hash.get(server_id)
|
86 |
+
if server is None:
|
87 |
+
await websocket.close(code=1008, reason="Not Found")
|
88 |
+
return
|
89 |
+
|
90 |
+
# Determine target WebSocket URL
|
91 |
+
target_ws_url = f"ws://127.0.0.1:{server.get_port()}"
|
92 |
+
|
93 |
+
if not target_ws_url:
|
94 |
+
await websocket.close(code=1008, reason="Not Found")
|
95 |
+
return
|
96 |
+
|
97 |
+
try:
|
98 |
+
# Connect to the target WebSocket
|
99 |
+
async with websockets.connect(target_ws_url) as ws_target:
|
100 |
+
# Create tasks for bidirectional communication
|
101 |
+
async def forward_to_target():
|
102 |
+
"""Forward messages from the client to the target WebSocket."""
|
103 |
+
try:
|
104 |
+
while True:
|
105 |
+
data = await websocket.receive_bytes()
|
106 |
+
await ws_target.send(data, text=False)
|
107 |
+
except WebSocketDisconnect:
|
108 |
+
try:
|
109 |
+
await ws_target.close()
|
110 |
+
except RuntimeError:
|
111 |
+
pass
|
112 |
+
|
113 |
+
async def forward_from_target():
|
114 |
+
"""Forward messages from the target WebSocket to the client."""
|
115 |
+
try:
|
116 |
+
while True:
|
117 |
+
data = await ws_target.recv(decode=False)
|
118 |
+
await websocket.send_bytes(data)
|
119 |
+
except websockets.exceptions.ConnectionClosed:
|
120 |
+
try:
|
121 |
+
await websocket.close()
|
122 |
+
except RuntimeError:
|
123 |
+
pass
|
124 |
+
|
125 |
+
# Run both forwarding tasks concurrently
|
126 |
+
forward_task = asyncio.create_task(forward_to_target())
|
127 |
+
backward_task = asyncio.create_task(forward_from_target())
|
128 |
+
|
129 |
+
# Wait for either task to complete (which means a connection was closed)
|
130 |
+
done, pending = await asyncio.wait(
|
131 |
+
[forward_task, backward_task],
|
132 |
+
return_when=asyncio.FIRST_COMPLETED,
|
133 |
+
)
|
134 |
+
|
135 |
+
# Cancel the remaining task
|
136 |
+
for task in pending:
|
137 |
+
task.cancel()
|
138 |
+
|
139 |
+
except Exception as e:
|
140 |
+
print(f"WebSocket proxy error: {e}")
|
141 |
+
await websocket.close(code=1011, reason=str(e))
|
142 |
+
|
143 |
+
def start_server(self, server_id: str) -> viser.ViserServer:
|
144 |
+
"""Start a new Viser server and associate it with the given server ID.
|
145 |
+
|
146 |
+
Finds an available port within the configured min_local_port and max_local_port range.
|
147 |
+
These ports are used only for internal communication and don't need to be publicly exposed.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
server_id: The unique identifier to associate with the new server.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
The newly created Viser server instance.
|
154 |
+
|
155 |
+
Raises:
|
156 |
+
RuntimeError: If no free ports are available in the configured range.
|
157 |
+
"""
|
158 |
+
import socket
|
159 |
+
|
160 |
+
# Start searching from the last port + 1 (with wraparound)
|
161 |
+
port_range_size = self._max_port - self._min_port + 1
|
162 |
+
start_port = (
|
163 |
+
(self._last_port + 1 - self._min_port) % port_range_size
|
164 |
+
) + self._min_port
|
165 |
+
|
166 |
+
# Try each port once
|
167 |
+
for offset in range(port_range_size):
|
168 |
+
port = (
|
169 |
+
(start_port - self._min_port + offset) % port_range_size
|
170 |
+
) + self._min_port
|
171 |
+
try:
|
172 |
+
# Check if port is available by attempting to bind to it
|
173 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
174 |
+
s.bind(("127.0.0.1", port))
|
175 |
+
# Port is available, create server with this port
|
176 |
+
server = viser.ViserServer(port=port)
|
177 |
+
self._server_from_session_hash[server_id] = server
|
178 |
+
self._last_port = port
|
179 |
+
return server
|
180 |
+
except OSError:
|
181 |
+
# Port is in use, try the next one
|
182 |
+
continue
|
183 |
+
|
184 |
+
# If we get here, no ports were available
|
185 |
+
raise RuntimeError(
|
186 |
+
f"No available local ports in range {self._min_port}-{self._max_port}"
|
187 |
+
)
|
188 |
+
|
189 |
+
def get_server(self, server_id: str) -> viser.ViserServer:
|
190 |
+
"""Retrieve a Viser server instance by its ID.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
server_id: The unique identifier of the server to retrieve.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
The Viser server instance associated with the given ID.
|
197 |
+
"""
|
198 |
+
return self._server_from_session_hash[server_id]
|
199 |
+
|
200 |
+
def stop_server(self, server_id: str) -> None:
|
201 |
+
"""Stop a Viser server and remove it from the manager.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
server_id: The unique identifier of the server to stop.
|
205 |
+
"""
|
206 |
+
self._server_from_session_hash.pop(server_id).stop()
|