|
import logging |
|
import os |
|
import shutil |
|
import tempfile |
|
from abc import ABC, abstractmethod |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SerializableStore(ABC): |
|
"""Abstract base class for serializable stores.""" |
|
|
|
@abstractmethod |
|
def _save_to_directory(self, path: str, **kwargs) -> None: |
|
"""Save the store to a directory. |
|
|
|
Args: |
|
path (str): The path to a directory. |
|
""" |
|
|
|
@abstractmethod |
|
def _load_from_directory(self, path: str, **kwargs) -> None: |
|
"""Load the store from a directory. |
|
|
|
Args: |
|
path (str): The path to the file. |
|
""" |
|
|
|
def save_to_directory(self, path: str, **kwargs) -> None: |
|
"""Save the store to a directory. |
|
|
|
Args: |
|
path (str): The path to a directory. |
|
""" |
|
os.makedirs(path, exist_ok=True) |
|
self._save_to_directory(path, **kwargs) |
|
|
|
def load_from_directory(self, path: str, **kwargs) -> None: |
|
"""Load the store from a directory. |
|
|
|
Args: |
|
path (str): The path to the file. |
|
""" |
|
if not os.path.exists(path): |
|
raise FileNotFoundError(f'Directory "{path}" not found.') |
|
|
|
self._load_from_directory(path, **kwargs) |
|
|
|
def save_to_archive( |
|
self, |
|
base_name: str, |
|
format: str, |
|
) -> str: |
|
"""Save the store to an archive. |
|
|
|
Args: |
|
base_name (str): The base name of the archive. |
|
format (str): The format of the archive. |
|
|
|
Returns: |
|
str: The path to the archive. |
|
""" |
|
temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store") |
|
|
|
if os.path.exists(temp_dir): |
|
shutil.rmtree(temp_dir) |
|
|
|
self.save_to_directory(temp_dir) |
|
|
|
result_file_path = shutil.make_archive( |
|
base_name=base_name, root_dir=temp_dir, format=format |
|
) |
|
|
|
shutil.rmtree(temp_dir) |
|
return result_file_path |
|
|
|
def load_from_archive( |
|
self, |
|
file_name: str, |
|
) -> None: |
|
"""Load the store from an archive. |
|
|
|
Args: |
|
file_name (str): The path to the archive. |
|
""" |
|
if not os.path.exists(file_name): |
|
raise FileNotFoundError(f'Archive file "{file_name}" not found.') |
|
|
|
temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store") |
|
|
|
if os.path.exists(temp_dir): |
|
shutil.rmtree(temp_dir) |
|
|
|
|
|
shutil.unpack_archive(file_name, temp_dir) |
|
|
|
self.load_from_directory(temp_dir) |
|
|
|
shutil.rmtree(temp_dir) |
|
|
|
def save_to_disc(self, path: str) -> None: |
|
"""Save the store to disc. Depending on the path, this can be a directory or an archive. |
|
|
|
Args: |
|
path (str): The path to a directory or an archive. |
|
""" |
|
|
|
if path.lower().endswith(".zip"): |
|
logger.info(f"Saving to archive at {path} ...") |
|
|
|
base_name = os.path.splitext(path)[0] |
|
result_path = self.save_to_archive(base_name, format="zip") |
|
if not result_path.endswith(path): |
|
logger.warning(f"Saved to {result_path} instead of {path}.") |
|
|
|
elif not os.path.splitext(path)[1]: |
|
logger.info(f"Saving to directory at {path} ...") |
|
self.save_to_directory(path) |
|
else: |
|
raise ValueError("Unsupported file format. Only .zip and directories are supported.") |
|
|
|
def load_from_disc(self, path: str) -> None: |
|
"""Load the store from disc. Depending on the path, this can be a directory or an archive. |
|
|
|
Args: |
|
path (str): The path to a directory or an archive. |
|
""" |
|
|
|
if path.lower().endswith(".zip"): |
|
logger.info(f"Loading from archive at {path} ...") |
|
self.load_from_archive(path) |
|
|
|
elif os.path.isdir(path): |
|
logger.info(f"Loading from directory at {path} ...") |
|
self.load_from_directory(path) |
|
else: |
|
raise ValueError("Unsupported file format. Only .zip and directories are supported.") |
|
|