|
import gradio as gr |
|
from transformers import pipeline |
|
import torch |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model_name = "sberbank-ai/rugpt3large_based_on_gpt2" |
|
try: |
|
logger.info(f"Попытка загрузки модели {model_name}...") |
|
generator = pipeline( |
|
"text-generation", |
|
model=model_name, |
|
device=-1, |
|
framework="pt", |
|
max_length=80, |
|
truncation=True, |
|
model_kwargs={"torch_dtype": torch.float32} |
|
) |
|
logger.info("Модель успешно загружена.") |
|
except Exception as e: |
|
logger.error(f"Ошибка загрузки модели: {e}") |
|
exit(1) |
|
|
|
def respond(message, max_tokens=80, temperature=0.5, top_p=0.7): |
|
|
|
prompt = f"Вы медицинский чат-бот. Пользователь говорит: '{message}'. Дайте краткий ответ только с диагнозом и лечением на русском языке в формате: Диагноз: [диагноз]. Лечение: [лечение]." |
|
try: |
|
logger.info(f"Генерация ответа для: {message}") |
|
outputs = generator( |
|
prompt, |
|
max_length=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2 |
|
) |
|
response = outputs[0]["generated_text"].replace(prompt, "").strip() |
|
logger.info(f"Ответ сгенерирован: {response}") |
|
|
|
|
|
if "Диагноз:" in response and "Лечение:" in response: |
|
return response |
|
else: |
|
|
|
diagnosis = response.split(".")[0].strip() if response else "Неизвестно" |
|
return f"Диагноз: {diagnosis}. Лечение: Обратитесь к врачу для точной помощи." |
|
except Exception as e: |
|
logger.error(f"Ошибка генерации ответа: {e}") |
|
return "Ошибка генерации. Проконсультируйтесь с врачом." |
|
|
|
demo = gr.Interface( |
|
fn=respond, |
|
inputs=[ |
|
gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы (например, 'Болит горло')..."), |
|
gr.Slider(minimum=50, maximum=150, value=80, step=10, label="Макс. токенов"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Температура"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Top-p") |
|
], |
|
outputs="text", |
|
title="Медицинский чат-бот на базе RuGPT-3 Large", |
|
theme=gr.themes.Soft(), |
|
description="Введите симптомы, и чат-бот предложит диагноз и лечение. Для точной помощи обратитесь к врачу." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |