|
import hydra |
|
import pyrootutils |
|
from omegaconf import DictConfig, OmegaConf, SCMode |
|
|
|
root = pyrootutils.setup_root( |
|
search_from=__file__, |
|
indicator=[".project-root"], |
|
pythonpath=True, |
|
dotenv=True, |
|
) |
|
|
|
import json |
|
import logging |
|
|
|
import gradio as gr |
|
import torch |
|
import yaml |
|
|
|
from src.demo.annotation_utils import load_argumentation_model |
|
from src.demo.backend_utils import ( |
|
download_processed_documents, |
|
load_acl_anthology_venues, |
|
process_text_from_arxiv, |
|
process_uploaded_files, |
|
process_uploaded_pdf_files, |
|
render_annotated_document, |
|
upload_processed_documents, |
|
wrapped_add_annotated_pie_documents_from_dataset, |
|
wrapped_process_text, |
|
) |
|
from src.demo.frontend_utils import ( |
|
change_tab, |
|
escape_regex, |
|
get_cell_for_fixed_column_from_df, |
|
open_accordion, |
|
open_accordion_with_stats, |
|
unescape_regex, |
|
) |
|
from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS |
|
from src.demo.retriever_utils import ( |
|
get_document_as_dict, |
|
get_span_annotation, |
|
load_retriever, |
|
retrieve_all_relevant_spans, |
|
retrieve_all_similar_spans, |
|
retrieve_relevant_spans, |
|
retrieve_similar_spans, |
|
) |
|
|
|
|
|
def load_yaml_config(path: str) -> str: |
|
with open(path, "r") as file: |
|
yaml_string = file.read() |
|
config = yaml.safe_load(yaml_string) |
|
return yaml.dump(config) |
|
|
|
|
|
def resolve_config(cfg) -> dict: |
|
return OmegaConf.to_container(cfg, resolve=True, structured_config_mode=SCMode.DICT) |
|
|
|
|
|
@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="demo.yaml") |
|
def main(cfg: DictConfig) -> None: |
|
|
|
|
|
logging.basicConfig() |
|
|
|
|
|
cfg = resolve_config(cfg) |
|
|
|
example_text = cfg["example_text"] |
|
|
|
default_device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
default_retriever_config_str = yaml.dump(cfg["retriever"]) |
|
default_argumentation_model_config_str = yaml.dump(cfg["argumentation_model"]) |
|
handle_parts_of_same = cfg["handle_parts_of_same"] |
|
|
|
default_arxiv_id = cfg["default_arxiv_id"] |
|
default_load_pie_dataset_kwargs_str = json.dumps( |
|
cfg["default_load_pie_dataset_kwargs"], indent=2 |
|
) |
|
|
|
default_render_mode = cfg["default_render_mode"] |
|
if default_render_mode not in AVAILABLE_RENDER_MODES: |
|
raise ValueError( |
|
f"Invalid default render mode '{default_render_mode}'. " |
|
f"Choose one of {AVAILABLE_RENDER_MODES}." |
|
) |
|
default_render_kwargs = cfg["default_render_kwargs"] |
|
|
|
|
|
default_split_regex = cfg["default_split_regex"] |
|
|
|
render_mode2caption = { |
|
render_mode: cfg["render_mode_captions"].get(render_mode, render_mode) |
|
for render_mode in AVAILABLE_RENDER_MODES |
|
} |
|
render_caption2mode = {v: k for k, v in render_mode2caption.items()} |
|
default_min_similarity = cfg["default_min_similarity"] |
|
default_top_k = cfg["default_top_k"] |
|
layer_caption_mapping = cfg["layer_caption_mapping"] |
|
relation_name_mapping = cfg["relation_name_mapping"] |
|
|
|
indexed_documents_label = "Indexed Documents" |
|
indexed_documents_caption2column = { |
|
"documents": "TOTAL", |
|
"ADUs": "num_adus", |
|
"Relations": "num_relations", |
|
} |
|
|
|
gr.Info("Loading models ...") |
|
argumentation_model = load_argumentation_model( |
|
config_str=default_argumentation_model_config_str, |
|
device=default_device, |
|
) |
|
retriever = load_retriever( |
|
config_str=default_retriever_config_str, device=default_device, config_format="yaml" |
|
) |
|
|
|
if cfg.get("pdf_fulltext_extractor"): |
|
gr.Info("Loading PDF fulltext extractor ...") |
|
pdf_fulltext_extractor = hydra.utils.instantiate(cfg["pdf_fulltext_extractor"]) |
|
else: |
|
pdf_fulltext_extractor = None |
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
argumentation_model_state = gr.State((argumentation_model,)) |
|
retriever_state = gr.State((retriever,)) |
|
|
|
with gr.Row(): |
|
with gr.Tabs() as left_tabs: |
|
with gr.Tab("User Input", id="user_input") as user_input_tab: |
|
doc_id = gr.Textbox( |
|
label="Document ID", |
|
value="user_input", |
|
) |
|
doc_text = gr.Textbox( |
|
label="Text", |
|
lines=20, |
|
value=example_text, |
|
) |
|
|
|
with gr.Accordion("Model Configuration", open=False): |
|
with gr.Accordion("argumentation structure", open=True): |
|
argumentation_model_config_str = gr.Code( |
|
language="yaml", |
|
label="Argumentation Model Configuration", |
|
value=default_argumentation_model_config_str, |
|
lines=len(default_argumentation_model_config_str.split("\n")), |
|
) |
|
load_arg_model_btn = gr.Button("Load Argumentation Model") |
|
|
|
with gr.Accordion("retriever", open=True): |
|
retriever_config_str = gr.Code( |
|
language="yaml", |
|
label="Retriever Configuration", |
|
value=default_retriever_config_str, |
|
lines=len(default_retriever_config_str.split("\n")), |
|
) |
|
load_retriever_btn = gr.Button("Load Retriever") |
|
|
|
device = gr.Textbox( |
|
label="Device (e.g. 'cuda' or 'cpu')", |
|
value=default_device, |
|
) |
|
load_arg_model_btn.click( |
|
fn=lambda _argumentation_model_config_str, _device: ( |
|
load_argumentation_model( |
|
config_str=_argumentation_model_config_str, |
|
device=_device, |
|
), |
|
), |
|
inputs=[argumentation_model_config_str, device], |
|
outputs=argumentation_model_state, |
|
) |
|
load_retriever_btn.click( |
|
fn=lambda _retriever_config, _device, _previous_retriever: ( |
|
load_retriever( |
|
config_str=_retriever_config, |
|
device=_device, |
|
previous_retriever=_previous_retriever[0], |
|
config_format="yaml", |
|
), |
|
), |
|
inputs=[retriever_config_str, device, retriever_state], |
|
outputs=retriever_state, |
|
) |
|
|
|
split_regex_escaped = gr.Textbox( |
|
label="Regex to partition the text", |
|
placeholder="Regular expression pattern to split the text into partitions", |
|
value=escape_regex(default_split_regex), |
|
) |
|
|
|
predict_btn = gr.Button("Analyse") |
|
|
|
with gr.Tab("Analysed Document", id="analysed_document") as analysed_document_tab: |
|
selected_document_id = gr.Textbox( |
|
label="Document ID", max_lines=1, interactive=False |
|
) |
|
rendered_output = gr.HTML(label="Rendered Output") |
|
|
|
with gr.Accordion("Render Options", open=False): |
|
render_as = gr.Dropdown( |
|
label="Render with", |
|
choices=list(render_mode2caption.values()), |
|
value=render_mode2caption[default_render_mode], |
|
) |
|
render_kwargs = gr.Code( |
|
language="json", |
|
label="Render Arguments", |
|
lines=len(json.dumps(default_render_kwargs, indent=2).split("\n")), |
|
value=json.dumps(default_render_kwargs, indent=2), |
|
) |
|
render_btn = gr.Button("Re-render") |
|
|
|
with gr.Accordion("See plain result ...", open=False): |
|
get_document_json_btn = gr.Button("Fetch annotated document as JSON") |
|
document_json = gr.JSON(label="Model Output") |
|
|
|
with gr.Tabs() as right_tabs: |
|
with gr.Tab("Retrieval", id="retrieval") as retrieval_tab: |
|
with gr.Accordion( |
|
indexed_documents_label, open=False |
|
) as processed_documents_accordion: |
|
processed_documents_df = gr.DataFrame( |
|
headers=["id", "num_adus", "num_relations"], |
|
interactive=False, |
|
elem_classes="df-docstore", |
|
) |
|
gr.Markdown("Data Snapshot:") |
|
with gr.Row(): |
|
download_processed_documents_btn = gr.DownloadButton("Download") |
|
upload_processed_documents_btn = gr.UploadButton( |
|
"Upload", file_types=["file"] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hover_adu_id = gr.Textbox( |
|
label="ID (hover)", |
|
elem_id="hover_adu_id", |
|
interactive=False, |
|
visible=False, |
|
) |
|
selected_adu_id = gr.Textbox( |
|
label="ID (selected)", |
|
elem_id="selected_adu_id", |
|
interactive=False, |
|
visible=False, |
|
) |
|
selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False) |
|
|
|
with gr.Accordion("Relevant ADUs from other documents", open=True): |
|
relevant_adus_df = gr.DataFrame( |
|
headers=[ |
|
"relation", |
|
"adu", |
|
"reference_adu", |
|
"doc_id", |
|
"sim_score", |
|
"rel_score", |
|
], |
|
interactive=False, |
|
) |
|
|
|
with gr.Accordion("Retrieval Configuration", open=False): |
|
min_similarity = gr.Slider( |
|
label="Minimum Similarity", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.01, |
|
value=default_min_similarity, |
|
) |
|
top_k = gr.Slider( |
|
label="Top K", |
|
minimum=2, |
|
maximum=50, |
|
step=1, |
|
value=default_top_k, |
|
) |
|
retrieve_similar_adus_btn = gr.Button( |
|
"Retrieve *similar* ADUs for *selected* ADU" |
|
) |
|
similar_adus_df = gr.DataFrame( |
|
headers=["doc_id", "adu_id", "score", "text"], interactive=False |
|
) |
|
retrieve_all_similar_adus_btn = gr.Button( |
|
"Retrieve *similar* ADUs for *all* ADUs in the document" |
|
) |
|
all_similar_adus_df = gr.DataFrame( |
|
headers=["doc_id", "query_adu_id", "adu_id", "score", "text"], |
|
interactive=False, |
|
) |
|
retrieve_all_relevant_adus_btn = gr.Button( |
|
"Retrieve *relevant* ADUs for *all* ADUs in the document" |
|
) |
|
all_relevant_adus_df = gr.DataFrame( |
|
headers=["doc_id", "adu_id", "score", "text", "query_span_id"], |
|
interactive=False, |
|
) |
|
all_relevant_adus_query_doc_id = gr.Textbox(visible=False) |
|
|
|
with gr.Tab("Import Documents", id="import_documents") as import_documents_tab: |
|
upload_btn = gr.UploadButton( |
|
"Batch Analyse Texts", |
|
file_types=["text"], |
|
file_count="multiple", |
|
) |
|
|
|
upload_pdf_btn = gr.UploadButton( |
|
"Batch Analyse PDFs", |
|
|
|
file_count="multiple", |
|
visible=pdf_fulltext_extractor is not None, |
|
) |
|
|
|
enable_acl_venue_loading = pdf_fulltext_extractor is not None and cfg.get( |
|
"acl_anthology_pdf_dir" |
|
) |
|
acl_anthology_venues = gr.Textbox( |
|
label="ACL Anthology Venues", |
|
value="wiesp", |
|
max_lines=1, |
|
visible=enable_acl_venue_loading, |
|
) |
|
load_acl_anthology_venues_btn = gr.Button( |
|
"Import from ACL Anthology", |
|
variant="secondary", |
|
visible=enable_acl_venue_loading, |
|
) |
|
|
|
with gr.Accordion("Import text from arXiv", open=False): |
|
arxiv_id = gr.Textbox( |
|
label="arXiv paper ID", |
|
placeholder=f"e.g. {default_arxiv_id}", |
|
max_lines=1, |
|
) |
|
load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False) |
|
load_arxiv_btn = gr.Button( |
|
"Load & Analyse from arXiv", variant="secondary" |
|
) |
|
|
|
with gr.Accordion( |
|
"Import argument structure annotated PIE dataset", open=False |
|
): |
|
load_pie_dataset_kwargs_str = gr.Code( |
|
language="json", |
|
label="Parameters for Loading the PIE Dataset", |
|
value=default_load_pie_dataset_kwargs_str, |
|
lines=len(default_load_pie_dataset_kwargs_str.split("\n")), |
|
) |
|
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset") |
|
|
|
render_event_kwargs = dict( |
|
fn=lambda _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id: render_annotated_document( |
|
retriever=_retriever[0], |
|
document_id=_document_id, |
|
render_with=render_caption2mode[_render_as], |
|
render_kwargs_json=_render_kwargs, |
|
highlight_span_ids=( |
|
_all_relevant_adus_df["query_span_id"].tolist() |
|
if _document_id == _all_relevant_adus_query_doc_id |
|
else None |
|
), |
|
), |
|
inputs=[ |
|
retriever_state, |
|
selected_document_id, |
|
render_as, |
|
render_kwargs, |
|
all_relevant_adus_df, |
|
all_relevant_adus_query_doc_id, |
|
], |
|
outputs=rendered_output, |
|
) |
|
|
|
show_overview_kwargs = dict( |
|
fn=lambda _retriever: _retriever[0].docstore.overview( |
|
layer_captions=layer_caption_mapping, use_predictions=True |
|
), |
|
inputs=[retriever_state], |
|
outputs=[processed_documents_df], |
|
) |
|
show_stats_kwargs = dict( |
|
fn=lambda _processed_documents_df: open_accordion_with_stats( |
|
_processed_documents_df, |
|
base_label=indexed_documents_label, |
|
caption2column=indexed_documents_caption2column, |
|
total_column="TOTAL", |
|
), |
|
inputs=[processed_documents_df], |
|
outputs=[processed_documents_accordion], |
|
) |
|
predict_btn.click( |
|
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] |
|
).then( |
|
fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text( |
|
text=_doc_text, |
|
doc_id=_doc_id, |
|
argumentation_model=_argumentation_model[0], |
|
retriever=_retriever[0], |
|
split_regex_escaped=( |
|
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None |
|
), |
|
handle_parts_of_same=handle_parts_of_same, |
|
), |
|
inputs=[ |
|
doc_text, |
|
doc_id, |
|
argumentation_model_state, |
|
retriever_state, |
|
split_regex_escaped, |
|
], |
|
outputs=[selected_document_id], |
|
api_name="predict", |
|
).success( |
|
**show_overview_kwargs |
|
).success( |
|
**show_stats_kwargs |
|
).success( |
|
**render_event_kwargs |
|
) |
|
render_btn.click(**render_event_kwargs, api_name="render") |
|
|
|
load_arxiv_btn.click( |
|
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] |
|
).then( |
|
fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv( |
|
arxiv_id=_arxiv_id.strip() or default_arxiv_id, |
|
abstract_only=_load_arxiv_only_abstract, |
|
argumentation_model=_argumentation_model[0], |
|
retriever=_retriever[0], |
|
split_regex_escaped=( |
|
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None |
|
), |
|
handle_parts_of_same=handle_parts_of_same, |
|
), |
|
inputs=[ |
|
arxiv_id, |
|
load_arxiv_only_abstract, |
|
argumentation_model_state, |
|
retriever_state, |
|
split_regex_escaped, |
|
], |
|
outputs=[selected_document_id], |
|
api_name="predict", |
|
).success( |
|
**show_overview_kwargs |
|
).success( |
|
**show_stats_kwargs |
|
) |
|
|
|
load_pie_dataset_btn.click( |
|
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] |
|
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( |
|
fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset( |
|
retriever=_retriever[0], |
|
verbose=True, |
|
layer_captions=layer_caption_mapping, |
|
**json.loads(_load_pie_dataset_kwargs_str), |
|
), |
|
inputs=[retriever_state, load_pie_dataset_kwargs_str], |
|
outputs=[processed_documents_df], |
|
).success( |
|
**show_stats_kwargs |
|
) |
|
|
|
selected_document_id.change( |
|
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] |
|
).then(**render_event_kwargs) |
|
|
|
get_document_json_btn.click( |
|
fn=lambda _retriever, _document_id: get_document_as_dict( |
|
retriever=_retriever[0], doc_id=_document_id |
|
), |
|
inputs=[retriever_state, selected_document_id], |
|
outputs=[document_json], |
|
) |
|
|
|
upload_btn.upload( |
|
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] |
|
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( |
|
fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files( |
|
file_names=_file_names, |
|
argumentation_model=_argumentation_model[0], |
|
retriever=_retriever[0], |
|
split_regex_escaped=( |
|
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None |
|
), |
|
handle_parts_of_same=handle_parts_of_same, |
|
layer_captions=layer_caption_mapping, |
|
), |
|
inputs=[ |
|
upload_btn, |
|
argumentation_model_state, |
|
retriever_state, |
|
split_regex_escaped, |
|
], |
|
outputs=[processed_documents_df], |
|
).success( |
|
**show_stats_kwargs |
|
) |
|
upload_pdf_btn.upload( |
|
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] |
|
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( |
|
fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_pdf_files( |
|
file_names=_file_names, |
|
argumentation_model=_argumentation_model[0], |
|
retriever=_retriever[0], |
|
split_regex_escaped=( |
|
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None |
|
), |
|
handle_parts_of_same=handle_parts_of_same, |
|
layer_captions=layer_caption_mapping, |
|
pdf_fulltext_extractor=pdf_fulltext_extractor, |
|
), |
|
inputs=[ |
|
upload_pdf_btn, |
|
argumentation_model_state, |
|
retriever_state, |
|
split_regex_escaped, |
|
], |
|
outputs=[processed_documents_df], |
|
).success( |
|
**show_stats_kwargs |
|
) |
|
|
|
load_acl_anthology_venues_btn.click( |
|
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] |
|
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( |
|
fn=lambda _acl_anthology_venues, _argumentation_model, _retriever, _split_regex_escaped: load_acl_anthology_venues( |
|
pdf_fulltext_extractor=pdf_fulltext_extractor, |
|
venues=[venue.strip() for venue in _acl_anthology_venues.split(",")], |
|
argumentation_model=_argumentation_model[0], |
|
retriever=_retriever[0], |
|
split_regex_escaped=( |
|
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None |
|
), |
|
handle_parts_of_same=handle_parts_of_same, |
|
layer_captions=layer_caption_mapping, |
|
acl_anthology_data_dir=cfg.get("acl_anthology_data_dir"), |
|
pdf_output_dir=cfg.get("acl_anthology_pdf_dir"), |
|
), |
|
inputs=[ |
|
acl_anthology_venues, |
|
argumentation_model_state, |
|
retriever_state, |
|
split_regex_escaped, |
|
], |
|
outputs=[processed_documents_df], |
|
).success( |
|
**show_stats_kwargs |
|
) |
|
|
|
processed_documents_df.select( |
|
fn=get_cell_for_fixed_column_from_df, |
|
inputs=[processed_documents_df, gr.State("doc_id")], |
|
outputs=[selected_document_id], |
|
) |
|
|
|
download_processed_documents_btn.click( |
|
fn=lambda _retriever: download_processed_documents( |
|
_retriever[0], file_name="processed_documents" |
|
), |
|
inputs=[retriever_state], |
|
outputs=[download_processed_documents_btn], |
|
) |
|
upload_processed_documents_btn.upload( |
|
fn=lambda file_name, _retriever: upload_processed_documents( |
|
file_name, retriever=_retriever[0], layer_captions=layer_caption_mapping |
|
), |
|
inputs=[upload_processed_documents_btn, retriever_state], |
|
outputs=[processed_documents_df], |
|
).success(**show_stats_kwargs) |
|
|
|
retrieve_relevant_adus_event_kwargs = dict( |
|
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans( |
|
retriever=_retriever[0], |
|
query_span_id=_selected_adu_id, |
|
k=_top_k, |
|
score_threshold=_min_similarity, |
|
relation_label_mapping=relation_name_mapping, |
|
|
|
), |
|
inputs=[ |
|
retriever_state, |
|
selected_adu_id, |
|
min_similarity, |
|
top_k, |
|
], |
|
outputs=[relevant_adus_df], |
|
) |
|
relevant_adus_df.select( |
|
fn=get_cell_for_fixed_column_from_df, |
|
inputs=[relevant_adus_df, gr.State("doc_id")], |
|
outputs=[selected_document_id], |
|
) |
|
|
|
selected_adu_id.change( |
|
fn=lambda _retriever, _selected_adu_id: get_span_annotation( |
|
retriever=_retriever[0], span_id=_selected_adu_id |
|
), |
|
inputs=[retriever_state, selected_adu_id], |
|
outputs=[selected_adu_text], |
|
).success(**retrieve_relevant_adus_event_kwargs) |
|
|
|
retrieve_similar_adus_btn.click( |
|
fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans( |
|
retriever=_retriever[0], |
|
query_span_id=_selected_adu_id, |
|
k=_tok_k, |
|
score_threshold=_min_similarity, |
|
), |
|
inputs=[ |
|
retriever_state, |
|
selected_adu_id, |
|
min_similarity, |
|
top_k, |
|
], |
|
outputs=[similar_adus_df], |
|
) |
|
similar_adus_df.select( |
|
fn=get_cell_for_fixed_column_from_df, |
|
inputs=[similar_adus_df, gr.State("doc_id")], |
|
outputs=[selected_document_id], |
|
) |
|
|
|
retrieve_all_similar_adus_btn.click( |
|
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans( |
|
retriever=_retriever[0], |
|
query_doc_id=_document_id, |
|
k=_tok_k, |
|
score_threshold=_min_similarity, |
|
query_span_id_column="query_span_id", |
|
), |
|
inputs=[ |
|
retriever_state, |
|
selected_document_id, |
|
min_similarity, |
|
top_k, |
|
], |
|
outputs=[all_similar_adus_df], |
|
) |
|
|
|
retrieve_all_relevant_adus_btn.click( |
|
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: ( |
|
retrieve_all_relevant_spans( |
|
retriever=_retriever[0], |
|
query_doc_id=_document_id, |
|
k=_tok_k, |
|
score_threshold=_min_similarity, |
|
query_span_id_column="query_span_id", |
|
query_span_text_column="query_span_text", |
|
), |
|
_document_id, |
|
), |
|
inputs=[ |
|
retriever_state, |
|
selected_document_id, |
|
min_similarity, |
|
top_k, |
|
], |
|
outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id], |
|
) |
|
|
|
all_relevant_adus_df.change(**render_event_kwargs) |
|
|
|
|
|
all_similar_adus_df.select( |
|
fn=get_cell_for_fixed_column_from_df, |
|
inputs=[all_similar_adus_df, gr.State("query_span_id")], |
|
outputs=[selected_adu_id], |
|
) |
|
all_relevant_adus_df.select( |
|
fn=get_cell_for_fixed_column_from_df, |
|
inputs=[all_relevant_adus_df, gr.State("query_span_id")], |
|
outputs=[selected_adu_id], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[]) |
|
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |
|
|