|
from abc import ABC, abstractmethod |
|
from typing import ClassVar, Literal, Protocol, runtime_checkable, Type |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
class Parametrization(nn.Module, ABC): |
|
""" |
|
Abstract base class for parametrizations. |
|
A parametrization can be injected into any torch module of type `base_class` by `parametrize_module`. |
|
A parametrized module will follow the `ParametrizedModule` interface. |
|
|
|
This will overload the weight, bias, and forward of the module so that they play together with |
|
the parametrization. The external behavior of the parametrized module remains unchanged, for instance, |
|
a parametrized `Linear` module will still work as expected. |
|
|
|
Attributes: |
|
base_class: The base class of the module that can be parametrized. |
|
initialized: A flag that indicates whether the parametrization has been initialized. |
|
""" |
|
|
|
initialized: bool = False |
|
base_class: ClassVar[Type[nn.Module]] |
|
|
|
def initialize(self, base_module: "Parametrization.base_class") -> None: |
|
self._initialize(base_module) |
|
self.initialized = True |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Compute the forward pass of the parametrization. |
|
This is particularly important when a standard forward pass based on `weight` would be inefficient. |
|
""" |
|
assert self.initialized |
|
x = self._forward(x) |
|
return x |
|
|
|
@property |
|
def weight(self) -> torch.Tensor: |
|
"""Compute the weight tensor of the parametrization.""" |
|
return self._weight() |
|
|
|
@property |
|
def bias(self) -> torch.Tensor | None: |
|
"""Compute the bias tensor of the parametrization.""" |
|
return self._bias() |
|
|
|
@abstractmethod |
|
def _forward(self, x: torch.Tensor) -> torch.Tensor: |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def _initialize(self, base_module: "Parametrization.base_class") -> None: |
|
""" |
|
Initialize the parametrization based on a given base module. |
|
This method should build the internal representation the module's weight and bias, |
|
registering all required buffers and parameters in `self`. |
|
""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def _weight(self) -> torch.Tensor: |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def _bias(self) -> torch.Tensor | None: |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def get_target_params(self) -> dict[str, torch.nn.Parameter]: |
|
""" |
|
Return the (tunable) target parameters of the parametrization. |
|
Here, "target parameters" means that they can be tuned and potentially compressed |
|
by `self.reset_target_params(mode="compress")`. |
|
Other torch parameters of the module could be tuned as well, but should not returned here. |
|
The returned dictionary should be compatible with `self.named_parameters()`. |
|
|
|
See Also: |
|
- `ParametrizedModel.get_target_params` |
|
- `ParametrizedModel.compress` |
|
""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def reset_target_params(self, mode: Literal["full", "nonzero", "compress"] = "full") -> None: |
|
""" |
|
Reset the target parameters of the parametrization according to a given mode. |
|
|
|
Args: |
|
mode: The reset mode. |
|
"full" means reset to original value at initialization. |
|
"nonzero" means reset all non-zero values to original value at initialization. |
|
"compress" means the all zero values are removed and the the parameters are compressed accordingly. |
|
""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def get_num_params(self, compressed: bool = False, target_params: dict[str, torch.Tensor] | None = None) -> int: |
|
""" |
|
Computes the (effective) number of parameters of the parametrization. |
|
|
|
Args: |
|
compressed: Whether to count the number of parameters as if the module was actually compressed. |
|
If `False`, the number of parameters is the same as in the original module. |
|
target_params: Count the number of parameters as if `target_params` were used instead of |
|
`self.get_target_params()`. This "what if" feature is important when pruning |
|
a full `ParametrizedModel` to a certain target ratio. |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
@runtime_checkable |
|
class ParametrizedModule(Protocol): |
|
""" |
|
Interface for a parametrized `nn.Module`. |
|
It ensures that `weight` and `bias` are forwarded to the `Parametrization` instance. |
|
|
|
Attributes: |
|
parametrization: The `Parametrization` instance of the module. |
|
_forward: The original forward function of the module. |
|
__old_class__: The original class of the module. |
|
|
|
Notes: |
|
`_forward` and `__old_class__` are used by `parametrize_module` and `unparametrize_module` |
|
to allow restoring the original behavior of the module. |
|
""" |
|
|
|
parametrization: Parametrization |
|
_forward: callable |
|
__old_class__: type[nn.Module] |
|
|
|
@property |
|
def weight(self): |
|
return self.parametrization.weight |
|
|
|
@property |
|
def bias(self): |
|
return self.parametrization.bias |
|
|
|
|
|
def parametrize_module(module: nn.Module, parametrization: Parametrization) -> ParametrizedModule and nn.Module: |
|
""" |
|
Parametrize a module using a `Parametrization` instance. |
|
|
|
Args: |
|
module: The module to be parametrized. |
|
parametrization: The `Parametrization` instance to be applied to the module. |
|
|
|
Returns: The parametrized module using the `ParametrizedModule` interface. |
|
|
|
Notes: |
|
Adopted from https://stackoverflow.com/a/31075641 |
|
""" |
|
|
|
assert isinstance(module, parametrization.base_class) |
|
module.__old_class__ = module.__class__ |
|
|
|
|
|
module.add_module("parametrization", parametrization) |
|
module.parametrization.initialize(module) |
|
|
|
|
|
module._forward = module.forward |
|
|
|
|
|
del module.weight |
|
del module.bias |
|
module.__class__ = type("Parametrized" + module.__class__.__name__, (module.__class__, ParametrizedModule), {}) |
|
|
|
module.forward = module.parametrization.forward |
|
|
|
return module |
|
|
|
|
|
def unparametrize_module(module: ParametrizedModule) -> nn.Module: |
|
""" |
|
Revert the parametrization of a module. |
|
|
|
Args: |
|
module: A module that has been parametrized by `parametrize_module`. |
|
|
|
Returns: The original module. |
|
|
|
Notes: |
|
Adopted from https://stackoverflow.com/a/31075641 |
|
""" |
|
|
|
|
|
weight = module.weight |
|
bias = module.bias |
|
|
|
assert isinstance(module, nn.Module) |
|
|
|
|
|
module.__class__ = type(module.__old_class__.__name__, (module.__old_class__,), {}) |
|
delattr(module, "__old_class__") |
|
|
|
|
|
module.register_parameter("weight", nn.Parameter(weight, weight.requires_grad)) |
|
if bias is not None: |
|
module.register_parameter("bias", nn.Parameter(bias, bias.requires_grad)) |
|
else: |
|
module.register_parameter("bias", None) |
|
|
|
|
|
del module.parametrization |
|
module.forward = module._forward |
|
delattr(module, "_forward") |
|
|
|
return module |
|
|