GHOSTSONAFB / api.py
ghostai1's picture
Create api.py
959a928 verified
raw
history blame
13.8 kB
import gradio as gr
from gradio_client import Client
import os
import json
import random
from datetime import datetime
import numpy as np
from pydub import AudioSegment
import logging
import configparser
# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Configuration
BASE_DIR = "/home/pi5/muzax"
MP3_DIR = os.path.join(BASE_DIR, "mp3")
JSON_DIR = os.path.join(BASE_DIR, "json")
JSON_LOG = os.path.join(JSON_DIR, "render_log.json")
INI_FILE = os.path.join(BASE_DIR, "band_styles.ini")
API_URL = "http://192.168.0.155:9999/"
SONG_DURATION = 120 # API-compatible duration (integer)
TARGET_DURATION_MS = 180000 # 180 seconds in milliseconds
# Load band styles from INI
config = configparser.ConfigParser()
if not os.path.exists(INI_FILE):
logger.error(f"INI file not found: {INI_FILE}")
raise FileNotFoundError(f"INI file not found: {INI_FILE}")
config.read(INI_FILE)
ALLOWED_BANDS = config.sections()
# Create directories
for directory in [BASE_DIR, MP3_DIR, JSON_DIR]:
if not os.path.exists(directory):
os.makedirs(directory)
logger.info(f"Created directory: {directory}")
# Initialize JSON log
if not os.path.exists(JSON_LOG):
with open(JSON_LOG, "w") as f:
json.dump([], f)
logger.info(f"Initialized JSON log: {JSON_LOG}")
def generate_random_params(band):
"""Generate random parameters from INI file."""
style = config[band]
bpm = random.randint(int(style["bpm_min"]), int(style["bpm_max"]))
drum_beat = random.choice(style["drum_beat"].split(","))
synthesizer = random.choice(style["synthesizer"].split(","))
rhythmic_steps = random.choice(style["rhythmic_steps"].split(","))
bass_style = random.choice(style["bass_style"].split(","))
guitar_style = random.choice(style["guitar_style"].split(","))
return bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style
def generate_music_placeholder(prompt, duration, bpm, drum_beat, bass_style, guitar_style):
"""Placeholder for music generation."""
sample_rate = 44100
duration_ms = int(duration) * 1000
audio = AudioSegment.silent(duration=duration_ms)
t = np.linspace(0, float(duration), int(sample_rate * float(duration)), endpoint=False)
freq = 440 if guitar_style == "clean" else 220
sine_wave = 0.5 * np.sin(2 * np.pi * freq * t)
audio_samples = (sine_wave * 32767).astype(np.int16)
audio = AudioSegment(
audio_samples.tobytes(),
frame_rate=sample_rate,
sample_width=2,
channels=1
)
return audio
def extend_audio(audio, target_duration_ms):
"""Extend audio to target duration by looping."""
current_duration = len(audio)
if current_duration >= target_duration_ms:
return audio[:target_duration_ms] # Trim to exact duration
# Loop audio to exceed target duration
extended_audio = audio
while len(extended_audio) < target_duration_ms:
extended_audio += audio
return extended_audio[:target_duration_ms] # Trim to exact duration
def save_to_mp3(audio, filename):
"""Save audio to MP3."""
filepath = os.path.join(MP3_DIR, filename)
audio.export(filepath, format="mp3")
logger.info(f"Saved MP3 to {filepath}")
return filepath
def update_json_log(band, params, filepath, status):
"""Update JSON log."""
with open(JSON_LOG, "r") as f:
log = json.load(f)
render_entry = {
"timestamp": datetime.now().isoformat(),
"band": band,
"parameters": params,
"filepath": filepath,
"status": status
}
log.append(render_entry)
with open(JSON_LOG, "w") as f:
json.dump(log, f, indent=2)
logger.info(f"Updated JSON log: {render_entry}")
def generate_song(band, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
"""Generate and save a 180-second song."""
try:
client = Client(API_URL)
params = {
"bpm": bpm,
"drum_beat": drum_beat,
"synthesizer": synthesizer,
"rhythmic_steps": rhythmic_steps,
"bass_style": bass_style,
"guitar_style": guitar_style,
"api_name": config[band]["api_name"]
}
prompt = client.predict(**params)
logger.info(f"Prompt for {band}: {prompt}")
music_params = {
"instrumental_prompt": prompt,
"cfg_scale": 3.0,
"top_k": 300,
"top_p": 0.9,
"temperature": 0.8,
"total_duration": SONG_DURATION,
"bpm": bpm,
"drum_beat": drum_beat,
"synthesizer": synthesizer,
"rhythmic_steps": rhythmic_steps,
"bass_style": bass_style,
"guitar_style": guitar_style,
"target_volume": -23.0,
"preset": "rock",
"vram_status": "",
"api_name": "/generate_music"
}
result = client.predict(**music_params)
filepath, status, _ = result
if not filepath:
logger.warning("API returned no audio, using placeholder.")
audio = generate_music_placeholder(
prompt, SONG_DURATION, bpm, drum_beat, bass_style, guitar_style
)
audio = extend_audio(audio, TARGET_DURATION_MS)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{band}_{timestamp}.mp3"
filepath = save_to_mp3(audio, filename)
status = "Generated with placeholder, extended to 180 seconds"
else:
# Load API-generated audio and extend
audio = AudioSegment.from_file(filepath)
audio = extend_audio(audio, TARGET_DURATION_MS)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{band}_{timestamp}.mp3"
filepath = save_to_mp3(audio, filename)
status = "Generated with API, extended to 180 seconds"
update_json_log(band, music_params, filepath, status)
return filepath, status
except Exception as e:
logger.error(f"Error generating song: {str(e)}")
return None, f"Error: {str(e)}"
def get_last_five_songs():
"""Get the last 5 songs from JSON log."""
try:
with open(JSON_LOG, "r") as f:
log = json.load(f)
log.sort(key=lambda x: x["timestamp"], reverse=True)
return [
{
"timestamp": entry["timestamp"],
"band": entry["band"].replace("_", " ").title(),
"filepath": entry["filepath"],
"parameters": entry["parameters"],
"status": entry["status"]
}
for entry in log[:5]
]
except Exception as e:
logger.error(f"Error reading JSON log: {str(e)}")
return []
def create_gradio_interface():
"""Create Gradio interface."""
css = """
.gradio-container {background-color: #2b2b2b; color: #ffffff; font-family: Arial, sans-serif;}
.gr-button-primary {background-color: #4a90e2; color: #ffffff; border: none; padding: 10px 20px; border-radius: 5px;}
.gr-button-primary:hover {background-color: #357abd;}
.gr-button-secondary {background-color: #4a4a4a; color: #ffffff; border: none; padding: 10px 20px; border-radius: 5px;}
.gr-button-secondary:hover {background-color: #333333;}
.gr-panel {background-color: #3c3c3c; border: none; border-radius: 8px; padding: 15px;}
.gr-textbox, .gr-slider, .gr-dropdown, .gr-audio {background-color: #4a4a4a; color: #ffffff; border: none; border-radius: 5px;}
.gr-markdown h1, .gr-markdown h2, .gr-markdown h3 {color: #ffffff;}
"""
with gr.Blocks(title="Muzax Rock Generator", css=css) as demo:
gr.Markdown(
"""
# Muzax Rock Song Generator
Create 3-minute rock songs inspired by top bands. Save MP3s to /home/pi5/muzax/mp3.
"""
)
with gr.Tabs():
for band in ALLOWED_BANDS:
with gr.Tab(label=band.replace("_", " ").title()):
gr.Markdown(f"### {band.replace('_', ' ').title()} Song Generator")
with gr.Column():
bpm = gr.Slider(
minimum=60,
maximum=180,
value=120,
step=1,
label="Tempo (BPM) ๐ŸŽต",
info="Song speed in beats per minute."
)
drum_beat = gr.Dropdown(
choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing"],
value="standard rock",
label="Drum Beat ๐Ÿฅ",
info="Drum style."
)
synthesizer = gr.Dropdown(
choices=["none", "analog synth", "digital pad", "arpeggiated synth"],
value="none",
label="Synthesizer ๐ŸŽน",
info="Synth sound."
)
rhythmic_steps = gr.Dropdown(
choices=["none", "syncopated steps", "steady steps", "complex steps"],
value="steady steps",
label="Rhythmic Steps ๐Ÿ‘ฃ",
info="Rhythm complexity."
)
bass_style = gr.Dropdown(
choices=["none", "slap bass", "deep bass", "melodic bass"],
value="deep bass",
label="Bass Style ๐ŸŽธ",
info="Bass guitar style."
)
guitar_style = gr.Dropdown(
choices=["none", "distorted", "clean", "jangle"],
value="distorted",
label="Guitar Style ๐ŸŽธ",
info="Guitar sound."
)
with gr.Row():
randomize_btn = gr.Button("Randomize", variant="secondary")
generate_btn = gr.Button("Generate Song", variant="primary")
audio_output = gr.Audio(
label="Generated Song ๐ŸŽต",
type="filepath",
interactive=False
)
status_output = gr.Textbox(
label="Status ๐Ÿ“ข",
placeholder="Status updates here.",
interactive=False
)
randomize_btn.click(
fn=lambda: generate_random_params(band),
outputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style]
)
generate_btn.click(
fn=generate_song,
inputs=[gr.State(value=band), bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style],
outputs=[audio_output, status_output]
)
with gr.Tab("Recent Songs"):
gr.Markdown("### Last 5 Songs")
recent_songs = gr.State(value=get_last_five_songs())
for i in range(5):
with gr.Group():
gr.Markdown(f"#### Song {i+1}")
audio_player = gr.Audio(label=f"Song {i+1}", type="filepath", interactive=False)
info_text = gr.Textbox(label=f"Details {i+1}", interactive=False)
play_btn = gr.Button(f"Play Song {i+1}", variant="primary")
def play_song(song_list, index=i):
if index < len(song_list) and os.path.exists(song_list[index]["filepath"]):
return (
song_list[index]["filepath"],
f"Band: {song_list[index]['band']}\nTime: {song_list[index]['timestamp']}\nParams: {json.dumps(song_list[index]['parameters'], indent=2)}\nStatus: {song_list[index]['status']}"
)
return None, "Song unavailable."
play_btn.click(
fn=play_song,
inputs=[recent_songs],
outputs=[audio_player, info_text]
)
refresh_btn = gr.Button("Refresh Songs", variant="secondary")
refresh_btn.click(
fn=get_last_five_songs,
outputs=[recent_songs]
)
with gr.Tab("Render Log"):
gr.Markdown("### Render Log")
log_output = gr.JSON(label="All Renders", value=lambda: json.load(open(JSON_LOG)))
refresh_log_btn = gr.Button("Refresh Log", variant="primary")
refresh_log_btn.click(
fn=lambda: json.load(open(JSON_LOG)),
outputs=[log_output]
)
return demo
if __name__ == "__main__":
try:
demo = create_gradio_interface()
demo.launch(server_name="0.0.0.0", server_port=3223, share=True)
logger.info("Gradio launched on 0.0.0.0:3223 with public sharing enabled")
except Exception as e:
logger.error(f"Failed to launch Gradio: {str(e)}")