File size: 7,148 Bytes
3133b5e e7eaeed 3133b5e e7eaeed 3133b5e e7eaeed 3133b5e e7eaeed 3133b5e e7eaeed 3133b5e e7eaeed 3133b5e e7eaeed 3133b5e e7eaeed 3133b5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import json
import os
from typing import Dict, Iterable, List, Optional, Sequence, Type, TypeVar
from pie_datasets import Dataset, DatasetDict, IterableDataset
from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
from pytorch_ie.core import Document
from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
from src.serializer.interface import DocumentSerializer
from src.utils.logging_utils import get_pylogger
log = get_pylogger(__name__)
D = TypeVar("D", bound=Document)
def as_json_lines(file_name: str) -> bool:
if file_name.lower().endswith(".jsonl"):
return True
elif file_name.lower().endswith(".json"):
return False
else:
raise Exception(f"unknown file extension: {file_name}")
class JsonSerializer(DocumentSerializer):
def __init__(self, **kwargs):
self.default_kwargs = kwargs
@classmethod
def write(
cls,
documents: Iterable[Document],
path: str,
file_name: str = "documents.jsonl",
metadata_file_name: str = METADATA_FILE_NAME,
split: Optional[str] = None,
**kwargs,
) -> Dict[str, str]:
realpath = os.path.realpath(path)
log.info(f'serialize documents to "{realpath}" ...')
os.makedirs(realpath, exist_ok=True)
if not isinstance(documents, Sequence):
documents = list(documents)
# dump metadata including the document_type
if len(documents) == 0:
raise Exception("cannot serialize empty list of documents")
document_type = type(documents[0])
metadata = {"document_type": serialize_document_type(document_type)}
full_metadata_file_name = os.path.join(realpath, metadata_file_name)
if os.path.exists(full_metadata_file_name):
# load previous metadata
with open(full_metadata_file_name) as f:
previous_metadata = json.load(f)
if previous_metadata != metadata:
raise ValueError(
f"metadata file {full_metadata_file_name} already exists, "
"but the content does not match the current metadata"
"\nprevious metadata: {previous_metadata}"
"\ncurrent metadata: {metadata}"
)
else:
with open(full_metadata_file_name, "w") as f:
json.dump(metadata, f, indent=2)
if split is not None:
realpath = os.path.join(realpath, split)
os.makedirs(realpath, exist_ok=True)
full_file_name = os.path.join(realpath, file_name)
if as_json_lines(file_name):
# if the file already exists, append to it
mode = "a" if os.path.exists(full_file_name) else "w"
with open(full_file_name, mode) as f:
for doc in documents:
f.write(json.dumps(doc.asdict(), **kwargs) + "\n")
else:
docs_list = [doc.asdict() for doc in documents]
if os.path.exists(full_file_name):
# load previous documents
with open(full_file_name) as f:
previous_doc_list = json.load(f)
docs_list = previous_doc_list + docs_list
with open(full_file_name, "w") as f:
json.dump(docs_list, fp=f, **kwargs)
return {"path": realpath, "file_name": file_name, "metadata_file_name": metadata_file_name}
@classmethod
def read(
cls,
path: str,
document_type: Optional[Type[D]] = None,
file_name: str = "documents.jsonl",
metadata_file_name: str = METADATA_FILE_NAME,
split: Optional[str] = None,
) -> List[D]:
realpath = os.path.realpath(path)
log.info(f'load documents from "{realpath}" ...')
# try to load metadata including the document_type
full_metadata_file_name = os.path.join(realpath, metadata_file_name)
if os.path.exists(full_metadata_file_name):
with open(full_metadata_file_name) as f:
metadata = json.load(f)
document_type = resolve_optional_document_type(metadata.get("document_type"))
if document_type is None:
raise Exception("document_type is required to load serialized documents")
if split is not None:
realpath = os.path.join(realpath, split)
full_file_name = os.path.join(realpath, file_name)
documents = []
if as_json_lines(str(file_name)):
with open(full_file_name) as f:
for line in f:
json_dict = json.loads(line)
documents.append(document_type.fromdict(json_dict))
else:
with open(full_file_name) as f:
json_list = json.load(f)
for json_dict in json_list:
documents.append(document_type.fromdict(json_dict))
return documents
def read_with_defaults(self, **kwargs) -> List[D]:
all_kwargs = {**self.default_kwargs, **kwargs}
return self.read(**all_kwargs)
def write_with_defaults(self, **kwargs) -> Dict[str, str]:
all_kwargs = {**self.default_kwargs, **kwargs}
return self.write(**all_kwargs)
def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
return self.write_with_defaults(documents=documents, **kwargs)
class JsonSerializer2(DocumentSerializer):
def __init__(self, **kwargs):
self.default_kwargs = kwargs
@classmethod
def write(
cls,
documents: Iterable[Document],
path: str,
split: str = "train",
) -> Dict[str, str]:
if not isinstance(documents, (Dataset, IterableDataset)):
if not isinstance(documents, Sequence):
documents = IterableDataset.from_documents(documents)
else:
documents = Dataset.from_documents(documents)
dataset_dict = DatasetDict({split: documents})
dataset_dict.to_json(path=path)
return {"path": path, "split": split}
@classmethod
def read(
cls,
path: str,
document_type: Optional[Type[D]] = None,
split: Optional[str] = None,
) -> Dataset[Document]:
dataset_dict = DatasetDict.from_json(
data_dir=path, document_type=document_type, split=split
)
if split is not None:
return dataset_dict[split]
if len(dataset_dict) == 1:
return dataset_dict[list(dataset_dict.keys())[0]]
raise ValueError(f"multiple splits found in dataset_dict: {list(dataset_dict.keys())}")
def read_with_defaults(self, **kwargs) -> Sequence[D]:
all_kwargs = {**self.default_kwargs, **kwargs}
return self.read(**all_kwargs)
def write_with_defaults(self, **kwargs) -> Dict[str, str]:
all_kwargs = {**self.default_kwargs, **kwargs}
return self.write(**all_kwargs)
def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
return self.write_with_defaults(documents=documents, **kwargs)
|