import os import re import torch import json import struct from typing import Dict, Any, Union, Optional from safetensors.torch import load_file def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): """ memory efficient save file """ _TYPES = { torch.float64: "F64", torch.float32: "F32", torch.float16: "F16", torch.bfloat16: "BF16", torch.int64: "I64", torch.int32: "I32", torch.int16: "I16", torch.int8: "I8", torch.uint8: "U8", torch.bool: "BOOL", getattr(torch, "float8_e5m2", None): "F8_E5M2", getattr(torch, "float8_e4m3fn", None): "F8_E4M3", } _ALIGN = 256 def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: validated = {} for key, value in metadata.items(): if not isinstance(key, str): raise ValueError(f"Metadata key must be a string, got {type(key)}") if not isinstance(value, str): print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") validated[key] = str(value) else: validated[key] = value return validated # print(f"Using memory efficient save file: {filename}") header = {} offset = 0 if metadata: header["__metadata__"] = validate_metadata(metadata) for k, v in tensors.items(): if v.numel() == 0: # empty tensor header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} else: size = v.numel() * v.element_size() header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} offset += size hjson = json.dumps(header).encode("utf-8") hjson += b" " * (-(len(hjson) + 8) % _ALIGN) with open(filename, "wb") as f: f.write(struct.pack(" Dict[str, str]: return self.header.get("__metadata__", {}) def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") metadata = self.header[key] offset_start, offset_end = metadata["data_offsets"] if offset_start == offset_end: tensor_bytes = None else: # adjust offset by header size self.file.seek(self.header_size + 8 + offset_start) tensor_bytes = self.file.read(offset_end - offset_start) return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): header_size = struct.unpack(" dict[str, torch.Tensor]: if disable_mmap: # return safetensors.torch.load(open(path, "rb").read()) # use experimental loader # logger.info(f"Loading without mmap (experimental)") state_dict = {} with MemoryEfficientSafeOpen(path) as f: for key in f.keys(): state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) return state_dict else: try: state_dict = load_file(path, device=device) except: state_dict = load_file(path) # prevent device invalid Error if dtype is not None: for key in state_dict.keys(): state_dict[key] = state_dict[key].to(dtype=dtype) return state_dict def load_split_weights( file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False ) -> Dict[str, torch.Tensor]: """ Load split weights from a file. If the file name ends with 00001-of-00004 etc, it will load all files with the same prefix. dtype is as is, no conversion is done. """ device = torch.device(device) # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix basename = os.path.basename(file_path) match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename) if match: prefix = basename[: match.start(2)] count = int(match.group(3)) state_dict = {} for i in range(count): filename = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors" filepath = os.path.join(os.path.dirname(file_path), filename) if os.path.exists(filepath): state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap)) else: raise FileNotFoundError(f"File {filepath} not found") else: state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap) return state_dict