xujinheng666 commited on
Commit
b20dd94
·
verified ·
1 Parent(s): d974db8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -29
app.py CHANGED
@@ -1,8 +1,9 @@
 
1
  import torch
2
  import torchaudio
3
  import os
4
  import re
5
- import streamlit as st
6
  from difflib import SequenceMatcher
7
  from transformers import pipeline
8
 
@@ -16,46 +17,82 @@ pipe = pipeline(
16
  task="automatic-speech-recognition",
17
  model=MODEL_NAME,
18
  chunk_length_s=60,
19
- device=device
 
 
 
 
 
 
 
 
 
20
  )
21
  pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
22
 
23
- # Load quality rating model
24
- rating_pipe = pipeline("text-classification", model="tabularisai/multilingual-sentiment-analysis")
25
 
26
- # Sentiment label mapping
27
- label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
 
 
 
 
 
28
 
29
  def remove_punctuation(text):
30
  return re.sub(r'[^\w\s]', '', text)
31
 
32
  def transcribe_audio(audio_path):
33
- transcript = pipe(audio_path)["text"]
34
- return remove_punctuation(transcript)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def rate_quality(text):
37
- result = rating_pipe(text)[0]
38
- return label_map.get(result["label"], "Unknown")
 
 
 
 
 
 
 
39
 
40
  # Streamlit UI
41
- st.set_page_config(page_title="Cantonese Audio Transcription & Analysis", layout="centered")
42
- st.title("🗣️ Customer Service Conversation Quality Analyzer")
43
- st.markdown("Upload your Cantonese audio file, and we will transcribe and analyze its sentiment.")
44
-
45
- uploaded_file = st.file_uploader("Upload an audio file (WAV, MP3, etc.)", type=["wav", "mp3", "m4a"])
46
- if uploaded_file is not None:
47
- with st.spinner("Processing audio..."):
48
- temp_audio_path = "temp_audio.wav"
49
- with open(temp_audio_path, "wb") as f:
50
- f.write(uploaded_file.getbuffer())
51
- transcript = transcribe_audio(temp_audio_path)
52
- sentiment = rate_quality(transcript)
53
- os.remove(temp_audio_path)
54
-
55
- st.subheader("Transcription")
56
- st.text_area("", transcript, height=150)
57
 
58
- st.subheader("Sentiment Analysis")
59
- st.markdown(f"### 🎭 Sentiment: **{sentiment}**")
 
60
 
61
- st.success("Processing complete! 🎉")
 
 
1
+ import streamlit as st
2
  import torch
3
  import torchaudio
4
  import os
5
  import re
6
+ import jieba
7
  from difflib import SequenceMatcher
8
  from transformers import pipeline
9
 
 
17
  task="automatic-speech-recognition",
18
  model=MODEL_NAME,
19
  chunk_length_s=60,
20
+ device=device,
21
+ generate_kwargs={
22
+ "no_repeat_ngram_size": 4,
23
+ "repetition_penalty": 1.15,
24
+ "temperature": 0.5,
25
+ "top_p": 0.97,
26
+ "top_k": 40,
27
+ "max_new_tokens": 300,
28
+ "do_sample": True
29
+ }
30
  )
31
  pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
32
 
33
+ def is_similar(a, b, threshold=0.8):
34
+ return SequenceMatcher(None, a, b).ratio() > threshold
35
 
36
+ def remove_repeated_phrases(text):
37
+ sentences = re.split(r'(?<=[。!?])', text)
38
+ cleaned_sentences = []
39
+ for i, sentence in enumerate(sentences):
40
+ if i == 0 or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()):
41
+ cleaned_sentences.append(sentence.strip())
42
+ return " ".join(cleaned_sentences)
43
 
44
  def remove_punctuation(text):
45
  return re.sub(r'[^\w\s]', '', text)
46
 
47
  def transcribe_audio(audio_path):
48
+ waveform, sample_rate = torchaudio.load(audio_path)
49
+ duration = waveform.shape[1] / sample_rate
50
+ if duration > 60:
51
+ results = []
52
+ for start in range(0, int(duration), 55):
53
+ end = min(start + 60, int(duration))
54
+ chunk = waveform[:, start * sample_rate:end * sample_rate]
55
+ if chunk.shape[1] == 0:
56
+ continue
57
+ temp_filename = f"temp_chunk_{start}.wav"
58
+ torchaudio.save(temp_filename, chunk, sample_rate)
59
+ if os.path.exists(temp_filename):
60
+ try:
61
+ result = pipe(temp_filename)["text"]
62
+ results.append(remove_punctuation(result))
63
+ finally:
64
+ os.remove(temp_filename)
65
+ return remove_punctuation(remove_repeated_phrases(" ".join(results)))
66
+ return remove_punctuation(remove_repeated_phrases(pipe(audio_path)["text"]))
67
+
68
+ # Load quality rating model
69
+ rating_pipe = pipeline("text-classification", model="tabularisai/multilingual-sentiment-analysis")
70
 
71
  def rate_quality(text):
72
+ chunks = [text[i:i+512] for i in range(0, len(text), 512)]
73
+ results = []
74
+ label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
75
+
76
+ for chunk in chunks:
77
+ result = rating_pipe(chunk)[0]
78
+ results.append(label_map.get(result["label"], "Unknown"))
79
+
80
+ return max(set(results), key=results.count)
81
 
82
  # Streamlit UI
83
+ st.title("Cantonese Audio Transcription and Quality Rating")
84
+ st.write("Upload your Cantonese audio file to get the transcription and quality rating.")
85
+
86
+ audio_file = st.file_uploader("Upload Audio File", type=["wav", "mp3", "m4a"])
87
+
88
+ if audio_file is not None:
89
+ audio_path = audio_file.name
90
+ with open(audio_path, "wb") as f:
91
+ f.write(audio_file.getbuffer())
 
 
 
 
 
 
 
92
 
93
+ st.write("Processing audio...")
94
+ transcript = transcribe_audio(audio_path)
95
+ st.write("**Transcript:**", transcript)
96
 
97
+ quality_rating = rate_quality(transcript)
98
+ st.write("**Quality Rating:**", quality_rating)