ArneBinder commited on
Commit
b927ce3
·
verified ·
1 Parent(s): b0fe481

prepare for https://github.com/ArneBinder/pie-document-level/pull/312

Browse files
Files changed (1) hide show
  1. retrieve_and_dump_all_relevant.py +0 -101
retrieve_and_dump_all_relevant.py DELETED
@@ -1,101 +0,0 @@
1
- import pyrootutils
2
-
3
- root = pyrootutils.setup_root(
4
- search_from=__file__,
5
- indicator=[".project-root"],
6
- pythonpath=True,
7
- dotenv=True,
8
- )
9
-
10
- import argparse
11
- import logging
12
-
13
- from demo.model_utils import (
14
- retrieve_all_relevant_spans,
15
- retrieve_all_relevant_spans_for_all_documents,
16
- retrieve_relevant_spans,
17
- )
18
- from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
19
-
20
- logger = logging.getLogger(__name__)
21
-
22
-
23
- if __name__ == "__main__":
24
-
25
- parser = argparse.ArgumentParser()
26
- parser.add_argument(
27
- "-c",
28
- "--config_path",
29
- type=str,
30
- default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
31
- )
32
- parser.add_argument(
33
- "--data_path",
34
- type=str,
35
- required=True,
36
- help="Path to a zip or directory containing a retriever dump.",
37
- )
38
- parser.add_argument("-k", "--top_k", type=int, default=10)
39
- parser.add_argument("-t", "--threshold", type=float, default=0.95)
40
- parser.add_argument(
41
- "-o",
42
- "--output_path",
43
- type=str,
44
- required=True,
45
- )
46
- parser.add_argument(
47
- "--query_doc_id",
48
- type=str,
49
- default=None,
50
- help="If provided, retrieve all spans for only this query document.",
51
- )
52
- parser.add_argument(
53
- "--query_span_id",
54
- type=str,
55
- default=None,
56
- help="If provided, retrieve all spans for only this query span.",
57
- )
58
- args = parser.parse_args()
59
-
60
- logging.basicConfig(
61
- format="%(asctime)s %(levelname)-8s %(message)s",
62
- level=logging.INFO,
63
- datefmt="%Y-%m-%d %H:%M:%S",
64
- )
65
-
66
- if not args.output_path.endswith(".json"):
67
- raise ValueError("only support json output")
68
-
69
- logger.info(f"instantiating retriever from {args.config_path}...")
70
- retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
71
- args.config_path
72
- )
73
- logger.info(f"loading data from {args.data_path}...")
74
- retriever.load_from_disc(args.data_path)
75
-
76
- search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
77
- logger.info(f"use search_kwargs: {search_kwargs}")
78
-
79
- if args.query_span_id is not None:
80
- logger.warning(f"retrieving results for single span: {args.query_span_id}")
81
- all_spans_for_all_documents = retrieve_relevant_spans(
82
- retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
83
- )
84
- elif args.query_doc_id is not None:
85
- logger.warning(f"retrieving results for single document: {args.query_doc_id}")
86
- all_spans_for_all_documents = retrieve_all_relevant_spans(
87
- retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
88
- )
89
- else:
90
- all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
91
- retriever=retriever, **search_kwargs
92
- )
93
-
94
- if all_spans_for_all_documents is None:
95
- logger.warning("no relevant spans found in any document")
96
- exit(0)
97
-
98
- logger.info(f"dumping results to {args.output_path}...")
99
- all_spans_for_all_documents.to_json(args.output_path)
100
-
101
- logger.info("done")