|
import contextvars |
|
import importlib |
|
from contextlib import contextmanager |
|
from typing import Any, Type |
|
|
|
|
|
def get_class_from_str(class_str: str, package: str | None = None) -> Type[Any]: |
|
""" |
|
Converts a string to the corresponding class object, supporting relative imports. |
|
For relative module paths (starting with '.'), a package must be provided. |
|
|
|
Args: |
|
class_str: String representation of the class, either absolute or relative. |
|
package: Package context, only required for relative imports. |
|
|
|
Returns: Class object corresponding to the provided string. |
|
""" |
|
if not isinstance(class_str, str) and isinstance(class_str, type): |
|
return class_str |
|
|
|
module_path, _, class_name = class_str.rpartition(".") |
|
if not module_path and class_str.startswith("."): |
|
module_path = "." |
|
if module_path.startswith("."): |
|
if not package: |
|
raise ValueError("Relative module path provided without a package context.") |
|
module = importlib.import_module(module_path, package=package) |
|
else: |
|
module = importlib.import_module(module_path) |
|
return getattr(module, class_name) |
|
|
|
|
|
def get_str_from_class(cls: Type[Any], package: str | None = None) -> str: |
|
""" |
|
Converts a class object to its string representation. |
|
If a package is provided and the class's module is a submodule of the package, |
|
the returned string will use a relative import. |
|
Otherwise, an absolute import string is returned. |
|
|
|
Args: |
|
cls: Class object to convert. |
|
package: Package context, only required for relative imports. |
|
|
|
Returns: String representation of the class. |
|
""" |
|
if isinstance(cls, str): |
|
return cls |
|
|
|
module_path = cls.__module__ |
|
class_name = cls.__name__ |
|
|
|
if package: |
|
|
|
if module_path == package: |
|
return f".{class_name}" |
|
|
|
elif module_path.startswith(package + "."): |
|
|
|
relative = module_path[len(package) :] |
|
if not relative.startswith("."): |
|
relative = "." + relative |
|
return f"{relative}.{class_name}" |
|
return f"{module_path}.{class_name}" |
|
|
|
|
|
use_init_empty_weights = contextvars.ContextVar("init_empty_weights", default=False) |
|
|
|
|
|
@contextmanager |
|
def init_empty_weights(value: bool): |
|
""" |
|
Context manager to indicate that a (parametrized) model should be initialized with empty weights or not. |
|
If active, `use_init_empty_weights` will be set to `True` otherwise to `False`. |
|
To check if the context is active, import and check `use_init_empty_weights.get()`. |
|
|
|
Args: |
|
value: Indicates whether the model should be initialized with empty weights or not. |
|
""" |
|
token = use_init_empty_weights.set(value) |
|
try: |
|
yield |
|
finally: |
|
use_init_empty_weights.reset(token) |
|
|