ArneBinder's picture
update from https://github.com/ArneBinder/pie-document-level/pull/397
ced4316 verified
raw
history blame
29.4 kB
import hydra
import pyrootutils
from omegaconf import DictConfig, OmegaConf, SCMode
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".project-root"],
pythonpath=True,
dotenv=True,
)
import json
import logging
import gradio as gr
import torch
import yaml
from src.demo.annotation_utils import load_argumentation_model
from src.demo.backend_utils import (
download_processed_documents,
load_acl_anthology_venues,
process_text_from_arxiv,
process_uploaded_files,
process_uploaded_pdf_files,
render_annotated_document,
upload_processed_documents,
wrapped_add_annotated_pie_documents_from_dataset,
wrapped_process_text,
)
from src.demo.frontend_utils import (
change_tab,
escape_regex,
get_cell_for_fixed_column_from_df,
open_accordion,
open_accordion_with_stats,
unescape_regex,
)
from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS
from src.demo.retriever_utils import (
get_document_as_dict,
get_span_annotation,
load_retriever,
retrieve_all_relevant_spans,
retrieve_all_similar_spans,
retrieve_relevant_spans,
retrieve_similar_spans,
)
def load_yaml_config(path: str) -> str:
with open(path, "r") as file:
yaml_string = file.read()
config = yaml.safe_load(yaml_string)
return yaml.dump(config)
def resolve_config(cfg) -> dict:
return OmegaConf.to_container(cfg, resolve=True, structured_config_mode=SCMode.DICT)
@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="demo.yaml")
def main(cfg: DictConfig) -> None:
# configure logging
logging.basicConfig()
# resolve everything in the config to prevent any issues with to json serialization etc.
cfg = resolve_config(cfg)
example_text = cfg["example_text"]
default_device = "cuda:0" if torch.cuda.is_available() else "cpu"
default_retriever_config_str = yaml.dump(cfg["retriever"])
default_argumentation_model_config_str = yaml.dump(cfg["argumentation_model"])
handle_parts_of_same = cfg["handle_parts_of_same"]
default_arxiv_id = cfg["default_arxiv_id"]
default_load_pie_dataset_kwargs_str = json.dumps(
cfg["default_load_pie_dataset_kwargs"], indent=2
)
default_render_mode = cfg["default_render_mode"]
if default_render_mode not in AVAILABLE_RENDER_MODES:
raise ValueError(
f"Invalid default render mode '{default_render_mode}'. "
f"Choose one of {AVAILABLE_RENDER_MODES}."
)
default_render_kwargs = cfg["default_render_kwargs"]
# captions for better readability
default_split_regex = cfg["default_split_regex"]
# map from render mode to the corresponding caption
render_mode2caption = {
render_mode: cfg["render_mode_captions"].get(render_mode, render_mode)
for render_mode in AVAILABLE_RENDER_MODES
}
render_caption2mode = {v: k for k, v in render_mode2caption.items()}
default_min_similarity = cfg["default_min_similarity"]
default_top_k = cfg["default_top_k"]
layer_caption_mapping = cfg["layer_caption_mapping"]
relation_name_mapping = cfg["relation_name_mapping"]
indexed_documents_label = "Indexed Documents"
indexed_documents_caption2column = {
"documents": "TOTAL",
"ADUs": "num_adus",
"Relations": "num_relations",
}
gr.Info("Loading models ...")
argumentation_model = load_argumentation_model(
config_str=default_argumentation_model_config_str,
device=default_device,
)
retriever = load_retriever(
config_str=default_retriever_config_str, device=default_device, config_format="yaml"
)
if cfg.get("pdf_fulltext_extractor"):
gr.Info("Loading PDF fulltext extractor ...")
pdf_fulltext_extractor = hydra.utils.instantiate(cfg["pdf_fulltext_extractor"])
else:
pdf_fulltext_extractor = None
with gr.Blocks() as demo:
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
# models_state = gr.State((argumentation_model, embedding_model))
argumentation_model_state = gr.State((argumentation_model,))
retriever_state = gr.State((retriever,))
with gr.Row():
with gr.Tabs() as left_tabs:
with gr.Tab("User Input", id="user_input") as user_input_tab:
doc_id = gr.Textbox(
label="Document ID",
value="user_input",
)
doc_text = gr.Textbox(
label="Text",
lines=20,
value=example_text,
)
with gr.Accordion("Model Configuration", open=False):
with gr.Accordion("argumentation structure", open=True):
argumentation_model_config_str = gr.Code(
language="yaml",
label="Argumentation Model Configuration",
value=default_argumentation_model_config_str,
lines=len(default_argumentation_model_config_str.split("\n")),
)
load_arg_model_btn = gr.Button("Load Argumentation Model")
with gr.Accordion("retriever", open=True):
retriever_config_str = gr.Code(
language="yaml",
label="Retriever Configuration",
value=default_retriever_config_str,
lines=len(default_retriever_config_str.split("\n")),
)
load_retriever_btn = gr.Button("Load Retriever")
device = gr.Textbox(
label="Device (e.g. 'cuda' or 'cpu')",
value=default_device,
)
load_arg_model_btn.click(
fn=lambda _argumentation_model_config_str, _device: (
load_argumentation_model(
config_str=_argumentation_model_config_str,
device=_device,
),
),
inputs=[argumentation_model_config_str, device],
outputs=argumentation_model_state,
)
load_retriever_btn.click(
fn=lambda _retriever_config, _device, _previous_retriever: (
load_retriever(
config_str=_retriever_config,
device=_device,
previous_retriever=_previous_retriever[0],
config_format="yaml",
),
),
inputs=[retriever_config_str, device, retriever_state],
outputs=retriever_state,
)
split_regex_escaped = gr.Textbox(
label="Regex to partition the text",
placeholder="Regular expression pattern to split the text into partitions",
value=escape_regex(default_split_regex),
)
predict_btn = gr.Button("Analyse")
with gr.Tab("Analysed Document", id="analysed_document") as analysed_document_tab:
selected_document_id = gr.Textbox(
label="Document ID", max_lines=1, interactive=False
)
rendered_output = gr.HTML(label="Rendered Output")
with gr.Accordion("Render Options", open=False):
render_as = gr.Dropdown(
label="Render with",
choices=list(render_mode2caption.values()),
value=render_mode2caption[default_render_mode],
)
render_kwargs = gr.Code(
language="json",
label="Render Arguments",
lines=len(json.dumps(default_render_kwargs, indent=2).split("\n")),
value=json.dumps(default_render_kwargs, indent=2),
)
render_btn = gr.Button("Re-render")
with gr.Accordion("See plain result ...", open=False):
get_document_json_btn = gr.Button("Fetch annotated document as JSON")
document_json = gr.JSON(label="Model Output")
with gr.Tabs() as right_tabs:
with gr.Tab("Retrieval", id="retrieval") as retrieval_tab:
with gr.Accordion(
indexed_documents_label, open=False
) as processed_documents_accordion:
processed_documents_df = gr.DataFrame(
headers=["id", "num_adus", "num_relations"],
interactive=False,
elem_classes="df-docstore",
)
gr.Markdown("Data Snapshot:")
with gr.Row():
download_processed_documents_btn = gr.DownloadButton("Download")
upload_processed_documents_btn = gr.UploadButton(
"Upload", file_types=["file"]
)
# currently not used
# relation_types = set_relation_types(
# argumentation_model_state.value[0], default=["supports_reversed", "contradicts_reversed"]
# )
# Dummy textbox to hold the hover adu id. On click on the rendered output,
# its content will be copied to selected_adu_id which will trigger the retrieval.
hover_adu_id = gr.Textbox(
label="ID (hover)",
elem_id="hover_adu_id",
interactive=False,
visible=False,
)
selected_adu_id = gr.Textbox(
label="ID (selected)",
elem_id="selected_adu_id",
interactive=False,
visible=False,
)
selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False)
with gr.Accordion("Relevant ADUs from other documents", open=True):
relevant_adus_df = gr.DataFrame(
headers=[
"relation",
"adu",
"reference_adu",
"doc_id",
"sim_score",
"rel_score",
],
interactive=False,
)
with gr.Accordion("Retrieval Configuration", open=False):
min_similarity = gr.Slider(
label="Minimum Similarity",
minimum=0.0,
maximum=1.0,
step=0.01,
value=default_min_similarity,
)
top_k = gr.Slider(
label="Top K",
minimum=2,
maximum=50,
step=1,
value=default_top_k,
)
retrieve_similar_adus_btn = gr.Button(
"Retrieve *similar* ADUs for *selected* ADU"
)
similar_adus_df = gr.DataFrame(
headers=["doc_id", "adu_id", "score", "text"], interactive=False
)
retrieve_all_similar_adus_btn = gr.Button(
"Retrieve *similar* ADUs for *all* ADUs in the document"
)
all_similar_adus_df = gr.DataFrame(
headers=["doc_id", "query_adu_id", "adu_id", "score", "text"],
interactive=False,
)
retrieve_all_relevant_adus_btn = gr.Button(
"Retrieve *relevant* ADUs for *all* ADUs in the document"
)
all_relevant_adus_df = gr.DataFrame(
headers=["doc_id", "adu_id", "score", "text", "query_span_id"],
interactive=False,
)
all_relevant_adus_query_doc_id = gr.Textbox(visible=False)
with gr.Tab("Import Documents", id="import_documents") as import_documents_tab:
upload_btn = gr.UploadButton(
"Batch Analyse Texts",
file_types=["text"],
file_count="multiple",
)
upload_pdf_btn = gr.UploadButton(
"Batch Analyse PDFs",
# file_types=["pdf"],
file_count="multiple",
visible=pdf_fulltext_extractor is not None,
)
enable_acl_venue_loading = pdf_fulltext_extractor is not None and cfg.get(
"acl_anthology_pdf_dir"
)
acl_anthology_venues = gr.Textbox(
label="ACL Anthology Venues",
value="wiesp",
max_lines=1,
visible=enable_acl_venue_loading,
)
load_acl_anthology_venues_btn = gr.Button(
"Import from ACL Anthology",
variant="secondary",
visible=enable_acl_venue_loading,
)
with gr.Accordion("Import text from arXiv", open=False):
arxiv_id = gr.Textbox(
label="arXiv paper ID",
placeholder=f"e.g. {default_arxiv_id}",
max_lines=1,
)
load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False)
load_arxiv_btn = gr.Button(
"Load & Analyse from arXiv", variant="secondary"
)
with gr.Accordion(
"Import argument structure annotated PIE dataset", open=False
):
load_pie_dataset_kwargs_str = gr.Code(
language="json",
label="Parameters for Loading the PIE Dataset",
value=default_load_pie_dataset_kwargs_str,
lines=len(default_load_pie_dataset_kwargs_str.split("\n")),
)
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
render_event_kwargs = dict(
fn=lambda _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id: render_annotated_document(
retriever=_retriever[0],
document_id=_document_id,
render_with=render_caption2mode[_render_as],
render_kwargs_json=_render_kwargs,
highlight_span_ids=(
_all_relevant_adus_df["query_span_id"].tolist()
if _document_id == _all_relevant_adus_query_doc_id
else None
),
),
inputs=[
retriever_state,
selected_document_id,
render_as,
render_kwargs,
all_relevant_adus_df,
all_relevant_adus_query_doc_id,
],
outputs=rendered_output,
)
show_overview_kwargs = dict(
fn=lambda _retriever: _retriever[0].docstore.overview(
layer_captions=layer_caption_mapping, use_predictions=True
),
inputs=[retriever_state],
outputs=[processed_documents_df],
)
show_stats_kwargs = dict(
fn=lambda _processed_documents_df: open_accordion_with_stats(
_processed_documents_df,
base_label=indexed_documents_label,
caption2column=indexed_documents_caption2column,
total_column="TOTAL",
),
inputs=[processed_documents_df],
outputs=[processed_documents_accordion],
)
predict_btn.click(
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
).then(
fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text(
text=_doc_text,
doc_id=_doc_id,
argumentation_model=_argumentation_model[0],
retriever=_retriever[0],
split_regex_escaped=(
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
),
handle_parts_of_same=handle_parts_of_same,
),
inputs=[
doc_text,
doc_id,
argumentation_model_state,
retriever_state,
split_regex_escaped,
],
outputs=[selected_document_id],
api_name="predict",
).success(
**show_overview_kwargs
).success(
**show_stats_kwargs
).success(
**render_event_kwargs
)
render_btn.click(**render_event_kwargs, api_name="render")
load_arxiv_btn.click(
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
).then(
fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv(
arxiv_id=_arxiv_id.strip() or default_arxiv_id,
abstract_only=_load_arxiv_only_abstract,
argumentation_model=_argumentation_model[0],
retriever=_retriever[0],
split_regex_escaped=(
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
),
handle_parts_of_same=handle_parts_of_same,
),
inputs=[
arxiv_id,
load_arxiv_only_abstract,
argumentation_model_state,
retriever_state,
split_regex_escaped,
],
outputs=[selected_document_id],
api_name="predict",
).success(
**show_overview_kwargs
).success(
**show_stats_kwargs
)
load_pie_dataset_btn.click(
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset(
retriever=_retriever[0],
verbose=True,
layer_captions=layer_caption_mapping,
**json.loads(_load_pie_dataset_kwargs_str),
),
inputs=[retriever_state, load_pie_dataset_kwargs_str],
outputs=[processed_documents_df],
).success(
**show_stats_kwargs
)
selected_document_id.change(
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
).then(**render_event_kwargs)
get_document_json_btn.click(
fn=lambda _retriever, _document_id: get_document_as_dict(
retriever=_retriever[0], doc_id=_document_id
),
inputs=[retriever_state, selected_document_id],
outputs=[document_json],
)
upload_btn.upload(
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files(
file_names=_file_names,
argumentation_model=_argumentation_model[0],
retriever=_retriever[0],
split_regex_escaped=(
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
),
handle_parts_of_same=handle_parts_of_same,
layer_captions=layer_caption_mapping,
),
inputs=[
upload_btn,
argumentation_model_state,
retriever_state,
split_regex_escaped,
],
outputs=[processed_documents_df],
).success(
**show_stats_kwargs
)
upload_pdf_btn.upload(
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_pdf_files(
file_names=_file_names,
argumentation_model=_argumentation_model[0],
retriever=_retriever[0],
split_regex_escaped=(
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
),
handle_parts_of_same=handle_parts_of_same,
layer_captions=layer_caption_mapping,
pdf_fulltext_extractor=pdf_fulltext_extractor,
),
inputs=[
upload_pdf_btn,
argumentation_model_state,
retriever_state,
split_regex_escaped,
],
outputs=[processed_documents_df],
).success(
**show_stats_kwargs
)
load_acl_anthology_venues_btn.click(
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
fn=lambda _acl_anthology_venues, _argumentation_model, _retriever, _split_regex_escaped: load_acl_anthology_venues(
pdf_fulltext_extractor=pdf_fulltext_extractor,
venues=[venue.strip() for venue in _acl_anthology_venues.split(",")],
argumentation_model=_argumentation_model[0],
retriever=_retriever[0],
split_regex_escaped=(
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
),
handle_parts_of_same=handle_parts_of_same,
layer_captions=layer_caption_mapping,
acl_anthology_data_dir=cfg.get("acl_anthology_data_dir"),
pdf_output_dir=cfg.get("acl_anthology_pdf_dir"),
),
inputs=[
acl_anthology_venues,
argumentation_model_state,
retriever_state,
split_regex_escaped,
],
outputs=[processed_documents_df],
).success(
**show_stats_kwargs
)
processed_documents_df.select(
fn=get_cell_for_fixed_column_from_df,
inputs=[processed_documents_df, gr.State("doc_id")],
outputs=[selected_document_id],
)
download_processed_documents_btn.click(
fn=lambda _retriever: download_processed_documents(
_retriever[0], file_name="processed_documents"
),
inputs=[retriever_state],
outputs=[download_processed_documents_btn],
)
upload_processed_documents_btn.upload(
fn=lambda file_name, _retriever: upload_processed_documents(
file_name, retriever=_retriever[0], layer_captions=layer_caption_mapping
),
inputs=[upload_processed_documents_btn, retriever_state],
outputs=[processed_documents_df],
).success(**show_stats_kwargs)
retrieve_relevant_adus_event_kwargs = dict(
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
retriever=_retriever[0],
query_span_id=_selected_adu_id,
k=_top_k,
score_threshold=_min_similarity,
relation_label_mapping=relation_name_mapping,
# columns=relevant_adus.headers
),
inputs=[
retriever_state,
selected_adu_id,
min_similarity,
top_k,
],
outputs=[relevant_adus_df],
)
relevant_adus_df.select(
fn=get_cell_for_fixed_column_from_df,
inputs=[relevant_adus_df, gr.State("doc_id")],
outputs=[selected_document_id],
)
selected_adu_id.change(
fn=lambda _retriever, _selected_adu_id: get_span_annotation(
retriever=_retriever[0], span_id=_selected_adu_id
),
inputs=[retriever_state, selected_adu_id],
outputs=[selected_adu_text],
).success(**retrieve_relevant_adus_event_kwargs)
retrieve_similar_adus_btn.click(
fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans(
retriever=_retriever[0],
query_span_id=_selected_adu_id,
k=_tok_k,
score_threshold=_min_similarity,
),
inputs=[
retriever_state,
selected_adu_id,
min_similarity,
top_k,
],
outputs=[similar_adus_df],
)
similar_adus_df.select(
fn=get_cell_for_fixed_column_from_df,
inputs=[similar_adus_df, gr.State("doc_id")],
outputs=[selected_document_id],
)
retrieve_all_similar_adus_btn.click(
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans(
retriever=_retriever[0],
query_doc_id=_document_id,
k=_tok_k,
score_threshold=_min_similarity,
query_span_id_column="query_span_id",
),
inputs=[
retriever_state,
selected_document_id,
min_similarity,
top_k,
],
outputs=[all_similar_adus_df],
)
retrieve_all_relevant_adus_btn.click(
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: (
retrieve_all_relevant_spans(
retriever=_retriever[0],
query_doc_id=_document_id,
k=_tok_k,
score_threshold=_min_similarity,
query_span_id_column="query_span_id",
query_span_text_column="query_span_text",
),
_document_id,
),
inputs=[
retriever_state,
selected_document_id,
min_similarity,
top_k,
],
outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id],
)
all_relevant_adus_df.change(**render_event_kwargs)
# select query span id from the "retrieve all" result data frames
all_similar_adus_df.select(
fn=get_cell_for_fixed_column_from_df,
inputs=[all_similar_adus_df, gr.State("query_span_id")],
outputs=[selected_adu_id],
)
all_relevant_adus_df.select(
fn=get_cell_for_fixed_column_from_df,
inputs=[all_relevant_adus_df, gr.State("query_span_id")],
outputs=[selected_adu_id],
)
# argumentation_model_state.change(
# fn=lambda _argumentation_model: set_relation_types(_argumentation_model[0]),
# inputs=[argumentation_model_state],
# outputs=[relation_types],
# )
rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[])
demo.launch()
if __name__ == "__main__":
main()