ProximileAdmin commited on
Commit
3626d35
·
verified ·
1 Parent(s): 9aea266

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -201
app.py CHANGED
@@ -1,10 +1,10 @@
1
  #!/usr/bin/env python3
2
  """
3
- Gradio Interface for Multimodal Chat with SSH Tunnel Keepalive and API Fallback
4
 
5
  This application provides a Gradio web interface for multimodal chat with a
6
- local vLLM model. It establishes an SSH tunnel to a local vLLM server and
7
- provides fallback to Hyperbolic API if that server is unavailable.
8
  """
9
 
10
  import os
@@ -13,6 +13,7 @@ import threading
13
  import logging
14
  import base64
15
  import json
 
16
  from io import BytesIO
17
  import gradio as gr
18
  from openai import OpenAI
@@ -31,7 +32,9 @@ SSH_PORT = int(os.environ.get('SSH_PORT', 22))
31
  SSH_USERNAME = os.environ.get('SSH_USERNAME')
32
  SSH_PASSWORD = os.environ.get('SSH_PASSWORD')
33
  REMOTE_PORT = int(os.environ.get('REMOTE_PORT', 8000)) # vLLM API port on remote machine
34
- LOCAL_PORT = int(os.environ.get('LOCAL_PORT', 8020)) # Local forwarded port
 
 
35
  VLLM_MODEL = os.environ.get('MODEL_NAME', 'google/gemma-3-27b-it')
36
  HYPERBOLIC_KEY = os.environ.get('HYPERBOLIC_XYZ_KEY')
37
  FALLBACK_MODEL = 'Qwen/Qwen2.5-VL-72B-Instruct' # Fallback model at Hyperbolic
@@ -42,27 +45,36 @@ MAX_CONCURRENT = int(os.environ.get('MAX_CONCURRENT', 3)) # Default to 3 concur
42
  # API endpoints
43
  VLLM_ENDPOINT = "http://localhost:" + str(LOCAL_PORT) + "/v1"
44
  HYPERBOLIC_ENDPOINT = "https://api.hyperbolic.xyz/v1"
 
 
45
 
46
  # Global variables
47
- tunnel = None
 
48
  use_fallback = False # Whether to use fallback API instead of local vLLM
49
- tunnel_status = {"is_running": False, "message": "Initializing tunnel..."}
 
 
 
 
50
 
51
- def start_ssh_tunnel():
52
  """
53
- Start the SSH tunnel and monitor its status.
54
  """
55
- global tunnel, use_fallback, tunnel_status
56
 
57
  if not all([SSH_HOST, SSH_USERNAME, SSH_PASSWORD]):
58
  logger.error("Missing SSH connection details. Falling back to Hyperbolic API.")
59
  use_fallback = True
60
- tunnel_status = {"is_running": False, "message": "Missing SSH credentials"}
 
61
  return
62
 
63
  try:
64
- logger.info("Starting SSH tunnel...")
65
- tunnel = SSHTunnel(
 
66
  ssh_host=SSH_HOST,
67
  ssh_port=SSH_PORT,
68
  username=SSH_USERNAME,
@@ -73,19 +85,41 @@ def start_ssh_tunnel():
73
  keep_alive_interval=15
74
  )
75
 
76
- if tunnel.start():
77
- logger.info("SSH tunnel started successfully")
78
- use_fallback = False
79
- tunnel_status = {"is_running": True, "message": "Connected"}
80
  else:
81
- logger.warning("Failed to start SSH tunnel. Falling back to Hyperbolic API.")
82
  use_fallback = True
