|
import os |
|
import torch |
|
import torchaudio |
|
import time |
|
import sys |
|
import numpy as np |
|
import gc |
|
import gradio as gr |
|
from pydub import AudioSegment |
|
from audiocraft.models import MusicGen |
|
from torch.cuda.amp import autocast |
|
import warnings |
|
import random |
|
import traceback |
|
import logging |
|
from datetime import datetime |
|
from pathlib import Path |
|
import mmap |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" |
|
|
|
|
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
log_dir = "logs" |
|
os.makedirs(log_dir, exist_ok=True) |
|
log_file = os.path.join(log_dir, f"musicgen_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") |
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format="%(asctime)s [%(levelname)s] %(message)s", |
|
handlers=[ |
|
logging.FileHandler(log_file), |
|
logging.StreamHandler(sys.stdout) |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
if device != "cuda": |
|
logger.error("CUDA is required for GPU rendering. CPU rendering is disabled.") |
|
sys.exit(1) |
|
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)} (CUDA 12)") |
|
logger.info(f"Using precision: float16 for model, float32 for CPU processing") |
|
|
|
|
|
def clean_memory(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
torch.cuda.ipc_collect() |
|
torch.cuda.synchronize() |
|
vram_mb = torch.cuda.memory_allocated() / 1024**2 |
|
logger.info(f"Memory cleaned: VRAM allocated = {vram_mb:.2f} MB") |
|
logger.debug(f"VRAM summary: {torch.cuda.memory_summary()}") |
|
return vram_mb |
|
|
|
|
|
clean_memory() |
|
|
|
|
|
try: |
|
logger.info("Loading MusicGen medium model into VRAM...") |
|
local_model_path = "./models/musicgen-medium" |
|
if not os.path.exists(local_model_path): |
|
logger.error(f"Local model path {local_model_path} does not exist.") |
|
logger.error("Please download the MusicGen medium model weights and place them in the correct directory.") |
|
sys.exit(1) |
|
musicgen_model = MusicGen.get_pretrained(local_model_path, device=device) |
|
musicgen_model.set_generation_params( |
|
duration=30, |
|
two_step_cfg=False |
|
) |
|
logger.info("MusicGen medium model loaded successfully.") |
|
except Exception as e: |
|
logger.error(f"Failed to load MusicGen model: {e}") |
|
logger.error(traceback.format_exc()) |
|
sys.exit(1) |
|
|
|
|
|
def check_disk_space(path="."): |
|
stat = os.statvfs(path) |
|
free_space = stat.f_bavail * stat.f_frsize / (1024**3) |
|
if free_space < 1.0: |
|
logger.warning(f"Low disk space ({free_space:.2f} GB). Ensure at least 1 GB free.") |
|
return free_space >= 1.0 |
|
|
|
|
|
def balance_stereo(audio_segment, noise_threshold=-60, sample_rate=16000): |
|
logger.debug(f"Balancing stereo for segment with sample rate {sample_rate}") |
|
samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) |
|
if audio_segment.channels == 2: |
|
stereo_samples = samples.reshape(-1, 2) |
|
db_samples = 20 * np.log10(np.abs(stereo_samples) + 1e-10) |
|
mask = db_samples > noise_threshold |
|
stereo_samples = stereo_samples * mask |
|
left_nonzero = stereo_samples[:, 0][stereo_samples[:, 0] != 0] |
|
right_nonzero = stereo_samples[:, 1][stereo_samples[:, 1] != 0] |
|
left_rms = np.sqrt(np.mean(left_nonzero**2)) if len(left_nonzero) > 0 else 0 |
|
right_rms = np.sqrt(np.mean(right_nonzero**2)) if len(right_nonzero) > 0 else 0 |
|
if left_rms > 0 and right_rms > 0: |
|
avg_rms = (left_rms + right_rms) / 2 |
|
stereo_samples[:, 0] = stereo_samples[:, 0] * (avg_rms / left_rms) |
|
stereo_samples[:, 1] = stereo_samples[:, 1] * (avg_rms / right_rms) |
|
balanced_samples = stereo_samples.flatten().astype(np.int16) |
|
balanced_segment = AudioSegment( |
|
balanced_samples.tobytes(), |
|
frame_rate=sample_rate, |
|
sample_width=audio_segment.sample_width, |
|
channels=2 |
|
) |
|
logger.debug("Stereo balancing completed") |
|
return balanced_segment |
|
logger.debug("Segment is not stereo, returning unchanged") |
|
return audio_segment |
|
|
|
def calculate_rms(segment): |
|
samples = np.array(segment.get_array_of_samples(), dtype=np.float32) |
|
rms = np.sqrt(np.mean(samples**2)) |
|
logger.debug(f"Calculated RMS: {rms}") |
|
return rms |
|
|
|
def rms_normalize(segment, target_rms_db=-23.0, peak_limit_db=-3.0, sample_rate=16000): |
|
logger.debug(f"Normalizing RMS for segment with target {target_rms_db} dBFS") |
|
target_rms = 10 ** (target_rms_db / 20) * 32767 |
|
current_rms = calculate_rms(segment) |
|
if current_rms > 0: |
|
gain_factor = target_rms / current_rms |
|
segment = segment.apply_gain(20 * np.log10(gain_factor)) |
|
segment = hard_limit(segment, limit_db=peak_limit_db, sample_rate=sample_rate) |
|
logger.debug("RMS normalization completed") |
|
return segment |
|
|
|
def hard_limit(audio_segment, limit_db=-3.0, sample_rate=16000): |
|
logger.debug(f"Applying hard limit at {limit_db} dBFS") |
|
limit = 10 ** (limit_db / 20.0) * 32767 |
|
samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32) |
|
samples = np.clip(samples, -limit, limit).astype(np.int16) |
|
limited_segment = AudioSegment( |
|
samples.tobytes(), |
|
frame_rate=sample_rate, |
|
sample_width=audio_segment.sample_width, |
|
channels=audio_segment.channels |
|
) |
|
logger.debug("Hard limit applied") |
|
return limited_segment |
|
|
|
def apply_eq(segment, sample_rate=16000): |
|
logger.debug(f"Applying EQ with sample rate {sample_rate}") |
|
segment = segment.high_pass_filter(20) |
|
segment = segment.low_pass_filter(20000) |
|
logger.debug("EQ applied") |
|
return segment |
|
|
|
def apply_fade(segment, fade_in_duration=500, fade_out_duration=500): |
|
logger.debug(f"Applying fade: in={fade_in_duration}ms, out={fade_out_duration}ms") |
|
segment = segment.fade_in(fade_in_duration) |
|
segment = segment.fade_out(fade_out_duration) |
|
logger.debug("Fade applied") |
|
return segment |
|
|
|
|
|
def set_red_hot_chili_peppers_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("strong rhythmic steps" if bpm > 120 else "groovy rhythmic flow") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else "" |
|
bass = f", {bass_style}" if bass_style != "none" else ", groovy basslines" |
|
guitar = f", {guitar_style} guitar riffs" if guitar_style != "none" else ", syncopated guitar riffs" |
|
prompt = f"Instrumental funk rock{bass}{guitar}{drum}{synth}, Red Hot Chili Peppers-inspired vibe with dynamic energy and funky breakdowns, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated RHCP prompt: {prompt}") |
|
return prompt |
|
|
|
def set_nirvana_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("intense rhythmic steps" if bpm > 120 else "grungy rhythmic pulse") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else "" |
|
bass = f", {bass_style}" if bass_style != "none" else ", melodic basslines" |
|
guitar = f", {guitar_style} guitar riffs" if guitar_style != "none" else ", raw distorted guitar riffs" |
|
prompt = f"Instrumental grunge{bass}{guitar}{drum}{synth}, Nirvana-inspired angst-filled sound with quiet-loud dynamics, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Nirvana prompt: {prompt}") |
|
return prompt |
|
|
|
def set_pearl_jam_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("soulful rhythmic steps" if bpm > 120 else "driving rhythmic flow") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else "" |
|
bass = f", {bass_style}" if bass_style != "none" else ", deep bass" |
|
guitar = f", {guitar_style} guitar leads" if guitar_style != "none" else ", soulful guitar leads" |
|
prompt = f"Instrumental grunge{bass}{guitar}{drum}{synth}, Pearl Jam-inspired emotional intensity with soaring choruses, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Pearl Jam prompt: {prompt}") |
|
return prompt |
|
|
|
def set_soundgarden_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("heavy rhythmic steps" if bpm > 120 else "sludgy rhythmic groove") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else "" |
|
bass = f", {bass_style}" if bass_style != "none" else "" |
|
guitar = f", {guitar_style} guitar riffs" if guitar_style != "none" else ", heavy sludgy guitar riffs" |
|
prompt = f"Instrumental grunge{bass}{guitar}{drum}{synth}, Soundgarden-inspired dark, psychedelic edge, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Soundgarden prompt: {prompt}") |
|
return prompt |
|
|
|
def set_foo_fighters_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
styles = ["anthemic", "gritty", "melodic", "fast-paced", "driving"] |
|
tempos = ["upbeat", "mid-tempo", "high-energy"] |
|
moods = ["energetic", "introspective", "rebellious", "uplifting"] |
|
style = random.choice(styles) |
|
tempo = random.choice(tempos) |
|
mood = random.choice(moods) |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("powerful rhythmic steps" if bpm > 120 else "catchy rhythmic groove") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else "" |
|
bass = f", {bass_style}" if bass_style != "none" else "" |
|
guitar = f", {guitar_style} guitar riffs" if guitar_style != "none" else f", {style} guitar riffs" |
|
prompt = f"Instrumental alternative rock{bass}{guitar}{drum}{synth}, Foo Fighters-inspired {mood} vibe with powerful choruses, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Foo Fighters prompt: {prompt}") |
|
return prompt |
|
|
|
def set_smashing_pumpkins_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("dynamic rhythmic steps" if bpm > 120 else "dreamy rhythmic flow") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else "" |
|
bass = f", {bass_style}" if bass_style != "none" else "" |
|
guitar = f", {guitar_style} guitar textures" if guitar_style != "none" else ", dreamy guitar textures" |
|
prompt = f"Instrumental alternative rock{bass}{guitar}{drum}{synth}, Smashing Pumpkins-inspired blend of melancholy and aggression, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Smashing Pumpkins prompt: {prompt}") |
|
return prompt |
|
|
|
def set_radiohead_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("complex rhythmic steps" if bpm > 120 else "intricate rhythmic pulse") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else ", atmospheric synths" |
|
bass = f", {bass_style}" if bass_style != "none" else "" |
|
guitar = f", {guitar_style} guitar layers" if guitar_style != "none" else ", intricate guitar layers" |
|
prompt = f"Instrumental experimental rock{bass}{guitar}{drum}{synth}, Radiohead-inspired blend of introspective and innovative soundscapes, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Radiohead prompt: {prompt}") |
|
return prompt |
|
|
|
def set_classic_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("bluesy rhythmic steps" if bpm > 120 else "steady rhythmic groove") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else "" |
|
bass = f", {bass_style}" if bass_style != "none" else ", groovy bass" |
|
guitar = f", {guitar_style} electric guitars" if guitar_style != "none" else ", bluesy electric guitars" |
|
prompt = f"Instrumental classic rock{bass}{guitar}{drum}{synth}, Led Zeppelin-inspired raw energy with dynamic solos, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Classic Rock prompt: {prompt}") |
|
return prompt |
|
|
|
def set_alternative_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("quirky rhythmic steps" if bpm > 120 else "energetic rhythmic flow") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else "" |
|
bass = f", {bass_style}" if bass_style != "none" else ", melodic basslines" |
|
guitar = f", {guitar_style} guitar riffs" if guitar_style != "none" else ", distorted guitar riffs" |
|
prompt = f"Instrumental alternative rock{bass}{guitar}{drum}{synth}, Pixies-inspired quirky, energetic vibe, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Alternative Rock prompt: {prompt}") |
|
return prompt |
|
|
|
def set_post_punk_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("sharp rhythmic steps" if bpm > 120 else "moody rhythmic pulse") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else "" |
|
bass = f", {bass_style}" if bass_style != "none" else ", driving basslines" |
|
guitar = f", {guitar_style} guitars" if guitar_style != "none" else ", jangly guitars" |
|
prompt = f"Instrumental post-punk{bass}{guitar}{drum}{synth}, Joy Division-inspired moody, atmospheric sound with a steady, hypnotic beat, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Post-Punk prompt: {prompt}") |
|
return prompt |
|
|
|
def set_indie_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("catchy rhythmic steps" if bpm > 120 else "jangly rhythmic flow") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else "" |
|
bass = f", {bass_style}" if bass_style != "none" else "" |
|
guitar = f", {guitar_style} guitars" if guitar_style != "none" else ", jangly guitars" |
|
prompt = f"Instrumental indie rock{bass}{guitar}{drum}{synth}, Arctic Monkeys-inspired blend of catchy riffs, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Indie Rock prompt: {prompt}") |
|
return prompt |
|
|
|
def set_funk_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("aggressive rhythmic steps" if bpm > 120 else "funky rhythmic groove") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else "" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else "" |
|
bass = f", {bass_style}" if bass_style != "none" else ", slap bass" |
|
guitar = f", {guitar_style} guitar chords" if guitar_style != "none" else ", funky guitar chords" |
|
prompt = f"Instrumental funk rock{bass}{guitar}{drum}{synth}, Rage Against the Machine-inspired mix of groove and aggression, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Funk Rock prompt: {prompt}") |
|
return prompt |
|
|
|
def set_detroit_techno_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("pulsing rhythmic steps" if bpm > 120 else "deep rhythmic groove") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else ", crisp hi-hats and a steady four-on-the-floor kick drum" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else ", deep pulsing synths with a repetitive, hypnotic pattern" |
|
bass = f", {bass_style}" if bass_style != "none" else ", driving basslines with a consistent, groovy pulse" |
|
guitar = f", {guitar_style} guitars" if guitar_style != "none" else "" |
|
prompt = f"Instrumental Detroit techno{bass}{guitar}{drum}{synth}, Juan Atkins-inspired rhythmic groove with a steady, repetitive beat, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Detroit Techno prompt: {prompt}") |
|
return prompt |
|
|
|
def set_deep_house_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("soulful rhythmic steps" if bpm > 120 else "laid-back rhythmic flow") |
|
drum = f", {drum_beat} drums" if drum_beat != "none" else ", steady four-on-the-floor kick drum with soft hi-hats" |
|
synth = f", {synthesizer} accents" if synthesizer != "none" else ", warm analog synth chords with a repetitive, hypnotic progression" |
|
bass = f", {bass_style}" if bass_style != "none" else ", deep basslines with a consistent, groovy pulse" |
|
guitar = f", {guitar_style} guitars" if guitar_style != "none" else "" |
|
prompt = f"Instrumental deep house{bass}{guitar}{drum}{synth}, Larry Heard-inspired laid-back groove with a steady, repetitive beat, {rhythm} at {bpm} BPM." |
|
logger.debug(f"Generated Deep House prompt: {prompt}") |
|
return prompt |
|
|
|
|
|
PRESETS = { |
|
"default": {"cfg_scale": 2.0, "top_k": 150, "top_p": 0.9, "temperature": 0.8}, |
|
"rock": {"cfg_scale": 2.5, "top_k": 140, "top_p": 0.9, "temperature": 0.9}, |
|
"techno": {"cfg_scale": 1.8, "top_k": 160, "top_p": 0.85, "temperature": 0.7}, |
|
"grunge": {"cfg_scale": 2.0, "top_k": 150, "top_p": 0.9, "temperature": 0.85}, |
|
"indie": {"cfg_scale": 2.2, "top_k": 145, "top_p": 0.9, "temperature": 0.8} |
|
} |
|
|
|
|
|
def get_latest_log(): |
|
log_files = sorted(Path(log_dir).glob("musicgen_log_*.log"), key=os.path.getmtime, reverse=True) |
|
if not log_files: |
|
logger.warning("No log files found") |
|
return "No log files found." |
|
try: |
|
with open(log_files[0], "r") as f: |
|
content = f.read() |
|
logger.info(f"Retrieved latest log file: {log_files[0]}") |
|
return content |
|
except Exception as e: |
|
logger.error(f"Failed to read log file {log_files[0]}: {e}") |
|
return f"Error reading log file: {e}" |
|
|
|
|
|
def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, bpm: int, drum_beat: str, synthesizer: str, rhythmic_steps: str, bass_style: str, guitar_style: str, target_volume: float, preset: str, vram_status: str): |
|
global musicgen_model |
|
if not instrumental_prompt.strip(): |
|
logger.warning("Empty instrumental prompt provided") |
|
return None, "β οΈ Please enter a valid instrumental prompt!", vram_status |
|
try: |
|
logger.info("Starting music generation...") |
|
start_time = time.time() |
|
max_duration = 30 |
|
total_duration = min(max(total_duration, 30), 120) |
|
processing_sample_rate = 16000 |
|
output_sample_rate = 32000 |
|
audio_segments = [] |
|
overlap_duration = 0.3 |
|
remaining_duration = total_duration |
|
|
|
if preset != "default": |
|
preset_params = PRESETS.get(preset, PRESETS["default"]) |
|
cfg_scale = preset_params["cfg_scale"] |
|
top_k = preset_params["top_k"] |
|
top_p = preset_params["top_p"] |
|
temperature = preset_params["temperature"] |
|
logger.info(f"Applied preset {preset}: cfg_scale={cfg_scale}, top_k={top_k}, top_p={top_p}, temperature={temperature}") |
|
|
|
if not check_disk_space(): |
|
logger.error("Insufficient disk space") |
|
return None, "β οΈ Insufficient disk space. Free up at least 1 GB.", vram_status |
|
|
|
logger.info(f"Generating audio for {total_duration}s with seed=42") |
|
seed = 42 |
|
base_prompt = instrumental_prompt |
|
clean_memory() |
|
vram_status = f"Initial VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" |
|
|
|
while remaining_duration > 0: |
|
current_duration = min(max_duration, remaining_duration) |
|
generation_duration = current_duration |
|
chunk_num = len(audio_segments) + 1 |
|
logger.info(f"Generating chunk {chunk_num} ({current_duration}s, VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB)") |
|
|
|
musicgen_model.set_generation_params( |
|
duration=generation_duration, |
|
use_sampling=True, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temperature=temperature, |
|
cfg_coef=cfg_scale |
|
) |
|
|
|
try: |
|
with torch.no_grad(): |
|
with autocast(dtype=torch.float16): |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
clean_memory() |
|
if not audio_segments: |
|
logger.debug("Generating first chunk") |
|
audio_segment = musicgen_model.generate([base_prompt], progress=True)[0].cpu() |
|
else: |
|
logger.debug("Generating continuation chunk") |
|
prev_segment = audio_segments[-1] |
|
prev_segment = balance_stereo(prev_segment, noise_threshold=-60, sample_rate=processing_sample_rate) |
|
temp_wav_path = f"temp_prev_{int(time.time()*1000)}.wav" |
|
logger.debug(f"Exporting previous segment to {temp_wav_path}") |
|
prev_segment.export(temp_wav_path, format="wav") |
|
|
|
with open(temp_wav_path, "rb") as f: |
|
mmapped_file = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) |
|
prev_audio, prev_sr = torchaudio.load(temp_wav_path) |
|
mmapped_file.close() |
|
if prev_sr != processing_sample_rate: |
|
logger.debug(f"Resampling from {prev_sr} to {processing_sample_rate}") |
|
prev_audio = torchaudio.transforms.Resample(prev_sr, processing_sample_rate)(prev_audio) |
|
prev_audio = prev_audio.to(device) |
|
os.remove(temp_wav_path) |
|
logger.debug(f"Deleted temporary file {temp_wav_path}") |
|
audio_segment = musicgen_model.generate_continuation( |
|
prompt=prev_audio[:, -int(processing_sample_rate * overlap_duration):], |
|
prompt_sample_rate=processing_sample_rate, |
|
descriptions=[base_prompt], |
|
progress=True |
|
)[0].cpu() |
|
del prev_audio |
|
clean_memory() |
|
except Exception as e: |
|
logger.error(f"Error in chunk {chunk_num} generation: {e}") |
|
logger.error(traceback.format_exc()) |
|
raise e |
|
|
|
logger.debug(f"Generated audio segment shape: {audio_segment.shape}") |
|
audio_segment = audio_segment.to(dtype=torch.float32) |
|
if audio_segment.dim() == 1: |
|
logger.debug("Converting mono to stereo") |
|
audio_segment = torch.stack([audio_segment, audio_segment], dim=0) |
|
elif audio_segment.dim() == 2 and audio_segment.shape[0] != 2: |
|
logger.debug("Adjusting to stereo") |
|
audio_segment = torch.cat([audio_segment, audio_segment], dim=0) |
|
|
|
if audio_segment.shape[0] != 2: |
|
logger.error(f"Expected stereo audio with shape (2, samples), got shape {audio_segment.shape}") |
|
raise ValueError(f"Expected stereo audio with shape (2, samples), got shape {audio_segment.shape}") |
|
|
|
temp_wav_path = f"temp_audio_{int(time.time()*1000)}.wav" |
|
logger.debug(f"Saving audio segment to {temp_wav_path}") |
|
torchaudio.save(temp_wav_path, audio_segment, output_sample_rate, bits_per_sample=16) |
|
segment = AudioSegment.from_wav(temp_wav_path) |
|
os.remove(temp_wav_path) |
|
logger.debug(f"Deleted temporary file {temp_wav_path}") |
|
segment = segment - 15 |
|
if segment.frame_rate != processing_sample_rate: |
|
logger.debug(f"Setting segment sample rate to {processing_sample_rate}") |
|
segment = segment.set_frame_rate(processing_sample_rate) |
|
segment = balance_stereo(segment, noise_threshold=-60, sample_rate=processing_sample_rate) |
|
segment = rms_normalize(segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate) |
|
segment = apply_eq(segment, sample_rate=processing_sample_rate) |
|
audio_segments.append(segment) |
|
|
|
del audio_segment |
|
clean_memory() |
|
vram_status = f"VRAM after chunk {chunk_num}: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" |
|
time.sleep(0.1) |
|
remaining_duration -= current_duration |
|
|
|
logger.info("Combining audio chunks...") |
|
final_segment = audio_segments[0][:min(max_duration, total_duration) * 1000] |
|
overlap_ms = int(overlap_duration * 1000) |
|
|
|
for i in range(1, len(audio_segments)): |
|
current_segment = audio_segments[i] |
|
current_segment = current_segment[:min(max_duration, total_duration - (i * max_duration)) * 1000] |
|
|
|
if overlap_ms > 0 and len(current_segment) > overlap_ms: |
|
logger.debug(f"Applying crossfade between chunks {i} and {i+1}") |
|
prev_overlap = final_segment[-overlap_ms:] |
|
curr_overlap = current_segment[:overlap_ms] |
|
num_samples = len(np.array(prev_overlap.get_array_of_samples(), dtype=np.float32)) // 2 |
|
blended_samples = np.zeros((num_samples, 2), dtype=np.float32) |
|
prev_samples = np.array(prev_overlap.get_array_of_samples(), dtype=np.float32).reshape(-1, 2) |
|
curr_samples = np.array(curr_overlap.get_array_of_samples(), dtype=np.float32).reshape(-1, 2) |
|
hann_window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(num_samples) / num_samples)) |
|
fade_out = hann_window[::-1] |
|
fade_in = hann_window |
|
blended_samples = (prev_samples * fade_out[:, None] + curr_samples * fade_in[:, None]) |
|
blended_segment = AudioSegment( |
|
blended_samples.astype(np.int16).tobytes(), |
|
frame_rate=processing_sample_rate, |
|
sample_width=2, |
|
channels=2 |
|
) |
|
blended_segment = rms_normalize(blended_segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate) |
|
final_segment = final_segment[:-overlap_ms] + blended_segment + current_segment[overlap_ms:] |
|
else: |
|
logger.debug(f"Concatenating chunk {i+1} without crossfade") |
|
final_segment += current_segment |
|
|
|
final_segment = final_segment[:total_duration * 1000] |
|
logger.info("Post-processing final track...") |
|
final_segment = rms_normalize(final_segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate) |
|
final_segment = apply_eq(final_segment, sample_rate=processing_sample_rate) |
|
final_segment = apply_fade(final_segment) |
|
final_segment = balance_stereo(final_segment, noise_threshold=-60, sample_rate=processing_sample_rate) |
|
final_segment = final_segment - 10 |
|
final_segment = final_segment.set_frame_rate(output_sample_rate) |
|
|
|
mp3_path = f"output_adjusted_volume_{int(time.time())}.mp3" |
|
logger.info("β οΈ WARNING: Audio is set to safe levels (~ -23 dBFS RMS, -3 dBFS peak). Start playback at LOW volume (10-20%) and adjust gradually.") |
|
logger.info("VERIFY: Open the file in Audacity to check for static. RMS should be ~ -23 dBFS, peaks β€ -3 dBFS. Report any static or issues.") |
|
try: |
|
logger.debug(f"Exporting final audio to {mp3_path}") |
|
final_segment.export( |
|
mp3_path, |
|
format="mp3", |
|
bitrate="96k", |
|
tags={"title": "GhostAI Instrumental", "artist": "GhostAI"} |
|
) |
|
logger.info(f"Final audio saved to {mp3_path}") |
|
except Exception as e: |
|
logger.error(f"Error exporting MP3: {e}") |
|
fallback_path = f"fallback_output_{int(time.time())}.mp3" |
|
try: |
|
final_segment.export(fallback_path, format="mp3", bitrate="96k") |
|
logger.info(f"Final audio saved to fallback: {fallback_path}") |
|
mp3_path = fallback_path |
|
except Exception as fallback_e: |
|
logger.error(f"Failed to save fallback MP3: {fallback_e}") |
|
raise e |
|
|
|
vram_status = f"Final VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" |
|
logger.info(f"Generation completed in {time.time() - start_time:.2f} seconds") |
|
return mp3_path, "β
Done! Generated static-free track with adjusted volume levels.", vram_status |
|
except Exception as e: |
|
logger.error(f"Generation failed: {e}") |
|
logger.error(traceback.format_exc()) |
|
return None, f"β Generation failed: {e}", vram_status |
|
finally: |
|
clean_memory() |
|
|
|
|
|
def clear_inputs(): |
|
logger.info("Clearing input fields") |
|
return "", 2.0, 150, 0.9, 0.8, 30, 120, "none", "none", "none", "none", "none", -23.0, "default", "" |
|
|
|
|
|
css = """ |
|
body { |
|
background: linear-gradient(135deg, #0A0A0A 0%, #1C2526 100%); |
|
color: #E0E0E0; |
|
font-family: 'Orbitron', sans-serif; |
|
} |
|
.header-container { |
|
text-align: center; |
|
padding: 10px 20px; |
|
background: rgba(0, 0, 0, 0.9); |
|
border-bottom: 1px solid #00FF9F; |
|
} |
|
#ghost-logo { |
|
font-size: 40px; |
|
animation: glitch-ghost 1.5s infinite; |
|
} |
|
h1 { |
|
color: #A100FF; |
|
font-size: 24px; |
|
animation: glitch-text 2s infinite; |
|
} |
|
p { |
|
color: #E0E0E0; |
|
font-size: 12px; |
|
} |
|
.input-container, .settings-container, .output-container, .logs-container { |
|
max-width: 1200px; |
|
margin: 20px auto; |
|
padding: 20px; |
|
background: rgba(28, 37, 38, 0.8); |
|
border-radius: 10px; |
|
} |
|
.textbox { |
|
background: #1A1A1A; |
|
border: 1px solid #A100FF; |
|
color: #E0E0E0; |
|
} |
|
.genre-buttons { |
|
display: flex; |
|
justify-content: center; |
|
flex-wrap: wrap; |
|
gap: 15px; |
|
} |
|
.genre-btn, button { |
|
background: linear-gradient(45deg, #A100FF, #00FF9F); |
|
border: none; |
|
color: #0A0A0A; |
|
padding: 10px 20px; |
|
border-radius: 5px; |
|
} |
|
.gradio-container { |
|
padding: 20px; |
|
} |
|
.group-container { |
|
margin-bottom: 20px; |
|
padding: 15px; |
|
border: 1px solid #00FF9F; |
|
border-radius: 8px; |
|
} |
|
@keyframes glitch-ghost { |
|
0% { transform: translate(0, 0); opacity: 1; } |
|
20% { transform: translate(-5px, 2px); opacity: 0.8; } |
|
100% { transform: translate(0, 0); opacity: 1; } |
|
} |
|
@keyframes glitch-text { |
|
0% { transform: translate(0, 0); } |
|
20% { transform: translate(-2px, 1px); } |
|
100% { transform: translate(0, 0); } |
|
} |
|
@font-face { |
|
font-family: 'Orbitron'; |
|
src: url('https://fonts.gstatic.com/s/orbitron/v29/yMJRMIlzdpvBhQQL_Qq7dy0.woff2') format('woff2'); |
|
} |
|
""" |
|
|
|
|
|
logger.info("Building Gradio interface...") |
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(""" |
|
<div class="header-container"> |
|
<div id="ghost-logo">π»</div> |
|
<h1>GhostAI Music Generator πΉ</h1> |
|
<p>Summon the Sound of the Unknown</p> |
|
</div> |
|
""") |
|
|
|
with gr.Column(elem_classes="input-container"): |
|
gr.Markdown("### πΈ Prompt Settings") |
|
instrumental_prompt = gr.Textbox( |
|
label="Instrumental Prompt βοΈ", |
|
placeholder="Click a genre button or type your own instrumental prompt", |
|
lines=4, |
|
elem_classes="textbox" |
|
) |
|
with gr.Row(elem_classes="genre-buttons"): |
|
rhcp_btn = gr.Button("Red Hot Chili Peppers πΆοΈ", elem_classes="genre-btn") |
|
nirvana_btn = gr.Button("Nirvana Grunge πΈ", elem_classes="genre-btn") |
|
pearl_jam_btn = gr.Button("Pearl Jam Grunge π¦ͺ", elem_classes="genre-btn") |
|
soundgarden_btn = gr.Button("Soundgarden Grunge π", elem_classes="genre-btn") |
|
foo_fighters_btn = gr.Button("Foo Fighters π€", elem_classes="genre-btn") |
|
smashing_pumpkins_btn = gr.Button("Smashing Pumpkins π", elem_classes="genre-btn") |
|
radiohead_btn = gr.Button("Radiohead π§ ", elem_classes="genre-btn") |
|
classic_rock_btn = gr.Button("Classic Rock πΈ", elem_classes="genre-btn") |
|
alternative_rock_btn = gr.Button("Alternative Rock π΅", elem_classes="genre-btn") |
|
post_punk_btn = gr.Button("Post-Punk π€", elem_classes="genre-btn") |
|
indie_rock_btn = gr.Button("Indie Rock π€", elem_classes="genre-btn") |
|
funk_rock_btn = gr.Button("Funk Rock πΊ", elem_classes="genre-btn") |
|
detroit_techno_btn = gr.Button("Detroit Techno ποΈ", elem_classes="genre-btn") |
|
deep_house_btn = gr.Button("Deep House π ", elem_classes="genre-btn") |
|
|
|
with gr.Column(elem_classes="settings-container"): |
|
gr.Markdown("### βοΈ API Settings") |
|
with gr.Group(elem_classes="group-container"): |
|
cfg_scale = gr.Slider( |
|
label="CFG Scale π―", |
|
minimum=1.0, |
|
maximum=10.0, |
|
value=2.0, |
|
step=0.1, |
|
info="Controls how closely the music follows the prompt." |
|
) |
|
top_k = gr.Slider( |
|
label="Top-K Sampling π’", |
|
minimum=10, |
|
maximum=500, |
|
value=150, |
|
step=10, |
|
info="Limits sampling to the top k most likely tokens." |
|
) |
|
top_p = gr.Slider( |
|
label="Top-P Sampling π°", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.05, |
|
info="Keeps tokens with cumulative probability above p." |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature π₯", |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.8, |
|
step=0.1, |
|
info="Controls randomness; lower values reduce noise." |
|
) |
|
total_duration = gr.Dropdown( |
|
label="Song Length β³ (seconds)", |
|
choices=[30, 60, 90, 120], |
|
value=30, |
|
info="Select the total duration of the track." |
|
) |
|
bpm = gr.Slider( |
|
label="Tempo π΅ (BPM)", |
|
minimum=60, |
|
maximum=180, |
|
value=120, |
|
step=1, |
|
info="Beats per minute to set the track's tempo." |
|
) |
|
drum_beat = gr.Dropdown( |
|
label="Drum Beat π₯", |
|
choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing"], |
|
value="none", |
|
info="Select a drum beat style to influence the rhythm." |
|
) |
|
synthesizer = gr.Dropdown( |
|
label="Synthesizer πΉ", |
|
choices=["none", "analog synth", "digital pad", "arpeggiated synth"], |
|
value="none", |
|
info="Select a synthesizer style for electronic accents." |
|
) |
|
rhythmic_steps = gr.Dropdown( |
|
label="Rhythmic Steps π£", |
|
choices=["none", "syncopated steps", "steady steps", "complex steps"], |
|
value="none", |
|
info="Select a rhythmic step style to enhance the beat." |
|
) |
|
bass_style = gr.Dropdown( |
|
label="Bass Style πΈ", |
|
choices=["none", "slap bass", "deep bass", "melodic bass"], |
|
value="none", |
|
info="Select a bass style to shape the low end." |
|
) |
|
guitar_style = gr.Dropdown( |
|
label="Guitar Style πΈ", |
|
choices=["none", "distorted", "clean", "jangle"], |
|
value="none", |
|
info="Select a guitar style to define the riffs." |
|
) |
|
target_volume = gr.Slider( |
|
label="Target Volume ποΈ (dBFS RMS)", |
|
minimum=-30.0, |
|
maximum=-20.0, |
|
value=-23.0, |
|
step=1.0, |
|
info="Adjust output loudness (-23 dBFS is standard, -20 dBFS is louder, -30 dBFS is quieter)." |
|
) |
|
preset = gr.Dropdown( |
|
label="Preset Configuration ποΈ", |
|
choices=["default", "rock", "techno", "grunge", "indie"], |
|
value="default", |
|
info="Select a preset optimized for specific genres." |
|
) |
|
|
|
with gr.Row(elem_classes="action-buttons"): |
|
gen_btn = gr.Button("Generate Music π") |
|
clr_btn = gr.Button("Clear Inputs π§Ή") |
|
|
|
with gr.Column(elem_classes="output-container"): |
|
gr.Markdown("### π§ Output") |
|
out_audio = gr.Audio(label="Generated Instrumental Track π΅", type="filepath") |
|
status = gr.Textbox(label="Status π’", interactive=False) |
|
vram_status = gr.Textbox(label="VRAM Usage π", interactive=False, value="") |
|
|
|
with gr.Column(elem_classes="logs-container"): |
|
gr.Markdown("### π Logs") |
|
log_output = gr.Textbox(label="Last Log File Contents", lines=20, interactive=False) |
|
log_btn = gr.Button("View Last Log π") |
|
|
|
rhcp_btn.click(set_red_hot_chili_peppers_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
nirvana_btn.click(set_nirvana_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
pearl_jam_btn.click(set_pearl_jam_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
soundgarden_btn.click(set_soundgarden_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
foo_fighters_btn.click(set_foo_fighters_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
smashing_pumpkins_btn.click(set_smashing_pumpkins_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
radiohead_btn.click(set_radiohead_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
classic_rock_btn.click(set_classic_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
alternative_rock_btn.click(set_alternative_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
post_punk_btn.click(set_post_punk_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
indie_rock_btn.click(set_indie_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
funk_rock_btn.click(set_funk_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
detroit_techno_btn.click(set_detroit_techno_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
deep_house_btn.click(set_deep_house_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt) |
|
gen_btn.click( |
|
generate_music, |
|
inputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, vram_status], |
|
outputs=[out_audio, status, vram_status] |
|
) |
|
clr_btn.click( |
|
clear_inputs, |
|
inputs=None, |
|
outputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, vram_status] |
|
) |
|
log_btn.click( |
|
get_latest_log, |
|
inputs=None, |
|
outputs=log_output |
|
) |
|
|
|
|
|
logger.info("Launching Gradio UI at http://localhost:9999...") |
|
app = demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=9999, |
|
share=False, |
|
inbrowser=False, |
|
show_error=True |
|
) |
|
try: |
|
fastapi_app = demo._server.app |
|
fastapi_app.docs_url = None |
|
fastapi_app.redoc_url = None |
|
fastapi_app.openapi_url = None |
|
except Exception as e: |
|
logger.error(f"Failed to configure FastAPI app: {e}") |