|
import logging |
|
import os |
|
from dataclasses import asdict, dataclass, field |
|
from typing import Any, Literal, Type |
|
|
|
import torch |
|
from peft import PeftConfig |
|
from peft.tuners.tuners_utils import _maybe_include_all_linear_layers, check_target_module_exists |
|
from torch import nn |
|
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel |
|
|
|
from .parametrized_layer import Parametrization, parametrize_module, ParametrizedModule, unparametrize_module |
|
from .projected_layer import SVDLinearParametrization |
|
from .utils import get_class_from_str, get_str_from_class, init_empty_weights |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class BaseModelConfig: |
|
""" |
|
Configuration for the base model to be parametrized by `ParametrizedModel`. |
|
|
|
Attributes: |
|
pretrained_model_cls: The class of the base model. Child class of `PreTrainedModel`. |
|
pretrained_model_kwargs: Keyword arguments used when creating the base model in the constructor |
|
of `ParametrizedModel` via `from_pretrained`. |
|
pretrained_config: Optional config used when creating the base model in the constructor |
|
of `ParametrizedModel` via `from_pretrained`. |
|
|
|
See Also: |
|
`ParametrizedModelConfig` |
|
""" |
|
|
|
pretrained_model_cls: Type[PreTrainedModel] |
|
pretrained_model_kwargs: dict[str, Any] = field(default_factory=dict) |
|
pretrained_config: PretrainedConfig | None = None |
|
|
|
def __post_init__(self): |
|
|
|
if isinstance(self.pretrained_model_cls, str): |
|
self.pretrained_model_cls = get_class_from_str(self.pretrained_model_cls) |
|
else: |
|
self.pretrained_model_cls = self.pretrained_model_cls |
|
|
|
def to_dict(self) -> dict[str, Any]: |
|
config_dict = asdict(self) |
|
|
|
config_dict["pretrained_model_cls"] = get_str_from_class(self.pretrained_model_cls) |
|
if self.pretrained_config is not None: |
|
config_dict["pretrained_config"] = self.pretrained_config.to_dict() |
|
return config_dict |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: dict[str, Any]) -> "BaseModelConfig": |
|
|
|
try: |
|
if config_dict["pretrained_config"] is not None: |
|
|
|
config_dict["pretrained_config"] = AutoConfig.for_model(**config_dict["pretrained_config"]) |
|
except ValueError: |
|
logger.warning("Unrecognized model identifier in AutoConfig, using PretrainedConfig instead.") |
|
config_dict["pretrained_config"] = PretrainedConfig.from_dict(config_dict["pretrained_config"]) |
|
return cls(**config_dict) |
|
|
|
|
|
|
|
PARAMETRIZATION_FACTORY_REGISTRY: dict[str, Type[Parametrization]] = { |
|
"svd": SVDLinearParametrization, |
|
} |
|
|
|
|
|
@dataclass |
|
class ParametrizationConfig: |
|
""" |
|
Configuration for the parametrization to be applied to the linear layers of the base model in `ParametrizedModel`. |
|
|
|
Attributes: |
|
module_factory_cls: The class name of the parametrization to be applied to linear layers. |
|
Can be a string representing a class name (with absolute module path) or a predefined key |
|
from `PARAMETRIZATION_FACTORY_REGISTRY`. |
|
Use `parse_module_factory_cls` to get the actual class when creating the parametrization. |
|
module_factory_kwargs: Keyword arguments used when creating the parametrization with `module_factory_cls`. |
|
target_modules: A (list of) string(s) specifying the names of the linear layers to be parametrized. |
|
Follows the same semantics as Huggingface's `PeftConfig`, see also `check_target_module_exists`. |
|
If a string, a regex match will be performed; if a list, a module will be parametrized if its name ends |
|
with any of the strings in `target_modules`. |
|
exclude_modules: A list of strings specifying the names of the linear layers to be excluded from |
|
parametrization. A module will be excluded if any of the strings in `exclude_modules` is in its name. |
|
|
|
See Also: |
|
`ParametrizedModelConfig` |
|
""" |
|
|
|
module_factory_cls: str |
|
module_factory_kwargs: dict[str, Any] = field(default_factory=dict) |
|
target_modules: str | list[str] | None = None |
|
exclude_modules: list[str] | None = None |
|
|
|
def parse_module_factory_cls(self) -> Type[Parametrization]: |
|
"""Returns the class of the parametrization to be applied to linear layers.""" |
|
try: |
|
if self.module_factory_cls in PARAMETRIZATION_FACTORY_REGISTRY: |
|
module_factory_cls = PARAMETRIZATION_FACTORY_REGISTRY[self.module_factory_cls] |
|
else: |
|
module_factory_cls = get_class_from_str(self.module_factory_cls) |
|
except Exception: |
|
raise ValueError(f"Unrecognized parametrization class: {self.module_factory_cls}") |
|
return module_factory_cls |
|
|
|
def to_dict(self) -> dict[str, Any]: |
|
config_dict = asdict(self) |
|
|
|
for key, value in config_dict.items(): |
|
if isinstance(value, set): |
|
config_dict[key] = list(value) |
|
return config_dict |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: dict[str, Any]) -> "ParametrizationConfig": |
|
return cls(**config_dict) |
|
|
|
|
|
@dataclass |
|
class AdapterConfig: |
|
""" |
|
Configuration for the Huggingface Peft adapters to be applied to the base model. |
|
|
|
Attributes: |
|
peft_config: One or more adapter `PeftConfig`s to be applied to the base model. |
|
If a single `PeftConfig` is provided, it will wrapped by a dict with key "default". |
|
The dictionary keys will be used as adapter names in `PretrainedModel.add_adapter`. |
|
|
|
See Also: |
|
`ParametrizedModelConfig` |
|
""" |
|
|
|
peft_config: PeftConfig | dict[str, PeftConfig] |
|
|
|
def __post_init__(self): |
|
if isinstance(self.peft_config, PeftConfig): |
|
self.peft_config = {"default": self.peft_config} |
|
|
|
def to_dict(self) -> dict[str, Any]: |
|
config_dict = asdict(self) |
|
|
|
for adapter_name, peft_config in self.peft_config.items(): |
|
peft_config_dict = peft_config.to_dict() |
|
|
|
for key, value in peft_config_dict.items(): |
|
if isinstance(value, set): |
|
peft_config_dict[key] = list(value) |
|
config_dict["peft_config"][adapter_name] = peft_config_dict |
|
return config_dict |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: dict[str, Any]) -> "AdapterConfig": |
|
|
|
for key, peft_config in config_dict["peft_config"].items(): |
|
config_dict["peft_config"][key] = PeftConfig.from_peft_type(**peft_config) |
|
return cls(**config_dict) |
|
|
|
|
|
try: |
|
|
|
import bitsandbytes |
|
|
|
|
|
|
|
QUANTIZATION_FACTORY_REGISTRY: dict[str, Type[nn.Linear]] = { |
|
"bnb4bit": bitsandbytes.nn.Linear4bit, |
|
} |
|
except ImportError: |
|
logger.warning("bitsandbytes is not installed, skipping quantization.") |
|
QUANTIZATION_FACTORY_REGISTRY: dict[str, Type[nn.Linear]] = {} |
|
|
|
|
|
@dataclass |
|
class WeightQuantizationConfig: |
|
""" |
|
Configuration for an (optional) weight quantization to be applied to the base model. |
|
So far, only fp4 quantization with bitsandbytes has been tested, but analogous bitsandbytes |
|
quantizations should work as well. `module_factory_cls` might also use a different quantization library, |
|
as long as it is compatible with the module replacement strategy in `ParametrizedModule.quantize`. |
|
|
|
Attributes: |
|
module_factory_cls: The class name of the quantization to be applied to linear layers. |
|
Can be a string representing a class name (with absolute module path) or a predefined key |
|
from `QUANTIZATION_FACTORY_REGISTRY`. |
|
Use `parse_module_factory_cls` to get the actual class when creating the quantization. |
|
module_factory_kwargs: Keyword arguments used when creating the quantization with `module_factory_cls`. |
|
target_modules: A (list of) string(s) specifying the names of the linear layers to be quantized. |
|
Follows the same semantics as Huggingface's `PeftConfig`, see also `check_target_module_exists`. |
|
If a string, a regex match will be performed; if a list, a module will be quantized if its name ends |
|
with any of the strings in `target_modules`. |
|
exclude_modules: A list of strings specifying the names of the linear layers to be excluded from |
|
quantization. A module will be excluded if any of the strings in `exclude_modules` is in its name. |
|
|
|
See Also: |
|
`ParametrizedModelConfig` |
|
""" |
|
|
|
module_factory_cls: str |
|
module_factory_kwargs: dict[str, Any] = field(default_factory=dict) |
|
target_modules: str | list[str] | None = None |
|
exclude_modules: list[str] | None = None |
|
|
|
def parse_module_factory_cls(self) -> Type[nn.Linear]: |
|
"""Returns the class of the quantization to be applied to linear layers.""" |
|
try: |
|
if self.module_factory_cls in QUANTIZATION_FACTORY_REGISTRY: |
|
module_factory_cls = QUANTIZATION_FACTORY_REGISTRY[self.module_factory_cls] |
|
else: |
|
module_factory_cls = get_class_from_str(self.module_factory_cls) |
|
except Exception: |
|
raise ValueError(f"Unrecognized quantization class: {self.module_factory_cls}") |
|
return module_factory_cls |
|
|
|
def to_dict(self) -> dict[str, Any]: |
|
config_dict = asdict(self) |
|
|
|
for key, value in config_dict["module_factory_kwargs"].items(): |
|
if isinstance(value, torch.dtype): |
|
config_dict["module_factory_kwargs"][key] = str(value) |
|
|
|
for key, value in config_dict.items(): |
|
if isinstance(value, set): |
|
config_dict[key] = list(value) |
|
return config_dict |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: dict[str, Any]) -> "WeightQuantizationConfig": |
|
|
|
for key, value in config_dict["module_factory_kwargs"].items(): |
|
if isinstance(value, str) and value.startswith("torch."): |
|
dtype_name = value.split(".")[-1] |
|
config_dict["module_factory_kwargs"][key] = getattr(torch, dtype_name) |
|
return cls(**config_dict) |
|
|
|
|
|
class ParametrizedModelConfig(PretrainedConfig): |
|
""" |
|
Configuration for `ParametrizedModel` implementing a `PretrainedConfig` to be fully compatible with |
|
Huggingface's `PreTrainedModel` framework. |
|
|
|
See Also: |
|
- `BaseModelConfig` |
|
- `ParametrizationConfig` |
|
- `AdapterConfig` |
|
- `WeightQuantizationConfig` |
|
- `ParametrizedModel` |
|
""" |
|
|
|
model_type = "parametrized_model" |
|
|
|
def __init__( |
|
self, |
|
base_model_config: BaseModelConfig | None = None, |
|
parametrization_config: ParametrizationConfig | None = None, |
|
adapter_config: AdapterConfig | None = None, |
|
weight_quantization_config: WeightQuantizationConfig | None = None, |
|
model_mode: Literal["train", "eval"] = "train", |
|
**kwargs: Any, |
|
): |
|
""" |
|
Initializes a `ParametrizedModelConfig`, serving as a container for `BaseModelConfig`, `ParametrizationConfig`, |
|
`AdapterConfig`, and `WeightQuantizationConfig`. |
|
|
|
Args: |
|
base_model_config: `BaseModelConfig` |
|
parametrization_config: `ParametrizationConfig` |
|
adapter_config: `AdapterConfig` |
|
weight_quantization_config: `WeightQuantizationConfig` |
|
model_mode: Whether to initialize the model in train or eval mode. |
|
**kwargs: Keyword arguments forwarded to `PretrainedConfig`. |
|
""" |
|
self.base_model_config = base_model_config |
|
self.parametrization_config = parametrization_config |
|
self.adapter_config = adapter_config |
|
self.weight_quantization_config = weight_quantization_config |
|
self.model_mode = model_mode |
|
super().__init__(**kwargs) |
|
|
|
def _convert_to_dict(self, config_dict: dict[str, Any]) -> dict[str, Any]: |
|
if self.base_model_config is not None: |
|
config_dict["base_model_config"] = self.base_model_config.to_dict() |
|
if self.parametrization_config is not None: |
|
config_dict["parametrization_config"] = self.parametrization_config.to_dict() |
|
if self.adapter_config is not None: |
|
config_dict["adapter_config"] = self.adapter_config.to_dict() |
|
if self.weight_quantization_config is not None: |
|
config_dict["weight_quantization_config"] = self.weight_quantization_config.to_dict() |
|
return config_dict |
|
|
|
def to_diff_dict(self): |
|
|
|
config_dict = super().to_diff_dict() |
|
return self._convert_to_dict(config_dict) |
|
|
|
def to_dict(self): |
|
|
|
config_dict = super().to_dict() |
|
return self._convert_to_dict(config_dict) |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: dict[str, Any], **kwargs: Any) -> PretrainedConfig: |
|
|
|
base_model_config_dict: dict[str, Any] | None = config_dict.pop("base_model_config", None) |
|
if base_model_config_dict is not None: |
|
base_model_config = BaseModelConfig.from_dict(base_model_config_dict) |
|
else: |
|
base_model_config = None |
|
|
|
parametrization_config_dict: dict[str, Any] | None = config_dict.pop("parametrization_config", None) |
|
if parametrization_config_dict is not None: |
|
parametrization_config = ParametrizationConfig.from_dict(parametrization_config_dict) |
|
else: |
|
parametrization_config = None |
|
|
|
adapter_config_dict: dict[str, Any] | None = config_dict.pop("adapter_config", None) |
|
if adapter_config_dict is not None: |
|
adapter_config = AdapterConfig.from_dict(adapter_config_dict) |
|
else: |
|
adapter_config = None |
|
|
|
weight_quantization_config_dict: dict[str, Any] | None = config_dict.pop("weight_quantization_config", None) |
|
if weight_quantization_config_dict is not None: |
|
weight_quantization_config = WeightQuantizationConfig.from_dict(weight_quantization_config_dict) |
|
else: |
|
weight_quantization_config = None |
|
|
|
config = super().from_dict(config_dict, **kwargs) |
|
|
|
|
|
if "return_unused_kwargs" in kwargs and kwargs["return_unused_kwargs"] is True: |
|
config[0].base_model_config = base_model_config |
|
config[0].parametrization_config = parametrization_config |
|
config[0].adapter_config = adapter_config |
|
config[0].weight_quantization_config = weight_quantization_config |
|
else: |
|
config.base_model_config = base_model_config |
|
config.parametrization_config = parametrization_config |
|
config.adapter_config = adapter_config |
|
config.weight_quantization_config = weight_quantization_config |
|
return config |
|
|
|
|
|
class ParametrizedModel(PreTrainedModel): |
|
""" |
|
Base class for parametrized models implemented as a custom Huggingface `PreTrainedModel`. |
|
It wraps any base model of type `PreTrainedModel` in `self.model`, whose linear layers can be |
|
parametrized (`parametrize`), equipped with adapters (`inject_adapters`), and quantized (`quantize`). |
|
The corresponding modules are accessed via `parametrized_modules`, `adapter_modules`, |
|
and `quantized_modules`, respectively. |
|
The class also provides several convenience methods to manage the parametrization: `get_target_params`, |
|
`get_num_params`, `get_compression_ratio`, `reset_target_params`, `compress`. |
|
|
|
Standard functionality (`forward`, `generate`, `save_pretrained`, `from_pretrained`) is essentially forwarded |
|
to the wrapped model. |
|
|
|
See Also: |
|
`ParametrizedModelConfig` |
|
""" |
|
|
|
config_class = ParametrizedModelConfig |
|
|
|
def __init__(self, config: ParametrizedModelConfig, base_model: PreTrainedModel | None = None, **_: Any): |
|
""" |
|
Initialize the `ParametrizedModel` from a given configuration or an existing base model. |
|
|
|
Args: |
|
config: `ParametrizedModelConfig` to be used. |
|
base_model: If provided, this base model is used instead of creating it from `config.base_model_config`. |
|
**_: Ignored keyword arguments to prevent unexpected keyword errors. |
|
|
|
See Also: `BaseModelConfig` |
|
""" |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
if base_model is None: |
|
if self.config.base_model_config is None: |
|
raise ValueError("Either base_model or base_model_config must be provided.") |
|
self.model = self.config.base_model_config.pretrained_model_cls.from_pretrained( |
|
config=self.config.base_model_config.pretrained_config, |
|
**self.config.base_model_config.pretrained_model_kwargs, |
|
) |
|
else: |
|
self.model = base_model |
|
|
|
|
|
self.train(self.config.model_mode == "train") |
|
logger.info(f"Base model {self.model.__class__} created.") |
|
|
|
|
|
self._parametrized_modules: dict[str, ParametrizedModule] | None = None |
|
self.parametrize() |
|
|
|
|
|
self._adapter_modules: dict[str, nn.Module] | None = None |
|
self.inject_adapters() |
|
|
|
|
|
self._quantized_modules: dict[str, nn.Linear] | None = None |
|
|
|
|
|
|
|
_ = self.parametrized_modules |
|
_ = self.adapter_modules |
|
_ = self.quantized_modules |
|
|
|
|
|
|
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
@property |
|
def base_model_name_or_path(self) -> str: |
|
"""Convenience method to return the name or path of the base model.""" |
|
return self.model.name_or_path |
|
|
|
def forward(self, *args, **kwargs) -> Any: |
|
return self.model(*args, **kwargs) |
|
|
|
def generate(self, *args, **kwargs) -> Any: |
|
return self.model.generate(*args, **kwargs) |
|
|
|
def save_pretrained( |
|
self, |
|
save_directory: str | os.PathLike, |
|
state_dict: dict | None = None, |
|
include_filter: list[str] | None = None, |
|
exclude_filter: list[str] | None = None, |
|
**kwargs: Any, |
|
) -> None: |
|
""" |
|
Override of the default `save_pretrained` method to allow filtering of the saved state dict. |
|
|
|
Args: |
|
save_directory: Directory to save the model to. |
|
state_dict: Manuel override of the state dict to be saved. |
|
If None, `include_filter` and `exclude_filter` are applied to `self.state_dict()`. |
|
include_filter: List of state dict keys to include from the state dict. |
|
Match when the key ends with any of the strings in the list. |
|
If None, all keys are included. |
|
exclude_filter: List of state dict keys to exclude from in the state dict. |
|
Match when the key ends with any of the strings in the list. |
|
If None, no keys are excluded. |
|
**kwargs: Keyword arguments to be passed to the default `save_pretrained` method. |
|
|
|
See Also: |
|
`PreTrainedModel.save_pretrained` |
|
""" |
|
if state_dict is None: |
|
state_dict = self.state_dict() |
|
if include_filter is not None: |
|
state_dict = {k: v for k, v in state_dict.items() if any(k.endswith(f) for f in include_filter)} |
|
if exclude_filter is not None: |
|
state_dict = {k: v for k, v in state_dict.items() if not any(k.endswith(f) for f in exclude_filter)} |
|
|
|
super().save_pretrained(save_directory=save_directory, state_dict=state_dict, **kwargs) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: str | os.PathLike | None, |
|
*model_args: Any, |
|
with_init_empty_weights: bool = True, |
|
**kwargs: Any, |
|
) -> PreTrainedModel: |
|
""" |
|
Override of the default `from_pretrained` method to allow initialization with empty weights. |
|
|
|
Args: |
|
pretrained_model_name_or_path: Model name or path. |
|
*model_args: Arguments to be passed to the default `from_pretrained` method. |
|
with_init_empty_weights: Whether to initialize the model with empty weights or not. |
|
**kwargs: Keyword arguments to be passed to the default `from_pretrained` method. |
|
""" |
|
with init_empty_weights(with_init_empty_weights): |
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|
|
@property |
|
def parametrized_modules(self) -> dict[str, ParametrizedModule]: |
|
""" |
|
Returns a dictionary of all parametrized modules in the model. |
|
The returned dictionary is compatible with `self.model.named_modules()`. |
|
""" |
|
if self._parametrized_modules is None: |
|
self._parametrized_modules = {} |
|
if self.config.parametrization_config is None: |
|
return self._parametrized_modules |
|
for m_name, module in self.model.named_modules(): |
|
if isinstance(module, ParametrizedModule): |
|
self._parametrized_modules[m_name] = module |
|
return self._parametrized_modules |
|
|
|
@property |
|
def adapter_modules(self) -> dict[str, nn.Module]: |
|
""" |
|
Returns a dictionary of all adapter modules in the model. |
|
The returned dictionary is compatible with `self.model.named_modules()`. |
|
""" |
|
if self._adapter_modules is None: |
|
self._adapter_modules = {} |
|
if self.config.adapter_config is None: |
|
return self._adapter_modules |
|
try: |
|
|
|
for adapter_name in self.model.active_adapters(): |
|
for m_name in self.model.get_adapter_state_dict(adapter_name).keys(): |
|
adapter_m_name = f"{m_name.rsplit('.', 1)[0]}.{adapter_name}" |
|
self._adapter_modules[adapter_m_name] = self.model.get_submodule(adapter_m_name) |
|
except ValueError as e: |
|
logger.warning(e) |
|
return self._adapter_modules |
|
|
|
@property |
|
def quantized_modules(self) -> dict[str, nn.Linear]: |
|
""" |
|
Returns a dictionary of all quantized modules in the model. |
|
The returned dictionary is compatible with `self.model.named_modules()`. |
|
""" |
|
if self._quantized_modules is None: |
|
self._quantized_modules = {} |
|
if self.config.weight_quantization_config is None: |
|
return self._quantized_modules |
|
try: |
|
module_factory_cls = self.config.weight_quantization_config.parse_module_factory_cls() |
|
except Exception as e: |
|
logger.warning(f"Could not parse weight quantization config, quantization not available.\nError: {e}") |
|
return self._quantized_modules |
|
for m_name, module in self.model.named_modules(): |
|
if isinstance(module, module_factory_cls): |
|
self._quantized_modules[m_name] = module |
|
return self._quantized_modules |
|
|
|
def parametrize(self) -> None: |
|
""" |
|
Parametrize the `target_modules` from `ParametrizationConfig` using `parametrized_layer.parametrize_module`. |
|
|
|
See Also: `ParametrizationConfig` |
|
""" |
|
if self.config.parametrization_config is None: |
|
logger.debug("Model parametrization is disabled.") |
|
return |
|
|
|
|
|
|
|
config: ParametrizationConfig = _maybe_include_all_linear_layers( |
|
self.config.parametrization_config, |
|
self.model, |
|
) |
|
module_factory_cls = config.parse_module_factory_cls() |
|
|
|
for m_name, module in self.model.named_modules(): |
|
|
|
if config.exclude_modules is not None and any(key in m_name for key in config.exclude_modules): |
|
continue |
|
if not check_target_module_exists(config, m_name): |
|
continue |
|
|
|
parametrization = module_factory_cls(**config.module_factory_kwargs) |
|
parametrize_module(module=module, parametrization=parametrization) |
|
logger.debug(f"Parametrized {module.__class__} module {m_name} as {parametrization.__class__}") |
|
|
|
self._parametrized_modules = None |
|
logger.info("Parametrization completed.") |
|
|
|
def inject_adapters(self) -> None: |
|
""" |
|
Inject adapters according to `AdapterConfig` using the adapter management of `PreTrainedModel`. |
|
|
|
See Also: `AdapterConfig` |
|
""" |
|
if self.config.adapter_config is None: |
|
logger.debug("Adapter injection is disabled.") |
|
return |
|
|
|
for adapter_name, peft_config in self.config.adapter_config.peft_config.items(): |
|
self.model.add_adapter(peft_config, adapter_name=adapter_name) |
|
self.model.set_adapter(list(self.config.adapter_config.peft_config.keys())) |
|
|
|
self._adapter_modules = None |
|
logger.info("Adapters injected.") |
|
|
|
def quantize(self) -> None: |
|
""" |
|
Quantize the `target_modules` from `WeightQuantizationConfig`. |
|
|
|
See Also: `WeightQuantizationConfig` |
|
""" |
|
if self.config.weight_quantization_config is None: |
|
logger.debug("Weight quantization is disabled.") |
|
return |
|
|
|
|
|
|
|
config: WeightQuantizationConfig = _maybe_include_all_linear_layers( |
|
self.config.weight_quantization_config, |
|
self.model, |
|
) |
|
module_factory_cls = config.parse_module_factory_cls() |
|
|
|
for m_name, module in self.model.named_modules(): |
|
|
|
if config.exclude_modules is not None and any(key in m_name for key in config.exclude_modules): |
|
continue |
|
if not check_target_module_exists(config, m_name) or isinstance(module, ParametrizedModule): |
|
continue |
|
if not isinstance(module, nn.Linear): |
|
continue |
|
|
|
|
|
quantized_module = module_factory_cls( |
|
module.in_features, |
|
module.out_features, |
|
bias=module.bias is not None, |
|
device=module.weight.device, |
|
**config.module_factory_kwargs, |
|
) |
|
|
|
quantized_module.load_state_dict(module.state_dict()) |
|
quantized_module = quantized_module.to(module.weight.device) |
|
quantized_module.weight.requires_grad = False |
|
logger.debug(f"Quantized {module.__class__} module {m_name} to {quantized_module.__class__}") |
|
|
|
|
|
parent_name, child_name = m_name.rsplit(".", 1) |
|
parent_module = self.model.get_submodule(parent_name) |
|
parent_module.add_module(child_name, quantized_module) |
|
|
|
self._quantized_modules = None |
|
logger.info("Quantization completed.") |
|
|
|
def get_target_params(self) -> dict[str, nn.Parameter]: |
|
""" |
|
Lifts `Parametrization.get_target_params` to the model scope. |
|
The returned dictionary should be compatible with `self.model.named_parameters()`. |
|
|
|
See Also: |
|
`Parametrization.get_target_params` |
|
""" |
|
target_params = {} |
|
for m_name, module in self.parametrized_modules.items(): |
|
for p_name, param in module.parametrization.get_target_params().items(): |
|
target_params[f"{m_name}.parametrization.{p_name}"] = param |
|
return target_params |
|
|
|
def get_num_params( |
|
self, compressed: bool = False, full: bool = False, target_params: dict[str, torch.Tensor] | None = None |
|
) -> int: |
|
""" |
|
Lifts `Parametrization.get_num_params` to the model scope. |
|
Computes the (effective) number of parameters of the entire model. |
|
|
|
Args: |
|
compressed: Whether to count the number of parameters as if the parametrized modules were actually |
|
compressed. If `False`, the number of parameters is the same as in the original module. |
|
full: If `True`, all parameters of the model are counted, if `False` only those of parametrized modules. |
|
Default is `False`, which follows the most common convention in the compression literature. |
|
target_params: Count the number of parameters as if `target_params` were used instead of |
|
the parametrized modules' target parameters. The dictionary keys should be compatible with those of |
|
`self.get_target_params`. |
|
|
|
See Also: |
|
`Parametrization.get_num_params` |
|
""" |
|
num_params_full = 0 |
|
if full: |
|
for name, param in self.model.named_parameters(): |
|
if "parametrization" not in name: |
|
if hasattr(param, "quant_state"): |
|
num_params_full += param.numel() * 2 |
|
else: |
|
num_params_full += param.numel() |
|
|
|
num_params = 0 |
|
for module_name, module in self.parametrized_modules.items(): |
|
module_target_params = None |
|
if compressed and target_params is not None: |
|
|
|
prefix = f"{module_name}.parametrization." |
|
|
|
module_target_params = { |
|
key[len(prefix) :]: value for key, value in target_params.items() if key.startswith(prefix) |
|
} |
|
if not module_target_params: |
|
module_target_params = None |
|
|
|
num_params += module.parametrization.get_num_params( |
|
compressed=compressed, target_params=module_target_params |
|
) |
|
num_params = num_params + num_params_full |
|
if num_params == 0: |
|
|
|
num_params = 1e-6 |
|
return num_params |
|
|
|
def get_compression_ratio(self, full: bool = False, target_params: dict[str, torch.Tensor] | None = None) -> float: |
|
""" |
|
Convenience function to compute the compression ratio of the present model. |
|
|
|
See Also: |
|
`get_num_params` |
|
""" |
|
return self.get_num_params(compressed=True, full=full, target_params=target_params) / self.get_num_params( |
|
full=full |
|
) |
|
|
|
def reset_target_params(self, mode: Literal["full", "nonzero", "compress"] = "full") -> None: |
|
""" |
|
Lifts `Parametrization.reset_target_params` to the model scope. |
|
|
|
Args: |
|
mode: The reset mode, see `Parametrization.reset_target_params`. |
|
|
|
See Also: |
|
`Parametrization.reset_target_params` |
|
""" |
|
for m_name, module in self.parametrized_modules.items(): |
|
module.parametrization.reset_target_params(mode=mode) |
|
|
|
def compress(self) -> None: |
|
""" |
|
Compresses all parametrized modules using `Parametrization.reset_target_params(mode="compress")`. |
|
If no compression is possible, the module is unparametrized and removed from `parametrized_modules`. |
|
""" |
|
removed_parametrized_modules = [] |
|
for m_name, module in self.parametrized_modules.items(): |
|
if module.parametrization.get_num_params(compressed=True) / module.parametrization.get_num_params() >= 1.0: |
|
unparametrize_module(module) |
|
removed_parametrized_modules.append(m_name) |
|
logger.debug(f"Unparametrizing {module.__class__} module {m_name}") |
|
else: |
|
module.parametrization.reset_target_params(mode="compress") |
|
logger.debug(f"Compressing {module.__class__} module {m_name}") |
|
for m_name in removed_parametrized_modules: |
|
self.parametrized_modules.pop(m_name) |
|
logger.info("Compression completed.") |
|
|
|
|
|
|
|
|
|
ParametrizedModelConfig.register_for_auto_class() |
|
ParametrizedModel.register_for_auto_class("AutoModel") |
|
|