83
- tunnel_status = {"is_running": False, "message": "Connection failed"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  except Exception as e:
86
- logger.error(f"Error starting SSH tunnel: {str(e)}")
87
  use_fallback = True
88
- tunnel_status = {"is_running": False, "message": "Connection error"}
 
89
 
90
  def check_vllm_api_health():
91
  """
@@ -95,7 +129,6 @@ def check_vllm_api_health():
95
  tuple: (is_healthy, message)
96
  """
97
  try:
98
- import requests
99
  response = requests.get(f"{VLLM_ENDPOINT}/models", timeout=5)
100
  if response.status_code == 200:
101
  try:
@@ -112,119 +145,84 @@ def check_vllm_api_health():
112
  except Exception as e:
113
  return False, f"API request failed: {str(e)}"
114
 
115
- def monitor_tunnel():
116
  """
117
- Monitor the SSH tunnel status and update the global variables.
118
- """
119
- global tunnel, use_fallback, tunnel_status
120
-
121
- logger.info("Starting tunnel monitoring thread")
122
-
123
- while True:
124
- try:
125
- if tunnel is not None:
126
- ssh_status = tunnel.check_status()
127
-
128
- # Check if the tunnel is running
129
- if ssh_status["is_running"]:
130
- # Check if vLLM API is actually responding
131
- is_healthy, message = check_vllm_api_health()
132
-
133
- if is_healthy:
134
- use_fallback = False
135
- tunnel_status = {
136
- "is_running": True,
137
- "message": f"Connected and healthy. {message}"
138
- }
139
- else:
140
- use_fallback = True
141
- tunnel_status = {
142
- "is_running": False,
143
- "message": "Tunnel connected but vLLM API unhealthy"
144
- }
145
- else:
146
- # Log the actual error for troubleshooting but don't expose it in the UI
147
- logger.error(f"SSH tunnel disconnected: {ssh_status['error'] or 'Unknown error'}")
148
- use_fallback = True
149
- tunnel_status = {
150
- "is_running": False,
151
- "message": "Disconnected - Check server status"
152
- }
153
- else:
154
- use_fallback = True
155
- tunnel_status = {"is_running": False, "message": "Tunnel not initialized"}
156
-
157
- except Exception as e:
158
- logger.error(f"Error monitoring tunnel: {str(e)}")
159
- use_fallback = True
160
- tunnel_status = {"is_running": False, "message": "Monitoring error"}
161
-
162
- time.sleep(5) # Check every 5 seconds
163
-
164
- def get_openai_client(use_fallback_api=None):
165
- """
166
- Create and return an OpenAI client configured for the appropriate endpoint.
167
-
168
- Args:
169
- use_fallback_api (bool): If True, use Hyperbolic API. If False, use local vLLM.
170
- If None, use the global use_fallback setting.
171
 
172
  Returns:
173
- OpenAI: Configured OpenAI client
174
  """
175
- global use_fallback
176
-
177
- # Determine which API to use
178
- if use_fallback_api is None:
179
- use_fallback_api = use_fallback
180
 
181
- if use_fallback_api:
182
- logger.info("Using Hyperbolic API")
183
- return OpenAI(
184
- api_key=HYPERBOLIC_KEY,
185
- base_url=HYPERBOLIC_ENDPOINT
186
- )
187
- else:
188
- logger.info("Using local vLLM API")
189
- return OpenAI(
190
- api_key="EMPTY", # vLLM doesn't require an actual API key
191
- base_url=VLLM_ENDPOINT
192
- )
 
 
 
 
 
 
 
 
 
 
193
 
194
- def get_model_name(use_fallback_api=None):
195
  """
196
- Return the appropriate model name based on the API being used.
197
-
198
- Args:
199
- use_fallback_api (bool): If True, use fallback model. If None, use the global setting.
200
 
201
  Returns:
202
- str: Model name
203
  """
204
- global use_fallback
205
-
206
- if use_fallback_api is None:
207
- use_fallback_api = use_fallback
208
-
209
- return FALLBACK_MODEL if use_fallback_api else VLLM_MODEL
 
 
210
 
211
- def convert_files_to_base64(files):
212
  """
213
- Convert uploaded files to base64 strings.
 
 
214
 
215
- Args:
216
- files (list): List of file paths
217
 
218
- Returns:
219
- list: List of base64-encoded strings
220
- """
221
- base64_images = []
222
- for file in files:
223
- with open(file, "rb") as image_file:
224
- # Read image data and encode to base64
225
- base64_data = base64.b64encode(image_file.read()).decode("utf-8")
226
- base64_images.append(base64_data)
227
- return base64_images
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  def process_chat(message_dict, history):
230
  """
@@ -242,39 +240,27 @@ def process_chat(message_dict, history):
242
  text = message_dict.get("text", "")
243
  files = message_dict.get("files", [])
244
 
245
- # Add user message to history first
246
  if not history:
247
  history = []
248
 
249
- # Add user message to chat history
250
  if files:
251
- # For each file, add a separate user message
252
  for file in files:
253
  history.append({"role": "user", "content": (file,)})
254
 
255
- # Add text message if not empty
256
  if text.strip():
257
  history.append({"role": "user", "content": text})
258
  else:
259
- # If no text but files exist, don't add an empty message
260
  if not files:
261
  history.append({"role": "user", "content": ""})
262
 
263
- # Convert all files to base64
264
  base64_images = convert_files_to_base64(files)
265
-
266
- # Prepare conversation history in OpenAI format
267
  openai_messages = []
268
 
269
- # Convert history to OpenAI format
270
  for h in history:
271
  if h["role"] == "user":
272
- # Handle user messages
273
  if isinstance(h["content"], tuple):
274
- # This is a file-only message, skip for now
275
  continue
276
  else:
277
- # Text message
278
  openai_messages.append({
279
  "role": "user",
280
  "content": h["content"]
@@ -285,21 +271,12 @@ def process_chat(message_dict, history):
285
  "content": h["content"]
286
  })
287
 
288
- # Handle images for the last user message if needed
289
  if base64_images:
290
- # Update the last user message to include image content
291
  if openai_messages and openai_messages[-1]["role"] == "user":
292
- # Get the last message
293
  last_msg = openai_messages[-1]
294
-
295
- # Format for OpenAI multimodal content structure
296
  content_list = []
297
-
298
- # Add text if there is any
299
  if last_msg["content"]:
300
  content_list.append({"type": "text", "text": last_msg["content"]})
301
-
302
- # Add images
303
  for img_b64 in base64_images:
304
  content_list.append({
305
  "type": "image_url",
@@ -307,37 +284,29 @@ def process_chat(message_dict, history):
307
  "url": f"data:image/jpeg;base64,{img_b64}"
308
  }
309
  })
310
-
311
- # Replace the content with the multimodal content list
312
  last_msg["content"] = content_list
313
 
314
- # Try primary API first, fall back if needed
315
  try:
316
- # First try with the currently selected API (vLLM or fallback)
317
  client = get_openai_client()
318
  model = get_model_name()
319
 
320
  response = client.chat.completions.create(
321
  model=model,
322
  messages=openai_messages,
323
- stream=True # Use streaming for better UX
324
  )
325
 
326
- # Stream the response
327
  assistant_message = ""
328
  for chunk in response:
329
  if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None:
330
  assistant_message += chunk.choices[0].delta.content
331
- # Update in real-time
332
  history_with_stream = history.copy()
333
  history_with_stream.append({"role": "assistant", "content": assistant_message})
334
  yield history_with_stream
335
 
336
- # Ensure we have the final message added
337
  if not assistant_message:
338
  assistant_message = "No response received from the model."
339
 
340
- # Add assistant response to history if not already added
341
  if not history or history[-1]["role"] != "assistant":
342
  history.append({"role": "assistant", "content": assistant_message})
343
 
@@ -345,8 +314,6 @@ def process_chat(message_dict, history):
345
 
346
  except Exception as primary_error:
347
  logger.error(f"Primary API error: {str(primary_error)}")
348
-
349
- # If we're not already using fallback, try that
350
  if not use_fallback:
351
  try:
352
  logger.info("Falling back to Hyperbolic API")
@@ -359,27 +326,21 @@ def process_chat(message_dict, history):
359
  stream=True
360
  )
361
 
362
- # Stream the response
363
  assistant_message = ""
364
  for chunk in response:
365
  if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None:
366
  assistant_message += chunk.choices[0].delta.content
367
- # Update in real-time
368
  history_with_stream = history.copy()
369
  history_with_stream.append({"role": "assistant", "content": assistant_message})
370
  yield history_with_stream
371
 
372
- # Ensure we have the final message added
373
  if not assistant_message:
374
  assistant_message = "No response received from the fallback model."
375
 
376
- # Add assistant response to history if not already added
377
  if not history or history[-1]["role"] != "assistant":
378
  history.append({"role": "assistant", "content": assistant_message})
379
 
380
- # Update fallback status (global already declared at function start)
381
  use_fallback = True
382
-
383
  return history
384
 
385
  except Exception as fallback_error:
@@ -388,24 +349,195 @@ def process_chat(message_dict, history):
388
  history.append({"role": "assistant", "content": error_msg})
389
  return history
390
  else:
391
- # Already using fallback, just report the error
392
  error_msg = "An error occurred with the model service."
393
  history.append({"role": "assistant", "content": error_msg})
394
  return history
395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  def get_tunnel_status_message():
397
  """
398
  Return a formatted status message for display in the UI.
399
  """
400
- global tunnel_status, use_fallback, MAX_CONCURRENT
401
-
402
  api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API"
403
  model = get_model_name()
404
-
405
- status_color = "🟢" if (tunnel_status["is_running"] and not use_fallback) else "🔴"
406
- status_text = tunnel_status["message"]
407
-
408
- return f"{status_color} Tunnel Status: {status_text}\nCurrent API: {api_mode}\nCurrent Model: {model}\nConcurrent Requests: {MAX_CONCURRENT}"
 
 
 
 
 
 
 
 
 
 
 
409
 
410
  def toggle_api():
411
  """
@@ -413,10 +545,8 @@ def toggle_api():
413
  """
414
  global use_fallback
415
  use_fallback = not use_fallback
416
-
417
  api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API"
418
  model = get_model_name()
419
-
420
  return f"Switched to {api_mode} using {model}"
421
 
422
  def update_concurrency(new_value):
@@ -434,29 +564,20 @@ def update_concurrency(new_value):
434
  value = int(new_value)
435
  if value < 1:
436
  return f"Error: Concurrency must be at least 1. Keeping current value: {MAX_CONCURRENT}"
437
-
438
  MAX_CONCURRENT = value
439
- # Note: This only updates the value for future event handlers
440
- # Existing event handlers keep their original concurrency_limit
441
- # A page refresh is needed for this to fully take effect
442
  return f"Concurrency updated to {MAX_CONCURRENT}. You may need to refresh the page for all changes to take effect."
443
  except ValueError:
444
  return f"Error: Invalid number. Keeping current value: {MAX_CONCURRENT}"
445
 
446
- # Start the SSH tunnel in a background thread
447
  if __name__ == "__main__":
448
- # Start the SSH tunnel
449
- start_ssh_tunnel()
450
-
451
- # Start the monitoring thread
452
- monitor_thread = threading.Thread(target=monitor_tunnel, daemon=True)
453
  monitor_thread.start()
454
 
455
- # Create Gradio application with Blocks for more control
456
  with gr.Blocks(theme="soft") as demo:
457
  gr.Markdown("# Multimodal Chat Interface")
458
 
459
- # Create chatbot component with message type
460
  chatbot = gr.Chatbot(
461
  label="Conversation",
462
  type="messages",
@@ -465,7 +586,6 @@ if __name__ == "__main__":
465
  height=400
466
  )
467
 
468
- # Create multimodal textbox for input
469
  with gr.Row():
470
  textbox = gr.MultimodalTextbox(
471
  file_types=["image", "video"],
@@ -477,84 +597,86 @@ if __name__ == "__main__":
477
  )
478
  submit_btn = gr.Button("Send", size="sm", scale=1)
479
 
480
- # Clear button
481
  clear_btn = gr.Button("Clear Chat")
482
 
483
- # Set up submit event chain with concurrency limit
484
  submit_event = textbox.submit(
485
  fn=process_chat,
486
  inputs=[textbox, chatbot],
487
  outputs=chatbot,
488
- concurrency_limit=MAX_CONCURRENT # Set concurrency limit for this event
489
  ).then(
490
  fn=lambda: {"text": "", "files": []},
491
  inputs=None,
492
  outputs=textbox
493
  )
494
 
495
- # Connect the submit button to the same functions with same concurrency limit
496
  submit_btn.click(
497
  fn=process_chat,
498
  inputs=[textbox, chatbot],
499
  outputs=chatbot,
500
- concurrency_limit=MAX_CONCURRENT # Set concurrency limit for this event
501
  ).then(
502
  fn=lambda: {"text": "", "files": []},
503
  inputs=None,
504
  outputs=textbox
505
  )
506
 
507
- # Set up clear button
508
  clear_btn.click(lambda: [], None, chatbot)
509
-
510
- # Load example images if they exist
511
- examples = []
512
 
513
- # Define example images with paths
514
  example_images = {
515
  "dog_pic.jpg": "What breed is this?",
516
  "ghostimg.png": "What's in this image?",
517
  "newspaper.png": "Provide a python list of dicts about everything on this page."
518
  }
519
-
520
- # Check each image and add to examples if it exists
521
  for img_name, prompt_text in example_images.items():
522
  img_path = os.path.join(os.path.dirname(__file__), img_name)
523
  if os.path.exists(img_path):
524
  examples.append([{"text": prompt_text, "files": [img_path]}])
525
-
526
- # Add examples if we have any
527
  if examples:
528
  gr.Examples(
529
  examples=examples,
530
  inputs=textbox
531
  )
532
 
533
- # Add status display
534
  status_text = gr.Textbox(
535
  label="Tunnel and API Status",
536
  value=get_tunnel_status_message(),
537
  interactive=False
538
  )
539
 
540
- # Refresh status button and toggle API button
 
 
 
 
 
 
541
  with gr.Row():
542
  refresh_btn = gr.Button("Refresh Status")
543
-
544
- # Set up refresh status button
545
  refresh_btn.click(
546
  fn=get_tunnel_status_message,
547
  inputs=None,
548
  outputs=status_text
549
  )
550
 
551
- # Just load the initial status without auto-refresh
 
 
 
 
 
 
 
 
 
552
  demo.load(
553
  fn=get_tunnel_status_message,
554
  inputs=None,
555
  outputs=status_text
556
  )
557
 
558
- # Launch the interface with the specified concurrency setting
559
  demo.queue(default_concurrency_limit=MAX_CONCURRENT)
560
  demo.launch()
 
1
  #!/usr/bin/env python3
2
  """
3
+ Gradio Interface for Multimodal Chat with SSH Tunnel Keepalive, GPU Monitoring, and API Fallback
4
 
5
  This application provides a Gradio web interface for multimodal chat with a
6
+ local vLLM model. It establishes SSH tunnels to a local vLLM server and
7
+ the nvidia-smi monitoring endpoint, with fallback to Hyperbolic API if needed.
8
  """
9
 
10
  import os
 
13
  import logging
14
  import base64
15
  import json
16
+ import requests
17
  from io import BytesIO
18
  import gradio as gr
19
  from openai import OpenAI
 
32
  SSH_USERNAME = os.environ.get('SSH_USERNAME')
33
  SSH_PASSWORD = os.environ.get('SSH_PASSWORD')
34
  REMOTE_PORT = int(os.environ.get('REMOTE_PORT', 8000)) # vLLM API port on remote machine
35
+ LOCAL_PORT = int(os.environ.get('LOCAL_PORT', 8020)) # Local forwarded port
36
+ GPU_REMOTE_PORT = 5000 # GPU monitoring endpoint on remote machine
37
+ GPU_LOCAL_PORT = 5020 # Local forwarded port for GPU monitoring
38
  VLLM_MODEL = os.environ.get('MODEL_NAME', 'google/gemma-3-27b-it')
39
  HYPERBOLIC_KEY = os.environ.get('HYPERBOLIC_XYZ_KEY')
40
  FALLBACK_MODEL = 'Qwen/Qwen2.5-VL-72B-Instruct' # Fallback model at Hyperbolic
 
45
  # API endpoints
46
  VLLM_ENDPOINT = "http://localhost:" + str(LOCAL_PORT) + "/v1"
47
  HYPERBOLIC_ENDPOINT = "https://api.hyperbolic.xyz/v1"
48
+ GPU_JSON_ENDPOINT = "http://localhost:" + str(GPU_LOCAL_PORT) + "/gpu/json"
49
+ GPU_TXT_ENDPOINT = "http://localhost:" + str(GPU_LOCAL_PORT) + "/gpu/txt" # For backward compatibility
50
 
51
  # Global variables
52
+ api_tunnel = None
53
+ gpu_tunnel = None
54
  use_fallback = False # Whether to use fallback API instead of local vLLM
55
+ api_tunnel_status = {"is_running": False, "message": "Initializing API tunnel..."}
56
+ gpu_tunnel_status = {"is_running": False, "message": "Initializing GPU monitoring tunnel..."}
57
+ gpu_data = {"timestamp": "", "gpus": [], "processes": [], "success": False}
58
+ gpu_monitor_thread = None
59
+ gpu_monitor_running = False
60
 
61
+ def start_ssh_tunnels():
62
  """
63
+ Start the SSH tunnels and monitor their status.
64
  """
65
+ global api_tunnel, gpu_tunnel, use_fallback, api_tunnel_status, gpu_tunnel_status
66
 
67
  if not all([SSH_HOST, SSH_USERNAME, SSH_PASSWORD]):
68
  logger.error("Missing SSH connection details. Falling back to Hyperbolic API.")
69
  use_fallback = True
70
+ api_tunnel_status = {"is_running": False, "message": "Missing SSH credentials"}
71
+ gpu_tunnel_status = {"is_running": False, "message": "Missing SSH credentials"}
72
  return
73
 
74
  try:
75
+ # Start API tunnel
76
+ logger.info("Starting API SSH tunnel...")
77
+ api_tunnel = SSHTunnel(
78
  ssh_host=SSH_HOST,
79
  ssh_port=SSH_PORT,
80
  username=SSH_USERNAME,
 
85
  keep_alive_interval=15
86
  )
87
 
88
+ if api_tunnel.start():
89
+ logger.info("API SSH tunnel started successfully")
90
+ api_tunnel_status = {"is_running": True, "message": "Connected"}
 
91
  else:
92
+ logger.warning("Failed to start API SSH tunnel. Falling back to Hyperbolic API.")
93
  use_fallback = True
94
+ api_tunnel_status = {"is_running": False, "message": "Connection failed"}
95
+
96
+ # Start GPU monitoring tunnel
97
+ logger.info("Starting GPU monitoring SSH tunnel...")
98
+ gpu_tunnel = SSHTunnel(
99
+ ssh_host=SSH_HOST,
100
+ ssh_port=SSH_PORT,
101
+ username=SSH_USERNAME,
102
+ password=SSH_PASSWORD,
103
+ remote_port=GPU_REMOTE_PORT,
104
+ local_port=GPU_LOCAL_PORT,
105
+ reconnect_interval=30,
106
+ keep_alive_interval=15
107
+ )
108
+
109
+ if gpu_tunnel.start():
110
+ logger.info("GPU monitoring SSH tunnel started successfully")
111
+ gpu_tunnel_status = {"is_running": True, "message": "Connected"}
112
+ # Start GPU monitoring
113
+ start_gpu_monitoring()
114
+ else:
115
+ logger.warning("Failed to start GPU monitoring SSH tunnel.")
116
+ gpu_tunnel_status = {"is_running": False, "message": "Connection failed"}
117
 
118
  except Exception as e:
119
+ logger.error(f"Error starting SSH tunnels: {str(e)}")
120
  use_fallback = True
121
+ api_tunnel_status = {"is_running": False, "message": "Connection error"}
122
+ gpu_tunnel_status = {"is_running": False, "message": "Connection error"}
123
 
124
  def check_vllm_api_health():
125
  """
 
129
  tuple: (is_healthy, message)
130
  """
131
  try:
 
132
  response = requests.get(f"{VLLM_ENDPOINT}/models", timeout=5)
133
  if response.status_code == 200:
134
  try:
 
145
  except Exception as e:
146
  return False, f"API request failed: {str(e)}"
147
 
148
+ def fetch_gpu_info():
149
  """
150
+ Fetch GPU information from the remote server in JSON format.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  Returns:
153
+ dict: GPU information or error message
154
  """
155
+ global gpu_tunnel_status
 
 
 
 
156
 
157
+ try:
158
+ response = requests.get(GPU_JSON_ENDPOINT, timeout=5)
159
+ if response.status_code == 200:
160
+ return response.json()
161
+ else:
162
+ logger.warning(f"Error fetching GPU info: HTTP {response.status_code}")
163
+ return {
164
+ "success": False,
165
+ "error": f"HTTP Error: {response.status_code}",
166
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
167
+ "gpus": [],
168
+ "processes": []
169
+ }
170
+ except Exception as e:
171
+ logger.warning(f"Error fetching GPU info: {str(e)}")
172
+ return {
173
+ "success": False,
174
+ "error": str(e),
175
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
176
+ "gpus": [],
177
+ "processes": []
178
+ }
179
 
180
+ def fetch_gpu_text():
181
  """
182
+ Fetch raw nvidia-smi output from the remote server for backward compatibility.
 
 
 
183
 
184
  Returns:
185
+ str: nvidia-smi output or error message
186
  """
187
+ try:
188
+ response = requests.get(GPU_TXT_ENDPOINT, timeout=5)
189
+ if response.status_code == 200:
190
+ return response.text
191
+ else:
192
+ return f"Error fetching GPU info: HTTP {response.status_code}"
193
+ except Exception as e:
194
+ return f"Error fetching GPU info: {str(e)}"
195
 
196
+ def start_gpu_monitoring():
197
  """
198
+ Start the GPU monitoring thread.
199
+ """
200
+ global gpu_monitor_thread, gpu_monitor_running, gpu_data
201
 
202
+ if gpu_monitor_running:
203
+ return
204
 
205
+ gpu_monitor_running = True
206
+
207
+ def monitor_loop():
208
+ global gpu_data
209
+ while gpu_monitor_running:
210
+ try:
211
+ gpu_data = fetch_gpu_info()
212
+ except Exception as e:
213
+ logger.error(f"Error in GPU monitoring loop: {str(e)}")
214
+ gpu_data = {
215
+ "success": False,
216
+ "error": str(e),
217
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
218
+ "gpus": [],
219
+ "processes": []
220
+ }
221
+ time.sleep(2) # Update every 2 seconds
222
+
223
+ gpu_monitor_thread = threading.Thread(target=monitor_loop, daemon=True)
224
+ gpu_monitor_thread.start()
225
+ logger.info("GPU monitoring thread started")
226
 
227
  def process_chat(message_dict, history):
228
  """
 
240
  text = message_dict.get("text", "")
241
  files = message_dict.get("files", [])
242
 
 
243
  if not history:
244
  history = []
245
 
 
246
  if files:
 
247
  for file in files:
248
  history.append({"role": "user", "content": (file,)})
249
 
 
250
  if text.strip():
251
  history.append({"role": "user", "content": text})
252
  else:
 
253
  if not files:
254
  history.append({"role": "user", "content": ""})
255
 
 
256
  base64_images = convert_files_to_base64(files)
 
 
257
  openai_messages = []
258
 
 
259
  for h in history:
260
  if h["role"] == "user":
 
261
  if isinstance(h["content"], tuple):
 
262
  continue
263
  else:
 
264
  openai_messages.append({
265
  "role": "user",
266
  "content": h["content"]
 
271
  "content": h["content"]
272
  })
273
 
 
274
  if base64_images:
 
275
  if openai_messages and openai_messages[-1]["role"] == "user":
 
276
  last_msg = openai_messages[-1]
 
 
277
  content_list = []
 
 
278
  if last_msg["content"]:
279
  content_list.append({"type": "text", "text": last_msg["content"]})
 
 
280
  for img_b64 in base64_images:
281
  content_list.append({
282
  "type": "image_url",
 
284
  "url": f"data:image/jpeg;base64,{img_b64}"
285
  }
286
  })
 
 
