File size: 4,636 Bytes
2cc87ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import logging
import os
import shutil
import tempfile
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
class SerializableStore(ABC):
"""Abstract base class for serializable stores."""
@abstractmethod
def _save_to_directory(self, path: str, **kwargs) -> None:
"""Save the store to a directory.
Args:
path (str): The path to a directory.
"""
@abstractmethod
def _load_from_directory(self, path: str, **kwargs) -> None:
"""Load the store from a directory.
Args:
path (str): The path to the file.
"""
def save_to_directory(self, path: str, **kwargs) -> None:
"""Save the store to a directory.
Args:
path (str): The path to a directory.
"""
os.makedirs(path, exist_ok=True)
self._save_to_directory(path, **kwargs)
def load_from_directory(self, path: str, **kwargs) -> None:
"""Load the store from a directory.
Args:
path (str): The path to the file.
"""
if not os.path.exists(path):
raise FileNotFoundError(f'Directory "{path}" not found.')
self._load_from_directory(path, **kwargs)
def save_to_archive(
self,
base_name: str,
format: str,
) -> str:
"""Save the store to an archive.
Args:
base_name (str): The base name of the archive.
format (str): The format of the archive.
Returns:
str: The path to the archive.
"""
temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store")
# remove the temporary directory if it already exists
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# save the documents to the directory
self.save_to_directory(temp_dir)
# zip the directory
result_file_path = shutil.make_archive(
base_name=base_name, root_dir=temp_dir, format=format
)
# remove the temporary directory
shutil.rmtree(temp_dir)
return result_file_path
def load_from_archive(
self,
file_name: str,
) -> None:
"""Load the store from an archive.
Args:
file_name (str): The path to the archive.
"""
if not os.path.exists(file_name):
raise FileNotFoundError(f'Archive file "{file_name}" not found.')
temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store")
# remove the temporary directory if it already exists
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# unzip the file
shutil.unpack_archive(file_name, temp_dir)
# load the documents from the directory
self.load_from_directory(temp_dir)
# remove the temporary directory
shutil.rmtree(temp_dir)
def save_to_disc(self, path: str) -> None:
"""Save the store to disc. Depending on the path, this can be a directory or an archive.
Args:
path (str): The path to a directory or an archive.
"""
# if it is a zip file, save to archive
if path.lower().endswith(".zip"):
logger.info(f"Saving to archive at {path} ...")
# get base name without extension
base_name = os.path.splitext(path)[0]
result_path = self.save_to_archive(base_name, format="zip")
if not result_path.endswith(path):
logger.warning(f"Saved to {result_path} instead of {path}.")
# if it does not have an extension, save to directory
elif not os.path.splitext(path)[1]:
logger.info(f"Saving to directory at {path} ...")
self.save_to_directory(path)
else:
raise ValueError("Unsupported file format. Only .zip and directories are supported.")
def load_from_disc(self, path: str) -> None:
"""Load the store from disc. Depending on the path, this can be a directory or an archive.
Args:
path (str): The path to a directory or an archive.
"""
# if it is a zip file, load from archive
if path.lower().endswith(".zip"):
logger.info(f"Loading from archive at {path} ...")
self.load_from_archive(path)
# if it is a directory, load from directory
elif os.path.isdir(path):
logger.info(f"Loading from directory at {path} ...")
self.load_from_directory(path)
else:
raise ValueError("Unsupported file format. Only .zip and directories are supported.")
|