Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify | |
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, pipeline, AutoTokenizer, AutoModelForTokenClassification | |
import re | |
from flask_cors import CORS | |
app = Flask(__name__) | |
CORS(app) | |
# Load chatbot model | |
model_name = "facebook/blenderbot-400M-distill" | |
tokenizer = BlenderbotTokenizer.from_pretrained(model_name) | |
model = BlenderbotForConditionalGeneration.from_pretrained(model_name) | |
# Load POS tagging pipeline | |
pos_pipe = pipeline("token-classification", model="TweebankNLP/bertweet-tb2-pos-tagging") | |
# Load NER model | |
model_checkpoint = "huggingface-course/bert-finetuned-ner" | |
ner_model = AutoModelForTokenClassification.from_pretrained(model_checkpoint) | |
ner_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
token_classifier = pipeline( | |
"token-classification", model=ner_model, aggregation_strategy="simple", tokenizer=ner_tokenizer, | |
) | |
# Function to clean messages | |
def clean_message(text): | |
# Remove emojis and special characters (except spaces and letters) | |
text = re.sub(r'[^\w\s]', '', text) | |
# Reduce repeated letters only if they appear more than twice at the end | |
text = re.sub(r'(\w*?)(\w)\2{2,}\b', r'\1\2', text) | |
# Perform POS tagging | |
pos_tags = pos_pipe(text) | |
# Convert words to title case selectively | |
words = text.split() | |
cleaned_words = [] | |
for i, word in enumerate(words): | |
tag = next((tag_info["entity"] for tag_info in pos_tags if tag_info["word"] == word), None) | |
if tag in ["ADJ", "ADP"]: # Keep ADJ and ADP words lowercase | |
cleaned_words.append(word.lower()) | |
else: # Title case for other words | |
cleaned_words.append(word.title()) | |
# Remove single-letter words (except 'I' or 'A' if needed) | |
cleaned_words = [word for word in cleaned_words if len(word) > 1] | |
return " ".join(cleaned_words) | |
# Function to extract named entities from a single message | |
def extract_entities(text, message_index, existing_entities=set(), threshold=0.85): | |
entities_dict = {"PER": [], "ORG": [], "LOC": [], "MISC": []} | |
seen_words = set(existing_entities) # Initialize the set of previously noted entities | |
results = token_classifier(text) | |
for entity in results: | |
word = entity["word"] | |
entity_type = entity["entity_group"] | |
score = entity["score"] | |
# Ignore low-confidence entities | |
if score < threshold: | |
continue | |
# Ignore subword tokens (split words like "##word") | |
if word.startswith("##"): | |
continue | |
# Ignore short words (e.g., single letters) | |
if len(word) == 1: | |
continue | |
# Keep multi-word locations intact | |
if entity_type == "LOC": | |
processed_words = [word] | |
else: | |
processed_words = word.split() | |
for single_word in processed_words: | |
# Check if the word has been already noted | |
if single_word not in seen_words: | |
seen_words.add(single_word) | |
# Add new word to the respective entity list | |
if entity_type in entities_dict: | |
entities_dict[entity_type].append({ | |
"index": message_index, | |
"word": single_word, | |
"substring": (text.find(single_word), text.find(single_word) + len(single_word)) | |
}) | |
return entities_dict | |
def home(): | |
return "Hello, World!" | |
def receive_message(): | |
data = request.get_json() | |
message_index = data.get("index") | |
message = data.get("message", "") | |
print(f"Received message at index {message_index}: {message}") | |
# Clean user message | |
cleaned_message = clean_message(message) | |
print("Cleaned Message:", cleaned_message) | |
# Extract named entities from user message | |
user_entities = extract_entities(cleaned_message, message_index) | |
print("Extracted Entities from User's Message:", user_entities) | |
# Generate chatbot response | |
inputs = tokenizer(cleaned_message, return_tensors="pt") | |
reply_ids = model.generate(**inputs) | |
bot_response = tokenizer.decode(reply_ids[0], skip_special_tokens=True) | |
print(f"Chatbot Response: {bot_response}") | |
# The bot's response index will be the user message index + 1 | |
bot_index = message_index + 1 | |
# Extract named entities from chatbot response (bot index) | |
bot_entities = extract_entities(bot_response, bot_index) | |
print("Extracted Entities from Chatbot's Response:", bot_entities) | |
return jsonify({ | |
'response': bot_response, | |
'person_user': user_entities.get("PER", []), | |
'location_user': user_entities.get("LOC", []), | |
'person_bot': bot_entities.get("PER", []), | |
'location_bot': bot_entities.get("LOC", []) | |
}) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", debug=True) | |