ProximileAdmin commited on
Commit
4518a2e
·
verified ·
1 Parent(s): 2c24615

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +536 -0
app.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio Interface for Multimodal Chat with SSH Tunnel Keepalive and API Fallback
4
+
5
+ This application provides a Gradio web interface for multimodal chat with a
6
+ local vLLM model. It establishes an SSH tunnel to a local vLLM server and
7
+ provides fallback to Hyperbolic API if that server is unavailable.
8
+ """
9
+
10
+ import os
11
+ import time
12
+ import threading
13
+ import logging
14
+ import base64
15
+ import json
16
+ from io import BytesIO
17
+ import gradio as gr
18
+ from openai import OpenAI
19
+ from ssh_tunneler import SSHTunnel
20
+
21
+ # Configure logging
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
25
+ )
26
+ logger = logging.getLogger('app')
27
+
28
+ # Get environment variables
29
+ SSH_HOST = os.environ.get('SSH_HOST')
30
+ SSH_PORT = int(os.environ.get('SSH_PORT', 22))
31
+ SSH_USERNAME = os.environ.get('SSH_USERNAME')
32
+ SSH_PASSWORD = os.environ.get('SSH_PASSWORD')
33
+ REMOTE_PORT = int(os.environ.get('REMOTE_PORT', 8000)) # vLLM API port on remote machine
34
+ LOCAL_PORT = int(os.environ.get('LOCAL_PORT', 8020)) # Local forwarded port
35
+ VLLM_MODEL = os.environ.get('MODEL_NAME', 'google/gemma-3-27b-it')
36
+ HYPERBOLIC_KEY = os.environ.get('HYPERBOLIC_XYZ_KEY')
37
+ FALLBACK_MODEL = 'Qwen/Qwen2.5-VL-72B-Instruct' # Fallback model at Hyperbolic
38
+
39
+ # API endpoints
40
+ VLLM_ENDPOINT = "http://localhost:" + str(LOCAL_PORT) + "/v1"
41
+ HYPERBOLIC_ENDPOINT = "https://api.hyperbolic.xyz/v1"
42
+
43
+ # Global variables
44
+ tunnel = None
45
+ use_fallback = False # Whether to use fallback API instead of local vLLM
46
+ tunnel_status = {"is_running": False, "message": "Initializing tunnel..."}
47
+
48
+ def start_ssh_tunnel():
49
+ """
50
+ Start the SSH tunnel and monitor its status.
51
+ """
52
+ global tunnel, use_fallback, tunnel_status
53
+
54
+ if not all([SSH_HOST, SSH_USERNAME, SSH_PASSWORD]):
55
+ logger.error("Missing SSH connection details. Falling back to Hyperbolic API.")
56
+ use_fallback = True
57
+ tunnel_status = {"is_running": False, "message": "Missing SSH credentials"}
58
+ return
59
+
60
+ try:
61
+ logger.info("Starting SSH tunnel...")
62
+ tunnel = SSHTunnel(
63
+ ssh_host=SSH_HOST,
64
+ ssh_port=SSH_PORT,
65
+ username=SSH_USERNAME,
66
+ password=SSH_PASSWORD,
67
+ remote_port=REMOTE_PORT,
68
+ local_port=LOCAL_PORT,
69
+ reconnect_interval=30,
70
+ keep_alive_interval=15
71
+ )
72
+
73
+ if tunnel.start():
74
+ logger.info("SSH tunnel started successfully")
75
+ use_fallback = False
76
+ tunnel_status = {"is_running": True, "message": "Connected"}
77
+ else:
78
+ logger.warning("Failed to start SSH tunnel. Falling back to Hyperbolic API.")
79
+ use_fallback = True
80
+ tunnel_status = {"is_running": False, "message": "Failed to connect"}
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error starting SSH tunnel: {str(e)}")
84
+ use_fallback = True
85
+ tunnel_status = {"is_running": False, "message": f"Error: {str(e)}"}
86
+
87
+ def check_vllm_api_health():
88
+ """
89
+ Check if the vLLM API is actually responding by querying the /v1/models endpoint.
90
+
91
+ Returns:
92
+ tuple: (is_healthy, message)
93
+ """
94
+ try:
95
+ import requests
96
+ response = requests.get(f"{VLLM_ENDPOINT}/models", timeout=5)
97
+ if response.status_code == 200:
98
+ try:
99
+ data = response.json()
100
+ if 'data' in data and len(data['data']) > 0:
101
+ model_id = data['data'][0].get('id', 'Unknown model')
102
+ return True, f"API is healthy. Available model: {model_id}"
103
+ else:
104
+ return True, "API is healthy but no models found"
105
+ except Exception as e:
106
+ return False, f"API returned 200 but invalid JSON: {str(e)}"
107
+ else:
108
+ return False, f"API returned status code: {response.status_code}"
109
+ except Exception as e:
110
+ return False, f"API request failed: {str(e)}"
111
+
112
+ def monitor_tunnel():
113
+ """
114
+ Monitor the SSH tunnel status and update the global variables.
115
+ """
116
+ global tunnel, use_fallback, tunnel_status
117
+
118
+ logger.info("Starting tunnel monitoring thread")
119
+
120
+ while True:
121
+ try:
122
+ if tunnel is not None:
123
+ ssh_status = tunnel.check_status()
124
+
125
+ # Check if the tunnel is running
126
+ if ssh_status["is_running"]:
127
+ # Check if vLLM API is actually responding
128
+ is_healthy, message = check_vllm_api_health()
129
+
130
+ if is_healthy:
131
+ use_fallback = False
132
+ tunnel_status = {
133
+ "is_running": True,
134
+ "message": f"Connected and healthy. {message}"
135
+ }
136
+ else:
137
+ use_fallback = True
138
+ tunnel_status = {
139
+ "is_running": False,
140
+ "message": f"Tunnel connected but vLLM API unhealthy: {message}"
141
+ }
142
+ else:
143
+ use_fallback = True
144
+ tunnel_status = {
145
+ "is_running": False,
146
+ "message": f"Disconnected: {ssh_status['error'] or 'Unknown error'}"
147
+ }
148
+ else:
149
+ use_fallback = True
150
+ tunnel_status = {"is_running": False, "message": "Tunnel not initialized"}
151
+
152
+ except Exception as e:
153
+ logger.error(f"Error monitoring tunnel: {str(e)}")
154
+ use_fallback = True
155
+ tunnel_status = {"is_running": False, "message": f"Monitoring error: {str(e)}"}
156
+
157
+ time.sleep(5) # Check every 5 seconds
158
+
159
+ def get_openai_client(use_fallback_api=None):
160
+ """
161
+ Create and return an OpenAI client configured for the appropriate endpoint.
162
+
163
+ Args:
164
+ use_fallback_api (bool): If True, use Hyperbolic API. If False, use local vLLM.
165
+ If None, use the global use_fallback setting.
166
+
167
+ Returns:
168
+ OpenAI: Configured OpenAI client
169
+ """
170
+ global use_fallback
171
+
172
+ # Determine which API to use
173
+ if use_fallback_api is None:
174
+ use_fallback_api = use_fallback
175
+
176
+ if use_fallback_api:
177
+ logger.info("Using Hyperbolic API")
178
+ return OpenAI(
179
+ api_key=HYPERBOLIC_KEY,
180
+ base_url=HYPERBOLIC_ENDPOINT
181
+ )
182
+ else:
183
+ logger.info("Using local vLLM API")
184
+ return OpenAI(
185
+ api_key="EMPTY", # vLLM doesn't require an actual API key
186
+ base_url=VLLM_ENDPOINT
187
+ )
188
+
189
+ def get_model_name(use_fallback_api=None):
190
+ """
191
+ Return the appropriate model name based on the API being used.
192
+
193
+ Args:
194
+ use_fallback_api (bool): If True, use fallback model. If None, use the global setting.
195
+
196
+ Returns:
197
+ str: Model name
198
+ """
199
+ global use_fallback
200
+
201
+ if use_fallback_api is None:
202
+ use_fallback_api = use_fallback
203
+
204
+ return FALLBACK_MODEL if use_fallback_api else VLLM_MODEL
205
+
206
+ def convert_files_to_base64(files):
207
+ """
208
+ Convert uploaded files to base64 strings.
209
+
210
+ Args:
211
+ files (list): List of file paths
212
+
213
+ Returns:
214
+ list: List of base64-encoded strings
215
+ """
216
+ base64_images = []
217
+ for file in files:
218
+ with open(file, "rb") as image_file:
219
+ # Read image data and encode to base64
220
+ base64_data = base64.b64encode(image_file.read()).decode("utf-8")
221
+ base64_images.append(base64_data)
222
+ return base64_images
223
+
224
+ def process_chat(message_dict, history):
225
+ """
226
+ Process user message and send to the appropriate API.
227
+
228
+ Args:
229
+ message_dict (dict): User message containing text and files
230
+ history (list): Chat history
231
+
232
+ Returns:
233
+ list: Updated chat history
234
+ """
235
+ global use_fallback
236
+
237
+ text = message_dict.get("text", "")
238
+ files = message_dict.get("files", [])
239
+
240
+ # Add user message to history first
241
+ if not history:
242
+ history = []
243
+
244
+ # Add user message to chat history
245
+ if files:
246
+ # For each file, add a separate user message
247
+ for file in files:
248
+ history.append({"role": "user", "content": (file,)})
249
+
250
+ # Add text message if not empty
251
+ if text.strip():
252
+ history.append({"role": "user", "content": text})
253
+ else:
254
+ # If no text but files exist, don't add an empty message
255
+ if not files:
256
+ history.append({"role": "user", "content": ""})
257
+
258
+ # Convert all files to base64
259
+ base64_images = convert_files_to_base64(files)
260
+
261
+ # Prepare conversation history in OpenAI format
262
+ openai_messages = []
263
+
264
+ # Convert history to OpenAI format
265
+ for h in history:
266
+ if h["role"] == "user":
267
+ # Handle user messages
268
+ if isinstance(h["content"], tuple):
269
+ # This is a file-only message, skip for now
270
+ continue
271
+ else:
272
+ # Text message
273
+ openai_messages.append({
274
+ "role": "user",
275
+ "content": h["content"]
276
+ })
277
+ elif h["role"] == "assistant":
278
+ openai_messages.append({
279
+ "role": "assistant",
280
+ "content": h["content"]
281
+ })
282
+
283
+ # Handle images for the last user message if needed
284
+ if base64_images:
285
+ # Update the last user message to include image content
286
+ if openai_messages and openai_messages[-1]["role"] == "user":
287
+ # Get the last message
288
+ last_msg = openai_messages[-1]
289
+
290
+ # Format for OpenAI multimodal content structure
291
+ content_list = []
292
+
293
+ # Add text if there is any
294
+ if last_msg["content"]:
295
+ content_list.append({"type": "text", "text": last_msg["content"]})
296
+
297
+ # Add images
298
+ for img_b64 in base64_images:
299
+ content_list.append({
300
+ "type": "image_url",
301
+ "image_url": {
302
+ "url": f"data:image/jpeg;base64,{img_b64}"
303
+ }
304
+ })
305
+
306
+ # Replace the content with the multimodal content list
307
+ last_msg["content"] = content_list
308
+
309
+ # Try primary API first, fall back if needed
310
+ try:
311
+ # First try with the currently selected API (vLLM or fallback)
312
+ client = get_openai_client()
313
+ model = get_model_name()
314
+
315
+ response = client.chat.completions.create(
316
+ model=model,
317
+ messages=openai_messages,
318
+ stream=True # Use streaming for better UX
319
+ )
320
+
321
+ # Stream the response
322
+ assistant_message = ""
323
+ for chunk in response:
324
+ if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None:
325
+ assistant_message += chunk.choices[0].delta.content
326
+ # Update in real-time
327
+ history_with_stream = history.copy()
328
+ history_with_stream.append({"role": "assistant", "content": assistant_message})
329
+ yield history_with_stream
330
+
331
+ # Ensure we have the final message added
332
+ if not assistant_message:
333
+ assistant_message = "No response received from the model."
334
+
335
+ # Add assistant response to history if not already added
336
+ if not history or history[-1]["role"] != "assistant":
337
+ history.append({"role": "assistant", "content": assistant_message})
338
+
339
+ return history
340
+
341
+ except Exception as primary_error:
342
+ logger.error(f"Primary API error: {str(primary_error)}")
343
+
344
+ # If we're not already using fallback, try that
345
+ if not use_fallback:
346
+ try:
347
+ logger.info("Falling back to Hyperbolic API")
348
+ client = get_openai_client(use_fallback_api=True)
349
+ model = get_model_name(use_fallback_api=True)
350
+
351
+ response = client.chat.completions.create(
352
+ model=model,
353
+ messages=openai_messages,
354
+ stream=True
355
+ )
356
+
357
+ # Stream the response
358
+ assistant_message = ""
359
+ for chunk in response:
360
+ if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None:
361
+ assistant_message += chunk.choices[0].delta.content
362
+ # Update in real-time
363
+ history_with_stream = history.copy()
364
+ history_with_stream.append({"role": "assistant", "content": assistant_message})
365
+ yield history_with_stream
366
+
367
+ # Ensure we have the final message added
368
+ if not assistant_message:
369
+ assistant_message = "No response received from the fallback model."
370
+
371
+ # Add assistant response to history if not already added
372
+ if not history or history[-1]["role"] != "assistant":
373
+ history.append({"role": "assistant", "content": assistant_message})
374
+
375
+ # Update fallback status (global already declared at function start)
376
+ use_fallback = True
377
+
378
+ return history
379
+
380
+ except Exception as fallback_error:
381
+ logger.error(f"Fallback API error: {str(fallback_error)}")
382
+ error_msg = f"Error with both primary and fallback APIs. Primary: {str(primary_error)}. Fallback: {str(fallback_error)}"
383
+ history.append({"role": "assistant", "content": error_msg})
384
+ return history
385
+ else:
386
+ # Already using fallback, just report the error
387
+ error_msg = f"An error occurred with the model: {str(primary_error)}"
388
+ history.append({"role": "assistant", "content": error_msg})
389
+ return history
390
+
391
+ def get_tunnel_status_message():
392
+ """
393
+ Return a formatted status message for display in the UI.
394
+ """
395
+ global tunnel_status, use_fallback
396
+
397
+ api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API"
398
+ model = get_model_name()
399
+
400
+ status_color = "🟢" if (tunnel_status["is_running"] and not use_fallback) else "🔴"
401
+ status_text = tunnel_status["message"]
402
+
403
+ return f"{status_color} Tunnel Status: {status_text}\nCurrent API: {api_mode}\nCurrent Model: {model}"
404
+
405
+ def toggle_api():
406
+ """
407
+ Toggle between local vLLM and Hyperbolic API.
408
+ """
409
+ global use_fallback
410
+ use_fallback = not use_fallback
411
+
412
+ api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API"
413
+ model = get_model_name()
414
+
415
+ return f"Switched to {api_mode} using {model}"
416
+
417
+ # Start the SSH tunnel in a background thread
418
+ if __name__ == "__main__":
419
+ # Start the SSH tunnel
420
+ start_ssh_tunnel()
421
+
422
+ # Start the monitoring thread
423
+ monitor_thread = threading.Thread(target=monitor_tunnel, daemon=True)
424
+ monitor_thread.start()
425
+
426
+ # Create Gradio application with Blocks for more control
427
+ with gr.Blocks(theme="soft") as demo:
428
+ gr.Markdown("# Multimodal Chat Interface")
429
+
430
+ # Create chatbot component with message type
431
+ chatbot = gr.Chatbot(
432
+ label="Conversation",
433
+ type="messages",
434
+ show_copy_button=True,
435
+ avatar_images=("👤", "🗣️"),
436
+ height=400
437
+ )
438
+
439
+ # Create multimodal textbox for input
440
+ with gr.Row():
441
+ textbox = gr.MultimodalTextbox(
442
+ file_types=["image", "video"],
443
+ file_count="multiple",
444
+ placeholder="Type your message here and/or upload images...",
445
+ label="Message",
446
+ show_label=False,
447
+ scale=9
448
+ )
449
+ submit_btn = gr.Button("Send", size="sm", scale=1)
450
+
451
+ # Clear button
452
+ clear_btn = gr.Button("Clear Chat")
453
+
454
+ # Set up submit event chain
455
+ submit_event = textbox.submit(
456
+ fn=process_chat,
457
+ inputs=[textbox, chatbot],
458
+ outputs=chatbot
459
+ ).then(
460
+ fn=lambda: {"text": "", "files": []},
461
+ inputs=None,
462
+ outputs=textbox
463
+ )
464
+
465
+ # Connect the submit button to the same functions
466
+ submit_btn.click(
467
+ fn=process_chat,
468
+ inputs=[textbox, chatbot],
469
+ outputs=chatbot
470
+ ).then(
471
+ fn=lambda: {"text": "", "files": []},
472
+ inputs=None,
473
+ outputs=textbox
474
+ )
475
+
476
+ # Set up clear button
477
+ clear_btn.click(lambda: [], None, chatbot)
478
+
479
+ # Load example images if they exist
480
+ examples = []
481
+
482
+ # Define example images with paths
483
+ example_images = {
484
+ "dog_pic.jpg": "What breed is this?",
485
+ "ghostimg.png": "What's in this image?",
486
+ "newspaper.png": "Provide a python list of dicts about everything on this page."
487
+ }
488
+
489
+ # Check each image and add to examples if it exists
490
+ for img_name, prompt_text in example_images.items():
491
+ img_path = os.path.join(os.path.dirname(__file__), img_name)
492
+ if os.path.exists(img_path):
493
+ examples.append([{"text": prompt_text, "files": [img_path]}])
494
+
495
+ # Add examples if we have any
496
+ if examples:
497
+ gr.Examples(
498
+ examples=examples,
499
+ inputs=textbox
500
+ )
501
+
502
+ # Add status display
503
+ status_text = gr.Textbox(
504
+ label="Tunnel and API Status",
505
+ value=get_tunnel_status_message(),
506
+ interactive=False
507
+ )
508
+
509
+ # Refresh status button and toggle API button
510
+ with gr.Row():
511
+ refresh_btn = gr.Button("Refresh Status")
512
+ toggle_api_btn = gr.Button("Toggle API (Local/Hyperbolic)")
513
+
514
+ # Set up refresh status button
515
+ refresh_btn.click(
516
+ fn=get_tunnel_status_message,
517
+ inputs=None,
518
+ outputs=status_text
519
+ )
520
+
521
+ # Set up toggle API button
522
+ toggle_api_btn.click(
523
+ fn=toggle_api,
524
+ inputs=None,
525
+ outputs=status_text
526
+ )
527
+
528
+ # Just load the initial status without auto-refresh
529
+ demo.load(
530
+ fn=get_tunnel_status_message,
531
+ inputs=None,
532
+ outputs=status_text
533
+ )
534
+
535
+ # Launch the interface on a different port than the SSH tunnel
536
+ demo.launch()