import pyrootutils root = pyrootutils.setup_root( search_from=__file__, indicator=[".project-root"], pythonpath=True, dotenv=True, ) import json import logging import os.path import re import tempfile from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import arxiv import gradio as gr import pandas as pd import requests import torch import yaml from bs4 import BeautifulSoup from model_utils import ( add_annotated_pie_documents_from_dataset, load_argumentation_model, load_retriever, process_texts, retrieve_all_relevant_spans, retrieve_all_similar_spans, retrieve_relevant_spans, retrieve_similar_spans, ) from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE from pytorch_ie import Annotation, Pipeline from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan from rendering_utils import HIGHLIGHT_SPANS_JS, render_displacy, render_pretty_table from src.langchain_modules import ( DocumentAwareSpanRetriever, DocumentAwareSpanRetrieverWithRelations, ) logger = logging.getLogger(__name__) def load_retriever_config(path: str) -> str: with open(path, "r") as file: yaml_string = file.read() config = yaml.safe_load(yaml_string) return yaml.dump(config) RENDER_WITH_DISPLACY = "displaCy + highlighted arguments" RENDER_WITH_PRETTY_TABLE = "Pretty Table" DEFAULT_MODEL_NAME = "ArneBinder/sam-pointer-bart-base-v0.3" DEFAULT_MODEL_REVISION = "76300f8e534e2fcf695f00cb49bba166739b8d8a" # local path # DEFAULT_MODEL_NAME = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46" # DEFAULT_MODEL_REVISION = None DEFAULT_RETRIEVER_CONFIG = load_retriever_config( "configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml" ) # 0.943180 from data_dir="predictions/default/2024-10-15_23-40-18" DEFAULT_MIN_SIMILARITY = 0.95 DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_SPLIT_REGEX = "\n\n\n+" DEFAULT_ARXIV_ID = "1706.03762" DEFAULT_LOAD_PIE_DATASET_KWARGS_STR = json.dumps( dict(path="pie/sciarg", name="resolve_parts_of_same", split="train"), indent=2 ) # Whether to handle segmented entities in the document. If True, labeled_spans are converted # to labeled_multi_spans and binary_relations with label "parts_of_same" are used to merge them. # This requires the networkx package to be installed. HANDLE_PARTS_OF_SAME = True LAYER_CAPTIONS = { "labeled_multi_spans": "adus", "binary_relations": "relations", "labeled_partitions": "partitions", } RELATION_NAME_MAPPING = { "supports_reversed": "supported by", "contradicts_reversed": "contradicts", } def escape_regex(regex: str) -> str: # "double escape" the backslashes result = regex.encode("unicode_escape").decode("utf-8") return result def unescape_regex(regex: str) -> str: # reverse of escape_regex result = regex.encode("utf-8").decode("unicode_escape") return result def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) -> Dict: document = retriever.get_document(doc_id=doc_id) return retriever.docstore.as_dict(document) def render_annotated_document( retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str, render_with: str, render_kwargs_json: str, ) -> str: text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document( retriever=retriever, document_id=document_id ) render_kwargs = json.loads(render_kwargs_json) if render_with == RENDER_WITH_PRETTY_TABLE: html = render_pretty_table( text=text, spans=spans, span_id2idx=span_id2idx, binary_relations=relations, **render_kwargs, ) elif render_with == RENDER_WITH_DISPLACY: html = render_displacy( text=text, spans=spans, span_id2idx=span_id2idx, binary_relations=relations, **render_kwargs, ) else: raise ValueError(f"Unknown render_with value: {render_with}") return html def wrapped_process_text( doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs ) -> str: try: process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs) except Exception as e: raise gr.Error(f"Failed to process text: {e}") # Return as dict and document to avoid serialization issues return doc_id def process_uploaded_files( file_names: List[str], retriever: DocumentAwareSpanRetriever, **kwargs ) -> pd.DataFrame: try: doc_ids = [] texts = [] for file_name in file_names: if file_name.lower().endswith(".txt"): # read the file content with open(file_name, "r", encoding="utf-8") as f: text = f.read() base_file_name = os.path.basename(file_name) doc_ids.append(base_file_name) texts.append(text) else: raise gr.Error(f"Unsupported file format: {file_name}") process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) except Exception as e: raise gr.Error(f"Failed to process uploaded files: {e}") return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True) def wrapped_add_annotated_pie_documents_from_dataset( retriever: DocumentAwareSpanRetriever, verbose: bool, **kwargs ) -> pd.DataFrame: try: add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs) except Exception as e: raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}") return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True) def open_accordion(): return gr.Accordion(open=True) def close_accordion(): return gr.Accordion(open=False) def get_cell_for_fixed_column_from_df( evt: gr.SelectData, df: pd.DataFrame, column: str, ) -> Any: """Get the value of the fixed column for the selected row in the DataFrame. This is required can *not* with a lambda function because that will not get the evt parameter. Args: evt: The event object. df: The DataFrame. column: The name of the column. Returns: The value of the fixed column for the selected row. """ row_idx, col_idx = evt.index doc_id = df.iloc[row_idx][column] return doc_id def set_relation_types( argumentation_model: Pipeline, default: Optional[List[str]] = None, ) -> gr.Dropdown: if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE): relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"] else: raise gr.Error("Unsupported taskmodule for relation types") return gr.Dropdown( choices=relation_types, label="Argumentative Relation Types", value=default, multiselect=True, ) def get_span_annotation( retriever: DocumentAwareSpanRetriever, span_id: str, ) -> Annotation: if span_id.strip() == "": raise gr.Error("No span selected.") try: return retriever.get_span_by_id(span_id=span_id) except Exception as e: raise gr.Error(f"Failed to retrieve span annotation: {e}") def get_text_spans_and_relations_from_document( retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str ) -> Tuple[ str, Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], Dict[str, int], Sequence[BinaryRelation], ]: document = retriever.get_document(doc_id=document_id) pie_document = retriever.docstore.unwrap(document) use_predicted_annotations = retriever.use_predicted_annotations(document) spans = retriever.get_base_layer( pie_document=pie_document, use_predicted_annotations=use_predicted_annotations ) relations = retriever.get_relation_layer( pie_document=pie_document, use_predicted_annotations=use_predicted_annotations ) span_id2idx = retriever.get_span_id2idx_from_doc(document) return pie_document.text, spans, span_id2idx, relations def download_processed_documents( retriever: DocumentAwareSpanRetriever, file_name: str = "retriever_store", ) -> Optional[str]: if len(retriever.docstore) == 0: gr.Warning("No documents to download.") return None # zip the directory file_path = os.path.join(tempfile.gettempdir(), file_name) gr.Info(f"Zipping the retriever store to '{file_name}' ...") result_file_path = retriever.save_to_archive(base_name=file_path, format="zip") return result_file_path def upload_processed_documents( file_name: str, retriever: DocumentAwareSpanRetriever, ) -> pd.DataFrame: # load the documents from the zip file or directory retriever.load_from_disc(file_name) # return the overview of the document store return retriever.docstore.overview(layer_captions=LAYER_CAPTIONS, use_predictions=True) def clean_spaces(text: str) -> str: # replace all multiple spaces with a single space text = re.sub(" +", " ", text) # reduce more than two newlines to two newlines text = re.sub("\n\n+", "\n\n", text) # remove leading and trailing whitespaces text = text.strip() return text def get_cleaned_arxiv_paper_text(html_content: str) -> str: # parse the HTML content with BeautifulSoup soup = BeautifulSoup(html_content, "html.parser") # get alerts (this is one div with classes "package-alerts" and "ltx_document") alerts = soup.find("div", class_="package-alerts ltx_document") # get the "article" html element article = soup.find("article") article_text = article.get_text() # cleanup the text article_text_clean = clean_spaces(article_text) return article_text_clean def load_text_from_arxiv(arxiv_id: str, abstract_only: bool = False) -> Tuple[str, str]: arxiv_id = arxiv_id.strip() if not arxiv_id: arxiv_id = DEFAULT_ARXIV_ID search_by_id = arxiv.Search(id_list=[arxiv_id]) try: result = list(arxiv.Client().results(search_by_id)) except arxiv.HTTPError as e: raise gr.Error(f"Failed to fetch arXiv data: {e}") if len(result) == 0: raise gr.Error(f"Could not find any paper with arXiv ID '{arxiv_id}'") first_result = result[0] if abstract_only: abstract_clean = first_result.summary.replace("\n", " ") return abstract_clean, first_result.entry_id if "/abs/" not in first_result.entry_id: raise gr.Error( f"Could not create the HTML URL for arXiv ID '{arxiv_id}' because its entry ID has " f"an unexpected format: {first_result.entry_id}" ) html_url = first_result.entry_id.replace("/abs/", "/html/") request_result = requests.get(html_url) if request_result.status_code != 200: raise gr.Error( f"Could not fetch the HTML content for arXiv ID '{arxiv_id}', status code: " f"{request_result.status_code}" ) html_content = request_result.text text_clean = get_cleaned_arxiv_paper_text(html_content) return text_clean, html_url def process_text_from_arxiv( arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs ) -> str: try: text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only) except Exception as e: raise gr.Error(f"Failed to load text from arXiv: {e}") return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs) def main(): example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent." print("Loading argumentation model ...") argumentation_model = load_argumentation_model( model_name=DEFAULT_MODEL_NAME, revision=DEFAULT_MODEL_REVISION, device=DEFAULT_DEVICE, ) print("Loading retriever ...") retriever = load_retriever( DEFAULT_RETRIEVER_CONFIG, device=DEFAULT_DEVICE, config_format="yaml" ) print("Models loaded.") default_render_kwargs = { "entity_options": { # we need to convert the keys to uppercase because the spacy rendering function expects them in uppercase "colors": { "own_claim".upper(): "#009933", "background_claim".upper(): "#99ccff", "data".upper(): "#993399", } }, "colors_hover": { "selected": "#ffa", # "tail": "#aff", "tail": { # green "supports": "#9f9", # red "contradicts": "#f99", # do not highlight "parts_of_same": None, }, "head": None, # "#faf", "other": None, }, } with gr.Blocks() as demo: # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called # models_state = gr.State((argumentation_model, embedding_model)) argumentation_model_state = gr.State((argumentation_model,)) retriever_state = gr.State((retriever,)) with gr.Row(): with gr.Column(scale=1): doc_id = gr.Textbox( label="Document ID", value="user_input", ) doc_text = gr.Textbox( label="Text", lines=20, value=example_text, ) with gr.Accordion("Model Configuration", open=False): with gr.Accordion("argumentation structure", open=True): model_name = gr.Textbox( label="Model Name", value=DEFAULT_MODEL_NAME, ) model_revision = gr.Textbox( label="Model Revision", value=DEFAULT_MODEL_REVISION, ) load_arg_model_btn = gr.Button("Load Argumentation Model") with gr.Accordion("retriever", open=True): retriever_config = gr.Textbox( label="Retriever Configuration", placeholder="Configuration for the retriever", value=DEFAULT_RETRIEVER_CONFIG, lines=len(DEFAULT_RETRIEVER_CONFIG.split("\n")), ) load_retriever_btn = gr.Button("Load Retriever") device = gr.Textbox( label="Device (e.g. 'cuda' or 'cpu')", value=DEFAULT_DEVICE, ) load_arg_model_btn.click( fn=lambda _model_name, _model_revision, _device: ( load_argumentation_model( model_name=_model_name, revision=_model_revision, device=_device ), ), inputs=[model_name, model_revision, device], outputs=argumentation_model_state, ) load_retriever_btn.click( fn=lambda _retriever_config, _device, _previous_retriever: ( load_retriever( retriever_config=_retriever_config, device=_device, previous_retriever=_previous_retriever[0], config_format="yaml", ), ), inputs=[retriever_config, device, retriever_state], outputs=retriever_state, ) split_regex_escaped = gr.Textbox( label="Regex to partition the text", placeholder="Regular expression pattern to split the text into partitions", value=escape_regex(DEFAULT_SPLIT_REGEX), ) predict_btn = gr.Button("Analyse") with gr.Column(scale=1): selected_document_id = gr.Textbox( label="Selected Document", max_lines=1, interactive=False ) rendered_output = gr.HTML(label="Rendered Output") with gr.Accordion("Render Options", open=False): render_as = gr.Dropdown( label="Render with", choices=[RENDER_WITH_PRETTY_TABLE, RENDER_WITH_DISPLACY], value=RENDER_WITH_DISPLACY, ) render_kwargs = gr.Textbox( label="Render Arguments", lines=5, value=json.dumps(default_render_kwargs, indent=2), ) render_btn = gr.Button("Re-render") with gr.Accordion("See plain result ...", open=False) as document_json_accordion: get_document_json_btn = gr.Button("Fetch annotated document as JSON") document_json = gr.JSON(label="Model Output") with gr.Column(scale=1): with gr.Accordion( "Indexed Documents", open=False ) as processed_documents_accordion: processed_documents_df = gr.DataFrame( headers=["id", "num_adus", "num_relations"], interactive=False, ) gr.Markdown("Data Snapshot:") with gr.Row(): download_processed_documents_btn = gr.DownloadButton("Download") upload_processed_documents_btn = gr.UploadButton( "Upload", file_types=["file"] ) upload_btn = gr.UploadButton( "Upload & Analyse Reference Documents", file_types=["text"], file_count="multiple", ) with gr.Accordion("Import text from arXiv", open=False): arxiv_id = gr.Textbox( label="arXiv paper ID", placeholder=f"e.g. {DEFAULT_ARXIV_ID}", max_lines=1, ) load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False) load_arxiv_btn = gr.Button("Load & process from arXiv", variant="secondary") with gr.Accordion("Import annotated PIE dataset", open=False): load_pie_dataset_kwargs_str = gr.Textbox( label="Parameters for Loading the PIE Dataset", value=DEFAULT_LOAD_PIE_DATASET_KWARGS_STR, lines=len(DEFAULT_LOAD_PIE_DATASET_KWARGS_STR.split("\n")), ) load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset") with gr.Accordion("Retrieval Configuration", open=False): min_similarity = gr.Slider( label="Minimum Similarity", minimum=0.0, maximum=1.0, step=0.01, value=DEFAULT_MIN_SIMILARITY, ) top_k = gr.Slider( label="Top K", minimum=2, maximum=50, step=1, value=10, ) retrieve_similar_adus_btn = gr.Button( "Retrieve *similar* ADUs for *selected* ADU" ) similar_adus_df = gr.DataFrame( headers=["doc_id", "adu_id", "score", "text"], interactive=False ) retrieve_all_similar_adus_btn = gr.Button( "Retrieve *similar* ADUs for *all* ADUs in the document" ) all_similar_adus_df = gr.DataFrame( headers=["doc_id", "query_adu_id", "adu_id", "score", "text"], interactive=False, ) retrieve_all_relevant_adus_btn = gr.Button( "Retrieve *relevant* ADUs for *all* ADUs in the document" ) all_relevant_adus_df = gr.DataFrame( headers=["doc_id", "adu_id", "score", "text"], interactive=False ) # currently not used # relation_types = set_relation_types( # argumentation_model_state.value[0], default=["supports_reversed", "contradicts_reversed"] # ) # Dummy textbox to hold the hover adu id. On click on the rendered output, # its content will be copied to selected_adu_id which will trigger the retrieval. hover_adu_id = gr.Textbox( label="ID (hover)", elem_id="hover_adu_id", interactive=False, visible=False, ) selected_adu_id = gr.Textbox( label="ID (selected)", elem_id="selected_adu_id", interactive=False, visible=False, ) selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False) with gr.Accordion("Relevant ADUs from other documents", open=True): relevant_adus_df = gr.DataFrame( headers=[ "relation", "adu", "reference_adu", "doc_id", "sim_score", "rel_score", ], interactive=False, ) render_event_kwargs = dict( fn=lambda _retriever, _document_id, _render_as, _render_kwargs: render_annotated_document( retriever=_retriever[0], document_id=_document_id, render_with=_render_as, render_kwargs_json=_render_kwargs, ), inputs=[retriever_state, selected_document_id, render_as, render_kwargs], outputs=rendered_output, ) show_overview_kwargs = dict( fn=lambda _retriever: _retriever[0].docstore.overview( layer_captions=LAYER_CAPTIONS, use_predictions=True ), inputs=[retriever_state], outputs=[processed_documents_df], ) predict_btn.click( fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text( text=_doc_text, doc_id=_doc_id, argumentation_model=_argumentation_model[0], retriever=_retriever[0], split_regex_escaped=( unescape_regex(_split_regex_escaped) if _split_regex_escaped else None ), handle_parts_of_same=HANDLE_PARTS_OF_SAME, ), inputs=[ doc_text, doc_id, argumentation_model_state, retriever_state, split_regex_escaped, ], outputs=[selected_document_id], api_name="predict", ).success(**show_overview_kwargs).success(**render_event_kwargs) render_btn.click(**render_event_kwargs, api_name="render") load_arxiv_btn.click( fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv( arxiv_id=_arxiv_id, abstract_only=_load_arxiv_only_abstract, argumentation_model=_argumentation_model[0], retriever=_retriever[0], split_regex_escaped=( unescape_regex(_split_regex_escaped) if _split_regex_escaped else None ), handle_parts_of_same=HANDLE_PARTS_OF_SAME, ), inputs=[ arxiv_id, load_arxiv_only_abstract, argumentation_model_state, retriever_state, split_regex_escaped, ], outputs=[selected_document_id], api_name="predict", ).success(**show_overview_kwargs) load_pie_dataset_btn.click( fn=open_accordion, inputs=[], outputs=[processed_documents_accordion] ).then( fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset( retriever=_retriever[0], verbose=True, **json.loads(_load_pie_dataset_kwargs_str) ), inputs=[retriever_state, load_pie_dataset_kwargs_str], outputs=[processed_documents_df], ) selected_document_id.change(**render_event_kwargs) get_document_json_btn.click( fn=lambda _retriever, _document_id: get_document_as_dict( retriever=_retriever[0], doc_id=_document_id ), inputs=[retriever_state, selected_document_id], outputs=[document_json], ) upload_btn.upload( fn=open_accordion, inputs=[], outputs=[processed_documents_accordion] ).then( fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files( file_names=_file_names, argumentation_model=_argumentation_model[0], retriever=_retriever[0], split_regex_escaped=unescape_regex(_split_regex_escaped), handle_parts_of_same=HANDLE_PARTS_OF_SAME, ), inputs=[ upload_btn, argumentation_model_state, retriever_state, split_regex_escaped, ], outputs=[processed_documents_df], ) processed_documents_df.select( fn=get_cell_for_fixed_column_from_df, inputs=[processed_documents_df, gr.State("doc_id")], outputs=[selected_document_id], ) download_processed_documents_btn.click( fn=lambda _retriever: download_processed_documents( _retriever[0], file_name="processed_documents" ), inputs=[retriever_state], outputs=[download_processed_documents_btn], ) upload_processed_documents_btn.upload( fn=lambda file_name, _retriever: upload_processed_documents( file_name, retriever=_retriever[0] ), inputs=[upload_processed_documents_btn, retriever_state], outputs=[processed_documents_df], ) retrieve_relevant_adus_event_kwargs = dict( fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans( retriever=_retriever[0], query_span_id=_selected_adu_id, k=_top_k, score_threshold=_min_similarity, relation_label_mapping=RELATION_NAME_MAPPING, # columns=relevant_adus.headers ), inputs=[ retriever_state, selected_adu_id, min_similarity, top_k, ], outputs=[relevant_adus_df], ) relevant_adus_df.select( fn=get_cell_for_fixed_column_from_df, inputs=[relevant_adus_df, gr.State("doc_id")], outputs=[selected_document_id], ) selected_adu_id.change( fn=lambda _retriever, _selected_adu_id: get_span_annotation( retriever=_retriever[0], span_id=_selected_adu_id ), inputs=[retriever_state, selected_adu_id], outputs=[selected_adu_text], ).success(**retrieve_relevant_adus_event_kwargs) retrieve_similar_adus_btn.click( fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans( retriever=_retriever[0], query_span_id=_selected_adu_id, k=_tok_k, score_threshold=_min_similarity, ), inputs=[ retriever_state, selected_adu_id, min_similarity, top_k, ], outputs=[similar_adus_df], ) similar_adus_df.select( fn=get_cell_for_fixed_column_from_df, inputs=[similar_adus_df, gr.State("doc_id")], outputs=[selected_document_id], ) retrieve_all_similar_adus_btn.click( fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans( retriever=_retriever[0], query_doc_id=_document_id, k=_tok_k, score_threshold=_min_similarity, query_span_id_column="query_span_id", ), inputs=[ retriever_state, selected_document_id, min_similarity, top_k, ], outputs=[all_similar_adus_df], ) retrieve_all_relevant_adus_btn.click( fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_relevant_spans( retriever=_retriever[0], query_doc_id=_document_id, k=_tok_k, score_threshold=_min_similarity, query_span_id_column="query_span_id", ), inputs=[ retriever_state, selected_document_id, min_similarity, top_k, ], outputs=[all_relevant_adus_df], ) # select query span id from the "retrieve all" result data frames all_similar_adus_df.select( fn=get_cell_for_fixed_column_from_df, inputs=[all_similar_adus_df, gr.State("query_span_id")], outputs=[selected_adu_id], ) all_relevant_adus_df.select( fn=get_cell_for_fixed_column_from_df, inputs=[all_relevant_adus_df, gr.State("query_span_id")], outputs=[selected_adu_id], ) # argumentation_model_state.change( # fn=lambda _argumentation_model: set_relation_types(_argumentation_model[0]), # inputs=[argumentation_model_state], # outputs=[relation_types], # ) rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[]) demo.launch() if __name__ == "__main__": # configure logging logging.basicConfig() main()