|
import gradio as gr |
|
from gradio_client import Client |
|
import os |
|
import json |
|
import random |
|
from datetime import datetime |
|
import numpy as np |
|
from pydub import AudioSegment |
|
import logging |
|
import configparser |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
BASE_DIR = "/home/pi5/muzax" |
|
MP3_DIR = os.path.join(BASE_DIR, "mp3") |
|
JSON_DIR = os.path.join(BASE_DIR, "json") |
|
JSON_LOG = os.path.join(JSON_DIR, "render_log.json") |
|
INI_FILE = os.path.join(BASE_DIR, "band_styles.ini") |
|
API_URL = "http://192.168.0.155:9999/" |
|
SONG_DURATION = 120 |
|
TARGET_DURATION_MS = 180000 |
|
|
|
|
|
config = configparser.ConfigParser() |
|
if not os.path.exists(INI_FILE): |
|
logger.error(f"INI file not found: {INI_FILE}") |
|
raise FileNotFoundError(f"INI file not found: {INI_FILE}") |
|
config.read(INI_FILE) |
|
ALLOWED_BANDS = config.sections() |
|
|
|
|
|
for directory in [BASE_DIR, MP3_DIR, JSON_DIR]: |
|
if not os.path.exists(directory): |
|
os.makedirs(directory) |
|
logger.info(f"Created directory: {directory}") |
|
|
|
|
|
if not os.path.exists(JSON_LOG): |
|
with open(JSON_LOG, "w") as f: |
|
json.dump([], f) |
|
logger.info(f"Initialized JSON log: {JSON_LOG}") |
|
|
|
def generate_random_params(band): |
|
"""Generate random parameters from INI file.""" |
|
style = config[band] |
|
bpm = random.randint(int(style["bpm_min"]), int(style["bpm_max"])) |
|
drum_beat = random.choice(style["drum_beat"].split(",")) |
|
synthesizer = random.choice(style["synthesizer"].split(",")) |
|
rhythmic_steps = random.choice(style["rhythmic_steps"].split(",")) |
|
bass_style = random.choice(style["bass_style"].split(",")) |
|
guitar_style = random.choice(style["guitar_style"].split(",")) |
|
return bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style |
|
|
|
def generate_music_placeholder(prompt, duration, bpm, drum_beat, bass_style, guitar_style): |
|
"""Placeholder for music generation.""" |
|
sample_rate = 44100 |
|
duration_ms = int(duration) * 1000 |
|
audio = AudioSegment.silent(duration=duration_ms) |
|
t = np.linspace(0, float(duration), int(sample_rate * float(duration)), endpoint=False) |
|
freq = 440 if guitar_style == "clean" else 220 |
|
sine_wave = 0.5 * np.sin(2 * np.pi * freq * t) |
|
audio_samples = (sine_wave * 32767).astype(np.int16) |
|
audio = AudioSegment( |
|
audio_samples.tobytes(), |
|
frame_rate=sample_rate, |
|
sample_width=2, |
|
channels=1 |
|
) |
|
return audio |
|
|
|
def extend_audio(audio, target_duration_ms): |
|
"""Extend audio to target duration by looping.""" |
|
current_duration = len(audio) |
|
if current_duration >= target_duration_ms: |
|
return audio[:target_duration_ms] |
|
|
|
extended_audio = audio |
|
while len(extended_audio) < target_duration_ms: |
|
extended_audio += audio |
|
return extended_audio[:target_duration_ms] |
|
|
|
def save_to_mp3(audio, filename): |
|
"""Save audio to MP3.""" |
|
filepath = os.path.join(MP3_DIR, filename) |
|
audio.export(filepath, format="mp3") |
|
logger.info(f"Saved MP3 to {filepath}") |
|
return filepath |
|
|
|
def update_json_log(band, params, filepath, status): |
|
"""Update JSON log.""" |
|
with open(JSON_LOG, "r") as f: |
|
log = json.load(f) |
|
render_entry = { |
|
"timestamp": datetime.now().isoformat(), |
|
"band": band, |
|
"parameters": params, |
|
"filepath": filepath, |
|
"status": status |
|
} |
|
log.append(render_entry) |
|
with open(JSON_LOG, "w") as f: |
|
json.dump(log, f, indent=2) |
|
logger.info(f"Updated JSON log: {render_entry}") |
|
|
|
def generate_song(band, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style): |
|
"""Generate and save a 180-second song.""" |
|
try: |
|
client = Client(API_URL) |
|
params = { |
|
"bpm": bpm, |
|
"drum_beat": drum_beat, |
|
"synthesizer": synthesizer, |
|
"rhythmic_steps": rhythmic_steps, |
|
"bass_style": bass_style, |
|
"guitar_style": guitar_style, |
|
"api_name": config[band]["api_name"] |
|
} |
|
prompt = client.predict(**params) |
|
logger.info(f"Prompt for {band}: {prompt}") |
|
|
|
music_params = { |
|
"instrumental_prompt": prompt, |
|
"cfg_scale": 3.0, |
|
"top_k": 300, |
|
"top_p": 0.9, |
|
"temperature": 0.8, |
|
"total_duration": SONG_DURATION, |
|
"bpm": bpm, |
|
"drum_beat": drum_beat, |
|
"synthesizer": synthesizer, |
|
"rhythmic_steps": rhythmic_steps, |
|
"bass_style": bass_style, |
|
"guitar_style": guitar_style, |
|
"target_volume": -23.0, |
|
"preset": "rock", |
|
"vram_status": "", |
|
"api_name": "/generate_music" |
|
} |
|
|
|
result = client.predict(**music_params) |
|
filepath, status, _ = result |
|
|
|
if not filepath: |
|
logger.warning("API returned no audio, using placeholder.") |
|
audio = generate_music_placeholder( |
|
prompt, SONG_DURATION, bpm, drum_beat, bass_style, guitar_style |
|
) |
|
audio = extend_audio(audio, TARGET_DURATION_MS) |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
filename = f"{band}_{timestamp}.mp3" |
|
filepath = save_to_mp3(audio, filename) |
|
status = "Generated with placeholder, extended to 180 seconds" |
|
else: |
|
|
|
audio = AudioSegment.from_file(filepath) |
|
audio = extend_audio(audio, TARGET_DURATION_MS) |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
filename = f"{band}_{timestamp}.mp3" |
|
filepath = save_to_mp3(audio, filename) |
|
status = "Generated with API, extended to 180 seconds" |
|
|
|
update_json_log(band, music_params, filepath, status) |
|
return filepath, status |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating song: {str(e)}") |
|
return None, f"Error: {str(e)}" |
|
|
|
def get_last_five_songs(): |
|
"""Get the last 5 songs from JSON log.""" |
|
try: |
|
with open(JSON_LOG, "r") as f: |
|
log = json.load(f) |
|
log.sort(key=lambda x: x["timestamp"], reverse=True) |
|
return [ |
|
{ |
|
"timestamp": entry["timestamp"], |
|
"band": entry["band"].replace("_", " ").title(), |
|
"filepath": entry["filepath"], |
|
"parameters": entry["parameters"], |
|
"status": entry["status"] |
|
} |
|
for entry in log[:5] |
|
] |
|
except Exception as e: |
|
logger.error(f"Error reading JSON log: {str(e)}") |
|
return [] |
|
|
|
def create_gradio_interface(): |
|
"""Create Gradio interface.""" |
|
css = """ |
|
.gradio-container {background-color: #2b2b2b; color: #ffffff; font-family: Arial, sans-serif;} |
|
.gr-button-primary {background-color: #4a90e2; color: #ffffff; border: none; padding: 10px 20px; border-radius: 5px;} |
|
.gr-button-primary:hover {background-color: #357abd;} |
|
.gr-button-secondary {background-color: #4a4a4a; color: #ffffff; border: none; padding: 10px 20px; border-radius: 5px;} |
|
.gr-button-secondary:hover {background-color: #333333;} |
|
.gr-panel {background-color: #3c3c3c; border: none; border-radius: 8px; padding: 15px;} |
|
.gr-textbox, .gr-slider, .gr-dropdown, .gr-audio {background-color: #4a4a4a; color: #ffffff; border: none; border-radius: 5px;} |
|
.gr-markdown h1, .gr-markdown h2, .gr-markdown h3 {color: #ffffff;} |
|
""" |
|
with gr.Blocks(title="Muzax Rock Generator", css=css) as demo: |
|
gr.Markdown( |
|
""" |
|
# Muzax Rock Song Generator |
|
Create 3-minute rock songs inspired by top bands. Save MP3s to /home/pi5/muzax/mp3. |
|
""" |
|
) |
|
|
|
with gr.Tabs(): |
|
for band in ALLOWED_BANDS: |
|
with gr.Tab(label=band.replace("_", " ").title()): |
|
gr.Markdown(f"### {band.replace('_', ' ').title()} Song Generator") |
|
with gr.Column(): |
|
bpm = gr.Slider( |
|
minimum=60, |
|
maximum=180, |
|
value=120, |
|
step=1, |
|
label="Tempo (BPM) ๐ต", |
|
info="Song speed in beats per minute." |
|
) |
|
drum_beat = gr.Dropdown( |
|
choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing"], |
|
value="standard rock", |
|
label="Drum Beat ๐ฅ", |
|
info="Drum style." |
|
) |
|
synthesizer = gr.Dropdown( |
|
choices=["none", "analog synth", "digital pad", "arpeggiated synth"], |
|
value="none", |
|
label="Synthesizer ๐น", |
|
info="Synth sound." |
|
) |
|
rhythmic_steps = gr.Dropdown( |
|
choices=["none", "syncopated steps", "steady steps", "complex steps"], |
|
value="steady steps", |
|
label="Rhythmic Steps ๐ฃ", |
|
info="Rhythm complexity." |
|
) |
|
bass_style = gr.Dropdown( |
|
choices=["none", "slap bass", "deep bass", "melodic bass"], |
|
value="deep bass", |
|
label="Bass Style ๐ธ", |
|
info="Bass guitar style." |
|
) |
|
guitar_style = gr.Dropdown( |
|
choices=["none", "distorted", "clean", "jangle"], |
|
value="distorted", |
|
label="Guitar Style ๐ธ", |
|
info="Guitar sound." |
|
) |
|
|
|
with gr.Row(): |
|
randomize_btn = gr.Button("Randomize", variant="secondary") |
|
generate_btn = gr.Button("Generate Song", variant="primary") |
|
|
|
audio_output = gr.Audio( |
|
label="Generated Song ๐ต", |
|
type="filepath", |
|
interactive=False |
|
) |
|
status_output = gr.Textbox( |
|
label="Status ๐ข", |
|
placeholder="Status updates here.", |
|
interactive=False |
|
) |
|
|
|
randomize_btn.click( |
|
fn=lambda: generate_random_params(band), |
|
outputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style] |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_song, |
|
inputs=[gr.State(value=band), bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], |
|
outputs=[audio_output, status_output] |
|
) |
|
|
|
with gr.Tab("Recent Songs"): |
|
gr.Markdown("### Last 5 Songs") |
|
recent_songs = gr.State(value=get_last_five_songs()) |
|
|
|
for i in range(5): |
|
with gr.Group(): |
|
gr.Markdown(f"#### Song {i+1}") |
|
audio_player = gr.Audio(label=f"Song {i+1}", type="filepath", interactive=False) |
|
info_text = gr.Textbox(label=f"Details {i+1}", interactive=False) |
|
play_btn = gr.Button(f"Play Song {i+1}", variant="primary") |
|
|
|
def play_song(song_list, index=i): |
|
if index < len(song_list) and os.path.exists(song_list[index]["filepath"]): |
|
return ( |
|
song_list[index]["filepath"], |
|
f"Band: {song_list[index]['band']}\nTime: {song_list[index]['timestamp']}\nParams: {json.dumps(song_list[index]['parameters'], indent=2)}\nStatus: {song_list[index]['status']}" |
|
) |
|
return None, "Song unavailable." |
|
|
|
play_btn.click( |
|
fn=play_song, |
|
inputs=[recent_songs], |
|
outputs=[audio_player, info_text] |
|
) |
|
|
|
refresh_btn = gr.Button("Refresh Songs", variant="secondary") |
|
refresh_btn.click( |
|
fn=get_last_five_songs, |
|
outputs=[recent_songs] |
|
) |
|
|
|
with gr.Tab("Render Log"): |
|
gr.Markdown("### Render Log") |
|
log_output = gr.JSON(label="All Renders", value=lambda: json.load(open(JSON_LOG))) |
|
refresh_log_btn = gr.Button("Refresh Log", variant="primary") |
|
refresh_log_btn.click( |
|
fn=lambda: json.load(open(JSON_LOG)), |
|
outputs=[log_output] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
try: |
|
demo = create_gradio_interface() |
|
demo.launch(server_name="0.0.0.0", server_port=3223, share=True) |
|
logger.info("Gradio launched on 0.0.0.0:3223 with public sharing enabled") |
|
except Exception as e: |
|
logger.error(f"Failed to launch Gradio: {str(e)}") |