martingenzel commited on
Commit
7836cdd
·
verified ·
1 Parent(s): 1d4c1a8
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
acip_model.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ from transformers import PreTrainedModel
5
+
6
+ from .parametrized_model import ParametrizedModel, ParametrizedModelConfig
7
+
8
+
9
+ class ACIPModelConfig(ParametrizedModelConfig):
10
+ """
11
+ Configuration for `ACIPModel`. Same functionality as `ParametrizedModelConfig`.
12
+
13
+ See Also:
14
+ - `ParametrizedModelConfig`
15
+ - `ACIPModel`
16
+ """
17
+
18
+ model_type = "acip_model"
19
+
20
+
21
+ class ACIPModel(ParametrizedModel):
22
+ """
23
+ This class extends `ParametrizedModel` by additional functionality required for ACIP.
24
+ It manages a `score_map` that stores the scores of the parametrized modules' target parameters,
25
+ which are updated during tuning by the ACIP method.
26
+ Moreover, it provides `prune_model_by_score` that prunes the target parameters of the model according to
27
+ their scores to achieve any given compression ratio.
28
+
29
+ Notes: The `score_map` is managed in float32 internally because a lower precision may lead to unexpected numerical
30
+ inaccuracies in the resulting parameter ranking. Fortunately, the memory consumption is negligible compared to
31
+ the model weights itself.
32
+
33
+ See Also:
34
+ - `ParametrizedModel`
35
+ - `ACIPModelConfig`
36
+ """
37
+
38
+ config_class = ACIPModelConfig
39
+
40
+ def __init__(self, config: ACIPModelConfig, base_model: PreTrainedModel | None = None, **_: Any):
41
+ super().__init__(config, base_model)
42
+ self.config = config # redundant but enables type hinting for ACIPModelConfig
43
+
44
+ self._score_map: dict[str, torch.Tensor] | None = None
45
+ # Register and initialize score map buffers
46
+ # Important: don't run _update_score_map here because load_state_dict might still override the buffers
47
+ self._init_score_map_buffers()
48
+
49
+ def _init_score_map_buffers(self):
50
+ """
51
+ Register and initialize score map buffers in parametrized modules (with random numbers).
52
+ Each target parameter "p_name" is associated with a buffer "p_name_score" that stores its score vector.
53
+ """
54
+ for m_name, module in self.parametrized_modules.items():
55
+ for p_name, param in module.parametrization.get_target_params().items():
56
+ module.parametrization.register_buffer(p_name + "_score", torch.ones_like(param.data).float())
57
+
58
+ def _update_score_map(self):
59
+ """Render `score_map` from the parametrized modules' score buffers."""
60
+ self._score_map = {}
61
+ for m_name, module in self.parametrized_modules.items():
62
+ for p_name in module.parametrization.get_target_params().keys():
63
+ self._score_map[f"{m_name}.parametrization.{p_name}"] = module.parametrization.get_buffer(
64
+ p_name + "_score"
65
+ )
66
+
67
+ @property
68
+ def score_map(self) -> dict[str, torch.Tensor]:
69
+ """Returns the score map as Tensor dictionary whose keys match those of `self.get_target_params`."""
70
+ if self._score_map is None:
71
+ self._update_score_map()
72
+ return self._score_map
73
+
74
+ @score_map.setter
75
+ def score_map(self, score_map: dict[str, torch.Tensor]) -> None:
76
+ """
77
+ Updates `score_map` and the corresponding parametrized modules' score buffers.
78
+
79
+ Args:
80
+ score_map: Dictionary whose keys should match (a subset of) `self.get_target_params`.
81
+ """
82
+ if self._score_map is None:
83
+ self._update_score_map()
84
+ # score_map.keys() can be a subset of self.get_target_params().keys()
85
+ for p_name, score in score_map.items():
86
+ buffer = self.model.get_buffer(p_name + "_score")
87
+ if buffer.shape != score.shape:
88
+ raise ValueError(
89
+ f"Score map for '{p_name}' has incorrect shape: expected {buffer.shape}, got {score.shape}"
90
+ )
91
+ # cast to float32 to avoid numerical instabilities
92
+ buffer.copy_(score.detach().float())
93
+ self._score_map[p_name] = buffer
94
+
95
+ def _predict_compression_ratio_by_score(self, k: int, full: bool = False) -> tuple[float, dict[str, torch.Tensor]]:
96
+ """
97
+ Helper function that checks what would happen if the k smallest target parameters are pruned
98
+ according to the global score map ranking. It returns the resulting compression ratio
99
+ and the corresponding parameter masks.
100
+
101
+ Args:
102
+ k: Number of target parameters to prune.
103
+ full: Whether to count the number of parameters of the entire model or only the parametrized modules.
104
+ See also `ParametrizedModel.get_num_params`.
105
+
106
+ Returns: Tuple of compression ratio and parameter masks. The masks indicate which parameters to keep.
107
+ """
108
+ # Find the threshold value for the k smallest entries according to the global score map ranking.
109
+ score_map_cat = torch.cat([param.flatten() for param in self.score_map.values()])
110
+ threshold = torch.kthvalue(score_map_cat, k).values.item()
111
+
112
+ # Create a set of parameter masks marking which values to keep.
113
+ param_masks = {}
114
+ for p_name, score in self.score_map.items():
115
+ param_masks[p_name] = (score > threshold).to(dtype=score.dtype)
116
+
117
+ # Compute hypothetical compression ratio if param_masks would be used as masks for the target parameters.
118
+ compression_ratio = self.get_compression_ratio(full=full, target_params=param_masks)
119
+ return compression_ratio, param_masks
120
+
121
+ def _get_param_masks(self, compression_ratio: float, full: bool = False) -> dict[str, torch.Tensor]:
122
+ """
123
+ Helper function that determines which parameters to keep to reach a target compression ratio.
124
+ Instead of looping over `k -> _predict_compression_ratio_by_score(k)`, a binary search can be used because
125
+ the compression ratio is monotonically increasing in k.
126
+
127
+ Args:
128
+ compression_ratio: Target compression ratio.
129
+ full: Whether to count the number of parameters of the entire model or only the parametrized modules.
130
+ See also `ParametrizedModel.get_num_params`.
131
+
132
+ Returns: Parameter masks indicating which parameters to keep to reach the target compression ratio.
133
+ """
134
+ if compression_ratio == 1.0:
135
+ return {p_name: torch.ones_like(score) for p_name, score in self.score_map.items()}
136
+
137
+ # Perform a binary search to find the smallest k such that the compression ratio is at least compression_ratio.
138
+ # Here, k_lo and k_hi are the lower and upper bound of the search interval.
139
+ k_lo, k_hi = 1, sum(score.numel() for score in self.score_map.values())
140
+ while k_lo < k_hi:
141
+ k_mid = (k_lo + k_hi + 1) // 2 # round up to ensure low <= mid
142
+ ratio, _ = self._predict_compression_ratio_by_score(k=k_mid, full=full)
143
+ if ratio > compression_ratio:
144
+ k_lo = k_mid
145
+ else:
146
+ k_hi = k_mid - 1
147
+ k = k_lo
148
+ # TODO: handle tie-breaks
149
+ return self._predict_compression_ratio_by_score(k=k, full=full)[1]
150
+
151
+ def prune_model_by_score(self, compression_ratio: float, full: bool = False) -> None:
152
+ """
153
+ This method prunes the target parameters of the model according to their scores to achieve
154
+ a given compression ratio.
155
+
156
+ This can be efficiently implemented by a simple binary search strategy:
157
+ We find the smallest number of parameters to be pruned according to the score map ranking
158
+ such that the resulting compression ratio is at least the target `compression_ratio`.
159
+
160
+ Args:
161
+ compression_ratio: The target compression ratio.
162
+ full: Whether to count the number of parameters of the entire model or only the parametrized modules.
163
+ See also `ParametrizedModel.get_num_params`.
164
+ """
165
+ param_masks = self._get_param_masks(compression_ratio=compression_ratio, full=full)
166
+
167
+ # Reset the target parameters according to the parameter masks
168
+ for p_name, param in self.get_target_params().items():
169
+ param.data[param_masks[p_name] > 0.0] = 1.0 # dummy value, will be rescaled by reset_target_params
170
+ param.data[param_masks[p_name] == 0.0] = 0.0
171
+ for m_name, module in self.parametrized_modules.items():
172
+ if any(p_name.startswith(m_name) for p_name in param_masks.keys()):
173
+ module.parametrization.reset_target_params(mode="nonzero")
174
+
175
+
176
+ # Register ACIPModelConfig and ACIPModel for AutoModel
177
+ # Required to push custom model to Huggingface Hub (see https://huggingface.co/docs/transformers/en/custom_models)
178
+ ACIPModelConfig.register_for_auto_class()
179
+ ACIPModel.register_for_auto_class("AutoModel")
config.json ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/rwthfs/rz/cluster/home/cp343770/p-res-mwl-llmcompression/artifacts/runs/paper_v3/compress__llama1_7b/model",
3
+ "adapter_config": {
4
+ "peft_config": {
5
+ "default": {
6
+ "alpha_pattern": {},
7
+ "auto_mapping": null,
8
+ "base_model_name_or_path": "jeffwan/llama-7b-hf",
9
+ "bias": "none",
10
+ "eva_config": null,
11
+ "exclude_modules": [
12
+ "base",
13
+ "parametrization",
14
+ "ortho"
15
+ ],
16
+ "fan_in_fan_out": false,
17
+ "inference_mode": false,
18
+ "init_lora_weights": true,
19
+ "layer_replication": null,
20
+ "layers_pattern": null,
21
+ "layers_to_transform": null,
22
+ "loftq_config": {},
23
+ "lora_alpha": 16,
24
+ "lora_bias": false,
25
+ "lora_dropout": 0.05,
26
+ "megatron_config": null,
27
+ "megatron_core": "megatron.core",
28
+ "modules_to_save": null,
29
+ "peft_type": "LORA",
30
+ "r": 32,
31
+ "rank_pattern": {},
32
+ "revision": null,
33
+ "target_modules": [
34
+ "q_proj",
35
+ "base",
36
+ "o_proj",
37
+ "gate_proj",
38
+ "down_proj",
39
+ "up_proj",
40
+ "v_proj",
41
+ "k_proj",
42
+ "ortho"
43
+ ],
44
+ "task_type": "CAUSAL_LM",
45
+ "use_dora": false,
46
+ "use_rslora": false
47
+ }
48
+ }
49
+ },
50
+ "architectures": [
51
+ "ACIPModel"
52
+ ],
53
+ "auto_map": {
54
+ "AutoConfig": "acip_model.ACIPModelConfig",
55
+ "AutoModel": "acip_model.ACIPModel"
56
+ },
57
+ "base_model_config": {
58
+ "pretrained_config": null,
59
+ "pretrained_model_cls": "transformers.models.auto.modeling_auto.AutoModelForCausalLM",
60
+ "pretrained_model_kwargs": {
61
+ "pretrained_model_name_or_path": "jeffwan/llama-7b-hf",
62
+ "torch_dtype": "bfloat16"
63
+ }
64
+ },
65
+ "model_mode": "train",
66
+ "model_type": "acip_model",
67
+ "parametrization_config": {
68
+ "exclude_modules": null,
69
+ "module_factory_cls": "svd",
70
+ "module_factory_kwargs": {
71
+ "mask_func": "ste",
72
+ "mask_scaling_factor": 0.02
73
+ },
74
+ "target_modules": [
75
+ "k_proj",
76
+ "down_proj",
77
+ "o_proj",
78
+ "v_proj",
79
+ "gate_proj",
80
+ "q_proj",
81
+ "up_proj"
82
+ ]
83
+ },
84
+ "torch_dtype": "bfloat16",
85
+ "transformers_version": "4.46.3",
86
+ "weight_quantization_config": {
87
+ "exclude_modules": null,
88
+ "module_factory_cls": "bitsandbytes.nn.Linear4bit",
89
+ "module_factory_kwargs": {
90
+ "compute_dtype": "torch.bfloat16",
91
+ "quant_type": "fp4"
92
+ },
93
+ "target_modules": [
94
+ "ortho",
95
+ "base",
96
+ "base_layer"
97
+ ]
98
+ }
99
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.46.3"
4
+ }
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ef5fdccfa838748228160684b44f4c37a99a944cd330ae6af63d94188e9c8e3
3
+ size 4979348184
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6182e14726f23f7f25faa53bab676e79acf5d3c3197624a2d15f7d458755db6
3
+ size 4989605584
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dda70e06b09128d1b3309f98651a4fe9fb111bc360716f3d5e10e327a606a847
3
+ size 5000010480
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbf327dd596abbea5916f36ae9b76eea6f45440e8c1fce3430fa8feaa153e203
3
+ size 4910324992
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:392b87946099ac53941f45818bb40be4af393d694f23bb7b7465674ba3503104
3
+ size 1281214432
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
parametrized_layer.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import ClassVar, Literal, Protocol, runtime_checkable, Type
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class Parametrization(nn.Module, ABC):
9
+ """
10
+ Abstract base class for parametrizations.
11
+ A parametrization can be injected into any torch module of type `base_class` by `parametrize_module`.
12
+ A parametrized module will follow the `ParametrizedModule` interface.
13
+
14
+ This will overload the weight, bias, and forward of the module so that they play together with
15
+ the parametrization. The external behavior of the parametrized module remains unchanged, for instance,
16
+ a parametrized `Linear` module will still work as expected.
17
+
18
+ Attributes:
19
+ base_class: The base class of the module that can be parametrized.
20
+ initialized: A flag that indicates whether the parametrization has been initialized.
21
+ """
22
+
23
+ initialized: bool = False
24
+ base_class: ClassVar[Type[nn.Module]]
25
+
26
+ def initialize(self, base_module: "Parametrization.base_class") -> None:
27
+ self._initialize(base_module)
28
+ self.initialized = True
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ """
32
+ Compute the forward pass of the parametrization.
33
+ This is particularly important when a standard forward pass based on `weight` would be inefficient.
34
+ """
35
+ assert self.initialized
36
+ x = self._forward(x)
37
+ return x
38
+
39
+ @property
40
+ def weight(self) -> torch.Tensor:
41
+ """Compute the weight tensor of the parametrization."""
42
+ return self._weight()
43
+
44
+ @property
45
+ def bias(self) -> torch.Tensor | None:
46
+ """Compute the bias tensor of the parametrization."""
47
+ return self._bias()
48
+
49
+ @abstractmethod
50
+ def _forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ raise NotImplementedError
52
+
53
+ @abstractmethod
54
+ def _initialize(self, base_module: "Parametrization.base_class") -> None:
55
+ """
56
+ Initialize the parametrization based on a given base module.
57
+ This method should build the internal representation the module's weight and bias,
58
+ registering all required buffers and parameters in `self`.
59
+ """
60
+ raise NotImplementedError
61
+
62
+ @abstractmethod
63
+ def _weight(self) -> torch.Tensor:
64
+ raise NotImplementedError
65
+
66
+ @abstractmethod
67
+ def _bias(self) -> torch.Tensor | None:
68
+ raise NotImplementedError
69
+
70
+ @abstractmethod
71
+ def get_target_params(self) -> dict[str, torch.nn.Parameter]:
72
+ """
73
+ Return the (tunable) target parameters of the parametrization.
74
+ Here, "target parameters" means that they can be tuned and potentially compressed
75
+ by `self.reset_target_params(mode="compress")`.
76
+ Other torch parameters of the module could be tuned as well, but should not returned here.
77
+ The returned dictionary should be compatible with `self.named_parameters()`.
78
+
79
+ See Also:
80
+ - `ParametrizedModel.get_target_params`
81
+ - `ParametrizedModel.compress`
82
+ """
83
+ raise NotImplementedError
84
+
85
+ @abstractmethod
86
+ def reset_target_params(self, mode: Literal["full", "nonzero", "compress"] = "full") -> None:
87
+ """
88
+ Reset the target parameters of the parametrization according to a given mode.
89
+
90
+ Args:
91
+ mode: The reset mode.
92
+ "full" means reset to original value at initialization.
93
+ "nonzero" means reset all non-zero values to original value at initialization.
94
+ "compress" means the all zero values are removed and the the parameters are compressed accordingly.
95
+ """
96
+ raise NotImplementedError
97
+
98
+ @abstractmethod
99
+ def get_num_params(self, compressed: bool = False, target_params: dict[str, torch.Tensor] | None = None) -> int:
100
+ """
101
+ Computes the (effective) number of parameters of the parametrization.
102
+
103
+ Args:
104
+ compressed: Whether to count the number of parameters as if the module was actually compressed.
105
+ If `False`, the number of parameters is the same as in the original module.
106
+ target_params: Count the number of parameters as if `target_params` were used instead of
107
+ `self.get_target_params()`. This "what if" feature is important when pruning
108
+ a full `ParametrizedModel` to a certain target ratio.
109
+ """
110
+ raise NotImplementedError
111
+
112
+
113
+ @runtime_checkable
114
+ class ParametrizedModule(Protocol):
115
+ """
116
+ Interface for a parametrized `nn.Module`.
117
+ It ensures that `weight` and `bias` are forwarded to the `Parametrization` instance.
118
+
119
+ Attributes:
120
+ parametrization: The `Parametrization` instance of the module.
121
+ _forward: The original forward function of the module.
122
+ __old_class__: The original class of the module.
123
+
124
+ Notes:
125
+ `_forward` and `__old_class__` are used by `parametrize_module` and `unparametrize_module`
126
+ to allow restoring the original behavior of the module.
127
+ """
128
+
129
+ parametrization: Parametrization
130
+ _forward: callable
131
+ __old_class__: type[nn.Module]
132
+
133
+ @property
134
+ def weight(self):
135
+ return self.parametrization.weight
136
+
137
+ @property
138
+ def bias(self):
139
+ return self.parametrization.bias
140
+
141
+
142
+ def parametrize_module(module: nn.Module, parametrization: Parametrization) -> ParametrizedModule and nn.Module:
143
+ """
144
+ Parametrize a module using a `Parametrization` instance.
145
+
146
+ Args:
147
+ module: The module to be parametrized.
148
+ parametrization: The `Parametrization` instance to be applied to the module.
149
+
150
+ Returns: The parametrized module using the `ParametrizedModule` interface.
151
+
152
+ Notes:
153
+ Adopted from https://stackoverflow.com/a/31075641
154
+ """
155
+
156
+ assert isinstance(module, parametrization.base_class)
157
+ module.__old_class__ = module.__class__
158
+
159
+ # Initializes the parametrization and adds it to the module
160
+ module.add_module("parametrization", parametrization)
161
+ module.parametrization.initialize(module)
162
+
163
+ # Save the original forward in case we want to remove the parametrization again
164
+ module._forward = module.forward
165
+
166
+ # Cast to new parametrized object class type
167
+ del module.weight
168
+ del module.bias
169
+ module.__class__ = type("Parametrized" + module.__class__.__name__, (module.__class__, ParametrizedModule), {})
170
+ # Make sure that we utilize the forward function of the parametrization
171
+ module.forward = module.parametrization.forward
172
+
173
+ return module
174
+
175
+
176
+ def unparametrize_module(module: ParametrizedModule) -> nn.Module:
177
+ """
178
+ Revert the parametrization of a module.
179
+
180
+ Args:
181
+ module: A module that has been parametrized by `parametrize_module`.
182
+
183
+ Returns: The original module.
184
+
185
+ Notes:
186
+ Adopted from https://stackoverflow.com/a/31075641
187
+ """
188
+
189
+ # Make sure to save weight and bias in intermediate variables
190
+ weight = module.weight
191
+ bias = module.bias
192
+
193
+ assert isinstance(module, nn.Module)
194
+
195
+ # This line will remove properties module.weight and module.bias
196
+ module.__class__ = type(module.__old_class__.__name__, (module.__old_class__,), {})
197
+ delattr(module, "__old_class__")
198
+
199
+ # Add weight and bias as native parameters to the module again
200
+ module.register_parameter("weight", nn.Parameter(weight, weight.requires_grad))
201
+ if bias is not None:
202
+ module.register_parameter("bias", nn.Parameter(bias, bias.requires_grad))
203
+ else:
204
+ module.register_parameter("bias", None)
205
+
206
+ # Recover the original forward pass and get rid of the parametrization
207
+ del module.parametrization
208
+ module.forward = module._forward
209
+ delattr(module, "_forward")
210
+
211
+ return module
parametrized_model.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from dataclasses import asdict, dataclass, field
4
+ from typing import Any, Literal, Type
5
+
6
+ import torch
7
+ from peft import PeftConfig
8
+ from peft.tuners.tuners_utils import _maybe_include_all_linear_layers, check_target_module_exists
9
+ from torch import nn
10
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
11
+
12
+ from .parametrized_layer import Parametrization, parametrize_module, ParametrizedModule, unparametrize_module
13
+ from .projected_layer import SVDLinearParametrization
14
+ from .utils import get_class_from_str, get_str_from_class, init_empty_weights
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class BaseModelConfig:
21
+ """
22
+ Configuration for the base model to be parametrized by `ParametrizedModel`.
23
+
24
+ Attributes:
25
+ pretrained_model_cls: The class of the base model. Child class of `PreTrainedModel`.
26
+ pretrained_model_kwargs: Keyword arguments used when creating the base model in the constructor
27
+ of `ParametrizedModel` via `from_pretrained`.
28
+ pretrained_config: Optional config used when creating the base model in the constructor
29
+ of `ParametrizedModel` via `from_pretrained`.
30
+
31
+ See Also:
32
+ `ParametrizedModelConfig`
33
+ """
34
+
35
+ pretrained_model_cls: Type[PreTrainedModel]
36
+ pretrained_model_kwargs: dict[str, Any] = field(default_factory=dict)
37
+ pretrained_config: PretrainedConfig | None = None
38
+
39
+ def __post_init__(self):
40
+ # if pretrained_model_cls is a string, convert it to a class (required for deserialization from JSON config)
41
+ if isinstance(self.pretrained_model_cls, str):
42
+ self.pretrained_model_cls = get_class_from_str(self.pretrained_model_cls) # noqa
43
+ else:
44
+ self.pretrained_model_cls = self.pretrained_model_cls
45
+
46
+ def to_dict(self) -> dict[str, Any]:
47
+ config_dict = asdict(self) # type: ignore
48
+ # make sure that pretrained_model_cls and pretrained_config are JSON serializable
49
+ config_dict["pretrained_model_cls"] = get_str_from_class(self.pretrained_model_cls)
50
+ if self.pretrained_config is not None:
51
+ config_dict["pretrained_config"] = self.pretrained_config.to_dict()
52
+ return config_dict
53
+
54
+ @classmethod
55
+ def from_dict(cls, config_dict: dict[str, Any]) -> "BaseModelConfig":
56
+ # try to deserialize pretrained_config with AutoConfig otherwise fall back to PretrainedConfig
57
+ try:
58
+ if config_dict["pretrained_config"] is not None:
59
+ # try AutoConfig to find the right model config class
60
+ config_dict["pretrained_config"] = AutoConfig.for_model(**config_dict["pretrained_config"])
61
+ except ValueError:
62
+ logger.warning("Unrecognized model identifier in AutoConfig, using PretrainedConfig instead.")
63
+ config_dict["pretrained_config"] = PretrainedConfig.from_dict(config_dict["pretrained_config"])
64
+ return cls(**config_dict)
65
+
66
+
67
+ # Predefined parametrization classes for `ParametrizationConfig.module_factory_cls` (avoids absolute package imports)
68
+ PARAMETRIZATION_FACTORY_REGISTRY: dict[str, Type[Parametrization]] = {
69
+ "svd": SVDLinearParametrization,
70
+ }
71
+
72
+
73
+ @dataclass
74
+ class ParametrizationConfig:
75
+ """
76
+ Configuration for the parametrization to be applied to the linear layers of the base model in `ParametrizedModel`.
77
+
78
+ Attributes:
79
+ module_factory_cls: The class name of the parametrization to be applied to linear layers.
80
+ Can be a string representing a class name (with absolute module path) or a predefined key
81
+ from `PARAMETRIZATION_FACTORY_REGISTRY`.
82
+ Use `parse_module_factory_cls` to get the actual class when creating the parametrization.
83
+ module_factory_kwargs: Keyword arguments used when creating the parametrization with `module_factory_cls`.
84
+ target_modules: A (list of) string(s) specifying the names of the linear layers to be parametrized.
85
+ Follows the same semantics as Huggingface's `PeftConfig`, see also `check_target_module_exists`.
86
+ If a string, a regex match will be performed; if a list, a module will be parametrized if its name ends
87
+ with any of the strings in `target_modules`.
88
+ exclude_modules: A list of strings specifying the names of the linear layers to be excluded from
89
+ parametrization. A module will be excluded if any of the strings in `exclude_modules` is in its name.
90
+
91
+ See Also:
92
+ `ParametrizedModelConfig`
93
+ """
94
+
95
+ module_factory_cls: str
96
+ module_factory_kwargs: dict[str, Any] = field(default_factory=dict)
97
+ target_modules: str | list[str] | None = None
98
+ exclude_modules: list[str] | None = None
99
+
100
+ def parse_module_factory_cls(self) -> Type[Parametrization]:
101
+ """Returns the class of the parametrization to be applied to linear layers."""
102
+ try:
103
+ if self.module_factory_cls in PARAMETRIZATION_FACTORY_REGISTRY:
104
+ module_factory_cls = PARAMETRIZATION_FACTORY_REGISTRY[self.module_factory_cls]
105
+ else:
106
+ module_factory_cls = get_class_from_str(self.module_factory_cls)
107
+ except Exception:
108
+ raise ValueError(f"Unrecognized parametrization class: {self.module_factory_cls}")
109
+ return module_factory_cls
110
+
111
+ def to_dict(self) -> dict[str, Any]:
112
+ config_dict = asdict(self) # type: ignore
113
+ # _maybe_include_all_linear_layers creates sets which does not work with JSON serialization, so cast to list
114
+ for key, value in config_dict.items():
115
+ if isinstance(value, set):
116
+ config_dict[key] = list(value)
117
+ return config_dict
118
+
119
+ @classmethod
120
+ def from_dict(cls, config_dict: dict[str, Any]) -> "ParametrizationConfig":
121
+ return cls(**config_dict)
122
+
123
+
124
+ @dataclass
125
+ class AdapterConfig:
126
+ """
127
+ Configuration for the Huggingface Peft adapters to be applied to the base model.
128
+
129
+ Attributes:
130
+ peft_config: One or more adapter `PeftConfig`s to be applied to the base model.
131
+ If a single `PeftConfig` is provided, it will wrapped by a dict with key "default".
132
+ The dictionary keys will be used as adapter names in `PretrainedModel.add_adapter`.
133
+
134
+ See Also:
135
+ `ParametrizedModelConfig`
136
+ """
137
+
138
+ peft_config: PeftConfig | dict[str, PeftConfig]
139
+
140
+ def __post_init__(self):
141
+ if isinstance(self.peft_config, PeftConfig):
142
+ self.peft_config = {"default": self.peft_config}
143
+
144
+ def to_dict(self) -> dict[str, Any]:
145
+ config_dict = asdict(self) # type: ignore
146
+ # Make each PeftConfig JSON serializable
147
+ for adapter_name, peft_config in self.peft_config.items():
148
+ peft_config_dict = peft_config.to_dict()
149
+ # Peft casts lists to sets, which are not JSON serializable, so cast to list manually
150
+ for key, value in peft_config_dict.items():
151
+ if isinstance(value, set):
152
+ peft_config_dict[key] = list(value)
153
+ config_dict["peft_config"][adapter_name] = peft_config_dict
154
+ return config_dict
155
+
156
+ @classmethod
157
+ def from_dict(cls, config_dict: dict[str, Any]) -> "AdapterConfig":
158
+ # Deserialize each PeftConfig automatically with from_peft_type
159
+ for key, peft_config in config_dict["peft_config"].items():
160
+ config_dict["peft_config"][key] = PeftConfig.from_peft_type(**peft_config)
161
+ return cls(**config_dict)
162
+
163
+
164
+ try:
165
+ # Prevent import errors because for some systems like macOS, bitsandbytes cannot be installed directly
166
+ import bitsandbytes
167
+
168
+ # Predefined quantization classes for `WeightQuantizationConfig.module_factory_cls`
169
+ # (avoids absolute package imports)
170
+ QUANTIZATION_FACTORY_REGISTRY: dict[str, Type[nn.Linear]] = {
171
+ "bnb4bit": bitsandbytes.nn.Linear4bit,
172
+ }
173
+ except ImportError:
174
+ logger.warning("bitsandbytes is not installed, skipping quantization.")
175
+ QUANTIZATION_FACTORY_REGISTRY: dict[str, Type[nn.Linear]] = {}
176
+
177
+
178
+ @dataclass
179
+ class WeightQuantizationConfig:
180
+ """
181
+ Configuration for an (optional) weight quantization to be applied to the base model.
182
+ So far, only fp4 quantization with bitsandbytes has been tested, but analogous bitsandbytes
183
+ quantizations should work as well. `module_factory_cls` might also use a different quantization library,
184
+ as long as it is compatible with the module replacement strategy in `ParametrizedModule.quantize`.
185
+
186
+ Attributes:
187
+ module_factory_cls: The class name of the quantization to be applied to linear layers.
188
+ Can be a string representing a class name (with absolute module path) or a predefined key
189
+ from `QUANTIZATION_FACTORY_REGISTRY`.
190
+ Use `parse_module_factory_cls` to get the actual class when creating the quantization.
191
+ module_factory_kwargs: Keyword arguments used when creating the quantization with `module_factory_cls`.
192
+ target_modules: A (list of) string(s) specifying the names of the linear layers to be quantized.
193
+ Follows the same semantics as Huggingface's `PeftConfig`, see also `check_target_module_exists`.
194
+ If a string, a regex match will be performed; if a list, a module will be quantized if its name ends
195
+ with any of the strings in `target_modules`.
196
+ exclude_modules: A list of strings specifying the names of the linear layers to be excluded from
197
+ quantization. A module will be excluded if any of the strings in `exclude_modules` is in its name.
198
+
199
+ See Also:
200
+ `ParametrizedModelConfig`
201
+ """
202
+
203
+ module_factory_cls: str
204
+ module_factory_kwargs: dict[str, Any] = field(default_factory=dict)
205
+ target_modules: str | list[str] | None = None
206
+ exclude_modules: list[str] | None = None
207
+
208
+ def parse_module_factory_cls(self) -> Type[nn.Linear]:
209
+ """Returns the class of the quantization to be applied to linear layers."""
210
+ try:
211
+ if self.module_factory_cls in QUANTIZATION_FACTORY_REGISTRY:
212
+ module_factory_cls = QUANTIZATION_FACTORY_REGISTRY[self.module_factory_cls]
213
+ else:
214
+ module_factory_cls = get_class_from_str(self.module_factory_cls)
215
+ except Exception:
216
+ raise ValueError(f"Unrecognized quantization class: {self.module_factory_cls}")
217
+ return module_factory_cls
218
+
219
+ def to_dict(self) -> dict[str, Any]:
220
+ config_dict = asdict(self) # type: ignore
221
+ # Make torch.dtype fields JSON serializable
222
+ for key, value in config_dict["module_factory_kwargs"].items():
223
+ if isinstance(value, torch.dtype):
224
+ config_dict["module_factory_kwargs"][key] = str(value)
225
+ # _maybe_include_all_linear_layers creates sets which does not work with JSON serialization, so cast to list
226
+ for key, value in config_dict.items():
227
+ if isinstance(value, set):
228
+ config_dict[key] = list(value)
229
+ return config_dict
230
+
231
+ @classmethod
232
+ def from_dict(cls, config_dict: dict[str, Any]) -> "WeightQuantizationConfig":
233
+ # Deserialize torch.dtype fields
234
+ for key, value in config_dict["module_factory_kwargs"].items():
235
+ if isinstance(value, str) and value.startswith("torch."):
236
+ dtype_name = value.split(".")[-1]
237
+ config_dict["module_factory_kwargs"][key] = getattr(torch, dtype_name)
238
+ return cls(**config_dict)
239
+
240
+
241
+ class ParametrizedModelConfig(PretrainedConfig):
242
+ """
243
+ Configuration for `ParametrizedModel` implementing a `PretrainedConfig` to be fully compatible with
244
+ Huggingface's `PreTrainedModel` framework.
245
+
246
+ See Also:
247
+ - `BaseModelConfig`
248
+ - `ParametrizationConfig`
249
+ - `AdapterConfig`
250
+ - `WeightQuantizationConfig`
251
+ - `ParametrizedModel`
252
+ """
253
+
254
+ model_type = "parametrized_model"
255
+
256
+ def __init__(
257
+ self,
258
+ base_model_config: BaseModelConfig | None = None,
259
+ parametrization_config: ParametrizationConfig | None = None,
260
+ adapter_config: AdapterConfig | None = None,
261
+ weight_quantization_config: WeightQuantizationConfig | None = None,
262
+ model_mode: Literal["train", "eval"] = "train",
263
+ **kwargs: Any,
264
+ ):
265
+ """
266
+ Initializes a `ParametrizedModelConfig`, serving as a container for `BaseModelConfig`, `ParametrizationConfig`,
267
+ `AdapterConfig`, and `WeightQuantizationConfig`.
268
+
269
+ Args:
270
+ base_model_config: `BaseModelConfig`
271
+ parametrization_config: `ParametrizationConfig`
272
+ adapter_config: `AdapterConfig`
273
+ weight_quantization_config: `WeightQuantizationConfig`
274
+ model_mode: Whether to initialize the model in train or eval mode.
275
+ **kwargs: Keyword arguments forwarded to `PretrainedConfig`.
276
+ """
277
+ self.base_model_config = base_model_config
278
+ self.parametrization_config = parametrization_config
279
+ self.adapter_config = adapter_config
280
+ self.weight_quantization_config = weight_quantization_config
281
+ self.model_mode = model_mode
282
+ super().__init__(**kwargs)
283
+
284
+ def _convert_to_dict(self, config_dict: dict[str, Any]) -> dict[str, Any]:
285
+ if self.base_model_config is not None:
286
+ config_dict["base_model_config"] = self.base_model_config.to_dict()
287
+ if self.parametrization_config is not None:
288
+ config_dict["parametrization_config"] = self.parametrization_config.to_dict()
289
+ if self.adapter_config is not None:
290
+ config_dict["adapter_config"] = self.adapter_config.to_dict()
291
+ if self.weight_quantization_config is not None:
292
+ config_dict["weight_quantization_config"] = self.weight_quantization_config.to_dict()
293
+ return config_dict
294
+
295
+ def to_diff_dict(self):
296
+ # Override PretrainedConfig to_diff_dict to make subconfigs JSON serializable.
297
+ config_dict = super().to_diff_dict()
298
+ return self._convert_to_dict(config_dict)
299
+
300
+ def to_dict(self):
301
+ # Override PretrainedConfig to_diff to make subconfigs JSON serializable.
302
+ config_dict = super().to_dict()
303
+ return self._convert_to_dict(config_dict)
304
+
305
+ @classmethod
306
+ def from_dict(cls, config_dict: dict[str, Any], **kwargs: Any) -> PretrainedConfig:
307
+ # Deserialize BaseModelConfig
308
+ base_model_config_dict: dict[str, Any] | None = config_dict.pop("base_model_config", None)
309
+ if base_model_config_dict is not None:
310
+ base_model_config = BaseModelConfig.from_dict(base_model_config_dict)
311
+ else:
312
+ base_model_config = None
313
+ # Deserialize ParametrizationConfig
314
+ parametrization_config_dict: dict[str, Any] | None = config_dict.pop("parametrization_config", None)
315
+ if parametrization_config_dict is not None:
316
+ parametrization_config = ParametrizationConfig.from_dict(parametrization_config_dict)
317
+ else:
318
+ parametrization_config = None
319
+ # Deserialize AdapterConfig
320
+ adapter_config_dict: dict[str, Any] | None = config_dict.pop("adapter_config", None)
321
+ if adapter_config_dict is not None:
322
+ adapter_config = AdapterConfig.from_dict(adapter_config_dict)
323
+ else:
324
+ adapter_config = None
325
+ # Deserialize WeightQuantizationConfig
326
+ weight_quantization_config_dict: dict[str, Any] | None = config_dict.pop("weight_quantization_config", None)
327
+ if weight_quantization_config_dict is not None:
328
+ weight_quantization_config = WeightQuantizationConfig.from_dict(weight_quantization_config_dict)
329
+ else:
330
+ weight_quantization_config = None
331
+
332
+ config = super().from_dict(config_dict, **kwargs)
333
+
334
+ # Handle special case when return_unused_kwargs is True
335
+ if "return_unused_kwargs" in kwargs and kwargs["return_unused_kwargs"] is True:
336
+ config[0].base_model_config = base_model_config
337
+ config[0].parametrization_config = parametrization_config
338
+ config[0].adapter_config = adapter_config
339
+ config[0].weight_quantization_config = weight_quantization_config
340
+ else:
341
+ config.base_model_config = base_model_config
342
+ config.parametrization_config = parametrization_config
343
+ config.adapter_config = adapter_config
344
+ config.weight_quantization_config = weight_quantization_config
345
+ return config
346
+
347
+
348
+ class ParametrizedModel(PreTrainedModel):
349
+ """
350
+ Base class for parametrized models implemented as a custom Huggingface `PreTrainedModel`.
351
+ It wraps any base model of type `PreTrainedModel` in `self.model`, whose linear layers can be
352
+ parametrized (`parametrize`), equipped with adapters (`inject_adapters`), and quantized (`quantize`).
353
+ The corresponding modules are accessed via `parametrized_modules`, `adapter_modules`,
354
+ and `quantized_modules`, respectively.
355
+ The class also provides several convenience methods to manage the parametrization: `get_target_params`,
356
+ `get_num_params`, `get_compression_ratio`, `reset_target_params`, `compress`.
357
+
358
+ Standard functionality (`forward`, `generate`, `save_pretrained`, `from_pretrained`) is essentially forwarded
359
+ to the wrapped model.
360
+
361
+ See Also:
362
+ `ParametrizedModelConfig`
363
+ """
364
+
365
+ config_class = ParametrizedModelConfig
366
+
367
+ def __init__(self, config: ParametrizedModelConfig, base_model: PreTrainedModel | None = None, **_: Any):
368
+ """
369
+ Initialize the `ParametrizedModel` from a given configuration or an existing base model.
370
+
371
+ Args:
372
+ config: `ParametrizedModelConfig` to be used.
373
+ base_model: If provided, this base model is used instead of creating it from `config.base_model_config`.
374
+ **_: Ignored keyword arguments to prevent unexpected keyword errors.
375
+
376
+ See Also: `BaseModelConfig`
377
+ """
378
+ super().__init__(config)
379
+ self.config = config # redundant but enables type hinting for ParametrizedModelConfig
380
+
381
+ # Either use an existing base model or create a new one from config.base_model_config
382
+ if base_model is None:
383
+ if self.config.base_model_config is None:
384
+ raise ValueError("Either base_model or base_model_config must be provided.")
385
+ self.model = self.config.base_model_config.pretrained_model_cls.from_pretrained(
386
+ config=self.config.base_model_config.pretrained_config,
387
+ **self.config.base_model_config.pretrained_model_kwargs,
388
+ )
389
+ else:
390
+ self.model = base_model
391
+
392
+ # Set base model to train or eval mode.
393
+ self.train(self.config.model_mode == "train")
394
+ logger.info(f"Base model {self.model.__class__} created.")
395
+
396
+ # Perform parametrization.
397
+ self._parametrized_modules: dict[str, ParametrizedModule] | None = None
398
+ self.parametrize()
399
+
400
+ # Inject adapters.
401
+ self._adapter_modules: dict[str, nn.Module] | None = None
402
+ self.inject_adapters()
403
+
404
+ # Quantization needs to be performed manually via `quantize` because this is fully optional.
405
+ self._quantized_modules: dict[str, nn.Linear] | None = None
406
+
407
+ # Modified modules are initalized after parametrize and inject_adapters because they may alter the nested
408
+ # module and parameter structure of the model.
409
+ _ = self.parametrized_modules
410
+ _ = self.adapter_modules
411
+ _ = self.quantized_modules
412
+
413
+ # Initially disable all tunable parameters to avoid unexpected behavior.
414
+ # Tunable parameter selection should be handled by the optimizer factory in `BaseLitModule`.
415
+ for param in self.parameters():
416
+ param.requires_grad = False
417
+
418
+ @property
419
+ def base_model_name_or_path(self) -> str:
420
+ """Convenience method to return the name or path of the base model."""
421
+ return self.model.name_or_path # type: ignore
422
+
423
+ def forward(self, *args, **kwargs) -> Any:
424
+ return self.model(*args, **kwargs)
425
+
426
+ def generate(self, *args, **kwargs) -> Any:
427
+ return self.model.generate(*args, **kwargs)
428
+
429
+ def save_pretrained(
430
+ self,
431
+ save_directory: str | os.PathLike,
432
+ state_dict: dict | None = None,
433
+ include_filter: list[str] | None = None,
434
+ exclude_filter: list[str] | None = None,
435
+ **kwargs: Any,
436
+ ) -> None:
437
+ """
438
+ Override of the default `save_pretrained` method to allow filtering of the saved state dict.
439
+
440
+ Args:
441
+ save_directory: Directory to save the model to.
442
+ state_dict: Manuel override of the state dict to be saved.
443
+ If None, `include_filter` and `exclude_filter` are applied to `self.state_dict()`.
444
+ include_filter: List of state dict keys to include from the state dict.
445
+ Match when the key ends with any of the strings in the list.
446
+ If None, all keys are included.
447
+ exclude_filter: List of state dict keys to exclude from in the state dict.
448
+ Match when the key ends with any of the strings in the list.
449
+ If None, no keys are excluded.
450
+ **kwargs: Keyword arguments to be passed to the default `save_pretrained` method.
451
+
452
+ See Also:
453
+ `PreTrainedModel.save_pretrained`
454
+ """
455
+ if state_dict is None:
456
+ state_dict = self.state_dict()
457
+ if include_filter is not None:
458
+ state_dict = {k: v for k, v in state_dict.items() if any(k.endswith(f) for f in include_filter)}
459
+ if exclude_filter is not None:
460
+ state_dict = {k: v for k, v in state_dict.items() if not any(k.endswith(f) for f in exclude_filter)}
461
+
462
+ super().save_pretrained(save_directory=save_directory, state_dict=state_dict, **kwargs)
463
+
464
+ @classmethod
465
+ def from_pretrained(
466
+ cls,
467
+ pretrained_model_name_or_path: str | os.PathLike | None,
468
+ *model_args: Any,
469
+ with_init_empty_weights: bool = True,
470
+ **kwargs: Any,
471
+ ) -> PreTrainedModel:
472
+ """
473
+ Override of the default `from_pretrained` method to allow initialization with empty weights.
474
+
475
+ Args:
476
+ pretrained_model_name_or_path: Model name or path.
477
+ *model_args: Arguments to be passed to the default `from_pretrained` method.
478
+ with_init_empty_weights: Whether to initialize the model with empty weights or not.
479
+ **kwargs: Keyword arguments to be passed to the default `from_pretrained` method.
480
+ """
481
+ with init_empty_weights(with_init_empty_weights):
482
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
483
+
484
+ @property
485
+ def parametrized_modules(self) -> dict[str, ParametrizedModule]:
486
+ """
487
+ Returns a dictionary of all parametrized modules in the model.
488
+ The returned dictionary is compatible with `self.model.named_modules()`.
489
+ """
490
+ if self._parametrized_modules is None:
491
+ self._parametrized_modules = {}
492
+ if self.config.parametrization_config is None:
493
+ return self._parametrized_modules
494
+ for m_name, module in self.model.named_modules():
495
+ if isinstance(module, ParametrizedModule):
496
+ self._parametrized_modules[m_name] = module
497
+ return self._parametrized_modules
498
+
499
+ @property
500
+ def adapter_modules(self) -> dict[str, nn.Module]:
501
+ """
502
+ Returns a dictionary of all adapter modules in the model.
503
+ The returned dictionary is compatible with `self.model.named_modules()`.
504
+ """
505
+ if self._adapter_modules is None:
506
+ self._adapter_modules = {}
507
+ if self.config.adapter_config is None:
508
+ return self._adapter_modules
509
+ try:
510
+ # Use the adapter management of `PreTrainedModel` to retrieve the adapter modules.
511
+ for adapter_name in self.model.active_adapters():
512
+ for m_name in self.model.get_adapter_state_dict(adapter_name).keys():
513
+ adapter_m_name = f"{m_name.rsplit('.', 1)[0]}.{adapter_name}"
514
+ self._adapter_modules[adapter_m_name] = self.model.get_submodule(adapter_m_name)
515
+ except ValueError as e:
516
+ logger.warning(e)
517
+ return self._adapter_modules
518
+
519
+ @property
520
+ def quantized_modules(self) -> dict[str, nn.Linear]:
521
+ """
522
+ Returns a dictionary of all quantized modules in the model.
523
+ The returned dictionary is compatible with `self.model.named_modules()`.
524
+ """
525
+ if self._quantized_modules is None:
526
+ self._quantized_modules = {}
527
+ if self.config.weight_quantization_config is None:
528
+ return self._quantized_modules
529
+ try:
530
+ module_factory_cls = self.config.weight_quantization_config.parse_module_factory_cls()
531
+ except Exception as e:
532
+ logger.warning(f"Could not parse weight quantization config, quantization not available.\nError: {e}")
533
+ return self._quantized_modules
534
+ for m_name, module in self.model.named_modules():
535
+ if isinstance(module, module_factory_cls):
536
+ self._quantized_modules[m_name] = module
537
+ return self._quantized_modules
538
+
539
+ def parametrize(self) -> None:
540
+ """
541
+ Parametrize the `target_modules` from `ParametrizationConfig` using `parametrized_layer.parametrize_module`.
542
+
543
+ See Also: `ParametrizationConfig`
544
+ """
545
+ if self.config.parametrization_config is None:
546
+ logger.debug("Model parametrization is disabled.")
547
+ return
548
+
549
+ # Use peft semantics, e.g, "all-linear" to include all linear layers
550
+ # TODO: Replace by own helper function to avoid unnecessary dependencies
551
+ config: ParametrizationConfig = _maybe_include_all_linear_layers( # type: ignore
552
+ self.config.parametrization_config, # type: ignore
553
+ self.model,
554
+ )
555
+ module_factory_cls = config.parse_module_factory_cls()
556
+
557
+ for m_name, module in self.model.named_modules():
558
+ # Only modify the modules that are targeted
559
+ if config.exclude_modules is not None and any(key in m_name for key in config.exclude_modules):
560
+ continue
561
+ if not check_target_module_exists(config, m_name):
562
+ continue
563
+
564
+ parametrization = module_factory_cls(**config.module_factory_kwargs)
565
+ parametrize_module(module=module, parametrization=parametrization)
566
+ logger.debug(f"Parametrized {module.__class__} module {m_name} as {parametrization.__class__}")
567
+
568
+ self._parametrized_modules = None # reset parametrized modules
569
+ logger.info("Parametrization completed.")
570
+
571
+ def inject_adapters(self) -> None:
572
+ """
573
+ Inject adapters according to `AdapterConfig` using the adapter management of `PreTrainedModel`.
574
+
575
+ See Also: `AdapterConfig`
576
+ """
577
+ if self.config.adapter_config is None:
578
+ logger.debug("Adapter injection is disabled.")
579
+ return
580
+
581
+ for adapter_name, peft_config in self.config.adapter_config.peft_config.items():
582
+ self.model.add_adapter(peft_config, adapter_name=adapter_name)
583
+ self.model.set_adapter(list(self.config.adapter_config.peft_config.keys()))
584
+
585
+ self._adapter_modules = None # reset adapter modules
586
+ logger.info("Adapters injected.")
587
+
588
+ def quantize(self) -> None:
589
+ """
590
+ Quantize the `target_modules` from `WeightQuantizationConfig`.
591
+
592
+ See Also: `WeightQuantizationConfig`
593
+ """
594
+ if self.config.weight_quantization_config is None:
595
+ logger.debug("Weight quantization is disabled.")
596
+ return
597
+
598
+ # Use peft semantics e.g "all-linear" to include all linear layers
599
+ # TODO: Replace by own helper function to avoid unnecessary dependencies
600
+ config: WeightQuantizationConfig = _maybe_include_all_linear_layers( # type: ignore
601
+ self.config.weight_quantization_config, # type: ignore
602
+ self.model,
603
+ )
604
+ module_factory_cls = config.parse_module_factory_cls()
605
+
606
+ for m_name, module in self.model.named_modules():
607
+ # Only modify the modules that are targeted
608
+ if config.exclude_modules is not None and any(key in m_name for key in config.exclude_modules):
609
+ continue
610
+ if not check_target_module_exists(config, m_name) or isinstance(module, ParametrizedModule):
611
+ continue
612
+ if not isinstance(module, nn.Linear):
613
+ continue
614
+
615
+ # Important: This module must NOT be created in a device context like with_init_device("cuda")
616
+ quantized_module = module_factory_cls(
617
+ module.in_features,
618
+ module.out_features,
619
+ bias=module.bias is not None,
620
+ device=module.weight.device,
621
+ **config.module_factory_kwargs,
622
+ )
623
+ # cf. https://huggingface.co/docs/bitsandbytes/reference/nn/linear4bit#bitsandbytes.nn.Linear4bit.example
624
+ quantized_module.load_state_dict(module.state_dict())
625
+ quantized_module = quantized_module.to(module.weight.device)
626
+ quantized_module.weight.requires_grad = False
627
+ logger.debug(f"Quantized {module.__class__} module {m_name} to {quantized_module.__class__}")
628
+
629
+ # Replace the target module by the quantized module
630
+ parent_name, child_name = m_name.rsplit(".", 1)
631
+ parent_module = self.model.get_submodule(parent_name)
632
+ parent_module.add_module(child_name, quantized_module)
633
+
634
+ self._quantized_modules = None # reset quantized modules
635
+ logger.info("Quantization completed.")
636
+
637
+ def get_target_params(self) -> dict[str, nn.Parameter]:
638
+ """
639
+ Lifts `Parametrization.get_target_params` to the model scope.
640
+ The returned dictionary should be compatible with `self.model.named_parameters()`.
641
+
642
+ See Also:
643
+ `Parametrization.get_target_params`
644
+ """
645
+ target_params = {}
646
+ for m_name, module in self.parametrized_modules.items():
647
+ for p_name, param in module.parametrization.get_target_params().items():
648
+ target_params[f"{m_name}.parametrization.{p_name}"] = param
649
+ return target_params
650
+
651
+ def get_num_params(
652
+ self, compressed: bool = False, full: bool = False, target_params: dict[str, torch.Tensor] | None = None
653
+ ) -> int:
654
+ """
655
+ Lifts `Parametrization.get_num_params` to the model scope.
656
+ Computes the (effective) number of parameters of the entire model.
657
+
658
+ Args:
659
+ compressed: Whether to count the number of parameters as if the parametrized modules were actually
660
+ compressed. If `False`, the number of parameters is the same as in the original module.
661
+ full: If `True`, all parameters of the model are counted, if `False` only those of parametrized modules.
662
+ Default is `False`, which follows the most common convention in the compression literature.
663
+ target_params: Count the number of parameters as if `target_params` were used instead of
664
+ the parametrized modules' target parameters. The dictionary keys should be compatible with those of
665
+ `self.get_target_params`.
666
+
667
+ See Also:
668
+ `Parametrization.get_num_params`
669
+ """
670
+ num_params_full = 0
671
+ if full:
672
+ for name, param in self.model.named_parameters():
673
+ if "parametrization" not in name: # exclude parametrized modules here (counted below)
674
+ if hasattr(param, "quant_state"): # HOTFIX: special case for bitsandbytes-quantized parameters
675
+ num_params_full += param.numel() * 2
676
+ else:
677
+ num_params_full += param.numel()
678
+
679
+ num_params = 0
680
+ for module_name, module in self.parametrized_modules.items():
681
+ module_target_params = None
682
+ if compressed and target_params is not None:
683
+ # Make target_params' keys those of parametrized models, i.e., trim f"{module_name}.parametrization."
684
+ prefix = f"{module_name}.parametrization."
685
+ # Filter and re-map keys for the current module
686
+ module_target_params = {
687
+ key[len(prefix) :]: value for key, value in target_params.items() if key.startswith(prefix)
688
+ }
689
+ if not module_target_params:
690
+ module_target_params = None
691
+
692
+ num_params += module.parametrization.get_num_params(
693
+ compressed=compressed, target_params=module_target_params
694
+ )
695
+ num_params = num_params + num_params_full
696
+ if num_params == 0:
697
+ # dummy to avoid division by zero (e.g., if there are no parametrized_modules and full=False)
698
+ num_params = 1e-6
699
+ return num_params
700
+
701
+ def get_compression_ratio(self, full: bool = False, target_params: dict[str, torch.Tensor] | None = None) -> float:
702
+ """
703
+ Convenience function to compute the compression ratio of the present model.
704
+
705
+ See Also:
706
+ `get_num_params`
707
+ """
708
+ return self.get_num_params(compressed=True, full=full, target_params=target_params) / self.get_num_params(
709
+ full=full
710
+ )
711
+
712
+ def reset_target_params(self, mode: Literal["full", "nonzero", "compress"] = "full") -> None:
713
+ """
714
+ Lifts `Parametrization.reset_target_params` to the model scope.
715
+
716
+ Args:
717
+ mode: The reset mode, see `Parametrization.reset_target_params`.
718
+
719
+ See Also:
720
+ `Parametrization.reset_target_params`
721
+ """
722
+ for m_name, module in self.parametrized_modules.items():
723
+ module.parametrization.reset_target_params(mode=mode)
724
+
725
+ def compress(self) -> None:
726
+ """
727
+ Compresses all parametrized modules using `Parametrization.reset_target_params(mode="compress")`.
728
+ If no compression is possible, the module is unparametrized and removed from `parametrized_modules`.
729
+ """
730
+ removed_parametrized_modules = []
731
+ for m_name, module in self.parametrized_modules.items():
732
+ if module.parametrization.get_num_params(compressed=True) / module.parametrization.get_num_params() >= 1.0:
733
+ unparametrize_module(module)
734
+ removed_parametrized_modules.append(m_name)
735
+ logger.debug(f"Unparametrizing {module.__class__} module {m_name}")
736
+ else:
737
+ module.parametrization.reset_target_params(mode="compress")
738
+ logger.debug(f"Compressing {module.__class__} module {m_name}")
739
+ for m_name in removed_parametrized_modules:
740
+ self.parametrized_modules.pop(m_name)
741
+ logger.info("Compression completed.")
742
+
743
+
744
+ # Register ParametrizedModelConfig and ParametrizedModel for AutoModel
745
+ # Required to push custom model to Huggingface Hub (see https://huggingface.co/docs/transformers/en/custom_models)
746
+ ParametrizedModelConfig.register_for_auto_class()
747
+ ParametrizedModel.register_for_auto_class("AutoModel")
projected_layer.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from logging import getLogger
4
+ from typing import Literal
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .parametrized_layer import Parametrization
11
+ from .utils import use_init_empty_weights
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class CompressionCriterion(ABC):
17
+ """
18
+ Abstract class for compression criterion of a (target) parameter of a parametrized module.
19
+ """
20
+
21
+ @abstractmethod
22
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
23
+ """
24
+ Args:
25
+ x: A tensor of any shape
26
+
27
+ Returns: A boolean mask of the same shape as `x` where `False` indicates that the entry can be removed.
28
+ """
29
+ raise NotImplementedError
30
+
31
+
32
+ class ThresholdCriterion(CompressionCriterion):
33
+ """
34
+ Compression criterion based on a threshold. All entries below `self.threshold` can be removed.
35
+ """
36
+
37
+ def __init__(self, threshold: float = 0.0):
38
+ self.threshold = threshold
39
+
40
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
41
+ return x > self.threshold
42
+
43
+
44
+ class ProjectedLinearParametrization(Parametrization, ABC):
45
+ """
46
+ Implementation of a linear layer parametrization, factorizing the weight matrix as
47
+ `weight = ortho.weight @ torch.diag(mask) @ base.weight`.
48
+ Here, `ortho` is a linear layer with orthogonal columns, `mask` represents a (binary) diagonal matrix
49
+ that can be pruned, and `base` is a linear layer (determined by the choice of `ortho`).
50
+ Any child class needs to implement `_ortho_init` which creates `ortho`. Based on this, `mask` and `base` are
51
+ initialized such that the original weight matrix is obtained at initialization.
52
+
53
+ `mask` corresponds to the only target parameter of this parametrization. Pruning it will result in
54
+ a low-rank matrix representation of the parametrized linear module.
55
+ """
56
+
57
+ base_class = nn.Linear
58
+
59
+ def __init__(
60
+ self,
61
+ mask_func: Literal["ste", "relu", "none"] = "ste",
62
+ mask_scaling_factor: float | str = "norm",
63
+ compression_criterion: CompressionCriterion = ThresholdCriterion(),
64
+ ):
65
+ """
66
+ Args:
67
+ mask_func: A function applied to the mask parameter in each forward pass implementing
68
+ custom functionalities. Available options: ["ste", "relu", "none"].
69
+ "ste" means using a straight-through estimator, i.e., in the forward pass, `mask` is binarized, which
70
+ is ignored in the backward pass. Before `mask` passed through a ReLU activation.
71
+ "relu" means that `mask` is passed through a ReLU activation.
72
+ "none" means that `mask` is not modified.
73
+ mask_scaling_factor: Conceptually, `mask` is initialized with ones, but rescaling to a smaller value
74
+ can vastly improve the training speed. `mask_scaling_factor` specifies this rescaling factor.
75
+ The rescaling should be compensated by scaling `ortho` accordingly in `self._ortho_init`.
76
+ If `mask_scaling_factor='norm'`, the scaling factor is chosen such that `mask` has unit L2 norm
77
+ (note that this can lead to a different behavior in model tuning than for a fixed factor
78
+ when some target parameters have different number of elements).
79
+ compression_criterion: `CompressionCriterion` to be used in `self.reset_target_params(mode="compress")`.
80
+ """
81
+ super().__init__()
82
+ self.mask_func = {
83
+ "ste": mask_func_ste,
84
+ "relu": mask_func_relu,
85
+ "none": mask_func_none,
86
+ }[mask_func]
87
+ self._mask_scaling_factor = mask_scaling_factor
88
+ self.compression_criterion = compression_criterion
89
+
90
+ def _forward(self, x: torch.Tensor) -> torch.Tensor:
91
+ # This implementation avoids an explicit materalization of `weight`.
92
+ x = self.base(x)
93
+ x = self.mask_func(self.mask, self.mask_scaling_factor) * x
94
+ x = self.ortho(x)
95
+ return x
96
+
97
+ def _weight(self) -> torch.Tensor:
98
+ # Compute the original weight matrix, don't use this in forward pass for efficiency reasons
99
+ mask = self.mask_func(self.mask, self.mask_scaling_factor)
100
+ return self.ortho.weight @ torch.diag(mask) @ self.base.weight
101
+
102
+ def _bias(self) -> torch.Tensor | None:
103
+ return self.ortho.bias
104
+
105
+ def _initialize(self, base_module: base_class) -> None:
106
+ factory_kwargs = {"device": base_module.weight.device, "dtype": base_module.weight.dtype}
107
+ in_dim, out_dim = base_module.in_features, base_module.out_features
108
+ proj_dim = min(in_dim, out_dim) # infer mask (bottleneck) dimension
109
+
110
+ # Initialize ortho layer ....
111
+ self.add_module(
112
+ "ortho",
113
+ nn.Linear(in_features=proj_dim, out_features=out_dim, bias=base_module.bias is not None, **factory_kwargs),
114
+ )
115
+ self._ortho_init(base_module.weight)
116
+ if base_module.bias is not None:
117
+ # It is important that ortho carries the bias (and not base) because ortho is used to compute the final
118
+ # output of the forward pass
119
+ self.ortho.bias.data.copy_(base_module.bias.data)
120
+
121
+ # ... and compute the base layer based on the choice of ortho (this only works of ortho has orthogonal columns)
122
+ base = base_module.__class__(in_features=in_dim, out_features=proj_dim, bias=False, **factory_kwargs)
123
+ base.weight.data.copy_(self.ortho.weight.data.T @ base_module.weight.data)
124
+ self.add_module("base", base)
125
+
126
+ # Creating (tunable) mask parameter ...
127
+ self.register_parameter("mask", torch.nn.Parameter(torch.ones(proj_dim, **factory_kwargs)))
128
+ # ... and rescale mask properly in a separate step
129
+ # (because reset_target_params calls mask_scaling_factor, which in turn may require mask to already exist)
130
+ self.reset_target_params()
131
+
132
+ @abstractmethod
133
+ def _ortho_init(self, weight: torch.Tensor) -> None:
134
+ """
135
+ Initialize ortho layer. Must be implemented by child class.
136
+
137
+ Args:
138
+ weight: Weight matrix of the original linear layer module.
139
+ """
140
+ raise NotImplementedError
141
+
142
+ def get_target_params(self) -> dict[str, torch.nn.Parameter]:
143
+ return {"mask": self.mask}
144
+
145
+ @property
146
+ def mask_scaling_factor(self) -> float:
147
+ if self._mask_scaling_factor == "norm":
148
+ # Choose scaling factor such that mask has unit L2 norm.
149
+ # Note: mask already needs to exist at this point to infer its shape.
150
+ self._mask_scaling_factor = 1 / math.sqrt(self.mask.numel())
151
+ return self._mask_scaling_factor
152
+ elif isinstance(self._mask_scaling_factor, float):
153
+ return self._mask_scaling_factor
154
+ else:
155
+ raise ValueError(f"Invalid mask_scaling_factor: {self._mask_scaling_factor}")
156
+
157
+ @property
158
+ def in_features(self) -> int:
159
+ return self.base.in_features
160
+
161
+ @property
162
+ def out_features(self) -> int:
163
+ return self.ortho.out_features
164
+
165
+ def reset_target_params(self, mode: Literal["full", "nonzero", "compress"] = "full") -> None:
166
+ with torch.no_grad():
167
+ if mode == "full":
168
+ # Scale mask values properly by self.mask_scaling_factor
169
+ self.mask.data = torch.ones_like(self.mask.data) * self.mask_scaling_factor
170
+ elif mode == "nonzero":
171
+ # Scale mask values properly by self.mask_scaling_factor
172
+ self.mask.data[self.mask.data > 0] = 1.0 * self.mask_scaling_factor
173
+ self.mask.data[self.mask.data < 0] = 0.0
174
+ elif mode == "compress":
175
+ if self.compression_criterion is None:
176
+ logger.warning("Compression criterion is not set. No op...")
177
+ return
178
+ # Select entries of parameter mask that should be kept
179
+ dim_select = self.compression_criterion(self.mask)
180
+ # Create and register compressed layers and mask
181
+ new_base = new_linear_from_mask(self.base, dim_select, column_select=False)
182
+ new_ortho = new_linear_from_mask(self.ortho, dim_select, column_select=True)
183
+ new_mask = self.mask[dim_select].clone().detach()
184
+ del self.mask, self.base, self.ortho
185
+ self.register_module("base", new_base)
186
+ self.register_module("ortho", new_ortho)
187
+ self.register_parameter("mask", nn.Parameter(new_mask))
188
+ else:
189
+ raise ValueError(f"Invalid mode: {mode}")
190
+
191
+ def get_num_params(self, compressed: bool = False, target_params: dict[str, torch.Tensor] | None = None) -> int:
192
+ if not compressed:
193
+ # Compute number of parameters for full linear layer
194
+ num_params = self.in_features * self.out_features
195
+ if self.bias is not None:
196
+ num_params += self.out_features
197
+ return num_params
198
+ else:
199
+ # Compute number of mask values that could be discarded by self.reset_target_params(mode="compress") ...
200
+ if target_params is not None:
201
+ sparsity = mask_sparsity(target_params["mask"] != 0.0, threshold=0.0)
202
+ else:
203
+ sparsity = mask_sparsity(self.mask)
204
+ # ... and compute the (hypothetical) number of parameters for a compressed module.
205
+ num_params = self.in_features * sparsity + sparsity * self.out_features
206
+ if self.bias is not None:
207
+ num_params += self.out_features
208
+ # If the number of parameters for the compressed module would be larger than the number of parameters
209
+ # for the full module, return the latter because we can always unparametrize to the original module if
210
+ # compression would not be effective.
211
+ num_params = min(self.get_num_params(compressed=False), num_params)
212
+ return num_params
213
+
214
+
215
+ class SVDLinearParametrization(ProjectedLinearParametrization):
216
+ """
217
+ Implementation of a linear layer parametrization using SVD decomposition.
218
+ If the SVD of weight is U * S * V^T, then `ortho.weight = U` and `base.weight = S * V^T`.
219
+ As base is computed automatically by `_initialize`, `_ortho_init` only needs to compute U and
220
+ scale it properly with `mask_scaling_factor`. The singular values S are buffered just in case they are needed
221
+ in the tuning process.
222
+ """
223
+
224
+ def _ortho_init(self, weight: torch.Tensor) -> None:
225
+ k = min(weight.shape[0], weight.shape[1])
226
+ if use_init_empty_weights.get():
227
+ # Check if the init_empty_weights context is active which avoids a (costly) SVD computation and just
228
+ # initializes U and S as empty tensors. They are loaded later from a pretrained model.
229
+ logger.debug("Parametrizing with empty weights.")
230
+ U = torch.empty(weight.shape[0], k)
231
+ S = torch.empty(k, 1)
232
+ else:
233
+ # Detaching is important to avoid memory leaks. torch.linalg.svd only works with float32.
234
+ U, S, _ = torch.linalg.svd(weight.detach().float(), full_matrices=False)
235
+ # Rescaling U based on mask_scaling_factor
236
+ # This step is somewhat manual because calling mask_scaling_factor requires the mask to already exist
237
+ if self._mask_scaling_factor == "norm":
238
+ U = math.pow(k, 1 / 4) * U
239
+ else:
240
+ U = math.sqrt(1 / self._mask_scaling_factor) * U
241
+ factory_kwargs = {"device": weight.device, "dtype": weight.dtype}
242
+ self.ortho.weight.data.copy_(U.detach().to(**factory_kwargs))
243
+ self.register_buffer("S", S.detach().flatten().to(**factory_kwargs))
244
+
245
+
246
+ def mask_func_ste(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor:
247
+ # See ProjectedLinearParametrization.__init__ for more details.
248
+ mask = F.relu(mask)
249
+ return (mask > 0).to(mask.dtype).detach() * mask_scaling_factor + mask - mask.detach()
250
+
251
+
252
+ def mask_func_relu(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor:
253
+ # See ProjectedLinearParametrization.__init__ for more details.
254
+ return F.relu(mask)
255
+
256
+
257
+ def mask_func_none(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor:
258
+ # See ProjectedLinearParametrization.__init__ for more details.
259
+ return mask
260
+
261
+
262
+ def mask_sparsity(mask: torch.Tensor, threshold: float = 0.0) -> int:
263
+ """Simple util function to compute the number of non-zero elements of a mask, where an element is considered
264
+ non-zero if its value is strictly greater than `threshold`."""
265
+ return torch.count_nonzero(mask > threshold).item()
266
+
267
+
268
+ def new_linear_from_mask(module: nn.Linear, dim_select: torch.Tensor, column_select=True) -> nn.Linear:
269
+ """
270
+ Creates a new linear layer from an existing one based on a mask indicating which columns/rows to keep.
271
+
272
+ Args:
273
+ module: Module to be pruned.
274
+ dim_select: Boolean tensor mask indicating which columns/rows to keep.
275
+ column_select: Whether to prune columns (True) or rows (False) according to `dim_select`.
276
+
277
+ Returns: Pruned module.
278
+ """
279
+ assert dim_select.dtype == torch.bool, "dim_select must be boolean"
280
+
281
+ in_features, out_features = module.in_features, module.out_features
282
+ sparsity = dim_select.sum().item()
283
+ if column_select:
284
+ in_features = sparsity
285
+ else:
286
+ out_features = sparsity
287
+ new_module = module.__class__(
288
+ in_features=in_features,
289
+ out_features=out_features,
290
+ bias=module.bias is not None,
291
+ device=module.weight.device,
292
+ dtype=module.weight.dtype,
293
+ )
294
+ weight = module.weight.data
295
+ if column_select:
296
+ weight = weight[:, dim_select]
297
+ else:
298
+ weight = weight[dim_select, :]
299
+ new_module.weight.data.copy_(weight.detach())
300
+
301
+ if new_module.bias is not None:
302
+ if column_select:
303
+ new_module.bias.data.copy_(module.bias.detach())
304
+ else:
305
+ # If rows are pruned, the bias needs to be pruned as well
306
+ new_module.bias.data.copy_(module.bias[dim_select].detach())
307
+
308
+ return new_module
utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextvars
2
+ import importlib
3
+ from contextlib import contextmanager
4
+ from typing import Any, Type
5
+
6
+
7
+ def get_class_from_str(class_str: str, package: str | None = None) -> Type[Any]:
8
+ """
9
+ Converts a string to the corresponding class object, supporting relative imports.
10
+ For relative module paths (starting with '.'), a package must be provided.
11
+
12
+ Args:
13
+ class_str: String representation of the class, either absolute or relative.
14
+ package: Package context, only required for relative imports.
15
+
16
+ Returns: Class object corresponding to the provided string.
17
+ """
18
+ if not isinstance(class_str, str) and isinstance(class_str, type):
19
+ return class_str
20
+
21
+ module_path, _, class_name = class_str.rpartition(".")
22
+ if not module_path and class_str.startswith("."):
23
+ module_path = "."
24
+ if module_path.startswith("."):
25
+ if not package:
26
+ raise ValueError("Relative module path provided without a package context.")
27
+ module = importlib.import_module(module_path, package=package)
28
+ else:
29
+ module = importlib.import_module(module_path)
30
+ return getattr(module, class_name)
31
+
32
+
33
+ def get_str_from_class(cls: Type[Any], package: str | None = None) -> str:
34
+ """
35
+ Converts a class object to its string representation.
36
+ If a package is provided and the class's module is a submodule of the package,
37
+ the returned string will use a relative import.
38
+ Otherwise, an absolute import string is returned.
39
+
40
+ Args:
41
+ cls: Class object to convert.
42
+ package: Package context, only required for relative imports.
43
+
44
+ Returns: String representation of the class.
45
+ """
46
+ if isinstance(cls, str):
47
+ return cls
48
+
49
+ module_path = cls.__module__
50
+ class_name = cls.__name__
51
+
52
+ if package:
53
+ # When class is defined directly in the package's __init__.py
54
+ if module_path == package:
55
+ return f".{class_name}"
56
+ # When class is in a submodule of the package
57
+ elif module_path.startswith(package + "."):
58
+ # Get the relative part (including the dot)
59
+ relative = module_path[len(package) :]
60
+ if not relative.startswith("."):
61
+ relative = "." + relative
62
+ return f"{relative}.{class_name}"
63
+ return f"{module_path}.{class_name}"
64
+
65
+
66
+ use_init_empty_weights = contextvars.ContextVar("init_empty_weights", default=False)
67
+
68
+
69
+ @contextmanager
70
+ def init_empty_weights(value: bool):
71
+ """
72
+ Context manager to indicate that a (parametrized) model should be initialized with empty weights or not.
73
+ If active, `use_init_empty_weights` will be set to `True` otherwise to `False`.
74
+ To check if the context is active, import and check `use_init_empty_weights.get()`.
75
+
76
+ Args:
77
+ value: Indicates whether the model should be initialized with empty weights or not.
78
+ """
79
+ token = use_init_empty_weights.set(value)
80
+ try:
81
+ yield
82
+ finally:
83
+ use_init_empty_weights.reset(token)