brentyi commited on
Commit
cdf42c7
·
1 Parent(s): a04bda0

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +64 -0
  2. requirements.txt +1 -0
  3. viser_proxy_manager.py +206 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import fastapi
4
+ import gradio as gr
5
+ import uvicorn
6
+
7
+ from viser_proxy_manager import ViserProxyManager
8
+
9
+
10
+ def main() -> None:
11
+ app = fastapi.FastAPI()
12
+ viser_manager = ViserProxyManager(app)
13
+
14
+ # Create a Gradio interface with title, iframe, and buttons
15
+ with gr.Blocks(title="Viser Viewer") as demo:
16
+ # Add a title and description
17
+ gr.Markdown("# 🌐 Viser Interactive Viewer")
18
+
19
+ # Add the iframe with a border
20
+ add_sphere_btn = gr.Button("Add Random Sphere")
21
+ iframe_html = gr.HTML("")
22
+
23
+ @demo.load(outputs=[iframe_html])
24
+ def start_server(request: gr.Request):
25
+ assert request.session_hash is not None
26
+ viser_manager.start_server(request.session_hash)
27
+
28
+ # Use the request's base URL if available
29
+ host = request.headers["host"]
30
+
31
+ return f"""
32
+ <div style="border: 2px solid #ccc; padding: 10px;">
33
+ <iframe src="http://{host}/viser/{request.session_hash}/" width="100%" height="500px" frameborder="0"></iframe>
34
+ </div>
35
+ """
36
+
37
+ @add_sphere_btn.click
38
+ def add_random_sphere(request: gr.Request):
39
+ assert request.session_hash is not None
40
+ server = viser_manager.get_server(request.session_hash)
41
+
42
+ # Add icosphere with random properties
43
+ server.scene.add_icosphere(
44
+ name=f"sphere_{random.randint(1, 10000)}",
45
+ position=(
46
+ random.uniform(-1, 1),
47
+ random.uniform(-1, 1),
48
+ random.uniform(-1, 1),
49
+ ),
50
+ radius=random.uniform(0.05, 0.2),
51
+ color=(random.random(), random.random(), random.random()),
52
+ )
53
+
54
+ @demo.unload
55
+ def stop(request: gr.Request):
56
+ assert request.session_hash is not None
57
+ viser_manager.stop_server(request.session_hash)
58
+
59
+ app = gr.mount_gradio_app(app, demo, path="/")
60
+ uvicorn.run(app, host="0.0.0.0", port=7860)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ main()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ viser
viser_proxy_manager.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import httpx
4
+ import viser
5
+ import websockets
6
+ from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
7
+ from fastapi.responses import Response
8
+
9
+
10
+ class ViserProxyManager:
11
+ """Manages Viser server instances for Gradio applications.
12
+
13
+ This class handles the creation, retrieval, and cleanup of Viser server instances,
14
+ as well as proxying HTTP and WebSocket requests to the appropriate Viser server.
15
+
16
+ Args:
17
+ app: The FastAPI application to which the proxy routes will be added.
18
+ min_local_port: Minimum local port number to use for Viser servers. Defaults to 8000.
19
+ These ports are used only for internal communication and don't need to be publicly exposed.
20
+ max_local_port: Maximum local port number to use for Viser servers. Defaults to 9000.
21
+ These ports are used only for internal communication and don't need to be publicly exposed.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ app: FastAPI,
27
+ min_local_port: int = 8000,
28
+ max_local_port: int = 9000,
29
+ ) -> None:
30
+ self._min_port = min_local_port
31
+ self._max_port = max_local_port
32
+ self._server_from_session_hash: dict[str, viser.ViserServer] = {}
33
+ self._last_port = min_local_port - 1 # Track last port tried
34
+
35
+ @app.get("/viser/{server_id}/{proxy_path:path}")
36
+ async def proxy(request: Request, server_id: str, proxy_path: str):
37
+ """Proxy HTTP requests to the appropriate Viser server."""
38
+ # Get the local port for this server ID
39
+ server = self._server_from_session_hash.get(server_id)
40
+ if server is None:
41
+ return Response(content="Server not found", status_code=404)
42
+
43
+ # Build target URL
44
+ if proxy_path:
45
+ path_suffix = f"/{proxy_path}"
46
+ else:
47
+ path_suffix = "/"
48
+
49
+ target_url = f"http://127.0.0.1:{server.get_port()}{path_suffix}"
50
+ if request.url.query:
51
+ target_url += f"?{request.url.query}"
52
+
53
+ # Forward request
54
+ async with httpx.AsyncClient() as client:
55
+ # Forward the original headers, but remove any problematic ones
56
+ headers = dict(request.headers)
57
+ headers.pop("host", None) # Remove host header to avoid conflicts
58
+ headers["accept-encoding"] = "identity" # Disable compression
59
+
60
+ proxied_req = client.build_request(
61
+ method=request.method,
62
+ url=target_url,
63
+ headers=headers,
64
+ content=await request.body(),
65
+ )
66
+ proxied_resp = await client.send(proxied_req, stream=True)
67
+
68
+ # Get response headers
69
+ response_headers = dict(proxied_resp.headers)
70
+
71
+ # Check if this is an HTML response
72
+ content = await proxied_resp.aread()
73
+ return Response(
74
+ content=content,
75
+ status_code=proxied_resp.status_code,
76
+ headers=response_headers,
77
+ )
78
+
79
+ # WebSocket Proxy
80
+ @app.websocket("/viser/{server_id}")
81
+ async def websocket_proxy(websocket: WebSocket, server_id: str):
82
+ """Proxy WebSocket connections to the appropriate Viser server."""
83
+ await websocket.accept()
84
+
85
+ server = self._server_from_session_hash.get(server_id)
86
+ if server is None:
87
+ await websocket.close(code=1008, reason="Not Found")
88
+ return
89
+
90
+ # Determine target WebSocket URL
91
+ target_ws_url = f"ws://127.0.0.1:{server.get_port()}"
92
+
93
+ if not target_ws_url:
94
+ await websocket.close(code=1008, reason="Not Found")
95
+ return
96
+
97
+ try:
98
+ # Connect to the target WebSocket
99
+ async with websockets.connect(target_ws_url) as ws_target:
100
+ # Create tasks for bidirectional communication
101
+ async def forward_to_target():
102
+ """Forward messages from the client to the target WebSocket."""
103
+ try:
104
+ while True:
105
+ data = await websocket.receive_bytes()
106
+ await ws_target.send(data, text=False)
107
+ except WebSocketDisconnect:
108
+ try:
109
+ await ws_target.close()
110
+ except RuntimeError:
111
+ pass
112
+
113
+ async def forward_from_target():
114
+ """Forward messages from the target WebSocket to the client."""
115
+ try:
116
+ while True:
117
+ data = await ws_target.recv(decode=False)
118
+ await websocket.send_bytes(data)
119
+ except websockets.exceptions.ConnectionClosed:
120
+ try:
121
+ await websocket.close()
122
+ except RuntimeError:
123
+ pass
124
+
125
+ # Run both forwarding tasks concurrently
126
+ forward_task = asyncio.create_task(forward_to_target())
127
+ backward_task = asyncio.create_task(forward_from_target())
128
+
129
+ # Wait for either task to complete (which means a connection was closed)
130
+ done, pending = await asyncio.wait(
131
+ [forward_task, backward_task],
132
+ return_when=asyncio.FIRST_COMPLETED,
133
+ )
134
+
135
+ # Cancel the remaining task
136
+ for task in pending:
137
+ task.cancel()
138
+
139
+ except Exception as e:
140
+ print(f"WebSocket proxy error: {e}")
141
+ await websocket.close(code=1011, reason=str(e))
142
+
143
+ def start_server(self, server_id: str) -> viser.ViserServer:
144
+ """Start a new Viser server and associate it with the given server ID.
145
+
146
+ Finds an available port within the configured min_local_port and max_local_port range.
147
+ These ports are used only for internal communication and don't need to be publicly exposed.
148
+
149
+ Args:
150
+ server_id: The unique identifier to associate with the new server.
151
+
152
+ Returns:
153
+ The newly created Viser server instance.
154
+
155
+ Raises:
156
+ RuntimeError: If no free ports are available in the configured range.
157
+ """
158
+ import socket
159
+
160
+ # Start searching from the last port + 1 (with wraparound)
161
+ port_range_size = self._max_port - self._min_port + 1
162
+ start_port = (
163
+ (self._last_port + 1 - self._min_port) % port_range_size
164
+ ) + self._min_port
165
+
166
+ # Try each port once
167
+ for offset in range(port_range_size):
168
+ port = (
169
+ (start_port - self._min_port + offset) % port_range_size
170
+ ) + self._min_port
171
+ try:
172
+ # Check if port is available by attempting to bind to it
173
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
174
+ s.bind(("127.0.0.1", port))
175
+ # Port is available, create server with this port
176
+ server = viser.ViserServer(port=port)
177
+ self._server_from_session_hash[server_id] = server
178
+ self._last_port = port
179
+ return server
180
+ except OSError:
181
+ # Port is in use, try the next one
182
+ continue
183
+
184
+ # If we get here, no ports were available
185
+ raise RuntimeError(
186
+ f"No available local ports in range {self._min_port}-{self._max_port}"
187
+ )
188
+
189
+ def get_server(self, server_id: str) -> viser.ViserServer:
190
+ """Retrieve a Viser server instance by its ID.
191
+
192
+ Args:
193
+ server_id: The unique identifier of the server to retrieve.
194
+
195
+ Returns:
196
+ The Viser server instance associated with the given ID.
197
+ """
198
+ return self._server_from_session_hash[server_id]
199
+
200
+ def stop_server(self, server_id: str) -> None:
201
+ """Stop a Viser server and remove it from the manager.
202
+
203
+ Args:
204
+ server_id: The unique identifier of the server to stop.
205
+ """
206
+ self._server_from_session_hash.pop(server_id).stop()