|
import io |
|
from flask import Flask, Response, send_from_directory, jsonify, request, abort |
|
import os |
|
from flask_cors import CORS |
|
from multiprocessing import Queue |
|
import base64 |
|
from typing import Any, Dict, Tuple |
|
from multiprocessing import Queue |
|
import logging |
|
import sys |
|
from threading import Lock |
|
from multiprocessing import Manager |
|
|
|
import torch |
|
|
|
from server.AudioTranscriber import AudioTranscriber |
|
from server.ActionProcessor import ActionProcessor |
|
from server.StandaloneApplication import StandaloneApplication |
|
from server.TextFilterer import TextFilterer |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
STATIC_DIR = ( |
|
"/app/server/static" |
|
if os.getenv("DEBUG") != "True" |
|
else os.path.join(os.getcwd(), "html") |
|
) |
|
|
|
|
|
audio_queue: "Queue[Tuple[io.BytesIO, str]]" = Queue() |
|
text_queue: "Queue[Tuple[str, str]]" = Queue() |
|
filtered_text_queue: "Queue[Tuple[str, str]]" = Queue() |
|
action_queue: "Queue[Tuple[Dict[str, Any], str]]" = Queue() |
|
|
|
|
|
action_storage_lock = Lock() |
|
manager = Manager() |
|
action_storage = manager.dict() |
|
|
|
app = Flask(__name__, static_folder=STATIC_DIR) |
|
|
|
_ = CORS( |
|
app, |
|
origins=["*"], |
|
methods=["GET", "POST", "OPTIONS"], |
|
allow_headers=["Content-Type", "Authorization"], |
|
) |
|
|
|
|
|
@app.after_request |
|
def add_header(response: Response): |
|
|
|
response.headers["Access-Control-Allow-Origin"] = "*" |
|
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" |
|
response.headers["Access-Control-Allow-Headers"] = "*" |
|
|
|
response.headers["Cross-Origin-Embedder-Policy"] = "require-corp" |
|
response.headers["Cross-Origin-Opener-Policy"] = "same-origin" |
|
response.headers["Cross-Origin-Resource-Policy"] = "cross-origin" |
|
return response |
|
|
|
|
|
@app.route("/") |
|
def serve_index(): |
|
try: |
|
response = send_from_directory(app.static_folder, "index.html") |
|
response.headers["Cross-Origin-Opener-Policy"] = "same-origin" |
|
response.headers["Cross-Origin-Embedder-Policy"] = "require-corp" |
|
return response |
|
except FileNotFoundError: |
|
abort( |
|
404, |
|
description=f"Static folder or index.html not found. Static folder: {app.static_folder}", |
|
) |
|
|
|
|
|
@app.route("/api/data", methods=["GET"]) |
|
def get_data(): |
|
return jsonify({"status": "success"}) |
|
|
|
|
|
@app.route("/api/order", methods=["POST"]) |
|
def post_order() -> Tuple[Response, int]: |
|
try: |
|
data = request.get_json() |
|
if not data or "action" not in data: |
|
return ( |
|
jsonify({"error": "Missing 'action' in request", "status": "error"}), |
|
400, |
|
) |
|
|
|
action_text: str = data["action"] |
|
token = request.args.get("token") |
|
if not token: |
|
return jsonify({"error": "Missing token parameter", "status": "error"}), 400 |
|
|
|
mid_split = len(action_text) // 2 |
|
|
|
text_queue.put((action_text[:mid_split], token)) |
|
text_queue.put((action_text, token)) |
|
text_queue.put((action_text[mid_split:], token)) |
|
|
|
return jsonify({"status": "success"}), 200 |
|
|
|
except Exception as e: |
|
return ( |
|
jsonify( |
|
{"error": f"Failed to process request: {str(e)}", "status": "error"} |
|
), |
|
500, |
|
) |
|
|
|
|
|
@app.route("/api/process", methods=["POST"]) |
|
def process_data(): |
|
try: |
|
content_type = request.headers.get("Content-Type", "") |
|
token = request.args.get("token") |
|
if not token: |
|
return jsonify({"error": "Missing token parameter", "status": "error"}), 400 |
|
|
|
|
|
if "application/json" in content_type: |
|
data = request.get_json() |
|
audio_base64 = data.get("audio_chunk") |
|
elif "multipart/form-data" in content_type: |
|
audio_base64 = request.form.get("audio_chunk") |
|
else: |
|
|
|
audio_base64 = request.get_data().decode("utf-8") |
|
|
|
|
|
if not audio_base64: |
|
return ( |
|
jsonify({"error": "Missing audio_chunk in request", "status": "error"}), |
|
400, |
|
) |
|
|
|
|
|
try: |
|
audio_chunk = base64.b64decode(audio_base64) |
|
except Exception as e: |
|
return ( |
|
jsonify( |
|
{ |
|
"error": f"Failed to decode audio chunk: {str(e)}", |
|
"status": "error", |
|
} |
|
), |
|
400, |
|
) |
|
|
|
|
|
audio_queue.put((io.BytesIO(audio_chunk), token)) |
|
|
|
return jsonify( |
|
{ |
|
"status": "success", |
|
} |
|
) |
|
except Exception as e: |
|
return ( |
|
jsonify( |
|
{"error": f"Failed to process request: {str(e)}", "status": "error"} |
|
), |
|
500, |
|
) |
|
|
|
|
|
@app.route("/api/actions", methods=["GET"]) |
|
def get_actions() -> Tuple[Response, int]: |
|
"""Retrieve and clear all pending actions for the current session""" |
|
token = request.args.get("token") |
|
if not token: |
|
return jsonify({"actions": [], "status": "error"}), 400 |
|
|
|
with action_storage_lock: |
|
|
|
actions = action_storage.get(token, []) |
|
action_storage[token] = [] |
|
|
|
return jsonify({"actions": actions, "status": "success"}), 200 |
|
|
|
|
|
@app.route("/<path:path>") |
|
def serve_static(path: str): |
|
try: |
|
return send_from_directory(app.static_folder, path) |
|
except FileNotFoundError: |
|
abort(404, description=f"File {path} not found in static folder") |
|
|
|
|
|
class ActionConsumer: |
|
def __init__(self, action_queue: Queue): |
|
self.action_queue = action_queue |
|
self.running = True |
|
|
|
def start(self): |
|
import threading |
|
|
|
self.thread = threading.Thread(target=self.run, daemon=True) |
|
self.thread.start() |
|
|
|
def run(self): |
|
while self.running: |
|
try: |
|
action, token = self.action_queue.get() |
|
with action_storage_lock: |
|
if token not in action_storage: |
|
action_storage[token] = [] |
|
current_actions = action_storage[token] |
|
current_actions.append(action) |
|
action_storage[token] = current_actions |
|
except Exception as e: |
|
logger.error(f"Error in ActionConsumer: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
if os.path.exists(app.static_folder): |
|
logger.info(f"Static folder contents: {os.listdir(app.static_folder)}") |
|
|
|
os.makedirs(app.static_folder, exist_ok=True) |
|
|
|
num_devices = torch.cuda.device_count() |
|
|
|
|
|
transcribers = [ |
|
AudioTranscriber(audio_queue, text_queue, device_index=i % num_devices) |
|
for i in range(4 if os.getenv("DEBUG") == "True" else 40) |
|
] |
|
for transcriber in transcribers: |
|
transcriber.start() |
|
|
|
|
|
action_consumer = ActionConsumer(action_queue) |
|
action_consumer.start() |
|
|
|
|
|
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") |
|
if not MISTRAL_API_KEY: |
|
raise ValueError("MISTRAL_API_KEY is not set") |
|
|
|
filterer = TextFilterer(text_queue, filtered_text_queue) |
|
filterer.start() |
|
|
|
actions_processors = [ |
|
ActionProcessor(filtered_text_queue, action_queue, MISTRAL_API_KEY) |
|
for _ in range(4 if os.getenv("DEBUG") == "True" else 16) |
|
] |
|
for actions_processor in actions_processors: |
|
actions_processor.start() |
|
|
|
options: Any = { |
|
"bind": "0.0.0.0:7860", |
|
"workers": 3, |
|
"worker_class": "sync", |
|
"timeout": 120, |
|
"forwarded_allow_ips": "*", |
|
"accesslog": None, |
|
"errorlog": "-", |
|
"capture_output": True, |
|
"enable_stdio_inheritance": True, |
|
} |
|
|
|
StandaloneApplication(app, options).run() |
|
|