|
import streamlit as st |
|
from transformers import pipeline |
|
|
|
|
|
|
|
try: |
|
classifier = pipeline("zero-shot-classification") |
|
except OSError as e: |
|
st.error(f"Ошибка загрузки модели: {e}. Убедитесь, что модель доступна или укажите другую.") |
|
st.stop() |
|
|
|
|
|
|
|
|
|
|
|
topic_classifier = pipeline("text-classification") |
|
|
|
text = "This is an example sentence for topic classification." |
|
result = topic_classifier(text) |
|
print(result) |
|
|
|
def classify_text(title, description, candidate_labels, show_all=False, threshold=0.95): |
|
""" |
|
Классифицирует текст и возвращает результаты в отсортированном виде. |
|
|
|
Args: |
|
title (str): Заголовок текста. |
|
description (str): Краткое описание текста. |
|
candidate_labels (list): Список меток-кандидатов. |
|
show_all (bool): Показывать ли все результаты, независимо от порога. |
|
threshold (float): Порог суммарной вероятности. |
|
|
|
Returns: |
|
list: Отсортированный список результатов классификации. |
|
""" |
|
text = f"{title} {description}" |
|
try: |
|
results = topic_classifier(text) |
|
|
|
except Exception as e: |
|
st.error(f"Ошибка классификации: {e}") |
|
return [] |
|
|
|
|
|
sorted_results = sorted(zip(results['labels'], results['scores']), key=lambda x: x[1], reverse=True) |
|
|
|
if show_all: |
|
return sorted_results |
|
else: |
|
cumulative_prob = 0 |
|
filtered_results = [] |
|
for label, score in sorted_results: |
|
filtered_results.append((label, score)) |
|
cumulative_prob += score |
|
if cumulative_prob >= threshold: |
|
break |
|
return filtered_results |
|
|
|
|
|
|
|
st.title("Классификация статей") |
|
|
|
|
|
title = st.text_input("Заголовок статьи") |
|
description = st.text_area("Краткое описание статьи", height=150) |
|
|
|
|
|
default_labels = "политика, экономика, спорт, культура, технологии, наука, происшествия" |
|
candidate_labels_str = st.text_input("Метки-кандидаты (через запятую)", default_labels) |
|
candidate_labels = [label.strip() for label in candidate_labels_str.split(",") if label.strip()] |
|
|
|
|
|
if st.button("Классифицировать"): |
|
if not title or not description or not candidate_labels: |
|
st.warning("Пожалуйста, заполните все поля.") |
|
else: |
|
with st.spinner("Идет классификация..."): |
|
results = classify_text(title, description, candidate_labels) |
|
if results: |
|
st.subheader("Результаты классификации (с ограничением по вероятности):") |
|
for label, score in results: |
|
st.write(f"- **{label}**: {score:.4f}") |
|
|
|
|
|
if st.button("Показать все категории"): |
|
all_results = classify_text(title, description, candidate_labels, show_all=True) |
|
st.subheader("Полные результаты классификации:") |
|
for label, score in all_results: |
|
st.write(f"- **{label}**: {score:.4f}") |
|
else: |
|
st.info("Не удалось получить результаты классификации.") |
|
|
|
elif title or description or candidate_labels_str != default_labels: |
|
st.warning("Пожалуйста, заполните все поля.") |