Create app.py
Browse filesspaces python file
app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
from diffusers import DiffusionPipeline
|
5 |
+
import soundfile as sf
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# Load text tokenizer and embedding model (umt5-base)
|
9 |
+
def load_text_processor():
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained("./umt5-base")
|
11 |
+
text_model = AutoModel.from_pretrained(
|
12 |
+
"./umt5-base",
|
13 |
+
use_safetensors=True,
|
14 |
+
torch_dtype=torch.float16,
|
15 |
+
device_map="auto"
|
16 |
+
)
|
17 |
+
return tokenizer, text_model
|
18 |
+
|
19 |
+
# Load the transformer backbone (phantomstep_transformer)
|
20 |
+
def load_transformer():
|
21 |
+
transformer = DiffusionPipeline.from_pretrained(
|
22 |
+
"./phantomstep_transformer",
|
23 |
+
use_safetensors=True,
|
24 |
+
torch_dtype=torch.float16,
|
25 |
+
device_map="auto"
|
26 |
+
)
|
27 |
+
return transformer
|
28 |
+
|
29 |
+
# Load the DCAE for audio encoding/decoding (phantomstep_dcae)
|
30 |
+
def load_dcae():
|
31 |
+
dcae = DiffusionPipeline.from_pretrained(
|
32 |
+
"./phantomstep_dcae",
|
33 |
+
use_safetensors=True,
|
34 |
+
torch_dtype=torch.float16,
|
35 |
+
device_map="auto"
|
36 |
+
)
|
37 |
+
return dcae
|
38 |
+
|
39 |
+
# Load the vocoder for audio synthesis (phantomstep_vocoder)
|
40 |
+
def load_vocoder():
|
41 |
+
vocoder = DiffusionPipeline.from_pretrained(
|
42 |
+
"./phantomstep_vocoder",
|
43 |
+
use_safetensors=True,
|
44 |
+
torch_dtype=torch.float16,
|
45 |
+
device_map="auto"
|
46 |
+
)
|
47 |
+
return vocoder
|
48 |
+
|
49 |
+
# Generate music from a text prompt
|
50 |
+
def generate_music(prompt, duration=20, seed=42):
|
51 |
+
torch.manual_seed(seed)
|
52 |
+
|
53 |
+
# Load all components
|
54 |
+
tokenizer, text_model = load_text_processor()
|
55 |
+
transformer = load_transformer()
|
56 |
+
dcae = load_dcae()
|
57 |
+
vocoder = load_vocoder()
|
58 |
+
|
59 |
+
# Step 1: Process text prompt to embeddings
|
60 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
61 |
+
inputs = {k: v.to(text_model.device) for k, v in inputs.items()}
|
62 |
+
with torch.no_grad():
|
63 |
+
embeddings = text_model(**inputs).last_hidden_state.mean(dim=1)
|
64 |
+
|
65 |
+
# Step 2: Pass embeddings through transformer
|
66 |
+
transformer_output = transformer(
|
67 |
+
embeddings,
|
68 |
+
num_inference_steps=50,
|
69 |
+
audio_length_in_s=duration
|
70 |
+
).audios[0]
|
71 |
+
|
72 |
+
# Step 3: Decode audio features with DCAE
|
73 |
+
dcae_output = dcae(
|
74 |
+
transformer_output,
|
75 |
+
num_inference_steps=50,
|
76 |
+
audio_length_in_s=duration
|
77 |
+
).audios[0]
|
78 |
+
|
79 |
+
# Step 4: Synthesize final audio with vocoder
|
80 |
+
audio = vocoder(
|
81 |
+
dcae_output,
|
82 |
+
num_inference_steps=50,
|
83 |
+
audio_length_in_s=duration
|
84 |
+
).audios[0]
|
85 |
+
|
86 |
+
# Save audio to a file
|
87 |
+
output_path = "output.wav"
|
88 |
+
sf.write(output_path, audio, 22050) # 22kHz sample rate
|
89 |
+
return output_path
|
90 |
+
|
91 |
+
# Gradio interface
|
92 |
+
with gr.Blocks(title="PhantomStep: Text-to-Music Generation 🎵") as demo:
|
93 |
+
gr.Markdown("# PhantomStep by GhostAI 🚀")
|
94 |
+
gr.Markdown("Enter a text prompt to generate music! 🎶")
|
95 |
+
|
96 |
+
prompt_input = gr.Textbox(label="Text Prompt", placeholder="A jazzy piano melody with a fast tempo")
|
97 |
+
duration_input = gr.Slider(label="Duration (seconds)", minimum=10, maximum=60, value=20, step=1)
|
98 |
+
seed_input = gr.Number(label="Random Seed", value=42, precision=0)
|
99 |
+
generate_button = gr.Button("Generate Music")
|
100 |
+
|
101 |
+
audio_output = gr.Audio(label="Generated Music")
|
102 |
+
|
103 |
+
generate_button.click(
|
104 |
+
fn=generate_music,
|
105 |
+
inputs=[prompt_input, duration_input, seed_input],
|
106 |
+
outputs=audio_output
|
107 |
+
)
|
108 |
+
|
109 |
+
demo.launch()
|