|
import gradio as gr |
|
import whisper |
|
import tempfile |
|
import os |
|
from utils import SocialGraphManager, SuggestionGenerator |
|
|
|
|
|
AVAILABLE_MODELS = { |
|
"distilgpt2": "DistilGPT2 (Fast, smaller model)", |
|
"gpt2": "GPT-2 (Medium size, better quality)", |
|
"google/gemma-3-1b-it": "Gemma 3 1B-IT (Small, instruction-tuned)", |
|
"Qwen/Qwen1.5-0.5B": "Qwen 1.5 0.5B (Very small, efficient)", |
|
"Qwen/Qwen1.5-1.8B": "Qwen 1.5 1.8B (Small, good quality)", |
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0": "TinyLlama 1.1B (Small, chat-tuned)", |
|
"microsoft/phi-3-mini-4k-instruct": "Phi-3 Mini (Small, instruction-tuned)", |
|
"microsoft/phi-2": "Phi-2 (Small, high quality for size)", |
|
} |
|
|
|
|
|
social_graph = SocialGraphManager("social_graph.json") |
|
|
|
|
|
suggestion_generator = SuggestionGenerator("distilgpt2") |
|
|
|
|
|
test_result = suggestion_generator.test_model() |
|
print(f"Model test result: {test_result}") |
|
|
|
|
|
if not suggestion_generator.model_loaded: |
|
print("Model failed to load, using fallback responses...") |
|
|
|
|
|
|
|
try: |
|
whisper_model = whisper.load_model("tiny") |
|
whisper_loaded = True |
|
except Exception as e: |
|
print(f"Error loading Whisper model: {e}") |
|
whisper_loaded = False |
|
|
|
|
|
def format_person_display(person): |
|
"""Format person information for display in the dropdown.""" |
|
return f"{person['name']} ({person['role']})" |
|
|
|
|
|
def get_people_choices(): |
|
"""Get formatted choices for the people dropdown.""" |
|
people = social_graph.get_people_list() |
|
choices = {} |
|
for person in people: |
|
display_name = format_person_display(person) |
|
person_id = person["id"] |
|
choices[display_name] = person_id |
|
return choices |
|
|
|
|
|
def get_topics_for_person(person_id): |
|
"""Get topics for a specific person.""" |
|
if not person_id: |
|
return [] |
|
|
|
person_context = social_graph.get_person_context(person_id) |
|
topics = person_context.get("topics", []) |
|
return topics |
|
|
|
|
|
def get_suggestion_categories(): |
|
"""Get suggestion categories from the social graph.""" |
|
if "common_utterances" in social_graph.graph: |
|
return list(social_graph.graph["common_utterances"].keys()) |
|
return [] |
|
|
|
|
|
def on_person_change(person_id): |
|
"""Handle person selection change.""" |
|
if not person_id: |
|
return "", "", [] |
|
|
|
person_context = social_graph.get_person_context(person_id) |
|
|
|
|
|
name = person_context.get("name", "") |
|
role = person_context.get("role", "") |
|
frequency = person_context.get("frequency", "") |
|
context_text = person_context.get("context", "") |
|
|
|
context_info = f"""### I'm talking to: {name} |
|
**Relationship:** {role} |
|
**How often we talk:** {frequency} |
|
|
|
**Our relationship:** {context_text} |
|
""" |
|
|
|
|
|
phrases = person_context.get("common_phrases", []) |
|
phrases_text = "\n\n".join(phrases) |
|
|
|
|
|
topics = person_context.get("topics", []) |
|
|
|
return context_info, phrases_text, topics |
|
|
|
|
|
def change_model(model_name, progress=gr.Progress()): |
|
"""Change the language model used for generation. |
|
|
|
Args: |
|
model_name: The name of the model to use |
|
progress: Gradio progress indicator |
|
|
|
Returns: |
|
A status message about the model change |
|
""" |
|
global suggestion_generator |
|
|
|
print(f"Changing model to: {model_name}") |
|
|
|
|
|
if model_name == suggestion_generator.model_name: |
|
return f"Already using model: {model_name}" |
|
|
|
|
|
progress(0, desc=f"Loading model: {model_name}") |
|
|
|
|
|
success = suggestion_generator.load_model(model_name) |
|
|
|
if success: |
|
progress(1.0, desc=f"Model loaded: {model_name}") |
|
return f"Successfully switched to model: {model_name}" |
|
else: |
|
progress(1.0, desc="Model loading failed") |
|
return f"Failed to load model: {model_name}. Using fallback responses instead." |
|
|
|
|
|
def generate_suggestions( |
|
person_id, |
|
user_input, |
|
suggestion_type, |
|
selected_topic=None, |
|
model_name="distilgpt2", |
|
temperature=0.7, |
|
progress=gr.Progress(), |
|
): |
|
"""Generate suggestions based on the selected person and user input.""" |
|
print( |
|
f"Generating suggestions with: person_id={person_id}, user_input={user_input}, " |
|
f"suggestion_type={suggestion_type}, selected_topic={selected_topic}, " |
|
f"model={model_name}, temperature={temperature}" |
|
) |
|
|
|
|
|
progress(0, desc="Starting...") |
|
|
|
if not person_id: |
|
print("No person_id provided") |
|
return "Please select who you're talking to first." |
|
|
|
|
|
if model_name != suggestion_generator.model_name: |
|
progress(0.1, desc=f"Switching to model: {model_name}") |
|
change_model(model_name, progress) |
|
|
|
person_context = social_graph.get_person_context(person_id) |
|
print(f"Person context: {person_context}") |
|
|
|
|
|
inferred_category = None |
|
if user_input and suggestion_type == "auto_detect": |
|
|
|
user_input_lower = user_input.lower() |
|
if any( |
|
word in user_input_lower |
|
for word in ["hi", "hello", "morning", "afternoon", "evening"] |
|
): |
|
inferred_category = "greetings" |
|
elif any( |
|
word in user_input_lower |
|
for word in ["feel", "tired", "happy", "sad", "frustrated"] |
|
): |
|
inferred_category = "emotions" |
|
elif any( |
|
word in user_input_lower |
|
for word in ["need", "want", "help", "water", "toilet", "loo"] |
|
): |
|
inferred_category = "needs" |
|
elif any( |
|
word in user_input_lower |
|
for word in ["what", "how", "when", "where", "why", "did"] |
|
): |
|
inferred_category = "questions" |
|
elif any( |
|
word in user_input_lower |
|
for word in ["remember", "used to", "back then", "when we"] |
|
): |
|
inferred_category = "reminiscing" |
|
elif any( |
|
word in user_input_lower |
|
for word in ["code", "program", "software", "app", "tech"] |
|
): |
|
inferred_category = "tech_talk" |
|
elif any( |
|
word in user_input_lower |
|
for word in ["plan", "schedule", "appointment", "tomorrow", "later"] |
|
): |
|
inferred_category = "organization" |
|
|
|
|
|
if selected_topic: |
|
person_context["selected_topic"] = selected_topic |
|
|
|
|
|
result = "" |
|
|
|
|
|
if suggestion_type == "model": |
|
print("Using model for suggestions") |
|
progress(0.2, desc="Preparing to generate suggestions...") |
|
|
|
|
|
suggestions = [] |
|
for i in range(3): |
|
progress_value = 0.3 + (i * 0.2) |
|
progress(progress_value, desc=f"Generating suggestion {i+1}/3") |
|
print(f"Generating suggestion {i+1}/3") |
|
try: |
|
suggestion = suggestion_generator.generate_suggestion( |
|
person_context, user_input, temperature=temperature |
|
) |
|
print(f"Generated suggestion: {suggestion}") |
|
suggestions.append(suggestion) |
|
except Exception as e: |
|
print(f"Error generating suggestion: {e}") |
|
suggestions.append("Error generating suggestion") |
|
|
|
result = ( |
|
f"### AI-Generated Responses (using {suggestion_generator.model_name}):\n\n" |
|
) |
|
for i, suggestion in enumerate(suggestions, 1): |
|
result += f"{i}. {suggestion}\n\n" |
|
|
|
print(f"Final result: {result[:100]}...") |
|
|
|
|
|
elif suggestion_type == "common_phrases": |
|
phrases = social_graph.get_relevant_phrases(person_id, user_input) |
|
result = "### My Common Phrases with this Person:\n\n" |
|
for i, phrase in enumerate(phrases, 1): |
|
result += f"{i}. {phrase}\n\n" |
|
|
|
|
|
elif suggestion_type == "auto_detect": |
|
print(f"Auto-detect mode, inferred category: {inferred_category}") |
|
if inferred_category: |
|
utterances = social_graph.get_common_utterances(inferred_category) |
|
print(f"Got utterances for category {inferred_category}: {utterances}") |
|
result = f"### Auto-detected category: {inferred_category.replace('_', ' ').title()}\n\n" |
|
for i, utterance in enumerate(utterances, 1): |
|
result += f"{i}. {utterance}\n\n" |
|
else: |
|
print("No category inferred, falling back to model") |
|
|
|
progress(0.3, desc="No category detected, using model instead...") |
|
try: |
|
suggestions = [] |
|
for i in range(3): |
|
progress_value = 0.4 + (i * 0.15) |
|
progress( |
|
progress_value, desc=f"Generating fallback suggestion {i+1}/3" |
|
) |
|
suggestion = suggestion_generator.generate_suggestion( |
|
person_context, user_input, temperature=temperature |
|
) |
|
suggestions.append(suggestion) |
|
|
|
result = f"### AI-Generated Responses (no category detected, using {suggestion_generator.model_name}):\n\n" |
|
for i, suggestion in enumerate(suggestions, 1): |
|
result += f"{i}. {suggestion}\n\n" |
|
except Exception as e: |
|
print(f"Error generating fallback suggestion: {e}") |
|
result = "### Could not generate a response:\n\n" |
|
result += "1. Sorry, I couldn't generate a suggestion at this time.\n\n" |
|
|
|
|
|
elif suggestion_type in get_suggestion_categories(): |
|
print(f"Using category: {suggestion_type}") |
|
utterances = social_graph.get_common_utterances(suggestion_type) |
|
print(f"Got utterances: {utterances}") |
|
result = f"### {suggestion_type.replace('_', ' ').title()} Phrases:\n\n" |
|
for i, utterance in enumerate(utterances, 1): |
|
result += f"{i}. {utterance}\n\n" |
|
|
|
|
|
else: |
|
print(f"No handler for suggestion type: {suggestion_type}") |
|
result = "No suggestions available. Please try a different option." |
|
|
|
print(f"Returning result: {result[:100]}...") |
|
|
|
|
|
progress(1.0, desc="Completed!") |
|
|
|
return result |
|
|
|
|
|
def transcribe_audio(audio_path): |
|
"""Transcribe audio using Whisper.""" |
|
if not whisper_loaded: |
|
return "Whisper model not loaded. Please check your installation." |
|
|
|
try: |
|
|
|
result = whisper_model.transcribe(audio_path) |
|
return result["text"] |
|
except Exception: |
|
return "Could not transcribe audio. Please try again." |
|
|
|
|
|
|
|
with gr.Blocks(title="Will's AAC Communication Aid") as demo: |
|
gr.Markdown("# Will's AAC Communication Aid") |
|
gr.Markdown( |
|
""" |
|
This demo simulates an AAC system from Will's perspective (a 38-year-old with MND). |
|
|
|
**How to use this demo:** |
|
1. Select who you (Will) are talking to from the dropdown |
|
2. Optionally select a conversation topic |
|
3. Enter or record what the other person said to you |
|
4. Get suggested responses based on your relationship with that person |
|
""" |
|
) |
|
|
|
|
|
with gr.Accordion("About Me (Will)", open=False): |
|
gr.Markdown( |
|
""" |
|
I'm Will, a 38-year-old computer programmer from Manchester with MND (diagnosed 5 months ago). |
|
I live with my wife Emma and two children (Mabel, 4 and Billy, 7). |
|
Originally from South East London, I enjoy technology, Manchester United, and have fond memories of cycling and hiking. |
|
I'm increasingly using this AAC system as my speech becomes more difficult. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
person_dropdown = gr.Dropdown( |
|
choices=get_people_choices(), |
|
label="I'm talking to:", |
|
info="Select who you (Will) are talking to", |
|
) |
|
|
|
|
|
def get_filtered_topics(person_id): |
|
if not person_id: |
|
return [] |
|
person_context = social_graph.get_person_context(person_id) |
|
return person_context.get("topics", []) |
|
|
|
|
|
topic_dropdown = gr.Dropdown( |
|
choices=[], |
|
label="Topic (optional):", |
|
info="Select a topic to discuss or respond about", |
|
allow_custom_value=True, |
|
) |
|
|
|
|
|
context_display = gr.Markdown(label="Relationship Context") |
|
|
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox( |
|
label="What they said to me: (leave empty to start a conversation)", |
|
placeholder='Examples:\n"How was your physio session today?"\n"The kids are asking if you want to watch a movie tonight"\n"I\'ve been looking at that new AAC software you mentioned"', |
|
lines=3, |
|
) |
|
|
|
|
|
with gr.Row(): |
|
audio_input = gr.Audio( |
|
label="Or record what they said:", |
|
type="filepath", |
|
sources=["microphone"], |
|
) |
|
transcribe_btn = gr.Button("Transcribe", variant="secondary") |
|
|
|
|
|
suggestion_type = gr.Radio( |
|
choices=[ |
|
"model", |
|
"auto_detect", |
|
"common_phrases", |
|
] |
|
+ get_suggestion_categories(), |
|
value="model", |
|
label="How should I respond?", |
|
info="Choose response type (model = AI-generated, auto_detect = automatic category detection)", |
|
) |
|
|
|
|
|
with gr.Row(): |
|
model_dropdown = gr.Dropdown( |
|
choices=list(AVAILABLE_MODELS.keys()), |
|
value="distilgpt2", |
|
label="Language Model", |
|
info="Select which AI model to use for generating responses", |
|
) |
|
|
|
temperature_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.5, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature", |
|
info="Controls randomness (higher = more creative, lower = more focused)", |
|
) |
|
|
|
|
|
generate_btn = gr.Button( |
|
"Generate My Responses/Conversation Starters", variant="primary" |
|
) |
|
|
|
|
|
model_status = gr.Markdown( |
|
value=f"Current model: {suggestion_generator.model_name}", |
|
label="Model Status", |
|
) |
|
|
|
with gr.Column(scale=1): |
|
|
|
common_phrases = gr.Textbox( |
|
label="My Common Phrases", |
|
placeholder="Common phrases I often use with this person will appear here...", |
|
lines=5, |
|
) |
|
|
|
|
|
suggestions_output = gr.Markdown( |
|
label="My Suggested Responses", |
|
value="Suggested responses will appear here...", |
|
) |
|
|
|
|
|
def handle_person_change(person_id): |
|
"""Handle person selection change and update UI elements.""" |
|
context_info, phrases_text, _ = on_person_change(person_id) |
|
|
|
|
|
topics = get_filtered_topics(person_id) |
|
|
|
|
|
return context_info, phrases_text, gr.update(choices=topics) |
|
|
|
def handle_model_change(model_name): |
|
"""Handle model selection change.""" |
|
status = change_model(model_name) |
|
return status |
|
|
|
|
|
person_dropdown.change( |
|
handle_person_change, |
|
inputs=[person_dropdown], |
|
outputs=[context_display, common_phrases, topic_dropdown], |
|
) |
|
|
|
|
|
model_dropdown.change( |
|
handle_model_change, |
|
inputs=[model_dropdown], |
|
outputs=[model_status], |
|
) |
|
|
|
|
|
generate_btn.click( |
|
generate_suggestions, |
|
inputs=[ |
|
person_dropdown, |
|
user_input, |
|
suggestion_type, |
|
topic_dropdown, |
|
model_dropdown, |
|
temperature_slider, |
|
], |
|
outputs=[suggestions_output], |
|
) |
|
|
|
|
|
transcribe_btn.click( |
|
transcribe_audio, |
|
inputs=[audio_input], |
|
outputs=[user_input], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|