|
import pyrootutils |
|
|
|
root = pyrootutils.setup_root( |
|
search_from=__file__, |
|
indicator=[".project-root"], |
|
pythonpath=True, |
|
dotenv=True, |
|
) |
|
|
|
import argparse |
|
import logging |
|
|
|
from demo.model_utils import ( |
|
retrieve_all_relevant_spans, |
|
retrieve_all_relevant_spans_for_all_documents, |
|
retrieve_relevant_spans, |
|
) |
|
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-c", |
|
"--config_path", |
|
type=str, |
|
default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml", |
|
) |
|
parser.add_argument( |
|
"--data_path", |
|
type=str, |
|
required=True, |
|
help="Path to a zip or directory containing a retriever dump.", |
|
) |
|
parser.add_argument("-k", "--top_k", type=int, default=10) |
|
parser.add_argument("-t", "--threshold", type=float, default=0.95) |
|
parser.add_argument( |
|
"-o", |
|
"--output_path", |
|
type=str, |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"--query_doc_id", |
|
type=str, |
|
default=None, |
|
help="If provided, retrieve all spans for only this query document.", |
|
) |
|
parser.add_argument( |
|
"--query_span_id", |
|
type=str, |
|
default=None, |
|
help="If provided, retrieve all spans for only this query span.", |
|
) |
|
args = parser.parse_args() |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s %(levelname)-8s %(message)s", |
|
level=logging.INFO, |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
|
|
if not args.output_path.endswith(".json"): |
|
raise ValueError("only support json output") |
|
|
|
logger.info(f"instantiating retriever from {args.config_path}...") |
|
retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file( |
|
args.config_path |
|
) |
|
logger.info(f"loading data from {args.data_path}...") |
|
retriever.load_from_disc(args.data_path) |
|
|
|
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold} |
|
logger.info(f"use search_kwargs: {search_kwargs}") |
|
|
|
if args.query_span_id is not None: |
|
logger.warning(f"retrieving results for single span: {args.query_span_id}") |
|
all_spans_for_all_documents = retrieve_relevant_spans( |
|
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs |
|
) |
|
elif args.query_doc_id is not None: |
|
logger.warning(f"retrieving results for single document: {args.query_doc_id}") |
|
all_spans_for_all_documents = retrieve_all_relevant_spans( |
|
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs |
|
) |
|
else: |
|
all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents( |
|
retriever=retriever, **search_kwargs |
|
) |
|
|
|
if all_spans_for_all_documents is None: |
|
logger.warning("no relevant spans found in any document") |
|
exit(0) |
|
|
|
logger.info(f"dumping results to {args.output_path}...") |
|
all_spans_for_all_documents.to_json(args.output_path) |
|
|
|
logger.info("done") |
|
|