#!/usr/bin/env python3 """ Gradio Interface for Multimodal Chat with SSH Tunnel Keepalive, GPU Monitoring, and API Fallback This application provides a Gradio web interface for multimodal chat with a local vLLM model. It establishes SSH tunnels to a local vLLM server and the nvidia-smi monitoring endpoint, with fallback to Hyperbolic API if needed. """ import os import time import threading import logging import base64 import json import requests from io import BytesIO import gradio as gr from openai import OpenAI from ssh_tunneler import SSHTunnel # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger('app') # Get environment variables SSH_HOST = os.environ.get('SSH_HOST') SSH_PORT = int(os.environ.get('SSH_PORT', 22)) SSH_USERNAME = os.environ.get('SSH_USERNAME') SSH_PASSWORD = os.environ.get('SSH_PASSWORD') REMOTE_PORT = int(os.environ.get('REMOTE_PORT', 8000)) # vLLM API port on remote machine LOCAL_PORT = int(os.environ.get('LOCAL_PORT', 8020)) # Local forwarded port GPU_REMOTE_PORT = 5000 # GPU monitoring endpoint on remote machine GPU_LOCAL_PORT = 5020 # Local forwarded port for GPU monitoring VLLM_MODEL = os.environ.get('MODEL_NAME', 'google/gemma-3-27b-it') HYPERBOLIC_KEY = os.environ.get('HYPERBOLIC_XYZ_KEY') FALLBACK_MODEL = 'Qwen/Qwen2.5-VL-72B-Instruct' # Fallback model at Hyperbolic # Set the maximum number of concurrent API calls before queuing MAX_CONCURRENT = int(os.environ.get('MAX_CONCURRENT', 3)) # Default to 3 concurrent calls # API endpoints VLLM_ENDPOINT = "http://localhost:" + str(LOCAL_PORT) + "/v1" HYPERBOLIC_ENDPOINT = "https://api.hyperbolic.xyz/v1" GPU_JSON_ENDPOINT = "http://localhost:" + str(GPU_LOCAL_PORT) + "/gpu/json" GPU_TXT_ENDPOINT = "http://localhost:" + str(GPU_LOCAL_PORT) + "/gpu/txt" # For backward compatibility # Global variables api_tunnel = None gpu_tunnel = None use_fallback = False # Whether to use fallback API instead of local vLLM api_tunnel_status = {"is_running": False, "message": "Initializing API tunnel..."} gpu_tunnel_status = {"is_running": False, "message": "Initializing GPU monitoring tunnel..."} gpu_data = {"timestamp": "", "gpus": [], "processes": [], "success": False} gpu_monitor_thread = None gpu_monitor_running = False def start_ssh_tunnels(): """ Start the SSH tunnels and monitor their status. """ global api_tunnel, gpu_tunnel, use_fallback, api_tunnel_status, gpu_tunnel_status if not all([SSH_HOST, SSH_USERNAME, SSH_PASSWORD]): logger.error("Missing SSH connection details. Falling back to Hyperbolic API.") use_fallback = True api_tunnel_status = {"is_running": False, "message": "Missing SSH credentials"} gpu_tunnel_status = {"is_running": False, "message": "Missing SSH credentials"} return try: # Start API tunnel logger.info("Starting API SSH tunnel...") api_tunnel = SSHTunnel( ssh_host=SSH_HOST, ssh_port=SSH_PORT, username=SSH_USERNAME, password=SSH_PASSWORD, remote_port=REMOTE_PORT, local_port=LOCAL_PORT, reconnect_interval=30, keep_alive_interval=15 ) if api_tunnel.start(): logger.info("API SSH tunnel started successfully") api_tunnel_status = {"is_running": True, "message": "Connected"} else: logger.warning("Failed to start API SSH tunnel. Falling back to Hyperbolic API.") use_fallback = True api_tunnel_status = {"is_running": False, "message": "Connection failed"} # Start GPU monitoring tunnel logger.info("Starting GPU monitoring SSH tunnel...") gpu_tunnel = SSHTunnel( ssh_host=SSH_HOST, ssh_port=SSH_PORT, username=SSH_USERNAME, password=SSH_PASSWORD, remote_port=GPU_REMOTE_PORT, local_port=GPU_LOCAL_PORT, reconnect_interval=30, keep_alive_interval=15 ) if gpu_tunnel.start(): logger.info("GPU monitoring SSH tunnel started successfully") gpu_tunnel_status = {"is_running": True, "message": "Connected"} # Start GPU monitoring start_gpu_monitoring() else: logger.warning("Failed to start GPU monitoring SSH tunnel.") gpu_tunnel_status = {"is_running": False, "message": "Connection failed"} except Exception as e: logger.error(f"Error starting SSH tunnels: {str(e)}") use_fallback = True api_tunnel_status = {"is_running": False, "message": "Connection error"} gpu_tunnel_status = {"is_running": False, "message": "Connection error"} def check_vllm_api_health(): """ Check if the vLLM API is actually responding by querying the /v1/models endpoint. Returns: tuple: (is_healthy, message) """ try: response = requests.get(f"{VLLM_ENDPOINT}/models", timeout=5) if response.status_code == 200: try: data = response.json() if 'data' in data and len(data['data']) > 0: model_id = data['data'][0].get('id', 'Unknown model') return True, f"API is healthy. Available model: {model_id}" else: return True, "API is healthy but no models found" except Exception as e: return False, f"API returned 200 but invalid JSON: {str(e)}" else: return False, f"API returned status code: {response.status_code}" except Exception as e: return False, f"API request failed: {str(e)}" def fetch_gpu_info(): """ Fetch GPU information from the remote server in JSON format. Returns: dict: GPU information or error message """ global gpu_tunnel_status try: response = requests.get(GPU_JSON_ENDPOINT, timeout=5) if response.status_code == 200: return response.json() else: logger.warning(f"Error fetching GPU info: HTTP {response.status_code}") return { "success": False, "error": f"HTTP Error: {response.status_code}", "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "gpus": [], "processes": [] } except Exception as e: logger.warning(f"Error fetching GPU info: {str(e)}") return { "success": False, "error": str(e), "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "gpus": [], "processes": [] } def fetch_gpu_text(): """ Fetch raw nvidia-smi output from the remote server for backward compatibility. Returns: str: nvidia-smi output or error message """ try: response = requests.get(GPU_TXT_ENDPOINT, timeout=5) if response.status_code == 200: return response.text else: return f"Error fetching GPU info: HTTP {response.status_code}" except Exception as e: return f"Error fetching GPU info: {str(e)}" def start_gpu_monitoring(): """ Start the GPU monitoring thread. """ global gpu_monitor_thread, gpu_monitor_running, gpu_data if gpu_monitor_running: return gpu_monitor_running = True def monitor_loop(): global gpu_data while gpu_monitor_running: try: gpu_data = fetch_gpu_info() except Exception as e: logger.error(f"Error in GPU monitoring loop: {str(e)}") gpu_data = { "success": False, "error": str(e), "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "gpus": [], "processes": [] } time.sleep(2) # Update every 2 seconds gpu_monitor_thread = threading.Thread(target=monitor_loop, daemon=True) gpu_monitor_thread.start() logger.info("GPU monitoring thread started") def process_chat(message_dict, history): """ Process user message and send to the appropriate API. Args: message_dict (dict): User message containing text and files history (list): Chat history Returns: list: Updated chat history """ global use_fallback text = message_dict.get("text", "") files = message_dict.get("files", []) if not history: history = [] if files: for file in files: history.append({"role": "user", "content": (file,)}) if text.strip(): history.append({"role": "user", "content": text}) else: if not files: history.append({"role": "user", "content": ""}) base64_images = convert_files_to_base64(files) openai_messages = [] for h in history: if h["role"] == "user": if isinstance(h["content"], tuple): continue else: openai_messages.append({ "role": "user", "content": h["content"] }) elif h["role"] == "assistant": openai_messages.append({ "role": "assistant", "content": h["content"] }) if base64_images: if openai_messages and openai_messages[-1]["role"] == "user": last_msg = openai_messages[-1] content_list = [] if last_msg["content"]: content_list.append({"type": "text", "text": last_msg["content"]}) for img_b64 in base64_images: content_list.append({ "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{img_b64}" } }) last_msg["content"] = content_list try: client = get_openai_client() model = get_model_name() response = client.chat.completions.create( model=model, messages=openai_messages, stream=True ) assistant_message = "" for chunk in response: if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None: assistant_message += chunk.choices[0].delta.content history_with_stream = history.copy() history_with_stream.append({"role": "assistant", "content": assistant_message}) yield history_with_stream if not assistant_message: assistant_message = "No response received from the model." if not history or history[-1]["role"] != "assistant": history.append({"role": "assistant", "content": assistant_message}) return history except Exception as primary_error: logger.error(f"Primary API error: {str(primary_error)}") if not use_fallback: try: logger.info("Falling back to Hyperbolic API") client = get_openai_client(use_fallback_api=True) model = get_model_name(use_fallback_api=True) response = client.chat.completions.create( model=model, messages=openai_messages, stream=True ) assistant_message = "" for chunk in response: if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None: assistant_message += chunk.choices[0].delta.content history_with_stream = history.copy() history_with_stream.append({"role": "assistant", "content": assistant_message}) yield history_with_stream if not assistant_message: assistant_message = "No response received from the fallback model." if not history or history[-1]["role"] != "assistant": history.append({"role": "assistant", "content": assistant_message}) use_fallback = True return history except Exception as fallback_error: logger.error(f"Fallback API error: {str(fallback_error)}") error_msg = "Error connecting to both primary and fallback APIs." history.append({"role": "assistant", "content": error_msg}) return history else: error_msg = "An error occurred with the model service." history.append({"role": "assistant", "content": error_msg}) return history def monitor_tunnels(): """ Monitor the SSH tunnels status and update the global variables. """ global api_tunnel, gpu_tunnel, use_fallback, api_tunnel_status, gpu_tunnel_status logger.info("Starting tunnel monitoring thread") while True: try: if api_tunnel is not None: ssh_status = api_tunnel.check_status() if ssh_status["is_running"]: is_healthy, message = check_vllm_api_health() if is_healthy: use_fallback = False api_tunnel_status = { "is_running": True, "message": f"Connected and healthy. {message}" } else: use_fallback = True api_tunnel_status = { "is_running": False, "message": "Tunnel connected but vLLM API unhealthy" } else: logger.error(f"API SSH tunnel disconnected: {ssh_status.get('error', 'Unknown error')}") use_fallback = True api_tunnel_status = { "is_running": False, "message": "Disconnected - Check server status" } else: use_fallback = True api_tunnel_status = {"is_running": False, "message": "Tunnel not initialized"} if gpu_tunnel is not None: ssh_status = gpu_tunnel.check_status() if ssh_status["is_running"]: gpu_tunnel_status = { "is_running": True, "message": "Connected" } if not gpu_monitor_running: start_gpu_monitoring() else: logger.error(f"GPU SSH tunnel disconnected: {ssh_status.get('error', 'Unknown error')}") gpu_tunnel_status = { "is_running": False, "message": "Disconnected - Check server status" } else: gpu_tunnel_status = {"is_running": False, "message": "Tunnel not initialized"} except Exception as e: logger.error(f"Error monitoring tunnels: {str(e)}") use_fallback = True api_tunnel_status = {"is_running": False, "message": "Monitoring error"} gpu_tunnel_status = {"is_running": False, "message": "Monitoring error"} time.sleep(5) # Check every 5 seconds def get_openai_client(use_fallback_api=None): """ Create and return an OpenAI client configured for the appropriate endpoint. Args: use_fallback_api (bool): If True, use Hyperbolic API. If False, use local vLLM. If None, use the global use_fallback setting. Returns: OpenAI: Configured OpenAI client """ global use_fallback if use_fallback_api is None: use_fallback_api = use_fallback if use_fallback_api: logger.info("Using Hyperbolic API") return OpenAI( api_key=HYPERBOLIC_KEY, base_url=HYPERBOLIC_ENDPOINT ) else: logger.info("Using local vLLM API") return OpenAI( api_key="EMPTY", # vLLM doesn't require an actual API key base_url=VLLM_ENDPOINT ) def get_model_name(use_fallback_api=None): """ Return the appropriate model name based on the API being used. Args: use_fallback_api (bool): If True, use fallback model. If None, use the global setting. Returns: str: Model name """ global use_fallback if use_fallback_api is None: use_fallback_api = use_fallback return FALLBACK_MODEL if use_fallback_api else VLLM_MODEL def convert_files_to_base64(files): """ Convert uploaded files to base64 strings. Args: files (list): List of file paths Returns: list: List of base64-encoded strings """ base64_images = [] for file in files: with open(file, "rb") as image_file: base64_data = base64.b64encode(image_file.read()).decode("utf-8") base64_images.append(base64_data) return base64_images def format_simplified_gpu_data(gpu_data): """ Format GPU data into a simplified, focused display. Args: gpu_data (dict): GPU data in JSON format Returns: str: Formatted GPU data """ if not gpu_data.get("success", False): return f"Error fetching GPU data: {gpu_data.get('error', 'Unknown error')}" output = [] output.append(f"Last updated: {gpu_data.get('timestamp', 'Unknown')}") for i, gpu in enumerate(gpu_data.get("gpus", [])): output.append(f"GPU {gpu.get('index', i)}: {gpu.get('name', 'Unknown')}") output.append(f" Memory: {gpu.get('memory_used', 0):6.0f} MB / {gpu.get('memory_total', 0):6.0f} MB ({gpu.get('memory_utilization', 0):5.1f}%)") output.append(f" Power: {gpu.get('power_draw', 0):5.1f}W / {gpu.get('power_limit', 0):5.1f}W") if 'fan_speed' in gpu: output.append(f" Fan: {gpu.get('fan_speed', 0):5.1f}%") output.append(f" Temp: {gpu.get('temperature', 0):5.1f}°C") output.append("") return "\n".join(output) def update_gpu_status(): """ Fetch and format the current GPU status. Returns: str: Formatted GPU status """ global gpu_data, gpu_tunnel_status if not gpu_tunnel_status["is_running"]: return "GPU monitoring tunnel is not connected." return format_simplified_gpu_data(gpu_data) def get_tunnel_status_message(): """ Return a formatted status message for display in the UI. """ global api_tunnel_status, gpu_tunnel_status, use_fallback, MAX_CONCURRENT api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API" model = get_model_name() api_status_color = "🟢" if (api_tunnel_status["is_running"] and not use_fallback) else "🔴" api_status_text = api_tunnel_status["message"] gpu_status_color = "🟢" if gpu_tunnel_status["is_running"] else "🔴" gpu_status_text = gpu_tunnel_status["message"] return (f"{api_status_color} API Tunnel: {api_status_text}\n" f"{gpu_status_color} GPU Tunnel: {gpu_status_text}\n" f"Current API: {api_mode}\n" f"Current Model: {model}\n" f"Concurrent Requests: {MAX_CONCURRENT}") def get_gpu_json(): """ Return the raw GPU JSON data for debugging. """ global gpu_data return json.dumps(gpu_data, indent=2) def toggle_api(): """ Toggle between local vLLM and Hyperbolic API. """ global use_fallback use_fallback = not use_fallback api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API" model = get_model_name() return f"Switched to {api_mode} using {model}" def update_concurrency(new_value): """ Update the MAX_CONCURRENT value. Args: new_value (str): New concurrency value as string Returns: str: Status message """ global MAX_CONCURRENT try: value = int(new_value) if value < 1: return f"Error: Concurrency must be at least 1. Keeping current value: {MAX_CONCURRENT}" MAX_CONCURRENT = value return f"Concurrency updated to {MAX_CONCURRENT}. You may need to refresh the page for all changes to take effect." except ValueError: return f"Error: Invalid number. Keeping current value: {MAX_CONCURRENT}" # Start SSH tunnels and monitoring threads if __name__ == "__main__": start_ssh_tunnels() monitor_thread = threading.Thread(target=monitor_tunnels, daemon=True) monitor_thread.start() with gr.Blocks(theme="soft") as demo: gr.Markdown("# Multimodal Chat Interface") chatbot = gr.Chatbot( label="Conversation", type="messages", show_copy_button=True, avatar_images=("👤", "🗣️"), height=400 ) with gr.Row(): textbox = gr.MultimodalTextbox( file_types=["image", "video"], file_count="multiple", placeholder="Type your message here and/or upload images...", label="Message", show_label=False, scale=9 ) submit_btn = gr.Button("Send", size="sm", scale=1) clear_btn = gr.Button("Clear Chat") submit_event = textbox.submit( fn=process_chat, inputs=[textbox, chatbot], outputs=chatbot, concurrency_limit=MAX_CONCURRENT ).then( fn=lambda: {"text": "", "files": []}, inputs=None, outputs=textbox ) submit_btn.click( fn=process_chat, inputs=[textbox, chatbot], outputs=chatbot, concurrency_limit=MAX_CONCURRENT ).then( fn=lambda: {"text": "", "files": []}, inputs=None, outputs=textbox ) clear_btn.click(lambda: [], None, chatbot) examples = [] example_images = { "dog_pic.jpg": "What breed is this?", "ghostimg.png": "What's in this image?", "newspaper.png": "Provide a python list of dicts about everything on this page." } for img_name, prompt_text in example_images.items(): img_path = os.path.join(os.path.dirname(__file__), img_name) if os.path.exists(img_path): examples.append([{"text": prompt_text, "files": [img_path]}]) if examples: gr.Examples( examples=examples, inputs=textbox ) status_text = gr.Textbox( label="Tunnel and API Status", value=get_tunnel_status_message(), interactive=False ) with gr.Accordion("GPU Status", open=False): # Changed from Textbox to HTML component gpu_status = gr.HTML( value=lambda: f"
{update_gpu_status()}", every=2 ) with gr.Row(): refresh_btn = gr.Button("Refresh Status") toggle_api_btn = gr.Button("Toggle API") refresh_btn.click( fn=get_tunnel_status_message, inputs=None, outputs=status_text ) toggle_api_btn.click( fn=toggle_api, inputs=None, outputs=status_text ).then( fn=get_tunnel_status_message, inputs=None, outputs=status_text ) demo.load( fn=get_tunnel_status_message, inputs=None, outputs=status_text ) demo.queue(default_concurrency_limit=MAX_CONCURRENT) demo.launch()