xujinheng666 commited on
Commit
6e645b6
·
verified ·
1 Parent(s): a8439f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import torchaudio
5
+ import os
6
+
7
+ def load_models():
8
+ st.session_state.transcription_pipe = pipeline(
9
+ task="automatic-speech-recognition",
10
+ model="alvanlii/whisper-small-cantonese",
11
+ chunk_length_s=60,
12
+ device="cuda" if torch.cuda.is_available() else "cpu"
13
+ )
14
+ st.session_state.transcription_pipe.model.config.forced_decoder_ids = st.session_state.transcription_pipe.tokenizer.get_decoder_prompt_ids(language="zh", task="transcribe")
15
+
16
+ st.session_state.translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
17
+ st.session_state.translation_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
18
+
19
+ st.session_state.summary_pipe = pipeline("text-summarization", model="facebook/bart-large-cnn")
20
+
21
+ st.session_state.rating_pipe = pipeline("text-classification", model="cardiffnlp/twitter-roberta-base-sentiment-latest")
22
+
23
+ def transcribe_audio(audio_path):
24
+ pipe = st.session_state.transcription_pipe
25
+ return pipe(audio_path)["text"]
26
+
27
+ def translate_text(text):
28
+ tokenizer = st.session_state.translation_tokenizer
29
+ model = st.session_state.translation_model
30
+ inputs = tokenizer(text, return_tensors="pt")
31
+ outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5)
32
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
33
+
34
+ def summarize_text(text):
35
+ return st.session_state.summary_pipe(text)[0]['summary_text']
36
+
37
+ def rate_quality(text):
38
+ result = st.session_state.rating_pipe(text)[0]
39
+ label_map = {"LABEL_0": "Poor", "LABEL_1": "Average", "LABEL_2": "Good"}
40
+ return label_map.get(result["label"], "Unknown")
41
+
42
+ def main():
43
+ st.title("Audio Processing & Conversation Quality Rating")
44
+
45
+ if "transcription_pipe" not in st.session_state:
46
+ with st.spinner("Loading models..."):
47
+ load_models()
48
+
49
+ uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "m4a"])
50
+
51
+ if uploaded_file is not None:
52
+ with st.spinner("Processing audio..."):
53
+ file_path = "temp_audio.wav"
54
+ with open(file_path, "wb") as f:
55
+ f.write(uploaded_file.read())
56
+
57
+ transcript = transcribe_audio(file_path)
58
+ translation = translate_text(transcript)
59
+ summary = summarize_text(translation)
60
+ rating = rate_quality(translation)
61
+
62
+ os.remove(file_path)
63
+
64
+ st.subheader("Transcription")
65
+ st.write(transcript)
66
+
67
+ st.subheader("Translation (English)")
68
+ st.write(translation)
69
+
70
+ st.subheader("Summary")
71
+ st.write(summary)
72
+
73
+ st.subheader("Conversation Quality Rating")
74
+ st.write(rating)
75
+
76
+ if __name__ == "__main__":
77
+ main()