from collections import Counter import matplotlib.pyplot as plt import pandas as pd import praw # Reddit's API import re # Regular expression module import streamlit as st import time import numpy as np from wordcloud import WordCloud from transformers import ( pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TokenClassificationPipeline ) from transformers.pipelines import AggregationStrategy # Function to normalize text by replacing multiple spaces/newlines with a single space def normalize_text(text): if not isinstance(text, str): return "" return re.sub(r'\s+', ' ', text).strip() # ---------- Cached function for scraping Reddit data ---------- @st.cache_data(show_spinner=False) def scrape_reddit_data(search_query, total_limit): # Retrieve API credentials from st.secrets reddit = praw.Reddit( client_id=st.secrets["reddit_client_id"], client_secret=st.secrets["reddit_client_secret"], user_agent=st.secrets["reddit_user_agent"] ) subreddit = reddit.subreddit("all") posts_data = [] # Iterate over submissions based on the search query and limit for i, submission in enumerate(subreddit.search(search_query, sort="relevance", limit=total_limit)): # No UI updates here as caching does not allow live progress updates if submission.title and submission.selftext: posts_data.append([ submission.title, submission.url, submission.created_utc, submission.selftext, ]) time.sleep(0.25) df = pd.DataFrame(posts_data, columns=["Title", "URL", "Date", "Detail"]) for col in ["Title", "Detail"]: df[col] = df[col].apply(normalize_text) # Filter out rows with empty Title or Detail df = df[(df["Title"] != "") & (df["Detail"] != "")] df['Date'] = pd.to_datetime(df['Date'], unit='s') df = df.sort_values(by="Date", ascending=True).reset_index(drop=True) return df # ------------------ Sentiment Analysis Functions ------------------------# def split_text_by_token_limit(text, tokenizer, max_tokens): tokens = tokenizer.encode(text, add_special_tokens=False) chunks = [] for i in range(0, len(tokens), max_tokens): chunk_tokens = tokens[i:i+max_tokens] chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True) chunks.append(chunk_text) return chunks def safe_sentiment(sentiment_pipeline, text): try: result = sentiment_pipeline(text)[0] except Exception as e: result = None return result def safe_sentiment_batch(sentiment_pipeline, texts): try: results = sentiment_pipeline(texts) except Exception as e: results = [None] * len(texts) return results def analyze_detail(text, tokenizer, sentiment_pipeline, max_tokens): text = preprocess_text(text) chunks = split_text_by_token_limit(text, tokenizer, max_tokens) if not chunks: return None # ここでバッチ処理を実行(チャンク全体を一括推論) results = safe_sentiment_batch(sentiment_pipeline, chunks) # 各チャンクの結果を集計 scores = {"POSITIVE": 0, "NEGATIVE": 0, "NEUTRAL": 0} for result in results: if result is not None: label = result['label'].upper() if label in scores: scores[label] += result['score'] final_label = max(scores, key=lambda k: scores[k]) final_score = scores[final_label] return {"label": final_label, "score": final_score} def preprocess_text(text): # Replace URLs and user mentions text = re.sub(r'http\S+', 'http', text) text = re.sub(r'@\w+', '@user', text) return text # def keyword_extraction(text): # try: # extractor = keyword_extractor() # result = extractor(text) # except Exception as e: # # Optionally, log the error: print(f"Error processing text: {e}") # result = None # return result