ArneBinder's picture
update from https://github.com/ArneBinder/pie-document-level/pull/397
ced4316 verified
import abc
import logging
from copy import copy
from typing import Iterator, List, Optional, Sequence, Tuple
import pandas as pd
from langchain_core.documents import Document as LCDocument
from langchain_core.stores import BaseStore
from pytorch_ie.documents import TextBasedDocument
from .serializable_store import SerializableStore
logger = logging.getLogger(__name__)
class PieDocumentStore(SerializableStore, BaseStore[str, LCDocument], abc.ABC):
"""Abstract base class for document stores specialized in storing and retrieving pie documents."""
METADATA_KEY_PIE_DOCUMENT: str = "pie_document"
"""Key for the pie document in the (langchain) document metadata."""
def wrap(self, pie_document: TextBasedDocument, **metadata) -> LCDocument:
"""Wrap the pie document in an LCDocument."""
return LCDocument(
id=pie_document.id,
page_content="",
metadata={self.METADATA_KEY_PIE_DOCUMENT: pie_document, **metadata},
)
def unwrap(self, document: LCDocument) -> TextBasedDocument:
"""Get the pie document from the langchain document."""
return document.metadata[self.METADATA_KEY_PIE_DOCUMENT]
def unwrap_with_metadata(self, document: LCDocument) -> Tuple[TextBasedDocument, dict]:
"""Get the pie document and metadata from the langchain document."""
metadata = copy(document.metadata)
pie_document = metadata.pop(self.METADATA_KEY_PIE_DOCUMENT)
return pie_document, metadata
@abc.abstractmethod
def mget(self, keys: Sequence[str]) -> List[LCDocument]:
pass
@abc.abstractmethod
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
pass
@abc.abstractmethod
def mdelete(self, keys: Sequence[str]) -> None:
pass
@abc.abstractmethod
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
pass
def __len__(self):
return len(list(self.yield_keys()))
def overview(self, layer_captions: dict, use_predictions: bool = False) -> pd.DataFrame:
"""Get an overview of the document store, including the number of items in each layer for each document
in the store.
Args:
layer_captions: A dictionary mapping layer names to captions.
use_predictions: Whether to use predictions instead of the actual layers.
Returns:
DataFrame: A pandas DataFrame containing the overview.
"""
rows = []
for doc_id in self.yield_keys():
document = self.mget([doc_id])[0]
pie_document = self.unwrap(document)
layers = {
caption: pie_document[layer_name] for layer_name, caption in layer_captions.items()
}
layer_sizes = {
f"num_{caption}": len(layer) + (len(layer.predictions) if use_predictions else 0)
for caption, layer in layers.items()
}
rows.append({"doc_id": doc_id, **layer_sizes})
df = pd.DataFrame(rows)
return df
def as_dict(self, document: LCDocument) -> dict:
"""Convert the langchain document to a dictionary."""
pie_document, metadata = self.unwrap_with_metadata(document)
return {self.METADATA_KEY_PIE_DOCUMENT: pie_document.asdict(), "metadata": metadata}