287
  last_msg["content"] = content_list
288
 
 
289
  try:
 
290
  client = get_openai_client()
291
  model = get_model_name()
292
 
293
  response = client.chat.completions.create(
294
  model=model,
295
  messages=openai_messages,
296
+ stream=True
297
  )
298
 
 
299
  assistant_message = ""
300
  for chunk in response:
301
  if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None:
302
  assistant_message += chunk.choices[0].delta.content
 
303
  history_with_stream = history.copy()
304
  history_with_stream.append({"role": "assistant", "content": assistant_message})
305
  yield history_with_stream
306
 
 
307
  if not assistant_message:
308
  assistant_message = "No response received from the model."
309
 
 
310
  if not history or history[-1]["role"] != "assistant":
311
  history.append({"role": "assistant", "content": assistant_message})
312
 
 
314
 
315
  except Exception as primary_error:
316
  logger.error(f"Primary API error: {str(primary_error)}")
 
 
317
  if not use_fallback:
318
  try:
319
  logger.info("Falling back to Hyperbolic API")
 
326
  stream=True
327
  )
328
 
 
329
  assistant_message = ""
330
  for chunk in response:
331
  if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None:
332
  assistant_message += chunk.choices[0].delta.content
 
333
  history_with_stream = history.copy()
334
  history_with_stream.append({"role": "assistant", "content": assistant_message})
