acip_llama1_7b / utils.py
martingenzel's picture
Add model
7836cdd verified
raw
history blame
2.99 kB
import contextvars
import importlib
from contextlib import contextmanager
from typing import Any, Type
def get_class_from_str(class_str: str, package: str | None = None) -> Type[Any]:
"""
Converts a string to the corresponding class object, supporting relative imports.
For relative module paths (starting with '.'), a package must be provided.
Args:
class_str: String representation of the class, either absolute or relative.
package: Package context, only required for relative imports.
Returns: Class object corresponding to the provided string.
"""
if not isinstance(class_str, str) and isinstance(class_str, type):
return class_str
module_path, _, class_name = class_str.rpartition(".")
if not module_path and class_str.startswith("."):
module_path = "."
if module_path.startswith("."):
if not package:
raise ValueError("Relative module path provided without a package context.")
module = importlib.import_module(module_path, package=package)
else:
module = importlib.import_module(module_path)
return getattr(module, class_name)
def get_str_from_class(cls: Type[Any], package: str | None = None) -> str:
"""
Converts a class object to its string representation.
If a package is provided and the class's module is a submodule of the package,
the returned string will use a relative import.
Otherwise, an absolute import string is returned.
Args:
cls: Class object to convert.
package: Package context, only required for relative imports.
Returns: String representation of the class.
"""
if isinstance(cls, str):
return cls
module_path = cls.__module__
class_name = cls.__name__
if package:
# When class is defined directly in the package's __init__.py
if module_path == package:
return f".{class_name}"
# When class is in a submodule of the package
elif module_path.startswith(package + "."):
# Get the relative part (including the dot)
relative = module_path[len(package) :]
if not relative.startswith("."):
relative = "." + relative
return f"{relative}.{class_name}"
return f"{module_path}.{class_name}"
use_init_empty_weights = contextvars.ContextVar("init_empty_weights", default=False)
@contextmanager
def init_empty_weights(value: bool):
"""
Context manager to indicate that a (parametrized) model should be initialized with empty weights or not.
If active, `use_init_empty_weights` will be set to `True` otherwise to `False`.
To check if the context is active, import and check `use_init_empty_weights.get()`.
Args:
value: Indicates whether the model should be initialized with empty weights or not.
"""
token = use_init_empty_weights.set(value)
try:
yield
finally:
use_init_empty_weights.reset(token)