File size: 7,685 Bytes
7836cdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
from abc import ABC, abstractmethod
from typing import ClassVar, Literal, Protocol, runtime_checkable, Type
import torch
from torch import nn
class Parametrization(nn.Module, ABC):
"""
Abstract base class for parametrizations.
A parametrization can be injected into any torch module of type `base_class` by `parametrize_module`.
A parametrized module will follow the `ParametrizedModule` interface.
This will overload the weight, bias, and forward of the module so that they play together with
the parametrization. The external behavior of the parametrized module remains unchanged, for instance,
a parametrized `Linear` module will still work as expected.
Attributes:
base_class: The base class of the module that can be parametrized.
initialized: A flag that indicates whether the parametrization has been initialized.
"""
initialized: bool = False
base_class: ClassVar[Type[nn.Module]]
def initialize(self, base_module: "Parametrization.base_class") -> None:
self._initialize(base_module)
self.initialized = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute the forward pass of the parametrization.
This is particularly important when a standard forward pass based on `weight` would be inefficient.
"""
assert self.initialized
x = self._forward(x)
return x
@property
def weight(self) -> torch.Tensor:
"""Compute the weight tensor of the parametrization."""
return self._weight()
@property
def bias(self) -> torch.Tensor | None:
"""Compute the bias tensor of the parametrization."""
return self._bias()
@abstractmethod
def _forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def _initialize(self, base_module: "Parametrization.base_class") -> None:
"""
Initialize the parametrization based on a given base module.
This method should build the internal representation the module's weight and bias,
registering all required buffers and parameters in `self`.
"""
raise NotImplementedError
@abstractmethod
def _weight(self) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def _bias(self) -> torch.Tensor | None:
raise NotImplementedError
@abstractmethod
def get_target_params(self) -> dict[str, torch.nn.Parameter]:
"""
Return the (tunable) target parameters of the parametrization.
Here, "target parameters" means that they can be tuned and potentially compressed
by `self.reset_target_params(mode="compress")`.
Other torch parameters of the module could be tuned as well, but should not returned here.
The returned dictionary should be compatible with `self.named_parameters()`.
See Also:
- `ParametrizedModel.get_target_params`
- `ParametrizedModel.compress`
"""
raise NotImplementedError
@abstractmethod
def reset_target_params(self, mode: Literal["full", "nonzero", "compress"] = "full") -> None:
"""
Reset the target parameters of the parametrization according to a given mode.
Args:
mode: The reset mode.
"full" means reset to original value at initialization.
"nonzero" means reset all non-zero values to original value at initialization.
"compress" means the all zero values are removed and the the parameters are compressed accordingly.
"""
raise NotImplementedError
@abstractmethod
def get_num_params(self, compressed: bool = False, target_params: dict[str, torch.Tensor] | None = None) -> int:
"""
Computes the (effective) number of parameters of the parametrization.
Args:
compressed: Whether to count the number of parameters as if the module was actually compressed.
If `False`, the number of parameters is the same as in the original module.
target_params: Count the number of parameters as if `target_params` were used instead of
`self.get_target_params()`. This "what if" feature is important when pruning
a full `ParametrizedModel` to a certain target ratio.
"""
raise NotImplementedError
@runtime_checkable
class ParametrizedModule(Protocol):
"""
Interface for a parametrized `nn.Module`.
It ensures that `weight` and `bias` are forwarded to the `Parametrization` instance.
Attributes:
parametrization: The `Parametrization` instance of the module.
_forward: The original forward function of the module.
__old_class__: The original class of the module.
Notes:
`_forward` and `__old_class__` are used by `parametrize_module` and `unparametrize_module`
to allow restoring the original behavior of the module.
"""
parametrization: Parametrization
_forward: callable
__old_class__: type[nn.Module]
@property
def weight(self):
return self.parametrization.weight
@property
def bias(self):
return self.parametrization.bias
def parametrize_module(module: nn.Module, parametrization: Parametrization) -> ParametrizedModule and nn.Module:
"""
Parametrize a module using a `Parametrization` instance.
Args:
module: The module to be parametrized.
parametrization: The `Parametrization` instance to be applied to the module.
Returns: The parametrized module using the `ParametrizedModule` interface.
Notes:
Adopted from https://stackoverflow.com/a/31075641
"""
assert isinstance(module, parametrization.base_class)
module.__old_class__ = module.__class__
# Initializes the parametrization and adds it to the module
module.add_module("parametrization", parametrization)
module.parametrization.initialize(module)
# Save the original forward in case we want to remove the parametrization again
module._forward = module.forward
# Cast to new parametrized object class type
del module.weight
del module.bias
module.__class__ = type("Parametrized" + module.__class__.__name__, (module.__class__, ParametrizedModule), {})
# Make sure that we utilize the forward function of the parametrization
module.forward = module.parametrization.forward
return module
def unparametrize_module(module: ParametrizedModule) -> nn.Module:
"""
Revert the parametrization of a module.
Args:
module: A module that has been parametrized by `parametrize_module`.
Returns: The original module.
Notes:
Adopted from https://stackoverflow.com/a/31075641
"""
# Make sure to save weight and bias in intermediate variables
weight = module.weight
bias = module.bias
assert isinstance(module, nn.Module)
# This line will remove properties module.weight and module.bias
module.__class__ = type(module.__old_class__.__name__, (module.__old_class__,), {})
delattr(module, "__old_class__")
# Add weight and bias as native parameters to the module again
module.register_parameter("weight", nn.Parameter(weight, weight.requires_grad))
if bias is not None:
module.register_parameter("bias", nn.Parameter(bias, bias.requires_grad))
else:
module.register_parameter("bias", None)
# Recover the original forward pass and get rid of the parametrization
del module.parametrization
module.forward = module._forward
delattr(module, "_forward")
return module
|