File size: 9,570 Bytes
9de62de
7836cdd
 
 
 
 
 
 
9de62de
 
7836cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9de62de
7836cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9de62de
7836cdd
 
9de62de
7836cdd
 
 
 
 
 
 
9de62de
7836cdd
 
 
 
 
 
 
 
 
 
9de62de
 
 
7836cdd
9de62de
7836cdd
9de62de
 
 
7836cdd
 
9de62de
7836cdd
 
 
9de62de
7836cdd
9de62de
7836cdd
 
9de62de
7836cdd
 
 
 
9de62de
 
7836cdd
 
 
 
 
9de62de
 
 
 
 
 
 
 
7836cdd
 
9de62de
7836cdd
 
 
9de62de
7836cdd
 
9de62de
 
 
 
 
 
7836cdd
 
 
9de62de
 
 
 
 
 
 
 
7836cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import logging
from typing import Any

import torch
from transformers import PreTrainedModel

from .parametrized_model import ParametrizedModel, ParametrizedModelConfig

logger = logging.getLogger(__name__)


class ACIPModelConfig(ParametrizedModelConfig):
    """
    Configuration for `ACIPModel`. Same functionality as `ParametrizedModelConfig`.

    See Also:
        - `ParametrizedModelConfig`
        - `ACIPModel`
    """

    model_type = "acip_model"


