ArneBinder's picture
upload https://github.com/ArneBinder/pie-document-level/pull/452
e7eaeed verified
import json
import os
from typing import Dict, Iterable, List, Optional, Sequence, Type, TypeVar
from pie_datasets import Dataset, DatasetDict, IterableDataset
from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
from pytorch_ie.core import Document
from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
from src.serializer.interface import DocumentSerializer
from src.utils.logging_utils import get_pylogger
log = get_pylogger(__name__)
D = TypeVar("D", bound=Document)
def as_json_lines(file_name: str) -> bool:
if file_name.lower().endswith(".jsonl"):
return True
elif file_name.lower().endswith(".json"):
return False
else:
raise Exception(f"unknown file extension: {file_name}")
class JsonSerializer(DocumentSerializer):
def __init__(self, **kwargs):
self.default_kwargs = kwargs
@classmethod
def write(
cls,
documents: Iterable[Document],
path: str,
file_name: str = "documents.jsonl",
metadata_file_name: str = METADATA_FILE_NAME,
split: Optional[str] = None,
**kwargs,
) -> Dict[str, str]:
realpath = os.path.realpath(path)
log.info(f'serialize documents to "{realpath}" ...')
os.makedirs(realpath, exist_ok=True)
if not isinstance(documents, Sequence):
documents = list(documents)
# dump metadata including the document_type
if len(documents) == 0:
raise Exception("cannot serialize empty list of documents")
document_type = type(documents[0])
metadata = {"document_type": serialize_document_type(document_type)}
full_metadata_file_name = os.path.join(realpath, metadata_file_name)
if os.path.exists(full_metadata_file_name):
# load previous metadata
with open(full_metadata_file_name) as f:
previous_metadata = json.load(f)
if previous_metadata != metadata:
raise ValueError(
f"metadata file {full_metadata_file_name} already exists, "
"but the content does not match the current metadata"
"\nprevious metadata: {previous_metadata}"
"\ncurrent metadata: {metadata}"
)
else:
with open(full_metadata_file_name, "w") as f:
json.dump(metadata, f, indent=2)
if split is not None:
realpath = os.path.join(realpath, split)
os.makedirs(realpath, exist_ok=True)
full_file_name = os.path.join(realpath, file_name)
if as_json_lines(file_name):
# if the file already exists, append to it
mode = "a" if os.path.exists(full_file_name) else "w"
with open(full_file_name, mode) as f:
for doc in documents:
f.write(json.dumps(doc.asdict(), **kwargs) + "\n")
else:
docs_list = [doc.asdict() for doc in documents]
if os.path.exists(full_file_name):
# load previous documents
with open(full_file_name) as f:
previous_doc_list = json.load(f)
docs_list = previous_doc_list + docs_list
with open(full_file_name, "w") as f:
json.dump(docs_list, fp=f, **kwargs)
return {"path": realpath, "file_name": file_name, "metadata_file_name": metadata_file_name}
@classmethod
def read(
cls,
path: str,
document_type: Optional[Type[D]] = None,
file_name: str = "documents.jsonl",
metadata_file_name: str = METADATA_FILE_NAME,
split: Optional[str] = None,
) -> List[D]:
realpath = os.path.realpath(path)
log.info(f'load documents from "{realpath}" ...')
# try to load metadata including the document_type
full_metadata_file_name = os.path.join(realpath, metadata_file_name)
if os.path.exists(full_metadata_file_name):
with open(full_metadata_file_name) as f:
metadata = json.load(f)
document_type = resolve_optional_document_type(metadata.get("document_type"))
if document_type is None:
raise Exception("document_type is required to load serialized documents")
if split is not None:
realpath = os.path.join(realpath, split)
full_file_name = os.path.join(realpath, file_name)
documents = []
if as_json_lines(str(file_name)):
with open(full_file_name) as f:
for line in f:
json_dict = json.loads(line)
documents.append(document_type.fromdict(json_dict))
else:
with open(full_file_name) as f:
json_list = json.load(f)
for json_dict in json_list:
documents.append(document_type.fromdict(json_dict))
return documents
def read_with_defaults(self, **kwargs) -> List[D]:
all_kwargs = {**self.default_kwargs, **kwargs}
return self.read(**all_kwargs)
def write_with_defaults(self, **kwargs) -> Dict[str, str]:
all_kwargs = {**self.default_kwargs, **kwargs}
return self.write(**all_kwargs)
def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
return self.write_with_defaults(documents=documents, **kwargs)
class JsonSerializer2(DocumentSerializer):
def __init__(self, **kwargs):
self.default_kwargs = kwargs
@classmethod
def write(
cls,
documents: Iterable[Document],
path: str,
split: str = "train",
) -> Dict[str, str]:
if not isinstance(documents, (Dataset, IterableDataset)):
if not isinstance(documents, Sequence):
documents = IterableDataset.from_documents(documents)
else:
documents = Dataset.from_documents(documents)
dataset_dict = DatasetDict({split: documents})
dataset_dict.to_json(path=path)
return {"path": path, "split": split}
@classmethod
def read(
cls,
path: str,
document_type: Optional[Type[D]] = None,
split: Optional[str] = None,
) -> Dataset[Document]:
dataset_dict = DatasetDict.from_json(
data_dir=path, document_type=document_type, split=split
)
if split is not None:
return dataset_dict[split]
if len(dataset_dict) == 1:
return dataset_dict[list(dataset_dict.keys())[0]]
raise ValueError(f"multiple splits found in dataset_dict: {list(dataset_dict.keys())}")
def read_with_defaults(self, **kwargs) -> Sequence[D]:
all_kwargs = {**self.default_kwargs, **kwargs}
return self.read(**all_kwargs)
def write_with_defaults(self, **kwargs) -> Dict[str, str]:
all_kwargs = {**self.default_kwargs, **kwargs}
return self.write(**all_kwargs)
def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
return self.write_with_defaults(documents=documents, **kwargs)