File size: 7,148 Bytes
3133b5e
 
e7eaeed
3133b5e
 
 
 
 
 
 
e7eaeed
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7eaeed
3133b5e
 
 
 
 
 
 
 
 
 
e7eaeed
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7eaeed
3133b5e
 
 
 
 
 
 
 
 
 
e7eaeed
3133b5e
 
 
 
e7eaeed
 
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7eaeed
3133b5e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import json
import os
from typing import Dict, Iterable, List, Optional, Sequence, Type, TypeVar

from pie_datasets import Dataset, DatasetDict, IterableDataset
from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
from pytorch_ie.core import Document
from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type

from src.serializer.interface import DocumentSerializer
from src.utils.logging_utils import get_pylogger

log = get_pylogger(__name__)

D = TypeVar("D", bound=Document)


def as_json_lines(file_name: str) -> bool:
    if file_name.lower().endswith(".jsonl"):
        return True
    elif file_name.lower().endswith(".json"):
        return False
    else:
        raise Exception(f"unknown file extension: {file_name}")


class JsonSerializer(DocumentSerializer):
    def __init__(self, **kwargs):
        self.default_kwargs = kwargs

    @classmethod
    def write(
        cls,
        documents: Iterable[Document],
        path: str,
        file_name: str = "documents.jsonl",
        metadata_file_name: str = METADATA_FILE_NAME,
        split: Optional[str] = None,
        **kwargs,
    ) -> Dict[str, str]:
        realpath = os.path.realpath(path)
        log.info(f'serialize documents to "{realpath}" ...')
        os.makedirs(realpath, exist_ok=True)

        if not isinstance(documents, Sequence):
            documents = list(documents)

        # dump metadata including the document_type
        if len(documents) == 0:
            raise Exception("cannot serialize empty list of documents")
        document_type = type(documents[0])
        metadata = {"document_type": serialize_document_type(document_type)}
        full_metadata_file_name = os.path.join(realpath, metadata_file_name)
        if os.path.exists(full_metadata_file_name):
            # load previous metadata
            with open(full_metadata_file_name) as f:
                previous_metadata = json.load(f)
            if previous_metadata != metadata:
                raise ValueError(
                    f"metadata file {full_metadata_file_name} already exists, "
                    "but the content does not match the current metadata"
                    "\nprevious metadata: {previous_metadata}"
                    "\ncurrent metadata: {metadata}"
                )
        else:
            with open(full_metadata_file_name, "w") as f:
                json.dump(metadata, f, indent=2)

        if split is not None:
            realpath = os.path.join(realpath, split)
            os.makedirs(realpath, exist_ok=True)
        full_file_name = os.path.join(realpath, file_name)
        if as_json_lines(file_name):
            # if the file already exists, append to it
            mode = "a" if os.path.exists(full_file_name) else "w"
            with open(full_file_name, mode) as f:
                for doc in documents:
                    f.write(json.dumps(doc.asdict(), **kwargs) + "\n")
        else:
            docs_list = [doc.asdict() for doc in documents]
            if os.path.exists(full_file_name):
                # load previous documents
                with open(full_file_name) as f:
                    previous_doc_list = json.load(f)
                docs_list = previous_doc_list + docs_list
            with open(full_file_name, "w") as f:
                json.dump(docs_list, fp=f, **kwargs)
        return {"path": realpath, "file_name": file_name, "metadata_file_name": metadata_file_name}

    @classmethod
    def read(
        cls,
        path: str,
        document_type: Optional[Type[D]] = None,
        file_name: str = "documents.jsonl",
        metadata_file_name: str = METADATA_FILE_NAME,
        split: Optional[str] = None,
    ) -> List[D]:
        realpath = os.path.realpath(path)
        log.info(f'load documents from "{realpath}" ...')

        # try to load metadata including the document_type
        full_metadata_file_name = os.path.join(realpath, metadata_file_name)
        if os.path.exists(full_metadata_file_name):
            with open(full_metadata_file_name) as f:
                metadata = json.load(f)
            document_type = resolve_optional_document_type(metadata.get("document_type"))

        if document_type is None:
            raise Exception("document_type is required to load serialized documents")

        if split is not None:
            realpath = os.path.join(realpath, split)
        full_file_name = os.path.join(realpath, file_name)
        documents = []
        if as_json_lines(str(file_name)):
            with open(full_file_name) as f:
                for line in f:
                    json_dict = json.loads(line)
                    documents.append(document_type.fromdict(json_dict))
        else:
            with open(full_file_name) as f:
                json_list = json.load(f)
            for json_dict in json_list:
                documents.append(document_type.fromdict(json_dict))
        return documents

    def read_with_defaults(self, **kwargs) -> List[D]:
        all_kwargs = {**self.default_kwargs, **kwargs}
        return self.read(**all_kwargs)

    def write_with_defaults(self, **kwargs) -> Dict[str, str]:
        all_kwargs = {**self.default_kwargs, **kwargs}
        return self.write(**all_kwargs)

    def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
        return self.write_with_defaults(documents=documents, **kwargs)


class JsonSerializer2(DocumentSerializer):
    def __init__(self, **kwargs):
        self.default_kwargs = kwargs

    @classmethod
    def write(
        cls,
        documents: Iterable[Document],
        path: str,
        split: str = "train",
    ) -> Dict[str, str]:
        if not isinstance(documents, (Dataset, IterableDataset)):
            if not isinstance(documents, Sequence):
                documents = IterableDataset.from_documents(documents)
            else:
                documents = Dataset.from_documents(documents)
        dataset_dict = DatasetDict({split: documents})
        dataset_dict.to_json(path=path)
        return {"path": path, "split": split}

    @classmethod
    def read(
        cls,
        path: str,
        document_type: Optional[Type[D]] = None,
        split: Optional[str] = None,
    ) -> Dataset[Document]:
        dataset_dict = DatasetDict.from_json(
            data_dir=path, document_type=document_type, split=split
        )
        if split is not None:
            return dataset_dict[split]
        if len(dataset_dict) == 1:
            return dataset_dict[list(dataset_dict.keys())[0]]
        raise ValueError(f"multiple splits found in dataset_dict: {list(dataset_dict.keys())}")

    def read_with_defaults(self, **kwargs) -> Sequence[D]:
        all_kwargs = {**self.default_kwargs, **kwargs}
        return self.read(**all_kwargs)

    def write_with_defaults(self, **kwargs) -> Dict[str, str]:
        all_kwargs = {**self.default_kwargs, **kwargs}
        return self.write(**all_kwargs)

    def __call__(self, documents: Iterable[Document], **kwargs) -> Dict[str, str]:
        return self.write_with_defaults(documents=documents, **kwargs)