Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
812b01c
1
Parent(s):
7948b62
Implement TaikoConformer7 model, loss function, preprocessing, and training pipeline
Browse files- Added TaikoLoss class for custom loss calculation with NPS penalties.
- Developed TaikoConformer7 model architecture using Conformer and CNN layers.
- Created preprocessing functions to handle audio data and generate labels.
- Implemented training script with data loading, model training, and validation.
- Integrated TensorBoard logging for loss and energy comparisons during training.
- Added support for sliding NPS labels in preprocessing and loss calculation.
- .gitignore +1 -0
- app.py +300 -0
- requirements.txt +14 -0
- tc5/__init__.py +0 -0
- tc5/config.py +25 -0
- tc5/dataset.py +21 -0
- tc5/infer.py +356 -0
- tc5/loss.py +65 -0
- tc5/model.py +133 -0
- tc5/preprocess.py +215 -0
- tc5/train.py +323 -0
- tc6/__init__.py +0 -0
- tc6/config.py +25 -0
- tc6/dataset.py +21 -0
- tc6/infer.py +354 -0
- tc6/loss.py +65 -0
- tc6/model.py +166 -0
- tc6/preprocess.py +258 -0
- tc6/train.py +336 -0
- tc7/__init__.py +0 -0
- tc7/config.py +27 -0
- tc7/dataset.py +21 -0
- tc7/infer.py +354 -0
- tc7/loss.py +94 -0
- tc7/model.py +166 -0
- tc7/preprocess.py +400 -0
- tc7/train.py +300 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from tc5.config import SAMPLE_RATE, HOP_LENGTH
|
4 |
+
from tc5.model import TaikoConformer5
|
5 |
+
from tc5 import infer as tc5infer
|
6 |
+
from tc6.model import TaikoConformer6
|
7 |
+
from tc6 import infer as tc6infer
|
8 |
+
from tc7.model import TaikoConformer7
|
9 |
+
from tc7 import infer as tc7infer
|
10 |
+
from gradio_client import Client, handle_file
|
11 |
+
import tempfile
|
12 |
+
|
13 |
+
DEVICE = torch.device("cpu")
|
14 |
+
|
15 |
+
# Load model once
|
16 |
+
tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
|
17 |
+
tc5.to(DEVICE)
|
18 |
+
tc5.eval()
|
19 |
+
|
20 |
+
# Load TC6 model
|
21 |
+
tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
|
22 |
+
tc6.to(DEVICE)
|
23 |
+
tc6.eval()
|
24 |
+
|
25 |
+
# Load TC7 model
|
26 |
+
tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
|
27 |
+
tc7.to(DEVICE)
|
28 |
+
tc7.eval()
|
29 |
+
|
30 |
+
synthesizer = Client("ryanlinjui/taiko-music-generator")
|
31 |
+
|
32 |
+
|
33 |
+
def infer_tc5(audio, nps, bpm):
|
34 |
+
audio_path = audio
|
35 |
+
filename = audio_path.split("/")[-1]
|
36 |
+
# Preprocess
|
37 |
+
mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps)
|
38 |
+
# Inference
|
39 |
+
don_energy, ka_energy, drumroll_energy = tc5infer.run_inference(
|
40 |
+
tc5, mel_input, nps_input, DEVICE
|
41 |
+
)
|
42 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
43 |
+
onsets = tc5infer.decode_onsets(
|
44 |
+
don_energy,
|
45 |
+
ka_energy,
|
46 |
+
drumroll_energy,
|
47 |
+
output_frame_hop_sec,
|
48 |
+
threshold=0.3,
|
49 |
+
min_distance_frames=3,
|
50 |
+
)
|
51 |
+
# Generate plot
|
52 |
+
plot = tc5infer.plot_results(
|
53 |
+
mel_input,
|
54 |
+
don_energy,
|
55 |
+
ka_energy,
|
56 |
+
drumroll_energy,
|
57 |
+
onsets,
|
58 |
+
output_frame_hop_sec,
|
59 |
+
)
|
60 |
+
# Generate TJA content
|
61 |
+
tja_content = tc5infer.write_tja(onsets, bpm=bpm, audio=filename)
|
62 |
+
|
63 |
+
# wrtie TJA content to a temporary file
|
64 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
|
65 |
+
temp_tja_file.write(tja_content.encode("utf-8"))
|
66 |
+
tja_path = temp_tja_file.name
|
67 |
+
|
68 |
+
result = synthesizer.predict(
|
69 |
+
param_0=handle_file(tja_path),
|
70 |
+
param_1=handle_file(audio_path),
|
71 |
+
param_2="達人譜面 / Master",
|
72 |
+
param_3=16,
|
73 |
+
param_4=5,
|
74 |
+
param_5=5,
|
75 |
+
param_6=5,
|
76 |
+
param_7=5,
|
77 |
+
param_8=5,
|
78 |
+
param_9=5,
|
79 |
+
param_10=5,
|
80 |
+
param_11=5,
|
81 |
+
param_12=5,
|
82 |
+
param_13=5,
|
83 |
+
param_14=5,
|
84 |
+
param_15=5,
|
85 |
+
api_name="/handle",
|
86 |
+
)
|
87 |
+
|
88 |
+
oni_audio = result[1]
|
89 |
+
|
90 |
+
return oni_audio, plot, tja_content
|
91 |
+
|
92 |
+
|
93 |
+
def infer_tc6(audio, nps, bpm, difficulty, level):
|
94 |
+
audio_path = audio
|
95 |
+
filename = audio_path.split("/")[-1]
|
96 |
+
# Preprocess
|
97 |
+
mel_input = tc6infer.preprocess_audio(audio_path)
|
98 |
+
nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE)
|
99 |
+
difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE)
|
100 |
+
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
101 |
+
# Inference
|
102 |
+
don_energy, ka_energy, drumroll_energy = tc6infer.run_inference(
|
103 |
+
tc6, mel_input, nps_input, difficulty_input, level_input, DEVICE
|
104 |
+
)
|
105 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
106 |
+
onsets = tc6infer.decode_onsets(
|
107 |
+
don_energy,
|
108 |
+
ka_energy,
|
109 |
+
drumroll_energy,
|
110 |
+
output_frame_hop_sec,
|
111 |
+
threshold=0.3,
|
112 |
+
min_distance_frames=3,
|
113 |
+
)
|
114 |
+
# Generate plot
|
115 |
+
plot = tc6infer.plot_results(
|
116 |
+
mel_input,
|
117 |
+
don_energy,
|
118 |
+
ka_energy,
|
119 |
+
drumroll_energy,
|
120 |
+
onsets,
|
121 |
+
output_frame_hop_sec,
|
122 |
+
)
|
123 |
+
# Generate TJA content
|
124 |
+
tja_content = tc6infer.write_tja(onsets, bpm=bpm, audio=filename)
|
125 |
+
|
126 |
+
# wrtie TJA content to a temporary file
|
127 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
|
128 |
+
temp_tja_file.write(tja_content.encode("utf-8"))
|
129 |
+
tja_path = temp_tja_file.name
|
130 |
+
|
131 |
+
result = synthesizer.predict(
|
132 |
+
param_0=handle_file(tja_path),
|
133 |
+
param_1=handle_file(audio_path),
|
134 |
+
param_2="達人譜面 / Master",
|
135 |
+
param_3=16,
|
136 |
+
param_4=5,
|
137 |
+
param_5=5,
|
138 |
+
param_6=5,
|
139 |
+
param_7=5,
|
140 |
+
param_8=5,
|
141 |
+
param_9=5,
|
142 |
+
param_10=5,
|
143 |
+
param_11=5,
|
144 |
+
param_12=5,
|
145 |
+
param_13=5,
|
146 |
+
param_14=5,
|
147 |
+
param_15=5,
|
148 |
+
api_name="/handle",
|
149 |
+
)
|
150 |
+
|
151 |
+
oni_audio = result[1]
|
152 |
+
|
153 |
+
return oni_audio, plot, tja_content
|
154 |
+
|
155 |
+
|
156 |
+
def infer_tc7(audio, nps, bpm, difficulty, level):
|
157 |
+
audio_path = audio
|
158 |
+
filename = audio_path.split("/")[-1]
|
159 |
+
# Preprocess
|
160 |
+
mel_input = tc7infer.preprocess_audio(audio_path)
|
161 |
+
nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE)
|
162 |
+
difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE)
|
163 |
+
level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
|
164 |
+
# Inference
|
165 |
+
don_energy, ka_energy, drumroll_energy = tc7infer.run_inference(
|
166 |
+
tc7, mel_input, nps_input, difficulty_input, level_input, DEVICE
|
167 |
+
)
|
168 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
169 |
+
onsets = tc7infer.decode_onsets(
|
170 |
+
don_energy,
|
171 |
+
ka_energy,
|
172 |
+
drumroll_energy,
|
173 |
+
output_frame_hop_sec,
|
174 |
+
threshold=0.3,
|
175 |
+
min_distance_frames=3,
|
176 |
+
)
|
177 |
+
# Generate plot
|
178 |
+
plot = tc7infer.plot_results(
|
179 |
+
mel_input,
|
180 |
+
don_energy,
|
181 |
+
ka_energy,
|
182 |
+
drumroll_energy,
|
183 |
+
onsets,
|
184 |
+
output_frame_hop_sec,
|
185 |
+
)
|
186 |
+
# Generate TJA content
|
187 |
+
tja_content = tc7infer.write_tja(onsets, bpm=bpm, audio=filename)
|
188 |
+
|
189 |
+
# wrtie TJA content to a temporary file
|
190 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
|
191 |
+
temp_tja_file.write(tja_content.encode("utf-8"))
|
192 |
+
tja_path = temp_tja_file.name
|
193 |
+
|
194 |
+
result = synthesizer.predict(
|
195 |
+
param_0=handle_file(tja_path),
|
196 |
+
param_1=handle_file(audio_path),
|
197 |
+
param_2="達人譜面 / Master",
|
198 |
+
param_3=16,
|
199 |
+
param_4=5,
|
200 |
+
param_5=5,
|
201 |
+
param_6=5,
|
202 |
+
param_7=5,
|
203 |
+
param_8=5,
|
204 |
+
param_9=5,
|
205 |
+
param_10=5,
|
206 |
+
param_11=5,
|
207 |
+
param_12=5,
|
208 |
+
param_13=5,
|
209 |
+
param_14=5,
|
210 |
+
param_15=5,
|
211 |
+
api_name="/handle",
|
212 |
+
)
|
213 |
+
|
214 |
+
oni_audio = result[1]
|
215 |
+
|
216 |
+
return oni_audio, plot, tja_content
|
217 |
+
|
218 |
+
|
219 |
+
def run_inference(audio, model_choice, nps, bpm, difficulty, level):
|
220 |
+
if model_choice == "TC5":
|
221 |
+
return infer_tc5(audio, nps, bpm)
|
222 |
+
elif model_choice == "TC6":
|
223 |
+
return infer_tc6(audio, nps, bpm, difficulty, level)
|
224 |
+
else: # TC7
|
225 |
+
return infer_tc7(audio, nps, bpm, difficulty, level)
|
226 |
+
|
227 |
+
|
228 |
+
with gr.Blocks() as demo:
|
229 |
+
gr.Markdown("# Taiko Conformer 5/7 Demo")
|
230 |
+
with gr.Row():
|
231 |
+
audio_input = gr.Audio(sources="upload", type="filepath", label="Input Audio")
|
232 |
+
|
233 |
+
with gr.Row():
|
234 |
+
model_choice = gr.Dropdown(
|
235 |
+
choices=["TC5", "TC6", "TC7"],
|
236 |
+
value="TC7",
|
237 |
+
label="Model Selection",
|
238 |
+
info="Choose between TaikoConformer 5, 6 or 7",
|
239 |
+
)
|
240 |
+
|
241 |
+
with gr.Row():
|
242 |
+
nps = gr.Slider(
|
243 |
+
value=5.0,
|
244 |
+
minimum=0.5,
|
245 |
+
maximum=11.0,
|
246 |
+
step=0.5,
|
247 |
+
label="NPS (Notes Per Second)",
|
248 |
+
)
|
249 |
+
bpm = gr.Slider(
|
250 |
+
value=240,
|
251 |
+
minimum=160,
|
252 |
+
maximum=640,
|
253 |
+
step=1,
|
254 |
+
label="BPM (Used by TJA Quantization)",
|
255 |
+
)
|
256 |
+
|
257 |
+
with gr.Row():
|
258 |
+
difficulty = gr.Slider(
|
259 |
+
value=3.0,
|
260 |
+
minimum=1.0,
|
261 |
+
maximum=3.0,
|
262 |
+
step=1.0,
|
263 |
+
label="Difficulty",
|
264 |
+
visible=False,
|
265 |
+
info="1=Normal, 2=Hard, 3=Oni",
|
266 |
+
)
|
267 |
+
level = gr.Slider(
|
268 |
+
value=8.0,
|
269 |
+
minimum=1.0,
|
270 |
+
maximum=10.0,
|
271 |
+
step=1.0,
|
272 |
+
label="Level",
|
273 |
+
visible=False,
|
274 |
+
info="Difficulty level from 1 to 10",
|
275 |
+
)
|
276 |
+
|
277 |
+
audio_output = gr.Audio(label="Generated Audio", type="filepath")
|
278 |
+
plot_output = gr.Plot(label="Onset/Energy Plot")
|
279 |
+
tja_output = gr.Textbox(label="TJA File Content", show_copy_button=True)
|
280 |
+
run_btn = gr.Button("Run Inference")
|
281 |
+
|
282 |
+
# Update visibility of TC7-specific controls based on model selection
|
283 |
+
def update_visibility(model_choice):
|
284 |
+
if model_choice == "TC7" or model_choice == "TC6":
|
285 |
+
return gr.update(visible=True), gr.update(visible=True)
|
286 |
+
else:
|
287 |
+
return gr.update(visible=False), gr.update(visible=False)
|
288 |
+
|
289 |
+
model_choice.change(
|
290 |
+
update_visibility, inputs=[model_choice], outputs=[difficulty, level]
|
291 |
+
)
|
292 |
+
|
293 |
+
run_btn.click(
|
294 |
+
run_inference,
|
295 |
+
inputs=[audio_input, model_choice, nps, bpm, difficulty, level],
|
296 |
+
outputs=[audio_output, plot_output, tja_output],
|
297 |
+
)
|
298 |
+
|
299 |
+
if __name__ == "__main__":
|
300 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
datasets
|
4 |
+
huggingface_hub
|
5 |
+
librosa
|
6 |
+
soundfile
|
7 |
+
matplotlib
|
8 |
+
tensorboard
|
9 |
+
black
|
10 |
+
tqdm
|
11 |
+
safetensors
|
12 |
+
accelerate
|
13 |
+
tja
|
14 |
+
spaces
|
tc5/__init__.py
ADDED
File without changes
|
tc5/config.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# ─── 1) CONFIG ─────────────────────────────────────────────────────
|
4 |
+
SAMPLE_RATE = 22050
|
5 |
+
N_MELS = 80
|
6 |
+
HOP_LENGTH = 256 # ~86 fps
|
7 |
+
TIME_SUB = 1
|
8 |
+
CNN_CH = 128
|
9 |
+
N_HEADS = 4
|
10 |
+
D_MODEL = 256
|
11 |
+
FF_DIM = 512
|
12 |
+
N_LAYERS = 4
|
13 |
+
DEPTHWISE_CONV_KERNEL_SIZE = 31
|
14 |
+
DROPOUT = 0.1
|
15 |
+
HIDDEN_DIM = 64
|
16 |
+
N_TYPES = 7
|
17 |
+
BATCH_SIZE = 4
|
18 |
+
GRAD_ACCUM_STEPS = 4
|
19 |
+
LR = 3e-4
|
20 |
+
EPOCHS = 30
|
21 |
+
DEVICE = (
|
22 |
+
"cuda"
|
23 |
+
if torch.cuda.is_available()
|
24 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
25 |
+
)
|
tc5/dataset.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset, concatenate_datasets
|
2 |
+
|
3 |
+
# ds1 = load_dataset("JacobLinCool/taiko-2023-1.1", split="train")
|
4 |
+
# ds2 = load_dataset("JacobLinCool/taiko-2023-1.2", split="train")
|
5 |
+
# ds3 = load_dataset("JacobLinCool/taiko-2023-1.3", split="train")
|
6 |
+
# ds4 = load_dataset("JacobLinCool/taiko-2023-1.4", split="train")
|
7 |
+
# ds5 = load_dataset("JacobLinCool/taiko-2023-1.5", split="train")
|
8 |
+
# ds6 = load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
9 |
+
# ds7 = load_dataset("JacobLinCool/taiko-2023-1.7", split="train")
|
10 |
+
# ds = concatenate_datasets([ds1, ds2, ds3, ds4, ds5, ds6, ds7]).with_format("torch")
|
11 |
+
|
12 |
+
# good = list(range(len(ds)))
|
13 |
+
# good.remove(1079) # 1079 has file problem
|
14 |
+
# ds = ds.select(good)
|
15 |
+
|
16 |
+
# for local test
|
17 |
+
ds = (
|
18 |
+
load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
19 |
+
.with_format("torch")
|
20 |
+
.select(range(10))
|
21 |
+
)
|
tc5/infer.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH
|
7 |
+
import torch.profiler
|
8 |
+
|
9 |
+
|
10 |
+
# --- PREPROCESSING (match training) ---
|
11 |
+
def preprocess_audio(audio_path, nps=5.0):
|
12 |
+
wav, sr = torchaudio.load(audio_path)
|
13 |
+
wav = wav.mean(dim=0) # mono
|
14 |
+
if sr != SAMPLE_RATE:
|
15 |
+
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
|
16 |
+
wav = wav / (wav.abs().max() + 1e-8) # Normalize audio
|
17 |
+
|
18 |
+
nps_tensor = torch.tensor(nps, dtype=torch.float32)
|
19 |
+
|
20 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
21 |
+
sample_rate=SAMPLE_RATE,
|
22 |
+
n_mels=N_MELS,
|
23 |
+
hop_length=HOP_LENGTH,
|
24 |
+
n_fft=2048,
|
25 |
+
)
|
26 |
+
mel = mel_transform(wav)
|
27 |
+
# mel shape is (n_mels, T_mel), unsqueeze for batch later in run_inference
|
28 |
+
return mel, nps_tensor # mel is (N_MELS, T_mel)
|
29 |
+
|
30 |
+
|
31 |
+
# --- INFERENCE ---
|
32 |
+
def run_inference(model, mel_input, nps_input, device):
|
33 |
+
model.eval()
|
34 |
+
with torch.no_grad():
|
35 |
+
mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel)
|
36 |
+
nps = nps_input.to(device).unsqueeze(0) # (1,)
|
37 |
+
|
38 |
+
mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel)
|
39 |
+
|
40 |
+
conformer_lengths = torch.tensor(
|
41 |
+
[mel_cnn_input.shape[-1]], dtype=torch.long, device=device
|
42 |
+
)
|
43 |
+
|
44 |
+
with torch.profiler.profile(
|
45 |
+
activities=[
|
46 |
+
torch.profiler.ProfilerActivity.CPU,
|
47 |
+
*(
|
48 |
+
[torch.profiler.ProfilerActivity.CUDA]
|
49 |
+
if device.type == "cuda"
|
50 |
+
else []
|
51 |
+
),
|
52 |
+
],
|
53 |
+
record_shapes=True,
|
54 |
+
profile_memory=True,
|
55 |
+
with_stack=False,
|
56 |
+
with_flops=True,
|
57 |
+
) as prof:
|
58 |
+
out_dict = model(mel_cnn_input, conformer_lengths, nps)
|
59 |
+
print(
|
60 |
+
prof.key_averages().table(
|
61 |
+
sort_by=(
|
62 |
+
"self_cuda_memory_usage"
|
63 |
+
if device.type == "cuda"
|
64 |
+
else "self_cpu_time_total"
|
65 |
+
),
|
66 |
+
row_limit=20,
|
67 |
+
)
|
68 |
+
)
|
69 |
+
|
70 |
+
energies = out_dict["presence"].squeeze(0).cpu().numpy()
|
71 |
+
|
72 |
+
don_energy = energies[:, 0]
|
73 |
+
ka_energy = energies[:, 1]
|
74 |
+
drumroll_energy = energies[:, 2]
|
75 |
+
|
76 |
+
return don_energy, ka_energy, drumroll_energy
|
77 |
+
|
78 |
+
|
79 |
+
# --- DECODE TO ONSETS ---
|
80 |
+
def decode_onsets(
|
81 |
+
don_energy,
|
82 |
+
ka_energy,
|
83 |
+
drumroll_energy,
|
84 |
+
hop_sec,
|
85 |
+
threshold=0.5,
|
86 |
+
min_distance_frames=3,
|
87 |
+
):
|
88 |
+
results = []
|
89 |
+
T_out = len(don_energy)
|
90 |
+
last_onset_frame = -min_distance_frames
|
91 |
+
|
92 |
+
for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection
|
93 |
+
if i < last_onset_frame + min_distance_frames:
|
94 |
+
continue
|
95 |
+
|
96 |
+
e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i]
|
97 |
+
energies_at_i = {
|
98 |
+
1: e_don,
|
99 |
+
2: e_ka,
|
100 |
+
5: e_drum,
|
101 |
+
} # Type mapping: 1:Don, 2:Ka, 5:Drumroll
|
102 |
+
|
103 |
+
# Find which energy is max and if it's a peak above threshold
|
104 |
+
# Sort by energy value descending to prioritize higher energy in case of ties for peak condition
|
105 |
+
sorted_types_by_energy = sorted(
|
106 |
+
energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True
|
107 |
+
)
|
108 |
+
|
109 |
+
detected_this_frame = False
|
110 |
+
for onset_type in sorted_types_by_energy:
|
111 |
+
current_energy_series = None
|
112 |
+
if onset_type == 1:
|
113 |
+
current_energy_series = don_energy
|
114 |
+
elif onset_type == 2:
|
115 |
+
current_energy_series = ka_energy
|
116 |
+
elif onset_type == 5:
|
117 |
+
current_energy_series = drumroll_energy
|
118 |
+
|
119 |
+
energy_val = current_energy_series[i]
|
120 |
+
|
121 |
+
if (
|
122 |
+
energy_val > threshold
|
123 |
+
and energy_val > current_energy_series[i - 1]
|
124 |
+
and energy_val > current_energy_series[i + 1]
|
125 |
+
):
|
126 |
+
# Check if this energy is the highest among the three at this frame
|
127 |
+
# This check is implicitly handled by iterating `sorted_types_by_energy`
|
128 |
+
# and breaking after the first detection.
|
129 |
+
results.append((i * hop_sec, onset_type))
|
130 |
+
last_onset_frame = i
|
131 |
+
detected_this_frame = True
|
132 |
+
break # Only one onset type per frame
|
133 |
+
|
134 |
+
return results
|
135 |
+
|
136 |
+
|
137 |
+
# --- VISUALIZATION ---
|
138 |
+
def plot_results(
|
139 |
+
mel_spectrogram,
|
140 |
+
don_energy,
|
141 |
+
ka_energy,
|
142 |
+
drumroll_energy,
|
143 |
+
onsets,
|
144 |
+
hop_sec,
|
145 |
+
out_path=None,
|
146 |
+
):
|
147 |
+
# mel_spectrogram is (N_MELS, T_mel)
|
148 |
+
T_mel = mel_spectrogram.shape[1]
|
149 |
+
T_out = len(don_energy) # Length of energy arrays (model output time dimension)
|
150 |
+
|
151 |
+
# Time axes
|
152 |
+
time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE)
|
153 |
+
# hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE
|
154 |
+
# However, the model output T_out is related to T_mel (input to CNN).
|
155 |
+
# If CNN does not change time dimension, T_out = T_mel.
|
156 |
+
# If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB.
|
157 |
+
# The `lengths` passed to conformer in `run_inference` is T_mel.
|
158 |
+
# The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`.
|
159 |
+
# So, T_out from model is T_mel.
|
160 |
+
# The `hop_sec` for onsets should be based on the model output frame rate.
|
161 |
+
# If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE.
|
162 |
+
# The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE.
|
163 |
+
# This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels.
|
164 |
+
# The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`.
|
165 |
+
# In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN.
|
166 |
+
# The `lengths` for the conformer is based on this T_cnn_out.
|
167 |
+
# So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps.
|
168 |
+
# Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer.
|
169 |
+
# Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames.
|
170 |
+
time_axis_energies = np.arange(T_out) * hop_sec
|
171 |
+
|
172 |
+
fig, ax1 = plt.subplots(figsize=(100, 10))
|
173 |
+
|
174 |
+
# Plot Mel Spectrogram on ax1
|
175 |
+
mel_db = torchaudio.functional.amplitude_to_DB(
|
176 |
+
mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0
|
177 |
+
)
|
178 |
+
img = ax1.imshow(
|
179 |
+
mel_db.numpy(),
|
180 |
+
aspect="auto",
|
181 |
+
origin="lower",
|
182 |
+
cmap="magma",
|
183 |
+
extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS],
|
184 |
+
)
|
185 |
+
ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets")
|
186 |
+
ax1.set_xlabel("Time (s)")
|
187 |
+
ax1.set_ylabel("Mel Bin")
|
188 |
+
fig.colorbar(img, ax=ax1, format="%+2.0f dB")
|
189 |
+
|
190 |
+
# Create a second y-axis for energies
|
191 |
+
ax2 = ax1.twinx()
|
192 |
+
ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red")
|
193 |
+
ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue")
|
194 |
+
ax2.plot(
|
195 |
+
time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green"
|
196 |
+
)
|
197 |
+
ax2.set_ylabel("Energy")
|
198 |
+
ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded
|
199 |
+
|
200 |
+
# Overlay onsets from decode_onsets (t is already in seconds)
|
201 |
+
labeled_types = set()
|
202 |
+
# Group drumrolls into segments (reuse logic from write_tja)
|
203 |
+
drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5]
|
204 |
+
drumroll_times.sort()
|
205 |
+
drumroll_segments = []
|
206 |
+
if drumroll_times:
|
207 |
+
seg_start = drumroll_times[0]
|
208 |
+
prev = drumroll_times[0]
|
209 |
+
for t in drumroll_times[1:]:
|
210 |
+
if t - prev <= hop_sec * 6: # up to 5-frame gap
|
211 |
+
prev = t
|
212 |
+
else:
|
213 |
+
drumroll_segments.append((seg_start, prev))
|
214 |
+
seg_start = t
|
215 |
+
prev = t
|
216 |
+
drumroll_segments.append((seg_start, prev))
|
217 |
+
# Plot Don/Ka onsets as vertical lines
|
218 |
+
for t_sec, typ in onsets:
|
219 |
+
if typ == 5:
|
220 |
+
continue # skip drumroll onsets
|
221 |
+
color_map = {1: "darkred", 2: "darkblue"}
|
222 |
+
label_map = {1: "Don Onset", 2: "Ka Onset"}
|
223 |
+
line_color = color_map.get(typ, "black")
|
224 |
+
line_label = label_map.get(typ, f"Type {typ} Onset")
|
225 |
+
if typ not in labeled_types:
|
226 |
+
ax1.axvline(
|
227 |
+
t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label
|
228 |
+
)
|
229 |
+
labeled_types.add(typ)
|
230 |
+
else:
|
231 |
+
ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9)
|
232 |
+
# Plot drumroll segments as shaded regions
|
233 |
+
for seg_start, seg_end in drumroll_segments:
|
234 |
+
ax1.axvspan(
|
235 |
+
seg_start,
|
236 |
+
seg_end + hop_sec,
|
237 |
+
color="green",
|
238 |
+
alpha=0.2,
|
239 |
+
label="Drumroll Segment" if "drumroll" not in labeled_types else None,
|
240 |
+
)
|
241 |
+
labeled_types.add("drumroll")
|
242 |
+
|
243 |
+
# Combine legends from both axes
|
244 |
+
lines, labels = ax1.get_legend_handles_labels()
|
245 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
246 |
+
ax2.legend(lines + lines2, labels + labels2, loc="upper right")
|
247 |
+
|
248 |
+
fig.tight_layout()
|
249 |
+
|
250 |
+
# Return plot as image buffer or save to file if path provided
|
251 |
+
if out_path:
|
252 |
+
plt.savefig(out_path)
|
253 |
+
print(f"Saved plot to {out_path}")
|
254 |
+
plt.close(fig)
|
255 |
+
return out_path
|
256 |
+
else:
|
257 |
+
# Return plot as in-memory buffer
|
258 |
+
return fig
|
259 |
+
|
260 |
+
|
261 |
+
def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
|
262 |
+
# TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
|
263 |
+
# Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
|
264 |
+
sec_per_beat = 60 / bpm
|
265 |
+
beats_per_measure = 4 # Assuming 4/4 time signature
|
266 |
+
sec_per_measure = sec_per_beat * beats_per_measure
|
267 |
+
# Step 1: Map onsets to (measure_idx, slot, typ)
|
268 |
+
slot_events = []
|
269 |
+
for t, typ in onsets:
|
270 |
+
measure_idx = int(t // sec_per_measure)
|
271 |
+
t_in_measure = t % sec_per_measure
|
272 |
+
slot = int(round(t_in_measure / sec_per_measure * quantize))
|
273 |
+
if slot >= quantize:
|
274 |
+
slot = quantize - 1
|
275 |
+
slot_events.append((measure_idx, slot, typ))
|
276 |
+
# Step 2: Build measure/slot grid
|
277 |
+
if slot_events:
|
278 |
+
max_measure_idx = max(m for m, _, _ in slot_events)
|
279 |
+
else:
|
280 |
+
max_measure_idx = -1
|
281 |
+
measures = {i: [0] * quantize for i in range(max_measure_idx + 1)}
|
282 |
+
# Step 3: Place Don/Ka, collect drumrolls
|
283 |
+
drumroll_slots = set()
|
284 |
+
for m, s, typ in slot_events:
|
285 |
+
if typ in [1, 2]:
|
286 |
+
measures[m][s] = typ
|
287 |
+
elif typ == 5:
|
288 |
+
drumroll_slots.add((m, s))
|
289 |
+
# Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end)
|
290 |
+
# Flatten all slots to a list of (measure, slot) sorted
|
291 |
+
drumroll_list = sorted(list(drumroll_slots))
|
292 |
+
# Group into contiguous regions (allowing a gap of 5 slots)
|
293 |
+
grouped = []
|
294 |
+
group = []
|
295 |
+
for ms in drumroll_list:
|
296 |
+
if not group:
|
297 |
+
group = [ms]
|
298 |
+
else:
|
299 |
+
last_m, last_s = group[-1]
|
300 |
+
m, s = ms
|
301 |
+
# Calculate slot distance, considering measure wrap
|
302 |
+
slot_dist = None
|
303 |
+
if m == last_m:
|
304 |
+
slot_dist = s - last_s
|
305 |
+
elif m == last_m + 1 and last_s <= quantize - 1:
|
306 |
+
slot_dist = (quantize - 1 - last_s) + s + 1
|
307 |
+
else:
|
308 |
+
slot_dist = None
|
309 |
+
# Allow gap of up to 5 slots (slot_dist <= 6)
|
310 |
+
if slot_dist is not None and 1 <= slot_dist <= 6:
|
311 |
+
group.append(ms)
|
312 |
+
else:
|
313 |
+
grouped.append(group)
|
314 |
+
group = [ms]
|
315 |
+
if group:
|
316 |
+
grouped.append(group)
|
317 |
+
# Mark 5 (start) and 8 (end) for each group
|
318 |
+
for region in grouped:
|
319 |
+
if len(region) == 1:
|
320 |
+
m, s = region[0]
|
321 |
+
measures[m][s] = 5
|
322 |
+
# Place 8 in next slot (or next measure if at end)
|
323 |
+
if s < quantize - 1:
|
324 |
+
measures[m][s + 1] = 8
|
325 |
+
elif m < max_measure_idx:
|
326 |
+
measures[m + 1][0] = 8
|
327 |
+
else:
|
328 |
+
m_start, s_start = region[0]
|
329 |
+
m_end, s_end = region[-1]
|
330 |
+
measures[m_start][s_start] = 5
|
331 |
+
measures[m_end][s_end] = 8
|
332 |
+
# Fill 0 for middle slots (already 0 by default)
|
333 |
+
|
334 |
+
# Step 5: Generate TJA content
|
335 |
+
tja_content = []
|
336 |
+
tja_content.append(f"TITLE:{audio} (TC5, {time.strftime('%Y-%m-%d %H:%M:%S')})")
|
337 |
+
tja_content.append(f"BPM:{bpm}")
|
338 |
+
tja_content.append(f"WAVE:{audio}")
|
339 |
+
tja_content.append("OFFSET:0")
|
340 |
+
tja_content.append("COURSE:Oni\nLEVEL:9\n")
|
341 |
+
tja_content.append("#START")
|
342 |
+
for i in range(max_measure_idx + 1):
|
343 |
+
notes = measures.get(i, [0] * quantize)
|
344 |
+
line = "".join(str(n) for n in notes)
|
345 |
+
tja_content.append(line + ",")
|
346 |
+
tja_content.append("#END")
|
347 |
+
|
348 |
+
tja_string = "\n".join(tja_content)
|
349 |
+
|
350 |
+
# If out_path is provided, also write to file
|
351 |
+
if out_path:
|
352 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
353 |
+
f.write(tja_string)
|
354 |
+
print(f"TJA chart saved to {out_path}")
|
355 |
+
|
356 |
+
return tja_string
|
tc5/loss.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class TaikoEnergyLoss(nn.Module):
|
6 |
+
def __init__(self, reduction="mean"):
|
7 |
+
super().__init__()
|
8 |
+
# Use 'none' reduction to get element-wise losses, then manually apply masking and reduction
|
9 |
+
self.mse_loss = nn.MSELoss(reduction="none")
|
10 |
+
self.reduction = reduction
|
11 |
+
|
12 |
+
def forward(self, outputs, batch):
|
13 |
+
"""
|
14 |
+
Calculates the MSE loss for energy-based predictions.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
outputs (dict): Model output, containing 'presence' tensor.
|
18 |
+
outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
|
19 |
+
batch (dict): Batch data from collate_fn, containing true labels and lengths.
|
20 |
+
batch['don_labels'], batch['ka_labels'], batch['drumroll_labels'] shape: (B, T)
|
21 |
+
batch['lengths'] shape: (B,) - valid sequence lengths for time dimension T.
|
22 |
+
Returns:
|
23 |
+
torch.Tensor: The calculated loss.
|
24 |
+
"""
|
25 |
+
pred_energies = outputs["presence"] # (B, T, 3)
|
26 |
+
|
27 |
+
true_don = batch["don_labels"] # (B, T)
|
28 |
+
true_ka = batch["ka_labels"] # (B, T)
|
29 |
+
true_drumroll = batch["drumroll_labels"] # (B, T)
|
30 |
+
|
31 |
+
# Stack true labels to match the structure of pred_energies (B, T, 3)
|
32 |
+
true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2)
|
33 |
+
|
34 |
+
B, T, _ = pred_energies.shape
|
35 |
+
|
36 |
+
# Create a mask based on batch['lengths'] to ignore padded parts of sequences
|
37 |
+
# batch['lengths'] gives the actual length of each sequence in the batch
|
38 |
+
# mask shape: (B, T)
|
39 |
+
mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
|
40 |
+
"lengths"
|
41 |
+
].unsqueeze(1)
|
42 |
+
# Expand mask to (B, T, 1) to broadcast across the 3 energy channels
|
43 |
+
mask_3d = mask_2d.unsqueeze(2)
|
44 |
+
|
45 |
+
# Calculate element-wise MSE loss
|
46 |
+
loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3)
|
47 |
+
|
48 |
+
# Apply the mask to the loss
|
49 |
+
masked_loss = loss_elementwise * mask_3d
|
50 |
+
|
51 |
+
if self.reduction == "mean":
|
52 |
+
# Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
|
53 |
+
total_loss = masked_loss.sum()
|
54 |
+
num_valid_elements = mask_3d.sum() # Total number of unmasked float values
|
55 |
+
if num_valid_elements > 0:
|
56 |
+
return total_loss / num_valid_elements
|
57 |
+
else:
|
58 |
+
# Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
|
59 |
+
return torch.tensor(
|
60 |
+
0.0, device=pred_energies.device, requires_grad=True
|
61 |
+
)
|
62 |
+
elif self.reduction == "sum":
|
63 |
+
return masked_loss.sum()
|
64 |
+
else: # 'none' or any other case
|
65 |
+
return masked_loss
|
tc5/model.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchaudio.models import Conformer
|
4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
5 |
+
from .config import (
|
6 |
+
N_MELS,
|
7 |
+
CNN_CH,
|
8 |
+
N_HEADS,
|
9 |
+
D_MODEL,
|
10 |
+
FF_DIM,
|
11 |
+
N_LAYERS,
|
12 |
+
DROPOUT,
|
13 |
+
DEPTHWISE_CONV_KERNEL_SIZE,
|
14 |
+
HIDDEN_DIM,
|
15 |
+
DEVICE,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class TaikoConformer5(nn.Module, PyTorchModelHubMixin):
|
20 |
+
def __init__(self):
|
21 |
+
super().__init__()
|
22 |
+
# 1) CNN frontend: frequency-only pooling
|
23 |
+
self.cnn = nn.Sequential(
|
24 |
+
nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
|
25 |
+
nn.BatchNorm2d(CNN_CH),
|
26 |
+
nn.GELU(),
|
27 |
+
nn.Dropout2d(DROPOUT),
|
28 |
+
nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
|
29 |
+
nn.BatchNorm2d(CNN_CH),
|
30 |
+
nn.GELU(),
|
31 |
+
nn.Dropout2d(DROPOUT),
|
32 |
+
)
|
33 |
+
feat_dim = CNN_CH * (N_MELS // 4)
|
34 |
+
|
35 |
+
# 2) Linear projection to model dimension
|
36 |
+
self.proj = nn.Linear(feat_dim, D_MODEL)
|
37 |
+
|
38 |
+
# 3) FiLM conditioning for notes_per_second
|
39 |
+
self.film = nn.Linear(1, 2 * D_MODEL)
|
40 |
+
|
41 |
+
# 4) Conformer encoder
|
42 |
+
self.encoder = Conformer(
|
43 |
+
input_dim=D_MODEL,
|
44 |
+
num_heads=N_HEADS,
|
45 |
+
ffn_dim=FF_DIM,
|
46 |
+
num_layers=N_LAYERS,
|
47 |
+
depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
|
48 |
+
dropout=DROPOUT,
|
49 |
+
use_group_norm=False,
|
50 |
+
convolution_first=False,
|
51 |
+
)
|
52 |
+
|
53 |
+
# 5) Presence regressor head
|
54 |
+
self.presence_regressor = nn.Sequential(
|
55 |
+
nn.Dropout(DROPOUT),
|
56 |
+
nn.Linear(D_MODEL, HIDDEN_DIM),
|
57 |
+
nn.GELU(),
|
58 |
+
nn.Dropout(DROPOUT),
|
59 |
+
nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
|
60 |
+
nn.Sigmoid(), # Output between 0 and 1
|
61 |
+
)
|
62 |
+
|
63 |
+
# 6) Initialize weights
|
64 |
+
for m in self.modules():
|
65 |
+
if isinstance(m, nn.Conv2d):
|
66 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
67 |
+
elif isinstance(m, nn.Linear):
|
68 |
+
nn.init.xavier_uniform_(m.weight)
|
69 |
+
if m.bias is not None:
|
70 |
+
nn.init.zeros_(m.bias)
|
71 |
+
|
72 |
+
def forward(
|
73 |
+
self, mel: torch.Tensor, lengths: torch.Tensor, notes_per_second: torch.Tensor
|
74 |
+
):
|
75 |
+
"""
|
76 |
+
Args:
|
77 |
+
mel: (B, 1, N_MELS, T_mel)
|
78 |
+
lengths: (B,) lengths after CNN
|
79 |
+
notes_per_second: (B,) stream of control values
|
80 |
+
Returns:
|
81 |
+
Dict with:
|
82 |
+
'presence': (B, T_cnn_out, 4)
|
83 |
+
'lengths': lengths
|
84 |
+
"""
|
85 |
+
# CNN frontend
|
86 |
+
x = self.cnn(mel) # (B, C, F, T)
|
87 |
+
B, C, F, T = x.size()
|
88 |
+
x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
|
89 |
+
|
90 |
+
# Project to model dimension
|
91 |
+
x = self.proj(x) # (B, T, D_MODEL)
|
92 |
+
|
93 |
+
# FiLM conditioning
|
94 |
+
nps = notes_per_second.unsqueeze(-1) # (B, 1)
|
95 |
+
gamma_beta = self.film(nps) # (B, 2*D_MODEL)
|
96 |
+
gamma, beta = gamma_beta.chunk(2, dim=-1)
|
97 |
+
x = gamma.unsqueeze(1) * x + beta.unsqueeze(1)
|
98 |
+
|
99 |
+
# Conformer encoder
|
100 |
+
x, _ = self.encoder(x, lengths=lengths)
|
101 |
+
|
102 |
+
# Presence prediction
|
103 |
+
presence = self.presence_regressor(x)
|
104 |
+
return {"presence": presence, "lengths": lengths}
|
105 |
+
|
106 |
+
|
107 |
+
if __name__ == "__main__":
|
108 |
+
model = TaikoConformer5().to(device=DEVICE)
|
109 |
+
print(model)
|
110 |
+
|
111 |
+
for name, param in model.named_parameters():
|
112 |
+
if param.requires_grad:
|
113 |
+
print(f"{name}: {param.numel():,}")
|
114 |
+
|
115 |
+
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
116 |
+
print(f"Total parameters: {params / 1e6:.2f}M")
|
117 |
+
|
118 |
+
batch_size = 4
|
119 |
+
mel_time_steps = 1024
|
120 |
+
input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
|
121 |
+
|
122 |
+
conformer_lengths = torch.tensor(
|
123 |
+
[mel_time_steps] * batch_size, dtype=torch.long
|
124 |
+
).to(DEVICE)
|
125 |
+
|
126 |
+
notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
|
127 |
+
DEVICE
|
128 |
+
)
|
129 |
+
|
130 |
+
output = model(input_mel, conformer_lengths, notes_per_second_input)
|
131 |
+
print("Output shapes:")
|
132 |
+
for key, value in output.items():
|
133 |
+
print(f"{key}: {value.shape}")
|
tc5/preprocess.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchaudio
|
5 |
+
from torchaudio.transforms import FrequencyMasking
|
6 |
+
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
|
7 |
+
from .model import TaikoConformer5
|
8 |
+
|
9 |
+
|
10 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
11 |
+
sample_rate=SAMPLE_RATE,
|
12 |
+
n_mels=N_MELS,
|
13 |
+
hop_length=HOP_LENGTH,
|
14 |
+
n_fft=2048,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
freq_mask = FrequencyMasking(freq_mask_param=15)
|
19 |
+
|
20 |
+
|
21 |
+
def preprocess(example, difficulty="oni"):
|
22 |
+
wav_tensor = example["audio"]["array"]
|
23 |
+
sr = example["audio"]["sampling_rate"]
|
24 |
+
# 1) load & resample
|
25 |
+
if sr != SAMPLE_RATE:
|
26 |
+
wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
|
27 |
+
# normalize audio
|
28 |
+
wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
|
29 |
+
# add random Gaussian noise
|
30 |
+
if torch.rand(1).item() < 0.5:
|
31 |
+
wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
|
32 |
+
# 2) mel: (1, N_MELS, T)
|
33 |
+
mel = mel_transform(wav_tensor).unsqueeze(0)
|
34 |
+
# apply SpecAugment
|
35 |
+
# we don't use time masking since we don't want model to predict notes when they are masked
|
36 |
+
mel = freq_mask(mel)
|
37 |
+
_, _, T = mel.shape
|
38 |
+
# 3) build label sequence of length ceil(T / TIME_SUB)
|
39 |
+
T_sub = math.ceil(T / TIME_SUB)
|
40 |
+
|
41 |
+
# Initialize energy-based labels for Don, Ka, Drumroll
|
42 |
+
don_labels = torch.zeros(T_sub, dtype=torch.float32)
|
43 |
+
ka_labels = torch.zeros(T_sub, dtype=torch.float32)
|
44 |
+
drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
|
45 |
+
|
46 |
+
# Define exponential decay tail parameters
|
47 |
+
tail_length = 40 # number of frames for decay tail
|
48 |
+
decay_rate = 8.0 # decay rate parameter, adjust as needed
|
49 |
+
tail_kernel = torch.exp(
|
50 |
+
-torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
|
51 |
+
)
|
52 |
+
|
53 |
+
fps = SAMPLE_RATE / HOP_LENGTH
|
54 |
+
num_valid_notes = 0
|
55 |
+
for onset in example[difficulty]:
|
56 |
+
typ, t_start, t_end, *_ = onset
|
57 |
+
|
58 |
+
# Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
|
59 |
+
if typ < 1 or typ > N_TYPES: # Filter out invalid types
|
60 |
+
continue
|
61 |
+
|
62 |
+
num_valid_notes += 1
|
63 |
+
|
64 |
+
f = int(round(t_start.item() * fps))
|
65 |
+
idx = f // TIME_SUB
|
66 |
+
if 0 <= idx < T_sub:
|
67 |
+
# Apply exponential decay kernel to the corresponding energy channel
|
68 |
+
# Type 1 and 3 are Don
|
69 |
+
if typ == 1 or typ == 3:
|
70 |
+
for i, val in enumerate(tail_kernel):
|
71 |
+
target_idx = idx + i
|
72 |
+
if 0 <= target_idx < T_sub:
|
73 |
+
don_labels[target_idx] = max(
|
74 |
+
don_labels[target_idx].item(), val.item()
|
75 |
+
)
|
76 |
+
# Type 2 and 4 are Ka
|
77 |
+
elif typ == 2 or typ == 4:
|
78 |
+
for i, val in enumerate(tail_kernel):
|
79 |
+
target_idx = idx + i
|
80 |
+
if 0 <= target_idx < T_sub:
|
81 |
+
ka_labels[target_idx] = max(
|
82 |
+
ka_labels[target_idx].item(), val.item()
|
83 |
+
)
|
84 |
+
# Type 5, 6, 7 are Drumroll
|
85 |
+
elif typ >= 5 and typ <= 7:
|
86 |
+
f_end = int(round(t_end.item() * fps))
|
87 |
+
idx_end = f_end // TIME_SUB
|
88 |
+
|
89 |
+
for dr in range(idx, idx_end):
|
90 |
+
if 0 <= dr < T_sub:
|
91 |
+
drumroll_labels[dr] = 1.0
|
92 |
+
|
93 |
+
for i, val in enumerate(tail_kernel):
|
94 |
+
target_idx = idx_end + i
|
95 |
+
if 0 <= target_idx < T_sub:
|
96 |
+
drumroll_labels[target_idx] = max(
|
97 |
+
drumroll_labels[target_idx].item(), val.item()
|
98 |
+
)
|
99 |
+
|
100 |
+
duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
|
101 |
+
nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
|
102 |
+
print(
|
103 |
+
f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}"
|
104 |
+
)
|
105 |
+
|
106 |
+
return {
|
107 |
+
"mel": mel,
|
108 |
+
"don_labels": don_labels,
|
109 |
+
"ka_labels": ka_labels,
|
110 |
+
"drumroll_labels": drumroll_labels,
|
111 |
+
"nps": torch.tensor(nps, dtype=torch.float32),
|
112 |
+
"duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
|
113 |
+
}
|
114 |
+
|
115 |
+
|
116 |
+
def collate_fn(batch):
|
117 |
+
mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS)
|
118 |
+
# Extract new energy-based labels
|
119 |
+
don_labels_list = [b["don_labels"] for b in batch]
|
120 |
+
ka_labels_list = [b["ka_labels"] for b in batch]
|
121 |
+
drumroll_labels_list = [b["drumroll_labels"] for b in batch]
|
122 |
+
|
123 |
+
nps_list = [b["nps"] for b in batch] # Extract NPS
|
124 |
+
durations_list = [b["duration_seconds"] for b in batch] # Extract durations
|
125 |
+
|
126 |
+
# Pad mels
|
127 |
+
padded_mels = nn.utils.rnn.pad_sequence(
|
128 |
+
mels_list, batch_first=True
|
129 |
+
) # (B, T_max, N_MELS)
|
130 |
+
# Reshape for CNN: (B, 1, N_MELS, T_max)
|
131 |
+
reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
|
132 |
+
|
133 |
+
# Simulate CNN time downsampling to get output lengths
|
134 |
+
dummy_model_for_shape_inference = TaikoConformer5()
|
135 |
+
dummy_cnn = dummy_model_for_shape_inference.cnn
|
136 |
+
with torch.no_grad():
|
137 |
+
cnn_out = dummy_cnn(reshaped_mels) # Use reshaped_mels that has batch dim
|
138 |
+
_, _, _, T_cnn = cnn_out.shape
|
139 |
+
|
140 |
+
padded_don_labels = []
|
141 |
+
padded_ka_labels = []
|
142 |
+
padded_drumroll_labels = []
|
143 |
+
# lengths = [] # This was for original presence/type labels, conformer_input_lengths is used for model
|
144 |
+
|
145 |
+
for i in range(len(batch)):
|
146 |
+
d_labels = don_labels_list[i]
|
147 |
+
k_labels = ka_labels_list[i]
|
148 |
+
dr_labels = drumroll_labels_list[i]
|
149 |
+
|
150 |
+
item_original_T_sub = d_labels.shape[
|
151 |
+
0
|
152 |
+
] # Assuming all label types have same original length
|
153 |
+
out_len = T_cnn # Target length for labels is T_cnn
|
154 |
+
|
155 |
+
# Pad or truncate don_labels
|
156 |
+
if item_original_T_sub < out_len:
|
157 |
+
pad_d = torch.full(
|
158 |
+
(out_len - item_original_T_sub,),
|
159 |
+
0, # Pad with 0 for energy labels
|
160 |
+
dtype=d_labels.dtype,
|
161 |
+
device=d_labels.device,
|
162 |
+
)
|
163 |
+
padded_d = torch.cat([d_labels, pad_d], dim=0)
|
164 |
+
else:
|
165 |
+
padded_d = d_labels[:out_len]
|
166 |
+
padded_don_labels.append(padded_d)
|
167 |
+
|
168 |
+
# Pad or truncate ka_labels
|
169 |
+
if item_original_T_sub < out_len:
|
170 |
+
pad_k = torch.full(
|
171 |
+
(out_len - item_original_T_sub,),
|
172 |
+
0, # Pad with 0 for energy labels
|
173 |
+
dtype=k_labels.dtype,
|
174 |
+
device=k_labels.device,
|
175 |
+
)
|
176 |
+
padded_k = torch.cat([k_labels, pad_k], dim=0)
|
177 |
+
else:
|
178 |
+
padded_k = k_labels[:out_len]
|
179 |
+
padded_ka_labels.append(padded_k)
|
180 |
+
|
181 |
+
# Pad or truncate drumroll_labels
|
182 |
+
if item_original_T_sub < out_len:
|
183 |
+
pad_dr = torch.full(
|
184 |
+
(out_len - item_original_T_sub,),
|
185 |
+
0, # Pad with 0 for energy labels
|
186 |
+
dtype=dr_labels.dtype,
|
187 |
+
device=dr_labels.device,
|
188 |
+
)
|
189 |
+
padded_dr = torch.cat([dr_labels, pad_dr], dim=0)
|
190 |
+
else:
|
191 |
+
padded_dr = dr_labels[:out_len]
|
192 |
+
padded_drumroll_labels.append(padded_dr)
|
193 |
+
|
194 |
+
# For Conformer input lengths: lengths of mel sequences after CNN subsampling
|
195 |
+
# (Assuming CNN does not subsample in time, T_cnn is effectively T_mel_padded)
|
196 |
+
# The `lengths` for the Conformer should be based on the mel input to the conformer part.
|
197 |
+
# The existing calculation for conformer_input_lengths seems to relate to TIME_SUB.
|
198 |
+
# If the Conformer input itself is not subsampled by TIME_SUB, this might need review.
|
199 |
+
# For now, keeping the existing conformer_input_lengths logic as it's outside the scope of label change.
|
200 |
+
conformer_input_lengths = [
|
201 |
+
math.ceil(mels_list[i].shape[0] / TIME_SUB) for i in range(len(batch))
|
202 |
+
]
|
203 |
+
conformer_input_lengths = torch.tensor(
|
204 |
+
[min(l, T_cnn) for l in conformer_input_lengths], dtype=torch.long
|
205 |
+
)
|
206 |
+
|
207 |
+
return {
|
208 |
+
"mel": reshaped_mels,
|
209 |
+
"don_labels": torch.stack(padded_don_labels),
|
210 |
+
"ka_labels": torch.stack(padded_ka_labels),
|
211 |
+
"drumroll_labels": torch.stack(padded_drumroll_labels),
|
212 |
+
"lengths": conformer_input_lengths, # These are for the Conformer model
|
213 |
+
"nps": torch.stack(nps_list),
|
214 |
+
"durations": torch.stack(durations_list),
|
215 |
+
}
|
tc5/train.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from accelerate.utils import set_seed
|
2 |
+
|
3 |
+
set_seed(1024)
|
4 |
+
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
from tqdm import tqdm
|
11 |
+
from datasets import concatenate_datasets
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import numpy as np
|
14 |
+
from .config import (
|
15 |
+
BATCH_SIZE,
|
16 |
+
DEVICE,
|
17 |
+
EPOCHS,
|
18 |
+
LR,
|
19 |
+
GRAD_ACCUM_STEPS,
|
20 |
+
HOP_LENGTH,
|
21 |
+
SAMPLE_RATE,
|
22 |
+
)
|
23 |
+
from .model import TaikoConformer5
|
24 |
+
from .dataset import ds
|
25 |
+
from .preprocess import preprocess, collate_fn
|
26 |
+
from .loss import TaikoEnergyLoss
|
27 |
+
from huggingface_hub import upload_folder
|
28 |
+
|
29 |
+
|
30 |
+
# --- Helper function to log energy plots ---
|
31 |
+
def log_energy_plots_to_tensorboard(
|
32 |
+
writer,
|
33 |
+
tag_prefix,
|
34 |
+
epoch,
|
35 |
+
pred_don,
|
36 |
+
pred_ka,
|
37 |
+
pred_drumroll,
|
38 |
+
true_don,
|
39 |
+
true_ka,
|
40 |
+
true_drumroll,
|
41 |
+
valid_length, # Actual valid length of the sequence (before padding)
|
42 |
+
hop_sec,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Logs a plot of predicted vs. true energies for one sample to TensorBoard.
|
46 |
+
Energies should be 1D numpy arrays for the single sample, up to valid_length.
|
47 |
+
"""
|
48 |
+
# Ensure data is on CPU and converted to numpy, and select only the valid part
|
49 |
+
pred_don = pred_don[:valid_length].detach().cpu().numpy()
|
50 |
+
pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
|
51 |
+
pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
|
52 |
+
true_don = true_don[:valid_length].cpu().numpy()
|
53 |
+
true_ka = true_ka[:valid_length].cpu().numpy()
|
54 |
+
true_drumroll = true_drumroll[:valid_length].cpu().numpy()
|
55 |
+
|
56 |
+
time_axis = np.arange(valid_length) * hop_sec
|
57 |
+
|
58 |
+
fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
|
59 |
+
fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)
|
60 |
+
|
61 |
+
axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
|
62 |
+
axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
|
63 |
+
axs[0].set_ylabel("Don Energy")
|
64 |
+
axs[0].legend()
|
65 |
+
axs[0].grid(True)
|
66 |
+
|
67 |
+
axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
|
68 |
+
axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
|
69 |
+
axs[1].set_ylabel("Ka Energy")
|
70 |
+
axs[1].legend()
|
71 |
+
axs[1].grid(True)
|
72 |
+
|
73 |
+
axs[2].plot(
|
74 |
+
time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
|
75 |
+
)
|
76 |
+
axs[2].plot(
|
77 |
+
time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
|
78 |
+
)
|
79 |
+
axs[2].set_ylabel("Drumroll Energy")
|
80 |
+
axs[2].set_xlabel("Time (s)")
|
81 |
+
axs[2].legend()
|
82 |
+
axs[2].grid(True)
|
83 |
+
|
84 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
|
85 |
+
writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
|
86 |
+
plt.close(fig)
|
87 |
+
|
88 |
+
|
89 |
+
def main():
|
90 |
+
global ds
|
91 |
+
|
92 |
+
# Calculate hop seconds for model output frames
|
93 |
+
# This assumes the model output time dimension corresponds to the mel spectrogram time dimension
|
94 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
95 |
+
|
96 |
+
best_val_loss = float("inf")
|
97 |
+
patience = 10 # Increased patience a bit
|
98 |
+
pat_count = 0
|
99 |
+
|
100 |
+
ds_oni = ds.map(
|
101 |
+
preprocess,
|
102 |
+
remove_columns=ds.column_names,
|
103 |
+
fn_kwargs={"difficulty": "oni"},
|
104 |
+
writer_batch_size=10,
|
105 |
+
)
|
106 |
+
ds_hard = ds.map(
|
107 |
+
preprocess,
|
108 |
+
remove_columns=ds.column_names,
|
109 |
+
fn_kwargs={"difficulty": "hard"},
|
110 |
+
writer_batch_size=10,
|
111 |
+
)
|
112 |
+
ds_normal = ds.map(
|
113 |
+
preprocess,
|
114 |
+
remove_columns=ds.column_names,
|
115 |
+
fn_kwargs={"difficulty": "normal"},
|
116 |
+
writer_batch_size=10,
|
117 |
+
)
|
118 |
+
ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])
|
119 |
+
|
120 |
+
ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
|
121 |
+
train_loader = DataLoader(
|
122 |
+
ds_train_test["train"],
|
123 |
+
batch_size=BATCH_SIZE,
|
124 |
+
shuffle=True,
|
125 |
+
collate_fn=collate_fn,
|
126 |
+
num_workers=2,
|
127 |
+
)
|
128 |
+
val_loader = DataLoader(
|
129 |
+
ds_train_test["test"],
|
130 |
+
batch_size=BATCH_SIZE,
|
131 |
+
shuffle=False,
|
132 |
+
collate_fn=collate_fn,
|
133 |
+
num_workers=2,
|
134 |
+
)
|
135 |
+
|
136 |
+
model = TaikoConformer5().to(DEVICE)
|
137 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
138 |
+
|
139 |
+
criterion = TaikoEnergyLoss(reduction="mean").to(DEVICE)
|
140 |
+
|
141 |
+
# Adjust scheduler steps for gradient accumulation
|
142 |
+
num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
|
143 |
+
total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch
|
144 |
+
|
145 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
146 |
+
optimizer, max_lr=LR, total_steps=total_optimizer_steps
|
147 |
+
)
|
148 |
+
|
149 |
+
writer = SummaryWriter()
|
150 |
+
|
151 |
+
for epoch in range(1, EPOCHS + 1):
|
152 |
+
model.train()
|
153 |
+
total_epoch_loss = 0.0
|
154 |
+
optimizer.zero_grad()
|
155 |
+
|
156 |
+
for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
|
157 |
+
mel = batch["mel"].to(DEVICE)
|
158 |
+
# Unpack new energy-based labels
|
159 |
+
don_labels = batch["don_labels"].to(DEVICE)
|
160 |
+
ka_labels = batch["ka_labels"].to(DEVICE)
|
161 |
+
drumroll_labels = batch["drumroll_labels"].to(DEVICE)
|
162 |
+
lengths = batch["lengths"].to(
|
163 |
+
DEVICE
|
164 |
+
) # These are for the Conformer model output
|
165 |
+
nps = batch["nps"].to(DEVICE)
|
166 |
+
|
167 |
+
output_dict = model(mel, lengths, nps)
|
168 |
+
# output_dict["presence"] is now (B, T_out, 3) for don, ka, drumroll energies
|
169 |
+
pred_energies_batch = output_dict["presence"] # (B, T_out, 3)
|
170 |
+
|
171 |
+
loss_input_batch = {
|
172 |
+
"don_labels": don_labels,
|
173 |
+
"ka_labels": ka_labels,
|
174 |
+
"drumroll_labels": drumroll_labels,
|
175 |
+
"lengths": lengths, # Pass lengths for masking within the loss function
|
176 |
+
}
|
177 |
+
loss = criterion(output_dict, loss_input_batch)
|
178 |
+
|
179 |
+
(loss / GRAD_ACCUM_STEPS).backward()
|
180 |
+
|
181 |
+
if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
|
182 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
183 |
+
optimizer.step()
|
184 |
+
scheduler.step()
|
185 |
+
optimizer.zero_grad()
|
186 |
+
|
187 |
+
total_epoch_loss += loss.item()
|
188 |
+
|
189 |
+
# Log plot for the first sample of the first batch in each training epoch
|
190 |
+
if idx == 0:
|
191 |
+
first_sample_pred_don = pred_energies_batch[0, :, 0]
|
192 |
+
first_sample_pred_ka = pred_energies_batch[0, :, 1]
|
193 |
+
first_sample_pred_drumroll = pred_energies_batch[0, :, 2]
|
194 |
+
|
195 |
+
first_sample_true_don = don_labels[0, :]
|
196 |
+
first_sample_true_ka = ka_labels[0, :]
|
197 |
+
first_sample_true_drumroll = drumroll_labels[0, :]
|
198 |
+
|
199 |
+
first_sample_length = lengths[
|
200 |
+
0
|
201 |
+
].item() # Get the valid length of the first sample
|
202 |
+
|
203 |
+
log_energy_plots_to_tensorboard(
|
204 |
+
writer,
|
205 |
+
"Train/Sample_0",
|
206 |
+
epoch,
|
207 |
+
first_sample_pred_don,
|
208 |
+
first_sample_pred_ka,
|
209 |
+
first_sample_pred_drumroll,
|
210 |
+
first_sample_true_don,
|
211 |
+
first_sample_true_ka,
|
212 |
+
first_sample_true_drumroll,
|
213 |
+
first_sample_length,
|
214 |
+
output_frame_hop_sec,
|
215 |
+
)
|
216 |
+
|
217 |
+
avg_train_loss = total_epoch_loss / len(train_loader)
|
218 |
+
writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)
|
219 |
+
|
220 |
+
# Validation
|
221 |
+
model.eval()
|
222 |
+
total_val_loss = 0.0
|
223 |
+
# Removed storage for classification logits/labels and confusion matrix components
|
224 |
+
|
225 |
+
with torch.no_grad():
|
226 |
+
for val_idx, batch in enumerate(
|
227 |
+
tqdm(val_loader, desc=f"Val Epoch {epoch}")
|
228 |
+
):
|
229 |
+
mel = batch["mel"].to(DEVICE)
|
230 |
+
don_labels = batch["don_labels"].to(DEVICE)
|
231 |
+
ka_labels = batch["ka_labels"].to(DEVICE)
|
232 |
+
drumroll_labels = batch["drumroll_labels"].to(DEVICE)
|
233 |
+
lengths = batch["lengths"].to(DEVICE)
|
234 |
+
nps = batch["nps"].to(DEVICE) # Ground truth NPS from batch
|
235 |
+
|
236 |
+
output_dict = model(mel, lengths, nps)
|
237 |
+
pred_energies_val_batch = output_dict["presence"] # (B, T_out, 3)
|
238 |
+
|
239 |
+
val_loss_input_batch = {
|
240 |
+
"don_labels": don_labels,
|
241 |
+
"ka_labels": ka_labels,
|
242 |
+
"drumroll_labels": drumroll_labels,
|
243 |
+
"lengths": lengths,
|
244 |
+
}
|
245 |
+
val_loss = criterion(output_dict, val_loss_input_batch)
|
246 |
+
total_val_loss += val_loss.item()
|
247 |
+
|
248 |
+
# Log plot for the first sample of the first batch in each validation epoch
|
249 |
+
if val_idx == 0:
|
250 |
+
first_val_sample_pred_don = pred_energies_val_batch[0, :, 0]
|
251 |
+
first_val_sample_pred_ka = pred_energies_val_batch[0, :, 1]
|
252 |
+
first_val_sample_pred_drumroll = pred_energies_val_batch[0, :, 2]
|
253 |
+
|
254 |
+
first_val_sample_true_don = don_labels[0, :]
|
255 |
+
first_val_sample_true_ka = ka_labels[0, :]
|
256 |
+
first_val_sample_true_drumroll = drumroll_labels[0, :]
|
257 |
+
|
258 |
+
first_val_sample_length = lengths[0].item()
|
259 |
+
|
260 |
+
log_energy_plots_to_tensorboard(
|
261 |
+
writer,
|
262 |
+
"Eval/Sample_0",
|
263 |
+
epoch,
|
264 |
+
first_val_sample_pred_don,
|
265 |
+
first_val_sample_pred_ka,
|
266 |
+
first_val_sample_pred_drumroll,
|
267 |
+
first_val_sample_true_don,
|
268 |
+
first_val_sample_true_ka,
|
269 |
+
first_val_sample_true_drumroll,
|
270 |
+
first_val_sample_length,
|
271 |
+
output_frame_hop_sec,
|
272 |
+
)
|
273 |
+
|
274 |
+
# Log ground truth NPS for reference during validation if needed
|
275 |
+
# writer.add_scalar("NPS/GT_Val_Batch_Avg", nps.mean().item(), epoch * len(val_loader) + idx)
|
276 |
+
|
277 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
278 |
+
writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)
|
279 |
+
|
280 |
+
# Log learning rate
|
281 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
282 |
+
writer.add_scalar("LR/learning_rate", current_lr, epoch)
|
283 |
+
|
284 |
+
# Log ground truth NPS from the last validation batch (or mean over epoch)
|
285 |
+
if "nps" in batch: # Check if nps is in the last batch
|
286 |
+
writer.add_scalar(
|
287 |
+
"NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
|
288 |
+
)
|
289 |
+
|
290 |
+
print(
|
291 |
+
f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
|
292 |
+
)
|
293 |
+
|
294 |
+
if avg_val_loss < best_val_loss:
|
295 |
+
best_val_loss = avg_val_loss
|
296 |
+
pat_count = 0
|
297 |
+
torch.save(model.state_dict(), "best_model.pt") # Changed model save name
|
298 |
+
print(f"Saved new best model to best_model.pt at epoch {epoch}")
|
299 |
+
else:
|
300 |
+
pat_count += 1
|
301 |
+
if pat_count >= patience:
|
302 |
+
print("Early stopping!")
|
303 |
+
break
|
304 |
+
writer.close()
|
305 |
+
|
306 |
+
model_id = "JacobLinCool/taiko-conformer-5"
|
307 |
+
try:
|
308 |
+
model.push_to_hub(model_id, commit_message="Upload trained model")
|
309 |
+
upload_folder(
|
310 |
+
repo_id=model_id,
|
311 |
+
folder_path="runs",
|
312 |
+
path_in_repo=".",
|
313 |
+
commit_message="Upload training logs",
|
314 |
+
ignore_patterns=["*.txt", "*.json", "*.csv"],
|
315 |
+
)
|
316 |
+
print(f"Model and logs uploaded to {model_id}")
|
317 |
+
except Exception as e:
|
318 |
+
print(f"Error uploading to Hugging Face Hub: {e}")
|
319 |
+
print("Make sure you have the correct permissions and try again.")
|
320 |
+
|
321 |
+
|
322 |
+
if __name__ == "__main__":
|
323 |
+
main()
|
tc6/__init__.py
ADDED
File without changes
|
tc6/config.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# ─── 1) CONFIG ─────────────────────────────────────────────────────
|
4 |
+
SAMPLE_RATE = 22050
|
5 |
+
N_MELS = 80
|
6 |
+
HOP_LENGTH = 256
|
7 |
+
TIME_SUB = 1
|
8 |
+
CNN_CH = 256
|
9 |
+
N_HEADS = 8
|
10 |
+
D_MODEL = 512
|
11 |
+
FF_DIM = 1024
|
12 |
+
N_LAYERS = 6
|
13 |
+
DEPTHWISE_CONV_KERNEL_SIZE = 31
|
14 |
+
DROPOUT = 0.1
|
15 |
+
HIDDEN_DIM = 64
|
16 |
+
N_TYPES = 7
|
17 |
+
BATCH_SIZE = 2
|
18 |
+
GRAD_ACCUM_STEPS = 8
|
19 |
+
LR = 3e-4
|
20 |
+
EPOCHS = 200
|
21 |
+
DEVICE = (
|
22 |
+
"cuda"
|
23 |
+
if torch.cuda.is_available()
|
24 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
25 |
+
)
|
tc6/dataset.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset, concatenate_datasets
|
2 |
+
|
3 |
+
ds1 = load_dataset("JacobLinCool/taiko-2023-1.1", split="train")
|
4 |
+
ds2 = load_dataset("JacobLinCool/taiko-2023-1.2", split="train")
|
5 |
+
ds3 = load_dataset("JacobLinCool/taiko-2023-1.3", split="train")
|
6 |
+
ds4 = load_dataset("JacobLinCool/taiko-2023-1.4", split="train")
|
7 |
+
ds5 = load_dataset("JacobLinCool/taiko-2023-1.5", split="train")
|
8 |
+
ds6 = load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
9 |
+
ds7 = load_dataset("JacobLinCool/taiko-2023-1.7", split="train")
|
10 |
+
ds = concatenate_datasets([ds1, ds2, ds3, ds4, ds5, ds6, ds7]).with_format("torch")
|
11 |
+
|
12 |
+
good = list(range(len(ds)))
|
13 |
+
good.remove(1079) # 1079 has file problem
|
14 |
+
ds = ds.select(good)
|
15 |
+
|
16 |
+
# for local test
|
17 |
+
# ds = (
|
18 |
+
# load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
19 |
+
# .with_format("torch")
|
20 |
+
# .select(range(10))
|
21 |
+
# )
|
tc6/infer.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH
|
7 |
+
import torch.profiler
|
8 |
+
|
9 |
+
|
10 |
+
# --- PREPROCESSING (match training) ---
|
11 |
+
def preprocess_audio(audio_path):
|
12 |
+
wav, sr = torchaudio.load(audio_path)
|
13 |
+
wav = wav.mean(dim=0) # mono
|
14 |
+
if sr != SAMPLE_RATE:
|
15 |
+
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
|
16 |
+
wav = wav / (wav.abs().max() + 1e-8) # Normalize audio
|
17 |
+
|
18 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
19 |
+
sample_rate=SAMPLE_RATE,
|
20 |
+
n_mels=N_MELS,
|
21 |
+
hop_length=HOP_LENGTH,
|
22 |
+
n_fft=2048,
|
23 |
+
)
|
24 |
+
mel = mel_transform(wav)
|
25 |
+
return mel # mel is (N_MELS, T_mel)
|
26 |
+
|
27 |
+
|
28 |
+
# --- INFERENCE ---
|
29 |
+
def run_inference(model, mel_input, nps_input, difficulty_input, level_input, device):
|
30 |
+
model.eval()
|
31 |
+
with torch.no_grad():
|
32 |
+
mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel)
|
33 |
+
nps = nps_input.to(device).unsqueeze(0) # (1,)
|
34 |
+
difficulty = difficulty_input.to(device).unsqueeze(0) # (1,)
|
35 |
+
level = level_input.to(device).unsqueeze(0) # (1,)
|
36 |
+
|
37 |
+
mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel)
|
38 |
+
|
39 |
+
conformer_lengths = torch.tensor(
|
40 |
+
[mel_cnn_input.shape[-1]], dtype=torch.long, device=device
|
41 |
+
)
|
42 |
+
|
43 |
+
with torch.profiler.profile(
|
44 |
+
activities=[
|
45 |
+
torch.profiler.ProfilerActivity.CPU,
|
46 |
+
*(
|
47 |
+
[torch.profiler.ProfilerActivity.CUDA]
|
48 |
+
if device.type == "cuda"
|
49 |
+
else []
|
50 |
+
),
|
51 |
+
],
|
52 |
+
record_shapes=True,
|
53 |
+
profile_memory=True,
|
54 |
+
with_stack=False,
|
55 |
+
with_flops=True,
|
56 |
+
) as prof:
|
57 |
+
out_dict = model(mel_cnn_input, conformer_lengths, nps, difficulty, level)
|
58 |
+
print(
|
59 |
+
prof.key_averages().table(
|
60 |
+
sort_by=(
|
61 |
+
"self_cuda_memory_usage"
|
62 |
+
if device.type == "cuda"
|
63 |
+
else "self_cpu_time_total"
|
64 |
+
),
|
65 |
+
row_limit=20,
|
66 |
+
)
|
67 |
+
)
|
68 |
+
|
69 |
+
energies = out_dict["presence"].squeeze(0).cpu().numpy()
|
70 |
+
|
71 |
+
don_energy = energies[:, 0]
|
72 |
+
ka_energy = energies[:, 1]
|
73 |
+
drumroll_energy = energies[:, 2]
|
74 |
+
|
75 |
+
return don_energy, ka_energy, drumroll_energy
|
76 |
+
|
77 |
+
|
78 |
+
# --- DECODE TO ONSETS ---
|
79 |
+
def decode_onsets(
|
80 |
+
don_energy,
|
81 |
+
ka_energy,
|
82 |
+
drumroll_energy,
|
83 |
+
hop_sec,
|
84 |
+
threshold=0.5,
|
85 |
+
min_distance_frames=3,
|
86 |
+
):
|
87 |
+
results = []
|
88 |
+
T_out = len(don_energy)
|
89 |
+
last_onset_frame = -min_distance_frames
|
90 |
+
|
91 |
+
for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection
|
92 |
+
if i < last_onset_frame + min_distance_frames:
|
93 |
+
continue
|
94 |
+
|
95 |
+
e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i]
|
96 |
+
energies_at_i = {
|
97 |
+
1: e_don,
|
98 |
+
2: e_ka,
|
99 |
+
5: e_drum,
|
100 |
+
} # Type mapping: 1:Don, 2:Ka, 5:Drumroll
|
101 |
+
|
102 |
+
# Find which energy is max and if it's a peak above threshold
|
103 |
+
# Sort by energy value descending to prioritize higher energy in case of ties for peak condition
|
104 |
+
sorted_types_by_energy = sorted(
|
105 |
+
energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True
|
106 |
+
)
|
107 |
+
|
108 |
+
detected_this_frame = False
|
109 |
+
for onset_type in sorted_types_by_energy:
|
110 |
+
current_energy_series = None
|
111 |
+
if onset_type == 1:
|
112 |
+
current_energy_series = don_energy
|
113 |
+
elif onset_type == 2:
|
114 |
+
current_energy_series = ka_energy
|
115 |
+
elif onset_type == 5:
|
116 |
+
current_energy_series = drumroll_energy
|
117 |
+
|
118 |
+
energy_val = current_energy_series[i]
|
119 |
+
|
120 |
+
if (
|
121 |
+
energy_val > threshold
|
122 |
+
and energy_val > current_energy_series[i - 1]
|
123 |
+
and energy_val > current_energy_series[i + 1]
|
124 |
+
):
|
125 |
+
# Check if this energy is the highest among the three at this frame
|
126 |
+
# This check is implicitly handled by iterating `sorted_types_by_energy`
|
127 |
+
# and breaking after the first detection.
|
128 |
+
results.append((i * hop_sec, onset_type))
|
129 |
+
last_onset_frame = i
|
130 |
+
detected_this_frame = True
|
131 |
+
break # Only one onset type per frame
|
132 |
+
|
133 |
+
return results
|
134 |
+
|
135 |
+
|
136 |
+
# --- VISUALIZATION ---
|
137 |
+
def plot_results(
|
138 |
+
mel_spectrogram,
|
139 |
+
don_energy,
|
140 |
+
ka_energy,
|
141 |
+
drumroll_energy,
|
142 |
+
onsets,
|
143 |
+
hop_sec,
|
144 |
+
out_path=None,
|
145 |
+
):
|
146 |
+
# mel_spectrogram is (N_MELS, T_mel)
|
147 |
+
T_mel = mel_spectrogram.shape[1]
|
148 |
+
T_out = len(don_energy) # Length of energy arrays (model output time dimension)
|
149 |
+
|
150 |
+
# Time axes
|
151 |
+
time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE)
|
152 |
+
# hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE
|
153 |
+
# However, the model output T_out is related to T_mel (input to CNN).
|
154 |
+
# If CNN does not change time dimension, T_out = T_mel.
|
155 |
+
# If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB.
|
156 |
+
# The `lengths` passed to conformer in `run_inference` is T_mel.
|
157 |
+
# The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`.
|
158 |
+
# So, T_out from model is T_mel.
|
159 |
+
# The `hop_sec` for onsets should be based on the model output frame rate.
|
160 |
+
# If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE.
|
161 |
+
# The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE.
|
162 |
+
# This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels.
|
163 |
+
# The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`.
|
164 |
+
# In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN.
|
165 |
+
# The `lengths` for the conformer is based on this T_cnn_out.
|
166 |
+
# So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps.
|
167 |
+
# Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer.
|
168 |
+
# Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames.
|
169 |
+
time_axis_energies = np.arange(T_out) * hop_sec
|
170 |
+
|
171 |
+
fig, ax1 = plt.subplots(figsize=(100, 10))
|
172 |
+
|
173 |
+
# Plot Mel Spectrogram on ax1
|
174 |
+
mel_db = torchaudio.functional.amplitude_to_DB(
|
175 |
+
mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0
|
176 |
+
)
|
177 |
+
img = ax1.imshow(
|
178 |
+
mel_db.numpy(),
|
179 |
+
aspect="auto",
|
180 |
+
origin="lower",
|
181 |
+
cmap="magma",
|
182 |
+
extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS],
|
183 |
+
)
|
184 |
+
ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets")
|
185 |
+
ax1.set_xlabel("Time (s)")
|
186 |
+
ax1.set_ylabel("Mel Bin")
|
187 |
+
fig.colorbar(img, ax=ax1, format="%+2.0f dB")
|
188 |
+
|
189 |
+
# Create a second y-axis for energies
|
190 |
+
ax2 = ax1.twinx()
|
191 |
+
ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red")
|
192 |
+
ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue")
|
193 |
+
ax2.plot(
|
194 |
+
time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green"
|
195 |
+
)
|
196 |
+
ax2.set_ylabel("Energy")
|
197 |
+
ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded
|
198 |
+
|
199 |
+
# Overlay onsets from decode_onsets (t is already in seconds)
|
200 |
+
labeled_types = set()
|
201 |
+
# Group drumrolls into segments (reuse logic from write_tja)
|
202 |
+
drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5]
|
203 |
+
drumroll_times.sort()
|
204 |
+
drumroll_segments = []
|
205 |
+
if drumroll_times:
|
206 |
+
seg_start = drumroll_times[0]
|
207 |
+
prev = drumroll_times[0]
|
208 |
+
for t in drumroll_times[1:]:
|
209 |
+
if t - prev <= hop_sec * 6: # up to 5-frame gap
|
210 |
+
prev = t
|
211 |
+
else:
|
212 |
+
drumroll_segments.append((seg_start, prev))
|
213 |
+
seg_start = t
|
214 |
+
prev = t
|
215 |
+
drumroll_segments.append((seg_start, prev))
|
216 |
+
# Plot Don/Ka onsets as vertical lines
|
217 |
+
for t_sec, typ in onsets:
|
218 |
+
if typ == 5:
|
219 |
+
continue # skip drumroll onsets
|
220 |
+
color_map = {1: "darkred", 2: "darkblue"}
|
221 |
+
label_map = {1: "Don Onset", 2: "Ka Onset"}
|
222 |
+
line_color = color_map.get(typ, "black")
|
223 |
+
line_label = label_map.get(typ, f"Type {typ} Onset")
|
224 |
+
if typ not in labeled_types:
|
225 |
+
ax1.axvline(
|
226 |
+
t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label
|
227 |
+
)
|
228 |
+
labeled_types.add(typ)
|
229 |
+
else:
|
230 |
+
ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9)
|
231 |
+
# Plot drumroll segments as shaded regions
|
232 |
+
for seg_start, seg_end in drumroll_segments:
|
233 |
+
ax1.axvspan(
|
234 |
+
seg_start,
|
235 |
+
seg_end + hop_sec,
|
236 |
+
color="green",
|
237 |
+
alpha=0.2,
|
238 |
+
label="Drumroll Segment" if "drumroll" not in labeled_types else None,
|
239 |
+
)
|
240 |
+
labeled_types.add("drumroll")
|
241 |
+
|
242 |
+
# Combine legends from both axes
|
243 |
+
lines, labels = ax1.get_legend_handles_labels()
|
244 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
245 |
+
ax2.legend(lines + lines2, labels + labels2, loc="upper right")
|
246 |
+
|
247 |
+
fig.tight_layout()
|
248 |
+
|
249 |
+
# Return plot as image buffer or save to file if path provided
|
250 |
+
if out_path:
|
251 |
+
plt.savefig(out_path)
|
252 |
+
print(f"Saved plot to {out_path}")
|
253 |
+
plt.close(fig)
|
254 |
+
return out_path
|
255 |
+
else:
|
256 |
+
# Return plot as in-memory buffer
|
257 |
+
return fig
|
258 |
+
|
259 |
+
|
260 |
+
def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
|
261 |
+
# TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
|
262 |
+
# Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
|
263 |
+
sec_per_beat = 60 / bpm
|
264 |
+
beats_per_measure = 4 # Assuming 4/4 time signature
|
265 |
+
sec_per_measure = sec_per_beat * beats_per_measure
|
266 |
+
# Step 1: Map onsets to (measure_idx, slot, typ)
|
267 |
+
slot_events = []
|
268 |
+
for t, typ in onsets:
|
269 |
+
measure_idx = int(t // sec_per_measure)
|
270 |
+
t_in_measure = t % sec_per_measure
|
271 |
+
slot = int(round(t_in_measure / sec_per_measure * quantize))
|
272 |
+
if slot >= quantize:
|
273 |
+
slot = quantize - 1
|
274 |
+
slot_events.append((measure_idx, slot, typ))
|
275 |
+
# Step 2: Build measure/slot grid
|
276 |
+
if slot_events:
|
277 |
+
max_measure_idx = max(m for m, _, _ in slot_events)
|
278 |
+
else:
|
279 |
+
max_measure_idx = -1
|
280 |
+
measures = {i: [0] * quantize for i in range(max_measure_idx + 1)}
|
281 |
+
# Step 3: Place Don/Ka, collect drumrolls
|
282 |
+
drumroll_slots = set()
|
283 |
+
for m, s, typ in slot_events:
|
284 |
+
if typ in [1, 2]:
|
285 |
+
measures[m][s] = typ
|
286 |
+
elif typ == 5:
|
287 |
+
drumroll_slots.add((m, s))
|
288 |
+
# Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end)
|
289 |
+
# Flatten all slots to a list of (measure, slot) sorted
|
290 |
+
drumroll_list = sorted(list(drumroll_slots))
|
291 |
+
# Group into contiguous regions (allowing a gap of 5 slots)
|
292 |
+
grouped = []
|
293 |
+
group = []
|
294 |
+
for ms in drumroll_list:
|
295 |
+
if not group:
|
296 |
+
group = [ms]
|
297 |
+
else:
|
298 |
+
last_m, last_s = group[-1]
|
299 |
+
m, s = ms
|
300 |
+
# Calculate slot distance, considering measure wrap
|
301 |
+
slot_dist = None
|
302 |
+
if m == last_m:
|
303 |
+
slot_dist = s - last_s
|
304 |
+
elif m == last_m + 1 and last_s <= quantize - 1:
|
305 |
+
slot_dist = (quantize - 1 - last_s) + s + 1
|
306 |
+
else:
|
307 |
+
slot_dist = None
|
308 |
+
# Allow gap of up to 5 slots (slot_dist <= 6)
|
309 |
+
if slot_dist is not None and 1 <= slot_dist <= 6:
|
310 |
+
group.append(ms)
|
311 |
+
else:
|
312 |
+
grouped.append(group)
|
313 |
+
group = [ms]
|
314 |
+
if group:
|
315 |
+
grouped.append(group)
|
316 |
+
# Mark 5 (start) and 8 (end) for each group
|
317 |
+
for region in grouped:
|
318 |
+
if len(region) == 1:
|
319 |
+
m, s = region[0]
|
320 |
+
measures[m][s] = 5
|
321 |
+
# Place 8 in next slot (or next measure if at end)
|
322 |
+
if s < quantize - 1:
|
323 |
+
measures[m][s + 1] = 8
|
324 |
+
elif m < max_measure_idx:
|
325 |
+
measures[m + 1][0] = 8
|
326 |
+
else:
|
327 |
+
m_start, s_start = region[0]
|
328 |
+
m_end, s_end = region[-1]
|
329 |
+
measures[m_start][s_start] = 5
|
330 |
+
measures[m_end][s_end] = 8
|
331 |
+
# Fill 0 for middle slots (already 0 by default)
|
332 |
+
# Step 5: Generate TJA content
|
333 |
+
tja_content = []
|
334 |
+
tja_content.append(f"TITLE:{audio} (TC6, {time.strftime('%Y-%m-%d %H:%M:%S')})")
|
335 |
+
tja_content.append(f"BPM:{bpm}")
|
336 |
+
tja_content.append(f"WAVE:{audio}")
|
337 |
+
tja_content.append("OFFSET:0")
|
338 |
+
tja_content.append("COURSE:Oni\nLEVEL:9\n")
|
339 |
+
tja_content.append("#START")
|
340 |
+
for i in range(max_measure_idx + 1):
|
341 |
+
notes = measures.get(i, [0] * quantize)
|
342 |
+
line = "".join(str(n) for n in notes)
|
343 |
+
tja_content.append(line + ",")
|
344 |
+
tja_content.append("#END")
|
345 |
+
|
346 |
+
tja_string = "\n".join(tja_content)
|
347 |
+
|
348 |
+
# If out_path is provided, also write to file
|
349 |
+
if out_path:
|
350 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
351 |
+
f.write(tja_string)
|
352 |
+
print(f"TJA chart saved to {out_path}")
|
353 |
+
|
354 |
+
return tja_string
|
tc6/loss.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class TaikoEnergyLoss(nn.Module):
|
6 |
+
def __init__(self, reduction="mean"):
|
7 |
+
super().__init__()
|
8 |
+
# Use 'none' reduction to get element-wise losses, then manually apply masking and reduction
|
9 |
+
self.mse_loss = nn.MSELoss(reduction="none")
|
10 |
+
self.reduction = reduction
|
11 |
+
|
12 |
+
def forward(self, outputs, batch):
|
13 |
+
"""
|
14 |
+
Calculates the MSE loss for energy-based predictions.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
outputs (dict): Model output, containing 'presence' tensor.
|
18 |
+
outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
|
19 |
+
batch (dict): Batch data from collate_fn, containing true labels and lengths.
|
20 |
+
batch['don_labels'], batch['ka_labels'], batch['drumroll_labels'] shape: (B, T)
|
21 |
+
batch['lengths'] shape: (B,) - valid sequence lengths for time dimension T.
|
22 |
+
Returns:
|
23 |
+
torch.Tensor: The calculated loss.
|
24 |
+
"""
|
25 |
+
pred_energies = outputs["presence"] # (B, T, 3)
|
26 |
+
|
27 |
+
true_don = batch["don_labels"] # (B, T)
|
28 |
+
true_ka = batch["ka_labels"] # (B, T)
|
29 |
+
true_drumroll = batch["drumroll_labels"] # (B, T)
|
30 |
+
|
31 |
+
# Stack true labels to match the structure of pred_energies (B, T, 3)
|
32 |
+
true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2)
|
33 |
+
|
34 |
+
B, T, _ = pred_energies.shape
|
35 |
+
|
36 |
+
# Create a mask based on batch['lengths'] to ignore padded parts of sequences
|
37 |
+
# batch['lengths'] gives the actual length of each sequence in the batch
|
38 |
+
# mask shape: (B, T)
|
39 |
+
mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
|
40 |
+
"lengths"
|
41 |
+
].unsqueeze(1)
|
42 |
+
# Expand mask to (B, T, 1) to broadcast across the 3 energy channels
|
43 |
+
mask_3d = mask_2d.unsqueeze(2)
|
44 |
+
|
45 |
+
# Calculate element-wise MSE loss
|
46 |
+
loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3)
|
47 |
+
|
48 |
+
# Apply the mask to the loss
|
49 |
+
masked_loss = loss_elementwise * mask_3d
|
50 |
+
|
51 |
+
if self.reduction == "mean":
|
52 |
+
# Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
|
53 |
+
total_loss = masked_loss.sum()
|
54 |
+
num_valid_elements = mask_3d.sum() # Total number of unmasked float values
|
55 |
+
if num_valid_elements > 0:
|
56 |
+
return total_loss / num_valid_elements
|
57 |
+
else:
|
58 |
+
# Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
|
59 |
+
return torch.tensor(
|
60 |
+
0.0, device=pred_energies.device, requires_grad=True
|
61 |
+
)
|
62 |
+
elif self.reduction == "sum":
|
63 |
+
return masked_loss.sum()
|
64 |
+
else: # 'none' or any other case
|
65 |
+
return masked_loss
|
tc6/model.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchaudio.models import Conformer
|
4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
5 |
+
from .config import (
|
6 |
+
N_MELS,
|
7 |
+
CNN_CH,
|
8 |
+
N_HEADS,
|
9 |
+
D_MODEL,
|
10 |
+
FF_DIM,
|
11 |
+
N_LAYERS,
|
12 |
+
DROPOUT,
|
13 |
+
DEPTHWISE_CONV_KERNEL_SIZE,
|
14 |
+
HIDDEN_DIM,
|
15 |
+
DEVICE,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class TaikoConformer6(nn.Module, PyTorchModelHubMixin):
|
20 |
+
def __init__(self):
|
21 |
+
super().__init__()
|
22 |
+
# 1) CNN frontend: frequency-only pooling
|
23 |
+
self.cnn = nn.Sequential(
|
24 |
+
nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
|
25 |
+
nn.BatchNorm2d(CNN_CH),
|
26 |
+
nn.GELU(),
|
27 |
+
nn.Dropout2d(DROPOUT),
|
28 |
+
nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
|
29 |
+
nn.BatchNorm2d(CNN_CH),
|
30 |
+
nn.GELU(),
|
31 |
+
nn.Dropout2d(DROPOUT),
|
32 |
+
)
|
33 |
+
feat_dim = CNN_CH * (N_MELS // 4)
|
34 |
+
|
35 |
+
# 2) Linear projection to model dimension
|
36 |
+
self.proj = nn.Linear(feat_dim, D_MODEL)
|
37 |
+
|
38 |
+
# 3) FiLM conditioning for notes_per_second, difficulty, and level
|
39 |
+
self.film_nps = nn.Linear(1, 2 * D_MODEL)
|
40 |
+
self.film_difficulty = nn.Linear(
|
41 |
+
1, 2 * D_MODEL
|
42 |
+
) # Assuming difficulty is a single scalar
|
43 |
+
self.film_level = nn.Linear(1, 2 * D_MODEL) # Assuming level is a single scalar
|
44 |
+
|
45 |
+
# 4) Conformer encoder
|
46 |
+
self.encoder = Conformer(
|
47 |
+
input_dim=D_MODEL,
|
48 |
+
num_heads=N_HEADS,
|
49 |
+
ffn_dim=FF_DIM,
|
50 |
+
num_layers=N_LAYERS,
|
51 |
+
depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
|
52 |
+
dropout=DROPOUT,
|
53 |
+
use_group_norm=False,
|
54 |
+
convolution_first=False,
|
55 |
+
)
|
56 |
+
|
57 |
+
# 5) Presence regressor head
|
58 |
+
self.presence_regressor = nn.Sequential(
|
59 |
+
nn.Dropout(DROPOUT),
|
60 |
+
nn.Linear(D_MODEL, HIDDEN_DIM),
|
61 |
+
nn.GELU(),
|
62 |
+
nn.Dropout(DROPOUT),
|
63 |
+
nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
|
64 |
+
nn.Sigmoid(), # Output between 0 and 1
|
65 |
+
)
|
66 |
+
|
67 |
+
# 6) Initialize weights
|
68 |
+
for m in self.modules():
|
69 |
+
if isinstance(m, nn.Conv2d):
|
70 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
71 |
+
elif isinstance(m, nn.Linear):
|
72 |
+
nn.init.xavier_uniform_(m.weight)
|
73 |
+
if m.bias is not None:
|
74 |
+
nn.init.zeros_(m.bias)
|
75 |
+
|
76 |
+
def forward(
|
77 |
+
self,
|
78 |
+
mel: torch.Tensor,
|
79 |
+
lengths: torch.Tensor,
|
80 |
+
notes_per_second: torch.Tensor,
|
81 |
+
difficulty: torch.Tensor,
|
82 |
+
level: torch.Tensor,
|
83 |
+
):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
mel: (B, 1, N_MELS, T_mel)
|
87 |
+
lengths: (B,) lengths after CNN
|
88 |
+
notes_per_second: (B,) stream of control values
|
89 |
+
difficulty: (B,) difficulty values
|
90 |
+
level: (B,) level values
|
91 |
+
Returns:
|
92 |
+
Dict with:
|
93 |
+
'presence': (B, T_cnn_out, 3) # Corrected from 4 to 3
|
94 |
+
'lengths': lengths
|
95 |
+
"""
|
96 |
+
# CNN frontend
|
97 |
+
x = self.cnn(mel) # (B, C, F, T)
|
98 |
+
B, C, F, T = x.size()
|
99 |
+
x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
|
100 |
+
|
101 |
+
# Project to model dimension
|
102 |
+
x = self.proj(x) # (B, T, D_MODEL)
|
103 |
+
|
104 |
+
# FiLM conditioning
|
105 |
+
nps = notes_per_second.unsqueeze(-1).float() # (B, 1)
|
106 |
+
gamma_beta_nps = self.film_nps(nps) # (B, 2*D_MODEL)
|
107 |
+
gamma_nps, beta_nps = gamma_beta_nps.chunk(2, dim=-1)
|
108 |
+
x = gamma_nps.unsqueeze(1) * x + beta_nps.unsqueeze(1)
|
109 |
+
|
110 |
+
diff = difficulty.unsqueeze(-1).float() # (B, 1)
|
111 |
+
gamma_beta_diff = self.film_difficulty(diff) # (B, 2*D_MODEL)
|
112 |
+
gamma_diff, beta_diff = gamma_beta_diff.chunk(2, dim=-1)
|
113 |
+
x = gamma_diff.unsqueeze(1) * x + beta_diff.unsqueeze(1)
|
114 |
+
|
115 |
+
lvl = level.unsqueeze(-1).float() # (B, 1)
|
116 |
+
gamma_beta_lvl = self.film_level(lvl) # (B, 2*D_MODEL)
|
117 |
+
gamma_lvl, beta_lvl = gamma_beta_lvl.chunk(2, dim=-1)
|
118 |
+
x = gamma_lvl.unsqueeze(1) * x + beta_lvl.unsqueeze(1)
|
119 |
+
|
120 |
+
# Conformer encoder
|
121 |
+
x, _ = self.encoder(x, lengths=lengths)
|
122 |
+
|
123 |
+
# Presence prediction
|
124 |
+
presence = self.presence_regressor(x)
|
125 |
+
return {"presence": presence, "lengths": lengths}
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
model = TaikoConformer6().to(device=DEVICE)
|
130 |
+
print(model)
|
131 |
+
|
132 |
+
for name, param in model.named_parameters():
|
133 |
+
if param.requires_grad:
|
134 |
+
print(f"{name}: {param.numel():,}")
|
135 |
+
|
136 |
+
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
137 |
+
print(f"Total parameters: {params / 1e6:.2f}M")
|
138 |
+
|
139 |
+
batch_size = 4
|
140 |
+
mel_time_steps = 1024
|
141 |
+
input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
|
142 |
+
|
143 |
+
conformer_lengths = torch.tensor(
|
144 |
+
[mel_time_steps] * batch_size, dtype=torch.long
|
145 |
+
).to(DEVICE)
|
146 |
+
|
147 |
+
notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
|
148 |
+
DEVICE
|
149 |
+
)
|
150 |
+
difficulty_input = torch.tensor([1.0] * batch_size, dtype=torch.float32).to(
|
151 |
+
DEVICE
|
152 |
+
) # Example difficulty
|
153 |
+
level_input = torch.tensor([5.0] * batch_size, dtype=torch.float32).to(
|
154 |
+
DEVICE
|
155 |
+
) # Example level
|
156 |
+
|
157 |
+
output = model(
|
158 |
+
input_mel,
|
159 |
+
conformer_lengths,
|
160 |
+
notes_per_second_input,
|
161 |
+
difficulty_input,
|
162 |
+
level_input,
|
163 |
+
)
|
164 |
+
print("Output shapes:")
|
165 |
+
for key, value in output.items():
|
166 |
+
print(f"{key}: {value.shape}")
|
tc6/preprocess.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchaudio
|
5 |
+
from torchaudio.transforms import FrequencyMasking
|
6 |
+
from tja import parse_tja, PyParsingMode
|
7 |
+
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
|
8 |
+
from .model import TaikoConformer6
|
9 |
+
|
10 |
+
|
11 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
12 |
+
sample_rate=SAMPLE_RATE,
|
13 |
+
n_mels=N_MELS,
|
14 |
+
hop_length=HOP_LENGTH,
|
15 |
+
n_fft=2048,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
freq_mask = FrequencyMasking(freq_mask_param=15)
|
20 |
+
|
21 |
+
|
22 |
+
def preprocess(example, difficulty="oni"):
|
23 |
+
wav_tensor = example["audio"]["array"]
|
24 |
+
sr = example["audio"]["sampling_rate"]
|
25 |
+
# 1) load & resample
|
26 |
+
if sr != SAMPLE_RATE:
|
27 |
+
wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
|
28 |
+
# normalize audio
|
29 |
+
wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
|
30 |
+
# add random Gaussian noise
|
31 |
+
if torch.rand(1).item() < 0.5:
|
32 |
+
wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
|
33 |
+
# 2) mel: (1, N_MELS, T)
|
34 |
+
mel = mel_transform(wav_tensor).unsqueeze(0)
|
35 |
+
# apply SpecAugment
|
36 |
+
mel = freq_mask(mel)
|
37 |
+
_, _, T = mel.shape
|
38 |
+
# 3) build label sequence of length ceil(T / TIME_SUB)
|
39 |
+
T_sub = math.ceil(T / TIME_SUB)
|
40 |
+
|
41 |
+
# Initialize energy-based labels for Don, Ka, Drumroll
|
42 |
+
don_labels = torch.zeros(T_sub, dtype=torch.float32)
|
43 |
+
ka_labels = torch.zeros(T_sub, dtype=torch.float32)
|
44 |
+
drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
|
45 |
+
|
46 |
+
# Define exponential decay tail parameters
|
47 |
+
tail_length = 40 # number of frames for decay tail
|
48 |
+
decay_rate = 8.0 # decay rate parameter, adjust as needed
|
49 |
+
tail_kernel = torch.exp(
|
50 |
+
-torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
|
51 |
+
)
|
52 |
+
|
53 |
+
fps = SAMPLE_RATE / HOP_LENGTH
|
54 |
+
num_valid_notes = 0
|
55 |
+
for onset in example[difficulty]:
|
56 |
+
typ, t_start, t_end, *_ = onset
|
57 |
+
|
58 |
+
# Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
|
59 |
+
if typ < 1 or typ > N_TYPES: # Filter out invalid types
|
60 |
+
continue
|
61 |
+
|
62 |
+
num_valid_notes += 1
|
63 |
+
|
64 |
+
exact_frame_start = t_start.item() * fps
|
65 |
+
|
66 |
+
# Type 1 and 3 are Don, Type 2 and 4 are Ka
|
67 |
+
if typ == 1 or typ == 3 or typ == 2 or typ == 4:
|
68 |
+
exact_hit_time_sub = exact_frame_start / TIME_SUB
|
69 |
+
|
70 |
+
current_labels = don_labels if (typ == 1 or typ == 3) else ka_labels
|
71 |
+
|
72 |
+
start_points_info = []
|
73 |
+
rounded_hit_time_sub = round(exact_hit_time_sub)
|
74 |
+
|
75 |
+
if (
|
76 |
+
abs(exact_hit_time_sub - rounded_hit_time_sub) < 1e-6
|
77 |
+
): # Tolerance for float precision
|
78 |
+
idx_single = int(rounded_hit_time_sub)
|
79 |
+
if 0 <= idx_single < T_sub:
|
80 |
+
start_points_info.append({"idx": idx_single, "weight": 1.0})
|
81 |
+
else:
|
82 |
+
idx_floor = math.floor(exact_hit_time_sub)
|
83 |
+
idx_ceil = idx_floor + 1
|
84 |
+
|
85 |
+
frac = exact_hit_time_sub - idx_floor
|
86 |
+
weight_ceil = frac
|
87 |
+
weight_floor = 1.0 - frac
|
88 |
+
|
89 |
+
if weight_floor > 1e-6 and 0 <= idx_floor < T_sub:
|
90 |
+
start_points_info.append({"idx": idx_floor, "weight": weight_floor})
|
91 |
+
if weight_ceil > 1e-6 and 0 <= idx_ceil < T_sub:
|
92 |
+
start_points_info.append({"idx": idx_ceil, "weight": weight_ceil})
|
93 |
+
|
94 |
+
for point_info in start_points_info:
|
95 |
+
start_idx = point_info["idx"]
|
96 |
+
weight = point_info["weight"]
|
97 |
+
for k_idx, kernel_val in enumerate(tail_kernel):
|
98 |
+
target_idx = start_idx + k_idx
|
99 |
+
if 0 <= target_idx < T_sub:
|
100 |
+
current_labels[target_idx] = max(
|
101 |
+
current_labels[target_idx].item(),
|
102 |
+
weight * kernel_val.item(),
|
103 |
+
)
|
104 |
+
|
105 |
+
# Type 5, 6, 7 are Drumroll
|
106 |
+
elif typ >= 5 and typ <= 7:
|
107 |
+
exact_frame_end = t_end.item() * fps
|
108 |
+
exact_start_time_sub = exact_frame_start / TIME_SUB
|
109 |
+
exact_end_time_sub = exact_frame_end / TIME_SUB
|
110 |
+
|
111 |
+
# Improved drumroll body
|
112 |
+
body_loop_start_idx = math.floor(exact_start_time_sub)
|
113 |
+
body_loop_end_idx = math.ceil(exact_end_time_sub)
|
114 |
+
|
115 |
+
for dr_idx in range(body_loop_start_idx, body_loop_end_idx):
|
116 |
+
if 0 <= dr_idx < T_sub:
|
117 |
+
drumroll_labels[dr_idx] = 1.0
|
118 |
+
|
119 |
+
# Improved drumroll tail (starts from exact_end_time_sub)
|
120 |
+
tail_start_points_info = []
|
121 |
+
rounded_end_time_sub = round(exact_end_time_sub)
|
122 |
+
if abs(exact_end_time_sub - rounded_end_time_sub) < 1e-6:
|
123 |
+
idx_single_tail = int(rounded_end_time_sub)
|
124 |
+
if 0 <= idx_single_tail < T_sub:
|
125 |
+
tail_start_points_info.append(
|
126 |
+
{"idx": idx_single_tail, "weight": 1.0}
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
idx_floor_tail = math.floor(exact_end_time_sub)
|
130 |
+
idx_ceil_tail = idx_floor_tail + 1
|
131 |
+
|
132 |
+
frac_tail = exact_end_time_sub - idx_floor_tail
|
133 |
+
weight_ceil_tail = frac_tail
|
134 |
+
weight_floor_tail = 1.0 - frac_tail
|
135 |
+
|
136 |
+
if weight_floor_tail > 1e-6 and 0 <= idx_floor_tail < T_sub:
|
137 |
+
tail_start_points_info.append(
|
138 |
+
{"idx": idx_floor_tail, "weight": weight_floor_tail}
|
139 |
+
)
|
140 |
+
if weight_ceil_tail > 1e-6 and 0 <= idx_ceil_tail < T_sub:
|
141 |
+
tail_start_points_info.append(
|
142 |
+
{"idx": idx_ceil_tail, "weight": weight_ceil_tail}
|
143 |
+
)
|
144 |
+
|
145 |
+
for point_info in tail_start_points_info:
|
146 |
+
start_idx = point_info["idx"]
|
147 |
+
weight = point_info["weight"]
|
148 |
+
for k_idx, kernel_val in enumerate(tail_kernel):
|
149 |
+
target_idx = start_idx + k_idx
|
150 |
+
if 0 <= target_idx < T_sub:
|
151 |
+
drumroll_labels[target_idx] = max(
|
152 |
+
drumroll_labels[target_idx].item(),
|
153 |
+
weight * kernel_val.item(),
|
154 |
+
)
|
155 |
+
|
156 |
+
duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
|
157 |
+
nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
|
158 |
+
|
159 |
+
parsed = parse_tja(example["tja"], mode=PyParsingMode.Full)
|
160 |
+
chart = next(
|
161 |
+
(chart for chart in parsed.charts if chart.course.lower() == difficulty), None
|
162 |
+
)
|
163 |
+
difficulty_id = (
|
164 |
+
0
|
165 |
+
if difficulty == "easy"
|
166 |
+
else (
|
167 |
+
1
|
168 |
+
if difficulty == "normal"
|
169 |
+
else 2 if difficulty == "hard" else 3 if difficulty == "oni" else 4
|
170 |
+
) # Assuming 4 for edit/ura
|
171 |
+
)
|
172 |
+
level = chart.level if chart else 0
|
173 |
+
|
174 |
+
# --- CNN shape inference and label padding/truncation ---
|
175 |
+
# Simulate CNN to get output time length (T_cnn)
|
176 |
+
dummy_model = TaikoConformer6()
|
177 |
+
with torch.no_grad():
|
178 |
+
cnn_out = dummy_model.cnn(mel.unsqueeze(0)) # (1, C, F, T_cnn)
|
179 |
+
_, _, _, T_cnn = cnn_out.shape
|
180 |
+
|
181 |
+
# Pad or truncate labels to T_cnn
|
182 |
+
def pad_or_truncate(label, out_len):
|
183 |
+
if label.shape[0] < out_len:
|
184 |
+
pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
|
185 |
+
return torch.cat([label, pad], dim=0)
|
186 |
+
else:
|
187 |
+
return label[:out_len]
|
188 |
+
|
189 |
+
don_labels = pad_or_truncate(don_labels, T_cnn)
|
190 |
+
ka_labels = pad_or_truncate(ka_labels, T_cnn)
|
191 |
+
drumroll_labels = pad_or_truncate(drumroll_labels, T_cnn)
|
192 |
+
|
193 |
+
# For conformer input lengths: based on original mel shape (before CNN)
|
194 |
+
conformer_input_length = min(math.ceil(T / TIME_SUB), T_cnn)
|
195 |
+
|
196 |
+
print(
|
197 |
+
f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}, Difficulty: {difficulty_id}, Level: {level}"
|
198 |
+
)
|
199 |
+
|
200 |
+
return {
|
201 |
+
"mel": mel, # (1, N_MELS, T)
|
202 |
+
"don_labels": don_labels, # (T_cnn,)
|
203 |
+
"ka_labels": ka_labels, # (T_cnn,)
|
204 |
+
"drumroll_labels": drumroll_labels, # (T_cnn,)
|
205 |
+
"nps": torch.tensor(nps, dtype=torch.float32),
|
206 |
+
"difficulty": torch.tensor(difficulty_id, dtype=torch.long),
|
207 |
+
"level": torch.tensor(level, dtype=torch.long),
|
208 |
+
"duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
|
209 |
+
"length": torch.tensor(
|
210 |
+
conformer_input_length, dtype=torch.long
|
211 |
+
), # for conformer
|
212 |
+
}
|
213 |
+
|
214 |
+
|
215 |
+
def collate_fn(batch):
|
216 |
+
mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS)
|
217 |
+
don_labels_list = [b["don_labels"] for b in batch]
|
218 |
+
ka_labels_list = [b["ka_labels"] for b in batch]
|
219 |
+
drumroll_labels_list = [b["drumroll_labels"] for b in batch]
|
220 |
+
nps_list = [b["nps"] for b in batch]
|
221 |
+
difficulty_list = [b["difficulty"] for b in batch]
|
222 |
+
level_list = [b["level"] for b in batch]
|
223 |
+
durations_list = [b["duration_seconds"] for b in batch]
|
224 |
+
lengths_list = [b["length"] for b in batch]
|
225 |
+
|
226 |
+
# Pad mels
|
227 |
+
padded_mels = nn.utils.rnn.pad_sequence(
|
228 |
+
mels_list, batch_first=True
|
229 |
+
) # (B, T_max, N_MELS)
|
230 |
+
reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
|
231 |
+
T_max = padded_mels.shape[1]
|
232 |
+
|
233 |
+
# Pad labels to T_max
|
234 |
+
def pad_label(label, out_len):
|
235 |
+
if label.shape[0] < out_len:
|
236 |
+
pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
|
237 |
+
return torch.cat([label, pad], dim=0)
|
238 |
+
else:
|
239 |
+
return label[:out_len]
|
240 |
+
|
241 |
+
don_labels = torch.stack([pad_label(l, T_max) for l in don_labels_list])
|
242 |
+
ka_labels = torch.stack([pad_label(l, T_max) for l in ka_labels_list])
|
243 |
+
drumroll_labels = torch.stack([pad_label(l, T_max) for l in drumroll_labels_list])
|
244 |
+
lengths = torch.tensor(
|
245 |
+
[min(l.item(), T_max) for l in lengths_list], dtype=torch.long
|
246 |
+
)
|
247 |
+
|
248 |
+
return {
|
249 |
+
"mel": reshaped_mels,
|
250 |
+
"don_labels": don_labels,
|
251 |
+
"ka_labels": ka_labels,
|
252 |
+
"drumroll_labels": drumroll_labels,
|
253 |
+
"lengths": lengths, # for conformer
|
254 |
+
"nps": torch.stack(nps_list),
|
255 |
+
"difficulty": torch.stack(difficulty_list),
|
256 |
+
"level": torch.stack(level_list),
|
257 |
+
"durations": torch.stack(durations_list),
|
258 |
+
}
|
tc6/train.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from accelerate.utils import set_seed
|
2 |
+
|
3 |
+
set_seed(1024)
|
4 |
+
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
from tqdm import tqdm
|
11 |
+
from datasets import concatenate_datasets
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import numpy as np
|
14 |
+
from .config import (
|
15 |
+
BATCH_SIZE,
|
16 |
+
DEVICE,
|
17 |
+
EPOCHS,
|
18 |
+
LR,
|
19 |
+
GRAD_ACCUM_STEPS,
|
20 |
+
HOP_LENGTH,
|
21 |
+
SAMPLE_RATE,
|
22 |
+
)
|
23 |
+
from .model import TaikoConformer6
|
24 |
+
from .dataset import ds
|
25 |
+
from .preprocess import preprocess, collate_fn
|
26 |
+
from .loss import TaikoEnergyLoss
|
27 |
+
from huggingface_hub import upload_folder
|
28 |
+
|
29 |
+
|
30 |
+
# --- Helper function to log energy plots ---
|
31 |
+
def log_energy_plots_to_tensorboard(
|
32 |
+
writer,
|
33 |
+
tag_prefix,
|
34 |
+
epoch,
|
35 |
+
pred_don,
|
36 |
+
pred_ka,
|
37 |
+
pred_drumroll,
|
38 |
+
true_don,
|
39 |
+
true_ka,
|
40 |
+
true_drumroll,
|
41 |
+
valid_length, # Actual valid length of the sequence (before padding)
|
42 |
+
hop_sec,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Logs a plot of predicted vs. true energies for one sample to TensorBoard.
|
46 |
+
Energies should be 1D numpy arrays for the single sample, up to valid_length.
|
47 |
+
"""
|
48 |
+
# Ensure data is on CPU and converted to numpy, and select only the valid part
|
49 |
+
pred_don = pred_don[:valid_length].detach().cpu().numpy()
|
50 |
+
pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
|
51 |
+
pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
|
52 |
+
true_don = true_don[:valid_length].cpu().numpy()
|
53 |
+
true_ka = true_ka[:valid_length].cpu().numpy()
|
54 |
+
true_drumroll = true_drumroll[:valid_length].cpu().numpy()
|
55 |
+
|
56 |
+
time_axis = np.arange(valid_length) * hop_sec
|
57 |
+
|
58 |
+
fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
|
59 |
+
fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)
|
60 |
+
|
61 |
+
axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
|
62 |
+
axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
|
63 |
+
axs[0].set_ylabel("Don Energy")
|
64 |
+
axs[0].legend()
|
65 |
+
axs[0].grid(True)
|
66 |
+
|
67 |
+
axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
|
68 |
+
axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
|
69 |
+
axs[1].set_ylabel("Ka Energy")
|
70 |
+
axs[1].legend()
|
71 |
+
axs[1].grid(True)
|
72 |
+
|
73 |
+
axs[2].plot(
|
74 |
+
time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
|
75 |
+
)
|
76 |
+
axs[2].plot(
|
77 |
+
time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
|
78 |
+
)
|
79 |
+
axs[2].set_ylabel("Drumroll Energy")
|
80 |
+
axs[2].set_xlabel("Time (s)")
|
81 |
+
axs[2].legend()
|
82 |
+
axs[2].grid(True)
|
83 |
+
|
84 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
|
85 |
+
writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
|
86 |
+
plt.close(fig)
|
87 |
+
|
88 |
+
|
89 |
+
def main():
|
90 |
+
global ds
|
91 |
+
|
92 |
+
# Calculate hop seconds for model output frames
|
93 |
+
# This assumes the model output time dimension corresponds to the mel spectrogram time dimension
|
94 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
95 |
+
|
96 |
+
best_val_loss = float("inf")
|
97 |
+
patience = 10 # Increased patience a bit
|
98 |
+
pat_count = 0
|
99 |
+
|
100 |
+
ds_oni = ds.map(
|
101 |
+
preprocess,
|
102 |
+
remove_columns=ds.column_names,
|
103 |
+
fn_kwargs={"difficulty": "oni"},
|
104 |
+
writer_batch_size=10,
|
105 |
+
)
|
106 |
+
ds_hard = ds.map(
|
107 |
+
preprocess,
|
108 |
+
remove_columns=ds.column_names,
|
109 |
+
fn_kwargs={"difficulty": "hard"},
|
110 |
+
writer_batch_size=10,
|
111 |
+
)
|
112 |
+
ds_normal = ds.map(
|
113 |
+
preprocess,
|
114 |
+
remove_columns=ds.column_names,
|
115 |
+
fn_kwargs={"difficulty": "normal"},
|
116 |
+
writer_batch_size=10,
|
117 |
+
)
|
118 |
+
ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])
|
119 |
+
|
120 |
+
ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
|
121 |
+
# ds_train_test.push_to_hub("JacobLinCool/taiko-conformer-6-ds")
|
122 |
+
train_loader = DataLoader(
|
123 |
+
ds_train_test["train"],
|
124 |
+
batch_size=BATCH_SIZE,
|
125 |
+
shuffle=True,
|
126 |
+
collate_fn=collate_fn,
|
127 |
+
num_workers=16,
|
128 |
+
persistent_workers=True,
|
129 |
+
prefetch_factor=4,
|
130 |
+
)
|
131 |
+
val_loader = DataLoader(
|
132 |
+
ds_train_test["test"],
|
133 |
+
batch_size=BATCH_SIZE,
|
134 |
+
shuffle=False,
|
135 |
+
collate_fn=collate_fn,
|
136 |
+
num_workers=16,
|
137 |
+
persistent_workers=True,
|
138 |
+
prefetch_factor=4,
|
139 |
+
)
|
140 |
+
|
141 |
+
model = TaikoConformer6().to(DEVICE)
|
142 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
143 |
+
|
144 |
+
criterion = TaikoEnergyLoss(reduction="mean").to(DEVICE)
|
145 |
+
|
146 |
+
# Adjust scheduler steps for gradient accumulation
|
147 |
+
num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
|
148 |
+
total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch
|
149 |
+
|
150 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
151 |
+
optimizer, max_lr=LR, total_steps=total_optimizer_steps
|
152 |
+
)
|
153 |
+
|
154 |
+
writer = SummaryWriter()
|
155 |
+
|
156 |
+
for epoch in range(1, EPOCHS + 1):
|
157 |
+
model.train()
|
158 |
+
total_epoch_loss = 0.0
|
159 |
+
optimizer.zero_grad()
|
160 |
+
|
161 |
+
for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
|
162 |
+
mel = batch["mel"].to(DEVICE)
|
163 |
+
# Unpack new energy-based labels
|
164 |
+
don_labels = batch["don_labels"].to(DEVICE)
|
165 |
+
ka_labels = batch["ka_labels"].to(DEVICE)
|
166 |
+
drumroll_labels = batch["drumroll_labels"].to(DEVICE)
|
167 |
+
lengths = batch["lengths"].to(
|
168 |
+
DEVICE
|
169 |
+
) # These are for the Conformer model output
|
170 |
+
nps = batch["nps"].to(DEVICE)
|
171 |
+
difficulty = batch["difficulty"].to(DEVICE) # Add difficulty
|
172 |
+
level = batch["level"].to(DEVICE) # Add level
|
173 |
+
|
174 |
+
output_dict = model(
|
175 |
+
mel, lengths, nps, difficulty, level
|
176 |
+
) # Pass difficulty and level
|
177 |
+
# output_dict["presence"] is now (B, T_out, 3) for don, ka, drumroll energies
|
178 |
+
pred_energies_batch = output_dict["presence"] # (B, T_out, 3)
|
179 |
+
|
180 |
+
loss_input_batch = {
|
181 |
+
"don_labels": don_labels,
|
182 |
+
"ka_labels": ka_labels,
|
183 |
+
"drumroll_labels": drumroll_labels,
|
184 |
+
"lengths": lengths, # Pass lengths for masking within the loss function
|
185 |
+
}
|
186 |
+
loss = criterion(output_dict, loss_input_batch)
|
187 |
+
|
188 |
+
(loss / GRAD_ACCUM_STEPS).backward()
|
189 |
+
|
190 |
+
if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
|
191 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
192 |
+
optimizer.step()
|
193 |
+
scheduler.step()
|
194 |
+
optimizer.zero_grad()
|
195 |
+
|
196 |
+
total_epoch_loss += loss.item()
|
197 |
+
|
198 |
+
# Log plot for the first sample of the first batch in each training epoch
|
199 |
+
if idx == 0:
|
200 |
+
first_sample_pred_don = pred_energies_batch[0, :, 0]
|
201 |
+
first_sample_pred_ka = pred_energies_batch[0, :, 1]
|
202 |
+
first_sample_pred_drumroll = pred_energies_batch[0, :, 2]
|
203 |
+
|
204 |
+
first_sample_true_don = don_labels[0, :]
|
205 |
+
first_sample_true_ka = ka_labels[0, :]
|
206 |
+
first_sample_true_drumroll = drumroll_labels[0, :]
|
207 |
+
|
208 |
+
first_sample_length = lengths[
|
209 |
+
0
|
210 |
+
].item() # Get the valid length of the first sample
|
211 |
+
|
212 |
+
log_energy_plots_to_tensorboard(
|
213 |
+
writer,
|
214 |
+
"Train/Sample_0",
|
215 |
+
epoch,
|
216 |
+
first_sample_pred_don,
|
217 |
+
first_sample_pred_ka,
|
218 |
+
first_sample_pred_drumroll,
|
219 |
+
first_sample_true_don,
|
220 |
+
first_sample_true_ka,
|
221 |
+
first_sample_true_drumroll,
|
222 |
+
first_sample_length,
|
223 |
+
output_frame_hop_sec,
|
224 |
+
)
|
225 |
+
|
226 |
+
avg_train_loss = total_epoch_loss / len(train_loader)
|
227 |
+
writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)
|
228 |
+
|
229 |
+
# Validation
|
230 |
+
model.eval()
|
231 |
+
total_val_loss = 0.0
|
232 |
+
# Removed storage for classification logits/labels and confusion matrix components
|
233 |
+
|
234 |
+
with torch.no_grad():
|
235 |
+
for val_idx, batch in enumerate(
|
236 |
+
tqdm(val_loader, desc=f"Val Epoch {epoch}")
|
237 |
+
):
|
238 |
+
mel = batch["mel"].to(DEVICE)
|
239 |
+
don_labels = batch["don_labels"].to(DEVICE)
|
240 |
+
ka_labels = batch["ka_labels"].to(DEVICE)
|
241 |
+
drumroll_labels = batch["drumroll_labels"].to(DEVICE)
|
242 |
+
lengths = batch["lengths"].to(DEVICE)
|
243 |
+
nps = batch["nps"].to(DEVICE) # Ground truth NPS from batch
|
244 |
+
difficulty = batch["difficulty"].to(DEVICE) # Add difficulty
|
245 |
+
level = batch["level"].to(DEVICE) # Add level
|
246 |
+
|
247 |
+
output_dict = model(
|
248 |
+
mel, lengths, nps, difficulty, level
|
249 |
+
) # Pass difficulty and level
|
250 |
+
pred_energies_val_batch = output_dict["presence"] # (B, T_out, 3)
|
251 |
+
|
252 |
+
val_loss_input_batch = {
|
253 |
+
"don_labels": don_labels,
|
254 |
+
"ka_labels": ka_labels,
|
255 |
+
"drumroll_labels": drumroll_labels,
|
256 |
+
"lengths": lengths,
|
257 |
+
}
|
258 |
+
val_loss = criterion(output_dict, val_loss_input_batch)
|
259 |
+
total_val_loss += val_loss.item()
|
260 |
+
|
261 |
+
# Log plot for the first sample of the first batch in each validation epoch
|
262 |
+
if val_idx == 0:
|
263 |
+
first_val_sample_pred_don = pred_energies_val_batch[0, :, 0]
|
264 |
+
first_val_sample_pred_ka = pred_energies_val_batch[0, :, 1]
|
265 |
+
first_val_sample_pred_drumroll = pred_energies_val_batch[0, :, 2]
|
266 |
+
|
267 |
+
first_val_sample_true_don = don_labels[0, :]
|
268 |
+
first_val_sample_true_ka = ka_labels[0, :]
|
269 |
+
first_val_sample_true_drumroll = drumroll_labels[0, :]
|
270 |
+
|
271 |
+
first_val_sample_length = lengths[0].item()
|
272 |
+
|
273 |
+
log_energy_plots_to_tensorboard(
|
274 |
+
writer,
|
275 |
+
"Eval/Sample_0",
|
276 |
+
epoch,
|
277 |
+
first_val_sample_pred_don,
|
278 |
+
first_val_sample_pred_ka,
|
279 |
+
first_val_sample_pred_drumroll,
|
280 |
+
first_val_sample_true_don,
|
281 |
+
first_val_sample_true_ka,
|
282 |
+
first_val_sample_true_drumroll,
|
283 |
+
first_val_sample_length,
|
284 |
+
output_frame_hop_sec,
|
285 |
+
)
|
286 |
+
|
287 |
+
# Log ground truth NPS for reference during validation if needed
|
288 |
+
# writer.add_scalar("NPS/GT_Val_Batch_Avg", nps.mean().item(), epoch * len(val_loader) + idx)
|
289 |
+
|
290 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
291 |
+
writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)
|
292 |
+
|
293 |
+
# Log learning rate
|
294 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
295 |
+
writer.add_scalar("LR/learning_rate", current_lr, epoch)
|
296 |
+
|
297 |
+
# Log ground truth NPS from the last validation batch (or mean over epoch)
|
298 |
+
if "nps" in batch: # Check if nps is in the last batch
|
299 |
+
writer.add_scalar(
|
300 |
+
"NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
|
301 |
+
)
|
302 |
+
|
303 |
+
print(
|
304 |
+
f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
|
305 |
+
)
|
306 |
+
|
307 |
+
if avg_val_loss < best_val_loss:
|
308 |
+
best_val_loss = avg_val_loss
|
309 |
+
pat_count = 0
|
310 |
+
torch.save(model.state_dict(), "best_model.pt") # Changed model save name
|
311 |
+
print(f"Saved new best model to best_model.pt at epoch {epoch}")
|
312 |
+
else:
|
313 |
+
pat_count += 1
|
314 |
+
if pat_count >= patience:
|
315 |
+
print("Early stopping!")
|
316 |
+
break
|
317 |
+
writer.close()
|
318 |
+
|
319 |
+
model_id = "JacobLinCool/taiko-conformer-6"
|
320 |
+
try:
|
321 |
+
model.push_to_hub(model_id, commit_message="Upload trained model")
|
322 |
+
upload_folder(
|
323 |
+
repo_id=model_id,
|
324 |
+
folder_path="runs",
|
325 |
+
path_in_repo=".",
|
326 |
+
commit_message="Upload training logs",
|
327 |
+
ignore_patterns=["*.txt", "*.json", "*.csv"],
|
328 |
+
)
|
329 |
+
print(f"Model and logs uploaded to {model_id}")
|
330 |
+
except Exception as e:
|
331 |
+
print(f"Error uploading to Hugging Face Hub: {e}")
|
332 |
+
print("Make sure you have the correct permissions and try again.")
|
333 |
+
|
334 |
+
|
335 |
+
if __name__ == "__main__":
|
336 |
+
main()
|
tc7/__init__.py
ADDED
File without changes
|
tc7/config.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# ─── 1) CONFIG ─────────────────────────────────────────────────────
|
4 |
+
SAMPLE_RATE = 22050
|
5 |
+
N_MELS = 80
|
6 |
+
HOP_LENGTH = 256
|
7 |
+
TIME_SUB = 1
|
8 |
+
CNN_CH = 256
|
9 |
+
N_HEADS = 8
|
10 |
+
D_MODEL = 512
|
11 |
+
FF_DIM = 1024
|
12 |
+
N_LAYERS = 6
|
13 |
+
DEPTHWISE_CONV_KERNEL_SIZE = 31
|
14 |
+
DROPOUT = 0.1
|
15 |
+
HIDDEN_DIM = 64
|
16 |
+
N_TYPES = 7
|
17 |
+
BATCH_SIZE = 2
|
18 |
+
GRAD_ACCUM_STEPS = 8
|
19 |
+
LR = 3e-4
|
20 |
+
EPOCHS = 200
|
21 |
+
NPS_PENALTY_WEIGHT_ALPHA = 0.3
|
22 |
+
NPS_PENALTY_WEIGHT_BETA = 1.0
|
23 |
+
DEVICE = (
|
24 |
+
"cuda"
|
25 |
+
if torch.cuda.is_available()
|
26 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
27 |
+
)
|
tc7/dataset.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset, concatenate_datasets
|
2 |
+
|
3 |
+
ds1 = load_dataset("JacobLinCool/taiko-2023-1.1", split="train")
|
4 |
+
ds2 = load_dataset("JacobLinCool/taiko-2023-1.2", split="train")
|
5 |
+
ds3 = load_dataset("JacobLinCool/taiko-2023-1.3", split="train")
|
6 |
+
ds4 = load_dataset("JacobLinCool/taiko-2023-1.4", split="train")
|
7 |
+
ds5 = load_dataset("JacobLinCool/taiko-2023-1.5", split="train")
|
8 |
+
ds6 = load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
9 |
+
ds7 = load_dataset("JacobLinCool/taiko-2023-1.7", split="train")
|
10 |
+
ds = concatenate_datasets([ds1, ds2, ds3, ds4, ds5, ds6, ds7]).with_format("torch")
|
11 |
+
|
12 |
+
good = list(range(len(ds)))
|
13 |
+
good.remove(1079) # 1079 has file problem
|
14 |
+
ds = ds.select(good)
|
15 |
+
|
16 |
+
# for local test
|
17 |
+
# ds = (
|
18 |
+
# load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
|
19 |
+
# .with_format("torch")
|
20 |
+
# .select(range(10))
|
21 |
+
# )
|
tc7/infer.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH
|
7 |
+
import torch.profiler
|
8 |
+
|
9 |
+
|
10 |
+
# --- PREPROCESSING (match training) ---
|
11 |
+
def preprocess_audio(audio_path):
|
12 |
+
wav, sr = torchaudio.load(audio_path)
|
13 |
+
wav = wav.mean(dim=0) # mono
|
14 |
+
if sr != SAMPLE_RATE:
|
15 |
+
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
|
16 |
+
wav = wav / (wav.abs().max() + 1e-8) # Normalize audio
|
17 |
+
|
18 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
19 |
+
sample_rate=SAMPLE_RATE,
|
20 |
+
n_mels=N_MELS,
|
21 |
+
hop_length=HOP_LENGTH,
|
22 |
+
n_fft=2048,
|
23 |
+
)
|
24 |
+
mel = mel_transform(wav)
|
25 |
+
return mel # mel is (N_MELS, T_mel)
|
26 |
+
|
27 |
+
|
28 |
+
# --- INFERENCE ---
|
29 |
+
def run_inference(model, mel_input, nps_input, difficulty_input, level_input, device):
|
30 |
+
model.eval()
|
31 |
+
with torch.no_grad():
|
32 |
+
mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel)
|
33 |
+
nps = nps_input.to(device).unsqueeze(0) # (1,)
|
34 |
+
difficulty = difficulty_input.to(device).unsqueeze(0) # (1,)
|
35 |
+
level = level_input.to(device).unsqueeze(0) # (1,)
|
36 |
+
|
37 |
+
mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel)
|
38 |
+
|
39 |
+
conformer_lengths = torch.tensor(
|
40 |
+
[mel_cnn_input.shape[-1]], dtype=torch.long, device=device
|
41 |
+
)
|
42 |
+
|
43 |
+
with torch.profiler.profile(
|
44 |
+
activities=[
|
45 |
+
torch.profiler.ProfilerActivity.CPU,
|
46 |
+
*(
|
47 |
+
[torch.profiler.ProfilerActivity.CUDA]
|
48 |
+
if device.type == "cuda"
|
49 |
+
else []
|
50 |
+
),
|
51 |
+
],
|
52 |
+
record_shapes=True,
|
53 |
+
profile_memory=True,
|
54 |
+
with_stack=False,
|
55 |
+
with_flops=True,
|
56 |
+
) as prof:
|
57 |
+
out_dict = model(mel_cnn_input, conformer_lengths, nps, difficulty, level)
|
58 |
+
print(
|
59 |
+
prof.key_averages().table(
|
60 |
+
sort_by=(
|
61 |
+
"self_cuda_memory_usage"
|
62 |
+
if device.type == "cuda"
|
63 |
+
else "self_cpu_time_total"
|
64 |
+
),
|
65 |
+
row_limit=20,
|
66 |
+
)
|
67 |
+
)
|
68 |
+
|
69 |
+
energies = out_dict["presence"].squeeze(0).cpu().numpy()
|
70 |
+
|
71 |
+
don_energy = energies[:, 0]
|
72 |
+
ka_energy = energies[:, 1]
|
73 |
+
drumroll_energy = energies[:, 2]
|
74 |
+
|
75 |
+
return don_energy, ka_energy, drumroll_energy
|
76 |
+
|
77 |
+
|
78 |
+
# --- DECODE TO ONSETS ---
|
79 |
+
def decode_onsets(
|
80 |
+
don_energy,
|
81 |
+
ka_energy,
|
82 |
+
drumroll_energy,
|
83 |
+
hop_sec,
|
84 |
+
threshold=0.5,
|
85 |
+
min_distance_frames=3,
|
86 |
+
):
|
87 |
+
results = []
|
88 |
+
T_out = len(don_energy)
|
89 |
+
last_onset_frame = -min_distance_frames
|
90 |
+
|
91 |
+
for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection
|
92 |
+
if i < last_onset_frame + min_distance_frames:
|
93 |
+
continue
|
94 |
+
|
95 |
+
e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i]
|
96 |
+
energies_at_i = {
|
97 |
+
1: e_don,
|
98 |
+
2: e_ka,
|
99 |
+
5: e_drum,
|
100 |
+
} # Type mapping: 1:Don, 2:Ka, 5:Drumroll
|
101 |
+
|
102 |
+
# Find which energy is max and if it's a peak above threshold
|
103 |
+
# Sort by energy value descending to prioritize higher energy in case of ties for peak condition
|
104 |
+
sorted_types_by_energy = sorted(
|
105 |
+
energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True
|
106 |
+
)
|
107 |
+
|
108 |
+
detected_this_frame = False
|
109 |
+
for onset_type in sorted_types_by_energy:
|
110 |
+
current_energy_series = None
|
111 |
+
if onset_type == 1:
|
112 |
+
current_energy_series = don_energy
|
113 |
+
elif onset_type == 2:
|
114 |
+
current_energy_series = ka_energy
|
115 |
+
elif onset_type == 5:
|
116 |
+
current_energy_series = drumroll_energy
|
117 |
+
|
118 |
+
energy_val = current_energy_series[i]
|
119 |
+
|
120 |
+
if (
|
121 |
+
energy_val > threshold
|
122 |
+
and energy_val > current_energy_series[i - 1]
|
123 |
+
and energy_val > current_energy_series[i + 1]
|
124 |
+
):
|
125 |
+
# Check if this energy is the highest among the three at this frame
|
126 |
+
# This check is implicitly handled by iterating `sorted_types_by_energy`
|
127 |
+
# and breaking after the first detection.
|
128 |
+
results.append((i * hop_sec, onset_type))
|
129 |
+
last_onset_frame = i
|
130 |
+
detected_this_frame = True
|
131 |
+
break # Only one onset type per frame
|
132 |
+
|
133 |
+
return results
|
134 |
+
|
135 |
+
|
136 |
+
# --- VISUALIZATION ---
|
137 |
+
def plot_results(
|
138 |
+
mel_spectrogram,
|
139 |
+
don_energy,
|
140 |
+
ka_energy,
|
141 |
+
drumroll_energy,
|
142 |
+
onsets,
|
143 |
+
hop_sec,
|
144 |
+
out_path=None,
|
145 |
+
):
|
146 |
+
# mel_spectrogram is (N_MELS, T_mel)
|
147 |
+
T_mel = mel_spectrogram.shape[1]
|
148 |
+
T_out = len(don_energy) # Length of energy arrays (model output time dimension)
|
149 |
+
|
150 |
+
# Time axes
|
151 |
+
time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE)
|
152 |
+
# hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE
|
153 |
+
# However, the model output T_out is related to T_mel (input to CNN).
|
154 |
+
# If CNN does not change time dimension, T_out = T_mel.
|
155 |
+
# If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB.
|
156 |
+
# The `lengths` passed to conformer in `run_inference` is T_mel.
|
157 |
+
# The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`.
|
158 |
+
# So, T_out from model is T_mel.
|
159 |
+
# The `hop_sec` for onsets should be based on the model output frame rate.
|
160 |
+
# If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE.
|
161 |
+
# The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE.
|
162 |
+
# This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels.
|
163 |
+
# The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`.
|
164 |
+
# In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN.
|
165 |
+
# The `lengths` for the conformer is based on this T_cnn_out.
|
166 |
+
# So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps.
|
167 |
+
# Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer.
|
168 |
+
# Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames.
|
169 |
+
time_axis_energies = np.arange(T_out) * hop_sec
|
170 |
+
|
171 |
+
fig, ax1 = plt.subplots(figsize=(100, 10))
|
172 |
+
|
173 |
+
# Plot Mel Spectrogram on ax1
|
174 |
+
mel_db = torchaudio.functional.amplitude_to_DB(
|
175 |
+
mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0
|
176 |
+
)
|
177 |
+
img = ax1.imshow(
|
178 |
+
mel_db.numpy(),
|
179 |
+
aspect="auto",
|
180 |
+
origin="lower",
|
181 |
+
cmap="magma",
|
182 |
+
extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS],
|
183 |
+
)
|
184 |
+
ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets")
|
185 |
+
ax1.set_xlabel("Time (s)")
|
186 |
+
ax1.set_ylabel("Mel Bin")
|
187 |
+
fig.colorbar(img, ax=ax1, format="%+2.0f dB")
|
188 |
+
|
189 |
+
# Create a second y-axis for energies
|
190 |
+
ax2 = ax1.twinx()
|
191 |
+
ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red")
|
192 |
+
ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue")
|
193 |
+
ax2.plot(
|
194 |
+
time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green"
|
195 |
+
)
|
196 |
+
ax2.set_ylabel("Energy")
|
197 |
+
ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded
|
198 |
+
|
199 |
+
# Overlay onsets from decode_onsets (t is already in seconds)
|
200 |
+
labeled_types = set()
|
201 |
+
# Group drumrolls into segments (reuse logic from write_tja)
|
202 |
+
drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5]
|
203 |
+
drumroll_times.sort()
|
204 |
+
drumroll_segments = []
|
205 |
+
if drumroll_times:
|
206 |
+
seg_start = drumroll_times[0]
|
207 |
+
prev = drumroll_times[0]
|
208 |
+
for t in drumroll_times[1:]:
|
209 |
+
if t - prev <= hop_sec * 6: # up to 5-frame gap
|
210 |
+
prev = t
|
211 |
+
else:
|
212 |
+
drumroll_segments.append((seg_start, prev))
|
213 |
+
seg_start = t
|
214 |
+
prev = t
|
215 |
+
drumroll_segments.append((seg_start, prev))
|
216 |
+
# Plot Don/Ka onsets as vertical lines
|
217 |
+
for t_sec, typ in onsets:
|
218 |
+
if typ == 5:
|
219 |
+
continue # skip drumroll onsets
|
220 |
+
color_map = {1: "darkred", 2: "darkblue"}
|
221 |
+
label_map = {1: "Don Onset", 2: "Ka Onset"}
|
222 |
+
line_color = color_map.get(typ, "black")
|
223 |
+
line_label = label_map.get(typ, f"Type {typ} Onset")
|
224 |
+
if typ not in labeled_types:
|
225 |
+
ax1.axvline(
|
226 |
+
t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label
|
227 |
+
)
|
228 |
+
labeled_types.add(typ)
|
229 |
+
else:
|
230 |
+
ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9)
|
231 |
+
# Plot drumroll segments as shaded regions
|
232 |
+
for seg_start, seg_end in drumroll_segments:
|
233 |
+
ax1.axvspan(
|
234 |
+
seg_start,
|
235 |
+
seg_end + hop_sec,
|
236 |
+
color="green",
|
237 |
+
alpha=0.2,
|
238 |
+
label="Drumroll Segment" if "drumroll" not in labeled_types else None,
|
239 |
+
)
|
240 |
+
labeled_types.add("drumroll")
|
241 |
+
|
242 |
+
# Combine legends from both axes
|
243 |
+
lines, labels = ax1.get_legend_handles_labels()
|
244 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
245 |
+
ax2.legend(lines + lines2, labels + labels2, loc="upper right")
|
246 |
+
|
247 |
+
fig.tight_layout()
|
248 |
+
|
249 |
+
# Return plot as image buffer or save to file if path provided
|
250 |
+
if out_path:
|
251 |
+
plt.savefig(out_path)
|
252 |
+
print(f"Saved plot to {out_path}")
|
253 |
+
plt.close(fig)
|
254 |
+
return out_path
|
255 |
+
else:
|
256 |
+
# Return plot as in-memory buffer
|
257 |
+
return fig
|
258 |
+
|
259 |
+
|
260 |
+
def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
|
261 |
+
# TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
|
262 |
+
# Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
|
263 |
+
sec_per_beat = 60 / bpm
|
264 |
+
beats_per_measure = 4 # Assuming 4/4 time signature
|
265 |
+
sec_per_measure = sec_per_beat * beats_per_measure
|
266 |
+
# Step 1: Map onsets to (measure_idx, slot, typ)
|
267 |
+
slot_events = []
|
268 |
+
for t, typ in onsets:
|
269 |
+
measure_idx = int(t // sec_per_measure)
|
270 |
+
t_in_measure = t % sec_per_measure
|
271 |
+
slot = int(round(t_in_measure / sec_per_measure * quantize))
|
272 |
+
if slot >= quantize:
|
273 |
+
slot = quantize - 1
|
274 |
+
slot_events.append((measure_idx, slot, typ))
|
275 |
+
# Step 2: Build measure/slot grid
|
276 |
+
if slot_events:
|
277 |
+
max_measure_idx = max(m for m, _, _ in slot_events)
|
278 |
+
else:
|
279 |
+
max_measure_idx = -1
|
280 |
+
measures = {i: [0] * quantize for i in range(max_measure_idx + 1)}
|
281 |
+
# Step 3: Place Don/Ka, collect drumrolls
|
282 |
+
drumroll_slots = set()
|
283 |
+
for m, s, typ in slot_events:
|
284 |
+
if typ in [1, 2]:
|
285 |
+
measures[m][s] = typ
|
286 |
+
elif typ == 5:
|
287 |
+
drumroll_slots.add((m, s))
|
288 |
+
# Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end)
|
289 |
+
# Flatten all slots to a list of (measure, slot) sorted
|
290 |
+
drumroll_list = sorted(list(drumroll_slots))
|
291 |
+
# Group into contiguous regions (allowing a gap of 5 slots)
|
292 |
+
grouped = []
|
293 |
+
group = []
|
294 |
+
for ms in drumroll_list:
|
295 |
+
if not group:
|
296 |
+
group = [ms]
|
297 |
+
else:
|
298 |
+
last_m, last_s = group[-1]
|
299 |
+
m, s = ms
|
300 |
+
# Calculate slot distance, considering measure wrap
|
301 |
+
slot_dist = None
|
302 |
+
if m == last_m:
|
303 |
+
slot_dist = s - last_s
|
304 |
+
elif m == last_m + 1 and last_s <= quantize - 1:
|
305 |
+
slot_dist = (quantize - 1 - last_s) + s + 1
|
306 |
+
else:
|
307 |
+
slot_dist = None
|
308 |
+
# Allow gap of up to 5 slots (slot_dist <= 6)
|
309 |
+
if slot_dist is not None and 1 <= slot_dist <= 6:
|
310 |
+
group.append(ms)
|
311 |
+
else:
|
312 |
+
grouped.append(group)
|
313 |
+
group = [ms]
|
314 |
+
if group:
|
315 |
+
grouped.append(group)
|
316 |
+
# Mark 5 (start) and 8 (end) for each group
|
317 |
+
for region in grouped:
|
318 |
+
if len(region) == 1:
|
319 |
+
m, s = region[0]
|
320 |
+
measures[m][s] = 5
|
321 |
+
# Place 8 in next slot (or next measure if at end)
|
322 |
+
if s < quantize - 1:
|
323 |
+
measures[m][s + 1] = 8
|
324 |
+
elif m < max_measure_idx:
|
325 |
+
measures[m + 1][0] = 8
|
326 |
+
else:
|
327 |
+
m_start, s_start = region[0]
|
328 |
+
m_end, s_end = region[-1]
|
329 |
+
measures[m_start][s_start] = 5
|
330 |
+
measures[m_end][s_end] = 8
|
331 |
+
# Fill 0 for middle slots (already 0 by default)
|
332 |
+
# Step 5: Generate TJA content
|
333 |
+
tja_content = []
|
334 |
+
tja_content.append(f"TITLE:{audio} (TC7, {time.strftime('%Y-%m-%d %H:%M:%S')})")
|
335 |
+
tja_content.append(f"BPM:{bpm}")
|
336 |
+
tja_content.append(f"WAVE:{audio}")
|
337 |
+
tja_content.append("OFFSET:0")
|
338 |
+
tja_content.append("COURSE:Oni\nLEVEL:9\n")
|
339 |
+
tja_content.append("#START")
|
340 |
+
for i in range(max_measure_idx + 1):
|
341 |
+
notes = measures.get(i, [0] * quantize)
|
342 |
+
line = "".join(str(n) for n in notes)
|
343 |
+
tja_content.append(line + ",")
|
344 |
+
tja_content.append("#END")
|
345 |
+
|
346 |
+
tja_string = "\n".join(tja_content)
|
347 |
+
|
348 |
+
# If out_path is provided, also write to file
|
349 |
+
if out_path:
|
350 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
351 |
+
f.write(tja_string)
|
352 |
+
print(f"TJA chart saved to {out_path}")
|
353 |
+
|
354 |
+
return tja_string
|
tc7/loss.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class TaikoLoss(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
reduction="mean",
|
9 |
+
nps_penalty_weight_alpha=0.3,
|
10 |
+
nps_penalty_weight_beta=1.0,
|
11 |
+
):
|
12 |
+
super().__init__()
|
13 |
+
self.mse_loss = nn.MSELoss(reduction="none")
|
14 |
+
self.reduction = reduction
|
15 |
+
self.nps_penalty_weight_alpha = nps_penalty_weight_alpha
|
16 |
+
self.nps_penalty_weight_beta = nps_penalty_weight_beta
|
17 |
+
|
18 |
+
def forward(self, outputs, batch):
|
19 |
+
"""
|
20 |
+
Calculates the MSE loss for energy-based predictions, with a two-level penalty
|
21 |
+
based on sliding NPS values.
|
22 |
+
- A heavier penalty if sliding_nps is 0.
|
23 |
+
- A continuous penalty if sliding_nps > 0.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
outputs (dict): Model output, containing 'presence' tensor.
|
27 |
+
outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
|
28 |
+
batch (dict): Batch data from collate_fn, containing true labels, lengths,
|
29 |
+
and sliding_nps_labels.
|
30 |
+
batch['sliding_nps_labels'] shape: (B, T)
|
31 |
+
Returns:
|
32 |
+
torch.Tensor: The calculated loss.
|
33 |
+
"""
|
34 |
+
pred_energies = outputs["presence"] # (B, T, 3)
|
35 |
+
true_don = batch["don_labels"] # (B, T)
|
36 |
+
true_ka = batch["ka_labels"] # (B, T)
|
37 |
+
true_drumroll = batch["drumroll_labels"] # (B, T)
|
38 |
+
true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2).to(
|
39 |
+
pred_energies.device
|
40 |
+
) # (B, T, 3)
|
41 |
+
|
42 |
+
B, T, _ = pred_energies.shape
|
43 |
+
|
44 |
+
# Create a mask based on batch['lengths'] to ignore padded parts of sequences
|
45 |
+
# batch['lengths'] gives the actual length of each sequence in the batch
|
46 |
+
# mask shape: (B, T)
|
47 |
+
mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
|
48 |
+
"lengths"
|
49 |
+
].to(pred_energies.device).unsqueeze(1)
|
50 |
+
# Expand mask to (B, T, 1) to broadcast across the 3 energy channels
|
51 |
+
mask_3d = mask_2d.unsqueeze(2) # (B, T, 1)
|
52 |
+
|
53 |
+
# Calculate element-wise MSE loss
|
54 |
+
mse_loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3)
|
55 |
+
|
56 |
+
# Calculate two-level Sliding NPS penalty
|
57 |
+
sliding_nps = batch["sliding_nps_labels"].to(pred_energies.device) # (B, T)
|
58 |
+
|
59 |
+
penalty_coefficients = torch.zeros_like(sliding_nps) # (B, T)
|
60 |
+
|
61 |
+
is_zero_nps = sliding_nps == 0.0
|
62 |
+
is_not_zero_nps = ~is_zero_nps
|
63 |
+
|
64 |
+
# Apply heavy penalty where sliding_nps is 0
|
65 |
+
penalty_coefficients[is_zero_nps] = self.nps_penalty_weight_beta
|
66 |
+
|
67 |
+
# Apply continuous penalty where sliding_nps > 0
|
68 |
+
penalty_coefficients[is_not_zero_nps] = self.nps_penalty_weight_alpha * (
|
69 |
+
1 - sliding_nps[is_not_zero_nps]
|
70 |
+
)
|
71 |
+
|
72 |
+
# Apply penalty factor to the MSE loss
|
73 |
+
loss_elementwise = mse_loss_elementwise * (
|
74 |
+
1 + penalty_coefficients.unsqueeze(2)
|
75 |
+
)
|
76 |
+
|
77 |
+
# Apply the mask to the combined loss
|
78 |
+
masked_loss = loss_elementwise * mask_3d
|
79 |
+
|
80 |
+
if self.reduction == "mean":
|
81 |
+
# Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
|
82 |
+
total_loss = masked_loss.sum()
|
83 |
+
num_valid_elements = mask_3d.sum() # Total number of unmasked float values
|
84 |
+
if num_valid_elements > 0:
|
85 |
+
return total_loss / num_valid_elements
|
86 |
+
else:
|
87 |
+
# Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
|
88 |
+
return torch.tensor(
|
89 |
+
0.0, device=pred_energies.device, requires_grad=True
|
90 |
+
)
|
91 |
+
elif self.reduction == "sum":
|
92 |
+
return masked_loss.sum()
|
93 |
+
else: # 'none' or any other case
|
94 |
+
return masked_loss
|
tc7/model.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchaudio.models import Conformer
|
4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
5 |
+
from .config import (
|
6 |
+
N_MELS,
|
7 |
+
CNN_CH,
|
8 |
+
N_HEADS,
|
9 |
+
D_MODEL,
|
10 |
+
FF_DIM,
|
11 |
+
N_LAYERS,
|
12 |
+
DROPOUT,
|
13 |
+
DEPTHWISE_CONV_KERNEL_SIZE,
|
14 |
+
HIDDEN_DIM,
|
15 |
+
DEVICE,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class TaikoConformer7(nn.Module, PyTorchModelHubMixin):
|
20 |
+
def __init__(self):
|
21 |
+
super().__init__()
|
22 |
+
# 1) CNN frontend: frequency-only pooling
|
23 |
+
self.cnn = nn.Sequential(
|
24 |
+
nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
|
25 |
+
nn.BatchNorm2d(CNN_CH),
|
26 |
+
nn.GELU(),
|
27 |
+
nn.Dropout2d(DROPOUT),
|
28 |
+
nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
|
29 |
+
nn.BatchNorm2d(CNN_CH),
|
30 |
+
nn.GELU(),
|
31 |
+
nn.Dropout2d(DROPOUT),
|
32 |
+
)
|
33 |
+
feat_dim = CNN_CH * (N_MELS // 4)
|
34 |
+
|
35 |
+
# 2) Linear projection to model dimension
|
36 |
+
self.proj = nn.Linear(feat_dim, D_MODEL)
|
37 |
+
|
38 |
+
# 3) FiLM conditioning for notes_per_second, difficulty, and level
|
39 |
+
self.film_nps = nn.Linear(1, 2 * D_MODEL)
|
40 |
+
self.film_difficulty = nn.Linear(
|
41 |
+
1, 2 * D_MODEL
|
42 |
+
) # Assuming difficulty is a single scalar
|
43 |
+
self.film_level = nn.Linear(1, 2 * D_MODEL) # Assuming level is a single scalar
|
44 |
+
|
45 |
+
# 4) Conformer encoder
|
46 |
+
self.encoder = Conformer(
|
47 |
+
input_dim=D_MODEL,
|
48 |
+
num_heads=N_HEADS,
|
49 |
+
ffn_dim=FF_DIM,
|
50 |
+
num_layers=N_LAYERS,
|
51 |
+
depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
|
52 |
+
dropout=DROPOUT,
|
53 |
+
use_group_norm=False,
|
54 |
+
convolution_first=False,
|
55 |
+
)
|
56 |
+
|
57 |
+
# 5) Presence regressor head
|
58 |
+
self.presence_regressor = nn.Sequential(
|
59 |
+
nn.Dropout(DROPOUT),
|
60 |
+
nn.Linear(D_MODEL, HIDDEN_DIM),
|
61 |
+
nn.GELU(),
|
62 |
+
nn.Dropout(DROPOUT),
|
63 |
+
nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
|
64 |
+
nn.Sigmoid(), # Output between 0 and 1
|
65 |
+
)
|
66 |
+
|
67 |
+
# 6) Initialize weights
|
68 |
+
for m in self.modules():
|
69 |
+
if isinstance(m, nn.Conv2d):
|
70 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
71 |
+
elif isinstance(m, nn.Linear):
|
72 |
+
nn.init.xavier_uniform_(m.weight)
|
73 |
+
if m.bias is not None:
|
74 |
+
nn.init.zeros_(m.bias)
|
75 |
+
|
76 |
+
def forward(
|
77 |
+
self,
|
78 |
+
mel: torch.Tensor,
|
79 |
+
lengths: torch.Tensor,
|
80 |
+
notes_per_second: torch.Tensor,
|
81 |
+
difficulty: torch.Tensor,
|
82 |
+
level: torch.Tensor,
|
83 |
+
):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
mel: (B, 1, N_MELS, T_mel)
|
87 |
+
lengths: (B,) lengths after CNN
|
88 |
+
notes_per_second: (B,) stream of control values
|
89 |
+
difficulty: (B,) difficulty values
|
90 |
+
level: (B,) level values
|
91 |
+
Returns:
|
92 |
+
Dict with:
|
93 |
+
'presence': (B, T_cnn_out, 3) # Corrected from 4 to 3
|
94 |
+
'lengths': lengths
|
95 |
+
"""
|
96 |
+
# CNN frontend
|
97 |
+
x = self.cnn(mel) # (B, C, F, T)
|
98 |
+
B, C, F, T = x.size()
|
99 |
+
x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
|
100 |
+
|
101 |
+
# Project to model dimension
|
102 |
+
x = self.proj(x) # (B, T, D_MODEL)
|
103 |
+
|
104 |
+
# FiLM conditioning
|
105 |
+
nps = notes_per_second.unsqueeze(-1).float() # (B, 1)
|
106 |
+
gamma_beta_nps = self.film_nps(nps) # (B, 2*D_MODEL)
|
107 |
+
gamma_nps, beta_nps = gamma_beta_nps.chunk(2, dim=-1)
|
108 |
+
x = gamma_nps.unsqueeze(1) * x + beta_nps.unsqueeze(1)
|
109 |
+
|
110 |
+
diff = difficulty.unsqueeze(-1).float() # (B, 1)
|
111 |
+
gamma_beta_diff = self.film_difficulty(diff) # (B, 2*D_MODEL)
|
112 |
+
gamma_diff, beta_diff = gamma_beta_diff.chunk(2, dim=-1)
|
113 |
+
x = gamma_diff.unsqueeze(1) * x + beta_diff.unsqueeze(1)
|
114 |
+
|
115 |
+
lvl = level.unsqueeze(-1).float() # (B, 1)
|
116 |
+
gamma_beta_lvl = self.film_level(lvl) # (B, 2*D_MODEL)
|
117 |
+
gamma_lvl, beta_lvl = gamma_beta_lvl.chunk(2, dim=-1)
|
118 |
+
x = gamma_lvl.unsqueeze(1) * x + beta_lvl.unsqueeze(1)
|
119 |
+
|
120 |
+
# Conformer encoder
|
121 |
+
x, _ = self.encoder(x, lengths=lengths)
|
122 |
+
|
123 |
+
# Presence prediction
|
124 |
+
presence = self.presence_regressor(x)
|
125 |
+
return {"presence": presence, "lengths": lengths}
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
model = TaikoConformer7().to(device=DEVICE)
|
130 |
+
print(model)
|
131 |
+
|
132 |
+
for name, param in model.named_parameters():
|
133 |
+
if param.requires_grad:
|
134 |
+
print(f"{name}: {param.numel():,}")
|
135 |
+
|
136 |
+
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
137 |
+
print(f"Total parameters: {params / 1e6:.2f}M")
|
138 |
+
|
139 |
+
batch_size = 4
|
140 |
+
mel_time_steps = 1024
|
141 |
+
input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
|
142 |
+
|
143 |
+
conformer_lengths = torch.tensor(
|
144 |
+
[mel_time_steps] * batch_size, dtype=torch.long
|
145 |
+
).to(DEVICE)
|
146 |
+
|
147 |
+
notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
|
148 |
+
DEVICE
|
149 |
+
)
|
150 |
+
difficulty_input = torch.tensor([1.0] * batch_size, dtype=torch.float32).to(
|
151 |
+
DEVICE
|
152 |
+
) # Example difficulty
|
153 |
+
level_input = torch.tensor([5.0] * batch_size, dtype=torch.float32).to(
|
154 |
+
DEVICE
|
155 |
+
) # Example level
|
156 |
+
|
157 |
+
output = model(
|
158 |
+
input_mel,
|
159 |
+
conformer_lengths,
|
160 |
+
notes_per_second_input,
|
161 |
+
difficulty_input,
|
162 |
+
level_input,
|
163 |
+
)
|
164 |
+
print("Output shapes:")
|
165 |
+
for key, value in output.items():
|
166 |
+
print(f"{key}: {value.shape}")
|
tc7/preprocess.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchaudio
|
5 |
+
from torchaudio.transforms import FrequencyMasking
|
6 |
+
from tja import parse_tja, PyParsingMode
|
7 |
+
from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
|
8 |
+
from .model import TaikoConformer7
|
9 |
+
|
10 |
+
|
11 |
+
mel_transform = torchaudio.transforms.MelSpectrogram(
|
12 |
+
sample_rate=SAMPLE_RATE,
|
13 |
+
n_mels=N_MELS,
|
14 |
+
hop_length=HOP_LENGTH,
|
15 |
+
n_fft=2048,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
freq_mask = FrequencyMasking(freq_mask_param=15)
|
20 |
+
|
21 |
+
|
22 |
+
def preprocess(example, difficulty="oni"):
|
23 |
+
wav_tensor = example["audio"]["array"]
|
24 |
+
sr = example["audio"]["sampling_rate"]
|
25 |
+
# 1) load & resample
|
26 |
+
if sr != SAMPLE_RATE:
|
27 |
+
wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
|
28 |
+
# normalize audio
|
29 |
+
wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
|
30 |
+
# add random Gaussian noise
|
31 |
+
if torch.rand(1).item() < 0.5:
|
32 |
+
wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
|
33 |
+
# 2) mel: (1, N_MELS, T)
|
34 |
+
mel = mel_transform(wav_tensor).unsqueeze(0)
|
35 |
+
# apply SpecAugment
|
36 |
+
mel = freq_mask(mel)
|
37 |
+
_, _, T = mel.shape
|
38 |
+
# 3) build label sequence of length ceil(T / TIME_SUB)
|
39 |
+
T_sub = math.ceil(T / TIME_SUB)
|
40 |
+
|
41 |
+
# Initialize energy-based labels for Don, Ka, Drumroll
|
42 |
+
don_labels = torch.zeros(T_sub, dtype=torch.float32)
|
43 |
+
ka_labels = torch.zeros(T_sub, dtype=torch.float32)
|
44 |
+
drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
|
45 |
+
sliding_nps_labels = torch.zeros(
|
46 |
+
T_sub, dtype=torch.float32
|
47 |
+
) # New label for sliding NPS
|
48 |
+
|
49 |
+
# Define exponential decay tail parameters
|
50 |
+
tail_length = 40 # number of frames for decay tail
|
51 |
+
decay_rate = 8.0 # decay rate parameter, adjust as needed
|
52 |
+
tail_kernel = torch.exp(
|
53 |
+
-torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
|
54 |
+
)
|
55 |
+
|
56 |
+
fps = SAMPLE_RATE / HOP_LENGTH
|
57 |
+
num_valid_notes = 0
|
58 |
+
for onset in example[difficulty]:
|
59 |
+
typ, t_start, t_end, *_ = onset
|
60 |
+
|
61 |
+
# Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
|
62 |
+
if typ < 1 or typ > N_TYPES: # Filter out invalid types
|
63 |
+
continue
|
64 |
+
|
65 |
+
num_valid_notes += 1
|
66 |
+
|
67 |
+
exact_frame_start = t_start.item() * fps
|
68 |
+
|
69 |
+
# Type 1 and 3 are Don, Type 2 and 4 are Ka
|
70 |
+
if typ == 1 or typ == 3 or typ == 2 or typ == 4:
|
71 |
+
exact_hit_time_sub = exact_frame_start / TIME_SUB
|
72 |
+
|
73 |
+
current_labels = don_labels if (typ == 1 or typ == 3) else ka_labels
|
74 |
+
|
75 |
+
start_points_info = []
|
76 |
+
rounded_hit_time_sub = round(exact_hit_time_sub)
|
77 |
+
|
78 |
+
if (
|
79 |
+
abs(exact_hit_time_sub - rounded_hit_time_sub) < 1e-6
|
80 |
+
): # Tolerance for float precision
|
81 |
+
idx_single = int(rounded_hit_time_sub)
|
82 |
+
if 0 <= idx_single < T_sub:
|
83 |
+
start_points_info.append({"idx": idx_single, "weight": 1.0})
|
84 |
+
else:
|
85 |
+
idx_floor = math.floor(exact_hit_time_sub)
|
86 |
+
idx_ceil = idx_floor + 1
|
87 |
+
|
88 |
+
frac = exact_hit_time_sub - idx_floor
|
89 |
+
weight_ceil = frac
|
90 |
+
weight_floor = 1.0 - frac
|
91 |
+
|
92 |
+
if weight_floor > 1e-6 and 0 <= idx_floor < T_sub:
|
93 |
+
start_points_info.append({"idx": idx_floor, "weight": weight_floor})
|
94 |
+
if weight_ceil > 1e-6 and 0 <= idx_ceil < T_sub:
|
95 |
+
start_points_info.append({"idx": idx_ceil, "weight": weight_ceil})
|
96 |
+
|
97 |
+
for point_info in start_points_info:
|
98 |
+
start_idx = point_info["idx"]
|
99 |
+
weight = point_info["weight"]
|
100 |
+
for k_idx, kernel_val in enumerate(tail_kernel):
|
101 |
+
target_idx = start_idx + k_idx
|
102 |
+
if 0 <= target_idx < T_sub:
|
103 |
+
current_labels[target_idx] = max(
|
104 |
+
current_labels[target_idx].item(),
|
105 |
+
weight * kernel_val.item(),
|
106 |
+
)
|
107 |
+
|
108 |
+
# Type 5, 6, 7 are Drumroll
|
109 |
+
elif typ >= 5 and typ <= 7:
|
110 |
+
exact_frame_end = t_end.item() * fps
|
111 |
+
exact_start_time_sub = exact_frame_start / TIME_SUB
|
112 |
+
exact_end_time_sub = exact_frame_end / TIME_SUB
|
113 |
+
|
114 |
+
# Improved drumroll body
|
115 |
+
body_loop_start_idx = math.floor(exact_start_time_sub)
|
116 |
+
body_loop_end_idx = math.ceil(exact_end_time_sub)
|
117 |
+
|
118 |
+
for dr_idx in range(body_loop_start_idx, body_loop_end_idx):
|
119 |
+
if 0 <= dr_idx < T_sub:
|
120 |
+
drumroll_labels[dr_idx] = 1.0
|
121 |
+
|
122 |
+
# Improved drumroll tail (starts from exact_end_time_sub)
|
123 |
+
tail_start_points_info = []
|
124 |
+
rounded_end_time_sub = round(exact_end_time_sub)
|
125 |
+
if abs(exact_end_time_sub - rounded_end_time_sub) < 1e-6:
|
126 |
+
idx_single_tail = int(rounded_end_time_sub)
|
127 |
+
if 0 <= idx_single_tail < T_sub:
|
128 |
+
tail_start_points_info.append(
|
129 |
+
{"idx": idx_single_tail, "weight": 1.0}
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
idx_floor_tail = math.floor(exact_end_time_sub)
|
133 |
+
idx_ceil_tail = idx_floor_tail + 1
|
134 |
+
|
135 |
+
frac_tail = exact_end_time_sub - idx_floor_tail
|
136 |
+
weight_ceil_tail = frac_tail
|
137 |
+
weight_floor_tail = 1.0 - frac_tail
|
138 |
+
|
139 |
+
if weight_floor_tail > 1e-6 and 0 <= idx_floor_tail < T_sub:
|
140 |
+
tail_start_points_info.append(
|
141 |
+
{"idx": idx_floor_tail, "weight": weight_floor_tail}
|
142 |
+
)
|
143 |
+
if weight_ceil_tail > 1e-6 and 0 <= idx_ceil_tail < T_sub:
|
144 |
+
tail_start_points_info.append(
|
145 |
+
{"idx": idx_ceil_tail, "weight": weight_ceil_tail}
|
146 |
+
)
|
147 |
+
|
148 |
+
for point_info in tail_start_points_info:
|
149 |
+
start_idx = point_info["idx"]
|
150 |
+
weight = point_info["weight"]
|
151 |
+
for k_idx, kernel_val in enumerate(tail_kernel):
|
152 |
+
target_idx = start_idx + k_idx
|
153 |
+
if 0 <= target_idx < T_sub:
|
154 |
+
drumroll_labels[target_idx] = max(
|
155 |
+
drumroll_labels[target_idx].item(),
|
156 |
+
weight * kernel_val.item(),
|
157 |
+
)
|
158 |
+
|
159 |
+
# Calculate sliding window NPS
|
160 |
+
note_events = (
|
161 |
+
[]
|
162 |
+
) # Store tuples of (time_sec, type_is_drumroll_start_or_end, duration_if_drumroll)
|
163 |
+
for onset in example[difficulty]:
|
164 |
+
typ, t_start_tensor, t_end_tensor, *_ = onset
|
165 |
+
t_start = t_start_tensor.item()
|
166 |
+
t_end = t_end_tensor.item()
|
167 |
+
|
168 |
+
if typ in [1, 2, 3, 4]: # Don or Ka
|
169 |
+
note_events.append(
|
170 |
+
(t_start, False, 0)
|
171 |
+
) # False indicates not a drumroll event, duration 0
|
172 |
+
elif typ >= 5 and typ <= 7: # Drumroll
|
173 |
+
note_events.append(
|
174 |
+
(t_start, True, t_end - t_start)
|
175 |
+
) # True indicates drumroll start, store duration
|
176 |
+
# We don't explicitly need a drumroll end event for this calculation method
|
177 |
+
|
178 |
+
note_events.sort(key=lambda x: x[0]) # Sort by time
|
179 |
+
|
180 |
+
window_duration_seconds = 0.5
|
181 |
+
# drumroll_nps_rate = 10.0 # Removed: Will use adaptive rate
|
182 |
+
|
183 |
+
# Step 1: Calculate base_sliding_nps_labels (Don/Ka only)
|
184 |
+
base_don_ka_sliding_nps = torch.zeros(T_sub, dtype=torch.float32)
|
185 |
+
time_step_duration_sec = TIME_SUB / fps # Duration of one T_sub segment
|
186 |
+
|
187 |
+
for k_idx in range(T_sub):
|
188 |
+
k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps
|
189 |
+
k_window_start_sec = k_window_end_sec - window_duration_seconds
|
190 |
+
|
191 |
+
current_don_ka_count = 0.0
|
192 |
+
for event_t, is_drumroll, _ in note_events:
|
193 |
+
if not is_drumroll: # Don or Ka hit
|
194 |
+
if k_window_start_sec <= event_t < k_window_end_sec:
|
195 |
+
current_don_ka_count += 1
|
196 |
+
base_don_ka_sliding_nps[k_idx] = current_don_ka_count / window_duration_seconds
|
197 |
+
|
198 |
+
# Step 2: Calculate adaptive_drumroll_rates_for_all_events
|
199 |
+
adaptive_drumroll_rates_for_all_events = []
|
200 |
+
for event_t, is_drumroll, drumroll_dur in note_events:
|
201 |
+
if is_drumroll:
|
202 |
+
drumroll_start_sec = event_t
|
203 |
+
drumroll_end_sec = event_t + drumroll_dur
|
204 |
+
|
205 |
+
slice_start_idx = math.floor(drumroll_start_sec / time_step_duration_sec)
|
206 |
+
slice_end_idx = math.ceil(drumroll_end_sec / time_step_duration_sec)
|
207 |
+
|
208 |
+
slice_start_idx = max(0, slice_start_idx)
|
209 |
+
slice_end_idx = min(T_sub, slice_end_idx)
|
210 |
+
|
211 |
+
max_nps_in_drumroll_period = 0.0
|
212 |
+
if slice_start_idx < slice_end_idx:
|
213 |
+
relevant_base_nps_values = base_don_ka_sliding_nps[
|
214 |
+
slice_start_idx:slice_end_idx
|
215 |
+
]
|
216 |
+
if relevant_base_nps_values.numel() > 0:
|
217 |
+
max_nps_in_drumroll_period = torch.max(
|
218 |
+
relevant_base_nps_values
|
219 |
+
).item()
|
220 |
+
|
221 |
+
rate = max(5.0, max_nps_in_drumroll_period)
|
222 |
+
adaptive_drumroll_rates_for_all_events.append(rate)
|
223 |
+
else:
|
224 |
+
adaptive_drumroll_rates_for_all_events.append(0.0) # Placeholder
|
225 |
+
|
226 |
+
# Step 3: Calculate final sliding_nps_labels using adaptive rates
|
227 |
+
# sliding_nps_labels is already initialized with zeros earlier in the function.
|
228 |
+
for k_idx in range(T_sub):
|
229 |
+
k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps
|
230 |
+
k_window_start_sec = k_window_end_sec - window_duration_seconds
|
231 |
+
|
232 |
+
current_window_total_nps_contribution = 0.0
|
233 |
+
for event_idx, (event_t, is_drumroll, drumroll_dur) in enumerate(note_events):
|
234 |
+
if is_drumroll:
|
235 |
+
drumroll_start_sec = event_t
|
236 |
+
drumroll_end_sec = event_t + drumroll_dur
|
237 |
+
|
238 |
+
overlap_start = max(k_window_start_sec, drumroll_start_sec)
|
239 |
+
overlap_end = min(k_window_end_sec, drumroll_end_sec)
|
240 |
+
|
241 |
+
if overlap_end > overlap_start:
|
242 |
+
overlap_duration = overlap_end - overlap_start
|
243 |
+
current_adaptive_rate = adaptive_drumroll_rates_for_all_events[
|
244 |
+
event_idx
|
245 |
+
]
|
246 |
+
current_window_total_nps_contribution += (
|
247 |
+
overlap_duration * current_adaptive_rate
|
248 |
+
)
|
249 |
+
else: # Don or Ka hit
|
250 |
+
if k_window_start_sec <= event_t < k_window_end_sec:
|
251 |
+
current_window_total_nps_contribution += (
|
252 |
+
1 # Each hit contributes 1 to the count
|
253 |
+
)
|
254 |
+
|
255 |
+
sliding_nps_labels[k_idx] = (
|
256 |
+
current_window_total_nps_contribution / window_duration_seconds
|
257 |
+
)
|
258 |
+
|
259 |
+
# Normalize sliding_nps_labels to 0-1 range
|
260 |
+
if T_sub > 0: # Ensure there are elements to normalize
|
261 |
+
min_nps_val = torch.min(sliding_nps_labels)
|
262 |
+
max_nps_val = torch.max(sliding_nps_labels)
|
263 |
+
denominator = max_nps_val - min_nps_val
|
264 |
+
if denominator > 1e-6: # Use a small epsilon for float comparison
|
265 |
+
sliding_nps_labels = (sliding_nps_labels - min_nps_val) / denominator
|
266 |
+
else:
|
267 |
+
# If all values are (nearly) the same
|
268 |
+
if max_nps_val > 1e-6: # If the constant value is positive
|
269 |
+
sliding_nps_labels = torch.ones_like(sliding_nps_labels)
|
270 |
+
else: # If the constant value is zero (or very close to it)
|
271 |
+
sliding_nps_labels = torch.zeros_like(sliding_nps_labels)
|
272 |
+
|
273 |
+
duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
|
274 |
+
nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
|
275 |
+
|
276 |
+
parsed = parse_tja(example["tja"], mode=PyParsingMode.Full)
|
277 |
+
chart = next(
|
278 |
+
(chart for chart in parsed.charts if chart.course.lower() == difficulty), None
|
279 |
+
)
|
280 |
+
difficulty_id = (
|
281 |
+
0
|
282 |
+
if difficulty == "easy"
|
283 |
+
else (
|
284 |
+
1
|
285 |
+
if difficulty == "normal"
|
286 |
+
else 2 if difficulty == "hard" else 3 if difficulty == "oni" else 4
|
287 |
+
) # Assuming 4 for edit/ura
|
288 |
+
)
|
289 |
+
level = chart.level if chart else 0
|
290 |
+
|
291 |
+
# --- CNN shape inference and label padding/truncation ---
|
292 |
+
# Simulate CNN to get output time length (T_cnn)
|
293 |
+
dummy_model = TaikoConformer7()
|
294 |
+
with torch.no_grad():
|
295 |
+
cnn_out = dummy_model.cnn(mel.unsqueeze(0)) # (1, C, F, T_cnn)
|
296 |
+
_, _, _, T_cnn = cnn_out.shape
|
297 |
+
|
298 |
+
# Pad or truncate labels to T_cnn
|
299 |
+
def pad_or_truncate(label, out_len):
|
300 |
+
if label.shape[0] < out_len:
|
301 |
+
pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
|
302 |
+
return torch.cat([label, pad], dim=0)
|
303 |
+
else:
|
304 |
+
return label[:out_len]
|
305 |
+
|
306 |
+
don_labels = pad_or_truncate(don_labels, T_cnn)
|
307 |
+
ka_labels = pad_or_truncate(ka_labels, T_cnn)
|
308 |
+
drumroll_labels = pad_or_truncate(drumroll_labels, T_cnn)
|
309 |
+
sliding_nps_labels = pad_or_truncate(sliding_nps_labels, T_cnn) # Pad new label
|
310 |
+
|
311 |
+
# For conformer input lengths: this should be T_cnn
|
312 |
+
conformer_sequence_length = T_cnn # This is the actual sequence length after CNN
|
313 |
+
|
314 |
+
print(
|
315 |
+
f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}, Difficulty: {difficulty_id}, Level: {level}"
|
316 |
+
)
|
317 |
+
|
318 |
+
return {
|
319 |
+
"mel": mel, # (1, N_MELS, T)
|
320 |
+
"don_labels": don_labels, # (T_cnn,)
|
321 |
+
"ka_labels": ka_labels, # (T_cnn,)
|
322 |
+
"drumroll_labels": drumroll_labels, # (T_cnn,)
|
323 |
+
"sliding_nps_labels": sliding_nps_labels, # Add new label (T_cnn,)
|
324 |
+
"nps": torch.tensor(nps, dtype=torch.float32),
|
325 |
+
"difficulty": torch.tensor(difficulty_id, dtype=torch.long),
|
326 |
+
"level": torch.tensor(level, dtype=torch.long),
|
327 |
+
"duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
|
328 |
+
"length": torch.tensor(
|
329 |
+
conformer_sequence_length, dtype=torch.long
|
330 |
+
), # Use T_cnn for conformer and loss masking
|
331 |
+
}
|
332 |
+
|
333 |
+
|
334 |
+
def collate_fn(batch):
|
335 |
+
mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS)
|
336 |
+
don_labels_list = [b["don_labels"] for b in batch]
|
337 |
+
ka_labels_list = [b["ka_labels"] for b in batch]
|
338 |
+
drumroll_labels_list = [b["drumroll_labels"] for b in batch]
|
339 |
+
sliding_nps_labels_list = [b["sliding_nps_labels"] for b in batch] # New label list
|
340 |
+
nps_list = [b["nps"] for b in batch]
|
341 |
+
difficulty_list = [b["difficulty"] for b in batch]
|
342 |
+
level_list = [b["level"] for b in batch]
|
343 |
+
durations_list = [b["duration_seconds"] for b in batch]
|
344 |
+
lengths_list = [b["length"] for b in batch] # These are T_cnn_i for each example
|
345 |
+
|
346 |
+
# Pad mels
|
347 |
+
padded_mels = nn.utils.rnn.pad_sequence(
|
348 |
+
mels_list, batch_first=True
|
349 |
+
) # (B, T_max_mel, N_MELS)
|
350 |
+
reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
|
351 |
+
# T_max_mel_batch = padded_mels.shape[1] # Max mel length in batch, not used for label padding anymore
|
352 |
+
|
353 |
+
# Determine max sequence length for labels (max T_cnn in batch)
|
354 |
+
max_label_len = 0
|
355 |
+
if lengths_list: # handle empty batch case
|
356 |
+
max_label_len = max(l.item() for l in lengths_list) if lengths_list else 0
|
357 |
+
|
358 |
+
# Pad labels to max_label_len (max_t_cnn_in_batch)
|
359 |
+
def pad_label_to_max_len(label_tensor, target_len):
|
360 |
+
current_len = label_tensor.shape[0]
|
361 |
+
if current_len < target_len:
|
362 |
+
padding_size = target_len - current_len
|
363 |
+
# Ensure padding is created on the same device as the label_tensor
|
364 |
+
padding = torch.zeros(
|
365 |
+
padding_size, dtype=label_tensor.dtype, device=label_tensor.device
|
366 |
+
)
|
367 |
+
return torch.cat((label_tensor, padding), dim=0)
|
368 |
+
elif (
|
369 |
+
current_len > target_len
|
370 |
+
): # Should ideally not happen if lengths_list is correct
|
371 |
+
return label_tensor[:target_len]
|
372 |
+
return label_tensor
|
373 |
+
|
374 |
+
don_labels = torch.stack(
|
375 |
+
[pad_label_to_max_len(l, max_label_len) for l in don_labels_list]
|
376 |
+
)
|
377 |
+
ka_labels = torch.stack(
|
378 |
+
[pad_label_to_max_len(l, max_label_len) for l in ka_labels_list]
|
379 |
+
)
|
380 |
+
drumroll_labels = torch.stack(
|
381 |
+
[pad_label_to_max_len(l, max_label_len) for l in drumroll_labels_list]
|
382 |
+
)
|
383 |
+
sliding_nps_labels = torch.stack(
|
384 |
+
[pad_label_to_max_len(l, max_label_len) for l in sliding_nps_labels_list]
|
385 |
+
) # Pad new labels
|
386 |
+
|
387 |
+
actual_lengths = torch.tensor([l.item() for l in lengths_list], dtype=torch.long)
|
388 |
+
|
389 |
+
return {
|
390 |
+
"mel": reshaped_mels,
|
391 |
+
"don_labels": don_labels,
|
392 |
+
"ka_labels": ka_labels,
|
393 |
+
"drumroll_labels": drumroll_labels,
|
394 |
+
"sliding_nps_labels": sliding_nps_labels, # Add new batched labels
|
395 |
+
"lengths": actual_lengths, # for conformer and loss masking (T_cnn_i for each item)
|
396 |
+
"nps": torch.stack(nps_list),
|
397 |
+
"difficulty": torch.stack(difficulty_list),
|
398 |
+
"level": torch.stack(level_list),
|
399 |
+
"durations": torch.stack(durations_list),
|
400 |
+
}
|
tc7/train.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from accelerate.utils import set_seed
|
2 |
+
|
3 |
+
set_seed(1024)
|
4 |
+
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
from tqdm import tqdm
|
10 |
+
from datasets import concatenate_datasets
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import numpy as np
|
13 |
+
from .config import (
|
14 |
+
BATCH_SIZE,
|
15 |
+
DEVICE,
|
16 |
+
EPOCHS,
|
17 |
+
LR,
|
18 |
+
GRAD_ACCUM_STEPS,
|
19 |
+
HOP_LENGTH,
|
20 |
+
NPS_PENALTY_WEIGHT_ALPHA,
|
21 |
+
NPS_PENALTY_WEIGHT_BETA,
|
22 |
+
SAMPLE_RATE,
|
23 |
+
)
|
24 |
+
from .model import TaikoConformer7
|
25 |
+
from .dataset import ds
|
26 |
+
from .preprocess import preprocess, collate_fn
|
27 |
+
from .loss import TaikoLoss
|
28 |
+
from huggingface_hub import upload_folder
|
29 |
+
|
30 |
+
|
31 |
+
def log_energy_plots_to_tensorboard(
|
32 |
+
writer,
|
33 |
+
tag_prefix,
|
34 |
+
epoch,
|
35 |
+
pred_don,
|
36 |
+
pred_ka,
|
37 |
+
pred_drumroll,
|
38 |
+
true_don,
|
39 |
+
true_ka,
|
40 |
+
true_drumroll,
|
41 |
+
valid_length,
|
42 |
+
hop_sec,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Logs a plot of predicted vs. true energies for one sample to TensorBoard.
|
46 |
+
Energies should be 1D numpy arrays for the single sample, up to valid_length.
|
47 |
+
"""
|
48 |
+
pred_don = pred_don[:valid_length].detach().cpu().numpy()
|
49 |
+
pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
|
50 |
+
pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
|
51 |
+
true_don = true_don[:valid_length].cpu().numpy()
|
52 |
+
true_ka = true_ka[:valid_length].cpu().numpy()
|
53 |
+
true_drumroll = true_drumroll[:valid_length].cpu().numpy()
|
54 |
+
|
55 |
+
time_axis = np.arange(valid_length) * hop_sec
|
56 |
+
|
57 |
+
fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
|
58 |
+
fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)
|
59 |
+
|
60 |
+
axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
|
61 |
+
axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
|
62 |
+
axs[0].set_ylabel("Don Energy")
|
63 |
+
axs[0].legend()
|
64 |
+
axs[0].grid(True)
|
65 |
+
|
66 |
+
axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
|
67 |
+
axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
|
68 |
+
axs[1].set_ylabel("Ka Energy")
|
69 |
+
axs[1].legend()
|
70 |
+
axs[1].grid(True)
|
71 |
+
|
72 |
+
axs[2].plot(
|
73 |
+
time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
|
74 |
+
)
|
75 |
+
axs[2].plot(
|
76 |
+
time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
|
77 |
+
)
|
78 |
+
axs[2].set_ylabel("Drumroll Energy")
|
79 |
+
axs[2].set_xlabel("Time (s)")
|
80 |
+
axs[2].legend()
|
81 |
+
axs[2].grid(True)
|
82 |
+
|
83 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
84 |
+
writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
|
85 |
+
plt.close(fig)
|
86 |
+
|
87 |
+
|
88 |
+
def main():
|
89 |
+
global ds
|
90 |
+
|
91 |
+
output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
|
92 |
+
|
93 |
+
best_val_loss = float("inf")
|
94 |
+
patience = 10
|
95 |
+
pat_count = 0
|
96 |
+
|
97 |
+
ds_oni = ds.map(
|
98 |
+
preprocess,
|
99 |
+
remove_columns=ds.column_names,
|
100 |
+
fn_kwargs={"difficulty": "oni"},
|
101 |
+
writer_batch_size=10,
|
102 |
+
)
|
103 |
+
ds_hard = ds.map(
|
104 |
+
preprocess,
|
105 |
+
remove_columns=ds.column_names,
|
106 |
+
fn_kwargs={"difficulty": "hard"},
|
107 |
+
writer_batch_size=10,
|
108 |
+
)
|
109 |
+
ds_normal = ds.map(
|
110 |
+
preprocess,
|
111 |
+
remove_columns=ds.column_names,
|
112 |
+
fn_kwargs={"difficulty": "normal"},
|
113 |
+
writer_batch_size=10,
|
114 |
+
)
|
115 |
+
ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])
|
116 |
+
|
117 |
+
ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
|
118 |
+
train_loader = DataLoader(
|
119 |
+
ds_train_test["train"],
|
120 |
+
batch_size=BATCH_SIZE,
|
121 |
+
shuffle=True,
|
122 |
+
collate_fn=collate_fn,
|
123 |
+
num_workers=8,
|
124 |
+
persistent_workers=True,
|
125 |
+
prefetch_factor=4,
|
126 |
+
)
|
127 |
+
val_loader = DataLoader(
|
128 |
+
ds_train_test["test"],
|
129 |
+
batch_size=BATCH_SIZE,
|
130 |
+
shuffle=False,
|
131 |
+
collate_fn=collate_fn,
|
132 |
+
num_workers=8,
|
133 |
+
persistent_workers=True,
|
134 |
+
prefetch_factor=4,
|
135 |
+
)
|
136 |
+
|
137 |
+
model = TaikoConformer7().to(DEVICE)
|
138 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
139 |
+
|
140 |
+
criterion = TaikoLoss(
|
141 |
+
reduction="mean",
|
142 |
+
nps_penalty_weight_alpha=NPS_PENALTY_WEIGHT_ALPHA,
|
143 |
+
nps_penalty_weight_beta=NPS_PENALTY_WEIGHT_BETA,
|
144 |
+
).to(DEVICE)
|
145 |
+
|
146 |
+
num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
|
147 |
+
total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch
|
148 |
+
|
149 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
150 |
+
optimizer, max_lr=LR, total_steps=total_optimizer_steps
|
151 |
+
)
|
152 |
+
|
153 |
+
writer = SummaryWriter()
|
154 |
+
|
155 |
+
for epoch in range(1, EPOCHS + 1):
|
156 |
+
model.train()
|
157 |
+
total_epoch_loss = 0.0
|
158 |
+
optimizer.zero_grad()
|
159 |
+
|
160 |
+
for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
|
161 |
+
mel = batch["mel"].to(DEVICE)
|
162 |
+
lengths = batch["lengths"].to(DEVICE)
|
163 |
+
nps = batch["nps"].to(DEVICE)
|
164 |
+
difficulty = batch["difficulty"].to(DEVICE)
|
165 |
+
level = batch["level"].to(DEVICE)
|
166 |
+
|
167 |
+
outputs = model(mel, lengths, nps, difficulty, level)
|
168 |
+
loss = criterion(outputs, batch)
|
169 |
+
|
170 |
+
total_epoch_loss += loss.item()
|
171 |
+
|
172 |
+
loss = loss / GRAD_ACCUM_STEPS
|
173 |
+
loss.backward()
|
174 |
+
|
175 |
+
if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
|
176 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
177 |
+
optimizer.step()
|
178 |
+
scheduler.step()
|
179 |
+
optimizer.zero_grad()
|
180 |
+
|
181 |
+
writer.add_scalar(
|
182 |
+
"Loss/Train_Step",
|
183 |
+
loss.item() * GRAD_ACCUM_STEPS,
|
184 |
+
epoch * len(train_loader) + idx,
|
185 |
+
)
|
186 |
+
writer.add_scalar(
|
187 |
+
"LR", scheduler.get_last_lr()[0], epoch * len(train_loader) + idx
|
188 |
+
)
|
189 |
+
|
190 |
+
if idx < 3:
|
191 |
+
if mel.size(0) > 0:
|
192 |
+
pred_don = outputs["presence"][0, :, 0]
|
193 |
+
pred_ka = outputs["presence"][0, :, 1]
|
194 |
+
pred_drumroll = outputs["presence"][0, :, 2]
|
195 |
+
true_don = batch["don_labels"][0]
|
196 |
+
true_ka = batch["ka_labels"][0]
|
197 |
+
true_drumroll = batch["drumroll_labels"][0]
|
198 |
+
valid_length = batch["lengths"][0].item()
|
199 |
+
|
200 |
+
log_energy_plots_to_tensorboard(
|
201 |
+
writer,
|
202 |
+
f"Train_Sample_Batch_{idx}_Sample_0",
|
203 |
+
epoch,
|
204 |
+
pred_don,
|
205 |
+
pred_ka,
|
206 |
+
pred_drumroll,
|
207 |
+
true_don,
|
208 |
+
true_ka,
|
209 |
+
true_drumroll,
|
210 |
+
valid_length,
|
211 |
+
output_frame_hop_sec,
|
212 |
+
)
|
213 |
+
|
214 |
+
avg_train_loss = total_epoch_loss / len(train_loader)
|
215 |
+
writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)
|
216 |
+
|
217 |
+
model.eval()
|
218 |
+
total_val_loss = 0.0
|
219 |
+
|
220 |
+
with torch.no_grad():
|
221 |
+
for idx, batch in enumerate(tqdm(val_loader, desc=f"Val Epoch {epoch}")):
|
222 |
+
mel = batch["mel"].to(DEVICE)
|
223 |
+
lengths = batch["lengths"].to(DEVICE)
|
224 |
+
nps = batch["nps"].to(DEVICE)
|
225 |
+
difficulty = batch["difficulty"].to(DEVICE)
|
226 |
+
level = batch["level"].to(DEVICE)
|
227 |
+
|
228 |
+
outputs = model(mel, lengths, nps, difficulty, level)
|
229 |
+
loss = criterion(outputs, batch)
|
230 |
+
total_val_loss += loss.item()
|
231 |
+
|
232 |
+
if idx < 3:
|
233 |
+
if mel.size(0) > 0:
|
234 |
+
pred_don = outputs["presence"][0, :, 0]
|
235 |
+
pred_ka = outputs["presence"][0, :, 1]
|
236 |
+
pred_drumroll = outputs["presence"][0, :, 2]
|
237 |
+
true_don = batch["don_labels"][0]
|
238 |
+
true_ka = batch["ka_labels"][0]
|
239 |
+
true_drumroll = batch["drumroll_labels"][0]
|
240 |
+
valid_length = batch["lengths"][0].item()
|
241 |
+
|
242 |
+
log_energy_plots_to_tensorboard(
|
243 |
+
writer,
|
244 |
+
f"Val_Sample_Batch_{idx}_Sample_0",
|
245 |
+
epoch,
|
246 |
+
pred_don,
|
247 |
+
pred_ka,
|
248 |
+
pred_drumroll,
|
249 |
+
true_don,
|
250 |
+
true_ka,
|
251 |
+
true_drumroll,
|
252 |
+
valid_length,
|
253 |
+
output_frame_hop_sec,
|
254 |
+
)
|
255 |
+
|
256 |
+
avg_val_loss = total_val_loss / len(val_loader)
|
257 |
+
writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)
|
258 |
+
|
259 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
260 |
+
writer.add_scalar("LR/learning_rate", current_lr, epoch)
|
261 |
+
|
262 |
+
if "nps" in batch:
|
263 |
+
writer.add_scalar(
|
264 |
+
"NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
|
265 |
+
)
|
266 |
+
|
267 |
+
print(
|
268 |
+
f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
|
269 |
+
)
|
270 |
+
|
271 |
+
if avg_val_loss < best_val_loss:
|
272 |
+
best_val_loss = avg_val_loss
|
273 |
+
pat_count = 0
|
274 |
+
torch.save(model.state_dict(), "best_model.pt")
|
275 |
+
print(f"Saved new best model to best_model.pt at epoch {epoch}")
|
276 |
+
else:
|
277 |
+
pat_count += 1
|
278 |
+
if pat_count >= patience:
|
279 |
+
print("Early stopping!")
|
280 |
+
break
|
281 |
+
writer.close()
|
282 |
+
|
283 |
+
model_id = "JacobLinCool/taiko-conformer-7"
|
284 |
+
try:
|
285 |
+
model.push_to_hub(
|
286 |
+
model_id, commit_message=f"Epoch {epoch}, Val Loss: {avg_val_loss:.4f}"
|
287 |
+
)
|
288 |
+
upload_folder(
|
289 |
+
repo_id=model_id,
|
290 |
+
folder_path="runs",
|
291 |
+
path_in_repo="runs",
|
292 |
+
commit_message="Upload TensorBoard logs",
|
293 |
+
)
|
294 |
+
except Exception as e:
|
295 |
+
print(f"Error uploading model or logs: {e}")
|
296 |
+
print("Make sure you have the correct permissions and try again.")
|
297 |
+
|
298 |
+
|
299 |
+
if __name__ == "__main__":
|
300 |
+
main()
|