add in changes to make app work as well as demo
Browse files
app.py
CHANGED
@@ -4,10 +4,22 @@ import tempfile
|
|
4 |
import os
|
5 |
from utils import SocialGraphManager, SuggestionGenerator
|
6 |
|
7 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
social_graph = SocialGraphManager("social_graph.json")
|
9 |
|
10 |
-
# Initialize the suggestion generator with distilgpt2
|
11 |
suggestion_generator = SuggestionGenerator("distilgpt2")
|
12 |
|
13 |
# Test the model to make sure it's working
|
@@ -23,7 +35,8 @@ if not suggestion_generator.model_loaded:
|
|
23 |
try:
|
24 |
whisper_model = whisper.load_model("tiny")
|
25 |
whisper_loaded = True
|
26 |
-
except Exception:
|
|
|
27 |
whisper_loaded = False
|
28 |
|
29 |
|
@@ -90,16 +103,55 @@ def on_person_change(person_id):
|
|
90 |
return context_info, phrases_text, topics
|
91 |
|
92 |
|
93 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
"""Generate suggestions based on the selected person and user input."""
|
95 |
print(
|
96 |
-
f"Generating suggestions with: person_id={person_id}, user_input={user_input},
|
|
|
|
|
97 |
)
|
98 |
|
99 |
if not person_id:
|
100 |
print("No person_id provided")
|
101 |
return "Please select who you're talking to first."
|
102 |
|
|
|
|
|
|
|
|
|
103 |
person_context = social_graph.get_person_context(person_id)
|
104 |
print(f"Person context: {person_context}")
|
105 |
|
@@ -160,7 +212,7 @@ def generate_suggestions(person_id, user_input, suggestion_type, selected_topic=
|
|
160 |
print(f"Generating suggestion {i+1}/3")
|
161 |
try:
|
162 |
suggestion = suggestion_generator.generate_suggestion(
|
163 |
-
person_context, user_input, temperature=
|
164 |
)
|
165 |
print(f"Generated suggestion: {suggestion}")
|
166 |
suggestions.append(suggestion)
|
@@ -168,11 +220,13 @@ def generate_suggestions(person_id, user_input, suggestion_type, selected_topic=
|
|
168 |
print(f"Error generating suggestion: {e}")
|
169 |
suggestions.append("Error generating suggestion")
|
170 |
|
171 |
-
result =
|
|
|
|
|
172 |
for i, suggestion in enumerate(suggestions, 1):
|
173 |
result += f"{i}. {suggestion}\n\n"
|
174 |
|
175 |
-
print(f"Final result: {result}")
|
176 |
|
177 |
# If suggestion type is "common_phrases", use the person's common phrases
|
178 |
elif suggestion_type == "common_phrases":
|
@@ -194,12 +248,16 @@ def generate_suggestions(person_id, user_input, suggestion_type, selected_topic=
|
|
194 |
print("No category inferred, falling back to model")
|
195 |
# Fall back to model if we couldn't infer a category
|
196 |
try:
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
203 |
except Exception as e:
|
204 |
print(f"Error generating fallback suggestion: {e}")
|
205 |
result = "### Could not generate a response:\n\n"
|
@@ -319,9 +377,33 @@ with gr.Blocks(title="Will's AAC Communication Aid") as demo:
|
|
319 |
info="Choose what kind of responses you want (model = AI-generated)",
|
320 |
)
|
321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
# Generate button
|
323 |
generate_btn = gr.Button("Generate My Responses", variant="primary")
|
324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
with gr.Column(scale=1):
|
326 |
# Common phrases
|
327 |
common_phrases = gr.Textbox(
|
@@ -347,6 +429,11 @@ with gr.Blocks(title="Will's AAC Communication Aid") as demo:
|
|
347 |
# Update the context, phrases, and topic dropdown
|
348 |
return context_info, phrases_text, gr.update(choices=topics)
|
349 |
|
|
|
|
|
|
|
|
|
|
|
350 |
# Set up the person change event
|
351 |
person_dropdown.change(
|
352 |
handle_person_change,
|
@@ -354,10 +441,24 @@ with gr.Blocks(title="Will's AAC Communication Aid") as demo:
|
|
354 |
outputs=[context_display, common_phrases, topic_dropdown],
|
355 |
)
|
356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
# Set up the generate button click event
|
358 |
generate_btn.click(
|
359 |
generate_suggestions,
|
360 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
outputs=[suggestions_output],
|
362 |
)
|
363 |
|
|
|
4 |
import os
|
5 |
from utils import SocialGraphManager, SuggestionGenerator
|
6 |
|
7 |
+
# Define available models
|
8 |
+
AVAILABLE_MODELS = {
|
9 |
+
"distilgpt2": "DistilGPT2 (Fast, smaller model)",
|
10 |
+
"gpt2": "GPT-2 (Medium size, better quality)",
|
11 |
+
"google/gemma-3-1b-it": "Gemma 3 1B-IT (Small, instruction-tuned)",
|
12 |
+
"Qwen/Qwen1.5-0.5B": "Qwen 1.5 0.5B (Very small, efficient)",
|
13 |
+
"Qwen/Qwen1.5-1.8B": "Qwen 1.5 1.8B (Small, good quality)",
|
14 |
+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0": "TinyLlama 1.1B (Small, chat-tuned)",
|
15 |
+
"microsoft/phi-3-mini-4k-instruct": "Phi-3 Mini (Small, instruction-tuned)",
|
16 |
+
"microsoft/phi-2": "Phi-2 (Small, high quality for size)",
|
17 |
+
}
|
18 |
+
|
19 |
+
# Initialize the social graph manager
|
20 |
social_graph = SocialGraphManager("social_graph.json")
|
21 |
|
22 |
+
# Initialize the suggestion generator with distilgpt2 (default)
|
23 |
suggestion_generator = SuggestionGenerator("distilgpt2")
|
24 |
|
25 |
# Test the model to make sure it's working
|
|
|
35 |
try:
|
36 |
whisper_model = whisper.load_model("tiny")
|
37 |
whisper_loaded = True
|
38 |
+
except Exception as e:
|
39 |
+
print(f"Error loading Whisper model: {e}")
|
40 |
whisper_loaded = False
|
41 |
|
42 |
|
|
|
103 |
return context_info, phrases_text, topics
|
104 |
|
105 |
|
106 |
+
def change_model(model_name):
|
107 |
+
"""Change the language model used for generation.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
model_name: The name of the model to use
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
A status message about the model change
|
114 |
+
"""
|
115 |
+
global suggestion_generator
|
116 |
+
|
117 |
+
print(f"Changing model to: {model_name}")
|
118 |
+
|
119 |
+
# Check if we need to change the model
|
120 |
+
if model_name == suggestion_generator.model_name:
|
121 |
+
return f"Already using model: {model_name}"
|
122 |
+
|
123 |
+
# Try to load the new model
|
124 |
+
success = suggestion_generator.load_model(model_name)
|
125 |
+
|
126 |
+
if success:
|
127 |
+
return f"Successfully switched to model: {model_name}"
|
128 |
+
else:
|
129 |
+
return f"Failed to load model: {model_name}. Using fallback responses instead."
|
130 |
+
|
131 |
+
|
132 |
+
def generate_suggestions(
|
133 |
+
person_id,
|
134 |
+
user_input,
|
135 |
+
suggestion_type,
|
136 |
+
selected_topic=None,
|
137 |
+
model_name="distilgpt2",
|
138 |
+
temperature=0.7,
|
139 |
+
):
|
140 |
"""Generate suggestions based on the selected person and user input."""
|
141 |
print(
|
142 |
+
f"Generating suggestions with: person_id={person_id}, user_input={user_input}, "
|
143 |
+
f"suggestion_type={suggestion_type}, selected_topic={selected_topic}, "
|
144 |
+
f"model={model_name}, temperature={temperature}"
|
145 |
)
|
146 |
|
147 |
if not person_id:
|
148 |
print("No person_id provided")
|
149 |
return "Please select who you're talking to first."
|
150 |
|
151 |
+
# Make sure we're using the right model
|
152 |
+
if model_name != suggestion_generator.model_name:
|
153 |
+
change_model(model_name)
|
154 |
+
|
155 |
person_context = social_graph.get_person_context(person_id)
|
156 |
print(f"Person context: {person_context}")
|
157 |
|
|
|
212 |
print(f"Generating suggestion {i+1}/3")
|
213 |
try:
|
214 |
suggestion = suggestion_generator.generate_suggestion(
|
215 |
+
person_context, user_input, temperature=temperature
|
216 |
)
|
217 |
print(f"Generated suggestion: {suggestion}")
|
218 |
suggestions.append(suggestion)
|
|
|
220 |
print(f"Error generating suggestion: {e}")
|
221 |
suggestions.append("Error generating suggestion")
|
222 |
|
223 |
+
result = (
|
224 |
+
f"### AI-Generated Responses (using {suggestion_generator.model_name}):\n\n"
|
225 |
+
)
|
226 |
for i, suggestion in enumerate(suggestions, 1):
|
227 |
result += f"{i}. {suggestion}\n\n"
|
228 |
|
229 |
+
print(f"Final result: {result[:100]}...")
|
230 |
|
231 |
# If suggestion type is "common_phrases", use the person's common phrases
|
232 |
elif suggestion_type == "common_phrases":
|
|
|
248 |
print("No category inferred, falling back to model")
|
249 |
# Fall back to model if we couldn't infer a category
|
250 |
try:
|
251 |
+
suggestions = []
|
252 |
+
for i in range(3):
|
253 |
+
suggestion = suggestion_generator.generate_suggestion(
|
254 |
+
person_context, user_input, temperature=temperature
|
255 |
+
)
|
256 |
+
suggestions.append(suggestion)
|
257 |
+
|
258 |
+
result = f"### AI-Generated Responses (no category detected, using {suggestion_generator.model_name}):\n\n"
|
259 |
+
for i, suggestion in enumerate(suggestions, 1):
|
260 |
+
result += f"{i}. {suggestion}\n\n"
|
261 |
except Exception as e:
|
262 |
print(f"Error generating fallback suggestion: {e}")
|
263 |
result = "### Could not generate a response:\n\n"
|
|
|
377 |
info="Choose what kind of responses you want (model = AI-generated)",
|
378 |
)
|
379 |
|
380 |
+
# Model selection
|
381 |
+
with gr.Row():
|
382 |
+
model_dropdown = gr.Dropdown(
|
383 |
+
choices=list(AVAILABLE_MODELS.keys()),
|
384 |
+
value="distilgpt2",
|
385 |
+
label="Language Model",
|
386 |
+
info="Select which AI model to use for generating responses",
|
387 |
+
)
|
388 |
+
|
389 |
+
temperature_slider = gr.Slider(
|
390 |
+
minimum=0.1,
|
391 |
+
maximum=1.5,
|
392 |
+
value=0.7,
|
393 |
+
step=0.1,
|
394 |
+
label="Temperature",
|
395 |
+
info="Controls randomness (higher = more creative, lower = more focused)",
|
396 |
+
)
|
397 |
+
|
398 |
# Generate button
|
399 |
generate_btn = gr.Button("Generate My Responses", variant="primary")
|
400 |
|
401 |
+
# Model status
|
402 |
+
model_status = gr.Markdown(
|
403 |
+
value=f"Current model: {suggestion_generator.model_name}",
|
404 |
+
label="Model Status",
|
405 |
+
)
|
406 |
+
|
407 |
with gr.Column(scale=1):
|
408 |
# Common phrases
|
409 |
common_phrases = gr.Textbox(
|
|
|
429 |
# Update the context, phrases, and topic dropdown
|
430 |
return context_info, phrases_text, gr.update(choices=topics)
|
431 |
|
432 |
+
def handle_model_change(model_name):
|
433 |
+
"""Handle model selection change."""
|
434 |
+
status = change_model(model_name)
|
435 |
+
return status
|
436 |
+
|
437 |
# Set up the person change event
|
438 |
person_dropdown.change(
|
439 |
handle_person_change,
|
|
|
441 |
outputs=[context_display, common_phrases, topic_dropdown],
|
442 |
)
|
443 |
|
444 |
+
# Set up the model change event
|
445 |
+
model_dropdown.change(
|
446 |
+
handle_model_change,
|
447 |
+
inputs=[model_dropdown],
|
448 |
+
outputs=[model_status],
|
449 |
+
)
|
450 |
+
|
451 |
# Set up the generate button click event
|
452 |
generate_btn.click(
|
453 |
generate_suggestions,
|
454 |
+
inputs=[
|
455 |
+
person_dropdown,
|
456 |
+
user_input,
|
457 |
+
suggestion_type,
|
458 |
+
topic_dropdown,
|
459 |
+
model_dropdown,
|
460 |
+
temperature_slider,
|
461 |
+
],
|
462 |
outputs=[suggestions_output],
|
463 |
)
|
464 |
|
demo.py
CHANGED
@@ -398,7 +398,10 @@ class LLMToolInterface(LLMInterface):
|
|
398 |
print("llm install llm-mlx")
|
399 |
elif "ollama" in self.model_name.lower():
|
400 |
print("llm install llm-ollama")
|
401 |
-
|
|
|
|
|
|
|
402 |
else:
|
403 |
print("Warning: LLM tool may be installed but returned an error.")
|
404 |
except Exception as e:
|
@@ -610,7 +613,7 @@ def main():
|
|
610 |
"- hf: 'distilgpt2', 'gpt2-medium', 'google/gemma-2b-it'\n"
|
611 |
"- llm: 'gemini-1.5-pro-latest', 'gemma-3-27b-it' (requires llm-gemini plugin)\n"
|
612 |
" 'mlx-community/gemma-7b-it' (requires llm-mlx plugin)\n"
|
613 |
-
" '
|
614 |
)
|
615 |
parser.add_argument(
|
616 |
"--num_responses", type=int, default=3, help="Number of responses to generate"
|
@@ -705,9 +708,12 @@ def main():
|
|
705 |
print("1. Install from https://ollama.ai/")
|
706 |
print("2. Start Ollama with: ollama serve")
|
707 |
print("3. Install the llm-ollama plugin: llm install llm-ollama")
|
708 |
-
|
709 |
-
|
710 |
-
|
|
|
|
|
|
|
711 |
else:
|
712 |
print("\nMake sure Simon Willison's LLM tool is installed:")
|
713 |
print("pip install llm")
|
|
|
398 |
print("llm install llm-mlx")
|
399 |
elif "ollama" in self.model_name.lower():
|
400 |
print("llm install llm-ollama")
|
401 |
+
model_name = self.model_name
|
402 |
+
if "/" in model_name:
|
403 |
+
model_name = model_name.split("/")[1]
|
404 |
+
print("ollama pull " + model_name)
|
405 |
else:
|
406 |
print("Warning: LLM tool may be installed but returned an error.")
|
407 |
except Exception as e:
|
|
|
613 |
"- hf: 'distilgpt2', 'gpt2-medium', 'google/gemma-2b-it'\n"
|
614 |
"- llm: 'gemini-1.5-pro-latest', 'gemma-3-27b-it' (requires llm-gemini plugin)\n"
|
615 |
" 'mlx-community/gemma-7b-it' (requires llm-mlx plugin)\n"
|
616 |
+
" 'Ollama: gemma3:4b-it-qat', 'Ollama: llama3:8b' (requires llm-ollama plugin)",
|
617 |
)
|
618 |
parser.add_argument(
|
619 |
"--num_responses", type=int, default=3, help="Number of responses to generate"
|
|
|
708 |
print("1. Install from https://ollama.ai/")
|
709 |
print("2. Start Ollama with: ollama serve")
|
710 |
print("3. Install the llm-ollama plugin: llm install llm-ollama")
|
711 |
+
model_name = args.model
|
712 |
+
if "ollama:" in model_name.lower():
|
713 |
+
model_name = model_name.replace("Ollama: ", "")
|
714 |
+
elif "/" in model_name:
|
715 |
+
model_name = model_name.split("/")[1]
|
716 |
+
print(f"4. Pull the model: ollama pull {model_name}")
|
717 |
else:
|
718 |
print("\nMake sure Simon Willison's LLM tool is installed:")
|
719 |
print("pip install llm")
|
utils.py
CHANGED
@@ -159,16 +159,20 @@ class SuggestionGenerator:
|
|
159 |
"""
|
160 |
self.model_name = model_name
|
161 |
self.model_loaded = False
|
|
|
|
|
162 |
|
|
|
163 |
try:
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
self.model_loaded = True
|
168 |
-
print(f"Model loaded successfully: {model_name}")
|
169 |
except Exception as e:
|
170 |
-
print(f"Error loading
|
171 |
-
self.
|
|
|
|
|
|
|
172 |
|
173 |
# Fallback responses if model fails to load or generate
|
174 |
self.fallback_responses = [
|
@@ -176,8 +180,92 @@ class SuggestionGenerator:
|
|
176 |
"That's interesting. Tell me more.",
|
177 |
"I'd like to talk about that further.",
|
178 |
"I appreciate you sharing that with me.",
|
|
|
|
|
179 |
]
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
def test_model(self) -> str:
|
182 |
"""Test if the model is working correctly."""
|
183 |
if not self.model_loaded:
|
@@ -186,7 +274,9 @@ class SuggestionGenerator:
|
|
186 |
try:
|
187 |
test_prompt = "I am Will. My son Billy asked about football. I respond:"
|
188 |
print(f"Testing model with prompt: {test_prompt}")
|
189 |
-
response = self.generator(
|
|
|
|
|
190 |
result = response[0]["generated_text"][len(test_prompt) :]
|
191 |
print(f"Test response: {result}")
|
192 |
return f"Model test successful: {result}"
|
@@ -222,39 +312,100 @@ class SuggestionGenerator:
|
|
222 |
# Extract context information
|
223 |
name = person_context.get("name", "")
|
224 |
role = person_context.get("role", "")
|
225 |
-
topics =
|
226 |
context = person_context.get("context", "")
|
227 |
selected_topic = person_context.get("selected_topic", "")
|
|
|
|
|
228 |
|
229 |
-
#
|
230 |
-
|
231 |
-
I'm talking to {name}, who is my {role}.
|
232 |
-
"""
|
233 |
|
234 |
-
|
235 |
-
|
|
|
236 |
|
237 |
-
|
238 |
-
prompt += f"Topics of interest: {topics}\n"
|
239 |
|
240 |
-
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
if user_input:
|
244 |
prompt += f'\n{name} just said to me: "{user_input}"\n'
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
# Generate suggestion
|
249 |
try:
|
250 |
print(f"Generating suggestion with prompt: {prompt}")
|
|
|
251 |
response = self.generator(
|
252 |
prompt,
|
253 |
-
max_length
|
254 |
temperature=temperature,
|
255 |
do_sample=True,
|
256 |
top_p=0.92,
|
257 |
top_k=50,
|
|
|
258 |
)
|
259 |
# Extract only the generated part, not the prompt
|
260 |
result = response[0]["generated_text"][len(prompt) :]
|
|
|
159 |
"""
|
160 |
self.model_name = model_name
|
161 |
self.model_loaded = False
|
162 |
+
self.generator = None
|
163 |
+
self.aac_user_info = None
|
164 |
|
165 |
+
# Load AAC user information from social graph
|
166 |
try:
|
167 |
+
with open("social_graph.json", "r") as f:
|
168 |
+
social_graph = json.load(f)
|
169 |
+
self.aac_user_info = social_graph.get("aac_user", {})
|
|
|
|
|
170 |
except Exception as e:
|
171 |
+
print(f"Error loading AAC user info from social graph: {e}")
|
172 |
+
self.aac_user_info = {}
|
173 |
+
|
174 |
+
# Try to load the model
|
175 |
+
self.load_model(model_name)
|
176 |
|
177 |
# Fallback responses if model fails to load or generate
|
178 |
self.fallback_responses = [
|
|
|
180 |
"That's interesting. Tell me more.",
|
181 |
"I'd like to talk about that further.",
|
182 |
"I appreciate you sharing that with me.",
|
183 |
+
"Could we talk about something else?",
|
184 |
+
"I need some time to think about that.",
|
185 |
]
|
186 |
|
187 |
+
def load_model(self, model_name: str) -> bool:
|
188 |
+
"""Load a Hugging Face model.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
model_name: Name of the HuggingFace model to use
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
bool: True if model loaded successfully, False otherwise
|
195 |
+
"""
|
196 |
+
self.model_name = model_name
|
197 |
+
self.model_loaded = False
|
198 |
+
|
199 |
+
try:
|
200 |
+
print(f"Loading model: {model_name}")
|
201 |
+
|
202 |
+
# Check if this is a gated model that requires authentication
|
203 |
+
is_gated_model = any(
|
204 |
+
name in model_name.lower()
|
205 |
+
for name in ["gemma", "llama", "mistral", "qwen", "phi"]
|
206 |
+
)
|
207 |
+
|
208 |
+
if is_gated_model:
|
209 |
+
# Try to get token from environment
|
210 |
+
import os
|
211 |
+
|
212 |
+
token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
|
213 |
+
"HF_TOKEN"
|
214 |
+
)
|
215 |
+
|
216 |
+
if token:
|
217 |
+
print(f"Using token for gated model: {model_name}")
|
218 |
+
from huggingface_hub import login
|
219 |
+
|
220 |
+
login(token=token, add_to_git_credential=False)
|
221 |
+
|
222 |
+
# Explicitly pass token to pipeline
|
223 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
224 |
+
|
225 |
+
try:
|
226 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
227 |
+
model_name, token=token
|
228 |
+
)
|
229 |
+
model = AutoModelForCausalLM.from_pretrained(
|
230 |
+
model_name, token=token
|
231 |
+
)
|
232 |
+
self.generator = pipeline(
|
233 |
+
"text-generation", model=model, tokenizer=tokenizer
|
234 |
+
)
|
235 |
+
except Exception as e:
|
236 |
+
print(f"Error loading gated model with token: {e}")
|
237 |
+
print(
|
238 |
+
"This may be due to not having accepted the model license or insufficient permissions."
|
239 |
+
)
|
240 |
+
print(
|
241 |
+
"Please visit the model page on Hugging Face Hub and accept the license."
|
242 |
+
)
|
243 |
+
raise
|
244 |
+
else:
|
245 |
+
print("No Hugging Face token found in environment variables.")
|
246 |
+
print(
|
247 |
+
"To use gated models like Gemma, you need to set up a token with the right permissions."
|
248 |
+
)
|
249 |
+
print("1. Create a token at https://huggingface.co/settings/tokens")
|
250 |
+
print(
|
251 |
+
"2. Make sure to enable 'Access to public gated repositories'"
|
252 |
+
)
|
253 |
+
print(
|
254 |
+
"3. Set it as an environment variable: export HUGGING_FACE_HUB_TOKEN=your_token_here"
|
255 |
+
)
|
256 |
+
raise ValueError("Authentication token required for gated model")
|
257 |
+
else:
|
258 |
+
# For non-gated models, use the standard pipeline
|
259 |
+
self.generator = pipeline("text-generation", model=model_name)
|
260 |
+
|
261 |
+
self.model_loaded = True
|
262 |
+
print(f"Model loaded successfully: {model_name}")
|
263 |
+
return True
|
264 |
+
except Exception as e:
|
265 |
+
print(f"Error loading model: {e}")
|
266 |
+
self.model_loaded = False
|
267 |
+
return False
|
268 |
+
|
269 |
def test_model(self) -> str:
|
270 |
"""Test if the model is working correctly."""
|
271 |
if not self.model_loaded:
|
|
|
274 |
try:
|
275 |
test_prompt = "I am Will. My son Billy asked about football. I respond:"
|
276 |
print(f"Testing model with prompt: {test_prompt}")
|
277 |
+
response = self.generator(
|
278 |
+
test_prompt, max_new_tokens=30, do_sample=True, truncation=True
|
279 |
+
)
|
280 |
result = response[0]["generated_text"][len(test_prompt) :]
|
281 |
print(f"Test response: {result}")
|
282 |
return f"Model test successful: {result}"
|
|
|
312 |
# Extract context information
|
313 |
name = person_context.get("name", "")
|
314 |
role = person_context.get("role", "")
|
315 |
+
topics = person_context.get("topics", [])
|
316 |
context = person_context.get("context", "")
|
317 |
selected_topic = person_context.get("selected_topic", "")
|
318 |
+
common_phrases = person_context.get("common_phrases", [])
|
319 |
+
frequency = person_context.get("frequency", "")
|
320 |
|
321 |
+
# Get AAC user information
|
322 |
+
aac_user = self.aac_user_info
|
|
|
|
|
323 |
|
324 |
+
# Build enhanced prompt
|
325 |
+
prompt = f"""I am {aac_user.get('name', 'Will')}, a {aac_user.get('age', 38)}-year-old with MND (Motor Neuron Disease) from {aac_user.get('location', 'Manchester')}.
|
326 |
+
{aac_user.get('background', '')}
|
327 |
|
328 |
+
My communication needs: {aac_user.get('communication_needs', '')}
|
|
|
329 |
|
330 |
+
I am talking to {name}, who is my {role}.
|
331 |
+
About {name}: {context}
|
332 |
+
We typically talk about: {', '.join(topics)}
|
333 |
+
We communicate {frequency}.
|
334 |
+
"""
|
335 |
+
|
336 |
+
# Add communication style based on relationship
|
337 |
+
if role in ["wife", "son", "daughter", "mother", "father"]:
|
338 |
+
prompt += "I communicate with my family in a warm, loving way, sometimes using inside jokes.\n"
|
339 |
+
elif role in ["doctor", "therapist", "nurse"]:
|
340 |
+
prompt += "I communicate with healthcare providers in a direct, informative way.\n"
|
341 |
+
elif role in ["best mate", "friend"]:
|
342 |
+
prompt += "I communicate with friends casually, often with humor and sometimes swearing.\n"
|
343 |
+
elif role in ["work colleague", "boss"]:
|
344 |
+
prompt += (
|
345 |
+
"I communicate with colleagues professionally but still friendly.\n"
|
346 |
+
)
|
347 |
|
348 |
+
# Add topic information if provided
|
349 |
+
if selected_topic:
|
350 |
+
prompt += f"\nWe are currently discussing {selected_topic}.\n"
|
351 |
+
|
352 |
+
# Add specific context about this topic with this person
|
353 |
+
if selected_topic == "football" and "Manchester United" in context:
|
354 |
+
prompt += "We both support Manchester United and often discuss recent matches.\n"
|
355 |
+
elif selected_topic == "programming" and "software developer" in context:
|
356 |
+
prompt += "We both work in software development and share technical interests.\n"
|
357 |
+
elif selected_topic == "family plans" and role in ["wife", "husband"]:
|
358 |
+
prompt += (
|
359 |
+
"We make family decisions together, considering my condition.\n"
|
360 |
+
)
|
361 |
+
elif selected_topic == "old scout adventures" and role == "best mate":
|
362 |
+
prompt += "We often reminisce about our Scout camping trips in South East London.\n"
|
363 |
+
elif selected_topic == "cycling" and "cycling" in context:
|
364 |
+
prompt += "I miss being able to cycle but enjoy talking about past cycling adventures.\n"
|
365 |
+
|
366 |
+
# Add the user's message if provided
|
367 |
if user_input:
|
368 |
prompt += f'\n{name} just said to me: "{user_input}"\n'
|
369 |
+
elif common_phrases:
|
370 |
+
# Use a common phrase from the person if no message is provided
|
371 |
+
default_message = common_phrases[0]
|
372 |
+
prompt += f'\n{name} typically says things like: "{default_message}"\n'
|
373 |
+
|
374 |
+
# Add the response prompt with specific guidance
|
375 |
+
# Check if this is an instruction-tuned model
|
376 |
+
is_instruction_model = any(
|
377 |
+
marker in self.model_name.lower()
|
378 |
+
for marker in ["-it", "instruct", "chat", "phi-3", "phi-2"]
|
379 |
+
)
|
380 |
+
|
381 |
+
if is_instruction_model:
|
382 |
+
# Use instruction format for instruction-tuned models
|
383 |
+
prompt += f"""
|
384 |
+
<instruction>
|
385 |
+
Respond to {name} in a way that is natural, brief (1-2 sentences), and directly relevant to what they just said.
|
386 |
+
Use language appropriate for our relationship.
|
387 |
+
</instruction>
|
388 |
+
|
389 |
+
My response to {name}:"""
|
390 |
+
else:
|
391 |
+
# Use standard format for non-instruction models
|
392 |
+
prompt += f"""
|
393 |
+
I want to respond to {name} in a way that is natural, brief (1-2 sentences), and directly relevant to what they just said. I'll use language appropriate for our relationship.
|
394 |
+
|
395 |
+
My response to {name}:"""
|
396 |
|
397 |
# Generate suggestion
|
398 |
try:
|
399 |
print(f"Generating suggestion with prompt: {prompt}")
|
400 |
+
# Use max_new_tokens instead of max_length to avoid the error
|
401 |
response = self.generator(
|
402 |
prompt,
|
403 |
+
max_new_tokens=max_length, # Generate new tokens, not including prompt
|
404 |
temperature=temperature,
|
405 |
do_sample=True,
|
406 |
top_p=0.92,
|
407 |
top_k=50,
|
408 |
+
truncation=True,
|
409 |
)
|
410 |
# Extract only the generated part, not the prompt
|
411 |
result = response[0]["generated_text"][len(prompt) :]
|