import pyrootutils root = pyrootutils.setup_root( search_from=__file__, indicator=[".project-root"], pythonpath=True, dotenv=True, ) import argparse import logging from demo.model_utils import ( retrieve_all_relevant_spans, retrieve_all_relevant_spans_for_all_documents, retrieve_relevant_spans, ) from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations logger = logging.getLogger(__name__) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config_path", type=str, default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml", ) parser.add_argument( "--data_path", type=str, required=True, help="Path to a zip or directory containing a retriever dump.", ) parser.add_argument("-k", "--top_k", type=int, default=10) parser.add_argument("-t", "--threshold", type=float, default=0.95) parser.add_argument( "-o", "--output_path", type=str, required=True, ) parser.add_argument( "--query_doc_id", type=str, default=None, help="If provided, retrieve all spans for only this query document.", ) parser.add_argument( "--query_span_id", type=str, default=None, help="If provided, retrieve all spans for only this query span.", ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", ) if not args.output_path.endswith(".json"): raise ValueError("only support json output") logger.info(f"instantiating retriever from {args.config_path}...") retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file( args.config_path ) logger.info(f"loading data from {args.data_path}...") retriever.load_from_disc(args.data_path) search_kwargs = {"k": args.top_k, "score_threshold": args.threshold} logger.info(f"use search_kwargs: {search_kwargs}") if args.query_span_id is not None: logger.warning(f"retrieving results for single span: {args.query_span_id}") all_spans_for_all_documents = retrieve_relevant_spans( retriever=retriever, query_span_id=args.query_span_id, **search_kwargs ) elif args.query_doc_id is not None: logger.warning(f"retrieving results for single document: {args.query_doc_id}") all_spans_for_all_documents = retrieve_all_relevant_spans( retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs ) else: all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents( retriever=retriever, **search_kwargs ) if all_spans_for_all_documents is None: logger.warning("no relevant spans found in any document") exit(0) logger.info(f"dumping results to {args.output_path}...") all_spans_for_all_documents.to_json(args.output_path) logger.info("done")