Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
b7070f2
1
Parent(s):
db8b2d5
Refactor inference functions to accept DEVICE and MODEL parameters for TC5, TC6, and TC7; update model loading to use GPU if available.
Browse files
app.py
CHANGED
@@ -11,34 +11,43 @@ from tc7 import infer as tc7infer
|
|
11 |
from gradio_client import Client, handle_file
|
12 |
import tempfile
|
13 |
|
14 |
-
|
15 |
|
16 |
# Load model once
|
17 |
tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
|
18 |
-
tc5.to(
|
19 |
tc5.eval()
|
|
|
|
|
|
|
20 |
|
21 |
# Load TC6 model
|
22 |
tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
|
23 |
-
tc6.to(
|
24 |
tc6.eval()
|
|
|
|
|
|
|
25 |
|
26 |
# Load TC7 model
|
27 |
tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
|
28 |
-
tc7.to(
|
29 |
tc7.eval()
|
|
|
|
|
|
|
30 |
|
31 |
synthesizer = Client("ryanlinjui/taiko-music-generator")
|
32 |
|
33 |
|
34 |
-
def infer_tc5(audio, nps, bpm, offset):
|
35 |
audio_path = audio
|
36 |
filename = audio_path.split("/")[-1]
|
37 |
# Preprocess
|
38 |
mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps)
|
39 |
# Inference
|
40 |
don_energy, ka_energy, drumroll_energy = tc5infer.run_inference(
|
41 |
-
|
42 |
)
|
43 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
44 |
onsets = tc5infer.decode_onsets(
|
@@ -91,7 +100,7 @@ def infer_tc5(audio, nps, bpm, offset):
|
|
91 |
return oni_audio, plot, tja_content
|
92 |
|
93 |
|
94 |
-
def infer_tc6(audio, nps, bpm, offset, difficulty, level):
|
95 |
audio_path = audio
|
96 |
filename = audio_path.split("/")[-1]
|
97 |
# Preprocess
|
@@ -101,7 +110,7 @@ def infer_tc6(audio, nps, bpm, offset, difficulty, level):
|
|
101 |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
102 |
# Inference
|
103 |
don_energy, ka_energy, drumroll_energy = tc6infer.run_inference(
|
104 |
-
|
105 |
)
|
106 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
107 |
onsets = tc6infer.decode_onsets(
|
@@ -154,7 +163,7 @@ def infer_tc6(audio, nps, bpm, offset, difficulty, level):
|
|
154 |
return oni_audio, plot, tja_content
|
155 |
|
156 |
|
157 |
-
def infer_tc7(audio, nps, bpm, offset, difficulty, level):
|
158 |
audio_path = audio
|
159 |
filename = audio_path.split("/")[-1]
|
160 |
# Preprocess
|
@@ -164,7 +173,7 @@ def infer_tc7(audio, nps, bpm, offset, difficulty, level):
|
|
164 |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
165 |
# Inference
|
166 |
don_energy, ka_energy, drumroll_energy = tc7infer.run_inference(
|
167 |
-
|
168 |
)
|
169 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
170 |
onsets = tc7infer.decode_onsets(
|
@@ -220,20 +229,21 @@ def infer_tc7(audio, nps, bpm, offset, difficulty, level):
|
|
220 |
@spaces.GPU
|
221 |
def run_inference_gpu(audio, model_choice, nps, bpm, offset, difficulty, level):
|
222 |
if model_choice == "TC5":
|
223 |
-
return infer_tc5(audio, nps, bpm, offset)
|
224 |
elif model_choice == "TC6":
|
225 |
-
return infer_tc6(audio, nps, bpm, offset, difficulty, level)
|
226 |
else: # TC7
|
227 |
-
return infer_tc7(audio, nps, bpm, offset, difficulty, level)
|
228 |
|
229 |
|
230 |
def run_inference_cpu(audio, model_choice, nps, bpm, offset, difficulty, level):
|
|
|
231 |
if model_choice == "TC5":
|
232 |
-
return infer_tc5(audio, nps, bpm, offset)
|
233 |
elif model_choice == "TC6":
|
234 |
-
return infer_tc6(audio, nps, bpm, offset, difficulty, level)
|
235 |
else: # TC7
|
236 |
-
return infer_tc7(audio, nps, bpm, offset, difficulty, level)
|
237 |
|
238 |
|
239 |
def run_inference(with_gpu, audio, model_choice, nps, bpm, offset, difficulty, level):
|
@@ -330,7 +340,16 @@ with gr.Blocks() as demo:
|
|
330 |
|
331 |
run_btn.click(
|
332 |
run_inference,
|
333 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
outputs=[audio_output, plot_output, tja_output],
|
335 |
)
|
336 |
|
|
|
11 |
from gradio_client import Client, handle_file
|
12 |
import tempfile
|
13 |
|
14 |
+
GPU_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
16 |
# Load model once
|
17 |
tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
|
18 |
+
tc5.to(GPU_DEVICE)
|
19 |
tc5.eval()
|
20 |
+
tc5_cpu = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
|
21 |
+
tc5_cpu.to("cpu")
|
22 |
+
tc5_cpu.eval()
|
23 |
|
24 |
# Load TC6 model
|
25 |
tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
|
26 |
+
tc6.to(GPU_DEVICE)
|
27 |
tc6.eval()
|
28 |
+
tc6_cpu = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
|
29 |
+
tc6_cpu.to("cpu")
|
30 |
+
tc6_cpu.eval()
|
31 |
|
32 |
# Load TC7 model
|
33 |
tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
|
34 |
+
tc7.to(GPU_DEVICE)
|
35 |
tc7.eval()
|
36 |
+
tc7_cpu = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
|
37 |
+
tc7_cpu.to("cpu")
|
38 |
+
tc7_cpu.eval()
|
39 |
|
40 |
synthesizer = Client("ryanlinjui/taiko-music-generator")
|
41 |
|
42 |
|
43 |
+
def infer_tc5(audio, nps, bpm, offset, DEVICE, MODEL):
|
44 |
audio_path = audio
|
45 |
filename = audio_path.split("/")[-1]
|
46 |
# Preprocess
|
47 |
mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps)
|
48 |
# Inference
|
49 |
don_energy, ka_energy, drumroll_energy = tc5infer.run_inference(
|
50 |
+
MODEL, mel_input, nps_input, DEVICE
|
51 |
)
|
52 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
53 |
onsets = tc5infer.decode_onsets(
|
|
|
100 |
return oni_audio, plot, tja_content
|
101 |
|
102 |
|
103 |
+
def infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL):
|
104 |
audio_path = audio
|
105 |
filename = audio_path.split("/")[-1]
|
106 |
# Preprocess
|
|
|
110 |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
111 |
# Inference
|
112 |
don_energy, ka_energy, drumroll_energy = tc6infer.run_inference(
|
113 |
+
MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE
|
114 |
)
|
115 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
116 |
onsets = tc6infer.decode_onsets(
|
|
|
163 |
return oni_audio, plot, tja_content
|
164 |
|
165 |
|
166 |
+
def infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, MODEL):
|
167 |
audio_path = audio
|
168 |
filename = audio_path.split("/")[-1]
|
169 |
# Preprocess
|
|
|
173 |
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
174 |
# Inference
|
175 |
don_energy, ka_energy, drumroll_energy = tc7infer.run_inference(
|
176 |
+
MODEL, mel_input, nps_input, difficulty_input, level_input, DEVICE
|
177 |
)
|
178 |
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
179 |
onsets = tc7infer.decode_onsets(
|
|
|
229 |
@spaces.GPU
|
230 |
def run_inference_gpu(audio, model_choice, nps, bpm, offset, difficulty, level):
|
231 |
if model_choice == "TC5":
|
232 |
+
return infer_tc5(audio, nps, bpm, offset, GPU_DEVICE, tc5)
|
233 |
elif model_choice == "TC6":
|
234 |
+
return infer_tc6(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc6)
|
235 |
else: # TC7
|
236 |
+
return infer_tc7(audio, nps, bpm, offset, difficulty, level, GPU_DEVICE, tc7)
|
237 |
|
238 |
|
239 |
def run_inference_cpu(audio, model_choice, nps, bpm, offset, difficulty, level):
|
240 |
+
DEVICE = torch.device("cpu")
|
241 |
if model_choice == "TC5":
|
242 |
+
return infer_tc5(audio, nps, bpm, offset, DEVICE, tc5_cpu)
|
243 |
elif model_choice == "TC6":
|
244 |
+
return infer_tc6(audio, nps, bpm, offset, difficulty, level, DEVICE, tc6_cpu)
|
245 |
else: # TC7
|
246 |
+
return infer_tc7(audio, nps, bpm, offset, difficulty, level, DEVICE, tc7_cpu)
|
247 |
|
248 |
|
249 |
def run_inference(with_gpu, audio, model_choice, nps, bpm, offset, difficulty, level):
|
|
|
340 |
|
341 |
run_btn.click(
|
342 |
run_inference,
|
343 |
+
inputs=[
|
344 |
+
with_gpu,
|
345 |
+
audio_input,
|
346 |
+
model_choice,
|
347 |
+
nps,
|
348 |
+
bpm,
|
349 |
+
offset,
|
350 |
+
difficulty,
|
351 |
+
level,
|
352 |
+
],
|
353 |
outputs=[audio_output, plot_output, tja_output],
|
354 |
)
|
355 |
|