File size: 2,990 Bytes
7836cdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import contextvars
import importlib
from contextlib import contextmanager
from typing import Any, Type
def get_class_from_str(class_str: str, package: str | None = None) -> Type[Any]:
"""
Converts a string to the corresponding class object, supporting relative imports.
For relative module paths (starting with '.'), a package must be provided.
Args:
class_str: String representation of the class, either absolute or relative.
package: Package context, only required for relative imports.
Returns: Class object corresponding to the provided string.
"""
if not isinstance(class_str, str) and isinstance(class_str, type):
return class_str
module_path, _, class_name = class_str.rpartition(".")
if not module_path and class_str.startswith("."):
module_path = "."
if module_path.startswith("."):
if not package:
raise ValueError("Relative module path provided without a package context.")
module = importlib.import_module(module_path, package=package)
else:
module = importlib.import_module(module_path)
return getattr(module, class_name)
def get_str_from_class(cls: Type[Any], package: str | None = None) -> str:
"""
Converts a class object to its string representation.
If a package is provided and the class's module is a submodule of the package,
the returned string will use a relative import.
Otherwise, an absolute import string is returned.
Args:
cls: Class object to convert.
package: Package context, only required for relative imports.
Returns: String representation of the class.
"""
if isinstance(cls, str):
return cls
module_path = cls.__module__
class_name = cls.__name__
if package:
# When class is defined directly in the package's __init__.py
if module_path == package:
return f".{class_name}"
# When class is in a submodule of the package
elif module_path.startswith(package + "."):
# Get the relative part (including the dot)
relative = module_path[len(package) :]
if not relative.startswith("."):
relative = "." + relative
return f"{relative}.{class_name}"
return f"{module_path}.{class_name}"
use_init_empty_weights = contextvars.ContextVar("init_empty_weights", default=False)
@contextmanager
def init_empty_weights(value: bool):
"""
Context manager to indicate that a (parametrized) model should be initialized with empty weights or not.
If active, `use_init_empty_weights` will be set to `True` otherwise to `False`.
To check if the context is active, import and check `use_init_empty_weights.get()`.
Args:
value: Indicates whether the model should be initialized with empty weights or not.
"""
token = use_init_empty_weights.set(value)
try:
yield
finally:
use_init_empty_weights.reset(token)
|