335
  yield history_with_stream
336
 
 
337
  if not assistant_message:
338
  assistant_message = "No response received from the fallback model."
339
 
 
340
  if not history or history[-1]["role"] != "assistant":
341
  history.append({"role": "assistant", "content": assistant_message})
342
 
 
343
  use_fallback = True
 
344
  return history
345
 
346
  except Exception as fallback_error:
 
349
  history.append({"role": "assistant", "content": error_msg})
350
  return history
351
  else:
 
352
  error_msg = "An error occurred with the model service."
353
  history.append({"role": "assistant", "content": error_msg})
354
  return history
355
 
356
+ def monitor_tunnels():
357
+ """
358
+ Monitor the SSH tunnels status and update the global variables.
359
+ """
360
+ global api_tunnel, gpu_tunnel, use_fallback, api_tunnel_status, gpu_tunnel_status
361
+
362
+ logger.info("Starting tunnel monitoring thread")
363
+
364
+ while True:
365
+ try:
366
+ if api_tunnel is not None:
367
+ ssh_status = api_tunnel.check_status()
368
+ if ssh_status["is_running"]:
369
+ is_healthy, message = check_vllm_api_health()
370
+ if is_healthy:
371
+ use_fallback = False
372
+ api_tunnel_status = {
373
+ "is_running": True,
374
+ "message": f"Connected and healthy. {message}"
375
+ }
376
+ else:
377
+ use_fallback = True
378
+ api_tunnel_status = {
379
+ "is_running": False,
380
+ "message": "Tunnel connected but vLLM API unhealthy"
381
+ }
382
+ else:
383
+ logger.error(f"API SSH tunnel disconnected: {ssh_status.get('error', 'Unknown error')}")
384
+ use_fallback = True
385
+ api_tunnel_status = {
386
+ "is_running": False,
387
+ "message": "Disconnected - Check server status"
388
+ }
389
+ else:
390
+ use_fallback = True
391
+ api_tunnel_status = {"is_running": False, "message": "Tunnel not initialized"}
392
+
393
+ if gpu_tunnel is not None:
394
+ ssh_status = gpu_tunnel.check_status()
395
+ if ssh_status["is_running"]:
396
+ gpu_tunnel_status = {
397
+ "is_running": True,
398
+ "message": "Connected"
399
+ }
400
+ if not gpu_monitor_running:
401
+ start_gpu_monitoring()
402
+ else:
403
+ logger.error(f"GPU SSH tunnel disconnected: {ssh_status.get('error', 'Unknown error')}")
404
+ gpu_tunnel_status = {
405
+ "is_running": False,
406
+ "message": "Disconnected - Check server status"
407
+ }
408
+ else:
409
+ gpu_tunnel_status = {"is_running": False, "message": "Tunnel not initialized"}
410
+
411
+ except Exception as e:
412
+ logger.error(f"Error monitoring tunnels: {str(e)}")
413
+ use_fallback = True
414
+ api_tunnel_status = {"is_running": False, "message": "Monitoring error"}
415
+ gpu_tunnel_status = {"is_running": False, "message": "Monitoring error"}
416
+
417
+ time.sleep(5) # Check every 5 seconds
418
+
419
+ def get_openai_client(use_fallback_api=None):
420
+ """
421
+ Create and return an OpenAI client configured for the appropriate endpoint.
422
+
423
+ Args:
424
+ use_fallback_api (bool): If True, use Hyperbolic API. If False, use local vLLM.
425
+ If None, use the global use_fallback setting.
426
+
427
+ Returns:
428
+ OpenAI: Configured OpenAI client
429
+ """
430
+ global use_fallback
431
+ if use_fallback_api is None:
432
+ use_fallback_api = use_fallback
433
+
434
+ if use_fallback_api:
435
+ logger.info("Using Hyperbolic API")
436
+ return OpenAI(
437
+ api_key=HYPERBOLIC_KEY,
438
+ base_url=HYPERBOLIC_ENDPOINT
439
+ )
440
+ else:
441
+ logger.info("Using local vLLM API")
442
+ return OpenAI(
443
+ api_key="EMPTY", # vLLM doesn't require an actual API key
444
+ base_url=VLLM_ENDPOINT
445
+ )
446
+
447
+ def get_model_name(use_fallback_api=None):
448
+ """
449
+ Return the appropriate model name based on the API being used.
450
+
451
+ Args:
452
+ use_fallback_api (bool): If True, use fallback model. If None, use the global setting.
453
+
454
+ Returns:
455
+ str: Model name
456
+ """
457
+ global use_fallback
458
+ if use_fallback_api is None:
459
+ use_fallback_api = use_fallback
460
+ return FALLBACK_MODEL if use_fallback_api else VLLM_MODEL
461
+
462
+ def convert_files_to_base64(files):
463
+ """
464
+ Convert uploaded files to base64 strings.
465
+
466
+ Args:
467
+ files (list): List of file paths
468
+
469
+ Returns:
470
+ list: List of base64-encoded strings
471
+ """
472
+ base64_images = []
473
+ for file in files:
474
+ with open(file, "rb") as image_file:
475
+ base64_data = base64.b64encode(image_file.read()).decode("utf-8")
476
+ base64_images.append(base64_data)
477
+ return base64_images
478
+
479
+ def format_simplified_gpu_data(gpu_data):
480
+ """
481
+ Format GPU data into a simplified, focused display.
482
+
483
+ Args:
484
+ gpu_data (dict): GPU data in JSON format
485
+
486
+ Returns:
487
+ str: Formatted GPU data
488
+ """
489
+ if not gpu_data.get("success", False):
490
+ return f"Error fetching GPU data: {gpu_data.get('error', 'Unknown error')}"
491
+
492
+ output = []
493
+ output.append(f"Last updated: {gpu_data.get('timestamp', 'Unknown')}")
494
+
495
+ for i, gpu in enumerate(gpu_data.get("gpus", [])):
496
+ output.append(f"GPU {gpu.get('index', i)}: {gpu.get('name', 'Unknown')}")
497
+ output.append(f" Memory: {gpu.get('memory_used', 0):6.0f} MB / {gpu.get('memory_total', 0):6.0f} MB ({gpu.get('memory_utilization', 0):5.1f}%)")
498
+ output.append(f" Power: {gpu.get('power_draw', 0):5.1f}W / {gpu.get('power_limit', 0):5.1f}W")
499
+ if 'fan_speed' in gpu:
500
+ output.append(f" Fan: {gpu.get('fan_speed', 0):5.1f}%")
501
+ output.append(f" Temp: {gpu.get('temperature', 0):5.1f}°C")
502
+ output.append("")
503
+
504
+ return "\n".join(output)
505
+
506
+ def update_gpu_status():
507
+ """
508
+ Fetch and format the current GPU status.
509
+
510
+ Returns:
511
+ str: Formatted GPU status
512
+ """
513
+ global gpu_data, gpu_tunnel_status
514
+ if not gpu_tunnel_status["is_running"]:
515
+ return "GPU monitoring tunnel is not connected."
516
+ return format_simplified_gpu_data(gpu_data)
517
+
518
  def get_tunnel_status_message():
