DoctorAI / app.py
Xolkin's picture
Update app.py
3d556e6 verified
import gradio as gr
from transformers import pipeline
import torch
import logging
# Настройка логирования
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Загружаем модель
model_name = "sberbank-ai/rugpt3large_based_on_gpt2"
try:
logger.info(f"Попытка загрузки модели {model_name}...")
generator = pipeline(
"text-generation",
model=model_name,
device=-1, # Используем CPU
framework="pt",
max_length=80, # Уменьшен для стабильности на CPU
truncation=True,
model_kwargs={"torch_dtype": torch.float32}
)
logger.info("Модель успешно загружена.")
except Exception as e:
logger.error(f"Ошибка загрузки модели: {e}")
exit(1)
def respond(message, max_tokens=80, temperature=0.5, top_p=0.7):
# Промпт с акцентом на медицинский ответ
prompt = f"Вы медицинский чат-бот. Пользователь говорит: '{message}'. Дайте краткий ответ только с диагнозом и лечением на русском языке в формате: Диагноз: [диагноз]. Лечение: [лечение]."
try:
logger.info(f"Генерация ответа для: {message}")
outputs = generator(
prompt,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
num_return_sequences=1,
no_repeat_ngram_size=2 # Предотвращаем повторы
)
response = outputs[0]["generated_text"].replace(prompt, "").strip()
logger.info(f"Ответ сгенерирован: {response}")
# Проверка и форматирование ответа
if "Диагноз:" in response and "Лечение:" in response:
return response
else:
# Если формат не соблюден, извлекаем диагноз и добавляем базовое лечение
diagnosis = response.split(".")[0].strip() if response else "Неизвестно"
return f"Диагноз: {diagnosis}. Лечение: Обратитесь к врачу для точной помощи."
except Exception as e:
logger.error(f"Ошибка генерации ответа: {e}")
return "Ошибка генерации. Проконсультируйтесь с врачом."
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы (например, 'Болит горло')..."),
gr.Slider(minimum=50, maximum=150, value=80, step=10, label="Макс. токенов"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Температура"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Top-p")
],
outputs="text",
title="Медицинский чат-бот на базе RuGPT-3 Large",
theme=gr.themes.Soft(),
description="Введите симптомы, и чат-бот предложит диагноз и лечение. Для точной помощи обратитесь к врачу."
)
if __name__ == "__main__":
demo.launch()