|
import math |
|
from abc import ABC, abstractmethod |
|
from logging import getLogger |
|
from typing import Literal |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from .parametrized_layer import Parametrization |
|
from .utils import use_init_empty_weights |
|
|
|
logger = getLogger(__name__) |
|
|
|
|
|
class CompressionCriterion(ABC): |
|
""" |
|
Abstract class for compression criterion of a (target) parameter of a parametrized module. |
|
""" |
|
|
|
@abstractmethod |
|
def __call__(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
x: A tensor of any shape |
|
|
|
Returns: A boolean mask of the same shape as `x` where `False` indicates that the entry can be removed. |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
class ThresholdCriterion(CompressionCriterion): |
|
""" |
|
Compression criterion based on a threshold. All entries below `self.threshold` can be removed. |
|
""" |
|
|
|
def __init__(self, threshold: float = 0.0): |
|
self.threshold = threshold |
|
|
|
def __call__(self, x: torch.Tensor) -> torch.Tensor: |
|
return x > self.threshold |
|
|
|
|
|
class ProjectedLinearParametrization(Parametrization, ABC): |
|
""" |
|
Implementation of a linear layer parametrization, factorizing the weight matrix as |
|
`weight = ortho.weight @ torch.diag(mask) @ base.weight`. |
|
Here, `ortho` is a linear layer with orthogonal columns, `mask` represents a (binary) diagonal matrix |
|
that can be pruned, and `base` is a linear layer (determined by the choice of `ortho`). |
|
Any child class needs to implement `_ortho_init` which creates `ortho`. Based on this, `mask` and `base` are |
|
initialized such that the original weight matrix is obtained at initialization. |
|
|
|
`mask` corresponds to the only target parameter of this parametrization. Pruning it will result in |
|
a low-rank matrix representation of the parametrized linear module. |
|
""" |
|
|
|
base_class = nn.Linear |
|
|
|
def __init__( |
|
self, |
|
mask_func: Literal["ste", "relu", "none"] = "ste", |
|
mask_scaling_factor: float | str = "norm", |
|
compression_criterion: CompressionCriterion = ThresholdCriterion(), |
|
): |
|
""" |
|
Args: |
|
mask_func: A function applied to the mask parameter in each forward pass implementing |
|
custom functionalities. Available options: ["ste", "relu", "none"]. |
|
"ste" means using a straight-through estimator, i.e., in the forward pass, `mask` is binarized, which |
|
is ignored in the backward pass. Before `mask` passed through a ReLU activation. |
|
"relu" means that `mask` is passed through a ReLU activation. |
|
"none" means that `mask` is not modified. |
|
mask_scaling_factor: Conceptually, `mask` is initialized with ones, but rescaling to a smaller value |
|
can vastly improve the training speed. `mask_scaling_factor` specifies this rescaling factor. |
|
The rescaling should be compensated by scaling `ortho` accordingly in `self._ortho_init`. |
|
If `mask_scaling_factor='norm'`, the scaling factor is chosen such that `mask` has unit L2 norm |
|
(note that this can lead to a different behavior in model tuning than for a fixed factor |
|
when some target parameters have different number of elements). |
|
compression_criterion: `CompressionCriterion` to be used in `self.reset_target_params(mode="compress")`. |
|
""" |
|
super().__init__() |
|
self.mask_func = { |
|
"ste": mask_func_ste, |
|
"relu": mask_func_relu, |
|
"none": mask_func_none, |
|
}[mask_func] |
|
self._mask_scaling_factor = mask_scaling_factor |
|
self.compression_criterion = compression_criterion |
|
|
|
def _forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = self.base(x) |
|
x = self.mask_func(self.mask, self.mask_scaling_factor) * x |
|
x = self.ortho(x) |
|
return x |
|
|
|
def _weight(self) -> torch.Tensor: |
|
|
|
mask = self.mask_func(self.mask, self.mask_scaling_factor) |
|
return self.ortho.weight @ torch.diag(mask) @ self.base.weight |
|
|
|
def _bias(self) -> torch.Tensor | None: |
|
return self.ortho.bias |
|
|
|
def _initialize(self, base_module: base_class) -> None: |
|
factory_kwargs = {"device": base_module.weight.device, "dtype": base_module.weight.dtype} |
|
in_dim, out_dim = base_module.in_features, base_module.out_features |
|
proj_dim = min(in_dim, out_dim) |
|
|
|
|
|
self.add_module( |
|
"ortho", |
|
nn.Linear(in_features=proj_dim, out_features=out_dim, bias=base_module.bias is not None, **factory_kwargs), |
|
) |
|
self._ortho_init(base_module.weight) |
|
if base_module.bias is not None: |
|
|
|
|
|
self.ortho.bias.data.copy_(base_module.bias.data) |
|
|
|
|
|
base = base_module.__class__(in_features=in_dim, out_features=proj_dim, bias=False, **factory_kwargs) |
|
base.weight.data.copy_(self.ortho.weight.data.T @ base_module.weight.data) |
|
self.add_module("base", base) |
|
|
|
|
|
self.register_parameter("mask", torch.nn.Parameter(torch.ones(proj_dim, **factory_kwargs))) |
|
|
|
|
|
self.reset_target_params() |
|
|
|
@abstractmethod |
|
def _ortho_init(self, weight: torch.Tensor) -> None: |
|
""" |
|
Initialize ortho layer. Must be implemented by child class. |
|
|
|
Args: |
|
weight: Weight matrix of the original linear layer module. |
|
""" |
|
raise NotImplementedError |
|
|
|
def get_target_params(self) -> dict[str, torch.nn.Parameter]: |
|
return {"mask": self.mask} |
|
|
|
@property |
|
def mask_scaling_factor(self) -> float: |
|
if self._mask_scaling_factor == "norm": |
|
|
|
|
|
self._mask_scaling_factor = 1 / math.sqrt(self.mask.numel()) |
|
return self._mask_scaling_factor |
|
elif isinstance(self._mask_scaling_factor, float): |
|
return self._mask_scaling_factor |
|
else: |
|
raise ValueError(f"Invalid mask_scaling_factor: {self._mask_scaling_factor}") |
|
|
|
@property |
|
def in_features(self) -> int: |
|
return self.base.in_features |
|
|
|
@property |
|
def out_features(self) -> int: |
|
return self.ortho.out_features |
|
|
|
def reset_target_params(self, mode: Literal["full", "nonzero", "compress"] = "full") -> None: |
|
with torch.no_grad(): |
|
if mode == "full": |
|
|
|
self.mask.data = torch.ones_like(self.mask.data) * self.mask_scaling_factor |
|
elif mode == "nonzero": |
|
|
|
self.mask.data[self.mask.data > 0] = 1.0 * self.mask_scaling_factor |
|
self.mask.data[self.mask.data < 0] = 0.0 |
|
elif mode == "compress": |
|
if self.compression_criterion is None: |
|
logger.warning("Compression criterion is not set. No op...") |
|
return |
|
|
|
dim_select = self.compression_criterion(self.mask) |
|
|
|
new_base = new_linear_from_mask(self.base, dim_select, column_select=False) |
|
new_ortho = new_linear_from_mask(self.ortho, dim_select, column_select=True) |
|
new_mask = self.mask[dim_select].clone().detach() |
|
del self.mask, self.base, self.ortho |
|
self.register_module("base", new_base) |
|
self.register_module("ortho", new_ortho) |
|
self.register_parameter("mask", nn.Parameter(new_mask)) |
|
else: |
|
raise ValueError(f"Invalid mode: {mode}") |
|
|
|
def get_num_params(self, compressed: bool = False, target_params: dict[str, torch.Tensor] | None = None) -> int: |
|
if not compressed: |
|
|
|
num_params = self.in_features * self.out_features |
|
if self.bias is not None: |
|
num_params += self.out_features |
|
return num_params |
|
else: |
|
|
|
if target_params is not None: |
|
sparsity = mask_sparsity(target_params["mask"] != 0.0, threshold=0.0) |
|
else: |
|
sparsity = mask_sparsity(self.mask) |
|
|
|
num_params = self.in_features * sparsity + sparsity * self.out_features |
|
if self.bias is not None: |
|
num_params += self.out_features |
|
|
|
|
|
|
|
num_params = min(self.get_num_params(compressed=False), num_params) |
|
return num_params |
|
|
|
|
|
class SVDLinearParametrization(ProjectedLinearParametrization): |
|
""" |
|
Implementation of a linear layer parametrization using SVD decomposition. |
|
If the SVD of weight is U * S * V^T, then `ortho.weight = U` and `base.weight = S * V^T`. |
|
As base is computed automatically by `_initialize`, `_ortho_init` only needs to compute U and |
|
scale it properly with `mask_scaling_factor`. The singular values S are buffered just in case they are needed |
|
in the tuning process. |
|
""" |
|
|
|
def _ortho_init(self, weight: torch.Tensor) -> None: |
|
k = min(weight.shape[0], weight.shape[1]) |
|
if use_init_empty_weights.get(): |
|
|
|
|
|
logger.debug("Parametrizing with empty weights.") |
|
U = torch.empty(weight.shape[0], k) |
|
S = torch.empty(k, 1) |
|
else: |
|
|
|
U, S, _ = torch.linalg.svd(weight.detach().float(), full_matrices=False) |
|
|
|
|
|
if self._mask_scaling_factor == "norm": |
|
U = math.pow(k, 1 / 4) * U |
|
else: |
|
U = math.sqrt(1 / self._mask_scaling_factor) * U |
|
factory_kwargs = {"device": weight.device, "dtype": weight.dtype} |
|
self.ortho.weight.data.copy_(U.detach().to(**factory_kwargs)) |
|
self.register_buffer("S", S.detach().flatten().to(**factory_kwargs)) |
|
|
|
|
|
def mask_func_ste(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor: |
|
|
|
mask = F.relu(mask) |
|
return (mask > 0).to(mask.dtype).detach() * mask_scaling_factor + mask - mask.detach() |
|
|
|
|
|
def mask_func_relu(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor: |
|
|
|
return F.relu(mask) |
|
|
|
|
|
def mask_func_none(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor: |
|
|
|
return mask |
|
|
|
|
|
def mask_sparsity(mask: torch.Tensor, threshold: float = 0.0) -> int: |
|
"""Simple util function to compute the number of non-zero elements of a mask, where an element is considered |
|
non-zero if its value is strictly greater than `threshold`.""" |
|
return torch.count_nonzero(mask > threshold).item() |
|
|
|
|
|
def new_linear_from_mask(module: nn.Linear, dim_select: torch.Tensor, column_select=True) -> nn.Linear: |
|
""" |
|
Creates a new linear layer from an existing one based on a mask indicating which columns/rows to keep. |
|
|
|
Args: |
|
module: Module to be pruned. |
|
dim_select: Boolean tensor mask indicating which columns/rows to keep. |
|
column_select: Whether to prune columns (True) or rows (False) according to `dim_select`. |
|
|
|
Returns: Pruned module. |
|
""" |
|
assert dim_select.dtype == torch.bool, "dim_select must be boolean" |
|
|
|
in_features, out_features = module.in_features, module.out_features |
|
sparsity = dim_select.sum().item() |
|
if column_select: |
|
in_features = sparsity |
|
else: |
|
out_features = sparsity |
|
new_module = module.__class__( |
|
in_features=in_features, |
|
out_features=out_features, |
|
bias=module.bias is not None, |
|
device=module.weight.device, |
|
dtype=module.weight.dtype, |
|
) |
|
weight = module.weight.data |
|
if column_select: |
|
weight = weight[:, dim_select] |
|
else: |
|
weight = weight[dim_select, :] |
|
new_module.weight.data.copy_(weight.detach()) |
|
|
|
if new_module.bias is not None: |
|
if column_select: |
|
new_module.bias.data.copy_(module.bias.detach()) |
|
else: |
|
|
|
new_module.bias.data.copy_(module.bias[dim_select].detach()) |
|
|
|
return new_module |
|
|