import requests from huggingface_hub import HfApi from typing import Dict, Optional, Tuple import json class ModelMemoryCalculator: def __init__(self): self.hf_api = HfApi() self.cache = {} # Cache results to avoid repeated API calls def get_model_memory_requirements(self, model_id: str) -> Dict: """ Calculate memory requirements for a given HuggingFace model. Args: model_id: HuggingFace model identifier (e.g., "black-forest-labs/FLUX.1-schnell") Returns: Dict with memory information including: - total_params: Total parameter count - memory_fp32: Memory in GB at FP32 precision - memory_fp16: Memory in GB at FP16 precision - memory_bf16: Memory in GB at BF16 precision - safetensors_files: List of safetensor files and their sizes """ if model_id in self.cache: return self.cache[model_id] try: print(f"Fetching model info for {model_id}...") # Get model info model_info = self.hf_api.model_info(model_id) print(f"Model info retrieved successfully") # Get safetensors metadata print(f"Fetching safetensors metadata...") safetensors_metadata = self.hf_api.get_safetensors_metadata(model_id) print(f"Found {len(safetensors_metadata)} safetensor files") total_params = 0 safetensors_files = [] # Iterate through all safetensor files for filename, metadata in safetensors_metadata.items(): file_params = 0 file_size_bytes = 0 # Calculate parameters from tensor metadata if 'metadata' in metadata and metadata['metadata']: for tensor_name, tensor_info in metadata['metadata'].items(): if 'shape' in tensor_info and 'dtype' in tensor_info: # Calculate tensor size shape = tensor_info['shape'] tensor_params = 1 for dim in shape: tensor_params *= dim file_params += tensor_params # Calculate byte size based on dtype dtype = tensor_info['dtype'] bytes_per_param = self._get_bytes_per_param(dtype) file_size_bytes += tensor_params * bytes_per_param total_params += file_params safetensors_files.append({ 'filename': filename, 'parameters': file_params, 'size_bytes': file_size_bytes, 'size_mb': file_size_bytes / (1024 * 1024) }) # Calculate memory requirements for different precisions memory_requirements = { 'model_id': model_id, 'total_params': total_params, 'total_params_billions': total_params / 1e9, 'memory_fp32_gb': (total_params * 4) / (1024**3), # 4 bytes per param 'memory_fp16_gb': (total_params * 2) / (1024**3), # 2 bytes per param 'memory_bf16_gb': (total_params * 2) / (1024**3), # 2 bytes per param 'memory_int8_gb': (total_params * 1) / (1024**3), # 1 byte per param 'safetensors_files': safetensors_files, 'estimated_inference_memory_fp16_gb': self._estimate_inference_memory(total_params, 'fp16'), 'estimated_inference_memory_bf16_gb': self._estimate_inference_memory(total_params, 'bf16'), } # Cache the result self.cache[model_id] = memory_requirements return memory_requirements except Exception as e: return { 'error': str(e), 'model_id': model_id, 'total_params': 0, 'memory_fp32_gb': 0, 'memory_fp16_gb': 0, 'memory_bf16_gb': 0, } def _get_bytes_per_param(self, dtype: str) -> int: """Get bytes per parameter for different data types.""" dtype_map = { 'F32': 4, 'float32': 4, 'F16': 2, 'float16': 2, 'BF16': 2, 'bfloat16': 2, 'I8': 1, 'int8': 1, 'I32': 4, 'int32': 4, 'I64': 8, 'int64': 8, } return dtype_map.get(dtype, 4) # Default to 4 bytes (FP32) def _estimate_inference_memory(self, total_params: int, precision: str) -> float: """ Estimate memory requirements during inference. This includes model weights + activations + intermediate tensors. """ bytes_per_param = 2 if precision in ['fp16', 'bf16'] else 4 # Model weights model_memory = (total_params * bytes_per_param) / (1024**3) # Estimate activation memory (rough approximation) # For diffusion models, activations can be 1.5-3x model size during inference activation_multiplier = 2.0 total_inference_memory = model_memory * (1 + activation_multiplier) return total_inference_memory def get_memory_recommendation(self, model_id: str, available_vram_gb: float) -> Dict: """ Get memory recommendations based on available VRAM. Args: model_id: HuggingFace model identifier available_vram_gb: Available VRAM in GB Returns: Dict with recommendations for precision, offloading, etc. """ memory_info = self.get_model_memory_requirements(model_id) if 'error' in memory_info: return {'error': memory_info['error']} recommendations = { 'model_id': model_id, 'available_vram_gb': available_vram_gb, 'model_memory_fp16_gb': memory_info['memory_fp16_gb'], 'estimated_inference_memory_fp16_gb': memory_info['estimated_inference_memory_fp16_gb'], 'recommendations': [] } inference_memory_fp16 = memory_info['estimated_inference_memory_fp16_gb'] inference_memory_bf16 = memory_info['estimated_inference_memory_bf16_gb'] # Determine recommendations if available_vram_gb >= inference_memory_bf16: recommendations['recommendations'].append("āœ… Full model can fit in VRAM with BF16 precision") recommendations['recommended_precision'] = 'bfloat16' recommendations['cpu_offload'] = False recommendations['attention_slicing'] = False elif available_vram_gb >= inference_memory_fp16: recommendations['recommendations'].append("āœ… Full model can fit in VRAM with FP16 precision") recommendations['recommended_precision'] = 'float16' recommendations['cpu_offload'] = False recommendations['attention_slicing'] = False elif available_vram_gb >= memory_info['memory_fp16_gb']: recommendations['recommendations'].append("āš ļø Model weights fit, but may need memory optimizations") recommendations['recommended_precision'] = 'float16' recommendations['cpu_offload'] = False recommendations['attention_slicing'] = True recommendations['vae_slicing'] = True else: recommendations['recommendations'].append("šŸ”„ Requires CPU offloading and memory optimizations") recommendations['recommended_precision'] = 'float16' recommendations['cpu_offload'] = True recommendations['sequential_offload'] = True recommendations['attention_slicing'] = True recommendations['vae_slicing'] = True return recommendations def format_memory_info(self, model_id: str) -> str: """Format memory information for display.""" info = self.get_model_memory_requirements(model_id) if 'error' in info: return f"āŒ Error calculating memory for {model_id}: {info['error']}" output = f""" šŸ“Š **Memory Requirements for {model_id}** šŸ”¢ **Parameters**: {info['total_params_billions']:.2f}B parameters šŸ’¾ **Model Memory**: • FP32: {info['memory_fp32_gb']:.2f} GB • FP16/BF16: {info['memory_fp16_gb']:.2f} GB • INT8: {info['memory_int8_gb']:.2f} GB šŸš€ **Estimated Inference Memory**: • FP16: {info['estimated_inference_memory_fp16_gb']:.2f} GB • BF16: {info['estimated_inference_memory_bf16_gb']:.2f} GB šŸ“ **SafeTensor Files**: {len(info['safetensors_files'])} files """ return output.strip() # Example usage and testing if __name__ == "__main__": calculator = ModelMemoryCalculator() # Test with FLUX.1-schnell model_id = "black-forest-labs/FLUX.1-schnell" print(f"Testing memory calculation for {model_id}...") memory_info = calculator.get_model_memory_requirements(model_id) print(json.dumps(memory_info, indent=2)) # Test recommendations print("\n" + "="*50) print("MEMORY RECOMMENDATIONS") print("="*50) vram_options = [8, 16, 24, 40] for vram in vram_options: rec = calculator.get_memory_recommendation(model_id, vram) print(f"\nšŸŽÆ For {vram}GB VRAM:") if 'recommendations' in rec: for r in rec['recommendations']: print(f" {r}") # Format for display print("\n" + "="*50) print("FORMATTED OUTPUT") print("="*50) print(calculator.format_memory_info(model_id))