willwade commited on
Commit
c8d94c7
Β·
1 Parent(s): d8b372b

adding better style changes and tone

Browse files
Files changed (3) hide show
  1. app.py +82 -31
  2. custom.css +137 -0
  3. utils.py +23 -0
app.py CHANGED
@@ -6,21 +6,23 @@ from utils import SocialGraphManager, SuggestionGenerator
6
 
7
  # Define available models
8
  AVAILABLE_MODELS = {
9
- "distilgpt2": "DistilGPT2 (Fast, smaller model)",
10
- "gpt2": "GPT-2 (Medium size, better quality)",
11
  "google/gemma-3-1b-it": "Gemma 3 1B-IT (Small, instruction-tuned)",
 
 
12
  "Qwen/Qwen1.5-0.5B": "Qwen 1.5 0.5B (Very small, efficient)",
13
  "Qwen/Qwen1.5-1.8B": "Qwen 1.5 1.8B (Small, good quality)",
14
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0": "TinyLlama 1.1B (Small, chat-tuned)",
15
  "microsoft/phi-3-mini-4k-instruct": "Phi-3 Mini (Small, instruction-tuned)",
16
  "microsoft/phi-2": "Phi-2 (Small, high quality for size)",
 
 
17
  }
18
 
19
  # Initialize the social graph manager
20
  social_graph = SocialGraphManager("social_graph.json")
21
 
22
- # Initialize the suggestion generator with Gemma 3B (default)
23
- suggestion_generator = SuggestionGenerator("google/gemma-3-1b-it")
24
 
25
  # Test the model to make sure it's working
26
  test_result = suggestion_generator.test_model()
@@ -67,9 +69,19 @@ def get_topics_for_person(person_id):
67
 
68
 
69
  def get_suggestion_categories():
70
- """Get suggestion categories from the social graph."""
71
  if "common_utterances" in social_graph.graph:
72
- return list(social_graph.graph["common_utterances"].keys())
 
 
 
 
 
 
 
 
 
 
73
  return []
74
 
75
 
@@ -140,15 +152,16 @@ def generate_suggestions(
140
  user_input,
141
  suggestion_type,
142
  selected_topic=None,
143
- model_name="google/gemma-3-1b-it",
144
  temperature=0.7,
 
145
  progress=gr.Progress(),
146
  ):
147
  """Generate suggestions based on the selected person and user input."""
148
  print(
149
  f"Generating suggestions with: person_id={person_id}, user_input={user_input}, "
150
  f"suggestion_type={suggestion_type}, selected_topic={selected_topic}, "
151
- f"model={model_name}, temperature={temperature}"
152
  )
153
 
154
  # Initialize progress
