|
import os |
|
import torch |
|
import torchaudio |
|
import psutil |
|
import time |
|
import sys |
|
import numpy as np |
|
import gc |
|
import gradio as gr |
|
from pydub import AudioSegment |
|
from audiocraft.models import MusicGen |
|
from torch.cuda.amp import autocast |
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32" |
|
|
|
|
|
if np.__version__ != "1.23.5": |
|
print(f"WARNING: NumPy version {np.__version__} is being used. This script was tested with numpy==1.23.5, but proceeding anyway.") |
|
if not torch.__version__.startswith(("2.1.0", "2.3.1")): |
|
print(f"WARNING: PyTorch version {torch.__version__} may not be compatible. Expected torch==2.1.0 or 2.3.1.") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
if device != "cuda": |
|
print("ERROR: CUDA is required for GPU rendering. CPU rendering is disabled to avoid slow performance.") |
|
sys.exit(1) |
|
print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}") |
|
|
|
|
|
try: |
|
print("Loading MusicGen model into VRAM...") |
|
local_model_path = "/home/ubuntu/ghostai_music_generator/models/musicgen-medium" |
|
if not os.path.exists(local_model_path): |
|
print(f"ERROR: Local model path {local_model_path} does not exist. Please ensure the model weights are downloaded.") |
|
sys.exit(1) |
|
musicgen_model = MusicGen.get_pretrained(local_model_path, device=device) |
|
except Exception as e: |
|
print(f"ERROR: Failed to load MusicGen model: {e}") |
|
print("Please ensure the model weights are in the correct path and dependencies are installed.") |
|
sys.exit(1) |
|
|
|
|
|
def print_resource_usage(stage: str): |
|
print(f"--- {stage} ---") |
|
print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / (1024**3):.2f} GB") |
|
print(f"GPU Memory Reserved: {torch.cuda.memory_reserved() / (1024**3):.2f} GB") |
|
print("---------------") |
|
|
|
|
|
def set_classic_rock_prompt(): |
|
return "Classic rock with bluesy electric guitars, steady drums, groovy bass, Hammond organ fills, and a Led Zeppelin-inspired raw energy, maintaining a cohesive structure with dynamic solos and powerful choruses." |
|
|
|
def set_alternative_rock_prompt(): |
|
return "Alternative rock with distorted guitar riffs, punchy drums, melodic basslines, atmospheric synths, and a Nirvana-inspired grunge vibe, featuring introspective verses and explosive choruses." |
|
|
|
def set_detroit_techno_prompt(): |
|
return "Detroit techno with deep pulsing synths, driving basslines, crisp hi-hats, atmospheric pads, and a rhythmic groove inspired by Juan Atkins, maintaining a hypnotic and energetic flow." |
|
|
|
def set_deep_house_prompt(): |
|
return "Deep house with warm analog synth chords, soulful vocal chops, deep basslines, crisp hi-hats, and a laid-back groove inspired by Larry Heard, creating a consistent hypnotic vibe with smooth transitions." |
|
|
|
def set_smooth_jazz_prompt(): |
|
return "Smooth jazz with warm saxophone leads, expressive Rhodes piano chords, soft bossa nova drums, upright bass, and a George Benson-inspired improvisational feel, maintaining a cohesive and relaxing vibe." |
|
|
|
def set_bebop_jazz_prompt(): |
|
return "Bebop jazz with fast-paced saxophone solos, intricate piano runs, walking basslines, complex drum patterns, and a Charlie Parker-inspired improvisational style, featuring dynamic shifts and virtuosic performances." |
|
|
|
def set_baroque_classical_prompt(): |
|
return "Baroque classical with harpsichord, delicate violin, cello, flute, and a Vivaldi-inspired melodic structure, featuring intricate counterpoint and elegant ornamentation, maintaining a consistent baroque elegance." |
|
|
|
def set_romantic_classical_prompt(): |
|
return "Romantic classical with lush strings, expressive piano, dramatic brass, subtle woodwinds, and a Chopin-inspired melodic flow, building emotional intensity with sweeping crescendos and delicate pianissimos." |
|
|
|
def set_boom_bap_hiphop_prompt(): |
|
return "Boom bap hip-hop with gritty sampled drums, deep basslines, jazzy piano loops, vinyl scratches, and a J Dilla-inspired rhythmic groove, maintaining a consistent head-nodding vibe." |
|
|
|
def set_trap_hiphop_prompt(): |
|
return "Trap hip-hop with hard-hitting 808 bass, snappy snares, rapid hi-hats, eerie synth melodies, and a modern Atlanta-inspired sound, featuring catchy hooks and energetic drops." |
|
|
|
def set_pop_rock_prompt(): |
|
return "Pop rock with catchy electric guitar riffs, uplifting synths, steady drums, melodic basslines, and a Coldplay-inspired anthemic feel, featuring bright intros and powerful choruses." |
|
|
|
def set_fusion_jazz_prompt(): |
|
return "Fusion jazz with electric piano, funky basslines, intricate drum patterns, soaring trumpet, and a Herbie Hancock-inspired groove, blending jazz improvisation with rock and funk elements." |
|
|
|
def set_edm_prompt(): |
|
return "EDM with high-energy synth leads, pounding basslines, four-on-the-floor kicks, euphoric breakdowns, and a festival-ready drop, inspired by artists like Avicii and Calvin Harris." |
|
|
|
def set_indie_folk_prompt(): |
|
return "Indie folk with acoustic guitars, heartfelt vocals, gentle percussion, warm bass, and a Bon Iver-inspired intimate atmosphere, featuring layered harmonies and emotional crescendos." |
|
|
|
|
|
def apply_chorus(segment): |
|
delayed = segment - 6 |
|
delayed = delayed.set_frame_rate(segment.frame_rate) |
|
return segment.overlay(delayed, position=20) |
|
|
|
def apply_eq(segment): |
|
segment = segment.low_pass_filter(8000) |
|
segment = segment.high_pass_filter(80) |
|
return segment |
|
|
|
def apply_limiter(segment, max_db=-3.0): |
|
if segment.dBFS > max_db: |
|
segment = segment - (segment.dBFS - max_db) |
|
return segment |
|
|
|
def apply_final_gain(segment, target_db=-12.0): |
|
gain_adjustment = target_db - segment.dBFS |
|
return segment + gain_adjustment |
|
|
|
def apply_fade(segment, fade_in_duration=2000, fade_out_duration=2000): |
|
segment = segment.fade_in(fade_in_duration) |
|
segment = segment.fade_out(fade_out_duration) |
|
return segment |
|
|
|
|
|
def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, crossfade_duration: int): |
|
global musicgen_model |
|
if not instrumental_prompt.strip(): |
|
return None, "⚠️ Please enter a valid instrumental prompt!" |
|
try: |
|
start_time = time.time() |
|
|
|
total_duration = min(max(total_duration, 10), 90) |
|
chunk_duration = 15 |
|
num_chunks = 2 if total_duration <= 30 else 3 |
|
chunk_duration = total_duration / num_chunks |
|
|
|
overlap_duration = min(1.0, crossfade_duration / 1000.0) |
|
generation_duration = chunk_duration + overlap_duration |
|
|
|
audio_chunks = [] |
|
sample_rate = musicgen_model.sample_rate |
|
|
|
torch.manual_seed(42) |
|
np.random.seed(42) |
|
|
|
for i in range(num_chunks): |
|
chunk_prompt = instrumental_prompt |
|
print(f"Generating chunk {i+1}/{num_chunks} on GPU (prompt: {chunk_prompt})...") |
|
musicgen_model.set_generation_params( |
|
duration=generation_duration, |
|
use_sampling=True, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temperature=temperature, |
|
cfg_coef=cfg_scale |
|
) |
|
|
|
print_resource_usage(f"Before Chunk {i+1} Generation") |
|
|
|
with torch.no_grad(): |
|
with autocast(): |
|
audio_chunk = musicgen_model.generate([chunk_prompt], progress=True)[0] |
|
|
|
audio_chunk = audio_chunk.cpu().to(dtype=torch.float32) |
|
if audio_chunk.dim() == 1: |
|
audio_chunk = torch.stack([audio_chunk, audio_chunk], dim=0) |
|
elif audio_chunk.dim() == 2 and audio_chunk.shape[0] == 1: |
|
audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
|
elif audio_chunk.dim() == 2 and audio_chunk.shape[0] != 2: |
|
audio_chunk = audio_chunk[:1, :] |
|
audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
|
elif audio_chunk.dim() > 2: |
|
audio_chunk = audio_chunk.view(2, -1) |
|
|
|
if audio_chunk.shape[0] != 2: |
|
raise ValueError(f"Expected stereo audio with shape (2, samples), got shape {audio_chunk.shape}") |
|
|
|
temp_wav_path = f"temp_chunk_{i}.wav" |
|
chunk_path = f"chunk_{i}.mp3" |
|
torchaudio.save(temp_wav_path, audio_chunk, sample_rate, bits_per_sample=24) |
|
segment = AudioSegment.from_wav(temp_wav_path) |
|
segment.export(chunk_path, format="mp3", bitrate="320k") |
|
os.remove(temp_wav_path) |
|
audio_chunks.append(chunk_path) |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
time.sleep(0.5) |
|
print_resource_usage(f"After Chunk {i+1} Generation") |
|
|
|
print("Combining audio chunks...") |
|
final_segment = AudioSegment.from_mp3(audio_chunks[0]) |
|
for i in range(1, len(audio_chunks)): |
|
next_segment = AudioSegment.from_mp3(audio_chunks[i]) |
|
next_segment = next_segment + 1 |
|
final_segment = final_segment.append(next_segment, crossfade=crossfade_duration) |
|
|
|
final_segment = final_segment[:total_duration * 1000] |
|
|
|
print("Post-processing final track...") |
|
final_segment = apply_eq(final_segment) |
|
final_segment = apply_chorus(final_segment) |
|
final_segment = apply_limiter(final_segment, max_db=-3.0) |
|
final_segment = final_segment.normalize(headroom=-6.0) |
|
final_segment = apply_final_gain(final_segment, target_db=-12.0) |
|
|
|
mp3_path = "output_cleaned.mp3" |
|
final_segment.export( |
|
mp3_path, |
|
format="mp3", |
|
bitrate="320k", |
|
tags={"title": "GhostAI Instrumental", "artist": "GhostAI"} |
|
) |
|
print(f"Saved final audio to {mp3_path}") |
|
|
|
for chunk_path in audio_chunks: |
|
os.remove(chunk_path) |
|
|
|
print_resource_usage("After Final Generation") |
|
print(f"Total Generation Time: {time.time() - start_time:.2f} seconds") |
|
|
|
return mp3_path, "✅ Done!" |
|
except Exception as e: |
|
return None, f"❌ Generation failed: {e}" |
|
finally: |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
def clear_inputs(): |
|
return "", 3.0, 300, 0.95, 1.0, 30, 500 |
|
|
|
|
|
css = """ |
|
body { |
|
background: linear-gradient(135deg, #0A0A0A 0%, #1C2526 100%); |
|
color: #E0E0E0; |
|
font-family: 'Orbitron', sans-serif; |
|
margin: 0; |
|
padding: 0; |
|
} |
|
.header-container { |
|
text-align: center; |
|
padding: 15px 20px; |
|
background: rgba(0, 0, 0, 0.9); |
|
border-bottom: 1px solid #00FF9F; |
|
box-shadow: 0 0 10px rgba(161, 0, 255, 0.3); |
|
} |
|
#ghost-logo { |
|
font-size: 60px; |
|
display: block; |
|
margin: 0 auto; |
|
animation: glitch-ghost 1.5s infinite; |
|
text-shadow: 0 0 10px #A100FF, 0 0 20px #00FF9F; |
|
} |
|
h1 { |
|
color: #A100FF; |
|
font-size: 28px; |
|
margin: 5px 0; |
|
text-shadow: 0 0 5px #A100FF, 0 0 10px #00FF9F; |
|
animation: glitch-text 2s infinite; |
|
} |
|
p { |
|
color: #E0E0E0; |
|
font-size: 14px; |
|
margin: 5px 0; |
|
} |
|
.input-container { |
|
max-width: 1000px; |
|
margin: 20px auto; |
|
padding: 20px; |
|
background: rgba(28, 37, 38, 0.8); |
|
border-radius: 10px; |
|
box-shadow: 0 0 15px rgba(0, 255, 159, 0.3); |
|
} |
|
.textbox { |
|
background: #1A1A1A; |
|
border: 1px solid #A100FF; |
|
color: #E0E0E0; |
|
border-radius: 5px; |
|
padding: 10px; |
|
margin-bottom: 20px; |
|
} |
|
.genre-buttons { |
|
display: flex; |
|
justify-content: center; |
|
gap: 15px; |
|
margin-bottom: 20px; |
|
} |
|
.genre-btn { |
|
background: linear-gradient(45deg, #A100FF, #00FF9F); |
|
border: none; |
|
color: #0A0A0A; |
|
font-weight: bold; |
|
padding: 10px 20px; |
|
border-radius: 5px; |
|
transition: transform 0.3s ease, box-shadow 0.3s ease; |
|
} |
|
.genre-btn:hover { |
|
transform: scale(1.05); |
|
box-shadow: 0 0 15px #00FF9F; |
|
} |
|
.settings-container { |
|
max-width: 1000px; |
|
margin: 20px auto; |
|
padding: 20px; |
|
background: rgba(28, 37, 38, 0.8); |
|
border-radius: 10px; |
|
box-shadow: 0 0 15px rgba(0, 255, 159, 0.3); |
|
} |
|
.action-buttons { |
|
display: flex; |
|
justify-content: center; |
|
gap: 20px; |
|
margin-top: 20px; |
|
} |
|
button { |
|
background: linear-gradient(45deg, #A100FF, #00FF9F); |
|
border: none; |
|
color: #0A0A0A; |
|
font-weight: bold; |
|
padding: 12px 24px; |
|
border-radius: 5px; |
|
transition: transform 0.3s ease, box-shadow 0.3s ease; |
|
} |
|
button:hover { |
|
transform: scale(1.05); |
|
box-shadow: 0 0 15px #00FF9F; |
|
} |
|
.output-container { |
|
max-width: 1000px; |
|
margin: 20px auto; |
|
padding: 20px; |
|
background: rgba(28, 37, 38, 0.8); |
|
border-radius: 10px; |
|
box-shadow: 0 0 15px rgba(0, 255, 159, 0.3); |
|
text-align: center; |
|
} |
|
@keyframes glitch-ghost { |
|
0% { transform: translate(0, 0); opacity: 1; } |
|
20% { transform: translate(-5px, 2px); opacity: 0.8; } |
|
40% { transform: translate(5px, -2px); opacity: 0.6; } |
|
60% { transform: translate(-3px, 1px); opacity: 0.9; } |
|
80% { transform: translate(3px, -1px); opacity: 0.7; } |
|
100% { transform: translate(0, 0); opacity: 1; } |
|
} |
|
@keyframes glitch-text { |
|
0% { transform: translate(0, 0); text-shadow: 0 0 10px #A100FF, 0 0 20px #00FF9F; } |
|
20% { transform: translate(-2px, 1px); text-shadow: 0 0 15px #00FF9F, 0 0 25px #A100FF; } |
|
40% { transform: translate(2px, -1px); text-shadow: 0 0 10px #A100FF, 0 0 30px #00FF9F; } |
|
60% { transform: translate(-1px, 2px); text-shadow: 0 0 15px #00FF9F, 0 0 20px #A100FF; } |
|
80% { transform: translate(1px, -2px); text-shadow: 0 0 10px #A100FF, 0 0 25px #00FF9F; } |
|
100% { transform: translate(0, 0); text-shadow: 0 0 10px #A100FF, 0 0 20px #00FF9F; } |
|
} |
|
@font-face { |
|
font-family: 'Orbitron'; |
|
src: url('https://fonts.gstatic.com/s/orbitron/v29/yMJRMIlzdpvBhQQL_Qq7dy0.woff2') format('woff2'); |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(""" |
|
<div class="header-container"> |
|
<div id="ghost-logo">👻</div> |
|
<h1>GhostAI Music Generator</h1> |
|
<p>Summon the Sound of the Unknown</p> |
|
</div> |
|
""") |
|
|
|
with gr.Column(elem_classes="input-container"): |
|
instrumental_prompt = gr.Textbox( |
|
label="Instrumental Prompt", |
|
placeholder="Click a genre button below or type your own instrumental prompt", |
|
lines=4, |
|
elem_classes="textbox" |
|
) |
|
with gr.Row(elem_classes="genre-buttons"): |
|
classic_rock_btn = gr.Button("Classic Rock", elem_classes="genre-btn") |
|
alternative_rock_btn = gr.Button("Alternative Rock", elem_classes="genre-btn") |
|
detroit_techno_btn = gr.Button("Detroit Techno", elem_classes="genre-btn") |
|
deep_house_btn = gr.Button("Deep House", elem_classes="genre-btn") |
|
smooth_jazz_btn = gr.Button("Smooth Jazz", elem_classes="genre-btn") |
|
bebop_jazz_btn = gr.Button("Bebop Jazz", elem_classes="genre-btn") |
|
baroque_classical_btn = gr.Button("Baroque Classical", elem_classes="genre-btn") |
|
romantic_classical_btn = gr.Button("Romantic Classical", elem_classes="genre-btn") |
|
boom_bap_hiphop_btn = gr.Button("Boom Bap Hip-Hop", elem_classes="genre-btn") |
|
trap_hiphop_btn = gr.Button("Trap Hip-Hop", elem_classes="genre-btn") |
|
pop_rock_btn = gr.Button("Pop Rock", elem_classes="genre-btn") |
|
fusion_jazz_btn = gr.Button("Fusion Jazz", elem_classes="genre-btn") |
|
edm_btn = gr.Button("EDM", elem_classes="genre-btn") |
|
indie_folk_btn = gr.Button("Indie Folk", elem_classes="genre-btn") |
|
|
|
with gr.Column(elem_classes="settings-container"): |
|
cfg_scale = gr.Slider( |
|
label="Guidance Scale (CFG)", |
|
minimum=1.0, |
|
maximum=10.0, |
|
value=3.0, |
|
step=0.1, |
|
info="Higher values make the instrumental more closely follow the prompt, but may reduce diversity." |
|
) |
|
top_k = gr.Slider( |
|
label="Top-K Sampling", |
|
minimum=10, |
|
maximum=500, |
|
value=300, |
|
step=10, |
|
info="Limits sampling to the top k most likely tokens. Higher values increase diversity." |
|
) |
|
top_p = gr.Slider( |
|
label="Top-P Sampling (Nucleus Sampling)", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.1, |
|
info="Keeps tokens with cumulative probability above p. Higher values increase diversity." |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=1.0, |
|
step=0.1, |
|
info="Controls randomness. Higher values make output more diverse but less predictable." |
|
) |
|
total_duration = gr.Slider( |
|
label="Total Duration (seconds)", |
|
minimum=10, |
|
maximum=90, |
|
value=30, |
|
step=1, |
|
info="Total duration of the track (10 to 90 seconds)." |
|
) |
|
crossfade_duration = gr.Slider( |
|
label="Crossfade Duration (ms)", |
|
minimum=100, |
|
maximum=2000, |
|
value=500, |
|
step=100, |
|
info="Crossfade duration between chunks for smoother transitions." |
|
) |
|
with gr.Row(elem_classes="action-buttons"): |
|
gen_btn = gr.Button("Generate Music") |
|
clr_btn = gr.Button("Clear Inputs") |
|
|
|
with gr.Column(elem_classes="output-container"): |
|
out_audio = gr.Audio(label="Generated Stereo Instrumental Track", type="filepath") |
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
|
classic_rock_btn.click(set_classic_rock_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
alternative_rock_btn.click(set_alternative_rock_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
detroit_techno_btn.click(set_detroit_techno_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
deep_house_btn.click(set_deep_house_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
smooth_jazz_btn.click(set_smooth_jazz_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
bebop_jazz_btn.click(set_bebop_jazz_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
baroque_classical_btn.click(set_baroque_classical_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
romantic_classical_btn.click(set_romantic_classical_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
boom_bap_hiphop_btn.click(set_boom_bap_hiphop_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
trap_hiphop_btn.click(set_trap_hiphop_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
pop_rock_btn.click(set_pop_rock_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
fusion_jazz_btn.click(set_fusion_jazz_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
edm_btn.click(set_edm_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
indie_folk_btn.click(set_indie_folk_prompt, inputs=None, outputs=[instrumental_prompt]) |
|
gen_btn.click( |
|
generate_music, |
|
inputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, crossfade_duration], |
|
outputs=[out_audio, status] |
|
) |
|
clr_btn.click( |
|
clear_inputs, |
|
inputs=None, |
|
outputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, crossfade_duration] |
|
) |
|
|
|
|
|
app = demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=9999, |
|
share=False, |
|
inbrowser=False, |
|
show_error=True |
|
) |
|
try: |
|
fastapi_app = demo._server.app |
|
fastapi_app.docs_url = None |
|
fastapi_app.redoc_url = None |
|
fastapi_app.openapi_url = None |
|
except Exception: |
|
pass |