|
import pyrootutils |
|
|
|
root = pyrootutils.setup_root( |
|
search_from=__file__, |
|
indicator=[".project-root"], |
|
pythonpath=True, |
|
dotenv=True, |
|
) |
|
|
|
import json |
|
import logging |
|
import os.path |
|
import re |
|
import tempfile |
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
|
|
|
import arxiv |
|
import gradio as gr |
|
import pandas as pd |
|
import requests |
|
import torch |
|
import yaml |
|
from bs4 import BeautifulSoup |
|
from model_utils import ( |
|
add_annotated_pie_documents_from_dataset, |
|
load_argumentation_model, |
|
load_retriever, |
|
process_texts, |
|
retrieve_all_relevant_spans, |
|
retrieve_all_similar_spans, |
|
retrieve_relevant_spans, |
|
retrieve_similar_spans, |
|
) |
|
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE |
|
from pytorch_ie import Annotation, Pipeline |
|
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan |
|
from rendering_utils import HIGHLIGHT_SPANS_JS, render_displacy, render_pretty_table |
|
|
|
from src.langchain_modules import ( |
|
DocumentAwareSpanRetriever, |
|
DocumentAwareSpanRetrieverWithRelations, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def load_retriever_config(path: str) -> str: |
|
with open(path, "r") as file: |
|
yaml_string = file.read() |
|
config = yaml.safe_load(yaml_string) |
|
return yaml.dump(config) |
|
|
|
|
|
RENDER_WITH_DISPLACY = "displaCy + highlighted arguments" |
|
RENDER_WITH_PRETTY_TABLE = "Pretty Table" |
|
|
|
DEFAULT_MODEL_NAME = "ArneBinder/sam-pointer-bart-base-v0.3" |
|
DEFAULT_MODEL_REVISION = "76300f8e534e2fcf695f00cb49bba166739b8d8a" |
|
|
|
|
|
|
|
DEFAULT_RETRIEVER_CONFIG = load_retriever_config( |
|
"configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml" |
|
) |
|
|
|
DEFAULT_MIN_SIMILARITY = 0.95 |
|
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
DEFAULT_SPLIT_REGEX = "\n\n\n+" |
|
DEFAULT_ARXIV_ID = "1706.03762" |
|
DEFAULT_LOAD_PIE_DATASET_KWARGS_STR = json.dumps( |
|
dict(path="pie/sciarg", name="resolve_parts_of_same", split="train"), indent=2 |
|
) |
|
|
|
|
|
|
|
|
|
HANDLE_PARTS_OF_SAME = True |
|
LAYER_CAPTIONS = { |
|
"labeled_multi_spans": "adus", |
|
"binary_relations": "relations", |
|
"labeled_partitions": "partitions", |
|
} |
|
RELATION_NAME_MAPPING = { |
|
"supports_reversed": "supported by", |
|
"contradicts_reversed": "contradicts", |
|
} |
|
|
|
|
|
def escape_regex(regex: str) -> str: |
|
|
|
result = regex.encode("unicode_escape").decode("utf-8") |
|
return result |
|
|
|
|
|
def unescape_regex(regex: str) -> str: |
|
|
|
result = regex.encode("utf-8").decode("unicode_escape") |
|
return result |
|
|
|
|
|
def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) -> Dict: |
|
document = retriever.get_document(doc_id=doc_id) |
|
return retriever.docstore.as_dict(document) |
|
|
|
|
|
def render_annotated_document( |
|
retriever: DocumentAwareSpanRetrieverWithRelations, |
|
document_id: str, |
|
render_with: str, |
|
render_kwargs_json: str, |
|
) -> str: |
|
text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document( |
|
retriever=retriever, document_id=document_id |
|
) |
|
|
|
render_kwargs = json.loads(render_kwargs_json) |
|
if render_with == RENDER_WITH_PRETTY_TABLE: |
|
html = render_pretty_table( |
|
text=text, |
|
spans=spans, |
|
span_id2idx=span_id2idx, |
|
binary_relations=relations, |
|
**render_kwargs, |
|
) |
|
elif render_with == RENDER_WITH_DISPLACY: |
|
html = render_displacy( |
|
text=text, |
|
spans=spans, |
|
span_id2idx=span_id2idx, |
|
binary_relations=relations, |
|
**render_kwargs, |
|
) |
|
else: |
|
raise ValueError(f"Unknown render_with value: {render_with}") |
|
|
|
return html |
|
|
|
|
|
def wrapped_process_text( |
|
doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs |
|
) -> str: |
|
try: |
|
process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to process text: {e}") |
|
|
|
return doc_id |
|
|
|
|
|
def process_uploaded_files( |
|
file_names: List[str], retriever: DocumentAwareSpanRetriever, **kwargs |
|
) -> pd.DataFrame: |
|
try: |
|
doc_ids = [] |
|
texts = [] |
|
for file_name in file_names: |
|
if file_name.lower().endswith(".txt"): |
|
|
|
with open(file_name, "r", encoding="utf-8") as f: |
|
text = f.read() |
|
base_file_name = os.path.basename(file_name) |
|
doc_ids.append(base_file_name) |
|
texts.append(text) |
|
else: |
|
raise gr.Error(f"Unsupported file format: {file_name}") |
|
process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to process uploaded files: {e}") |
|
|
|
return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True) |
|
|
|
|
|
def wrapped_add_annotated_pie_documents_from_dataset( |
|
retriever: DocumentAwareSpanRetriever, verbose: bool, **kwargs |
|
) -> pd.DataFrame: |
|
try: |
|
add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}") |
|
return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True) |
|
|
|
|
|
def open_accordion(): |
|
return gr.Accordion(open=True) |
|
|
|
|
|
def close_accordion(): |
|
return gr.Accordion(open=False) |
|
|
|
|
|
def get_cell_for_fixed_column_from_df( |
|
evt: gr.SelectData, |
|
df: pd.DataFrame, |
|
column: str, |
|
) -> Any: |
|
"""Get the value of the fixed column for the selected row in the DataFrame. |
|
This is required can *not* with a lambda function because that will not get |
|
the evt parameter. |
|
|
|
Args: |
|
evt: The event object. |
|
df: The DataFrame. |
|
column: The name of the column. |
|
|
|
Returns: |
|
The value of the fixed column for the selected row. |
|
""" |
|
row_idx, col_idx = evt.index |
|
doc_id = df.iloc[row_idx][column] |
|
return doc_id |
|
|
|
|
|
def set_relation_types( |
|
argumentation_model: Pipeline, |
|
default: Optional[List[str]] = None, |
|
) -> gr.Dropdown: |
|
if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE): |
|
relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"] |
|
else: |
|
raise gr.Error("Unsupported taskmodule for relation types") |
|
|
|
return gr.Dropdown( |
|
choices=relation_types, |
|
label="Argumentative Relation Types", |
|
value=default, |
|
multiselect=True, |
|
) |
|
|
|
|
|
def get_span_annotation( |
|
retriever: DocumentAwareSpanRetriever, |
|
span_id: str, |
|
) -> Annotation: |
|
if span_id.strip() == "": |
|
raise gr.Error("No span selected.") |
|
try: |
|
return retriever.get_span_by_id(span_id=span_id) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to retrieve span annotation: {e}") |
|
|
|
|
|
def get_text_spans_and_relations_from_document( |
|
retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str |
|
) -> Tuple[ |
|
str, |
|
Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], |
|
Dict[str, int], |
|
Sequence[BinaryRelation], |
|
]: |
|
document = retriever.get_document(doc_id=document_id) |
|
pie_document = retriever.docstore.unwrap(document) |
|
use_predicted_annotations = retriever.use_predicted_annotations(document) |
|
spans = retriever.get_base_layer( |
|
pie_document=pie_document, use_predicted_annotations=use_predicted_annotations |
|
) |
|
relations = retriever.get_relation_layer( |
|
pie_document=pie_document, use_predicted_annotations=use_predicted_annotations |
|
) |
|
span_id2idx = retriever.get_span_id2idx_from_doc(document) |
|
return pie_document.text, spans, span_id2idx, relations |
|
|
|
|
|
def download_processed_documents( |
|
retriever: DocumentAwareSpanRetriever, |
|
file_name: str = "retriever_store", |
|
) -> Optional[str]: |
|
if len(retriever.docstore) == 0: |
|
gr.Warning("No documents to download.") |
|
return None |
|
|
|
|
|
file_path = os.path.join(tempfile.gettempdir(), file_name) |
|
|
|
gr.Info(f"Zipping the retriever store to '{file_name}' ...") |
|
result_file_path = retriever.save_to_archive(base_name=file_path, format="zip") |
|
|
|
return result_file_path |
|
|
|
|
|
def upload_processed_documents( |
|
file_name: str, |
|
retriever: DocumentAwareSpanRetriever, |
|
) -> pd.DataFrame: |
|
|
|
retriever.load_from_disc(file_name) |
|
|
|
return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True) |
|
|
|
|
|
def clean_spaces(text: str) -> str: |
|
|
|
text = re.sub(" +", " ", text) |
|
|
|
text = re.sub("\n\n+", "\n\n", text) |
|
|
|
text = text.strip() |
|
return text |
|
|
|
|
|
def get_cleaned_arxiv_paper_text(html_content: str) -> str: |
|
|
|
soup = BeautifulSoup(html_content, "html.parser") |
|
|
|
alerts = soup.find("div", class_="package-alerts ltx_document") |
|
|
|
article = soup.find("article") |
|
article_text = article.get_text() |
|
|
|
article_text_clean = clean_spaces(article_text) |
|
return article_text_clean |
|
|
|
|
|
def load_text_from_arxiv(arxiv_id: str, abstract_only: bool = False) -> Tuple[str, str]: |
|
arxiv_id = arxiv_id.strip() |
|
if not arxiv_id: |
|
arxiv_id = DEFAULT_ARXIV_ID |
|
|
|
search_by_id = arxiv.Search(id_list=[arxiv_id]) |
|
try: |
|
result = list(arxiv.Client().results(search_by_id)) |
|
except arxiv.HTTPError as e: |
|
raise gr.Error(f"Failed to fetch arXiv data: {e}") |
|
if len(result) == 0: |
|
raise gr.Error(f"Could not find any paper with arXiv ID '{arxiv_id}'") |
|
first_result = result[0] |
|
if abstract_only: |
|
abstract_clean = first_result.summary.replace("\n", " ") |
|
return abstract_clean, first_result.entry_id |
|
if "/abs/" not in first_result.entry_id: |
|
raise gr.Error( |
|
f"Could not create the HTML URL for arXiv ID '{arxiv_id}' because its entry ID has " |
|
f"an unexpected format: {first_result.entry_id}" |
|
) |
|
html_url = first_result.entry_id.replace("/abs/", "/html/") |
|
request_result = requests.get(html_url) |
|
if request_result.status_code != 200: |
|
raise gr.Error( |
|
f"Could not fetch the HTML content for arXiv ID '{arxiv_id}', status code: " |
|
f"{request_result.status_code}" |
|
) |
|
html_content = request_result.text |
|
text_clean = get_cleaned_arxiv_paper_text(html_content) |
|
return text_clean, html_url |
|
|
|
|
|
def process_text_from_arxiv( |
|
arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs |
|
) -> str: |
|
try: |
|
text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load text from arXiv: {e}") |
|
return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs) |
|
|
|
|
|
def main(): |
|
|
|
example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent." |
|
|
|
print("Loading argumentation model ...") |
|
argumentation_model = load_argumentation_model( |
|
model_name=DEFAULT_MODEL_NAME, |
|
revision=DEFAULT_MODEL_REVISION, |
|
device=DEFAULT_DEVICE, |
|
) |
|
print("Loading retriever ...") |
|
retriever = load_retriever( |
|
DEFAULT_RETRIEVER_CONFIG, device=DEFAULT_DEVICE, config_format="yaml" |
|
) |
|
print("Models loaded.") |
|
|
|
default_render_kwargs = { |
|
"entity_options": { |
|
|
|
"colors": { |
|
"own_claim".upper(): "#009933", |
|
"background_claim".upper(): "#99ccff", |
|
"data".upper(): "#993399", |
|
} |
|
}, |
|
"colors_hover": { |
|
"selected": "#ffa", |
|
|
|
"tail": { |
|
|
|
"supports": "#9f9", |
|
|
|
"contradicts": "#f99", |
|
|
|
"parts_of_same": None, |
|
}, |
|
"head": None, |
|
"other": None, |
|
}, |
|
} |
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
argumentation_model_state = gr.State((argumentation_model,)) |
|
retriever_state = gr.State((retriever,)) |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
doc_id = gr.Textbox( |
|
label="Document ID", |
|
value="user_input", |
|
) |
|
doc_text = gr.Textbox( |
|
label="Text", |
|
lines=20, |
|
value=example_text, |
|
) |
|
|
|
with gr.Accordion("Model Configuration", open=False): |
|
with gr.Accordion("argumentation structure", open=True): |
|
model_name = gr.Textbox( |
|
label="Model Name", |
|
value=DEFAULT_MODEL_NAME, |
|
) |
|
model_revision = gr.Textbox( |
|
label="Model Revision", |
|
value=DEFAULT_MODEL_REVISION, |
|
) |
|
load_arg_model_btn = gr.Button("Load Argumentation Model") |
|
|
|
with gr.Accordion("retriever", open=True): |
|
retriever_config = gr.Textbox( |
|
label="Retriever Configuration", |
|
placeholder="Configuration for the retriever", |
|
value=DEFAULT_RETRIEVER_CONFIG, |
|
lines=len(DEFAULT_RETRIEVER_CONFIG.split("\n")), |
|
) |
|
load_retriever_btn = gr.Button("Load Retriever") |
|
|
|
device = gr.Textbox( |
|
label="Device (e.g. 'cuda' or 'cpu')", |
|
value=DEFAULT_DEVICE, |
|
) |
|
load_arg_model_btn.click( |
|
fn=lambda _model_name, _model_revision, _device: ( |
|
load_argumentation_model( |
|
model_name=_model_name, revision=_model_revision, device=_device |
|
), |
|
), |
|
inputs=[model_name, model_revision, device], |
|
outputs=argumentation_model_state, |
|
) |
|
load_retriever_btn.click( |
|
fn=lambda _retriever_config, _device, _previous_retriever: ( |
|
load_retriever( |
|
retriever_config=_retriever_config, |
|
device=_device, |
|
previous_retriever=_previous_retriever[0], |
|
config_format="yaml", |
|
), |
|
), |
|
inputs=[retriever_config, device, retriever_state], |
|
outputs=retriever_state, |
|
) |
|
|
|
split_regex_escaped = gr.Textbox( |
|
label="Regex to partition the text", |
|
placeholder="Regular expression pattern to split the text into partitions", |
|
value=escape_regex(DEFAULT_SPLIT_REGEX), |
|
) |
|
|
|
predict_btn = gr.Button("Analyse") |
|
|
|
with gr.Column(scale=1): |
|
|
|
selected_document_id = gr.Textbox( |
|
label="Selected Document", max_lines=1, interactive=False |
|
) |
|
rendered_output = gr.HTML(label="Rendered Output") |
|
|
|
with gr.Accordion("Render Options", open=False): |
|
render_as = gr.Dropdown( |
|
label="Render with", |
|
choices=[RENDER_WITH_PRETTY_TABLE, RENDER_WITH_DISPLACY], |
|
value=RENDER_WITH_DISPLACY, |
|
) |
|
render_kwargs = gr.Textbox( |
|
label="Render Arguments", |
|
lines=5, |
|
value=json.dumps(default_render_kwargs, indent=2), |
|
) |
|
render_btn = gr.Button("Re-render") |
|
|
|
with gr.Accordion("See plain result ...", open=False) as document_json_accordion: |
|
get_document_json_btn = gr.Button("Fetch annotated document as JSON") |
|
document_json = gr.JSON(label="Model Output") |
|
|
|
with gr.Column(scale=1): |
|
with gr.Accordion( |
|
"Indexed Documents", open=False |
|
) as processed_documents_accordion: |
|
processed_documents_df = gr.DataFrame( |
|
headers=["id", "num_adus", "num_relations"], |
|
interactive=False, |
|
) |
|
gr.Markdown("Data Snapshot:") |
|
with gr.Row(): |
|
download_processed_documents_btn = gr.DownloadButton("Download") |
|
upload_processed_documents_btn = gr.UploadButton( |
|
"Upload", file_types=["file"] |
|
) |
|
|
|
upload_btn = gr.UploadButton( |
|
"Upload & Analyse Reference Documents", |
|
file_types=["text"], |
|
file_count="multiple", |
|
) |
|
|
|
with gr.Accordion("Import text from arXiv", open=False): |
|
arxiv_id = gr.Textbox( |
|
label="arXiv paper ID", |
|
placeholder=f"e.g. {DEFAULT_ARXIV_ID}", |
|
max_lines=1, |
|
) |
|
load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False) |
|
load_arxiv_btn = gr.Button("Load & process from arXiv", variant="secondary") |
|
|
|
with gr.Accordion("Import annotated PIE dataset", open=False): |
|
load_pie_dataset_kwargs_str = gr.Textbox( |
|
label="Parameters for Loading the PIE Dataset", |
|
value=DEFAULT_LOAD_PIE_DATASET_KWARGS_STR, |
|
lines=len(DEFAULT_LOAD_PIE_DATASET_KWARGS_STR.split("\n")), |
|
) |
|
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset") |
|
|
|
with gr.Accordion("Retrieval Configuration", open=False): |
|
min_similarity = gr.Slider( |
|
label="Minimum Similarity", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.01, |
|
value=DEFAULT_MIN_SIMILARITY, |
|
) |
|
top_k = gr.Slider( |
|
label="Top K", |
|
minimum=2, |
|
maximum=50, |
|
step=1, |
|
value=10, |
|
) |
|
retrieve_similar_adus_btn = gr.Button( |
|
"Retrieve *similar* ADUs for *selected* ADU" |
|
) |
|
similar_adus_df = gr.DataFrame( |
|
headers=["doc_id", "adu_id", "score", "text"], interactive=False |
|
) |
|
retrieve_all_similar_adus_btn = gr.Button( |
|
"Retrieve *similar* ADUs for *all* ADUs in the document" |
|
) |
|
all_similar_adus_df = gr.DataFrame( |
|
headers=["doc_id", "query_adu_id", "adu_id", "score", "text"], |
|
interactive=False, |
|
) |
|
retrieve_all_relevant_adus_btn = gr.Button( |
|
"Retrieve *relevant* ADUs for *all* ADUs in the document" |
|
) |
|
all_relevant_adus_df = gr.DataFrame( |
|
headers=["doc_id", "adu_id", "score", "text"], interactive=False |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hover_adu_id = gr.Textbox( |
|
label="ID (hover)", |
|
elem_id="hover_adu_id", |
|
interactive=False, |
|
visible=False, |
|
) |
|
selected_adu_id = gr.Textbox( |
|
label="ID (selected)", |
|
elem_id="selected_adu_id", |
|
interactive=False, |
|
visible=False, |
|
) |
|
selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False) |
|
|
|
with gr.Accordion("Relevant ADUs from other documents", open=True): |
|
|
|
relevant_adus_df = gr.DataFrame( |
|
headers=[ |
|
"relation", |
|
"adu", |
|
"reference_adu", |
|
"doc_id", |
|
"sim_score", |
|
"rel_score", |
|
], |
|
interactive=False, |
|
) |
|
|
|
render_event_kwargs = dict( |
|
fn=lambda _retriever, _document_id, _render_as, _render_kwargs: render_annotated_document( |
|
retriever=_retriever[0], |
|
document_id=_document_id, |
|
render_with=_render_as, |
|
render_kwargs_json=_render_kwargs, |
|
), |
|
inputs=[retriever_state, selected_document_id, render_as, render_kwargs], |
|
outputs=rendered_output, |
|
) |
|
|
|
show_overview_kwargs = dict( |
|
fn=lambda _retriever: _retriever[0].docstore.overview( |
|
layer_captions=LAYER_CAPTIONS, use_predictions=True |
|
), |
|
inputs=[retriever_state], |
|
outputs=[processed_documents_df], |
|
) |
|
predict_btn.click( |
|
fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text( |
|
text=_doc_text, |
|
doc_id=_doc_id, |
|
argumentation_model=_argumentation_model[0], |
|
retriever=_retriever[0], |
|
split_regex_escaped=( |
|
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None |
|
), |
|
handle_parts_of_same=HANDLE_PARTS_OF_SAME, |
|
), |
|
inputs=[ |
|
doc_text, |
|
doc_id, |
|
argumentation_model_state, |
|
retriever_state, |
|
split_regex_escaped, |
|
], |
|
outputs=[selected_document_id], |
|
api_name="predict", |
|
).success(**show_overview_kwargs).success(**render_event_kwargs) |
|
render_btn.click(**render_event_kwargs, api_name="render") |
|
|
|
load_arxiv_btn.click( |
|
fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv( |
|
arxiv_id=_arxiv_id, |
|
abstract_only=_load_arxiv_only_abstract, |
|
argumentation_model=_argumentation_model[0], |
|
retriever=_retriever[0], |
|
split_regex_escaped=( |
|
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None |
|
), |
|
handle_parts_of_same=HANDLE_PARTS_OF_SAME, |
|
), |
|
inputs=[ |
|
arxiv_id, |
|
load_arxiv_only_abstract, |
|
argumentation_model_state, |
|
retriever_state, |
|
split_regex_escaped, |
|
], |
|
outputs=[selected_document_id], |
|
api_name="predict", |
|
).success(**show_overview_kwargs) |
|
|
|
load_pie_dataset_btn.click( |
|
fn=open_accordion, inputs=[], outputs=[processed_documents_accordion] |
|
).then( |
|
fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset( |
|
retriever=_retriever[0], verbose=True, **json.loads(_load_pie_dataset_kwargs_str) |
|
), |
|
inputs=[retriever_state, load_pie_dataset_kwargs_str], |
|
outputs=[processed_documents_df], |
|
) |
|
|
|
selected_document_id.change(**render_event_kwargs) |
|
|
|
get_document_json_btn.click( |
|
fn=lambda _retriever, _document_id: get_document_as_dict( |
|
retriever=_retriever[0], doc_id=_document_id |
|
), |
|
inputs=[retriever_state, selected_document_id], |
|
outputs=[document_json], |
|
) |
|
|
|
upload_btn.upload( |
|
fn=open_accordion, inputs=[], outputs=[processed_documents_accordion] |
|
).then( |
|
fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files( |
|
file_names=_file_names, |
|
argumentation_model=_argumentation_model[0], |
|
retriever=_retriever[0], |
|
split_regex_escaped=unescape_regex(_split_regex_escaped), |
|
handle_parts_of_same=HANDLE_PARTS_OF_SAME, |
|
), |
|
inputs=[ |
|
upload_btn, |
|
argumentation_model_state, |
|
retriever_state, |
|
split_regex_escaped, |
|
], |
|
outputs=[processed_documents_df], |
|
) |
|
processed_documents_df.select( |
|
fn=get_cell_for_fixed_column_from_df, |
|
inputs=[processed_documents_df, gr.State("doc_id")], |
|
outputs=[selected_document_id], |
|
) |
|
|
|
download_processed_documents_btn.click( |
|
fn=lambda _retriever: download_processed_documents( |
|
_retriever[0], file_name="processed_documents" |
|
), |
|
inputs=[retriever_state], |
|
outputs=[download_processed_documents_btn], |
|
) |
|
upload_processed_documents_btn.upload( |
|
fn=lambda file_name, _retriever: upload_processed_documents( |
|
file_name, retriever=_retriever[0] |
|
), |
|
inputs=[upload_processed_documents_btn, retriever_state], |
|
outputs=[processed_documents_df], |
|
) |
|
|
|
retrieve_relevant_adus_event_kwargs = dict( |
|
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans( |
|
retriever=_retriever[0], |
|
query_span_id=_selected_adu_id, |
|
k=_top_k, |
|
score_threshold=_min_similarity, |
|
relation_label_mapping=RELATION_NAME_MAPPING, |
|
|
|
), |
|
inputs=[ |
|
retriever_state, |
|
selected_adu_id, |
|
min_similarity, |
|
top_k, |
|
], |
|
outputs=[relevant_adus_df], |
|
) |
|
relevant_adus_df.select( |
|
fn=get_cell_for_fixed_column_from_df, |
|
inputs=[relevant_adus_df, gr.State("doc_id")], |
|
outputs=[selected_document_id], |
|
) |
|
|
|
selected_adu_id.change( |
|
fn=lambda _retriever, _selected_adu_id: get_span_annotation( |
|
retriever=_retriever[0], span_id=_selected_adu_id |
|
), |
|
inputs=[retriever_state, selected_adu_id], |
|
outputs=[selected_adu_text], |
|
).success(**retrieve_relevant_adus_event_kwargs) |
|
|
|
retrieve_similar_adus_btn.click( |
|
fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans( |
|
retriever=_retriever[0], |
|
query_span_id=_selected_adu_id, |
|
k=_tok_k, |
|
score_threshold=_min_similarity, |
|
), |
|
inputs=[ |
|
retriever_state, |
|
selected_adu_id, |
|
min_similarity, |
|
top_k, |
|
], |
|
outputs=[similar_adus_df], |
|
) |
|
similar_adus_df.select( |
|
fn=get_cell_for_fixed_column_from_df, |
|
inputs=[similar_adus_df, gr.State("doc_id")], |
|
outputs=[selected_document_id], |
|
) |
|
|
|
retrieve_all_similar_adus_btn.click( |
|
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans( |
|
retriever=_retriever[0], |
|
query_doc_id=_document_id, |
|
k=_tok_k, |
|
score_threshold=_min_similarity, |
|
query_span_id_column="query_span_id", |
|
), |
|
inputs=[ |
|
retriever_state, |
|
selected_document_id, |
|
min_similarity, |
|
top_k, |
|
], |
|
outputs=[all_similar_adus_df], |
|
) |
|
|
|
retrieve_all_relevant_adus_btn.click( |
|
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_relevant_spans( |
|
retriever=_retriever[0], |
|
query_doc_id=_document_id, |
|
k=_tok_k, |
|
score_threshold=_min_similarity, |
|
query_span_id_column="query_span_id", |
|
), |
|
inputs=[ |
|
retriever_state, |
|
selected_document_id, |
|
min_similarity, |
|
top_k, |
|
], |
|
outputs=[all_relevant_adus_df], |
|
) |
|
|
|
|
|
all_similar_adus_df.select( |
|
fn=get_cell_for_fixed_column_from_df, |
|
inputs=[all_similar_adus_df, gr.State("query_span_id")], |
|
outputs=[selected_adu_id], |
|
) |
|
all_relevant_adus_df.select( |
|
fn=get_cell_for_fixed_column_from_df, |
|
inputs=[all_relevant_adus_df, gr.State("query_span_id")], |
|
outputs=[selected_adu_id], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[]) |
|
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
logging.basicConfig() |
|
|
|
main() |
|
|