1NEYRON1's picture
Update app.py
b4c0a34
raw
history blame
4.79 kB
import streamlit as st
from transformers import pipeline
# Загружаем модель (замените на вашу модель, если нужно)
# Для примера используем zero-shot-classification
try:
classifier = pipeline("zero-shot-classification")
except OSError as e:
st.error(f"Ошибка загрузки модели: {e}. Убедитесь, что модель доступна или укажите другую.")
st.stop() # Остановка выполнения приложения при ошибке
# model =
# tokenizer =
# topic_classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
topic_classifier = pipeline("text-classification")
text = "This is an example sentence for topic classification."
result = topic_classifier(text)
print(result)
def classify_text(title, description, candidate_labels, show_all=False, threshold=0.95):
"""
Классифицирует текст и возвращает результаты в отсортированном виде.
Args:
title (str): Заголовок текста.
description (str): Краткое описание текста.
candidate_labels (list): Список меток-кандидатов.
show_all (bool): Показывать ли все результаты, независимо от порога.
threshold (float): Порог суммарной вероятности.
Returns:
list: Отсортированный список результатов классификации.
"""
text = f"{title} {description}" # Объединяем заголовок и описание
try:
results = topic_classifier(text)
# results = topic_classifier(text, candidate_labels, multi_label=True) # multi_label=True для нескольких меток
except Exception as e:
st.error(f"Ошибка классификации: {e}")
return []
# Сортируем результаты по убыванию вероятности
sorted_results = sorted(zip(results['labels'], results['scores']), key=lambda x: x[1], reverse=True)
if show_all:
return sorted_results
else:
cumulative_prob = 0
filtered_results = []
for label, score in sorted_results:
filtered_results.append((label, score))
cumulative_prob += score
if cumulative_prob >= threshold:
break
return filtered_results
# --- Интерфейс Streamlit ---
st.title("Классификация статей")
# Ввод данных
title = st.text_input("Заголовок статьи")
description = st.text_area("Краткое описание статьи", height=150)
# Ввод меток-кандидатов (разделенных запятыми)
default_labels = "политика, экономика, спорт, культура, технологии, наука, происшествия"
candidate_labels_str = st.text_input("Метки-кандидаты (через запятую)", default_labels)
candidate_labels = [label.strip() for label in candidate_labels_str.split(",") if label.strip()]
# Кнопка "Классифицировать"
if st.button("Классифицировать"):
if not title or not description or not candidate_labels:
st.warning("Пожалуйста, заполните все поля.")
else:
with st.spinner("Идет классификация..."): # Индикатор загрузки
results = classify_text(title, description, candidate_labels)
if results:
st.subheader("Результаты классификации (с ограничением по вероятности):")
for label, score in results:
st.write(f"- **{label}**: {score:.4f}")
# Кнопка "Показать все"
if st.button("Показать все категории"):
all_results = classify_text(title, description, candidate_labels, show_all=True)
st.subheader("Полные результаты классификации:")
for label, score in all_results:
st.write(f"- **{label}**: {score:.4f}")
else:
st.info("Не удалось получить результаты классификации.")
elif title or description or candidate_labels_str != default_labels: #небольшой костыль, чтобы при старте не было предупреждения
st.warning("Пожалуйста, заполните все поля.")