Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Gradio Interface for Multimodal Chat with SSH Tunnel Keepalive and API Fallback | |
This application provides a Gradio web interface for multimodal chat with a | |
local vLLM model. It establishes an SSH tunnel to a local vLLM server and | |
provides fallback to Hyperbolic API if that server is unavailable. | |
""" | |
import os | |
import time | |
import threading | |
import logging | |
import base64 | |
import json | |
from io import BytesIO | |
import gradio as gr | |
from openai import OpenAI | |
from ssh_tunneler import SSHTunnel | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger('app') | |
# Get environment variables | |
SSH_HOST = os.environ.get('SSH_HOST') | |
SSH_PORT = int(os.environ.get('SSH_PORT', 22)) | |
SSH_USERNAME = os.environ.get('SSH_USERNAME') | |
SSH_PASSWORD = os.environ.get('SSH_PASSWORD') | |
REMOTE_PORT = int(os.environ.get('REMOTE_PORT', 8000)) # vLLM API port on remote machine | |
LOCAL_PORT = int(os.environ.get('LOCAL_PORT', 8020)) # Local forwarded port | |
VLLM_MODEL = os.environ.get('MODEL_NAME', 'google/gemma-3-27b-it') | |
HYPERBOLIC_KEY = os.environ.get('HYPERBOLIC_XYZ_KEY') | |
FALLBACK_MODEL = 'Qwen/Qwen2.5-VL-72B-Instruct' # Fallback model at Hyperbolic | |
# Set the maximum number of concurrent API calls before queuing | |
MAX_CONCURRENT = int(os.environ.get('MAX_CONCURRENT', 3)) # Default to 3 concurrent calls | |
# API endpoints | |
VLLM_ENDPOINT = "http://localhost:" + str(LOCAL_PORT) + "/v1" | |
HYPERBOLIC_ENDPOINT = "https://api.hyperbolic.xyz/v1" | |
# Global variables | |
tunnel = None | |
use_fallback = False # Whether to use fallback API instead of local vLLM | |
tunnel_status = {"is_running": False, "message": "Initializing tunnel..."} | |
def start_ssh_tunnel(): | |
""" | |
Start the SSH tunnel and monitor its status. | |
""" | |
global tunnel, use_fallback, tunnel_status | |
if not all([SSH_HOST, SSH_USERNAME, SSH_PASSWORD]): | |
logger.error("Missing SSH connection details. Falling back to Hyperbolic API.") | |
use_fallback = True | |
tunnel_status = {"is_running": False, "message": "Missing SSH credentials"} | |
return | |
try: | |
logger.info("Starting SSH tunnel...") | |
tunnel = SSHTunnel( | |
ssh_host=SSH_HOST, | |
ssh_port=SSH_PORT, | |
username=SSH_USERNAME, | |
password=SSH_PASSWORD, | |
remote_port=REMOTE_PORT, | |
local_port=LOCAL_PORT, | |
reconnect_interval=30, | |
keep_alive_interval=15 | |
) | |
if tunnel.start(): | |
logger.info("SSH tunnel started successfully") | |
use_fallback = False | |
tunnel_status = {"is_running": True, "message": "Connected"} | |
else: | |
logger.warning("Failed to start SSH tunnel. Falling back to Hyperbolic API.") | |
use_fallback = True | |
tunnel_status = {"is_running": False, "message": "Connection failed"} | |
except Exception as e: | |
logger.error(f"Error starting SSH tunnel: {str(e)}") | |
use_fallback = True | |
tunnel_status = {"is_running": False, "message": "Connection error"} | |
def check_vllm_api_health(): | |
""" | |
Check if the vLLM API is actually responding by querying the /v1/models endpoint. | |
Returns: | |
tuple: (is_healthy, message) | |
""" | |
try: | |
import requests | |
response = requests.get(f"{VLLM_ENDPOINT}/models", timeout=5) | |
if response.status_code == 200: | |
try: | |
data = response.json() | |
if 'data' in data and len(data['data']) > 0: | |
model_id = data['data'][0].get('id', 'Unknown model') | |
return True, f"API is healthy. Available model: {model_id}" | |
else: | |
return True, "API is healthy but no models found" | |
except Exception as e: | |
return False, f"API returned 200 but invalid JSON: {str(e)}" | |
else: | |
return False, f"API returned status code: {response.status_code}" | |
except Exception as e: | |
return False, f"API request failed: {str(e)}" | |
def monitor_tunnel(): | |
""" | |
Monitor the SSH tunnel status and update the global variables. | |
""" | |
global tunnel, use_fallback, tunnel_status | |
logger.info("Starting tunnel monitoring thread") | |
while True: | |
try: | |
if tunnel is not None: | |
ssh_status = tunnel.check_status() | |
# Check if the tunnel is running | |
if ssh_status["is_running"]: | |
# Check if vLLM API is actually responding | |
is_healthy, message = check_vllm_api_health() | |
if is_healthy: | |
use_fallback = False | |
tunnel_status = { | |
"is_running": True, | |
"message": f"Connected and healthy. {message}" | |
} | |
else: | |
use_fallback = True | |
tunnel_status = { | |
"is_running": False, | |
"message": "Tunnel connected but vLLM API unhealthy" | |
} | |
else: | |
# Log the actual error for troubleshooting but don't expose it in the UI | |
logger.error(f"SSH tunnel disconnected: {ssh_status['error'] or 'Unknown error'}") | |
use_fallback = True | |
tunnel_status = { | |
"is_running": False, | |
"message": "Disconnected - Check server status" | |
} | |
else: | |
use_fallback = True | |
tunnel_status = {"is_running": False, "message": "Tunnel not initialized"} | |
except Exception as e: | |
logger.error(f"Error monitoring tunnel: {str(e)}") | |
use_fallback = True | |
tunnel_status = {"is_running": False, "message": "Monitoring error"} | |
time.sleep(5) # Check every 5 seconds | |
def get_openai_client(use_fallback_api=None): | |
""" | |
Create and return an OpenAI client configured for the appropriate endpoint. | |
Args: | |
use_fallback_api (bool): If True, use Hyperbolic API. If False, use local vLLM. | |
If None, use the global use_fallback setting. | |
Returns: | |
OpenAI: Configured OpenAI client | |
""" | |
global use_fallback | |
# Determine which API to use | |
if use_fallback_api is None: | |
use_fallback_api = use_fallback | |
if use_fallback_api: | |
logger.info("Using Hyperbolic API") | |
return OpenAI( | |
api_key=HYPERBOLIC_KEY, | |
base_url=HYPERBOLIC_ENDPOINT | |
) | |
else: | |
logger.info("Using local vLLM API") | |
return OpenAI( | |
api_key="EMPTY", # vLLM doesn't require an actual API key | |
base_url=VLLM_ENDPOINT | |
) | |
def get_model_name(use_fallback_api=None): | |
""" | |
Return the appropriate model name based on the API being used. | |
Args: | |
use_fallback_api (bool): If True, use fallback model. If None, use the global setting. | |
Returns: | |
str: Model name | |
""" | |
global use_fallback | |
if use_fallback_api is None: | |
use_fallback_api = use_fallback | |
return FALLBACK_MODEL if use_fallback_api else VLLM_MODEL | |
def convert_files_to_base64(files): | |
""" | |
Convert uploaded files to base64 strings. | |
Args: | |
files (list): List of file paths | |
Returns: | |
list: List of base64-encoded strings | |
""" | |
base64_images = [] | |
for file in files: | |
with open(file, "rb") as image_file: | |
# Read image data and encode to base64 | |
base64_data = base64.b64encode(image_file.read()).decode("utf-8") | |
base64_images.append(base64_data) | |
return base64_images | |
def process_chat(message_dict, history): | |
""" | |
Process user message and send to the appropriate API. | |
Args: | |
message_dict (dict): User message containing text and files | |
history (list): Chat history | |
Returns: | |
list: Updated chat history | |
""" | |
global use_fallback | |
text = message_dict.get("text", "") | |
files = message_dict.get("files", []) | |
# Add user message to history first | |
if not history: | |
history = [] | |
# Add user message to chat history | |
if files: | |
# For each file, add a separate user message | |
for file in files: | |
history.append({"role": "user", "content": (file,)}) | |
# Add text message if not empty | |
if text.strip(): | |
history.append({"role": "user", "content": text}) | |
else: | |
# If no text but files exist, don't add an empty message | |
if not files: | |
history.append({"role": "user", "content": ""}) | |
# Convert all files to base64 | |
base64_images = convert_files_to_base64(files) | |
# Prepare conversation history in OpenAI format | |
openai_messages = [] | |
# Convert history to OpenAI format | |
for h in history: | |
if h["role"] == "user": | |
# Handle user messages | |
if isinstance(h["content"], tuple): | |
# This is a file-only message, skip for now | |
continue | |
else: | |
# Text message | |
openai_messages.append({ | |
"role": "user", | |
"content": h["content"] | |
}) | |
elif h["role"] == "assistant": | |
openai_messages.append({ | |
"role": "assistant", | |
"content": h["content"] | |
}) | |
# Handle images for the last user message if needed | |
if base64_images: | |
# Update the last user message to include image content | |
if openai_messages and openai_messages[-1]["role"] == "user": | |
# Get the last message | |
last_msg = openai_messages[-1] | |
# Format for OpenAI multimodal content structure | |
content_list = [] | |
# Add text if there is any | |
if last_msg["content"]: | |
content_list.append({"type": "text", "text": last_msg["content"]}) | |
# Add images | |
for img_b64 in base64_images: | |
content_list.append({ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{img_b64}" | |
} | |
}) | |
# Replace the content with the multimodal content list | |
last_msg["content"] = content_list | |
# Try primary API first, fall back if needed | |
try: | |
# First try with the currently selected API (vLLM or fallback) | |
client = get_openai_client() | |
model = get_model_name() | |
response = client.chat.completions.create( | |
model=model, | |
messages=openai_messages, | |
stream=True # Use streaming for better UX | |
) | |
# Stream the response | |
assistant_message = "" | |
for chunk in response: | |
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None: | |
assistant_message += chunk.choices[0].delta.content | |
# Update in real-time | |
history_with_stream = history.copy() | |
history_with_stream.append({"role": "assistant", "content": assistant_message}) | |
yield history_with_stream | |
# Ensure we have the final message added | |
if not assistant_message: | |
assistant_message = "No response received from the model." | |
# Add assistant response to history if not already added | |
if not history or history[-1]["role"] != "assistant": | |
history.append({"role": "assistant", "content": assistant_message}) | |
return history | |
except Exception as primary_error: | |
logger.error(f"Primary API error: {str(primary_error)}") | |
# If we're not already using fallback, try that | |
if not use_fallback: | |
try: | |
logger.info("Falling back to Hyperbolic API") | |
client = get_openai_client(use_fallback_api=True) | |
model = get_model_name(use_fallback_api=True) | |
response = client.chat.completions.create( | |
model=model, | |
messages=openai_messages, | |
stream=True | |
) | |
# Stream the response | |
assistant_message = "" | |
for chunk in response: | |
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content is not None: | |
assistant_message += chunk.choices[0].delta.content | |
# Update in real-time | |
history_with_stream = history.copy() | |
history_with_stream.append({"role": "assistant", "content": assistant_message}) | |
yield history_with_stream | |
# Ensure we have the final message added | |
if not assistant_message: | |
assistant_message = "No response received from the fallback model." | |
# Add assistant response to history if not already added | |
if not history or history[-1]["role"] != "assistant": | |
history.append({"role": "assistant", "content": assistant_message}) | |
# Update fallback status (global already declared at function start) | |
use_fallback = True | |
return history | |
except Exception as fallback_error: | |
logger.error(f"Fallback API error: {str(fallback_error)}") | |
error_msg = "Error connecting to both primary and fallback APIs." | |
history.append({"role": "assistant", "content": error_msg}) | |
return history | |
else: | |
# Already using fallback, just report the error | |
error_msg = "An error occurred with the model service." | |
history.append({"role": "assistant", "content": error_msg}) | |
return history | |
def get_tunnel_status_message(): | |
""" | |
Return a formatted status message for display in the UI. | |
""" | |
global tunnel_status, use_fallback, MAX_CONCURRENT | |
api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API" | |
model = get_model_name() | |
status_color = "π’" if (tunnel_status["is_running"] and not use_fallback) else "π΄" | |
status_text = tunnel_status["message"] | |
return f"{status_color} Tunnel Status: {status_text}\nCurrent API: {api_mode}\nCurrent Model: {model}\nConcurrent Requests: {MAX_CONCURRENT}" | |
def toggle_api(): | |
""" | |
Toggle between local vLLM and Hyperbolic API. | |
""" | |
global use_fallback | |
use_fallback = not use_fallback | |
api_mode = "Hyperbolic API" if use_fallback else "Local vLLM API" | |
model = get_model_name() | |
return f"Switched to {api_mode} using {model}" | |
def update_concurrency(new_value): | |
""" | |
Update the MAX_CONCURRENT value. | |
Args: | |
new_value (str): New concurrency value as string | |
Returns: | |
str: Status message | |
""" | |
global MAX_CONCURRENT | |
try: | |
value = int(new_value) | |
if value < 1: | |
return f"Error: Concurrency must be at least 1. Keeping current value: {MAX_CONCURRENT}" | |
MAX_CONCURRENT = value | |
# Note: This only updates the value for future event handlers | |
# Existing event handlers keep their original concurrency_limit | |
# A page refresh is needed for this to fully take effect | |
return f"Concurrency updated to {MAX_CONCURRENT}. You may need to refresh the page for all changes to take effect." | |
except ValueError: | |
return f"Error: Invalid number. Keeping current value: {MAX_CONCURRENT}" | |
# Start the SSH tunnel in a background thread | |
if __name__ == "__main__": | |
# Start the SSH tunnel | |
start_ssh_tunnel() | |
# Start the monitoring thread | |
monitor_thread = threading.Thread(target=monitor_tunnel, daemon=True) | |
monitor_thread.start() | |
# Create Gradio application with Blocks for more control | |
with gr.Blocks(theme="soft") as demo: | |
gr.Markdown("# Multimodal Chat Interface") | |
# Create chatbot component with message type | |
chatbot = gr.Chatbot( | |
label="Conversation", | |
type="messages", | |
show_copy_button=True, | |
avatar_images=("π€", "π£οΈ"), | |
height=400 | |
) | |
# Create multimodal textbox for input | |
with gr.Row(): | |
textbox = gr.MultimodalTextbox( | |
file_types=["image", "video"], | |
file_count="multiple", | |
placeholder="Type your message here and/or upload images...", | |
label="Message", | |
show_label=False, | |
scale=9 | |
) | |
submit_btn = gr.Button("Send", size="sm", scale=1) | |
# Clear button | |
clear_btn = gr.Button("Clear Chat") | |
# Set up submit event chain with concurrency limit | |
submit_event = textbox.submit( | |
fn=process_chat, | |
inputs=[textbox, chatbot], | |
outputs=chatbot, | |
concurrency_limit=MAX_CONCURRENT # Set concurrency limit for this event | |
).then( | |
fn=lambda: {"text": "", "files": []}, | |
inputs=None, | |
outputs=textbox | |
) | |
# Connect the submit button to the same functions with same concurrency limit | |
submit_btn.click( | |
fn=process_chat, | |
inputs=[textbox, chatbot], | |
outputs=chatbot, | |
concurrency_limit=MAX_CONCURRENT # Set concurrency limit for this event | |
).then( | |
fn=lambda: {"text": "", "files": []}, | |
inputs=None, | |
outputs=textbox | |
) | |
# Set up clear button | |
clear_btn.click(lambda: [], None, chatbot) | |
# Load example images if they exist | |
examples = [] | |
# Define example images with paths | |
example_images = { | |
"dog_pic.jpg": "What breed is this?", | |
"ghostimg.png": "What's in this image?", | |
"newspaper.png": "Provide a python list of dicts about everything on this page." | |
} | |
# Check each image and add to examples if it exists | |
for img_name, prompt_text in example_images.items(): | |
img_path = os.path.join(os.path.dirname(__file__), img_name) | |
if os.path.exists(img_path): | |
examples.append([{"text": prompt_text, "files": [img_path]}]) | |
# Add examples if we have any | |
if examples: | |
gr.Examples( | |
examples=examples, | |
inputs=textbox | |
) | |
# Add status display | |
status_text = gr.Textbox( | |
label="Tunnel and API Status", | |
value=get_tunnel_status_message(), | |
interactive=False | |
) | |
# Refresh status button and toggle API button | |
with gr.Row(): | |
refresh_btn = gr.Button("Refresh Status") | |
# Set up refresh status button | |
refresh_btn.click( | |
fn=get_tunnel_status_message, | |
inputs=None, | |
outputs=status_text | |
) | |
# Just load the initial status without auto-refresh | |
demo.load( | |
fn=get_tunnel_status_message, | |
inputs=None, | |
outputs=status_text | |
) | |
# Launch the interface with the specified concurrency setting | |
demo.queue(default_concurrency_limit=MAX_CONCURRENT) | |
demo.launch() |