519
  """
520
  Return a formatted status message for display in the UI.
521
  """
522
+ global api_tunnel_status, gpu_tunnel_status, use_fallback, MAX_CONCURRENT
 
523
  api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API"
524
  model = get_model_name()
525
+ api_status_color = "🟢" if (api_tunnel_status["is_running"] and not use_fallback) else "🔴"
526
+ api_status_text = api_tunnel_status["message"]
527
+ gpu_status_color = "🟢" if gpu_tunnel_status["is_running"] else "🔴"
528
+ gpu_status_text = gpu_tunnel_status["message"]
529
+ return (f"{api_status_color} API Tunnel: {api_status_text}\n"
530
+ f"{gpu_status_color} GPU Tunnel: {gpu_status_text}\n"
531
+ f"Current API: {api_mode}\n"
532
+ f"Current Model: {model}\n"
533
+ f"Concurrent Requests: {MAX_CONCURRENT}")
534
+
535
+ def get_gpu_json():
536
+ """
537
+ Return the raw GPU JSON data for debugging.
538
+ """
539
+ global gpu_data
540
+ return json.dumps(gpu_data, indent=2)
541
 
542
  def toggle_api():
543
  """
 
545
  """
546
  global use_fallback
547
  use_fallback = not use_fallback
 
548
  api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API"
