Xolkin commited on
Commit
3d556e6
·
verified ·
1 Parent(s): 5b1faf8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -8,7 +8,7 @@ logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
  # Загружаем модель
11
- model_name = "ai-forever/rugpt3medium_based_on_gpt2"
12
  try:
13
  logger.info(f"Попытка загрузки модели {model_name}...")
14
  generator = pipeline(
@@ -16,7 +16,7 @@ try:
16
  model=model_name,
17
  device=-1, # Используем CPU
18
  framework="pt",
19
- max_length=100, # Уменьшен для стабильности
20
  truncation=True,
21
  model_kwargs={"torch_dtype": torch.float32}
22
  )
@@ -25,8 +25,8 @@ except Exception as e:
25
  logger.error(f"Ошибка загрузки модели: {e}")
26
  exit(1)
27
 
28
- def respond(message, max_tokens=100, temperature=0.5, top_p=0.7):
29
- # Явный промпт с акцентом на медицинский ответ
30
  prompt = f"Вы медицинский чат-бот. Пользователь говорит: '{message}'. Дайте краткий ответ только с диагнозом и лечением на русском языке в формате: Диагноз: [диагноз]. Лечение: [лечение]."
31
  try:
32
  logger.info(f"Генерация ответа для: {message}")
@@ -36,7 +36,8 @@ def respond(message, max_tokens=100, temperature=0.5, top_p=0.7):
36
  temperature=temperature,
37
  top_p=top_p,
38
  do_sample=True,
39
- num_return_sequences=1
 
40
  )
41
  response = outputs[0]["generated_text"].replace(prompt, "").strip()
42
  logger.info(f"Ответ сгенерирован: {response}")
@@ -45,7 +46,7 @@ def respond(message, max_tokens=100, temperature=0.5, top_p=0.7):
45
  if "Диагноз:" in response and "Лечение:" in response:
46
  return response
47
  else:
48
- # Если формат не соблюден, пытаемся извлечь диагноз и генерируем базовое лечение
49
  diagnosis = response.split(".")[0].strip() if response else "Неизвестно"
50
  return f"Диагноз: {diagnosis}. Лечение: Обратитесь к врачу для точной помощи."
51
  except Exception as e:
@@ -56,12 +57,12 @@ demo = gr.Interface(
56
  fn=respond,
57
  inputs=[
58
  gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы (например, 'Болит горло')..."),
59
- gr.Slider(minimum=50, maximum=200, value=100, step=10, label="Макс. токенов"),
60
  gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Температура"),
61
  gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Top-p")
62
  ],
63
  outputs="text",
64
- title="Медицинский чат-бот на базе RuGPT-3medium",
65
  theme=gr.themes.Soft(),
66
  description="Введите симптомы, и чат-бот предложит диагноз и лечение. Для точной помощи обратитесь к врачу."
67
  )
 
8
  logger = logging.getLogger(__name__)
9
 
10
  # Загружаем модель
11
+ model_name = "sberbank-ai/rugpt3large_based_on_gpt2"
12
  try:
13
  logger.info(f"Попытка загрузки модели {model_name}...")
14
  generator = pipeline(
 
16
  model=model_name,
17
  device=-1, # Используем CPU
18
  framework="pt",
19
+ max_length=80, # Уменьшен для стабильности на CPU
20
  truncation=True,
21
  model_kwargs={"torch_dtype": torch.float32}
22
  )
 
25
  logger.error(f"Ошибка загрузки модели: {e}")
26
  exit(1)
27
 
28
+ def respond(message, max_tokens=80, temperature=0.5, top_p=0.7):
29
+ # Промпт с акцентом на медицинский ответ
30
  prompt = f"Вы медицинский чат-бот. Пользователь говорит: '{message}'. Дайте краткий ответ только с диагнозом и лечением на русском языке в формате: Диагноз: [диагноз]. Лечение: [лечение]."
31
  try:
32
  logger.info(f"Генерация ответа для: {message}")
 
36
  temperature=temperature,
37
  top_p=top_p,
38
  do_sample=True,
39
+ num_return_sequences=1,
40
+ no_repeat_ngram_size=2 # Предотвращаем повторы
41
  )
42
  response = outputs[0]["generated_text"].replace(prompt, "").strip()
43
  logger.info(f"Ответ сгенерирован: {response}")
 
46
  if "Диагноз:" in response and "Лечение:" in response:
47
  return response
48
  else:
49
+ # Если формат не соблюден, извлекаем диагноз и добавляем базовое лечение
50
  diagnosis = response.split(".")[0].strip() if response else "Неизвестно"
51
  return f"Диагноз: {diagnosis}. Лечение: Обратитесь к врачу для точной помощи."
52
  except Exception as e:
 
57
  fn=respond,
58
  inputs=[
59
  gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы (например, 'Болит горло')..."),
60
+ gr.Slider(minimum=50, maximum=150, value=80, step=10, label="Макс. токенов"),
61
  gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Температура"),
62
  gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Top-p")
63
  ],
64
  outputs="text",
65
+ title="Медицинский чат-бот на базе RuGPT-3 Large",
66
  theme=gr.themes.Soft(),
67
  description="Введите симптомы, и чат-бот предложит диагноз и лечение. Для точной помощи обратитесь к врачу."
68
  )