File size: 2,990 Bytes
7836cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import contextvars
import importlib
from contextlib import contextmanager
from typing import Any, Type


def get_class_from_str(class_str: str, package: str | None = None) -> Type[Any]:
    """
    Converts a string to the corresponding class object, supporting relative imports.
    For relative module paths (starting with '.'), a package must be provided.

    Args:
        class_str: String representation of the class, either absolute or relative.
        package: Package context, only required for relative imports.

    Returns: Class object corresponding to the provided string.
    """
    if not isinstance(class_str, str) and isinstance(class_str, type):
        return class_str

    module_path, _, class_name = class_str.rpartition(".")
    if not module_path and class_str.startswith("."):
        module_path = "."
    if module_path.startswith("."):
        if not package:
            raise ValueError("Relative module path provided without a package context.")
        module = importlib.import_module(module_path, package=package)
    else:
        module = importlib.import_module(module_path)
    return getattr(module, class_name)


def get_str_from_class(cls: Type[Any], package: str | None = None) -> str:
    """
    Converts a class object to its string representation.
    If a package is provided and the class's module is a submodule of the package,
    the returned string will use a relative import.
    Otherwise, an absolute import string is returned.

    Args:
        cls: Class object to convert.
        package: Package context, only required for relative imports.

    Returns: String representation of the class.
    """
    if isinstance(cls, str):
        return cls

    module_path = cls.__module__
    class_name = cls.__name__

    if package:
        # When class is defined directly in the package's __init__.py
        if module_path == package:
            return f".{class_name}"
        # When class is in a submodule of the package
        elif module_path.startswith(package + "."):
            # Get the relative part (including the dot)
            relative = module_path[len(package) :]
            if not relative.startswith("."):
                relative = "." + relative
            return f"{relative}.{class_name}"
    return f"{module_path}.{class_name}"


use_init_empty_weights = contextvars.ContextVar("init_empty_weights", default=False)


@contextmanager
def init_empty_weights(value: bool):
    """
    Context manager to indicate that a (parametrized) model should be initialized with empty weights or not.
    If active, `use_init_empty_weights` will be set to `True` otherwise to `False`.
    To check if the context is active, import and check `use_init_empty_weights.get()`.

    Args:
        value: Indicates whether the model should be initialized with empty weights or not.
    """
    token = use_init_empty_weights.set(value)
    try:
        yield
    finally:
        use_init_empty_weights.reset(token)