Spaces:
Sleeping
Sleeping
from collections import Counter | |
from concurrent.futures import ThreadPoolExecutor # palarell processing | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import praw # Reddit's API | |
import re # Regular expression module | |
import streamlit as st | |
import time | |
import torch | |
import numpy as np | |
from wordcloud import WordCloud | |
from transformers import ( | |
pipeline, | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
AutoModelForTokenClassification, | |
TokenClassificationPipeline, | |
T5Tokenizer, | |
T5ForConditionalGeneration, | |
AutoTokenizer, | |
AutoModelForSeq2SeqLM | |
) | |
from transformers.pipelines import AggregationStrategy | |
from functions import ( | |
scrape_reddit_data, | |
safe_sentiment, | |
analyze_detail, | |
preprocess_text, | |
generate_variants, | |
contains_excluded_keywords, | |
extract_terms, | |
# remove_excluded_from_list, | |
process_extracted_result | |
) | |
# ---------- Cached function for loading the model pipelines ---------- | |
def summarizer(text, prompt, max_length=600, min_length=10): | |
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") | |
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") | |
# Tokenize the prompt and article separately without adding special tokens | |
prompt_tokens = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"] | |
article_tokens = tokenizer(text, add_special_tokens=False, return_tensors="pt")["input_ids"] | |
# Concatenate prompt and article tokens | |
combined_input_ids = torch.cat([prompt_tokens, article_tokens], dim=-1) | |
# skip the too large input which has more than 1024 tokens | |
if len(combined_input_ids[0]) > 1024: | |
return None | |
# st.write(len(combined_input_ids[0])) | |
# Convert the tensor to a list and add special tokens as required by the model | |
input_ids_list = tokenizer.build_inputs_with_special_tokens(combined_input_ids[0].tolist()) | |
input_ids = torch.tensor([input_ids_list]) | |
# Generate the summary | |
summary_ids = model.generate(input_ids, max_length=max_length, min_length=min_length, do_sample=False) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
def load_sentiment_pipeline(): # fine-turned sentiment pipeline | |
tokenizer = AutoTokenizer.from_pretrained("kusa04/CustomModel_reddit") | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"kusa04/CustomModel_reddit", | |
use_auth_token=st.secrets["hugging_face_with_my_fine_turning_model"], | |
# use_auth_token=st.secrets["hugging_face_token"] | |
) | |
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=0) # -1 to 0 | |
max_tokens = tokenizer.model_max_length | |
if max_tokens > 10000: | |
max_tokens = 200 | |
return sentiment_pipeline, tokenizer, max_tokens | |
class KeyphraseExtractionPipeline(TokenClassificationPipeline): | |
def __init__(self, model, *args, **kwargs): | |
super().__init__( | |
model=AutoModelForTokenClassification.from_pretrained(model), | |
tokenizer=AutoTokenizer.from_pretrained(model), | |
*args, | |
**kwargs | |
) | |
def postprocess(self, all_outputs): | |
results = super().postprocess( | |
all_outputs=all_outputs, | |
aggregation_strategy=AggregationStrategy.SIMPLE, | |
) | |
return np.unique([result.get("word").strip() for result in results]) | |
def get_keyword_pipeline(): | |
model_name = "ml6team/keyphrase-extraction-kbir-inspec" | |
return KeyphraseExtractionPipeline(model=model_name) | |
keyword_pipeline = get_keyword_pipeline() | |
def keyword_extractor(text): | |
try: | |
return keyword_pipeline(text) | |
except Exception as e: | |
return None | |
st.title("Scraping & Analysis of Reddit") | |
# --- Sidebar --- | |
with st.sidebar: | |
st.header("Controls") | |
user_query = st.text_input("Enter search keyword:", value="Monster Hunter Wilds") | |
scrape_btn = st.button("Scrape") | |
summarize_btn = st.button("Summarize") | |
sentiment_btn = st.button("Sentiment Analysis") | |
keyword_extraction_btn = st.button("Keyword Extraction") | |
# --- User Input --- | |
# user_query = st.text_input("Enter search keyword:", value="Monster Hunter Wilds") | |
if user_query: | |
search_query = f'"{user_query}" OR "{user_query.replace(" ", "")}"' | |
st.session_state["user_query"] = user_query | |
else: | |
search_query = "" | |
st.write("Search Query:", search_query) | |
# Button to trigger scraping and summarizing overall text using chunking | |
if scrape_btn: | |
with st.spinner("Scraping..."): | |
# progress_bar = st.progress(0) | |
progress_text = st.empty() | |
total_limit = 5000 # Maximum number of submissions to check | |
df = scrape_reddit_data(search_query, total_limit) | |
length = len(df) | |
progress_text.text(f"Collected {length} valid posts.") | |
st.session_state["df"] = df | |
# Load the summarization pipeline | |
# with st.spinner("Loading Summarizing Pipeline"): | |
# summarizer = summarizer() | |
if summarize_btn: | |
df = st.session_state.get("df") | |
if df is None or df.empty: | |
st.write("Please run 'Scrape' with an accurate keyword first.") | |
st.stop() | |
# Split the "Detail" texts into a list | |
all_details = df["Detail"].tolist() | |
# Divide the list into chunks of 4 posts each | |
chunk_size = 4 | |
chunks = [all_details[i:i + chunk_size] for i in range(0, len(all_details), chunk_size)] | |
chunk_summaries = [] | |
# Summarize each chunk individually | |
for idx, chunk in enumerate(chunks): | |
# Join the posts in the chunk with "\n" | |
chunk_text = " \n ".join(chunk) | |
with st.spinner(f"Summarizing chunk {idx + 1} of {len(chunks)}..."): | |
prompt = """ | |
Please summarize the following Reddit post formally, | |
highlighting the author's key experiences and opinions, | |
especially about HOW HE/SHE THINK ABOUT IT. | |
NOTE THAT every time you see indent, that means they are different users, posted by different people. | |
So, TELL ME THE HOLE TENDENCY ACROSS MULTIPLE REDDIT USERS: | |
""" | |
summary_output = summarizer(chunk_text, prompt, max_length=50, min_length=2) | |
chunk_summaries.append(summary_output) | |
# Combine all chunk summaries using the same delimiter | |
combined_summary_text = " \n ".join(str(chunk) + "\n\n\n\n" for chunk in chunk_summaries if chunk) | |
# Generate an overall summary from the combined chunk summaries | |
# with st.spinner("Generating overall summary from chunk summaries..."): | |
# prompt = """ | |
# Based on the above text, what kind of tendencies do you think can be perceived in the users? | |
# I believe that there is a space between each sentence, which indicates that each belongs to a different user. | |
# With that in mind, PLEASE EXPLAIN to me in an easy-to-understand manner what each user tends to be seeking. | |
# """ | |
# overall_summary_output = summarizer(combined_summary_text, prompt, max_length=600) | |
# Display the overall summary | |
st.subheader("Overall Summary of All Posts") | |
st.write(combined_summary_text) | |
# Save the DataFrame and overall summary in session state for later use | |
st.session_state["df"] = df | |
st.session_state["overall_summary"] = combined_summary_text | |
# button to trigger sentiment analysis | |
if sentiment_btn: | |
df = st.session_state.get("df") | |
if df is None or df.empty: | |
st.write("Please run 'Scrape' with an accurate keyword first.") | |
st.stop() | |
length = len(df) | |
with st.spinner("Loading..."): | |
sentiment_pipeline, tokenizer, max_tokens = load_sentiment_pipeline() | |
st.write("Loaded...") | |
with st.spinner("Doing Sentiment Analysis..."): | |
progress_bar = st.progress(0) | |
# title is short, so dont havwe to use batch processing | |
df['Title_Sentiment'] = df['Title'].apply(lambda x: \ | |
safe_sentiment(sentiment_pipeline, preprocess_text(x), length, progress_bar) if x else None) | |
# palarell procsssing for each row of detail | |
with ThreadPoolExecutor() as executor: | |
detail_sentiments = list(executor.map( | |
lambda x: analyze_detail(x, tokenizer, sentiment_pipeline, max_tokens) if x else None, | |
df['Detail'] | |
)) | |
df['Detail_Sentiment'] = detail_sentiments | |
df["Title_Sentiment_Label"] = df["Title_Sentiment"].apply(lambda x: x["label"] if x else None) | |
df["Title_Sentiment_Score"] = df["Title_Sentiment"].apply(lambda x: x["score"] if x else None) | |
df["Detail_Sentiment_Label"] = df["Detail_Sentiment"].apply(lambda x: x["label"] if x else None) | |
df["Detail_Sentiment_Score"] = df["Detail_Sentiment"].apply(lambda x: x["score"] if x else None) | |
df = df.drop(["Title_Sentiment", "Detail_Sentiment"], axis=1) | |
cols = ["Title", "Title_Sentiment_Label", "Title_Sentiment_Score", | |
"Detail", "Detail_Sentiment_Label", "Detail_Sentiment_Score", "Date"] | |
df = df[cols] | |
st.session_state["df"] = df | |
with st.spinner("Drawing Sentiment Graphs..."): | |
# ① create yyyy-mm colmns | |
df["YearMonth"] = pd.to_datetime(df["Date"]).dt.to_period("M").astype(str) | |
df["Title_Sentiment_Label"] = df["Title_Sentiment_Label"].str.lower() | |
df["Detail_Sentiment_Label"] = df["Detail_Sentiment_Label"].str.lower() | |
# ② groupby and pivot title & detail | |
title_counts = df.groupby(["YearMonth", "Title_Sentiment_Label"]).size().reset_index(name="count") | |
detail_counts = df.groupby(["YearMonth", "Detail_Sentiment_Label"]).size().reset_index(name="count") | |
# ③ pivot → index=YearMonth, columns=sentiment, values=count | |
title_pivot = title_counts.pivot(index="YearMonth", columns="Title_Sentiment_Label", values="count").fillna(0) | |
detail_pivot = detail_counts.pivot(index="YearMonth", columns="Detail_Sentiment_Label", values="count").fillna(0) | |
# Sort | |
title_pivot = title_pivot.sort_index() | |
detail_pivot = detail_pivot.sort_index() | |
# --- ④ Visualize title graph --- | |
fig1, ax1 = plt.subplots(figsize=(15, 6)) | |
# stacked bar plot | |
title_pivot.plot(kind="bar", stacked=True, ax=ax1, color={ | |
"positive": "orange", | |
"neutral": "yellowgreen", | |
"negative": "blue" | |
}) | |
# line graph | |
# for sentiment, color in zip(["positive", "neutral", "negative"], ["darkorange", "green", "navy"]): | |
# if sentiment in title_pivot.columns: | |
# ax1.plot(title_pivot.index, title_pivot[sentiment], label=f"{sentiment} trend", marker="o", color=color, linestyle="--") | |
title_cum = title_pivot.cumsum(axis=1) | |
for sentiment, color in zip(["positive", "neutral", "negative"], ["darkorange", "green", "navy"]): | |
if sentiment in title_cum.columns: | |
ax1.plot(title_cum.index, title_cum[sentiment], label=f"{sentiment} trend", marker="o", color=color, linestyle="--") | |
ax1.set_title("Monthly Title Sentiment Counts") | |
ax1.set_xlabel("Time (YYYY-MM)") | |
ax1.set_ylabel("Count") | |
ax1.legend() | |
plt.xticks(rotation=45) | |
st.pyplot(fig1) | |
# --- ⑤ Visualize detail --- | |
fig2, ax2 = plt.subplots(figsize=(15, 6)) | |
# stacked bar plot | |
detail_pivot.plot(kind="bar", stacked=True, ax=ax2, color={ | |
"positive": "darkorange", | |
"neutral": "forestgreen", | |
"negative": "navy" | |
}) | |
# line graph | |
# for sentiment, color in zip(["positive", "neutral", "negative"], ["orangered", "limegreen", "darkblue"]): | |
# if sentiment in detail_pivot.columns: | |
# ax2.plot(detail_pivot.index, detail_pivot[sentiment], label=f"{sentiment} trend", marker="o", color=color, linestyle="--") | |
detail_cum = detail_pivot.cumsum(axis=1) | |
for sentiment, color in zip(["positive", "neutral", "negative"], ["orangered", "limegreen", "darkblue"]): | |
if sentiment in detail_pivot.columns: | |
ax2.plot(detail_cum.index, detail_cum[sentiment], label=f"{sentiment} trend", marker="o", color=color, linestyle="--") | |
ax2.set_title("Monthly Detail Sentiment Counts") | |
ax2.set_xlabel("Time (YYYY-MM)") | |
ax2.set_ylabel("Count") | |
ax2.legend() | |
plt.xticks(rotation=45) | |
st.pyplot(fig2) | |
if keyword_extraction_btn: | |
df = st.session_state.get("df") | |
user_query = st.session_state.get("user_query") | |
if (df is None or df.empty) or (user_query is None): | |
st.write("Please run 'Scrape' with an accurate keyword first.") | |
st.stop() | |
else: | |
with st.spinner("Extracting Keyword..."): | |
target_col = "Detail_Keyword" | |
# 並列処理で各 'Detail' に対してキーワード抽出を実行 | |
with ThreadPoolExecutor() as executor: | |
results = list( | |
executor.map(lambda x: keyword_extractor(preprocess_text(x)) if x else None, df['Detail']) | |
) | |
df[target_col] = results | |
# st.write("df: ", df[target_col].head().to_list()) | |
# generate exclude keyword | |
excluded_keywords = generate_variants(user_query) | |
df_filtered = df[~df[target_col].apply( | |
lambda cell: contains_excluded_keywords(cell, excluded_keywords=excluded_keywords) | |
)].copy() | |
# st.write("filtered: ", df_filtered[target_col].head().to_list()) | |
# conver to list | |
terms_list = df_filtered[target_col].dropna().apply(lambda x: extract_terms(x)) | |
terms = [term for sublist in terms_list for term in sublist] | |
# st.write("term_list: ", terms_list) | |
# st.write("terms:", terms) | |
# count frequency | |
freq = Counter(terms) | |
# st.write("freq:", freq) | |
with st.spinner("Drawing Keywords Diagram..."): | |
if freq: | |
# st.write("freq preview:", list(freq.items())[:10]) # optional debug | |
wc = WordCloud(width=800, height=400, background_color="white") | |
wc.generate_from_frequencies(freq) | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
ax.imshow(wc, interpolation="bilinear") | |
ax.axis("off") | |
st.pyplot(fig) | |
else: | |
st.warning("No keywords to display.") | |