ghostai1 commited on
Commit
c3a53a3
·
verified ·
1 Parent(s): ebff558

Update 8bitapp.py

Browse files
Files changed (1) hide show
  1. 8bitapp.py +15 -8
8bitapp.py CHANGED
@@ -13,7 +13,6 @@ from torch.cuda.amp import autocast
13
  import warnings
14
  import random
15
  from bitsandbytes.nn import Linear8bitLt
16
- from transformers import AutoModel
17
 
18
  # Suppress warnings for cleaner output
19
  warnings.filterwarnings("ignore")
@@ -52,24 +51,32 @@ try:
52
  # Load MusicGen model in FP16
53
  musicgen_model = MusicGen.get_pretrained(local_model_path, device=device)
54
 
55
- # Apply 8-bit quantization to linear layers
56
  def quantize_to_8bit(model):
57
- for name, module in model.named_modules():
 
 
 
 
 
58
  if isinstance(module, torch.nn.Linear):
59
  # Replace with 8-bit linear layer
60
- parent = model
61
- for part in name.split('.')[:-1]:
 
62
  parent = getattr(parent, part)
63
- setattr(parent, name.split('.')[-1], Linear8bitLt(
64
  module.in_features,
65
  module.out_features,
66
  bias=module.bias is not None,
67
  has_fp16_weights=False,
68
  threshold=6.0
69
  ))
 
 
70
  return model
71
 
72
- # Quantize the model (apply to relevant transformer layers)
73
  musicgen_model = quantize_to_8bit(musicgen_model)
74
  musicgen_model.to(device)
75
 
@@ -94,7 +101,7 @@ def print_resource_usage(stage: str):
94
  print("---------------")
95
 
96
  # Check available GPU memory
97
- def check_vram_availability(required_gb=3.5):
98
  """Check if sufficient VRAM is available for audio generation."""
99
  total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
100
  allocated_vram = torch.cuda.memory_allocated() / (1024**3)
 
13
  import warnings
14
  import random
15
  from bitsandbytes.nn import Linear8bitLt
 
16
 
17
  # Suppress warnings for cleaner output
18
  warnings.filterwarnings("ignore")
 
51
  # Load MusicGen model in FP16
52
  musicgen_model = MusicGen.get_pretrained(local_model_path, device=device)
53
 
54
+ # Apply 8-bit quantization to the language model (lm) component
55
  def quantize_to_8bit(model):
56
+ # Target the lm (language model) attribute, which contains the transformer
57
+ if not hasattr(model, 'lm'):
58
+ raise AttributeError("MusicGen model does not have 'lm' attribute for quantization.")
59
+ lm = model.lm
60
+ quantized_layers = 0
61
+ for name, module in lm.named_modules():
62
  if isinstance(module, torch.nn.Linear):
63
  # Replace with 8-bit linear layer
64
+ parent = lm
65
+ name_parts = name.split('.')
66
+ for part in name_parts[:-1]:
67
  parent = getattr(parent, part)
68
+ setattr(parent, name_parts[-1], Linear8bitLt(
69
  module.in_features,
70
  module.out_features,
71
  bias=module.bias is not None,
72
  has_fp16_weights=False,
73
  threshold=6.0
74
  ))
75
+ quantized_layers += 1
76
+ print(f"Quantized {quantized_layers} linear layers to 8-bit.")
77
  return model
78
 
79
+ # Quantize the model
80
  musicgen_model = quantize_to_8bit(musicgen_model)
81
  musicgen_model.to(device)
82
 
 
101
  print("---------------")
102
 
103
  # Check available GPU memory
104
+ def check_vram_availability(required_gb=4.0):
105
  """Check if sufficient VRAM is available for audio generation."""
106
  total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
107
  allocated_vram = torch.cuda.memory_allocated() / (1024**3)