File size: 4,636 Bytes
2cc87ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import logging
import os
import shutil
import tempfile
from abc import ABC, abstractmethod

logger = logging.getLogger(__name__)


class SerializableStore(ABC):
    """Abstract base class for serializable stores."""

    @abstractmethod
    def _save_to_directory(self, path: str, **kwargs) -> None:
        """Save the store to a directory.

        Args:
            path (str): The path to a directory.
        """

    @abstractmethod
    def _load_from_directory(self, path: str, **kwargs) -> None:
        """Load the store from a directory.

        Args:
            path (str): The path to the file.
        """

    def save_to_directory(self, path: str, **kwargs) -> None:
        """Save the store to a directory.

        Args:
            path (str): The path to a directory.
        """
        os.makedirs(path, exist_ok=True)
        self._save_to_directory(path, **kwargs)

    def load_from_directory(self, path: str, **kwargs) -> None:
        """Load the store from a directory.

        Args:
            path (str): The path to the file.
        """
        if not os.path.exists(path):
            raise FileNotFoundError(f'Directory "{path}" not found.')

        self._load_from_directory(path, **kwargs)

    def save_to_archive(
        self,
        base_name: str,
        format: str,
    ) -> str:
        """Save the store to an archive.

        Args:
            base_name (str): The base name of the archive.
            format (str): The format of the archive.

        Returns:
            str: The path to the archive.
        """
        temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store")
        # remove the temporary directory if it already exists
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
        # save the documents to the directory
        self.save_to_directory(temp_dir)
        # zip the directory
        result_file_path = shutil.make_archive(
            base_name=base_name, root_dir=temp_dir, format=format
        )
        # remove the temporary directory
        shutil.rmtree(temp_dir)
        return result_file_path

    def load_from_archive(
        self,
        file_name: str,
    ) -> None:
        """Load the store from an archive.

        Args:
            file_name (str): The path to the archive.
        """
        if not os.path.exists(file_name):
            raise FileNotFoundError(f'Archive file "{file_name}" not found.')

        temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store")
        # remove the temporary directory if it already exists
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)

        # unzip the file
        shutil.unpack_archive(file_name, temp_dir)
        # load the documents from the directory
        self.load_from_directory(temp_dir)
        # remove the temporary directory
        shutil.rmtree(temp_dir)

    def save_to_disc(self, path: str) -> None:
        """Save the store to disc. Depending on the path, this can be a directory or an archive.

        Args:
            path (str): The path to a directory or an archive.
        """
        # if it is a zip file, save to archive
        if path.lower().endswith(".zip"):
            logger.info(f"Saving to archive at {path} ...")
            # get base name without extension
            base_name = os.path.splitext(path)[0]
            result_path = self.save_to_archive(base_name, format="zip")
            if not result_path.endswith(path):
                logger.warning(f"Saved to {result_path} instead of {path}.")
        # if it does not have an extension, save to directory
        elif not os.path.splitext(path)[1]:
            logger.info(f"Saving to directory at {path} ...")
            self.save_to_directory(path)
        else:
            raise ValueError("Unsupported file format. Only .zip and directories are supported.")

    def load_from_disc(self, path: str) -> None:
        """Load the store from disc. Depending on the path, this can be a directory or an archive.

        Args:
            path (str): The path to a directory or an archive.
        """
        # if it is a zip file, load from archive
        if path.lower().endswith(".zip"):
            logger.info(f"Loading from archive at {path} ...")
            self.load_from_archive(path)
        # if it is a directory, load from directory
        elif os.path.isdir(path):
            logger.info(f"Loading from directory at {path} ...")
            self.load_from_directory(path)
        else:
            raise ValueError("Unsupported file format. Only .zip and directories are supported.")