549
  model = get_model_name()
 
550
  return f"Switched to {api_mode} using {model}"
551
 
552
  def update_concurrency(new_value):
 
564
  value = int(new_value)
565
  if value < 1:
566
  return f"Error: Concurrency must be at least 1. Keeping current value: {MAX_CONCURRENT}"
 
567
  MAX_CONCURRENT = value
 
 
 
568
  return f"Concurrency updated to {MAX_CONCURRENT}. You may need to refresh the page for all changes to take effect."
569
  except ValueError:
570
  return f"Error: Invalid number. Keeping current value: {MAX_CONCURRENT}"
571
 
572
+ # Start SSH tunnels and monitoring threads
573
  if __name__ == "__main__":
574
+ start_ssh_tunnels()
575
+ monitor_thread = threading.Thread(target=monitor_tunnels, daemon=True)
 
 
 
576
  monitor_thread.start()
577
 
 
578
  with gr.Blocks(theme="soft") as demo:
579
  gr.Markdown("# Multimodal Chat Interface")
580
 
 
581
  chatbot = gr.Chatbot(
582
  label="Conversation",
583
  type="messages",
 
586
  height=400
587
  )
588
 
 
589
  with gr.Row():
590
  textbox = gr.MultimodalTextbox(
591
  file_types=["image", "video"],
 
597
  )
