File size: 6,922 Bytes
2cc87ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import json
import logging
import os
import shutil
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple

from datasets import Dataset as HFDataset
from langchain_core.documents import Document as LCDocument
from pie_datasets import Dataset, DatasetDict, concatenate_datasets
from pytorch_ie.documents import TextBasedDocument

from .pie_document_store import PieDocumentStore

logger = logging.getLogger(__name__)


class DatasetsPieDocumentStore(PieDocumentStore):
    """PIE Document store that uses Huggingface Datasets as the backend."""

    def __init__(self) -> None:
        self._data: Optional[Dataset] = None
        # keys map to indices in the dataset
        self._keys: Dict[str, int] = {}
        self._metadata: Dict[str, Any] = {}

    def __len__(self):
        return len(self._keys)

    def _get_pie_docs_by_indices(self, indices: Iterable[int]) -> Sequence[TextBasedDocument]:
        if self._data is None:
            return []
        return self._data.apply_hf_func(func=HFDataset.select, indices=indices)

    def mget(self, keys: Sequence[str]) -> List[LCDocument]:
        if self._data is None or len(keys) == 0:
            return []
        keys_in_data = [key for key in keys if key in self._keys]
        indices = [self._keys[key] for key in keys_in_data]
        dataset = self._get_pie_docs_by_indices(indices)
        metadatas = [self._metadata.get(key, {}) for key in keys_in_data]
        return [self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(dataset, metadatas)]

    def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
        if len(items) == 0:
            return
        keys, new_docs = zip(*items)
        pie_docs, metadatas = zip(*[self.unwrap_with_metadata(doc) for doc in new_docs])
        if self._data is None:
            idx_start = 0
            self._data = Dataset.from_documents(pie_docs)
        else:
            # we pass the features to the new dataset to mitigate issues caused by
            # slightly different inferred features
            dataset = Dataset.from_documents(pie_docs, features=self._data.features)
            idx_start = len(self._data)
            self._data = concatenate_datasets([self._data, dataset], clear_metadata=False)
        keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)}
        self._keys.update(keys_dict)
        self._metadata.update(
            {key: metadata for key, metadata in zip(keys, metadatas) if metadata}
        )

    def add_pie_dataset(
        self,
        dataset: Dataset,
        keys: Optional[List[str]] = None,
        metadatas: Optional[List[Dict[str, Any]]] = None,
    ) -> None:
        if len(dataset) == 0:
            return
        if keys is None:
            keys = [doc.id for doc in dataset]
        if len(keys) != len(set(keys)):
            raise ValueError("Keys must be unique.")
        if None in keys:
            raise ValueError("Keys must not be None.")
        if metadatas is None:
            metadatas = [{} for _ in range(len(dataset))]
        if len(keys) != len(dataset) or len(keys) != len(metadatas):
            raise ValueError("Keys, dataset and metadatas must have the same length.")

        if self._data is None:
            idx_start = 0
            self._data = dataset
        else:
            idx_start = len(self._data)
            self._data = concatenate_datasets([self._data, dataset], clear_metadata=False)
        keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)}
        self._keys.update(keys_dict)
        metadatas_dict = {key: metadata for key, metadata in zip(keys, metadatas) if metadata}
        self._metadata.update(metadatas_dict)

    def mdelete(self, keys: Sequence[str]) -> None:
        for key in keys:
            idx = self._keys.pop(key, None)
            if idx is not None:
                self._metadata.pop(key, None)

    def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
        return (key for key in self._keys if prefix is None or key.startswith(prefix))

    def _purge_invalid_entries(self):
        if self._data is None or len(self._keys) == len(self._data):
            return
        self._data = self._get_pie_docs_by_indices(self._keys.values())

    def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None:
        self._purge_invalid_entries()
        if len(self) == 0:
            logger.warning("No documents to save.")
            return

        all_doc_ids = list(self._keys)
        all_metadatas: List[Dict[str, Any]] = [self._metadata.get(key, {}) for key in all_doc_ids]
        pie_documents_path = os.path.join(path, "pie_documents")
        if os.path.exists(pie_documents_path):
            # remove existing directory
            logger.warning(f"Removing existing directory: {pie_documents_path}")
            shutil.rmtree(pie_documents_path)
        os.makedirs(pie_documents_path, exist_ok=True)
        DatasetDict({"train": self._data}).to_json(pie_documents_path)
        doc_ids_path = os.path.join(path, "doc_ids.json")
        with open(doc_ids_path, "w") as f:
            json.dump(all_doc_ids, f)
        metadata_path = os.path.join(path, "metadata.json")
        with open(metadata_path, "w") as f:
            json.dump(all_metadatas, f)

    def _load_from_directory(self, path: str, **kwargs) -> None:
        doc_ids_path = os.path.join(path, "doc_ids.json")
        if os.path.exists(doc_ids_path):
            with open(doc_ids_path, "r") as f:
                all_doc_ids = json.load(f)
        else:
            logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.")
            all_doc_ids = None
        metadata_path = os.path.join(path, "metadata.json")
        if os.path.exists(metadata_path):
            with open(metadata_path, "r") as f:
                all_metadata = json.load(f)
        else:
            logger.warning(f"File {metadata_path} does not exist, don't load any metadata.")
            all_metadata = None
        pie_documents_path = os.path.join(path, "pie_documents")
        if not os.path.exists(pie_documents_path):
            logger.warning(
                f"Directory {pie_documents_path} does not exist, don't load any documents."
            )
            return None
        # If we have a dataset already loaded, we use its features to load the new dataset
        # This is to mitigate issues caused by slightly different inferred features.
        features = self._data.features if self._data is not None else None
        pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path, features=features)
        pie_docs = pie_dataset["train"]
        self.add_pie_dataset(pie_docs, keys=all_doc_ids, metadatas=all_metadata)
        logger.info(f"Loaded {len(pie_docs)} documents from {path} into docstore")