g13_DL_project / functions.py
kusa04's picture
Update functions.py
76015ae verified
raw
history blame
8.51 kB
from collections import Counter
import matplotlib.pyplot as plt
import pandas as pd
import praw # Reddit's API
import re # Regular expression module
import streamlit as st
import time
import numpy as np
from wordcloud import WordCloud
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
TokenClassificationPipeline
)
from transformers.pipelines import AggregationStrategy
# Function to normalize text by replacing multiple spaces/newlines with a single space
def normalize_text(text):
if not isinstance(text, str):
return ""
return re.sub(r'\s+', ' ', text).strip()
# ---------- Cached function for scraping Reddit data ----------
# @st.cache_data(show_spinner=False)
def scrape_reddit_data(search_query, total_limit):
# Retrieve API credentials from st.secrets
reddit = praw.Reddit(
client_id=st.secrets["reddit_client_id"],
client_secret=st.secrets["reddit_client_secret"],
user_agent=st.secrets["reddit_user_agent"]
)
subreddit = reddit.subreddit("all")
posts_data = []
# Iterate over submissions based on the search query and limit
for i, submission in enumerate(subreddit.search(search_query, sort="relevance", limit=total_limit)):
# No UI updates here as caching does not allow live progress updates
if submission.title and submission.selftext:
posts_data.append([
submission.title,
submission.url,
submission.created_utc,
submission.selftext,
])
time.sleep(0.25)
df = pd.DataFrame(posts_data, columns=["Title", "URL", "Date", "Detail"])
for col in ["Title", "Detail"]:
df[col] = df[col].apply(normalize_text)
# Filter out rows with empty Title or Detail
df = df[(df["Title"] != "") & (df["Detail"] != "")]
df['Date'] = pd.to_datetime(df['Date'], unit='s')
df = df.sort_values(by="Date", ascending=True).reset_index(drop=True)
return df
# ------------------ Sentiment Analysis Functions ------------------------#
def split_text_by_token_limit(text, tokenizer, max_tokens):
tokens = tokenizer.encode(text, add_special_tokens=False)
chunks = []
for i in range(0, len(tokens), max_tokens):
chunk_tokens = tokens[i:i+max_tokens]
chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
chunks.append(chunk_text)
return chunks
# def safe_sentiment(sentiment_pipeline, text, length, progress_bar):
# try:
# result = sentiment_pipeline(text)[0]
# except Exception as e:
# result = None
# if "count" not in st.session_state:
# st.session_state.count = 0
# st.session_state.count += 1
# progress = st.session_state.count / length
# progress_bar.progress(progress)
# return result
def safe_sentiment(sentiment_pipeline, text, length, progress_bar):
try:
result = sentiment_pipeline(text)[0]
except Exception as e:
result = None
if "count" not in st.session_state:
st.session_state.count = 0
st.session_state.count += 1
progress = st.session_state.count / length
# Clamp the progress value between 0.0 and 1.0
progress = min(max(progress, 0.0), 1.0)
progress_bar.progress(progress)
return result
def safe_sentiment_batch(sentiment_pipeline, texts):
try:
results = sentiment_pipeline(texts)
except Exception as e:
results = [None] * len(texts)
return results
def analyze_detail(text, tokenizer, sentiment_pipeline, max_tokens):
text = preprocess_text(text)
chunks = split_text_by_token_limit(text, tokenizer, max_tokens)
if not chunks:
return None
# batch processing (for each chunk)
results = safe_sentiment_batch(sentiment_pipeline, chunks)
# arrange the result
scores = {"POSITIVE": 0, "NEGATIVE": 0, "NEUTRAL": 0}
for result in results:
if result is not None:
label = result['label'].upper()
if label in scores:
scores[label] += result['score']
final_label = max(scores, key=lambda k: scores[k])
final_score = scores[final_label]
return {"label": final_label, "score": final_score}
def preprocess_text(text):
# Replace URLs and user mentions
text = re.sub(r'http\S+', 'http', text)
text = re.sub(r'@\w+', '@user', text)
return text
def generate_variants(keyword):
# Split the keyword into individual words
words = keyword.split()
# Original keyword
original = keyword
# Convert the keyword to all uppercase letters
all_upper = keyword.upper()
# Convert the keyword to all lowercase letters
all_lower = keyword.lower()
# Concatenate words with each word capitalized (no spaces)
no_space_title = ''.join(word.capitalize() for word in words)
# Concatenate words in all uppercase (no spaces)
no_space_upper = ''.join(word.upper() for word in words)
# Concatenate words in all lowercase (no spaces)
no_space_lower = ''.join(word.lower() for word in words)
# Create a string with only the first letter of each word (e.g., MHW)
initials = ''.join(word[0].upper() for word in words)
# Return all variants as a list
return [original, all_upper, all_lower, no_space_title, no_space_upper, no_space_lower, initials]
# Function to check if a cell contains any excluded keywords
def contains_excluded_keywords(cell, excluded_keywords):
if isinstance(cell, np.ndarray):
cell_str = ' '.join(map(str, cell))
return any(keyword in cell_str for keyword in excluded_keywords)
elif isinstance(cell, str):
return any(keyword in cell for keyword in excluded_keywords)
return False
# Function to extract terms from a cell
def extract_terms(cell):
if isinstance(cell, np.ndarray):
# Convert each element to a string and strip whitespace
return [str(item).strip() for item in cell if str(item).strip()]
elif isinstance(cell, str):
# Split the string by commas and strip whitespace from each term
return [term.strip() for term in cell.split(',') if term.strip()]
else:
return []
# def remove_excluded_from_list(keywords_list, excluded_keywords):
# """
# Remove items from the keywords_list if they contain any of the excluded keywords.
# This function checks for partial matches in a case-insensitive manner.
# """
# if not isinstance(keywords_list, list):
# return keywords_list # If it's not a list, return as is
# filtered_list = []
# for item in keywords_list:
# # Check if item contains any excluded keyword (case-insensitive)
# if any(kw.lower() in item.lower() for kw in excluded_keywords):
# # Skip this item if it matches an excluded keyword
# continue
# else:
# filtered_list.append(item)
# return filtered_list
def remove_excluded_from_text(text, excluded_keywords):
"""
Remove occurrences of any excluded keyword from the text.
Matching is case-insensitive. Extra whitespace is cleaned.
"""
if not isinstance(text, str):
return text
filtered_text = text
for kw in excluded_keywords:
# Create a regex pattern for the keyword (case-insensitive)
pattern = re.compile(re.escape(kw), re.IGNORECASE)
# Replace any occurrence of the keyword with an empty string
filtered_text = pattern.sub("", filtered_text)
# Remove extra spaces and strip the result
filtered_text = re.sub(r'\s+', ' ', filtered_text).strip()
return filtered_text
def process_extracted_result(result, excluded_keywords):
"""
Process an extracted result by removing excluded keywords from each string.
If result is a list, process each element; if it's a string, process it directly.
Return a list of non-empty cleaned strings.
"""
cleaned_items = []
if isinstance(result, list):
for item in result:
cleaned_item = remove_excluded_from_text(item, excluded_keywords)
if cleaned_item: # Only add non-empty strings
cleaned_items.append(cleaned_item)
elif isinstance(result, str):
cleaned_item = remove_excluded_from_text(result, excluded_keywords)
if cleaned_item:
cleaned_items.append(cleaned_item)
return cleaned_items