g13_DL_project / app.py
kusa04's picture
Update app.py
17e449a verified
from collections import Counter
from concurrent.futures import ThreadPoolExecutor # palarell processing
import matplotlib.pyplot as plt
import pandas as pd
import praw # Reddit's API
import re # Regular expression module
import streamlit as st
import time
import torch
import numpy as np
from wordcloud import WordCloud
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
TokenClassificationPipeline,
T5Tokenizer,
T5ForConditionalGeneration,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
from transformers.pipelines import AggregationStrategy
from functions import (
scrape_reddit_data,
safe_sentiment,
analyze_detail,
preprocess_text,
generate_variants,
contains_excluded_keywords,
extract_terms,
# remove_excluded_from_list,
process_extracted_result
)
# ---------- Cached function for loading the model pipelines ----------
@st.cache_resource(show_spinner=False)
def summarizer(text, prompt, max_length=600, min_length=10):
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
# Tokenize the prompt and article separately without adding special tokens
prompt_tokens = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
article_tokens = tokenizer(text, add_special_tokens=False, return_tensors="pt")["input_ids"]
# Concatenate prompt and article tokens
combined_input_ids = torch.cat([prompt_tokens, article_tokens], dim=-1)
# skip the too large input which has more than 1024 tokens
if len(combined_input_ids[0]) > 1024:
return None
# st.write(len(combined_input_ids[0]))
# Convert the tensor to a list and add special tokens as required by the model
input_ids_list = tokenizer.build_inputs_with_special_tokens(combined_input_ids[0].tolist())
input_ids = torch.tensor([input_ids_list])
# Generate the summary
summary_ids = model.generate(input_ids, max_length=max_length, min_length=min_length, do_sample=False)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
@st.cache_resource(show_spinner=False)
def load_sentiment_pipeline(): # fine-turned sentiment pipeline
tokenizer = AutoTokenizer.from_pretrained("kusa04/CustomModel_reddit")
model = AutoModelForSequenceClassification.from_pretrained(
"kusa04/CustomModel_reddit",
use_auth_token=st.secrets["hugging_face_with_my_fine_turning_model"],
# use_auth_token=st.secrets["hugging_face_token"]
)
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=0) # -1 to 0
max_tokens = tokenizer.model_max_length
if max_tokens > 10000:
max_tokens = 200
return sentiment_pipeline, tokenizer, max_tokens
@st.cache_resource(show_spinner=False)
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
def __init__(self, model, *args, **kwargs):
super().__init__(
model=AutoModelForTokenClassification.from_pretrained(model),
tokenizer=AutoTokenizer.from_pretrained(model),
*args,
**kwargs
)
def postprocess(self, all_outputs):
results = super().postprocess(
all_outputs=all_outputs,
aggregation_strategy=AggregationStrategy.SIMPLE,
)
return np.unique([result.get("word").strip() for result in results])
@st.cache_resource(show_spinner=False)
def get_keyword_pipeline():
model_name = "ml6team/keyphrase-extraction-kbir-inspec"
return KeyphraseExtractionPipeline(model=model_name)
keyword_pipeline = get_keyword_pipeline()
def keyword_extractor(text):
try:
return keyword_pipeline(text)
except Exception as e:
return None
st.title("Scraping & Analysis of Reddit")
# --- Sidebar ---
with st.sidebar:
st.header("Controls")
user_query = st.text_input("Enter search keyword:", value="Monster Hunter Wilds")
scrape_btn = st.button("Scrape")
summarize_btn = st.button("Summarize")
sentiment_btn = st.button("Sentiment Analysis")
keyword_extraction_btn = st.button("Keyword Extraction")
# --- User Input ---
# user_query = st.text_input("Enter search keyword:", value="Monster Hunter Wilds")
if user_query:
search_query = f'"{user_query}" OR "{user_query.replace(" ", "")}"'
st.session_state["user_query"] = user_query
else:
search_query = ""
st.write("Search Query:", search_query)
# Button to trigger scraping and summarizing overall text using chunking
if scrape_btn:
with st.spinner("Scraping..."):
# progress_bar = st.progress(0)
progress_text = st.empty()
total_limit = 5000 # Maximum number of submissions to check
df = scrape_reddit_data(search_query, total_limit)
length = len(df)
progress_text.text(f"Collected {length} valid posts.")
st.session_state["df"] = df
# Load the summarization pipeline
# with st.spinner("Loading Summarizing Pipeline"):
# summarizer = summarizer()
if summarize_btn:
df = st.session_state.get("df")
if df is None or df.empty:
st.write("Please run 'Scrape' with an accurate keyword first.")
st.stop()
# Split the "Detail" texts into a list
all_details = df["Detail"].tolist()
# Divide the list into chunks of 4 posts each
chunk_size = 4
chunks = [all_details[i:i + chunk_size] for i in range(0, len(all_details), chunk_size)]
chunk_summaries = []
# Summarize each chunk individually
for idx, chunk in enumerate(chunks):
# Join the posts in the chunk with "\n"
chunk_text = " \n ".join(chunk)
with st.spinner(f"Summarizing chunk {idx + 1} of {len(chunks)}..."):
prompt = """
Please summarize the following Reddit post formally,
highlighting the author's key experiences and opinions,
especially about HOW HE/SHE THINK ABOUT IT.
NOTE THAT every time you see indent, that means they are different users, posted by different people.
So, TELL ME THE HOLE TENDENCY ACROSS MULTIPLE REDDIT USERS:
"""
summary_output = summarizer(chunk_text, prompt, max_length=50, min_length=2)
chunk_summaries.append(summary_output)
# Combine all chunk summaries using the same delimiter
combined_summary_text = " \n ".join(str(chunk) + "\n\n\n\n" for chunk in chunk_summaries if chunk)
# Generate an overall summary from the combined chunk summaries
# with st.spinner("Generating overall summary from chunk summaries..."):
# prompt = """
# Based on the above text, what kind of tendencies do you think can be perceived in the users?
# I believe that there is a space between each sentence, which indicates that each belongs to a different user.
# With that in mind, PLEASE EXPLAIN to me in an easy-to-understand manner what each user tends to be seeking.
# """
# overall_summary_output = summarizer(combined_summary_text, prompt, max_length=600)
# Display the overall summary
st.subheader("Overall Summary of All Posts")
st.write(combined_summary_text)
# Save the DataFrame and overall summary in session state for later use
st.session_state["df"] = df
st.session_state["overall_summary"] = combined_summary_text
# button to trigger sentiment analysis
if sentiment_btn:
df = st.session_state.get("df")
if df is None or df.empty:
st.write("Please run 'Scrape' with an accurate keyword first.")
st.stop()
length = len(df)
with st.spinner("Loading..."):
sentiment_pipeline, tokenizer, max_tokens = load_sentiment_pipeline()
st.write("Loaded...")
with st.spinner("Doing Sentiment Analysis..."):
progress_bar = st.progress(0)
# title is short, so dont havwe to use batch processing
df['Title_Sentiment'] = df['Title'].apply(lambda x: \
safe_sentiment(sentiment_pipeline, preprocess_text(x), length, progress_bar) if x else None)
# palarell procsssing for each row of detail
with ThreadPoolExecutor() as executor:
detail_sentiments = list(executor.map(
lambda x: analyze_detail(x, tokenizer, sentiment_pipeline, max_tokens) if x else None,
df['Detail']
))
df['Detail_Sentiment'] = detail_sentiments
df["Title_Sentiment_Label"] = df["Title_Sentiment"].apply(lambda x: x["label"] if x else None)
df["Title_Sentiment_Score"] = df["Title_Sentiment"].apply(lambda x: x["score"] if x else None)
df["Detail_Sentiment_Label"] = df["Detail_Sentiment"].apply(lambda x: x["label"] if x else None)
df["Detail_Sentiment_Score"] = df["Detail_Sentiment"].apply(lambda x: x["score"] if x else None)
df = df.drop(["Title_Sentiment", "Detail_Sentiment"], axis=1)
cols = ["Title", "Title_Sentiment_Label", "Title_Sentiment_Score",
"Detail", "Detail_Sentiment_Label", "Detail_Sentiment_Score", "Date"]
df = df[cols]
st.session_state["df"] = df
with st.spinner("Drawing Sentiment Graphs..."):
# ① create yyyy-mm colmns
df["YearMonth"] = pd.to_datetime(df["Date"]).dt.to_period("M").astype(str)
df["Title_Sentiment_Label"] = df["Title_Sentiment_Label"].str.lower()
df["Detail_Sentiment_Label"] = df["Detail_Sentiment_Label"].str.lower()
# ② groupby and pivot title & detail
title_counts = df.groupby(["YearMonth", "Title_Sentiment_Label"]).size().reset_index(name="count")
detail_counts = df.groupby(["YearMonth", "Detail_Sentiment_Label"]).size().reset_index(name="count")
# ③ pivot → index=YearMonth, columns=sentiment, values=count
title_pivot = title_counts.pivot(index="YearMonth", columns="Title_Sentiment_Label", values="count").fillna(0)
detail_pivot = detail_counts.pivot(index="YearMonth", columns="Detail_Sentiment_Label", values="count").fillna(0)
# Sort
title_pivot = title_pivot.sort_index()
detail_pivot = detail_pivot.sort_index()
# --- ④ Visualize title graph ---
fig1, ax1 = plt.subplots(figsize=(15, 6))
# stacked bar plot
title_pivot.plot(kind="bar", stacked=True, ax=ax1, color={
"positive": "orange",
"neutral": "yellowgreen",
"negative": "blue"
})
# line graph
# for sentiment, color in zip(["positive", "neutral", "negative"], ["darkorange", "green", "navy"]):
# if sentiment in title_pivot.columns:
# ax1.plot(title_pivot.index, title_pivot[sentiment], label=f"{sentiment} trend", marker="o", color=color, linestyle="--")
title_cum = title_pivot.cumsum(axis=1)
for sentiment, color in zip(["positive", "neutral", "negative"], ["darkorange", "green", "navy"]):
if sentiment in title_cum.columns:
ax1.plot(title_cum.index, title_cum[sentiment], label=f"{sentiment} trend", marker="o", color=color, linestyle="--")
ax1.set_title("Monthly Title Sentiment Counts")
ax1.set_xlabel("Time (YYYY-MM)")
ax1.set_ylabel("Count")
ax1.legend()
plt.xticks(rotation=45)
st.pyplot(fig1)
# --- ⑤ Visualize detail ---
fig2, ax2 = plt.subplots(figsize=(15, 6))
# stacked bar plot
detail_pivot.plot(kind="bar", stacked=True, ax=ax2, color={
"positive": "darkorange",
"neutral": "forestgreen",
"negative": "navy"
})
# line graph
# for sentiment, color in zip(["positive", "neutral", "negative"], ["orangered", "limegreen", "darkblue"]):
# if sentiment in detail_pivot.columns:
# ax2.plot(detail_pivot.index, detail_pivot[sentiment], label=f"{sentiment} trend", marker="o", color=color, linestyle="--")
detail_cum = detail_pivot.cumsum(axis=1)
for sentiment, color in zip(["positive", "neutral", "negative"], ["orangered", "limegreen", "darkblue"]):
if sentiment in detail_pivot.columns:
ax2.plot(detail_cum.index, detail_cum[sentiment], label=f"{sentiment} trend", marker="o", color=color, linestyle="--")
ax2.set_title("Monthly Detail Sentiment Counts")
ax2.set_xlabel("Time (YYYY-MM)")
ax2.set_ylabel("Count")
ax2.legend()
plt.xticks(rotation=45)
st.pyplot(fig2)
if keyword_extraction_btn:
df = st.session_state.get("df")
user_query = st.session_state.get("user_query")
if (df is None or df.empty) or (user_query is None):
st.write("Please run 'Scrape' with an accurate keyword first.")
st.stop()
else:
with st.spinner("Extracting Keyword..."):
target_col = "Detail_Keyword"
# 並列処理で各 'Detail' に対してキーワード抽出を実行
with ThreadPoolExecutor() as executor:
results = list(
executor.map(lambda x: keyword_extractor(preprocess_text(x)) if x else None, df['Detail'])
)
df[target_col] = results
# st.write("df: ", df[target_col].head().to_list())
# generate exclude keyword
excluded_keywords = generate_variants(user_query)
df_filtered = df[~df[target_col].apply(
lambda cell: contains_excluded_keywords(cell, excluded_keywords=excluded_keywords)
)].copy()
# st.write("filtered: ", df_filtered[target_col].head().to_list())
# conver to list
terms_list = df_filtered[target_col].dropna().apply(lambda x: extract_terms(x))
terms = [term for sublist in terms_list for term in sublist]
# st.write("term_list: ", terms_list)
# st.write("terms:", terms)
# count frequency
freq = Counter(terms)
# st.write("freq:", freq)
with st.spinner("Drawing Keywords Diagram..."):
if freq:
# st.write("freq preview:", list(freq.items())[:10]) # optional debug
wc = WordCloud(width=800, height=400, background_color="white")
wc.generate_from_frequencies(freq)
fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(wc, interpolation="bilinear")
ax.axis("off")
st.pyplot(fig)
else:
st.warning("No keywords to display.")