Anupam007's picture
Rename app to app.py
b815b2c verified
import gradio as gr
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import time
import os
import numpy as np
import soundfile as sf
import librosa
# --- Configuration ---
# Device selection (GPU if available, else CPU)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Using device: {device}")
# STT Model (Use smaller model for lower latency)
stt_model_id = "openai/whisper-tiny" # Or "openai/whisper-base". Avoid larger models for streaming.
# Summarization Model
summarizer_model_id = "sshleifer/distilbart-cnn-6-6" # Use a distilled/smaller model for speed
# Summarization Interval (seconds) - How often to regenerate the summary
SUMMARY_INTERVAL = 30.0 # Summarize every 30 seconds
# --- Load Models ---
# (Keep the model loading code exactly the same as before)
print("Loading STT model...")
stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
stt_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
stt_model.to(device)
processor = AutoProcessor.from_pretrained(stt_model_id)
stt_pipeline = pipeline(
"automatic-speech-recognition",
model=stt_model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
print("STT model loaded.")
print("Loading Summarization pipeline...")
summarizer = pipeline(
"summarization",
model=summarizer_model_id,
device=device
)
print("Summarization pipeline loaded.")
# --- Helper Functions ---
# (Keep the format_summary_as_bullets function exactly the same)
def format_summary_as_bullets(summary_text):
"""Attempts to format a summary text block into bullet points."""
if not summary_text:
return ""
# Simple approach: split by sentences and add bullets.
# More advanced NLP could be used here.
sentences = summary_text.replace(". ", ".\n- ").split('\n')
bullet_summary = "- " + "\n".join(sentences).strip()
# Remove potential empty bullets
bullet_summary = "\n".join([line for line in bullet_summary.split('\n') if line.strip() not in ['-', '']])
return bullet_summary
# --- Processing Function for Streaming ---
# (Keep the process_audio_stream function exactly the same)
# This function ONLY processes audio, it doesn't interact with the webcam video
def process_audio_stream(
new_chunk_tuple, # Gradio streaming yields (sample_rate, numpy_data)
accumulated_transcript_state, # gr.State holding the full text
last_summary_time_state, # gr.State holding the timestamp of the last summary
current_summary_state # gr.State holding the last generated summary
):
if new_chunk_tuple is None:
# Initial call or stream ended, return current state
return accumulated_transcript_state, current_summary_state, accumulated_transcript_state, last_summary_time_state, current_summary_state
sample_rate, audio_chunk = new_chunk_tuple
if audio_chunk is None or sample_rate is None or audio_chunk.size == 0:
# Handle potential empty chunks gracefully
return accumulated_transcript_state, current_summary_state, accumulated_transcript_state, last_summary_time_state, current_summary_state
print(f"Received chunk: {audio_chunk.shape}, Sample Rate: {sample_rate}, Duration: {len(audio_chunk)/sample_rate:.2f}s")
# Ensure audio is float32 and mono, as Whisper expects
if audio_chunk.dtype != np.float32:
# Normalize assuming input is int16
# Adjust if your microphone provides different integer types
audio_chunk = audio_chunk.astype(np.float32) / 32768.0 # Max value for int16 is 32767
# --- 1. Transcribe the new chunk ---
new_text = ""
try:
result = stt_pipeline({"sampling_rate": sample_rate, "raw": audio_chunk.copy()})
new_text = result["text"].strip() if result["text"] else ""
print(f"Transcription chunk: '{new_text}'")
except Exception as e:
print(f"Error during transcription chunk: {e}")
new_text = f"[Transcription Error: {e}]"
# --- 2. Update Accumulated Transcript ---
if accumulated_transcript_state and not accumulated_transcript_state.endswith((" ", "\n")) and new_text:
updated_transcript = accumulated_transcript_state + " " + new_text
else:
updated_transcript = accumulated_transcript_state + new_text
# --- 3. Periodic Summarization ---
current_time = time.time()
new_summary = current_summary_state # Keep the old summary by default
updated_last_summary_time = last_summary_time_state
# Check transcript length to avoid summarizing tiny bits of text too early
if updated_transcript and len(updated_transcript) > 50 and (current_time - last_summary_time_state > SUMMARY_INTERVAL):
print(f"Summarizing transcript (length: {len(updated_transcript)})...")
try:
# Summarize the *entire* transcript up to this point
summary_result = summarizer(updated_transcript, max_length=150, min_length=30, do_sample=False)
if summary_result and isinstance(summary_result, list):
raw_summary = summary_result[0]['summary_text']
new_summary = format_summary_as_bullets(raw_summary)
updated_last_summary_time = current_time # Update time only on successful summary
print("Summary updated.")
else:
print("Summarization did not produce expected output.")
except Exception as e:
print(f"Error during summarization: {e}")
# Display error in summary box but keep the last known good summary in state
# To avoid overwriting a potentially useful summary with just an error message
# We return the error message for display, but not update summary_state with it
error_display_summary = f"[Summarization Error]\n\nLast good summary:\n{current_summary_state}"
return updated_transcript, error_display_summary, updated_transcript, last_summary_time_state, current_summary_state
# --- 4. Return Updated State and Outputs ---
return updated_transcript, new_summary, updated_transcript, updated_last_summary_time, new_summary
# --- Gradio Interface ---
print("Creating Gradio interface...")
with gr.Blocks() as demo:
gr.Markdown("# Real-Time Meeting Notes with Webcam View")
gr.Markdown("Speak into your microphone. Transcription appears below. Summary updates periodically.")
# State variables to store data between stream calls
transcript_state = gr.State("") # Holds the full transcript
last_summary_time = gr.State(0.0) # Holds the time the summary was last generated
summary_state = gr.State("") # Holds the current bullet point summary
with gr.Row():
with gr.Column(scale=1):
# Input: Microphone stream
audio_stream = gr.Audio(sources=["microphone"], streaming=True, label="Live Microphone Input", type="numpy")
# NEW: Webcam Display
# Use gr.Image which is simpler for just displaying webcam feed
# live=True makes it update continuously
webcam_view = gr.Image(sources=["webcam"], label="Your Webcam", streaming=True) # Use streaming=True for live view
with gr.Column(scale=2):
transcription_output = gr.Textbox(label="Full Transcription", lines=15, interactive=False) # Display only
summary_output = gr.Textbox(label=f"Bullet Point Summary (Updates ~every {SUMMARY_INTERVAL}s)", lines=10, interactive=False) # Display only
# Connect the streaming audio input to the processing function
# Note: The webcam component runs independently in the browser, it doesn't feed data here
audio_stream.stream(
fn=process_audio_stream,
inputs=[audio_stream, transcript_state, last_summary_time, summary_state],
outputs=[transcription_output, summary_output, transcript_state, last_summary_time, summary_state],
)
# Add a button to clear the state if needed
def clear_state_values():
print("Clearing state.")
return "", "", 0.0, "" # Clear transcript display, summary display, reset time state, clear summary state
# Need separate function to clear states vs displays if they differ
def clear_state():
return "", 0.0, "" # Clear transcript_state, last_summary_time, summary_state
clear_button = gr.Button("Clear Transcript & Summary")
# This button clears the display textboxes AND resets the internal states
clear_button.click(
fn=lambda: ("", "", "", 0.0, ""), # Return empty values for all outputs/states
inputs=[],
outputs=[transcription_output, summary_output, transcript_state, last_summary_time, summary_state]
)
print("Launching Gradio interface...")
demo.queue() # Enable queue for handling multiple requests/stream chunks
demo.launch(debug=True, share=True) # share=True for Colab public link