import json import logging import os import shutil from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple from datasets import Dataset as HFDataset from langchain_core.documents import Document as LCDocument from pie_datasets import Dataset, DatasetDict, concatenate_datasets from pytorch_ie.documents import TextBasedDocument from .pie_document_store import PieDocumentStore logger = logging.getLogger(__name__) class DatasetsPieDocumentStore(PieDocumentStore): """PIE Document store that uses Huggingface Datasets as the backend.""" def __init__(self) -> None: self._data: Optional[Dataset] = None # keys map to indices in the dataset self._keys: Dict[str, int] = {} self._metadata: Dict[str, Any] = {} def __len__(self): return len(self._keys) def _get_pie_docs_by_indices(self, indices: Iterable[int]) -> Sequence[TextBasedDocument]: if self._data is None: return [] return self._data.apply_hf_func(func=HFDataset.select, indices=indices) def mget(self, keys: Sequence[str]) -> List[LCDocument]: if self._data is None or len(keys) == 0: return [] keys_in_data = [key for key in keys if key in self._keys] indices = [self._keys[key] for key in keys_in_data] dataset = self._get_pie_docs_by_indices(indices) metadatas = [self._metadata.get(key, {}) for key in keys_in_data] return [self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(dataset, metadatas)] def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None: if len(items) == 0: return keys, new_docs = zip(*items) pie_docs, metadatas = zip(*[self.unwrap_with_metadata(doc) for doc in new_docs]) if self._data is None: idx_start = 0 self._data = Dataset.from_documents(pie_docs) else: # we pass the features to the new dataset to mitigate issues caused by # slightly different inferred features dataset = Dataset.from_documents(pie_docs, features=self._data.features) idx_start = len(self._data) self._data = concatenate_datasets([self._data, dataset], clear_metadata=False) keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)} self._keys.update(keys_dict) self._metadata.update( {key: metadata for key, metadata in zip(keys, metadatas) if metadata} ) def add_pie_dataset( self, dataset: Dataset, keys: Optional[List[str]] = None, metadatas: Optional[List[Dict[str, Any]]] = None, ) -> None: if len(dataset) == 0: return if keys is None: keys = [doc.id for doc in dataset] if len(keys) != len(set(keys)): raise ValueError("Keys must be unique.") if None in keys: raise ValueError("Keys must not be None.") if metadatas is None: metadatas = [{} for _ in range(len(dataset))] if len(keys) != len(dataset) or len(keys) != len(metadatas): raise ValueError("Keys, dataset and metadatas must have the same length.") if self._data is None: idx_start = 0 self._data = dataset else: idx_start = len(self._data) self._data = concatenate_datasets([self._data, dataset], clear_metadata=False) keys_dict = {key: idx for idx, key in zip(range(idx_start, len(self._data)), keys)} self._keys.update(keys_dict) metadatas_dict = {key: metadata for key, metadata in zip(keys, metadatas) if metadata} self._metadata.update(metadatas_dict) def mdelete(self, keys: Sequence[str]) -> None: for key in keys: idx = self._keys.pop(key, None) if idx is not None: self._metadata.pop(key, None) def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: return (key for key in self._keys if prefix is None or key.startswith(prefix)) def _purge_invalid_entries(self): if self._data is None or len(self._keys) == len(self._data): return self._data = self._get_pie_docs_by_indices(self._keys.values()) def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None: self._purge_invalid_entries() if len(self) == 0: logger.warning("No documents to save.") return all_doc_ids = list(self._keys) all_metadatas: List[Dict[str, Any]] = [self._metadata.get(key, {}) for key in all_doc_ids] pie_documents_path = os.path.join(path, "pie_documents") if os.path.exists(pie_documents_path): # remove existing directory logger.warning(f"Removing existing directory: {pie_documents_path}") shutil.rmtree(pie_documents_path) os.makedirs(pie_documents_path, exist_ok=True) DatasetDict({"train": self._data}).to_json(pie_documents_path) doc_ids_path = os.path.join(path, "doc_ids.json") with open(doc_ids_path, "w") as f: json.dump(all_doc_ids, f) metadata_path = os.path.join(path, "metadata.json") with open(metadata_path, "w") as f: json.dump(all_metadatas, f) def _load_from_directory(self, path: str, **kwargs) -> None: doc_ids_path = os.path.join(path, "doc_ids.json") if os.path.exists(doc_ids_path): with open(doc_ids_path, "r") as f: all_doc_ids = json.load(f) else: logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.") all_doc_ids = None metadata_path = os.path.join(path, "metadata.json") if os.path.exists(metadata_path): with open(metadata_path, "r") as f: all_metadata = json.load(f) else: logger.warning(f"File {metadata_path} does not exist, don't load any metadata.") all_metadata = None pie_documents_path = os.path.join(path, "pie_documents") if not os.path.exists(pie_documents_path): logger.warning( f"Directory {pie_documents_path} does not exist, don't load any documents." ) return None # If we have a dataset already loaded, we use its features to load the new dataset # This is to mitigate issues caused by slightly different inferred features. features = self._data.features if self._data is not None else None pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path, features=features) pie_docs = pie_dataset["train"] self.add_pie_dataset(pie_docs, keys=all_doc_ids, metadatas=all_metadata) logger.info(f"Loaded {len(pie_docs)} documents from {path} into docstore")