class ACIPModel(ParametrizedModel):
    """
    This class extends `ParametrizedModel` by additional functionality required for ACIP.
    It manages a `score_map` that stores the scores of the parametrized modules' target parameters,
    which are updated during tuning by the ACIP method.
    Moreover, it provides `prune_model_by_score` that prunes the target parameters of the model according to
    their scores to achieve any given size ratio.

    Notes: The `score_map` is managed in float32 internally because a lower precision may lead to unexpected numerical
        inaccuracies in the resulting parameter ranking. Fortunately, the memory consumption is negligible compared to
        the model weights itself.

    See Also:
        - `ParametrizedModel`
        - `ACIPModelConfig`
    """

    config_class = ACIPModelConfig

    def __init__(self, config: ACIPModelConfig, base_model: PreTrainedModel | None = None, **_: Any):
        super().__init__(config, base_model)
        self.config = config  # redundant but enables type hinting for ACIPModelConfig

        self._score_map: dict[str, torch.Tensor] | None = None
        # Register and initialize score map buffers
        # Important: don't run _update_score_map here because load_state_dict might still override the buffers
        self._init_score_map_buffers()

    def _init_score_map_buffers(self):
        """
        Register and initialize score map buffers in parametrized modules (with random numbers).
        Each target parameter "p_name" is associated with a buffer "p_name_score" that stores its score vector.
        """
        for m_name, module in self.parametrized_modules.items():
            for p_name, param in module.parametrization.get_target_params().items():
                module.parametrization.register_buffer(p_name + "_score", torch.ones_like(param.data).float())

    def _update_score_map(self):
        """Render `score_map` from the parametrized modules' score buffers."""
        self._score_map = {}
        for m_name, module in self.parametrized_modules.items():
            for p_name in module.parametrization.get_target_params().keys():
                self._score_map[f"{m_name}.parametrization.{p_name}"] = module.parametrization.get_buffer(
                    p_name + "_score"
                )

    @property
    def score_map(self) -> dict[str, torch.Tensor]:
        """Returns the score map as Tensor dictionary whose keys match those of `self.get_target_params`."""
        if self._score_map is None:
            self._update_score_map()
        return self._score_map

    @score_map.setter
    def score_map(self, score_map: dict[str, torch.Tensor]) -> None:
        """
        Updates `score_map` and the corresponding parametrized modules' score buffers.

        Args:
            score_map: Dictionary whose keys should match (a subset of) `self.get_target_params`.
        """
        if self._score_map is None:
            self._update_score_map()
        # score_map.keys() can be a subset of self.get_target_params().keys()
        for p_name, score in score_map.items():
            buffer = self.model.get_buffer(p_name + "_score")
            if buffer.shape != score.shape:
                raise ValueError(
                    f"Score map for '{p_name}' has incorrect shape: expected {buffer.shape}, got {score.shape}"
                )
            # cast to float32 to avoid numerical instabilities
            buffer.copy_(score.detach().float())
            self._score_map[p_name] = buffer

    def _predict_size_ratio_by_score(self, k: int, full: bool = False) -> tuple[float, dict[str, torch.Tensor]]:
        """
        Helper function that checks what would happen if the k smallest target parameters are pruned
        according to the global score map ranking. It returns the resulting size ratio
        and the corresponding parameter masks.

        Args:
            k: Number of target parameters to prune.
            full: Whether to count the number of parameters of the entire model or only the parametrized modules.
                See also `ParametrizedModel.get_num_params`.

        Returns: Tuple of size ratio and parameter masks. The masks indicate which parameters to keep.
        """
        # Find the threshold value for the k smallest entries according to the global score map ranking.
        score_map_cat = torch.cat([param.flatten() for param in self.score_map.values()])
        threshold = torch.kthvalue(score_map_cat, k).values.item()

        # Create a set of parameter masks marking which values to keep.
        param_masks = {}
        for p_name, score in self.score_map.items():
            param_masks[p_name] = (score > threshold).to(dtype=score.dtype)

        # Compute hypothetical size ratio if param_masks would be used as masks for the target parameters.
        size_ratio = self.get_size_ratio(full=full, target_params=param_masks)
        return size_ratio, param_masks

    def _get_param_masks(self, size_ratio: float, full: bool = False) -> dict[str, torch.Tensor]:
        """
        Helper function that determines which parameters to keep to reach a target size ratio.
        Instead of looping over `k -> _predict_size_ratio_by_score(k)`, a binary search can be used because
        the size ratio is monotonically increasing in k.

        Args:
            size_ratio: Target size ratio.
            full: Whether to count the number of parameters of the entire model or only the parametrized modules.
                See also `ParametrizedModel.get_num_params`.

        Returns: Parameter masks indicating which parameters to keep to reach the target size ratio.
        """
        if size_ratio == 1.0:
            return {p_name: torch.ones_like(score) for p_name, score in self.score_map.items()}

        # Perform a binary search to find the smallest k such that the size ratio is at least size_ratio.
        # Here, k_lo and k_hi are the lower and upper bound of the search interval.
        k_lo, k_hi = 1, sum(score.numel() for score in self.score_map.values())
        while k_lo < k_hi:
            k_mid = (k_lo + k_hi + 1) // 2  # round up to ensure low <= mid
            ratio, _ = self._predict_size_ratio_by_score(k=k_mid, full=full)
            if ratio > size_ratio:
                k_lo = k_mid
            else:
                k_hi = k_mid - 1
        k = k_lo
        # TODO: handle tie-breaks
        return self._predict_size_ratio_by_score(k=k, full=full)[1]

    def prune_model_by_score(
        self,
        size_ratio: float | None = None,
        compression_rate: float | None = None,
        full: bool = False,
    ) -> None:
        """
        This method prunes the target parameters of the model according to their scores to achieve
        a given size ratio.

        This can be efficiently implemented by a simple binary search strategy:
        We find the smallest number of parameters to be pruned according to the score map ranking
        such that the resulting size ratio is at least the target `size_ratio`.

        Args:
            size_ratio: The target size ratio, which is the ratio between the size of the compressed model and
                the original model (where size is measured in number of parameters).
                If not provided, `compression_rate` must be provided.
            compression_rate: This is a convenience parameter that allows you to set the target compression rate
                instead of `size_ratio`. It is equivalent to `size_ratio = 1.0 - compression_rate`.
                If both `size_ratio` and `compression_rate` are provided, `size_ratio` is used.
            full: Whether to count the number of parameters of the entire model or only the parametrized modules.
                See also `ParametrizedModel.get_num_params`.
        """
        if size_ratio is None and compression_rate is None:
            raise ValueError("Either `size_ratio` or `compression_rate` must be provided.")
        elif size_ratio is None and compression_rate is not None:
            size_ratio = 1.0 - compression_rate
        else:
            logger.warning("Both `size_ratio` and `compression_rate` are provided. Using `size_ratio`.")

        param_masks = self._get_param_masks(size_ratio=size_ratio, full=full)

        # Reset the target parameters according to the parameter masks
        for p_name, param in self.get_target_params().items():
            param.data[param_masks[p_name] > 0.0] = 1.0  # dummy value, will be rescaled by reset_target_params
            param.data[param_masks[p_name] == 0.0] = 0.0
        for m_name, module in self.parametrized_modules.items():
            if any(p_name.startswith(m_name) for p_name in param_masks.keys()):
                module.parametrization.reset_target_params(mode="nonzero")


# Register ACIPModelConfig and ACIPModel for AutoModel
# Required to push custom model to Huggingface Hub (see https://huggingface.co/docs/transformers/en/custom_models)
ACIPModelConfig.register_for_auto_class()
ACIPModel.register_for_auto_class("AutoModel")