willwade commited on
Commit
fa1bef5
·
1 Parent(s): b929813

add in changes to make app work as well as demo

Browse files
Files changed (3) hide show
  1. app.py +116 -15
  2. demo.py +11 -5
  3. utils.py +173 -22
app.py CHANGED
@@ -4,10 +4,22 @@ import tempfile
4
  import os
5
  from utils import SocialGraphManager, SuggestionGenerator
6
 
7
- # Initialize the social graph manager and suggestion generator
 
 
 
 
 
 
 
 
 
 
 
 
8
  social_graph = SocialGraphManager("social_graph.json")
9
 
10
- # Initialize the suggestion generator with distilgpt2
11
  suggestion_generator = SuggestionGenerator("distilgpt2")
12
 
13
  # Test the model to make sure it's working
@@ -23,7 +35,8 @@ if not suggestion_generator.model_loaded:
23
  try:
24
  whisper_model = whisper.load_model("tiny")
25
  whisper_loaded = True
26
- except Exception:
 
27
  whisper_loaded = False
28
 
29
 
@@ -90,16 +103,55 @@ def on_person_change(person_id):
90
  return context_info, phrases_text, topics
91
 
92
 
93
- def generate_suggestions(person_id, user_input, suggestion_type, selected_topic=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  """Generate suggestions based on the selected person and user input."""
95
  print(
96
- f"Generating suggestions with: person_id={person_id}, user_input={user_input}, suggestion_type={suggestion_type}, selected_topic={selected_topic}"
 
 
97
  )
98
 
99
  if not person_id:
100
  print("No person_id provided")
101
  return "Please select who you're talking to first."
102
 
 
 
 
 
103
  person_context = social_graph.get_person_context(person_id)
104
  print(f"Person context: {person_context}")
105
 
@@ -160,7 +212,7 @@ def generate_suggestions(person_id, user_input, suggestion_type, selected_topic=
160
  print(f"Generating suggestion {i+1}/3")
161
  try:
162
  suggestion = suggestion_generator.generate_suggestion(
163
- person_context, user_input, temperature=0.7
164
  )
165
  print(f"Generated suggestion: {suggestion}")
166
  suggestions.append(suggestion)
@@ -168,11 +220,13 @@ def generate_suggestions(person_id, user_input, suggestion_type, selected_topic=
168
  print(f"Error generating suggestion: {e}")
169
  suggestions.append("Error generating suggestion")
170
 
171
- result = "### AI-Generated Responses:\n\n"
 
 
172
  for i, suggestion in enumerate(suggestions, 1):
173
  result += f"{i}. {suggestion}\n\n"
174
 
175
- print(f"Final result: {result}")
176
 
177
  # If suggestion type is "common_phrases", use the person's common phrases
178
  elif suggestion_type == "common_phrases":
@@ -194,12 +248,16 @@ def generate_suggestions(person_id, user_input, suggestion_type, selected_topic=
194
  print("No category inferred, falling back to model")
195
  # Fall back to model if we couldn't infer a category
196
  try:
197
- suggestion = suggestion_generator.generate_suggestion(
198
- person_context, user_input
199
- )
200
- print(f"Generated fallback suggestion: {suggestion}")
201
- result = "### AI-Generated Response (no category detected):\n\n"
202
- result += f"1. {suggestion}\n\n"
 
 
 
 
203
  except Exception as e:
204
  print(f"Error generating fallback suggestion: {e}")
205
  result = "### Could not generate a response:\n\n"
@@ -319,9 +377,33 @@ with gr.Blocks(title="Will's AAC Communication Aid") as demo:
319
  info="Choose what kind of responses you want (model = AI-generated)",
320
  )
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  # Generate button
323
  generate_btn = gr.Button("Generate My Responses", variant="primary")
324
 
 
 
 
 
 
 
325
  with gr.Column(scale=1):
326
  # Common phrases
327
  common_phrases = gr.Textbox(
@@ -347,6 +429,11 @@ with gr.Blocks(title="Will's AAC Communication Aid") as demo:
347
  # Update the context, phrases, and topic dropdown
348
  return context_info, phrases_text, gr.update(choices=topics)
349
 
 
 
 
 
 
350
  # Set up the person change event
351
  person_dropdown.change(
352
  handle_person_change,
@@ -354,10 +441,24 @@ with gr.Blocks(title="Will's AAC Communication Aid") as demo:
354
  outputs=[context_display, common_phrases, topic_dropdown],
355
  )
356
 
 
 
 
 
 
 
 
357
  # Set up the generate button click event
358
  generate_btn.click(
359
  generate_suggestions,
360
- inputs=[person_dropdown, user_input, suggestion_type, topic_dropdown],
 
 
 
 
 
 
 
361
  outputs=[suggestions_output],
362
  )
363
 
 
4
  import os
5
  from utils import SocialGraphManager, SuggestionGenerator
6
 
7
+ # Define available models
8
+ AVAILABLE_MODELS = {
9
+ "distilgpt2": "DistilGPT2 (Fast, smaller model)",
10
+ "gpt2": "GPT-2 (Medium size, better quality)",
11
+ "google/gemma-3-1b-it": "Gemma 3 1B-IT (Small, instruction-tuned)",
12
+ "Qwen/Qwen1.5-0.5B": "Qwen 1.5 0.5B (Very small, efficient)",
13
+ "Qwen/Qwen1.5-1.8B": "Qwen 1.5 1.8B (Small, good quality)",
14
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0": "TinyLlama 1.1B (Small, chat-tuned)",
15
+ "microsoft/phi-3-mini-4k-instruct": "Phi-3 Mini (Small, instruction-tuned)",
16
+ "microsoft/phi-2": "Phi-2 (Small, high quality for size)",
17
+ }
18
+
19
+ # Initialize the social graph manager
20
  social_graph = SocialGraphManager("social_graph.json")
21
 
22
+ # Initialize the suggestion generator with distilgpt2 (default)
23
  suggestion_generator = SuggestionGenerator("distilgpt2")
24
 
25
  # Test the model to make sure it's working
 
35
  try:
36
  whisper_model = whisper.load_model("tiny")
37
  whisper_loaded = True
38
+ except Exception as e:
39
+ print(f"Error loading Whisper model: {e}")
40
  whisper_loaded = False
41
 
42
 
 
103
  return context_info, phrases_text, topics
104
 
105
 
106
+ def change_model(model_name):
107
+ """Change the language model used for generation.
108
+
109
+ Args:
110
+ model_name: The name of the model to use
111
+
112
+ Returns:
113
+ A status message about the model change
114
+ """
115
+ global suggestion_generator
116
+
117
+ print(f"Changing model to: {model_name}")
118
+
119
+ # Check if we need to change the model
120
+ if model_name == suggestion_generator.model_name:
121
+ return f"Already using model: {model_name}"
122
+
123
+ # Try to load the new model
124
+ success = suggestion_generator.load_model(model_name)
125
+
126
+ if success:
127
+ return f"Successfully switched to model: {model_name}"
128
+ else:
129
+ return f"Failed to load model: {model_name}. Using fallback responses instead."
130
+
131
+
132
+ def generate_suggestions(
133
+ person_id,
134
+ user_input,
135
+ suggestion_type,
136
+ selected_topic=None,
137
+ model_name="distilgpt2",
138
+ temperature=0.7,
139
+ ):
140
  """Generate suggestions based on the selected person and user input."""
141
  print(
142
+ f"Generating suggestions with: person_id={person_id}, user_input={user_input}, "
143
+ f"suggestion_type={suggestion_type}, selected_topic={selected_topic}, "
144
+ f"model={model_name}, temperature={temperature}"
145
  )
146
 
147
  if not person_id:
148
  print("No person_id provided")
149
  return "Please select who you're talking to first."
150
 
151
+ # Make sure we're using the right model
152
+ if model_name != suggestion_generator.model_name:
153
+ change_model(model_name)
154
+
155
  person_context = social_graph.get_person_context(person_id)
156
  print(f"Person context: {person_context}")
157
 
 
212
  print(f"Generating suggestion {i+1}/3")
213
  try:
214
  suggestion = suggestion_generator.generate_suggestion(
215
+ person_context, user_input, temperature=temperature
216
  )
217
  print(f"Generated suggestion: {suggestion}")
218
  suggestions.append(suggestion)
 
220
  print(f"Error generating suggestion: {e}")
221
  suggestions.append("Error generating suggestion")
222
 
223
+ result = (
224
+ f"### AI-Generated Responses (using {suggestion_generator.model_name}):\n\n"
225
+ )
226
  for i, suggestion in enumerate(suggestions, 1):
227
  result += f"{i}. {suggestion}\n\n"
228
 
229
+ print(f"Final result: {result[:100]}...")
230
 
231
  # If suggestion type is "common_phrases", use the person's common phrases
232
  elif suggestion_type == "common_phrases":
 
248
  print("No category inferred, falling back to model")
249
  # Fall back to model if we couldn't infer a category
250
  try:
251
+ suggestions = []
252
+ for i in range(3):
253
+ suggestion = suggestion_generator.generate_suggestion(
254
+ person_context, user_input, temperature=temperature
255
+ )
256
+ suggestions.append(suggestion)
257
+
258
+ result = f"### AI-Generated Responses (no category detected, using {suggestion_generator.model_name}):\n\n"
259
+ for i, suggestion in enumerate(suggestions, 1):
260
+ result += f"{i}. {suggestion}\n\n"
261
  except Exception as e:
262
  print(f"Error generating fallback suggestion: {e}")
263
  result = "### Could not generate a response:\n\n"
 
377
  info="Choose what kind of responses you want (model = AI-generated)",
378
  )
379
 
380
+ # Model selection
381
+ with gr.Row():
382
+ model_dropdown = gr.Dropdown(
383
+ choices=list(AVAILABLE_MODELS.keys()),
384
+ value="distilgpt2",
385
+ label="Language Model",
386
+ info="Select which AI model to use for generating responses",
387
+ )
388
+
389
+ temperature_slider = gr.Slider(
390
+ minimum=0.1,
391
+ maximum=1.5,
392
+ value=0.7,
393
+ step=0.1,
394
+ label="Temperature",
395
+ info="Controls randomness (higher = more creative, lower = more focused)",
396
+ )
397
+
398
  # Generate button
399
  generate_btn = gr.Button("Generate My Responses", variant="primary")
400
 
401
+ # Model status
402
+ model_status = gr.Markdown(
403
+ value=f"Current model: {suggestion_generator.model_name}",
404
+ label="Model Status",
405
+ )
406
+
407
  with gr.Column(scale=1):
408
  # Common phrases
409
  common_phrases = gr.Textbox(
 
429
  # Update the context, phrases, and topic dropdown
430
  return context_info, phrases_text, gr.update(choices=topics)
431
 
432
+ def handle_model_change(model_name):
433
+ """Handle model selection change."""
434
+ status = change_model(model_name)
435
+ return status
436
+
437
  # Set up the person change event
438
  person_dropdown.change(
439
  handle_person_change,
 
441
  outputs=[context_display, common_phrases, topic_dropdown],
442
  )
443
 
444
+ # Set up the model change event
445
+ model_dropdown.change(
446
+ handle_model_change,
447
+ inputs=[model_dropdown],
448
+ outputs=[model_status],
449
+ )
450
+
451
  # Set up the generate button click event
452
  generate_btn.click(
453
  generate_suggestions,
454
+ inputs=[
455
+ person_dropdown,
456
+ user_input,
457
+ suggestion_type,
458
+ topic_dropdown,
459
+ model_dropdown,
460
+ temperature_slider,
461
+ ],
462
  outputs=[suggestions_output],
463
  )
464
 
demo.py CHANGED
@@ -398,7 +398,10 @@ class LLMToolInterface(LLMInterface):
398
  print("llm install llm-mlx")
399
  elif "ollama" in self.model_name.lower():
400
  print("llm install llm-ollama")
401
- print("ollama pull " + self.model_name.replace("ollama/", ""))
 
 
 
402
  else:
403
  print("Warning: LLM tool may be installed but returned an error.")
404
  except Exception as e:
@@ -610,7 +613,7 @@ def main():
610
  "- hf: 'distilgpt2', 'gpt2-medium', 'google/gemma-2b-it'\n"
611
  "- llm: 'gemini-1.5-pro-latest', 'gemma-3-27b-it' (requires llm-gemini plugin)\n"
612
  " 'mlx-community/gemma-7b-it' (requires llm-mlx plugin)\n"
613
- " 'ollama/gemma3:4b-it-qat', 'ollama/llama3:8b' (requires llm-ollama plugin)",
614
  )
615
  parser.add_argument(
616
  "--num_responses", type=int, default=3, help="Number of responses to generate"
@@ -705,9 +708,12 @@ def main():
705
  print("1. Install from https://ollama.ai/")
706
  print("2. Start Ollama with: ollama serve")
707
  print("3. Install the llm-ollama plugin: llm install llm-ollama")
708
- print(
709
- f"4. Pull the model: ollama pull {args.model.replace('ollama/', '')}"
710
- )
 
 
 
711
  else:
712
  print("\nMake sure Simon Willison's LLM tool is installed:")
713
  print("pip install llm")
 
398
  print("llm install llm-mlx")
399
  elif "ollama" in self.model_name.lower():
400
  print("llm install llm-ollama")
401
+ model_name = self.model_name
402
+ if "/" in model_name:
403
+ model_name = model_name.split("/")[1]
404
+ print("ollama pull " + model_name)
405
  else:
406
  print("Warning: LLM tool may be installed but returned an error.")
407
  except Exception as e:
 
613
  "- hf: 'distilgpt2', 'gpt2-medium', 'google/gemma-2b-it'\n"
614
  "- llm: 'gemini-1.5-pro-latest', 'gemma-3-27b-it' (requires llm-gemini plugin)\n"
615
  " 'mlx-community/gemma-7b-it' (requires llm-mlx plugin)\n"
616
+ " 'Ollama: gemma3:4b-it-qat', 'Ollama: llama3:8b' (requires llm-ollama plugin)",
617
  )
618
  parser.add_argument(
619
  "--num_responses", type=int, default=3, help="Number of responses to generate"
 
708
  print("1. Install from https://ollama.ai/")
709
  print("2. Start Ollama with: ollama serve")
710
  print("3. Install the llm-ollama plugin: llm install llm-ollama")
711
+ model_name = args.model
712
+ if "ollama:" in model_name.lower():
713
+ model_name = model_name.replace("Ollama: ", "")
714
+ elif "/" in model_name:
715
+ model_name = model_name.split("/")[1]
716
+ print(f"4. Pull the model: ollama pull {model_name}")
717
  else:
718
  print("\nMake sure Simon Willison's LLM tool is installed:")
719
  print("pip install llm")
utils.py CHANGED
@@ -159,16 +159,20 @@ class SuggestionGenerator:
159
  """
160
  self.model_name = model_name
161
  self.model_loaded = False
 
 
162
 
 
163
  try:
164
- print(f"Loading model: {model_name}")
165
- # Use a simpler approach with a pre-built pipeline
166
- self.generator = pipeline("text-generation", model=model_name)
167
- self.model_loaded = True
168
- print(f"Model loaded successfully: {model_name}")
169
  except Exception as e:
170
- print(f"Error loading model: {e}")
171
- self.model_loaded = False
 
 
 
172
 
173
  # Fallback responses if model fails to load or generate
174
  self.fallback_responses = [
@@ -176,8 +180,92 @@ class SuggestionGenerator:
176
  "That's interesting. Tell me more.",
177
  "I'd like to talk about that further.",
178
  "I appreciate you sharing that with me.",
 
 
179
  ]
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  def test_model(self) -> str:
182
  """Test if the model is working correctly."""
183
  if not self.model_loaded:
@@ -186,7 +274,9 @@ class SuggestionGenerator:
186
  try:
187
  test_prompt = "I am Will. My son Billy asked about football. I respond:"
188
  print(f"Testing model with prompt: {test_prompt}")
189
- response = self.generator(test_prompt, max_length=30, do_sample=True)
 
 
190
  result = response[0]["generated_text"][len(test_prompt) :]
191
  print(f"Test response: {result}")
192
  return f"Model test successful: {result}"
@@ -222,39 +312,100 @@ class SuggestionGenerator:
222
  # Extract context information
223
  name = person_context.get("name", "")
224
  role = person_context.get("role", "")
225
- topics = ", ".join(person_context.get("topics", []))
226
  context = person_context.get("context", "")
227
  selected_topic = person_context.get("selected_topic", "")
 
 
228
 
229
- # Build prompt
230
- prompt = f"""I am Will, a person with MND (Motor Neuron Disease).
231
- I'm talking to {name}, who is my {role}.
232
- """
233
 
234
- if context:
235
- prompt += f"Context: {context}\n"
 
236
 
237
- if topics:
238
- prompt += f"Topics of interest: {topics}\n"
239
 
240
- if selected_topic:
241
- prompt += f"We're currently talking about: {selected_topic}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  if user_input:
244
  prompt += f'\n{name} just said to me: "{user_input}"\n'
245
-
246
- prompt += "\nMy response:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  # Generate suggestion
249
  try:
250
  print(f"Generating suggestion with prompt: {prompt}")
 
251
  response = self.generator(
252
  prompt,
253
- max_length=len(prompt.split()) + max_length,
254
  temperature=temperature,
255
  do_sample=True,
256
  top_p=0.92,
257
  top_k=50,
 
258
  )
259
  # Extract only the generated part, not the prompt
260
  result = response[0]["generated_text"][len(prompt) :]
 
159
  """
160
  self.model_name = model_name
161
  self.model_loaded = False
162
+ self.generator = None
163
+ self.aac_user_info = None
164
 
165
+ # Load AAC user information from social graph
166
  try:
167
+ with open("social_graph.json", "r") as f:
168
+ social_graph = json.load(f)
169
+ self.aac_user_info = social_graph.get("aac_user", {})
 
 
170
  except Exception as e:
171
+ print(f"Error loading AAC user info from social graph: {e}")
172
+ self.aac_user_info = {}
173
+
174
+ # Try to load the model
175
+ self.load_model(model_name)
176
 
177
  # Fallback responses if model fails to load or generate
178
  self.fallback_responses = [
 
180
  "That's interesting. Tell me more.",
181
  "I'd like to talk about that further.",
182
  "I appreciate you sharing that with me.",
183
+ "Could we talk about something else?",
184
+ "I need some time to think about that.",
185
  ]
186
 
187
+ def load_model(self, model_name: str) -> bool:
188
+ """Load a Hugging Face model.
189
+
190
+ Args:
191
+ model_name: Name of the HuggingFace model to use
192
+
193
+ Returns:
194
+ bool: True if model loaded successfully, False otherwise
195
+ """
196
+ self.model_name = model_name
197
+ self.model_loaded = False
198
+
199
+ try:
200
+ print(f"Loading model: {model_name}")
201
+
202
+ # Check if this is a gated model that requires authentication
203
+ is_gated_model = any(
204
+ name in model_name.lower()
205
+ for name in ["gemma", "llama", "mistral", "qwen", "phi"]
206
+ )
207
+
208
+ if is_gated_model:
209
+ # Try to get token from environment
210
+ import os
211
+
212
+ token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
213
+ "HF_TOKEN"
214
+ )
215
+
216
+ if token:
217
+ print(f"Using token for gated model: {model_name}")
218
+ from huggingface_hub import login
219
+
220
+ login(token=token, add_to_git_credential=False)
221
+
222
+ # Explicitly pass token to pipeline
223
+ from transformers import AutoTokenizer, AutoModelForCausalLM
224
+
225
+ try:
226
+ tokenizer = AutoTokenizer.from_pretrained(
227
+ model_name, token=token
228
+ )
229
+ model = AutoModelForCausalLM.from_pretrained(
230
+ model_name, token=token
231
+ )
232
+ self.generator = pipeline(
233
+ "text-generation", model=model, tokenizer=tokenizer
234
+ )
235
+ except Exception as e:
236
+ print(f"Error loading gated model with token: {e}")
237
+ print(
238
+ "This may be due to not having accepted the model license or insufficient permissions."
239
+ )
240
+ print(
241
+ "Please visit the model page on Hugging Face Hub and accept the license."
242
+ )
243
+ raise
244
+ else:
245
+ print("No Hugging Face token found in environment variables.")
246
+ print(
247
+ "To use gated models like Gemma, you need to set up a token with the right permissions."
248
+ )
249
+ print("1. Create a token at https://huggingface.co/settings/tokens")
250
+ print(
251
+ "2. Make sure to enable 'Access to public gated repositories'"
252
+ )
253
+ print(
254
+ "3. Set it as an environment variable: export HUGGING_FACE_HUB_TOKEN=your_token_here"
255
+ )
256
+ raise ValueError("Authentication token required for gated model")
257
+ else:
258
+ # For non-gated models, use the standard pipeline
259
+ self.generator = pipeline("text-generation", model=model_name)
260
+
261
+ self.model_loaded = True
262
+ print(f"Model loaded successfully: {model_name}")
263
+ return True
264
+ except Exception as e:
265
+ print(f"Error loading model: {e}")
266
+ self.model_loaded = False
267
+ return False
268
+
269
  def test_model(self) -> str:
270
  """Test if the model is working correctly."""
271
  if not self.model_loaded:
 
274
  try:
275
  test_prompt = "I am Will. My son Billy asked about football. I respond:"
276
  print(f"Testing model with prompt: {test_prompt}")
277
+ response = self.generator(
278
+ test_prompt, max_new_tokens=30, do_sample=True, truncation=True
279
+ )
280
  result = response[0]["generated_text"][len(test_prompt) :]
281
  print(f"Test response: {result}")
282
  return f"Model test successful: {result}"
 
312
  # Extract context information
313
  name = person_context.get("name", "")
314
  role = person_context.get("role", "")
315
+ topics = person_context.get("topics", [])
316
  context = person_context.get("context", "")
317
  selected_topic = person_context.get("selected_topic", "")
318
+ common_phrases = person_context.get("common_phrases", [])
319
+ frequency = person_context.get("frequency", "")
320
 
321
+ # Get AAC user information
322
+ aac_user = self.aac_user_info
 
 
323
 
324
+ # Build enhanced prompt
325
+ prompt = f"""I am {aac_user.get('name', 'Will')}, a {aac_user.get('age', 38)}-year-old with MND (Motor Neuron Disease) from {aac_user.get('location', 'Manchester')}.
326
+ {aac_user.get('background', '')}
327
 
328
+ My communication needs: {aac_user.get('communication_needs', '')}
 
329
 
330
+ I am talking to {name}, who is my {role}.
331
+ About {name}: {context}
332
+ We typically talk about: {', '.join(topics)}
333
+ We communicate {frequency}.
334
+ """
335
+
336
+ # Add communication style based on relationship
337
+ if role in ["wife", "son", "daughter", "mother", "father"]:
338
+ prompt += "I communicate with my family in a warm, loving way, sometimes using inside jokes.\n"
339
+ elif role in ["doctor", "therapist", "nurse"]:
340
+ prompt += "I communicate with healthcare providers in a direct, informative way.\n"
341
+ elif role in ["best mate", "friend"]:
342
+ prompt += "I communicate with friends casually, often with humor and sometimes swearing.\n"
343
+ elif role in ["work colleague", "boss"]:
344
+ prompt += (
345
+ "I communicate with colleagues professionally but still friendly.\n"
346
+ )
347
 
348
+ # Add topic information if provided
349
+ if selected_topic:
350
+ prompt += f"\nWe are currently discussing {selected_topic}.\n"
351
+
352
+ # Add specific context about this topic with this person
353
+ if selected_topic == "football" and "Manchester United" in context:
354
+ prompt += "We both support Manchester United and often discuss recent matches.\n"
355
+ elif selected_topic == "programming" and "software developer" in context:
356
+ prompt += "We both work in software development and share technical interests.\n"
357
+ elif selected_topic == "family plans" and role in ["wife", "husband"]:
358
+ prompt += (
359
+ "We make family decisions together, considering my condition.\n"
360
+ )
361
+ elif selected_topic == "old scout adventures" and role == "best mate":
362
+ prompt += "We often reminisce about our Scout camping trips in South East London.\n"
363
+ elif selected_topic == "cycling" and "cycling" in context:
364
+ prompt += "I miss being able to cycle but enjoy talking about past cycling adventures.\n"
365
+
366
+ # Add the user's message if provided
367
  if user_input:
368
  prompt += f'\n{name} just said to me: "{user_input}"\n'
369
+ elif common_phrases:
370
+ # Use a common phrase from the person if no message is provided
371
+ default_message = common_phrases[0]
372
+ prompt += f'\n{name} typically says things like: "{default_message}"\n'
373
+
374
+ # Add the response prompt with specific guidance
375
+ # Check if this is an instruction-tuned model
376
+ is_instruction_model = any(
377
+ marker in self.model_name.lower()
378
+ for marker in ["-it", "instruct", "chat", "phi-3", "phi-2"]
379
+ )
380
+
381
+ if is_instruction_model:
382
+ # Use instruction format for instruction-tuned models
383
+ prompt += f"""
384
+ <instruction>
385
+ Respond to {name} in a way that is natural, brief (1-2 sentences), and directly relevant to what they just said.
386
+ Use language appropriate for our relationship.
387
+ </instruction>
388
+
389
+ My response to {name}:"""
390
+ else:
391
+ # Use standard format for non-instruction models
392
+ prompt += f"""
393
+ I want to respond to {name} in a way that is natural, brief (1-2 sentences), and directly relevant to what they just said. I'll use language appropriate for our relationship.
394
+
395
+ My response to {name}:"""
396
 
397
  # Generate suggestion
398
  try:
399
  print(f"Generating suggestion with prompt: {prompt}")
400
+ # Use max_new_tokens instead of max_length to avoid the error
401
  response = self.generator(
402
  prompt,
403
+ max_new_tokens=max_length, # Generate new tokens, not including prompt
404
  temperature=temperature,
405
  do_sample=True,
406
  top_p=0.92,
407
  top_k=50,
408
+ truncation=True,
409
  )
410
  # Extract only the generated part, not the prompt
411
  result = response[0]["generated_text"][len(prompt) :]