File size: 3,152 Bytes
2cc87ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import pyrootutils

root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".project-root"],
    pythonpath=True,
    dotenv=True,
)

import argparse
import logging

from demo.model_utils import (
    retrieve_all_relevant_spans,
    retrieve_all_relevant_spans_for_all_documents,
    retrieve_relevant_spans,
)
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations

logger = logging.getLogger(__name__)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c",
        "--config_path",
        type=str,
        default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        required=True,
        help="Path to a zip or directory containing a retriever dump.",
    )
    parser.add_argument("-k", "--top_k", type=int, default=10)
    parser.add_argument("-t", "--threshold", type=float, default=0.95)
    parser.add_argument(
        "-o",
        "--output_path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--query_doc_id",
        type=str,
        default=None,
        help="If provided, retrieve all spans for only this query document.",
    )
    parser.add_argument(
        "--query_span_id",
        type=str,
        default=None,
        help="If provided, retrieve all spans for only this query span.",
    )
    args = parser.parse_args()

    logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    if not args.output_path.endswith(".json"):
        raise ValueError("only support json output")

    logger.info(f"instantiating retriever from {args.config_path}...")
    retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
        args.config_path
    )
    logger.info(f"loading data from {args.data_path}...")
    retriever.load_from_disc(args.data_path)

    search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
    logger.info(f"use search_kwargs: {search_kwargs}")

    if args.query_span_id is not None:
        logger.warning(f"retrieving results for single span: {args.query_span_id}")
        all_spans_for_all_documents = retrieve_relevant_spans(
            retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
        )
    elif args.query_doc_id is not None:
        logger.warning(f"retrieving results for single document: {args.query_doc_id}")
        all_spans_for_all_documents = retrieve_all_relevant_spans(
            retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
        )
    else:
        all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
            retriever=retriever, **search_kwargs
        )

    if all_spans_for_all_documents is None:
        logger.warning("no relevant spans found in any document")
        exit(0)

    logger.info(f"dumping results to {args.output_path}...")
    all_spans_for_all_documents.to_json(args.output_path)

    logger.info("done")