ArneBinder's picture
new demo setup with langchain retriever
2cc87ec verified
import logging
import os
import shutil
import tempfile
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
class SerializableStore(ABC):
"""Abstract base class for serializable stores."""
@abstractmethod
def _save_to_directory(self, path: str, **kwargs) -> None:
"""Save the store to a directory.
Args:
path (str): The path to a directory.
"""
@abstractmethod
def _load_from_directory(self, path: str, **kwargs) -> None:
"""Load the store from a directory.
Args:
path (str): The path to the file.
"""
def save_to_directory(self, path: str, **kwargs) -> None:
"""Save the store to a directory.
Args:
path (str): The path to a directory.
"""
os.makedirs(path, exist_ok=True)
self._save_to_directory(path, **kwargs)
def load_from_directory(self, path: str, **kwargs) -> None:
"""Load the store from a directory.
Args:
path (str): The path to the file.
"""
if not os.path.exists(path):
raise FileNotFoundError(f'Directory "{path}" not found.')
self._load_from_directory(path, **kwargs)
def save_to_archive(
self,
base_name: str,
format: str,
) -> str:
"""Save the store to an archive.
Args:
base_name (str): The base name of the archive.
format (str): The format of the archive.
Returns:
str: The path to the archive.
"""
temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store")
# remove the temporary directory if it already exists
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# save the documents to the directory
self.save_to_directory(temp_dir)
# zip the directory
result_file_path = shutil.make_archive(
base_name=base_name, root_dir=temp_dir, format=format
)
# remove the temporary directory
shutil.rmtree(temp_dir)
return result_file_path
def load_from_archive(
self,
file_name: str,
) -> None:
"""Load the store from an archive.
Args:
file_name (str): The path to the archive.
"""
if not os.path.exists(file_name):
raise FileNotFoundError(f'Archive file "{file_name}" not found.')
temp_dir = os.path.join(tempfile.gettempdir(), "retriever_store")
# remove the temporary directory if it already exists
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# unzip the file
shutil.unpack_archive(file_name, temp_dir)
# load the documents from the directory
self.load_from_directory(temp_dir)
# remove the temporary directory
shutil.rmtree(temp_dir)
def save_to_disc(self, path: str) -> None:
"""Save the store to disc. Depending on the path, this can be a directory or an archive.
Args:
path (str): The path to a directory or an archive.
"""
# if it is a zip file, save to archive
if path.lower().endswith(".zip"):
logger.info(f"Saving to archive at {path} ...")
# get base name without extension
base_name = os.path.splitext(path)[0]
result_path = self.save_to_archive(base_name, format="zip")
if not result_path.endswith(path):
logger.warning(f"Saved to {result_path} instead of {path}.")
# if it does not have an extension, save to directory
elif not os.path.splitext(path)[1]:
logger.info(f"Saving to directory at {path} ...")
self.save_to_directory(path)
else:
raise ValueError("Unsupported file format. Only .zip and directories are supported.")
def load_from_disc(self, path: str) -> None:
"""Load the store from disc. Depending on the path, this can be a directory or an archive.
Args:
path (str): The path to a directory or an archive.
"""
# if it is a zip file, load from archive
if path.lower().endswith(".zip"):
logger.info(f"Loading from archive at {path} ...")
self.load_from_archive(path)
# if it is a directory, load from directory
elif os.path.isdir(path):
logger.info(f"Loading from directory at {path} ...")
self.load_from_directory(path)
else:
raise ValueError("Unsupported file format. Only .zip and directories are supported.")