ScientificArgumentRecommender / retrieve_and_dump_all_relevant.py
ArneBinder's picture
new demo setup with langchain retriever
2cc87ec verified
raw
history blame
3.15 kB
import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".project-root"],
pythonpath=True,
dotenv=True,
)
import argparse
import logging
from demo.model_utils import (
retrieve_all_relevant_spans,
retrieve_all_relevant_spans_for_all_documents,
retrieve_relevant_spans,
)
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
logger = logging.getLogger(__name__)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config_path",
type=str,
default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
)
parser.add_argument(
"--data_path",
type=str,
required=True,
help="Path to a zip or directory containing a retriever dump.",
)
parser.add_argument("-k", "--top_k", type=int, default=10)
parser.add_argument("-t", "--threshold", type=float, default=0.95)
parser.add_argument(
"-o",
"--output_path",
type=str,
required=True,
)
parser.add_argument(
"--query_doc_id",
type=str,
default=None,
help="If provided, retrieve all spans for only this query document.",
)
parser.add_argument(
"--query_span_id",
type=str,
default=None,
help="If provided, retrieve all spans for only this query span.",
)
args = parser.parse_args()
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
if not args.output_path.endswith(".json"):
raise ValueError("only support json output")
logger.info(f"instantiating retriever from {args.config_path}...")
retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
args.config_path
)
logger.info(f"loading data from {args.data_path}...")
retriever.load_from_disc(args.data_path)
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
logger.info(f"use search_kwargs: {search_kwargs}")
if args.query_span_id is not None:
logger.warning(f"retrieving results for single span: {args.query_span_id}")
all_spans_for_all_documents = retrieve_relevant_spans(
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
)
elif args.query_doc_id is not None:
logger.warning(f"retrieving results for single document: {args.query_doc_id}")
all_spans_for_all_documents = retrieve_all_relevant_spans(
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
)
else:
all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
retriever=retriever, **search_kwargs
)
if all_spans_for_all_documents is None:
logger.warning("no relevant spans found in any document")
exit(0)
logger.info(f"dumping results to {args.output_path}...")
all_spans_for_all_documents.to_json(args.output_path)
logger.info("done")