@@ -166,9 +179,16 @@ def generate_suggestions(
166
  person_context = social_graph.get_person_context(person_id)
167
  print(f"Person context: {person_context}")
168
 
 
 
 
 
 
 
 
169
  # Try to infer conversation type if user input is provided
170
  inferred_category = None
171
- if user_input and suggestion_type == "auto_detect":
172
  # Simple keyword matching for now - could be enhanced with ML
173
  user_input_lower = user_input.lower()
174
  if any(
@@ -215,7 +235,7 @@ def generate_suggestions(
215
  result = ""
216
 
217
  # If suggestion type is "model", use the language model for multiple suggestions
218
- if suggestion_type == "model":
219
  print("Using model for suggestions")
220
  progress(0.2, desc="Preparing to generate suggestions...")
221
 
@@ -226,6 +246,8 @@ def generate_suggestions(
226
  progress(progress_value, desc=f"Generating suggestion {i+1}/3")
227
  print(f"Generating suggestion {i+1}/3")
228
  try:
 
 
229
  suggestion = suggestion_generator.generate_suggestion(
230
  person_context, user_input, temperature=temperature
231
  )
@@ -244,14 +266,14 @@ def generate_suggestions(
244
  print(f"Final result: {result[:100]}...")
245
 
246
  # If suggestion type is "common_phrases", use the person's common phrases
247
- elif suggestion_type == "common_phrases":
248
  phrases = social_graph.get_relevant_phrases(person_id, user_input)
249
  result = "### My Common Phrases with this Person:\n\n"
250
  for i, phrase in enumerate(phrases, 1):
251
  result += f"{i}. {phrase}\n\n"
252
 
253
  # If suggestion type is "auto_detect", use the inferred category or default to model
254
- elif suggestion_type == "auto_detect":
255
  print(f"Auto-detect mode, inferred category: {inferred_category}")
256
  if inferred_category:
257
  utterances = social_graph.get_common_utterances(inferred_category)
@@ -270,6 +292,8 @@ def generate_suggestions(
270
  progress(
271
  progress_value, desc=f"Generating fallback suggestion {i+1}/3"
272
  )
 
 
273
  suggestion = suggestion_generator.generate_suggestion(
274
  person_context, user_input, temperature=temperature
275
  )
@@ -284,17 +308,25 @@ def generate_suggestions(
284
  result += "1. Sorry, I couldn't generate a suggestion at this time.\n\n"
285
 
286
  # If suggestion type is a category from common_utterances
287
- elif suggestion_type in get_suggestion_categories():
288
- print(f"Using category: {suggestion_type}")
289
- utterances = social_graph.get_common_utterances(suggestion_type)
 
 
 
 
 
 
 
 
290
  print(f"Got utterances: {utterances}")
291
- result = f"### {suggestion_type.replace('_', ' ').title()} Phrases:\n\n"
292
  for i, utterance in enumerate(utterances, 1):
293
  result += f"{i}. {utterance}\n\n"
294
 
295
  # Default fallback
296
  else:
297
- print(f"No handler for suggestion type: {suggestion_type}")
298
  result = "No suggestions available. Please try a different option."
299
 
300
  print(f"Returning result: {result[:100]}...")
@@ -325,7 +357,7 @@ def transcribe_audio(audio_path):
325
 
326
 
327
  # Create the Gradio interface
328
- with gr.Blocks(title="Will's AAC Communication Aid") as demo:
329
  gr.Markdown("# Will's AAC Communication Aid")
330
  gr.Markdown(
331
  """
@@ -385,33 +417,51 @@ with gr.Blocks(title="Will's AAC Communication Aid") as demo:
385
  lines=3,
386
  )
387
 
388
- # Audio input
389
- with gr.Row():
 
390
  audio_input = gr.Audio(
391
- label="Or record what they said:",
392
  type="filepath",
393
  sources=["microphone"],
 
 
 
 
 
394
  )
395
- transcribe_btn = gr.Button("Transcribe", variant="secondary")
396
 
397
- # Suggestion type selection
398
  suggestion_type = gr.Radio(
399
  choices=[
400
- "model",
401
- "auto_detect",
402
- "common_phrases",
403
  ]
404
  + get_suggestion_categories(),
405
- value="model", # Default to model for better results
406
  label="How should I respond?",
407
- info="Choose response type (model = AI-generated, auto_detect = automatic category detection)",
 
408
  )
409
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  # Model selection
411
  with gr.Row():
412
  model_dropdown = gr.Dropdown(
413
  choices=list(AVAILABLE_MODELS.keys()),
414
- value="google/gemma-3-1b-it",
415
  label="Language Model",
416
  info="Select which AI model to use for generating responses",
417
  )
@@ -491,12 +541,13 @@ with gr.Blocks(title="Will's AAC Communication Aid") as demo:
491
  topic_dropdown,
492
  model_dropdown,
493
  temperature_slider,
 
494
  ],
495
  outputs=[suggestions_output],
496
  )
497
 
498
- # Transcribe audio to text
499
- transcribe_btn.click(
500
  transcribe_audio,
501
  inputs=[audio_input],
502
  outputs=[user_input],
 
6
 
7
  # Define available models
8
  AVAILABLE_MODELS = {
 
 
9
  "google/gemma-3-1b-it": "Gemma 3 1B-IT (Small, instruction-tuned)",
10
+ "google/gemma-3-2b-it": "Gemma 3 2B-IT (Default, instruction-tuned)",
11
+ "google/gemma-3-4b-it": "Gemma 3 4B-IT (Better quality, instruction-tuned)",
12
  "Qwen/Qwen1.5-0.5B": "Qwen 1.5 0.5B (Very small, efficient)",
13
  "Qwen/Qwen1.5-1.8B": "Qwen 1.5 1.8B (Small, good quality)",
14
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0": "TinyLlama 1.1B (Small, chat-tuned)",
15
  "microsoft/phi-3-mini-4k-instruct": "Phi-3 Mini (Small, instruction-tuned)",
16
  "microsoft/phi-2": "Phi-2 (Small, high quality for size)",
17
+ "distilgpt2": "DistilGPT2 (Fast, smaller model)",
18
+ "gpt2": "GPT-2 (Medium size, better quality)",
19
  }
20
 
21
  # Initialize the social graph manager
22
  social_graph = SocialGraphManager("social_graph.json")
23
 
24
+ # Initialize the suggestion generator with Gemma 3 2B (default)
25
+ suggestion_generator = SuggestionGenerator("google/gemma-3-2b-it")
26
 
27
  # Test the model to make sure it's working
28
  test_result = suggestion_generator.test_model()
 
69
 
70
 
71
  def get_suggestion_categories():
72
+ """Get suggestion categories from the social graph with emoji prefixes."""
73
  if "common_utterances" in social_graph.graph:
74
+ categories = list(social_graph.graph["common_utterances"].keys())
75
+ emoji_map = {
76
+ "greetings": "πŸ‘‹ greetings",
77
+ "needs": "πŸ†˜ needs",
78
+ "emotions": "😊 emotions",
79
+ "questions": "❓ questions",
80
+ "tech_talk": "πŸ’» tech_talk",
81
+ "reminiscing": "πŸ”™ reminiscing",
82
+ "organization": "πŸ“… organization",
83
+ }
84
+ return [emoji_map.get(cat, cat) for cat in categories]
85
  return []
86
 
87
 
 
152
  user_input,
153
  suggestion_type,
154
  selected_topic=None,
155
+ model_name="google/gemma-3-2b-it",
156
  temperature=0.7,
157
+ mood=3,
158
  progress=gr.Progress(),
159
  ):
160
  """Generate suggestions based on the selected person and user input."""
161
  print(
162
  f"Generating suggestions with: person_id={person_id}, user_input={user_input}, "
163
  f"suggestion_type={suggestion_type}, selected_topic={selected_topic}, "
164
+ f"model={model_name}, temperature={temperature}, mood={mood}"
165
  )
166
 
167
  # Initialize progress
 
179
  person_context = social_graph.get_person_context(person_id)
180
  print(f"Person context: {person_context}")
181
 
182
+ # Remove emoji prefix from suggestion_type if present
183
+ clean_suggestion_type = suggestion_type
184
+ if suggestion_type.startswith(
185
+ ("πŸ€–", "πŸ”", "πŸ’¬", "πŸ‘‹", "πŸ†˜", "😊", "❓", "πŸ’»", "πŸ”™", "πŸ“…")
186
+ ):
187
+ clean_suggestion_type = suggestion_type[2:].strip() # Remove emoji and space
188
+
189
  # Try to infer conversation type if user input is provided
190
  inferred_category = None
191
+ if user_input and clean_suggestion_type == "auto_detect":
192
  # Simple keyword matching for now - could be enhanced with ML
193
  user_input_lower = user_input.lower()
194
  if any(
 
235
  result = ""
236
 
237
  # If suggestion type is "model", use the language model for multiple suggestions
238
+ if clean_suggestion_type == "model":
239
  print("Using model for suggestions")
240
  progress(0.2, desc="Preparing to generate suggestions...")
241
 
 
246
  progress(progress_value, desc=f"Generating suggestion {i+1}/3")
247
  print(f"Generating suggestion {i+1}/3")
248
  try:
249
+ # Add mood to person context
250
+ person_context["mood"] = mood
251
  suggestion = suggestion_generator.generate_suggestion(
252
  person_context, user_input, temperature=temperature
253
  )
 
266
  print(f"Final result: {result[:100]}...")
267
 
268
  # If suggestion type is "common_phrases", use the person's common phrases
269
+ elif clean_suggestion_type == "common_phrases":
270
  phrases = social_graph.get_relevant_phrases(person_id, user_input)
271
  result = "### My Common Phrases with this Person:\n\n"
272
  for i, phrase in enumerate(phrases, 1):
273
  result += f"{i}. {phrase}\n\n"
274
 
275
  # If suggestion type is "auto_detect", use the inferred category or default to model
276
+ elif clean_suggestion_type == "auto_detect":
277
  print(f"Auto-detect mode, inferred category: {inferred_category}")
278
  if inferred_category:
279
  utterances = social_graph.get_common_utterances(inferred_category)
 
292
  progress(
293
  progress_value, desc=f"Generating fallback suggestion {i+1}/3"
294
  )
295
+ # Add mood to person context
296
+ person_context["mood"] = mood
297
  suggestion = suggestion_generator.generate_suggestion(
298
  person_context, user_input, temperature=temperature
299
  )
 
308
  result += "1. Sorry, I couldn't generate a suggestion at this time.\n\n"
309
 
310
  # If suggestion type is a category from common_utterances
311
+ elif clean_suggestion_type in [
312
+ "greetings",
313
+ "needs",
314
+ "emotions",
315
+ "questions",
316
+ "tech_talk",
317
+ "reminiscing",
318
+ "organization",
319
+ ]:
320
+ print(f"Using category: {clean_suggestion_type}")
321
+ utterances = social_graph.get_common_utterances(clean_suggestion_type)
322
  print(f"Got utterances: {utterances}")
323
+ result = f"### {clean_suggestion_type.replace('_', ' ').title()} Phrases:\n\n"
324
  for i, utterance in enumerate(utterances, 1):
325
  result += f"{i}. {utterance}\n\n"
326
 
327
  # Default fallback
328
  else:
329
+ print(f"No handler for suggestion type: {clean_suggestion_type}")
330
  result = "No suggestions available. Please try a different option."
331
 
332
  print(f"Returning result: {result[:100]}...")
 
357
 
358
 
359
  # Create the Gradio interface
360
+ with gr.Blocks(title="Will's AAC Communication Aid", css="custom.css") as demo:
361
  gr.Markdown("# Will's AAC Communication Aid")
362
  gr.Markdown(
363
  """
 
417
  lines=3,
418
  )
419
 
420
+ # Audio input with auto-transcription
421
+ with gr.Column(elem_classes="audio-recorder-container"):
422
+ gr.Markdown("### 🎀 Or record what they said")
423
  audio_input = gr.Audio(
424
+ label="",
425
  type="filepath",
426
  sources=["microphone"],
427
+ elem_classes="audio-recorder",
428
+ )
429
+ gr.Markdown(
430
+ "*Recording will auto-transcribe when stopped*",
431
+ elem_classes="auto-transcribe-hint",
432
  )
 
433
 
434
+ # Suggestion type selection with emojis
435
  suggestion_type = gr.Radio(
436
  choices=[
437
+ "πŸ€– model",
438
+ "πŸ” auto_detect",
439
+ "πŸ’¬ common_phrases",
440
  ]
441
  + get_suggestion_categories(),
442
+ value="πŸ€– model", # Default to model for better results
443
  label="How should I respond?",
444
+ info="Choose response type",
445
+ elem_classes="emoji-response-options",
446
  )
447
 
448
+ # Add a mood slider with emoji indicators at the ends
449
+ with gr.Column(elem_classes="mood-slider-container"):
450
+ mood_slider = gr.Slider(
451
+ minimum=1,
452
+ maximum=5,
453
+ value=3,
454
+ step=1,
455
+ label="How am I feeling today?",
456
+ info="This will influence the tone of your responses (😒 Sad β†’ Happy πŸ˜„)",
457
+ elem_classes="mood-slider",
458
+ )
459
+
460
  # Model selection
461
  with gr.Row():
462
  model_dropdown = gr.Dropdown(
463
  choices=list(AVAILABLE_MODELS.keys()),
464
+ value="google/gemma-3-2b-it",
465
  label="Language Model",
466
  info="Select which AI model to use for generating responses",
467
  )
 
541
  topic_dropdown,
542
  model_dropdown,
543
  temperature_slider,
544
+ mood_slider,
545
  ],
546
  outputs=[suggestions_output],
547
  )
548
 
549
+ # Auto-transcribe audio to text when recording stops
550
+ audio_input.stop_recording(
551
  transcribe_audio,
552
  inputs=[audio_input],
553
  outputs=[user_input],
custom.css ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Custom CSS for Will's AAC Communication Aid */
2
+
3
+ /* Main container styling */
4
+ .gradio-container {
5
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
6
+ }
7
+
8
+ /* Emoji response options */
9
+ .emoji-response-options .gr-form {
10
+ margin-top: 10px;
11
+ }
12
+
13
+ /* Direct emoji labels for radio buttons */
14
+ .emoji-response-options label[for$="model"] span:first-child::before {
15
+ content: "πŸ€– ";
16
+ }
17
+
18
+ .emoji-response-options label[for$="auto_detect"] span:first-child::before {
19
+ content: "πŸ” ";
20
+ }
21
+
22
+ .emoji-response-options label[for$="common_phrases"] span:first-child::before {
23
+ content: "πŸ’¬ ";
24
+ }
25
+
26
+ .emoji-response-options label[for$="greetings"] span:first-child::before {
27
+ content: "πŸ‘‹ ";
28
+ }
29
+
30
+ .emoji-response-options label[for$="needs"] span:first-child::before {
31
+ content: "πŸ†˜ ";
32
+ }
33
+
34
+ .emoji-response-options label[for$="emotions"] span:first-child::before {
35
+ content: "😊 ";
36
+ }
37
+
38
+ .emoji-response-options label[for$="questions"] span:first-child::before {
39
+ content: "❓ ";
40
+ }
41
+
42
+ .emoji-response-options label[for$="tech_talk"] span:first-child::before {
43
+ content: "πŸ’» ";
44
+ }
45
+
46
+ .emoji-response-options label[for$="reminiscing"] span:first-child::before {
47
+ content: "πŸ”™ ";
48
+ }
49
+
50
+ .emoji-response-options label[for$="organization"] span:first-child::before {
51
+ content: "πŸ“… ";
52
+ }
53
+
54
+ /* Mood slider styling */
55
+ .mood-slider-container {
56
+ margin-bottom: 20px;
57
+ position: relative;
58
+ }
59
+
60
+ .mood-slider .gr-slider {
61
+ height: 20px;
62
+ border-radius: 10px;
63
+ }
64
+
65
+ .mood-slider .gr-slider-value {
66
+ font-weight: bold;
67
+ }
68
+
69
+ /* Add emoji indicators to the ends of the slider */
70
+ .mood-slider::before {
71
+ content: "😒";
72
+ position: absolute;
73
+ left: 0;
74
+ bottom: 5px;
75
+ font-size: 24px;
76
+ }
77
+
78
+ .mood-slider::after {
79
+ content: "πŸ˜„";
80
+ position: absolute;
81
+ right: 0;
82
+ bottom: 5px;
83
+ font-size: 24px;
84
+ }
85
+
86
+ /* Style for audio recorder */
87
+ .audio-recorder-container {
88
+ margin-top: 15px;
89
+ margin-bottom: 15px;
90
+ border: 2px solid #2563eb;
91
+ border-radius: 8px;
92
+ padding: 10px;
93
+ background-color: rgba(37, 99, 235, 0.05);
94
+ }
95
+
96
+ .audio-recorder-container h3 {
97
+ margin-top: 0;
98
+ color: #2563eb;
99
+ }
100
+
101
+ .audio-recorder {
102
+ margin: 10px 0;
103
+ }
104
+
105
+ .audio-recorder .mic-icon {
106
+ color: #2563eb;
107
+ font-size: 24px;
108
+ }
109
+
110
+ .auto-transcribe-hint {
111
+ font-size: 12px;
112
+ color: #666;
113
+ margin-top: 0;
114
+ text-align: center;
115
+ }
116
+
117
+ /* Improve button styling */
118
+ .gr-button-primary {
119
+ background-color: #2563eb;
120
+ border-radius: 8px;
121
+ font-weight: 600;
122
+ transition: all 0.3s ease;
123
+ }
124
+
125
+ .gr-button-primary:hover {
126
+ background-color: #1d4ed8;
127
+ transform: translateY(-2px);
128
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
129
+ }
130
+
131
+ /* Improve markdown output */
132
+ #suggestions_output {
133
+ border-radius: 8px;
134
+ padding: 15px;
135
+ background-color: #f8fafc;
136
+ border-left: 4px solid #2563eb;
137
+ }
utils.py CHANGED
@@ -277,6 +277,26 @@ class SuggestionGenerator:
277
  self.model_loaded = False
278
  return False
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  def test_model(self) -> str:
281
  """Test if the model is working correctly."""
282
  if not self.model_loaded:
@@ -330,6 +350,7 @@ class SuggestionGenerator:
330
  selected_topic = person_context.get("selected_topic", "")
331
  common_phrases = person_context.get("common_phrases", [])
332
  frequency = person_context.get("frequency", "")
 
333
 
334
  # Get AAC user information
335
  aac_user = self.aac_user_info
@@ -344,6 +365,8 @@ I am talking to {name}, who is my {role}.
344
  About {name}: {context}
345
  We typically talk about: {', '.join(topics)}
346
  We communicate {frequency}.
 
 
347
  """
348
 
349
  # Add communication style based on relationship
 
277
  self.model_loaded = False
278
  return False
279
 
280
+ def _get_mood_description(self, mood_value: int) -> str:
281
+ """Convert mood value (1-5) to a descriptive string.
282
+
283
+ Args:
284
+ mood_value: Integer from 1-5 representing mood (1=sad, 5=happy)
285
+
286
+ Returns:
287
+ String description of the mood
288
+ """
289
+ mood_descriptions = {
290
+ 1: "I'm feeling quite down and sad today. My responses might be more subdued.",
291
+ 2: "I'm feeling a bit low today. I might be less enthusiastic than usual.",
292
+ 3: "I'm feeling okay today - neither particularly happy nor sad.",
293
+ 4: "I'm feeling pretty good today. I'm in a positive mood.",
294
+ 5: "I'm feeling really happy and upbeat today! I'm in a great mood.",
295
+ }
296
+
297
+ # Default to neutral if value is out of range
298
+ return mood_descriptions.get(mood_value, mood_descriptions[3])
299
+
300
  def test_model(self) -> str:
301
  """Test if the model is working correctly."""
302
  if not self.model_loaded:
 
350
  selected_topic = person_context.get("selected_topic", "")
351
  common_phrases = person_context.get("common_phrases", [])
352
  frequency = person_context.get("frequency", "")
353
+ mood = person_context.get("mood", 3) # Default to neutral mood (3)
354
 
355
  # Get AAC user information
356
  aac_user = self.aac_user_info
 
365
  About {name}: {context}
366
  We typically talk about: {', '.join(topics)}
367
  We communicate {frequency}.
368
+
369
+ My current mood: {self._get_mood_description(mood)}
370
  """
371
 
372
  # Add communication style based on relationship