prepare for https://github.com/ArneBinder/pie-document-level/pull/312
Browse files
retrieve_and_dump_all_relevant.py
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
import pyrootutils
|
2 |
-
|
3 |
-
root = pyrootutils.setup_root(
|
4 |
-
search_from=__file__,
|
5 |
-
indicator=[".project-root"],
|
6 |
-
pythonpath=True,
|
7 |
-
dotenv=True,
|
8 |
-
)
|
9 |
-
|
10 |
-
import argparse
|
11 |
-
import logging
|
12 |
-
|
13 |
-
from demo.model_utils import (
|
14 |
-
retrieve_all_relevant_spans,
|
15 |
-
retrieve_all_relevant_spans_for_all_documents,
|
16 |
-
retrieve_relevant_spans,
|
17 |
-
)
|
18 |
-
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
|
19 |
-
|
20 |
-
logger = logging.getLogger(__name__)
|
21 |
-
|
22 |
-
|
23 |
-
if __name__ == "__main__":
|
24 |
-
|
25 |
-
parser = argparse.ArgumentParser()
|
26 |
-
parser.add_argument(
|
27 |
-
"-c",
|
28 |
-
"--config_path",
|
29 |
-
type=str,
|
30 |
-
default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
|
31 |
-
)
|
32 |
-
parser.add_argument(
|
33 |
-
"--data_path",
|
34 |
-
type=str,
|
35 |
-
required=True,
|
36 |
-
help="Path to a zip or directory containing a retriever dump.",
|
37 |
-
)
|
38 |
-
parser.add_argument("-k", "--top_k", type=int, default=10)
|
39 |
-
parser.add_argument("-t", "--threshold", type=float, default=0.95)
|
40 |
-
parser.add_argument(
|
41 |
-
"-o",
|
42 |
-
"--output_path",
|
43 |
-
type=str,
|
44 |
-
required=True,
|
45 |
-
)
|
46 |
-
parser.add_argument(
|
47 |
-
"--query_doc_id",
|
48 |
-
type=str,
|
49 |
-
default=None,
|
50 |
-
help="If provided, retrieve all spans for only this query document.",
|
51 |
-
)
|
52 |
-
parser.add_argument(
|
53 |
-
"--query_span_id",
|
54 |
-
type=str,
|
55 |
-
default=None,
|
56 |
-
help="If provided, retrieve all spans for only this query span.",
|
57 |
-
)
|
58 |
-
args = parser.parse_args()
|
59 |
-
|
60 |
-
logging.basicConfig(
|
61 |
-
format="%(asctime)s %(levelname)-8s %(message)s",
|
62 |
-
level=logging.INFO,
|
63 |
-
datefmt="%Y-%m-%d %H:%M:%S",
|
64 |
-
)
|
65 |
-
|
66 |
-
if not args.output_path.endswith(".json"):
|
67 |
-
raise ValueError("only support json output")
|
68 |
-
|
69 |
-
logger.info(f"instantiating retriever from {args.config_path}...")
|
70 |
-
retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
|
71 |
-
args.config_path
|
72 |
-
)
|
73 |
-
logger.info(f"loading data from {args.data_path}...")
|
74 |
-
retriever.load_from_disc(args.data_path)
|
75 |
-
|
76 |
-
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
|
77 |
-
logger.info(f"use search_kwargs: {search_kwargs}")
|
78 |
-
|
79 |
-
if args.query_span_id is not None:
|
80 |
-
logger.warning(f"retrieving results for single span: {args.query_span_id}")
|
81 |
-
all_spans_for_all_documents = retrieve_relevant_spans(
|
82 |
-
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
|
83 |
-
)
|
84 |
-
elif args.query_doc_id is not None:
|
85 |
-
logger.warning(f"retrieving results for single document: {args.query_doc_id}")
|
86 |
-
all_spans_for_all_documents = retrieve_all_relevant_spans(
|
87 |
-
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
|
88 |
-
)
|
89 |
-
else:
|
90 |
-
all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
|
91 |
-
retriever=retriever, **search_kwargs
|
92 |
-
)
|
93 |
-
|
94 |
-
if all_spans_for_all_documents is None:
|
95 |
-
logger.warning("no relevant spans found in any document")
|
96 |
-
exit(0)
|
97 |
-
|
98 |
-
logger.info(f"dumping results to {args.output_path}...")
|
99 |
-
all_spans_for_all_documents.to_json(args.output_path)
|
100 |
-
|
101 |
-
logger.info("done")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|