"""This file contains the definition of base classes. We thank the following public implementations for inspiring this code: https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py """ import copy import os from typing import Union, Callable, Tuple, Dict, Optional, List import torch def get_parameter_device(parameter: torch.nn.Module): try: return next(parameter.parameters()).device except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes( module: torch.nn.Module, ) -> List[Tuple[str, torch.Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = parameter._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].device def get_parameter_dtype(parameter: torch.nn.Module): try: return next(parameter.parameters()).dtype except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes( module: torch.nn.Module, ) -> List[Tuple[str, torch.Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = parameter._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].dtype class BaseModel(torch.nn.Module): def __init__(self): super().__init__() def save_pretrained( self, save_directory: Union[str, os.PathLike], save_function: Callable = None, state_dict: Optional[Dict[str, torch.Tensor]] = None, ): """Save a model to a directory, so that it can be re-loaded using the load_pretrained class method. Args: save_directory -> Union[str, os.PathLike]: Directory to which to save. Will be created if it doesn't exist. save_function -> Optional[Callable]: The function to use to save the state dictionary. Useful on distributed training like TPUs when one need to replace `torch.save` by another method. state_dict -> Optional[Dict[str, torch.Tensor]]: The state dictionary to save. If `None`, the model's state dictionary will be saved. """ if os.path.isfile(save_directory): print(f"Provided path ({save_directory}) should be a directory, not a file") return if save_function is None: save_function = torch.save os.makedirs(save_directory, exist_ok=True) model_to_save = self # Save the model if state_dict is None: state_dict = model_to_save.state_dict() weights_name = "pytorch_model.bin" # Save the model save_function(state_dict, os.path.join(save_directory, weights_name)) print(f"Model weights saved in {os.path.join(save_directory, weights_name)}") def load_pretrained( self, # pretrained_model_path: Union[str, os.PathLike], checkpoint, strict_loading: bool = True, torch_dtype: Optional[torch.dtype] = None, rename_keys: Optional[Dict[str, str]] = None, ): """Instantiate a pretrained pytorch model from a weights path. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train the model, you should first set it back in training mode with `model.train()`. Args: pretrained_model_path -> Union[str, os.PathLike]: Path to a pretrained model. strict_loading -> bool: Whether or not to strictly enforce that the provided weights file matches the architecture of this model. torch_dtype -> Optional[torch.dtype]: The dtype to use for the model. Defaults to `None`, which means no conversion. rename_keys -> Optional[Dict[str, str]]: A dictionary containing the keys to rename. Defaults to `None`, which means no renaming. """ # if os.path.isfile(pretrained_model_path): # model_file = pretrained_model_path # elif os.path.isdir(pretrained_model_path): # pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin") # if os.path.isfile(pretrained_model_path): # model_file = pretrained_model_path # else: # raise ValueError(f"{pretrained_model_path} does not exist") # else: # raise ValueError(f"{pretrained_model_path} does not exist") # # checkpoint = torch.load(model_file, map_location="cpu") new_checkpoint = copy.deepcopy(checkpoint) if rename_keys is not None: for p_key in checkpoint: for r_key in rename_keys: if p_key.startswith(r_key): new_checkpoint[p_key.replace(r_key, rename_keys[r_key])] = ( checkpoint[p_key] ) new_checkpoint.pop(p_key) break checkpoint = new_checkpoint self.load_state_dict(checkpoint, strict=strict_loading) if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) elif torch_dtype is not None: self.to(torch_dtype) # Set model in evaluation mode to deactivate DropOut modules by default self.eval() @property def device(self): """Returns the device of the model. Returns: `torch.device`: The device of the model. """ return get_parameter_device(self) @property def dtype(self) -> torch.dtype: """Returns the dtype of the model.""" return get_parameter_dtype(self) def num_parameters( self, only_trainable: bool = False, exclude_embeddings: bool = False ) -> int: """ Get number of (optionally, trainable or non-embeddings) parameters in the module. Args: only_trainable (`bool`, *optional*, defaults to `False`): Whether or not to return only the number of trainable parameters exclude_embeddings (`bool`, *optional*, defaults to `False`): Whether or not to return only the number of non-embeddings parameters Returns: `int`: The number of parameters. """ if exclude_embeddings: embedding_param_names = [ f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, torch.nn.Embedding) ] non_embedding_parameters = [ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names ] return sum( p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable ) else: return sum( p.numel() for p in self.parameters() if p.requires_grad or not only_trainable )