import logging import os import shutil import tempfile from abc import ABC, abstractmethod logger = logging.getLogger(__name__) class SerializableStore(ABC): """Abstract base class for serializable stores.""" @abstractmethod def _save_to_directory(self, path: str, **kwargs) -> None: """Save the store to a directory. Args: path (str): The path to a directory. """ @abstractmethod def _load_from_directory(self, path: str, **kwargs) -> None: """Load the store from a directory. Args: path (str): The path to the file. """ def save_to_directory(self, path: str, **kwargs) -> None: """Save the store to a directory. Args: path (str): The path to a directory. """ os.makedirs(path, exist_ok=True) self._save_to_directory(path, **kwargs) def load_from_directory(self, path: str, **kwargs) -> None: """Load the store from a directory. Args: path (str): The path to the file. """ if not os.path.exists(path): raise FileNotFoundError(f'Directory "{path}" not found.') self._load_from_directory(path, **kwargs) def save_to_archive( self, base_name: str, format: str, ) -> str: """Save the store to an archive. Args: base_name (str): The base name of the archive. format (str): The format of the archive. Returns: str: The path to the archive. """ temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store") # remove the temporary directory if it already exists if os.path.exists(temp_dir): shutil.rmtree(temp_dir) # save the documents to the directory self.save_to_directory(temp_dir) # zip the directory result_file_path = shutil.make_archive( base_name=base_name, root_dir=temp_dir, format=format ) # remove the temporary directory shutil.rmtree(temp_dir) return result_file_path def load_from_archive( self, file_name: str, ) -> None: """Load the store from an archive. Args: file_name (str): The path to the archive. """ if not os.path.exists(file_name): raise FileNotFoundError(f'Archive file "{file_name}" not found.') temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store") # remove the temporary directory if it already exists if os.path.exists(temp_dir): shutil.rmtree(temp_dir) # unzip the file shutil.unpack_archive(file_name, temp_dir) # load the documents from the directory self.load_from_directory(temp_dir) # remove the temporary directory shutil.rmtree(temp_dir) def save_to_disc(self, path: str) -> None: """Save the store to disc. Depending on the path, this can be a directory or an archive. Args: path (str): The path to a directory or an archive. """ # if it is a zip file, save to archive if path.lower().endswith(".zip"): logger.info(f"Saving to archive at {path} ...") # get base name without extension base_name = os.path.splitext(path)[0] result_path = self.save_to_archive(base_name, format="zip") if not result_path.endswith(path): logger.warning(f"Saved to {result_path} instead of {path}.") # if it does not have an extension, save to directory elif not os.path.splitext(path)[1]: logger.info(f"Saving to directory at {path} ...") self.save_to_directory(path) else: raise ValueError("Unsupported file format. Only .zip and directories are supported.") def load_from_disc(self, path: str) -> None: """Load the store from disc. Depending on the path, this can be a directory or an archive. Args: path (str): The path to a directory or an archive. """ # if it is a zip file, load from archive if path.lower().endswith(".zip"): logger.info(f"Loading from archive at {path} ...") self.load_from_archive(path) # if it is a directory, load from directory elif os.path.isdir(path): logger.info(f"Loading from directory at {path} ...") self.load_from_directory(path) else: raise ValueError("Unsupported file format. Only .zip and directories are supported.")