xujinheng666 commited on
Commit
9afd3be
Β·
verified Β·
1 Parent(s): b20dd94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -70
app.py CHANGED
@@ -2,97 +2,66 @@ import streamlit as st
2
  import torch
3
  import torchaudio
4
  import os
5
- import re
6
- import jieba
7
- from difflib import SequenceMatcher
8
  from transformers import pipeline
9
 
10
  # Device setup
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # Load Whisper model for transcription
14
  MODEL_NAME = "alvanlii/whisper-small-cantonese"
15
  language = "zh"
16
- pipe = pipeline(
17
  task="automatic-speech-recognition",
18
  model=MODEL_NAME,
19
  chunk_length_s=60,
20
- device=device,
21
- generate_kwargs={
22
- "no_repeat_ngram_size": 4,
23
- "repetition_penalty": 1.15,
24
- "temperature": 0.5,
25
- "top_p": 0.97,
26
- "top_k": 40,
27
- "max_new_tokens": 300,
28
- "do_sample": True
29
- }
30
  )
31
- pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
32
 
33
- def is_similar(a, b, threshold=0.8):
34
- return SequenceMatcher(None, a, b).ratio() > threshold
35
 
36
- def remove_repeated_phrases(text):
37
- sentences = re.split(r'(?<=[γ€‚οΌοΌŸ])', text)
38
- cleaned_sentences = []
39
- for i, sentence in enumerate(sentences):
40
- if i == 0 or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
41
- cleaned_sentences.append(sentence.strip())
42
- return " ".join(cleaned_sentences)
43
 
44
- def remove_punctuation(text):
45
- return re.sub(r'[^\w\s]', '', text)
46
 
47
  def transcribe_audio(audio_path):
48
- waveform, sample_rate = torchaudio.load(audio_path)
49
- duration = waveform.shape[1] / sample_rate
50
- if duration > 60:
51
- results = []
52
- for start in range(0, int(duration), 55):
53
- end = min(start + 60, int(duration))
54
- chunk = waveform[:, start * sample_rate:end * sample_rate]
55
- if chunk.shape[1] == 0:
56
- continue
57
- temp_filename = f"temp_chunk_{start}.wav"
58
- torchaudio.save(temp_filename, chunk, sample_rate)
59
- if os.path.exists(temp_filename):
60
- try:
61
- result = pipe(temp_filename)["text"]
62
- results.append(remove_punctuation(result))
63
- finally:
64
- os.remove(temp_filename)
65
- return remove_punctuation(remove_repeated_phrases(" ".join(results)))
66
- return remove_punctuation(remove_repeated_phrases(pipe(audio_path)["text"]))
67
 
68
- # Load quality rating model
69
- rating_pipe = pipeline("text-classification", model="tabularisai/multilingual-sentiment-analysis")
70
 
71
  def rate_quality(text):
72
- chunks = [text[i:i+512] for i in range(0, len(text), 512)]
73
- results = []
74
  label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
 
 
 
 
 
 
 
75
 
76
- for chunk in chunks:
77
- result = rating_pipe(chunk)[0]
78
- results.append(label_map.get(result["label"], "Unknown"))
 
79
 
80
- return max(set(results), key=results.count)
81
-
82
- # Streamlit UI
83
- st.title("Cantonese Audio Transcription and Quality Rating")
84
- st.write("Upload your Cantonese audio file to get the transcription and quality rating.")
85
-
86
- audio_file = st.file_uploader("Upload Audio File", type=["wav", "mp3", "m4a"])
87
-
88
- if audio_file is not None:
89
- audio_path = audio_file.name
90
- with open(audio_path, "wb") as f:
91
- f.write(audio_file.getbuffer())
92
 
93
- st.write("Processing audio...")
94
- transcript = transcribe_audio(audio_path)
95
- st.write("**Transcript:**", transcript)
 
96
 
97
- quality_rating = rate_quality(transcript)
98
- st.write("**Quality Rating:**", quality_rating)
 
2
  import torch
3
  import torchaudio
4
  import os
 
 
 
5
  from transformers import pipeline
6
 
7
  # Device setup
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
+ # Load Whisper model for Cantonese transcription
11
  MODEL_NAME = "alvanlii/whisper-small-cantonese"
12
  language = "zh"
13
+ transcriber = pipeline(
14
  task="automatic-speech-recognition",
15
  model=MODEL_NAME,
16
  chunk_length_s=60,
17
+ device=device
 
 
 
 
 
 
 
 
 
18
  )
19
+ transcriber.model.config.forced_decoder_ids = transcriber.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
20
 
21
+ # Load Summarization model
22
+ summarizer = pipeline("summarization", model="Ayaka/bart-base-cantonese")
23
 
24
+ # Load quality rating model
25
+ rating_pipe = pipeline("text-classification", model="tabularisai/multilingual-sentiment-analysis")
26
+
27
+ # Streamlit UI setup
28
+ st.set_page_config(page_title="Cantonese Audio Analysis", layout="centered")
29
+ st.title("🌟 Cantonese Audio Analysis")
30
+ st.write("Upload a Cantonese audio file to transcribe, summarize, and evaluate its quality.")
31
 
32
+ # File uploader
33
+ audio_file = st.file_uploader("Upload your audio file (WAV format)", type=["wav"])
34
 
35
  def transcribe_audio(audio_path):
36
+ return transcriber(audio_path)["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ def summarize_text(text):
39
+ return summarizer(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
40
 
41
  def rate_quality(text):
42
+ result = rating_pipe(text[:512])[0]
 
43
  label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
44
+ return label_map.get(result["label"], "Unknown")
45
+
46
+ if audio_file:
47
+ st.audio(audio_file, format="audio/wav")
48
+ temp_audio_path = "temp_audio.wav"
49
+ with open(temp_audio_path, "wb") as f:
50
+ f.write(audio_file.read())
51
 
52
+ with st.spinner("Transcribing audio..."):
53
+ transcript = transcribe_audio(temp_audio_path)
54
+ st.subheader("πŸ“ Transcript")
55
+ st.write(transcript)
56
 
57
+ with st.spinner("Summarizing transcript..."):
58
+ summary = summarize_text(transcript)
59
+ st.subheader("πŸ“– Summary")
60
+ st.write(summary)
 
 
 
 
 
 
 
 
61
 
62
+ with st.spinner("Evaluating conversation quality..."):
63
+ quality_rating = rate_quality(summary)
64
+ st.subheader("πŸ† Quality Rating")
65
+ st.write(f"**{quality_rating}**")
66
 
67
+ os.remove(temp_audio_path)