598
  submit_btn = gr.Button("Send", size="sm", scale=1)
599
 
 
600
  clear_btn = gr.Button("Clear Chat")
601
 
 
602
  submit_event = textbox.submit(
603
  fn=process_chat,
604
  inputs=[textbox, chatbot],
605
  outputs=chatbot,
606
+ concurrency_limit=MAX_CONCURRENT
607
  ).then(
608
  fn=lambda: {"text": "", "files": []},
609
  inputs=None,
610
  outputs=textbox
611
  )
612
 
 
613
  submit_btn.click(
614
  fn=process_chat,
615
  inputs=[textbox, chatbot],
616
  outputs=chatbot,
617
+ concurrency_limit=MAX_CONCURRENT
618
  ).then(
619
  fn=lambda: {"text": "", "files": []},
620
  inputs=None,
621
  outputs=textbox
622
  )
623
 
 
624
  clear_btn.click(lambda: [], None, chatbot)
 
 
 
625
 
626
+ examples = []
627
  example_images = {
628
  "dog_pic.jpg": "What breed is this?",
629
  "ghostimg.png": "What's in this image?",
630
  "newspaper.png": "Provide a python list of dicts about everything on this page."
631
  }
 
 
632
  for img_name, prompt_text in example_images.items():
633
  img_path = os.path.join(os.path.dirname(__file__), img_name)
