import base64 import faster_whisper import tempfile import torch import time import requests import logging from fastapi import FastAPI, HTTPException, WebSocket,WebSocketDisconnect import websockets from pydantic import BaseModel from typing import Optional import asyncio # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') device = 'cuda' if torch.cuda.is_available() else 'cpu' logging.info(f'Device selected: {device}') model_name = 'ivrit-ai/faster-whisper-v2-d4' logging.info(f'Loading model: {model_name}') model = faster_whisper.WhisperModel(model_name, device=device) logging.info('Model loaded successfully') # Maximum data size: 200MB MAX_PAYLOAD_SIZE = 200 * 1024 * 1024 logging.info(f'Max payload size set to: {MAX_PAYLOAD_SIZE} bytes') app = FastAPI() class InputData(BaseModel): type: str data: Optional[str] = None # Used for blob input url: Optional[str] = None # Used for url input def download_file(url, max_size_bytes, output_filename, api_key=None): """ Download a file from a given URL with size limit and optional API key. """ logging.debug(f'Starting file download from URL: {url}') try: headers = {} if api_key: headers['Authorization'] = f'Bearer {api_key}' logging.debug('API key provided, added to headers') response = requests.get(url, stream=True, headers=headers) response.raise_for_status() file_size = int(response.headers.get('Content-Length', 0)) logging.info(f'File size: {file_size} bytes') if file_size > max_size_bytes: logging.error(f'File size exceeds limit: {file_size} > {max_size_bytes}') return False downloaded_size = 0 with open(output_filename, 'wb') as file: for chunk in response.iter_content(chunk_size=8192): downloaded_size += len(chunk) logging.debug(f'Downloaded {downloaded_size} bytes') if downloaded_size > max_size_bytes: logging.error('Downloaded size exceeds maximum allowed payload size') return False file.write(chunk) logging.info(f'File downloaded successfully: {output_filename}') return True except requests.RequestException as e: logging.error(f"Error downloading file: {e}") return False @app.get("/") async def read_root(): return {"message": "This is the Ivrit AI Streaming service."} @app.post("/transcribe") async def transcribe(input_data: InputData): logging.info(f'Received transcription request with data: {input_data}') datatype = input_data.type if not datatype: logging.error('datatype field not provided') raise HTTPException(status_code=400, detail="datatype field not provided. Should be 'blob' or 'url'.") if datatype not in ['blob', 'url']: logging.error(f'Invalid datatype: {datatype}') raise HTTPException(status_code=400, detail=f"datatype should be 'blob' or 'url', but is {datatype} instead.") with tempfile.TemporaryDirectory() as d: audio_file = f'{d}/audio.mp3' logging.debug(f'Created temporary directory: {d}') if datatype == 'blob': if not input_data.data: logging.error("Missing 'data' for 'blob' input") raise HTTPException(status_code=400, detail="Missing 'data' for 'blob' input.") logging.info('Decoding base64 blob data') mp3_bytes = base64.b64decode(input_data.data) open(audio_file, 'wb').write(mp3_bytes) logging.info(f'Audio file written: {audio_file}') elif datatype == 'url': if not input_data.url: logging.error("Missing 'url' for 'url' input") raise HTTPException(status_code=400, detail="Missing 'url' for 'url' input.") logging.info(f'Downloading file from URL: {input_data.url}') success = download_file(input_data.url, MAX_PAYLOAD_SIZE, audio_file, None) if not success: logging.error(f"Error downloading data from {input_data.url}") raise HTTPException(status_code=400, detail=f"Error downloading data from {input_data.url}") result = transcribe_core(audio_file) return {"result": result} def transcribe_core(audio_file): logging.info('Starting transcription...') ret = {'segments': []} segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True) logging.info('Transcription completed') for s in segs: words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words] seg = { 'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob, 'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words } logging.info(f'Transcription segment: {seg}') ret['segments'].append(seg) return ret def transcribe_core_ws(audio_file, last_transcribed_time): """ Transcribe the audio file and return only the segments that have not been processed yet. :param audio_file: Path to the growing audio file. :param last_transcribed_time: The last time (in seconds) that was transcribed. :return: Newly transcribed segments and the updated last transcribed time. """ logging.info(f"Starting transcription for file: {audio_file} from {last_transcribed_time} seconds.") ret = {'new_segments': []} new_last_transcribed_time = last_transcribed_time try: # Transcribe the entire audio file logging.debug(f"Initiating model transcription for file: {audio_file}") segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True) logging.info('Transcription completed successfully.') except Exception as e: logging.error(f"Error during transcription: {e}") raise e # Track the new segments and update the last transcribed time for s in segs: logging.info(f"Processing segment with start time: {s.start} and end time: {s.end}") # Only process segments that start after the last transcribed time if s.start >= last_transcribed_time: logging.info(f"New segment found starting at {s.start} seconds.") words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words] seg = { 'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob, 'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words } logging.info(f'Adding new transcription segment: {seg}') ret['new_segments'].append(seg) # Update the last transcribed time to the end of the current segment new_last_transcribed_time = max(new_last_transcribed_time, s.end) logging.debug(f"Updated last transcribed time to: {new_last_transcribed_time} seconds") logging.info(f"Returning {len(ret['new_segments'])} new segments and updated last transcribed time.") return ret, new_last_transcribed_time import tempfile @app.websocket("/ws/transcribe") async def websocket_transcribe(websocket: WebSocket): logging.info("New WebSocket connection request received.") await websocket.accept() logging.info("WebSocket connection established successfully.") try: processed_segments = [] # Keeps track of the segments already transcribed audio_data = bytearray() # Buffer for audio chunks logging.info("Initialized processed_segments and audio_data buffer.") # A temporary file to store the growing audio data with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_audio_file: logging.info(f"Temporary audio file created at {temp_audio_file.name}") # Continuously receive and process audio chunks while True: try: logging.info("Waiting to receive the next chunk of audio data from WebSocket.") # Receive the next chunk of audio data audio_chunk = await websocket.receive_bytes() logging.info(f"Received an audio chunk of size {len(audio_chunk)} bytes.") if not audio_chunk: logging.warning("Received empty audio chunk, skipping processing.") continue temp_audio_file.write(audio_chunk) temp_audio_file.flush() logging.debug(f"Written audio chunk to temporary file: {temp_audio_file.name}") audio_data.extend(audio_chunk) # In-memory data buffer (if needed) logging.debug(f"Audio data buffer extended to size {len(audio_data)} bytes.") # Perform transcription and track new segments logging.info( f"Transcribing audio from {temp_audio_file.name}. Processed segments: {len(processed_segments)}") partial_result, processed_segments = transcribe_core_ws(temp_audio_file.name, processed_segments) logging.info( f"Transcription completed. Sending {len(partial_result['new_segments'])} new segments to the client.") # Send the new transcription result back to the client await websocket.send_json(partial_result) except WebSocketDisconnect: logging.info("WebSocket connection closed by the client. Ending transcription session.") break except Exception as e: logging.error(f"Error processing audio chunk: {e}") await websocket.send_json({"error": str(e)}) break except Exception as e: logging.error(f"Unexpected error during WebSocket transcription: {e}") await websocket.send_json({"error": str(e)}) finally: logging.info("Cleaning up and closing WebSocket connection.")