File size: 9,570 Bytes
9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd 9de62de 7836cdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import logging
from typing import Any
import torch
from transformers import PreTrainedModel
from .parametrized_model import ParametrizedModel, ParametrizedModelConfig
logger = logging.getLogger(__name__)
class ACIPModelConfig(ParametrizedModelConfig):
"""
Configuration for `ACIPModel`. Same functionality as `ParametrizedModelConfig`.
See Also:
- `ParametrizedModelConfig`
- `ACIPModel`
"""
model_type = "acip_model"
class ACIPModel(ParametrizedModel):
"""
This class extends `ParametrizedModel` by additional functionality required for ACIP.
It manages a `score_map` that stores the scores of the parametrized modules' target parameters,
which are updated during tuning by the ACIP method.
Moreover, it provides `prune_model_by_score` that prunes the target parameters of the model according to
their scores to achieve any given size ratio.
Notes: The `score_map` is managed in float32 internally because a lower precision may lead to unexpected numerical
inaccuracies in the resulting parameter ranking. Fortunately, the memory consumption is negligible compared to
the model weights itself.
See Also:
- `ParametrizedModel`
- `ACIPModelConfig`
"""
config_class = ACIPModelConfig
def __init__(self, config: ACIPModelConfig, base_model: PreTrainedModel | None = None, **_: Any):
super().__init__(config, base_model)
self.config = config # redundant but enables type hinting for ACIPModelConfig
self._score_map: dict[str, torch.Tensor] | None = None
# Register and initialize score map buffers
# Important: don't run _update_score_map here because load_state_dict might still override the buffers
self._init_score_map_buffers()
def _init_score_map_buffers(self):
"""
Register and initialize score map buffers in parametrized modules (with random numbers).
Each target parameter "p_name" is associated with a buffer "p_name_score" that stores its score vector.
"""
for m_name, module in self.parametrized_modules.items():
for p_name, param in module.parametrization.get_target_params().items():
module.parametrization.register_buffer(p_name + "_score", torch.ones_like(param.data).float())
def _update_score_map(self):
"""Render `score_map` from the parametrized modules' score buffers."""
self._score_map = {}
for m_name, module in self.parametrized_modules.items():
for p_name in module.parametrization.get_target_params().keys():
self._score_map[f"{m_name}.parametrization.{p_name}"] = module.parametrization.get_buffer(
p_name + "_score"
)
@property
def score_map(self) -> dict[str, torch.Tensor]:
"""Returns the score map as Tensor dictionary whose keys match those of `self.get_target_params`."""
if self._score_map is None:
self._update_score_map()
return self._score_map
@score_map.setter
def score_map(self, score_map: dict[str, torch.Tensor]) -> None:
"""
Updates `score_map` and the corresponding parametrized modules' score buffers.
Args:
score_map: Dictionary whose keys should match (a subset of) `self.get_target_params`.
"""
if self._score_map is None:
self._update_score_map()
# score_map.keys() can be a subset of self.get_target_params().keys()
for p_name, score in score_map.items():
buffer = self.model.get_buffer(p_name + "_score")
if buffer.shape != score.shape:
raise ValueError(
f"Score map for '{p_name}' has incorrect shape: expected {buffer.shape}, got {score.shape}"
)
# cast to float32 to avoid numerical instabilities
buffer.copy_(score.detach().float())
self._score_map[p_name] = buffer
def _predict_size_ratio_by_score(self, k: int, full: bool = False) -> tuple[float, dict[str, torch.Tensor]]:
"""
Helper function that checks what would happen if the k smallest target parameters are pruned
according to the global score map ranking. It returns the resulting size ratio
and the corresponding parameter masks.
Args:
k: Number of target parameters to prune.
full: Whether to count the number of parameters of the entire model or only the parametrized modules.
See also `ParametrizedModel.get_num_params`.
Returns: Tuple of size ratio and parameter masks. The masks indicate which parameters to keep.
"""
# Find the threshold value for the k smallest entries according to the global score map ranking.
score_map_cat = torch.cat([param.flatten() for param in self.score_map.values()])
threshold = torch.kthvalue(score_map_cat, k).values.item()
# Create a set of parameter masks marking which values to keep.
param_masks = {}
for p_name, score in self.score_map.items():
param_masks[p_name] = (score > threshold).to(dtype=score.dtype)
# Compute hypothetical size ratio if param_masks would be used as masks for the target parameters.
size_ratio = self.get_size_ratio(full=full, target_params=param_masks)
return size_ratio, param_masks
def _get_param_masks(self, size_ratio: float, full: bool = False) -> dict[str, torch.Tensor]:
"""
Helper function that determines which parameters to keep to reach a target size ratio.
Instead of looping over `k -> _predict_size_ratio_by_score(k)`, a binary search can be used because
the size ratio is monotonically increasing in k.
Args:
size_ratio: Target size ratio.
full: Whether to count the number of parameters of the entire model or only the parametrized modules.
See also `ParametrizedModel.get_num_params`.
Returns: Parameter masks indicating which parameters to keep to reach the target size ratio.
"""
if size_ratio == 1.0:
return {p_name: torch.ones_like(score) for p_name, score in self.score_map.items()}
# Perform a binary search to find the smallest k such that the size ratio is at least size_ratio.
# Here, k_lo and k_hi are the lower and upper bound of the search interval.
k_lo, k_hi = 1, sum(score.numel() for score in self.score_map.values())
while k_lo < k_hi:
k_mid = (k_lo + k_hi + 1) // 2 # round up to ensure low <= mid
ratio, _ = self._predict_size_ratio_by_score(k=k_mid, full=full)
if ratio > size_ratio:
k_lo = k_mid
else:
k_hi = k_mid - 1
k = k_lo
# TODO: handle tie-breaks
return self._predict_size_ratio_by_score(k=k, full=full)[1]
def prune_model_by_score(
self,
size_ratio: float | None = None,
compression_rate: float | None = None,
full: bool = False,
) -> None:
"""
This method prunes the target parameters of the model according to their scores to achieve
a given size ratio.
This can be efficiently implemented by a simple binary search strategy:
We find the smallest number of parameters to be pruned according to the score map ranking
such that the resulting size ratio is at least the target `size_ratio`.
Args:
size_ratio: The target size ratio, which is the ratio between the size of the compressed model and
the original model (where size is measured in number of parameters).
If not provided, `compression_rate` must be provided.
compression_rate: This is a convenience parameter that allows you to set the target compression rate
instead of `size_ratio`. It is equivalent to `size_ratio = 1.0 - compression_rate`.
If both `size_ratio` and `compression_rate` are provided, `size_ratio` is used.
full: Whether to count the number of parameters of the entire model or only the parametrized modules.
See also `ParametrizedModel.get_num_params`.
"""
if size_ratio is None and compression_rate is None:
raise ValueError("Either `size_ratio` or `compression_rate` must be provided.")
elif size_ratio is None and compression_rate is not None:
size_ratio = 1.0 - compression_rate
else:
logger.warning("Both `size_ratio` and `compression_rate` are provided. Using `size_ratio`.")
param_masks = self._get_param_masks(size_ratio=size_ratio, full=full)
# Reset the target parameters according to the parameter masks
for p_name, param in self.get_target_params().items():
param.data[param_masks[p_name] > 0.0] = 1.0 # dummy value, will be rescaled by reset_target_params
param.data[param_masks[p_name] == 0.0] = 0.0
for m_name, module in self.parametrized_modules.items():
if any(p_name.startswith(m_name) for p_name in param_masks.keys()):
module.parametrization.reset_target_params(mode="nonzero")
# Register ACIPModelConfig and ACIPModel for AutoModel
# Required to push custom model to Huggingface Hub (see https://huggingface.co/docs/transformers/en/custom_models)
ACIPModelConfig.register_for_auto_class()
ACIPModel.register_for_auto_class("AutoModel")
|