634
  if os.path.exists(img_path):
635
  examples.append([{"text": prompt_text, "files": [img_path]}])
 
 
636
  if examples:
637
  gr.Examples(
638
  examples=examples,
639
  inputs=textbox
640
  )
641
 
 
642
  status_text = gr.Textbox(
643
  label="Tunnel and API Status",
644
  value=get_tunnel_status_message(),
645
  interactive=False
646
  )
647
 
648
+ with gr.Accordion("GPU Status", open=False):
649
+ # Changed from Textbox to HTML component
650
+ gpu_status = gr.HTML(
651
+ value=lambda: f"<pre style='font-family: monospace; white-space: pre; overflow: auto;'>{update_gpu_status()}</pre>",
652
+ every=2
653
+ )
654
+
655
  with gr.Row():
656
  refresh_btn = gr.Button("Refresh Status")
657
+ toggle_api_btn = gr.Button("Toggle API")
658
+
659
  refresh_btn.click(
660
  fn=get_tunnel_status_message,
661
  inputs=None,
662
  outputs=status_text
663
  )
664
 
665
+ toggle_api_btn.click(
666
+ fn=toggle_api,
667
+ inputs=None,
668
+ outputs=status_text
669
+ ).then(
670
+ fn=get_tunnel_status_message,
671
+ inputs=None,
672
+ outputs=status_text
673
+ )
674
+
675
  demo.load(
676
  fn=get_tunnel_status_message,
677
  inputs=None,
678
  outputs=status_text
679
  )
680
 
 
681
  demo.queue(default_concurrency_limit=MAX_CONCURRENT)
682
  demo.launch()