from flask import Flask, request, jsonify from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, pipeline, AutoTokenizer, AutoModelForTokenClassification import re from flask_cors import CORS app = Flask(__name__) CORS(app) # Load chatbot model model_name = "facebook/blenderbot-400M-distill" tokenizer = BlenderbotTokenizer.from_pretrained(model_name) model = BlenderbotForConditionalGeneration.from_pretrained(model_name) # Load POS tagging pipeline pos_pipe = pipeline("token-classification", model="TweebankNLP/bertweet-tb2-pos-tagging") # Load NER model model_checkpoint = "huggingface-course/bert-finetuned-ner" ner_model = AutoModelForTokenClassification.from_pretrained(model_checkpoint) ner_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) token_classifier = pipeline( "token-classification", model=ner_model, aggregation_strategy="simple", tokenizer=ner_tokenizer, ) # Function to clean messages def clean_message(text): # Remove emojis and special characters (except spaces and letters) text = re.sub(r'[^\w\s]', '', text) # Reduce repeated letters only if they appear more than twice at the end text = re.sub(r'(\w*?)(\w)\2{2,}\b', r'\1\2', text) # Perform POS tagging pos_tags = pos_pipe(text) # Convert words to title case selectively words = text.split() cleaned_words = [] for i, word in enumerate(words): tag = next((tag_info["entity"] for tag_info in pos_tags if tag_info["word"] == word), None) if tag in ["ADJ", "ADP"]: # Keep ADJ and ADP words lowercase cleaned_words.append(word.lower()) else: # Title case for other words cleaned_words.append(word.title()) # Remove single-letter words (except 'I' or 'A' if needed) cleaned_words = [word for word in cleaned_words if len(word) > 1] return " ".join(cleaned_words) # Function to extract named entities from a single message def extract_entities(text, message_index, existing_entities=set(), threshold=0.85): entities_dict = {"PER": [], "ORG": [], "LOC": [], "MISC": []} seen_words = set(existing_entities) # Initialize the set of previously noted entities results = token_classifier(text) for entity in results: word = entity["word"] entity_type = entity["entity_group"] score = entity["score"] # Ignore low-confidence entities if score < threshold: continue # Ignore subword tokens (split words like "##word") if word.startswith("##"): continue # Ignore short words (e.g., single letters) if len(word) == 1: continue # Keep multi-word locations intact if entity_type == "LOC": processed_words = [word] else: processed_words = word.split() for single_word in processed_words: # Check if the word has been already noted if single_word not in seen_words: seen_words.add(single_word) # Add new word to the respective entity list if entity_type in entities_dict: entities_dict[entity_type].append({ "index": message_index, "word": single_word, "substring": (text.find(single_word), text.find(single_word) + len(single_word)) }) return entities_dict @app.route("/") def home(): return "Hello, World!" @app.route("/api/home", methods=['POST','GET']) def receive_message(): data = request.get_json() message_index = data.get("index") message = data.get("message", "") print(f"Received message at index {message_index}: {message}") # Clean user message cleaned_message = clean_message(message) print("Cleaned Message:", cleaned_message) # Extract named entities from user message user_entities = extract_entities(cleaned_message, message_index) print("Extracted Entities from User's Message:", user_entities) # Generate chatbot response inputs = tokenizer(cleaned_message, return_tensors="pt") reply_ids = model.generate(**inputs) bot_response = tokenizer.decode(reply_ids[0], skip_special_tokens=True) print(f"Chatbot Response: {bot_response}") # The bot's response index will be the user message index + 1 bot_index = message_index + 1 # Extract named entities from chatbot response (bot index) bot_entities = extract_entities(bot_response, bot_index) print("Extracted Entities from Chatbot's Response:", bot_entities) return jsonify({ 'response': bot_response, 'person_user': user_entities.get("PER", []), 'location_user': user_entities.get("LOC", []), 'person_bot': bot_entities.get("PER", []), 'location_bot': bot_entities.get("LOC", []) }) if __name__ == "__main__": app.run(host="0.0.0.0", debug=True)