Update 8bitapp.py
Browse files- 8bitapp.py +15 -8
8bitapp.py
CHANGED
@@ -13,7 +13,6 @@ from torch.cuda.amp import autocast
|
|
13 |
import warnings
|
14 |
import random
|
15 |
from bitsandbytes.nn import Linear8bitLt
|
16 |
-
from transformers import AutoModel
|
17 |
|
18 |
# Suppress warnings for cleaner output
|
19 |
warnings.filterwarnings("ignore")
|
@@ -52,24 +51,32 @@ try:
|
|
52 |
# Load MusicGen model in FP16
|
53 |
musicgen_model = MusicGen.get_pretrained(local_model_path, device=device)
|
54 |
|
55 |
-
# Apply 8-bit quantization to
|
56 |
def quantize_to_8bit(model):
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
58 |
if isinstance(module, torch.nn.Linear):
|
59 |
# Replace with 8-bit linear layer
|
60 |
-
parent =
|
61 |
-
|
|
|
62 |
parent = getattr(parent, part)
|
63 |
-
setattr(parent,
|
64 |
module.in_features,
|
65 |
module.out_features,
|
66 |
bias=module.bias is not None,
|
67 |
has_fp16_weights=False,
|
68 |
threshold=6.0
|
69 |
))
|
|
|
|
|
70 |
return model
|
71 |
|
72 |
-
# Quantize the model
|
73 |
musicgen_model = quantize_to_8bit(musicgen_model)
|
74 |
musicgen_model.to(device)
|
75 |
|
@@ -94,7 +101,7 @@ def print_resource_usage(stage: str):
|
|
94 |
print("---------------")
|
95 |
|
96 |
# Check available GPU memory
|
97 |
-
def check_vram_availability(required_gb=
|
98 |
"""Check if sufficient VRAM is available for audio generation."""
|
99 |
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
100 |
allocated_vram = torch.cuda.memory_allocated() / (1024**3)
|
|
|
13 |
import warnings
|
14 |
import random
|
15 |
from bitsandbytes.nn import Linear8bitLt
|
|
|
16 |
|
17 |
# Suppress warnings for cleaner output
|
18 |
warnings.filterwarnings("ignore")
|
|
|
51 |
# Load MusicGen model in FP16
|
52 |
musicgen_model = MusicGen.get_pretrained(local_model_path, device=device)
|
53 |
|
54 |
+
# Apply 8-bit quantization to the language model (lm) component
|
55 |
def quantize_to_8bit(model):
|
56 |
+
# Target the lm (language model) attribute, which contains the transformer
|
57 |
+
if not hasattr(model, 'lm'):
|
58 |
+
raise AttributeError("MusicGen model does not have 'lm' attribute for quantization.")
|
59 |
+
lm = model.lm
|
60 |
+
quantized_layers = 0
|
61 |
+
for name, module in lm.named_modules():
|
62 |
if isinstance(module, torch.nn.Linear):
|
63 |
# Replace with 8-bit linear layer
|
64 |
+
parent = lm
|
65 |
+
name_parts = name.split('.')
|
66 |
+
for part in name_parts[:-1]:
|
67 |
parent = getattr(parent, part)
|
68 |
+
setattr(parent, name_parts[-1], Linear8bitLt(
|
69 |
module.in_features,
|
70 |
module.out_features,
|
71 |
bias=module.bias is not None,
|
72 |
has_fp16_weights=False,
|
73 |
threshold=6.0
|
74 |
))
|
75 |
+
quantized_layers += 1
|
76 |
+
print(f"Quantized {quantized_layers} linear layers to 8-bit.")
|
77 |
return model
|
78 |
|
79 |
+
# Quantize the model
|
80 |
musicgen_model = quantize_to_8bit(musicgen_model)
|
81 |
musicgen_model.to(device)
|
82 |
|
|
|
101 |
print("---------------")
|
102 |
|
103 |
# Check available GPU memory
|
104 |
+
def check_vram_availability(required_gb=4.0):
|
105 |
"""Check if sufficient VRAM is available for audio generation."""
|
106 |
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
107 |
allocated_vram = torch.cuda.memory_allocated() / (1024**3)
|