|
import pyrootutils |
|
|
|
root = pyrootutils.setup_root( |
|
search_from=__file__, |
|
indicator=[".project-root"], |
|
pythonpath=True, |
|
dotenv=True, |
|
) |
|
|
|
import argparse |
|
import logging |
|
import os |
|
|
|
import pandas as pd |
|
|
|
from src.demo.retriever_utils import ( |
|
retrieve_all_relevant_spans, |
|
retrieve_all_relevant_spans_for_all_documents, |
|
retrieve_relevant_spans, |
|
) |
|
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-c", |
|
"--config_path", |
|
type=str, |
|
default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml", |
|
) |
|
parser.add_argument( |
|
"--data_path", |
|
type=str, |
|
required=True, |
|
help="Path to a zip or directory containing a retriever dump.", |
|
) |
|
parser.add_argument("-k", "--top_k", type=int, default=10) |
|
parser.add_argument("-t", "--threshold", type=float, default=0.95) |
|
parser.add_argument( |
|
"-o", |
|
"--output_path", |
|
type=str, |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"--query_doc_id", |
|
type=str, |
|
default=None, |
|
help="If provided, retrieve all spans for only this query document.", |
|
) |
|
parser.add_argument( |
|
"--query_span_id", |
|
type=str, |
|
default=None, |
|
help="If provided, retrieve all spans for only this query span.", |
|
) |
|
parser.add_argument( |
|
"--doc_id_whitelist", |
|
type=str, |
|
nargs="+", |
|
default=None, |
|
help="If provided, only consider documents with these IDs.", |
|
) |
|
parser.add_argument( |
|
"--doc_id_blacklist", |
|
type=str, |
|
nargs="+", |
|
default=None, |
|
help="If provided, ignore documents with these IDs.", |
|
) |
|
parser.add_argument( |
|
"--query_target_doc_id_pairs", |
|
type=str, |
|
nargs="+", |
|
default=None, |
|
help="One or more pairs of query and target document IDs " |
|
'(each separated by ":") to retrieve spans for. If provided, ' |
|
"--query_doc_id and --query_span_id are ignored.", |
|
) |
|
args = parser.parse_args() |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s %(levelname)-8s %(message)s", |
|
level=logging.INFO, |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
|
|
if not args.output_path.endswith(".json"): |
|
raise ValueError("only support json output") |
|
|
|
logger.info(f"instantiating retriever from {args.config_path}...") |
|
retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file( |
|
args.config_path |
|
) |
|
logger.info(f"loading data from {args.data_path}...") |
|
retriever.load_from_disc(args.data_path) |
|
|
|
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold} |
|
if args.doc_id_whitelist is not None: |
|
search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist |
|
if args.doc_id_blacklist is not None: |
|
search_kwargs["doc_id_blacklist"] = args.doc_id_blacklist |
|
logger.info(f"use search_kwargs: {search_kwargs}") |
|
|
|
if args.query_target_doc_id_pairs is not None: |
|
all_spans_for_all_documents = None |
|
for doc_id_pair in args.query_target_doc_id_pairs: |
|
query_doc_id, target_doc_id = doc_id_pair.split(":") |
|
current_result = retrieve_all_relevant_spans( |
|
retriever=retriever, |
|
query_doc_id=query_doc_id, |
|
doc_id_whitelist=[target_doc_id], |
|
**search_kwargs, |
|
) |
|
if current_result is None: |
|
logger.warning( |
|
f"no relevant spans found for query_doc_id={query_doc_id} and " |
|
f"target_doc_id={target_doc_id}" |
|
) |
|
continue |
|
logger.info( |
|
f"retrieved {len(current_result)} spans for query_doc_id={query_doc_id} " |
|
f"and target_doc_id={target_doc_id}" |
|
) |
|
current_result["query_doc_id"] = query_doc_id |
|
if all_spans_for_all_documents is None: |
|
all_spans_for_all_documents = current_result |
|
else: |
|
all_spans_for_all_documents = pd.concat( |
|
[all_spans_for_all_documents, current_result], ignore_index=True |
|
) |
|
|
|
elif args.query_span_id is not None: |
|
logger.warning(f"retrieving results for single span: {args.query_span_id}") |
|
all_spans_for_all_documents = retrieve_relevant_spans( |
|
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs |
|
) |
|
elif args.query_doc_id is not None: |
|
logger.warning(f"retrieving results for single document: {args.query_doc_id}") |
|
all_spans_for_all_documents = retrieve_all_relevant_spans( |
|
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs |
|
) |
|
else: |
|
all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents( |
|
retriever=retriever, **search_kwargs |
|
) |
|
|
|
if all_spans_for_all_documents is None: |
|
logger.warning("no relevant spans found in any document") |
|
exit(0) |
|
|
|
logger.info(f"dumping results ({len(all_spans_for_all_documents)}) to {args.output_path}...") |
|
os.makedirs(os.path.dirname(args.output_path), exist_ok=True) |
|
all_spans_for_all_documents.to_json(args.output_path) |
|
|
|
logger.info("done") |
|
|