import hydra import pyrootutils from omegaconf import DictConfig, OmegaConf, SCMode root = pyrootutils.setup_root( search_from=__file__, indicator=[".project-root"], pythonpath=True, dotenv=True, ) import json import logging import gradio as gr import torch import yaml from src.demo.annotation_utils import load_argumentation_model from src.demo.backend_utils import ( download_processed_documents, load_acl_anthology_venues, process_text_from_arxiv, process_uploaded_files, process_uploaded_pdf_files, render_annotated_document, upload_processed_documents, wrapped_add_annotated_pie_documents_from_dataset, wrapped_process_text, ) from src.demo.frontend_utils import ( change_tab, escape_regex, get_cell_for_fixed_column_from_df, open_accordion, open_accordion_with_stats, unescape_regex, ) from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS from src.demo.retriever_utils import ( get_document_as_dict, get_span_annotation, load_retriever, retrieve_all_relevant_spans, retrieve_all_similar_spans, retrieve_relevant_spans, retrieve_similar_spans, ) def load_yaml_config(path: str) -> str: with open(path, "r") as file: yaml_string = file.read() config = yaml.safe_load(yaml_string) return yaml.dump(config) def resolve_config(cfg) -> dict: return OmegaConf.to_container(cfg, resolve=True, structured_config_mode=SCMode.DICT) @hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="demo.yaml") def main(cfg: DictConfig) -> None: # configure logging logging.basicConfig() # resolve everything in the config to prevent any issues with to json serialization etc. cfg = resolve_config(cfg) example_text = cfg["example_text"] default_device = "cuda:0" if torch.cuda.is_available() else "cpu" default_retriever_config_str = yaml.dump(cfg["retriever"]) default_argumentation_model_config_str = yaml.dump(cfg["argumentation_model"]) handle_parts_of_same = cfg["handle_parts_of_same"] default_arxiv_id = cfg["default_arxiv_id"] default_load_pie_dataset_kwargs_str = json.dumps( cfg["default_load_pie_dataset_kwargs"], indent=2 ) default_render_mode = cfg["default_render_mode"] if default_render_mode not in AVAILABLE_RENDER_MODES: raise ValueError( f"Invalid default render mode '{default_render_mode}'. " f"Choose one of {AVAILABLE_RENDER_MODES}." ) default_render_kwargs = cfg["default_render_kwargs"] # captions for better readability default_split_regex = cfg["default_split_regex"] # map from render mode to the corresponding caption render_mode2caption = { render_mode: cfg["render_mode_captions"].get(render_mode, render_mode) for render_mode in AVAILABLE_RENDER_MODES } render_caption2mode = {v: k for k, v in render_mode2caption.items()} default_min_similarity = cfg["default_min_similarity"] default_top_k = cfg["default_top_k"] layer_caption_mapping = cfg["layer_caption_mapping"] relation_name_mapping = cfg["relation_name_mapping"] indexed_documents_label = "Indexed Documents" indexed_documents_caption2column = { "documents": "TOTAL", "ADUs": "num_adus", "Relations": "num_relations", } gr.Info("Loading models ...") argumentation_model = load_argumentation_model( config_str=default_argumentation_model_config_str, device=default_device, ) retriever = load_retriever( config_str=default_retriever_config_str, device=default_device, config_format="yaml" ) if cfg.get("pdf_fulltext_extractor"): gr.Info("Loading PDF fulltext extractor ...") pdf_fulltext_extractor = hydra.utils.instantiate(cfg["pdf_fulltext_extractor"]) else: pdf_fulltext_extractor = None with gr.Blocks() as demo: # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called # models_state = gr.State((argumentation_model, embedding_model)) argumentation_model_state = gr.State((argumentation_model,)) retriever_state = gr.State((retriever,)) with gr.Row(): with gr.Tabs() as left_tabs: with gr.Tab("User Input", id="user_input") as user_input_tab: doc_id = gr.Textbox( label="Document ID", value="user_input", ) doc_text = gr.Textbox( label="Text", lines=20, value=example_text, ) with gr.Accordion("Model Configuration", open=False): with gr.Accordion("argumentation structure", open=True): argumentation_model_config_str = gr.Code( language="yaml", label="Argumentation Model Configuration", value=default_argumentation_model_config_str, lines=len(default_argumentation_model_config_str.split("\n")), ) load_arg_model_btn = gr.Button("Load Argumentation Model") with gr.Accordion("retriever", open=True): retriever_config_str = gr.Code( language="yaml", label="Retriever Configuration", value=default_retriever_config_str, lines=len(default_retriever_config_str.split("\n")), ) load_retriever_btn = gr.Button("Load Retriever") device = gr.Textbox( label="Device (e.g. 'cuda' or 'cpu')", value=default_device, ) load_arg_model_btn.click( fn=lambda _argumentation_model_config_str, _device: ( load_argumentation_model( config_str=_argumentation_model_config_str, device=_device, ), ), inputs=[argumentation_model_config_str, device], outputs=argumentation_model_state, ) load_retriever_btn.click( fn=lambda _retriever_config, _device, _previous_retriever: ( load_retriever( config_str=_retriever_config, device=_device, previous_retriever=_previous_retriever[0], config_format="yaml", ), ), inputs=[retriever_config_str, device, retriever_state], outputs=retriever_state, ) split_regex_escaped = gr.Textbox( label="Regex to partition the text", placeholder="Regular expression pattern to split the text into partitions", value=escape_regex(default_split_regex), ) predict_btn = gr.Button("Analyse") with gr.Tab("Analysed Document", id="analysed_document") as analysed_document_tab: selected_document_id = gr.Textbox( label="Document ID", max_lines=1, interactive=False ) rendered_output = gr.HTML(label="Rendered Output") with gr.Accordion("Render Options", open=False): render_as = gr.Dropdown( label="Render with", choices=list(render_mode2caption.values()), value=render_mode2caption[default_render_mode], ) render_kwargs = gr.Code( language="json", label="Render Arguments", lines=len(json.dumps(default_render_kwargs, indent=2).split("\n")), value=json.dumps(default_render_kwargs, indent=2), ) render_btn = gr.Button("Re-render") with gr.Accordion("See plain result ...", open=False): get_document_json_btn = gr.Button("Fetch annotated document as JSON") document_json = gr.JSON(label="Model Output") with gr.Tabs() as right_tabs: with gr.Tab("Retrieval", id="retrieval") as retrieval_tab: with gr.Accordion( indexed_documents_label, open=False ) as processed_documents_accordion: processed_documents_df = gr.DataFrame( headers=["id", "num_adus", "num_relations"], interactive=False, elem_classes="df-docstore", ) gr.Markdown("Data Snapshot:") with gr.Row(): download_processed_documents_btn = gr.DownloadButton("Download") upload_processed_documents_btn = gr.UploadButton( "Upload", file_types=["file"] ) # currently not used # relation_types = set_relation_types( # argumentation_model_state.value[0], default=["supports_reversed", "contradicts_reversed"] # ) # Dummy textbox to hold the hover adu id. On click on the rendered output, # its content will be copied to selected_adu_id which will trigger the retrieval. hover_adu_id = gr.Textbox( label="ID (hover)", elem_id="hover_adu_id", interactive=False, visible=False, ) selected_adu_id = gr.Textbox( label="ID (selected)", elem_id="selected_adu_id", interactive=False, visible=False, ) selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False) with gr.Accordion("Relevant ADUs from other documents", open=True): relevant_adus_df = gr.DataFrame( headers=[ "relation", "adu", "reference_adu", "doc_id", "sim_score", "rel_score", ], interactive=False, ) with gr.Accordion("Retrieval Configuration", open=False): min_similarity = gr.Slider( label="Minimum Similarity", minimum=0.0, maximum=1.0, step=0.01, value=default_min_similarity, ) top_k = gr.Slider( label="Top K", minimum=2, maximum=50, step=1, value=default_top_k, ) retrieve_similar_adus_btn = gr.Button( "Retrieve *similar* ADUs for *selected* ADU" ) similar_adus_df = gr.DataFrame( headers=["doc_id", "adu_id", "score", "text"], interactive=False ) retrieve_all_similar_adus_btn = gr.Button( "Retrieve *similar* ADUs for *all* ADUs in the document" ) all_similar_adus_df = gr.DataFrame( headers=["doc_id", "query_adu_id", "adu_id", "score", "text"], interactive=False, ) retrieve_all_relevant_adus_btn = gr.Button( "Retrieve *relevant* ADUs for *all* ADUs in the document" ) all_relevant_adus_df = gr.DataFrame( headers=["doc_id", "adu_id", "score", "text", "query_span_id"], interactive=False, ) all_relevant_adus_query_doc_id = gr.Textbox(visible=False) with gr.Tab("Import Documents", id="import_documents") as import_documents_tab: upload_btn = gr.UploadButton( "Batch Analyse Texts", file_types=["text"], file_count="multiple", ) upload_pdf_btn = gr.UploadButton( "Batch Analyse PDFs", # file_types=["pdf"], file_count="multiple", visible=pdf_fulltext_extractor is not None, ) enable_acl_venue_loading = pdf_fulltext_extractor is not None and cfg.get( "acl_anthology_pdf_dir" ) acl_anthology_venues = gr.Textbox( label="ACL Anthology Venues", value="wiesp", max_lines=1, visible=enable_acl_venue_loading, ) load_acl_anthology_venues_btn = gr.Button( "Import from ACL Anthology", variant="secondary", visible=enable_acl_venue_loading, ) with gr.Accordion("Import text from arXiv", open=False): arxiv_id = gr.Textbox( label="arXiv paper ID", placeholder=f"e.g. {default_arxiv_id}", max_lines=1, ) load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False) load_arxiv_btn = gr.Button( "Load & Analyse from arXiv", variant="secondary" ) with gr.Accordion( "Import argument structure annotated PIE dataset", open=False ): load_pie_dataset_kwargs_str = gr.Code( language="json", label="Parameters for Loading the PIE Dataset", value=default_load_pie_dataset_kwargs_str, lines=len(default_load_pie_dataset_kwargs_str.split("\n")), ) load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset") render_event_kwargs = dict( fn=lambda _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id: render_annotated_document( retriever=_retriever[0], document_id=_document_id, render_with=render_caption2mode[_render_as], render_kwargs_json=_render_kwargs, highlight_span_ids=( _all_relevant_adus_df["query_span_id"].tolist() if _document_id == _all_relevant_adus_query_doc_id else None ), ), inputs=[ retriever_state, selected_document_id, render_as, render_kwargs, all_relevant_adus_df, all_relevant_adus_query_doc_id, ], outputs=rendered_output, ) show_overview_kwargs = dict( fn=lambda _retriever: _retriever[0].docstore.overview( layer_captions=layer_caption_mapping, use_predictions=True ), inputs=[retriever_state], outputs=[processed_documents_df], ) show_stats_kwargs = dict( fn=lambda _processed_documents_df: open_accordion_with_stats( _processed_documents_df, base_label=indexed_documents_label, caption2column=indexed_documents_caption2column, total_column="TOTAL", ), inputs=[processed_documents_df], outputs=[processed_documents_accordion], ) predict_btn.click( fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] ).then( fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text( text=_doc_text, doc_id=_doc_id, argumentation_model=_argumentation_model[0], retriever=_retriever[0], split_regex_escaped=( unescape_regex(_split_regex_escaped) if _split_regex_escaped else None ), handle_parts_of_same=handle_parts_of_same, ), inputs=[ doc_text, doc_id, argumentation_model_state, retriever_state, split_regex_escaped, ], outputs=[selected_document_id], api_name="predict", ).success( **show_overview_kwargs ).success( **show_stats_kwargs ).success( **render_event_kwargs ) render_btn.click(**render_event_kwargs, api_name="render") load_arxiv_btn.click( fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] ).then( fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv( arxiv_id=_arxiv_id.strip() or default_arxiv_id, abstract_only=_load_arxiv_only_abstract, argumentation_model=_argumentation_model[0], retriever=_retriever[0], split_regex_escaped=( unescape_regex(_split_regex_escaped) if _split_regex_escaped else None ), handle_parts_of_same=handle_parts_of_same, ), inputs=[ arxiv_id, load_arxiv_only_abstract, argumentation_model_state, retriever_state, split_regex_escaped, ], outputs=[selected_document_id], api_name="predict", ).success( **show_overview_kwargs ).success( **show_stats_kwargs ) load_pie_dataset_btn.click( fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset( retriever=_retriever[0], verbose=True, layer_captions=layer_caption_mapping, **json.loads(_load_pie_dataset_kwargs_str), ), inputs=[retriever_state, load_pie_dataset_kwargs_str], outputs=[processed_documents_df], ).success( **show_stats_kwargs ) selected_document_id.change( fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] ).then(**render_event_kwargs) get_document_json_btn.click( fn=lambda _retriever, _document_id: get_document_as_dict( retriever=_retriever[0], doc_id=_document_id ), inputs=[retriever_state, selected_document_id], outputs=[document_json], ) upload_btn.upload( fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files( file_names=_file_names, argumentation_model=_argumentation_model[0], retriever=_retriever[0], split_regex_escaped=( unescape_regex(_split_regex_escaped) if _split_regex_escaped else None ), handle_parts_of_same=handle_parts_of_same, layer_captions=layer_caption_mapping, ), inputs=[ upload_btn, argumentation_model_state, retriever_state, split_regex_escaped, ], outputs=[processed_documents_df], ).success( **show_stats_kwargs ) upload_pdf_btn.upload( fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_pdf_files( file_names=_file_names, argumentation_model=_argumentation_model[0], retriever=_retriever[0], split_regex_escaped=( unescape_regex(_split_regex_escaped) if _split_regex_escaped else None ), handle_parts_of_same=handle_parts_of_same, layer_captions=layer_caption_mapping, pdf_fulltext_extractor=pdf_fulltext_extractor, ), inputs=[ upload_pdf_btn, argumentation_model_state, retriever_state, split_regex_escaped, ], outputs=[processed_documents_df], ).success( **show_stats_kwargs ) load_acl_anthology_venues_btn.click( fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( fn=lambda _acl_anthology_venues, _argumentation_model, _retriever, _split_regex_escaped: load_acl_anthology_venues( pdf_fulltext_extractor=pdf_fulltext_extractor, venues=[venue.strip() for venue in _acl_anthology_venues.split(",")], argumentation_model=_argumentation_model[0], retriever=_retriever[0], split_regex_escaped=( unescape_regex(_split_regex_escaped) if _split_regex_escaped else None ), handle_parts_of_same=handle_parts_of_same, layer_captions=layer_caption_mapping, acl_anthology_data_dir=cfg.get("acl_anthology_data_dir"), pdf_output_dir=cfg.get("acl_anthology_pdf_dir"), ), inputs=[ acl_anthology_venues, argumentation_model_state, retriever_state, split_regex_escaped, ], outputs=[processed_documents_df], ).success( **show_stats_kwargs ) processed_documents_df.select( fn=get_cell_for_fixed_column_from_df, inputs=[processed_documents_df, gr.State("doc_id")], outputs=[selected_document_id], ) download_processed_documents_btn.click( fn=lambda _retriever: download_processed_documents( _retriever[0], file_name="processed_documents" ), inputs=[retriever_state], outputs=[download_processed_documents_btn], ) upload_processed_documents_btn.upload( fn=lambda file_name, _retriever: upload_processed_documents( file_name, retriever=_retriever[0], layer_captions=layer_caption_mapping ), inputs=[upload_processed_documents_btn, retriever_state], outputs=[processed_documents_df], ).success(**show_stats_kwargs) retrieve_relevant_adus_event_kwargs = dict( fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans( retriever=_retriever[0], query_span_id=_selected_adu_id, k=_top_k, score_threshold=_min_similarity, relation_label_mapping=relation_name_mapping, # columns=relevant_adus.headers ), inputs=[ retriever_state, selected_adu_id, min_similarity, top_k, ], outputs=[relevant_adus_df], ) relevant_adus_df.select( fn=get_cell_for_fixed_column_from_df, inputs=[relevant_adus_df, gr.State("doc_id")], outputs=[selected_document_id], ) selected_adu_id.change( fn=lambda _retriever, _selected_adu_id: get_span_annotation( retriever=_retriever[0], span_id=_selected_adu_id ), inputs=[retriever_state, selected_adu_id], outputs=[selected_adu_text], ).success(**retrieve_relevant_adus_event_kwargs) retrieve_similar_adus_btn.click( fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans( retriever=_retriever[0], query_span_id=_selected_adu_id, k=_tok_k, score_threshold=_min_similarity, ), inputs=[ retriever_state, selected_adu_id, min_similarity, top_k, ], outputs=[similar_adus_df], ) similar_adus_df.select( fn=get_cell_for_fixed_column_from_df, inputs=[similar_adus_df, gr.State("doc_id")], outputs=[selected_document_id], ) retrieve_all_similar_adus_btn.click( fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans( retriever=_retriever[0], query_doc_id=_document_id, k=_tok_k, score_threshold=_min_similarity, query_span_id_column="query_span_id", ), inputs=[ retriever_state, selected_document_id, min_similarity, top_k, ], outputs=[all_similar_adus_df], ) retrieve_all_relevant_adus_btn.click( fn=lambda _retriever, _document_id, _min_similarity, _tok_k: ( retrieve_all_relevant_spans( retriever=_retriever[0], query_doc_id=_document_id, k=_tok_k, score_threshold=_min_similarity, query_span_id_column="query_span_id", query_span_text_column="query_span_text", ), _document_id, ), inputs=[ retriever_state, selected_document_id, min_similarity, top_k, ], outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id], ) all_relevant_adus_df.change(**render_event_kwargs) # select query span id from the "retrieve all" result data frames all_similar_adus_df.select( fn=get_cell_for_fixed_column_from_df, inputs=[all_similar_adus_df, gr.State("query_span_id")], outputs=[selected_adu_id], ) all_relevant_adus_df.select( fn=get_cell_for_fixed_column_from_df, inputs=[all_relevant_adus_df, gr.State("query_span_id")], outputs=[selected_adu_id], ) # argumentation_model_state.change( # fn=lambda _argumentation_model: set_relation_types(_argumentation_model[0]), # inputs=[argumentation_model_state], # outputs=[relation_types], # ) rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[]) demo.launch() if __name__ == "__main__": main()