|
import json |
|
import logging |
|
import os |
|
import tempfile |
|
from pathlib import Path |
|
from typing import Iterable, List, Optional, Sequence |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from acl_anthology import Anthology |
|
from pie_datasets import Dataset, IterableDataset, load_dataset |
|
from pytorch_ie import Pipeline |
|
from pytorch_ie.documents import ( |
|
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, |
|
) |
|
from tqdm import tqdm |
|
|
|
from src.demo.annotation_utils import annotate_documents, create_documents |
|
from src.demo.data_utils import load_text_from_arxiv |
|
from src.demo.rendering_utils import ( |
|
RENDER_WITH_DISPLACY, |
|
RENDER_WITH_PRETTY_TABLE, |
|
render_displacy, |
|
render_pretty_table, |
|
) |
|
from src.demo.retriever_utils import get_text_spans_and_relations_from_document |
|
from src.langchain_modules import ( |
|
DocumentAwareSpanRetriever, |
|
DocumentAwareSpanRetrieverWithRelations, |
|
) |
|
from src.utils.pdf_utils.acl_anthology_utils import XML2RawPapers |
|
from src.utils.pdf_utils.process_pdf import FulltextExtractor, PDFDownloader |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def add_annotated_pie_documents( |
|
retriever: DocumentAwareSpanRetriever, |
|
pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions], |
|
use_predicted_annotations: bool, |
|
verbose: bool = False, |
|
) -> None: |
|
if verbose: |
|
gr.Info(f"Create span embeddings for {len(pie_documents)} documents...") |
|
num_docs_before = len(retriever.docstore) |
|
retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations) |
|
|
|
num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore) |
|
|
|
if num_overwritten_docs > 0: |
|
gr.Warning(f"{num_overwritten_docs} documents were overwritten.") |
|
|
|
|
|
def process_texts( |
|
texts: Iterable[str], |
|
doc_ids: Iterable[str], |
|
argumentation_model: Pipeline, |
|
retriever: DocumentAwareSpanRetriever, |
|
split_regex_escaped: Optional[str], |
|
handle_parts_of_same: bool = False, |
|
verbose: bool = False, |
|
) -> None: |
|
|
|
if len(set(doc_ids)) != len(list(doc_ids)): |
|
raise gr.Error("Document IDs must be unique.") |
|
pie_documents = create_documents( |
|
texts=texts, |
|
doc_ids=doc_ids, |
|
split_regex=split_regex_escaped, |
|
) |
|
if verbose: |
|
gr.Info(f"Annotate {len(pie_documents)} documents...") |
|
pie_documents = annotate_documents( |
|
documents=pie_documents, |
|
argumentation_model=argumentation_model, |
|
handle_parts_of_same=handle_parts_of_same, |
|
) |
|
add_annotated_pie_documents( |
|
retriever=retriever, |
|
pie_documents=pie_documents, |
|
use_predicted_annotations=True, |
|
verbose=verbose, |
|
) |
|
|
|
|
|
def add_annotated_pie_documents_from_dataset( |
|
retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs |
|
) -> None: |
|
try: |
|
gr.Info( |
|
"Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2) |
|
) |
|
dataset = load_dataset(**load_dataset_kwargs) |
|
if not isinstance(dataset, (Dataset, IterableDataset)): |
|
raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.") |
|
dataset_converted = dataset.to_document_type( |
|
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions |
|
) |
|
add_annotated_pie_documents( |
|
retriever=retriever, |
|
pie_documents=dataset_converted, |
|
use_predicted_annotations=False, |
|
verbose=verbose, |
|
) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load dataset: {e}") |
|
|
|
|
|
def wrapped_process_text( |
|
doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs |
|
) -> str: |
|
try: |
|
process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to process text: {e}") |
|
|
|
return doc_id |
|
|
|
|
|
def process_uploaded_files( |
|
file_names: List[str], |
|
retriever: DocumentAwareSpanRetriever, |
|
layer_captions: dict[str, str], |
|
**kwargs, |
|
) -> pd.DataFrame: |
|
try: |
|
doc_ids = [] |
|
texts = [] |
|
for file_name in file_names: |
|
if file_name.lower().endswith(".txt"): |
|
|
|
with open(file_name, "r", encoding="utf-8") as f: |
|
text = f.read() |
|
base_file_name = os.path.basename(file_name) |
|
doc_ids.append(base_file_name) |
|
texts.append(text) |
|
else: |
|
raise gr.Error(f"Unsupported file format: {file_name}") |
|
process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to process uploaded files: {e}") |
|
|
|
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) |
|
|
|
|
|
def process_uploaded_pdf_files( |
|
pdf_fulltext_extractor: Optional[FulltextExtractor], |
|
file_names: List[str], |
|
retriever: DocumentAwareSpanRetriever, |
|
layer_captions: dict[str, str], |
|
**kwargs, |
|
) -> pd.DataFrame: |
|
try: |
|
if pdf_fulltext_extractor is None: |
|
raise gr.Error("PDF fulltext extractor is not available.") |
|
doc_ids = [] |
|
texts = [] |
|
for file_name in file_names: |
|
if file_name.lower().endswith(".pdf"): |
|
|
|
text_and_extraction_data = pdf_fulltext_extractor(file_name) |
|
if text_and_extraction_data is None: |
|
raise gr.Error(f"Failed to extract fulltext from PDF: {file_name}") |
|
text, _ = text_and_extraction_data |
|
|
|
base_file_name = os.path.basename(file_name) |
|
doc_ids.append(base_file_name) |
|
texts.append(text) |
|
|
|
else: |
|
raise gr.Error(f"Unsupported file format: {file_name}") |
|
process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to process uploaded files: {e}") |
|
|
|
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) |
|
|
|
|
|
def load_acl_anthology_venues( |
|
venues: List[str], |
|
pdf_fulltext_extractor: Optional[FulltextExtractor], |
|
retriever: DocumentAwareSpanRetriever, |
|
layer_captions: dict[str, str], |
|
acl_anthology_data_dir: Optional[str], |
|
pdf_output_dir: Optional[str], |
|
show_progress: bool = True, |
|
**kwargs, |
|
) -> pd.DataFrame: |
|
try: |
|
if pdf_fulltext_extractor is None: |
|
raise gr.Error("PDF fulltext extractor is not available.") |
|
if acl_anthology_data_dir is None: |
|
raise gr.Error("ACL Anthology data directory is not provided.") |
|
if pdf_output_dir is None: |
|
raise gr.Error("PDF output directory is not provided.") |
|
xml2raw_papers = XML2RawPapers( |
|
anthology=Anthology(datadir=Path(acl_anthology_data_dir)), |
|
venue_id_whitelist=venues, |
|
verbose=False, |
|
) |
|
pdf_downloader = PDFDownloader() |
|
doc_ids = [] |
|
texts = [] |
|
os.makedirs(pdf_output_dir, exist_ok=True) |
|
papers = xml2raw_papers() |
|
if show_progress: |
|
papers_list = list(papers) |
|
papers = tqdm(papers_list, desc="extracting fulltext") |
|
gr.Info( |
|
f"Downloading and extracting fulltext from {len(papers_list)} papers in venues: {venues}" |
|
) |
|
for paper in papers: |
|
if paper.url is not None: |
|
pdf_save_path = pdf_downloader.download( |
|
paper.url, opath=Path(pdf_output_dir) / f"{paper.name}.pdf" |
|
) |
|
fulltext_extraction_output = pdf_fulltext_extractor(pdf_save_path) |
|
|
|
if fulltext_extraction_output: |
|
text, _ = fulltext_extraction_output |
|
doc_id = f"aclanthology.org/{paper.name}" |
|
doc_ids.append(doc_id) |
|
texts.append(text) |
|
else: |
|
gr.Warning(f"Failed to extract fulltext from PDF: {paper.url}") |
|
|
|
process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to process uploaded files: {e}") |
|
|
|
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) |
|
|
|
|
|
def wrapped_add_annotated_pie_documents_from_dataset( |
|
retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs |
|
) -> pd.DataFrame: |
|
try: |
|
add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}") |
|
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) |
|
|
|
|
|
def download_processed_documents( |
|
retriever: DocumentAwareSpanRetriever, |
|
file_name: str = "retriever_store", |
|
) -> Optional[str]: |
|
if len(retriever.docstore) == 0: |
|
gr.Warning("No documents to download.") |
|
return None |
|
|
|
|
|
file_path = os.path.join(tempfile.gettempdir(), file_name) |
|
|
|
gr.Info(f"Zipping the retriever store to '{file_name}' ...") |
|
result_file_path = retriever.save_to_archive(base_name=file_path, format="zip") |
|
|
|
return result_file_path |
|
|
|
|
|
def upload_processed_documents( |
|
file_name: str, |
|
retriever: DocumentAwareSpanRetriever, |
|
layer_captions: dict[str, str], |
|
) -> pd.DataFrame: |
|
|
|
retriever.load_from_disc(file_name) |
|
|
|
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) |
|
|
|
|
|
def process_text_from_arxiv( |
|
arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs |
|
) -> str: |
|
try: |
|
text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load text from arXiv: {e}") |
|
return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs) |
|
|
|
|
|
def render_annotated_document( |
|
retriever: DocumentAwareSpanRetrieverWithRelations, |
|
document_id: str, |
|
render_with: str, |
|
render_kwargs_json: str, |
|
highlight_span_ids: Optional[List[str]] = None, |
|
) -> str: |
|
text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document( |
|
retriever=retriever, document_id=document_id |
|
) |
|
|
|
render_kwargs = json.loads(render_kwargs_json) |
|
if render_with == RENDER_WITH_PRETTY_TABLE: |
|
html = render_pretty_table( |
|
text=text, |
|
spans=spans, |
|
span_id2idx=span_id2idx, |
|
binary_relations=relations, |
|
**render_kwargs, |
|
) |
|
elif render_with == RENDER_WITH_DISPLACY: |
|
html = render_displacy( |
|
text=text, |
|
spans=spans, |
|
span_id2idx=span_id2idx, |
|
binary_relations=relations, |
|
highlight_span_ids=highlight_span_ids, |
|
**render_kwargs, |
|
) |
|
else: |
|
raise ValueError(f"Unknown render_with value: {render_with}") |
|
|
|
return html |
|
|