JacobLinCool commited on
Commit
812b01c
·
1 Parent(s): 7948b62

Implement TaikoConformer7 model, loss function, preprocessing, and training pipeline

Browse files

- Added TaikoLoss class for custom loss calculation with NPS penalties.
- Developed TaikoConformer7 model architecture using Conformer and CNN layers.
- Created preprocessing functions to handle audio data and generate labels.
- Implemented training script with data loading, model training, and validation.
- Integrated TensorBoard logging for loss and energy comparisons during training.
- Added support for sliding NPS labels in preprocessing and loss calculation.

Files changed (27) hide show
  1. .gitignore +1 -0
  2. app.py +300 -0
  3. requirements.txt +14 -0
  4. tc5/__init__.py +0 -0
  5. tc5/config.py +25 -0
  6. tc5/dataset.py +21 -0
  7. tc5/infer.py +356 -0
  8. tc5/loss.py +65 -0
  9. tc5/model.py +133 -0
  10. tc5/preprocess.py +215 -0
  11. tc5/train.py +323 -0
  12. tc6/__init__.py +0 -0
  13. tc6/config.py +25 -0
  14. tc6/dataset.py +21 -0
  15. tc6/infer.py +354 -0
  16. tc6/loss.py +65 -0
  17. tc6/model.py +166 -0
  18. tc6/preprocess.py +258 -0
  19. tc6/train.py +336 -0
  20. tc7/__init__.py +0 -0
  21. tc7/config.py +27 -0
  22. tc7/dataset.py +21 -0
  23. tc7/infer.py +354 -0
  24. tc7/loss.py +94 -0
  25. tc7/model.py +166 -0
  26. tc7/preprocess.py +400 -0
  27. tc7/train.py +300 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from tc5.config import SAMPLE_RATE, HOP_LENGTH
4
+ from tc5.model import TaikoConformer5
5
+ from tc5 import infer as tc5infer
6
+ from tc6.model import TaikoConformer6
7
+ from tc6 import infer as tc6infer
8
+ from tc7.model import TaikoConformer7
9
+ from tc7 import infer as tc7infer
10
+ from gradio_client import Client, handle_file
11
+ import tempfile
12
+
13
+ DEVICE = torch.device("cpu")
14
+
15
+ # Load model once
16
+ tc5 = TaikoConformer5.from_pretrained("JacobLinCool/taiko-conformer-5")
17
+ tc5.to(DEVICE)
18
+ tc5.eval()
19
+
20
+ # Load TC6 model
21
+ tc6 = TaikoConformer6.from_pretrained("JacobLinCool/taiko-conformer-6")
22
+ tc6.to(DEVICE)
23
+ tc6.eval()
24
+
25
+ # Load TC7 model
26
+ tc7 = TaikoConformer7.from_pretrained("JacobLinCool/taiko-conformer-7")
27
+ tc7.to(DEVICE)
28
+ tc7.eval()
29
+
30
+ synthesizer = Client("ryanlinjui/taiko-music-generator")
31
+
32
+
33
+ def infer_tc5(audio, nps, bpm):
34
+ audio_path = audio
35
+ filename = audio_path.split("/")[-1]
36
+ # Preprocess
37
+ mel_input, nps_input = tc5infer.preprocess_audio(audio_path, nps)
38
+ # Inference
39
+ don_energy, ka_energy, drumroll_energy = tc5infer.run_inference(
40
+ tc5, mel_input, nps_input, DEVICE
41
+ )
42
+ output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
43
+ onsets = tc5infer.decode_onsets(
44
+ don_energy,
45
+ ka_energy,
46
+ drumroll_energy,
47
+ output_frame_hop_sec,
48
+ threshold=0.3,
49
+ min_distance_frames=3,
50
+ )
51
+ # Generate plot
52
+ plot = tc5infer.plot_results(
53
+ mel_input,
54
+ don_energy,
55
+ ka_energy,
56
+ drumroll_energy,
57
+ onsets,
58
+ output_frame_hop_sec,
59
+ )
60
+ # Generate TJA content
61
+ tja_content = tc5infer.write_tja(onsets, bpm=bpm, audio=filename)
62
+
63
+ # wrtie TJA content to a temporary file
64
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
65
+ temp_tja_file.write(tja_content.encode("utf-8"))
66
+ tja_path = temp_tja_file.name
67
+
68
+ result = synthesizer.predict(
69
+ param_0=handle_file(tja_path),
70
+ param_1=handle_file(audio_path),
71
+ param_2="達人譜面 / Master",
72
+ param_3=16,
73
+ param_4=5,
74
+ param_5=5,
75
+ param_6=5,
76
+ param_7=5,
77
+ param_8=5,
78
+ param_9=5,
79
+ param_10=5,
80
+ param_11=5,
81
+ param_12=5,
82
+ param_13=5,
83
+ param_14=5,
84
+ param_15=5,
85
+ api_name="/handle",
86
+ )
87
+
88
+ oni_audio = result[1]
89
+
90
+ return oni_audio, plot, tja_content
91
+
92
+
93
+ def infer_tc6(audio, nps, bpm, difficulty, level):
94
+ audio_path = audio
95
+ filename = audio_path.split("/")[-1]
96
+ # Preprocess
97
+ mel_input = tc6infer.preprocess_audio(audio_path)
98
+ nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE)
99
+ difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE)
100
+ level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
101
+ # Inference
102
+ don_energy, ka_energy, drumroll_energy = tc6infer.run_inference(
103
+ tc6, mel_input, nps_input, difficulty_input, level_input, DEVICE
104
+ )
105
+ output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
106
+ onsets = tc6infer.decode_onsets(
107
+ don_energy,
108
+ ka_energy,
109
+ drumroll_energy,
110
+ output_frame_hop_sec,
111
+ threshold=0.3,
112
+ min_distance_frames=3,
113
+ )
114
+ # Generate plot
115
+ plot = tc6infer.plot_results(
116
+ mel_input,
117
+ don_energy,
118
+ ka_energy,
119
+ drumroll_energy,
120
+ onsets,
121
+ output_frame_hop_sec,
122
+ )
123
+ # Generate TJA content
124
+ tja_content = tc6infer.write_tja(onsets, bpm=bpm, audio=filename)
125
+
126
+ # wrtie TJA content to a temporary file
127
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
128
+ temp_tja_file.write(tja_content.encode("utf-8"))
129
+ tja_path = temp_tja_file.name
130
+
131
+ result = synthesizer.predict(
132
+ param_0=handle_file(tja_path),
133
+ param_1=handle_file(audio_path),
134
+ param_2="達人譜面 / Master",
135
+ param_3=16,
136
+ param_4=5,
137
+ param_5=5,
138
+ param_6=5,
139
+ param_7=5,
140
+ param_8=5,
141
+ param_9=5,
142
+ param_10=5,
143
+ param_11=5,
144
+ param_12=5,
145
+ param_13=5,
146
+ param_14=5,
147
+ param_15=5,
148
+ api_name="/handle",
149
+ )
150
+
151
+ oni_audio = result[1]
152
+
153
+ return oni_audio, plot, tja_content
154
+
155
+
156
+ def infer_tc7(audio, nps, bpm, difficulty, level):
157
+ audio_path = audio
158
+ filename = audio_path.split("/")[-1]
159
+ # Preprocess
160
+ mel_input = tc7infer.preprocess_audio(audio_path)
161
+ nps_input = torch.tensor(nps, dtype=torch.float32).to(DEVICE)
162
+ difficulty_input = torch.tensor(difficulty, dtype=torch.float32).to(DEVICE)
163
+ level_input = torch.tensor(level, dtype=torch.float32).to(DEVICE)
164
+ # Inference
165
+ don_energy, ka_energy, drumroll_energy = tc7infer.run_inference(
166
+ tc7, mel_input, nps_input, difficulty_input, level_input, DEVICE
167
+ )
168
+ output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
169
+ onsets = tc7infer.decode_onsets(
170
+ don_energy,
171
+ ka_energy,
172
+ drumroll_energy,
173
+ output_frame_hop_sec,
174
+ threshold=0.3,
175
+ min_distance_frames=3,
176
+ )
177
+ # Generate plot
178
+ plot = tc7infer.plot_results(
179
+ mel_input,
180
+ don_energy,
181
+ ka_energy,
182
+ drumroll_energy,
183
+ onsets,
184
+ output_frame_hop_sec,
185
+ )
186
+ # Generate TJA content
187
+ tja_content = tc7infer.write_tja(onsets, bpm=bpm, audio=filename)
188
+
189
+ # wrtie TJA content to a temporary file
190
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".tja") as temp_tja_file:
191
+ temp_tja_file.write(tja_content.encode("utf-8"))
192
+ tja_path = temp_tja_file.name
193
+
194
+ result = synthesizer.predict(
195
+ param_0=handle_file(tja_path),
196
+ param_1=handle_file(audio_path),
197
+ param_2="達人譜面 / Master",
198
+ param_3=16,
199
+ param_4=5,
200
+ param_5=5,
201
+ param_6=5,
202
+ param_7=5,
203
+ param_8=5,
204
+ param_9=5,
205
+ param_10=5,
206
+ param_11=5,
207
+ param_12=5,
208
+ param_13=5,
209
+ param_14=5,
210
+ param_15=5,
211
+ api_name="/handle",
212
+ )
213
+
214
+ oni_audio = result[1]
215
+
216
+ return oni_audio, plot, tja_content
217
+
218
+
219
+ def run_inference(audio, model_choice, nps, bpm, difficulty, level):
220
+ if model_choice == "TC5":
221
+ return infer_tc5(audio, nps, bpm)
222
+ elif model_choice == "TC6":
223
+ return infer_tc6(audio, nps, bpm, difficulty, level)
224
+ else: # TC7
225
+ return infer_tc7(audio, nps, bpm, difficulty, level)
226
+
227
+
228
+ with gr.Blocks() as demo:
229
+ gr.Markdown("# Taiko Conformer 5/7 Demo")
230
+ with gr.Row():
231
+ audio_input = gr.Audio(sources="upload", type="filepath", label="Input Audio")
232
+
233
+ with gr.Row():
234
+ model_choice = gr.Dropdown(
235
+ choices=["TC5", "TC6", "TC7"],
236
+ value="TC7",
237
+ label="Model Selection",
238
+ info="Choose between TaikoConformer 5, 6 or 7",
239
+ )
240
+
241
+ with gr.Row():
242
+ nps = gr.Slider(
243
+ value=5.0,
244
+ minimum=0.5,
245
+ maximum=11.0,
246
+ step=0.5,
247
+ label="NPS (Notes Per Second)",
248
+ )
249
+ bpm = gr.Slider(
250
+ value=240,
251
+ minimum=160,
252
+ maximum=640,
253
+ step=1,
254
+ label="BPM (Used by TJA Quantization)",
255
+ )
256
+
257
+ with gr.Row():
258
+ difficulty = gr.Slider(
259
+ value=3.0,
260
+ minimum=1.0,
261
+ maximum=3.0,
262
+ step=1.0,
263
+ label="Difficulty",
264
+ visible=False,
265
+ info="1=Normal, 2=Hard, 3=Oni",
266
+ )
267
+ level = gr.Slider(
268
+ value=8.0,
269
+ minimum=1.0,
270
+ maximum=10.0,
271
+ step=1.0,
272
+ label="Level",
273
+ visible=False,
274
+ info="Difficulty level from 1 to 10",
275
+ )
276
+
277
+ audio_output = gr.Audio(label="Generated Audio", type="filepath")
278
+ plot_output = gr.Plot(label="Onset/Energy Plot")
279
+ tja_output = gr.Textbox(label="TJA File Content", show_copy_button=True)
280
+ run_btn = gr.Button("Run Inference")
281
+
282
+ # Update visibility of TC7-specific controls based on model selection
283
+ def update_visibility(model_choice):
284
+ if model_choice == "TC7" or model_choice == "TC6":
285
+ return gr.update(visible=True), gr.update(visible=True)
286
+ else:
287
+ return gr.update(visible=False), gr.update(visible=False)
288
+
289
+ model_choice.change(
290
+ update_visibility, inputs=[model_choice], outputs=[difficulty, level]
291
+ )
292
+
293
+ run_btn.click(
294
+ run_inference,
295
+ inputs=[audio_input, model_choice, nps, bpm, difficulty, level],
296
+ outputs=[audio_output, plot_output, tja_output],
297
+ )
298
+
299
+ if __name__ == "__main__":
300
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ datasets
4
+ huggingface_hub
5
+ librosa
6
+ soundfile
7
+ matplotlib
8
+ tensorboard
9
+ black
10
+ tqdm
11
+ safetensors
12
+ accelerate
13
+ tja
14
+ spaces
tc5/__init__.py ADDED
File without changes
tc5/config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # ─── 1) CONFIG ─────────────────────────────────────────────────────
4
+ SAMPLE_RATE = 22050
5
+ N_MELS = 80
6
+ HOP_LENGTH = 256 # ~86 fps
7
+ TIME_SUB = 1
8
+ CNN_CH = 128
9
+ N_HEADS = 4
10
+ D_MODEL = 256
11
+ FF_DIM = 512
12
+ N_LAYERS = 4
13
+ DEPTHWISE_CONV_KERNEL_SIZE = 31
14
+ DROPOUT = 0.1
15
+ HIDDEN_DIM = 64
16
+ N_TYPES = 7
17
+ BATCH_SIZE = 4
18
+ GRAD_ACCUM_STEPS = 4
19
+ LR = 3e-4
20
+ EPOCHS = 30
21
+ DEVICE = (
22
+ "cuda"
23
+ if torch.cuda.is_available()
24
+ else "mps" if torch.backends.mps.is_available() else "cpu"
25
+ )
tc5/dataset.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, concatenate_datasets
2
+
3
+ # ds1 = load_dataset("JacobLinCool/taiko-2023-1.1", split="train")
4
+ # ds2 = load_dataset("JacobLinCool/taiko-2023-1.2", split="train")
5
+ # ds3 = load_dataset("JacobLinCool/taiko-2023-1.3", split="train")
6
+ # ds4 = load_dataset("JacobLinCool/taiko-2023-1.4", split="train")
7
+ # ds5 = load_dataset("JacobLinCool/taiko-2023-1.5", split="train")
8
+ # ds6 = load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
9
+ # ds7 = load_dataset("JacobLinCool/taiko-2023-1.7", split="train")
10
+ # ds = concatenate_datasets([ds1, ds2, ds3, ds4, ds5, ds6, ds7]).with_format("torch")
11
+
12
+ # good = list(range(len(ds)))
13
+ # good.remove(1079) # 1079 has file problem
14
+ # ds = ds.select(good)
15
+
16
+ # for local test
17
+ ds = (
18
+ load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
19
+ .with_format("torch")
20
+ .select(range(10))
21
+ )
tc5/infer.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import torchaudio
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH
7
+ import torch.profiler
8
+
9
+
10
+ # --- PREPROCESSING (match training) ---
11
+ def preprocess_audio(audio_path, nps=5.0):
12
+ wav, sr = torchaudio.load(audio_path)
13
+ wav = wav.mean(dim=0) # mono
14
+ if sr != SAMPLE_RATE:
15
+ wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
16
+ wav = wav / (wav.abs().max() + 1e-8) # Normalize audio
17
+
18
+ nps_tensor = torch.tensor(nps, dtype=torch.float32)
19
+
20
+ mel_transform = torchaudio.transforms.MelSpectrogram(
21
+ sample_rate=SAMPLE_RATE,
22
+ n_mels=N_MELS,
23
+ hop_length=HOP_LENGTH,
24
+ n_fft=2048,
25
+ )
26
+ mel = mel_transform(wav)
27
+ # mel shape is (n_mels, T_mel), unsqueeze for batch later in run_inference
28
+ return mel, nps_tensor # mel is (N_MELS, T_mel)
29
+
30
+
31
+ # --- INFERENCE ---
32
+ def run_inference(model, mel_input, nps_input, device):
33
+ model.eval()
34
+ with torch.no_grad():
35
+ mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel)
36
+ nps = nps_input.to(device).unsqueeze(0) # (1,)
37
+
38
+ mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel)
39
+
40
+ conformer_lengths = torch.tensor(
41
+ [mel_cnn_input.shape[-1]], dtype=torch.long, device=device
42
+ )
43
+
44
+ with torch.profiler.profile(
45
+ activities=[
46
+ torch.profiler.ProfilerActivity.CPU,
47
+ *(
48
+ [torch.profiler.ProfilerActivity.CUDA]
49
+ if device.type == "cuda"
50
+ else []
51
+ ),
52
+ ],
53
+ record_shapes=True,
54
+ profile_memory=True,
55
+ with_stack=False,
56
+ with_flops=True,
57
+ ) as prof:
58
+ out_dict = model(mel_cnn_input, conformer_lengths, nps)
59
+ print(
60
+ prof.key_averages().table(
61
+ sort_by=(
62
+ "self_cuda_memory_usage"
63
+ if device.type == "cuda"
64
+ else "self_cpu_time_total"
65
+ ),
66
+ row_limit=20,
67
+ )
68
+ )
69
+
70
+ energies = out_dict["presence"].squeeze(0).cpu().numpy()
71
+
72
+ don_energy = energies[:, 0]
73
+ ka_energy = energies[:, 1]
74
+ drumroll_energy = energies[:, 2]
75
+
76
+ return don_energy, ka_energy, drumroll_energy
77
+
78
+
79
+ # --- DECODE TO ONSETS ---
80
+ def decode_onsets(
81
+ don_energy,
82
+ ka_energy,
83
+ drumroll_energy,
84
+ hop_sec,
85
+ threshold=0.5,
86
+ min_distance_frames=3,
87
+ ):
88
+ results = []
89
+ T_out = len(don_energy)
90
+ last_onset_frame = -min_distance_frames
91
+
92
+ for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection
93
+ if i < last_onset_frame + min_distance_frames:
94
+ continue
95
+
96
+ e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i]
97
+ energies_at_i = {
98
+ 1: e_don,
99
+ 2: e_ka,
100
+ 5: e_drum,
101
+ } # Type mapping: 1:Don, 2:Ka, 5:Drumroll
102
+
103
+ # Find which energy is max and if it's a peak above threshold
104
+ # Sort by energy value descending to prioritize higher energy in case of ties for peak condition
105
+ sorted_types_by_energy = sorted(
106
+ energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True
107
+ )
108
+
109
+ detected_this_frame = False
110
+ for onset_type in sorted_types_by_energy:
111
+ current_energy_series = None
112
+ if onset_type == 1:
113
+ current_energy_series = don_energy
114
+ elif onset_type == 2:
115
+ current_energy_series = ka_energy
116
+ elif onset_type == 5:
117
+ current_energy_series = drumroll_energy
118
+
119
+ energy_val = current_energy_series[i]
120
+
121
+ if (
122
+ energy_val > threshold
123
+ and energy_val > current_energy_series[i - 1]
124
+ and energy_val > current_energy_series[i + 1]
125
+ ):
126
+ # Check if this energy is the highest among the three at this frame
127
+ # This check is implicitly handled by iterating `sorted_types_by_energy`
128
+ # and breaking after the first detection.
129
+ results.append((i * hop_sec, onset_type))
130
+ last_onset_frame = i
131
+ detected_this_frame = True
132
+ break # Only one onset type per frame
133
+
134
+ return results
135
+
136
+
137
+ # --- VISUALIZATION ---
138
+ def plot_results(
139
+ mel_spectrogram,
140
+ don_energy,
141
+ ka_energy,
142
+ drumroll_energy,
143
+ onsets,
144
+ hop_sec,
145
+ out_path=None,
146
+ ):
147
+ # mel_spectrogram is (N_MELS, T_mel)
148
+ T_mel = mel_spectrogram.shape[1]
149
+ T_out = len(don_energy) # Length of energy arrays (model output time dimension)
150
+
151
+ # Time axes
152
+ time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE)
153
+ # hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE
154
+ # However, the model output T_out is related to T_mel (input to CNN).
155
+ # If CNN does not change time dimension, T_out = T_mel.
156
+ # If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB.
157
+ # The `lengths` passed to conformer in `run_inference` is T_mel.
158
+ # The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`.
159
+ # So, T_out from model is T_mel.
160
+ # The `hop_sec` for onsets should be based on the model output frame rate.
161
+ # If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE.
162
+ # The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE.
163
+ # This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels.
164
+ # The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`.
165
+ # In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN.
166
+ # The `lengths` for the conformer is based on this T_cnn_out.
167
+ # So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps.
168
+ # Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer.
169
+ # Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames.
170
+ time_axis_energies = np.arange(T_out) * hop_sec
171
+
172
+ fig, ax1 = plt.subplots(figsize=(100, 10))
173
+
174
+ # Plot Mel Spectrogram on ax1
175
+ mel_db = torchaudio.functional.amplitude_to_DB(
176
+ mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0
177
+ )
178
+ img = ax1.imshow(
179
+ mel_db.numpy(),
180
+ aspect="auto",
181
+ origin="lower",
182
+ cmap="magma",
183
+ extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS],
184
+ )
185
+ ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets")
186
+ ax1.set_xlabel("Time (s)")
187
+ ax1.set_ylabel("Mel Bin")
188
+ fig.colorbar(img, ax=ax1, format="%+2.0f dB")
189
+
190
+ # Create a second y-axis for energies
191
+ ax2 = ax1.twinx()
192
+ ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red")
193
+ ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue")
194
+ ax2.plot(
195
+ time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green"
196
+ )
197
+ ax2.set_ylabel("Energy")
198
+ ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded
199
+
200
+ # Overlay onsets from decode_onsets (t is already in seconds)
201
+ labeled_types = set()
202
+ # Group drumrolls into segments (reuse logic from write_tja)
203
+ drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5]
204
+ drumroll_times.sort()
205
+ drumroll_segments = []
206
+ if drumroll_times:
207
+ seg_start = drumroll_times[0]
208
+ prev = drumroll_times[0]
209
+ for t in drumroll_times[1:]:
210
+ if t - prev <= hop_sec * 6: # up to 5-frame gap
211
+ prev = t
212
+ else:
213
+ drumroll_segments.append((seg_start, prev))
214
+ seg_start = t
215
+ prev = t
216
+ drumroll_segments.append((seg_start, prev))
217
+ # Plot Don/Ka onsets as vertical lines
218
+ for t_sec, typ in onsets:
219
+ if typ == 5:
220
+ continue # skip drumroll onsets
221
+ color_map = {1: "darkred", 2: "darkblue"}
222
+ label_map = {1: "Don Onset", 2: "Ka Onset"}
223
+ line_color = color_map.get(typ, "black")
224
+ line_label = label_map.get(typ, f"Type {typ} Onset")
225
+ if typ not in labeled_types:
226
+ ax1.axvline(
227
+ t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label
228
+ )
229
+ labeled_types.add(typ)
230
+ else:
231
+ ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9)
232
+ # Plot drumroll segments as shaded regions
233
+ for seg_start, seg_end in drumroll_segments:
234
+ ax1.axvspan(
235
+ seg_start,
236
+ seg_end + hop_sec,
237
+ color="green",
238
+ alpha=0.2,
239
+ label="Drumroll Segment" if "drumroll" not in labeled_types else None,
240
+ )
241
+ labeled_types.add("drumroll")
242
+
243
+ # Combine legends from both axes
244
+ lines, labels = ax1.get_legend_handles_labels()
245
+ lines2, labels2 = ax2.get_legend_handles_labels()
246
+ ax2.legend(lines + lines2, labels + labels2, loc="upper right")
247
+
248
+ fig.tight_layout()
249
+
250
+ # Return plot as image buffer or save to file if path provided
251
+ if out_path:
252
+ plt.savefig(out_path)
253
+ print(f"Saved plot to {out_path}")
254
+ plt.close(fig)
255
+ return out_path
256
+ else:
257
+ # Return plot as in-memory buffer
258
+ return fig
259
+
260
+
261
+ def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
262
+ # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
263
+ # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
264
+ sec_per_beat = 60 / bpm
265
+ beats_per_measure = 4 # Assuming 4/4 time signature
266
+ sec_per_measure = sec_per_beat * beats_per_measure
267
+ # Step 1: Map onsets to (measure_idx, slot, typ)
268
+ slot_events = []
269
+ for t, typ in onsets:
270
+ measure_idx = int(t // sec_per_measure)
271
+ t_in_measure = t % sec_per_measure
272
+ slot = int(round(t_in_measure / sec_per_measure * quantize))
273
+ if slot >= quantize:
274
+ slot = quantize - 1
275
+ slot_events.append((measure_idx, slot, typ))
276
+ # Step 2: Build measure/slot grid
277
+ if slot_events:
278
+ max_measure_idx = max(m for m, _, _ in slot_events)
279
+ else:
280
+ max_measure_idx = -1
281
+ measures = {i: [0] * quantize for i in range(max_measure_idx + 1)}
282
+ # Step 3: Place Don/Ka, collect drumrolls
283
+ drumroll_slots = set()
284
+ for m, s, typ in slot_events:
285
+ if typ in [1, 2]:
286
+ measures[m][s] = typ
287
+ elif typ == 5:
288
+ drumroll_slots.add((m, s))
289
+ # Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end)
290
+ # Flatten all slots to a list of (measure, slot) sorted
291
+ drumroll_list = sorted(list(drumroll_slots))
292
+ # Group into contiguous regions (allowing a gap of 5 slots)
293
+ grouped = []
294
+ group = []
295
+ for ms in drumroll_list:
296
+ if not group:
297
+ group = [ms]
298
+ else:
299
+ last_m, last_s = group[-1]
300
+ m, s = ms
301
+ # Calculate slot distance, considering measure wrap
302
+ slot_dist = None
303
+ if m == last_m:
304
+ slot_dist = s - last_s
305
+ elif m == last_m + 1 and last_s <= quantize - 1:
306
+ slot_dist = (quantize - 1 - last_s) + s + 1
307
+ else:
308
+ slot_dist = None
309
+ # Allow gap of up to 5 slots (slot_dist <= 6)
310
+ if slot_dist is not None and 1 <= slot_dist <= 6:
311
+ group.append(ms)
312
+ else:
313
+ grouped.append(group)
314
+ group = [ms]
315
+ if group:
316
+ grouped.append(group)
317
+ # Mark 5 (start) and 8 (end) for each group
318
+ for region in grouped:
319
+ if len(region) == 1:
320
+ m, s = region[0]
321
+ measures[m][s] = 5
322
+ # Place 8 in next slot (or next measure if at end)
323
+ if s < quantize - 1:
324
+ measures[m][s + 1] = 8
325
+ elif m < max_measure_idx:
326
+ measures[m + 1][0] = 8
327
+ else:
328
+ m_start, s_start = region[0]
329
+ m_end, s_end = region[-1]
330
+ measures[m_start][s_start] = 5
331
+ measures[m_end][s_end] = 8
332
+ # Fill 0 for middle slots (already 0 by default)
333
+
334
+ # Step 5: Generate TJA content
335
+ tja_content = []
336
+ tja_content.append(f"TITLE:{audio} (TC5, {time.strftime('%Y-%m-%d %H:%M:%S')})")
337
+ tja_content.append(f"BPM:{bpm}")
338
+ tja_content.append(f"WAVE:{audio}")
339
+ tja_content.append("OFFSET:0")
340
+ tja_content.append("COURSE:Oni\nLEVEL:9\n")
341
+ tja_content.append("#START")
342
+ for i in range(max_measure_idx + 1):
343
+ notes = measures.get(i, [0] * quantize)
344
+ line = "".join(str(n) for n in notes)
345
+ tja_content.append(line + ",")
346
+ tja_content.append("#END")
347
+
348
+ tja_string = "\n".join(tja_content)
349
+
350
+ # If out_path is provided, also write to file
351
+ if out_path:
352
+ with open(out_path, "w", encoding="utf-8") as f:
353
+ f.write(tja_string)
354
+ print(f"TJA chart saved to {out_path}")
355
+
356
+ return tja_string
tc5/loss.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class TaikoEnergyLoss(nn.Module):
6
+ def __init__(self, reduction="mean"):
7
+ super().__init__()
8
+ # Use 'none' reduction to get element-wise losses, then manually apply masking and reduction
9
+ self.mse_loss = nn.MSELoss(reduction="none")
10
+ self.reduction = reduction
11
+
12
+ def forward(self, outputs, batch):
13
+ """
14
+ Calculates the MSE loss for energy-based predictions.
15
+
16
+ Args:
17
+ outputs (dict): Model output, containing 'presence' tensor.
18
+ outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
19
+ batch (dict): Batch data from collate_fn, containing true labels and lengths.
20
+ batch['don_labels'], batch['ka_labels'], batch['drumroll_labels'] shape: (B, T)
21
+ batch['lengths'] shape: (B,) - valid sequence lengths for time dimension T.
22
+ Returns:
23
+ torch.Tensor: The calculated loss.
24
+ """
25
+ pred_energies = outputs["presence"] # (B, T, 3)
26
+
27
+ true_don = batch["don_labels"] # (B, T)
28
+ true_ka = batch["ka_labels"] # (B, T)
29
+ true_drumroll = batch["drumroll_labels"] # (B, T)
30
+
31
+ # Stack true labels to match the structure of pred_energies (B, T, 3)
32
+ true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2)
33
+
34
+ B, T, _ = pred_energies.shape
35
+
36
+ # Create a mask based on batch['lengths'] to ignore padded parts of sequences
37
+ # batch['lengths'] gives the actual length of each sequence in the batch
38
+ # mask shape: (B, T)
39
+ mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
40
+ "lengths"
41
+ ].unsqueeze(1)
42
+ # Expand mask to (B, T, 1) to broadcast across the 3 energy channels
43
+ mask_3d = mask_2d.unsqueeze(2)
44
+
45
+ # Calculate element-wise MSE loss
46
+ loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3)
47
+
48
+ # Apply the mask to the loss
49
+ masked_loss = loss_elementwise * mask_3d
50
+
51
+ if self.reduction == "mean":
52
+ # Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
53
+ total_loss = masked_loss.sum()
54
+ num_valid_elements = mask_3d.sum() # Total number of unmasked float values
55
+ if num_valid_elements > 0:
56
+ return total_loss / num_valid_elements
57
+ else:
58
+ # Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
59
+ return torch.tensor(
60
+ 0.0, device=pred_energies.device, requires_grad=True
61
+ )
62
+ elif self.reduction == "sum":
63
+ return masked_loss.sum()
64
+ else: # 'none' or any other case
65
+ return masked_loss
tc5/model.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchaudio.models import Conformer
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from .config import (
6
+ N_MELS,
7
+ CNN_CH,
8
+ N_HEADS,
9
+ D_MODEL,
10
+ FF_DIM,
11
+ N_LAYERS,
12
+ DROPOUT,
13
+ DEPTHWISE_CONV_KERNEL_SIZE,
14
+ HIDDEN_DIM,
15
+ DEVICE,
16
+ )
17
+
18
+
19
+ class TaikoConformer5(nn.Module, PyTorchModelHubMixin):
20
+ def __init__(self):
21
+ super().__init__()
22
+ # 1) CNN frontend: frequency-only pooling
23
+ self.cnn = nn.Sequential(
24
+ nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
25
+ nn.BatchNorm2d(CNN_CH),
26
+ nn.GELU(),
27
+ nn.Dropout2d(DROPOUT),
28
+ nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
29
+ nn.BatchNorm2d(CNN_CH),
30
+ nn.GELU(),
31
+ nn.Dropout2d(DROPOUT),
32
+ )
33
+ feat_dim = CNN_CH * (N_MELS // 4)
34
+
35
+ # 2) Linear projection to model dimension
36
+ self.proj = nn.Linear(feat_dim, D_MODEL)
37
+
38
+ # 3) FiLM conditioning for notes_per_second
39
+ self.film = nn.Linear(1, 2 * D_MODEL)
40
+
41
+ # 4) Conformer encoder
42
+ self.encoder = Conformer(
43
+ input_dim=D_MODEL,
44
+ num_heads=N_HEADS,
45
+ ffn_dim=FF_DIM,
46
+ num_layers=N_LAYERS,
47
+ depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
48
+ dropout=DROPOUT,
49
+ use_group_norm=False,
50
+ convolution_first=False,
51
+ )
52
+
53
+ # 5) Presence regressor head
54
+ self.presence_regressor = nn.Sequential(
55
+ nn.Dropout(DROPOUT),
56
+ nn.Linear(D_MODEL, HIDDEN_DIM),
57
+ nn.GELU(),
58
+ nn.Dropout(DROPOUT),
59
+ nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
60
+ nn.Sigmoid(), # Output between 0 and 1
61
+ )
62
+
63
+ # 6) Initialize weights
64
+ for m in self.modules():
65
+ if isinstance(m, nn.Conv2d):
66
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
67
+ elif isinstance(m, nn.Linear):
68
+ nn.init.xavier_uniform_(m.weight)
69
+ if m.bias is not None:
70
+ nn.init.zeros_(m.bias)
71
+
72
+ def forward(
73
+ self, mel: torch.Tensor, lengths: torch.Tensor, notes_per_second: torch.Tensor
74
+ ):
75
+ """
76
+ Args:
77
+ mel: (B, 1, N_MELS, T_mel)
78
+ lengths: (B,) lengths after CNN
79
+ notes_per_second: (B,) stream of control values
80
+ Returns:
81
+ Dict with:
82
+ 'presence': (B, T_cnn_out, 4)
83
+ 'lengths': lengths
84
+ """
85
+ # CNN frontend
86
+ x = self.cnn(mel) # (B, C, F, T)
87
+ B, C, F, T = x.size()
88
+ x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
89
+
90
+ # Project to model dimension
91
+ x = self.proj(x) # (B, T, D_MODEL)
92
+
93
+ # FiLM conditioning
94
+ nps = notes_per_second.unsqueeze(-1) # (B, 1)
95
+ gamma_beta = self.film(nps) # (B, 2*D_MODEL)
96
+ gamma, beta = gamma_beta.chunk(2, dim=-1)
97
+ x = gamma.unsqueeze(1) * x + beta.unsqueeze(1)
98
+
99
+ # Conformer encoder
100
+ x, _ = self.encoder(x, lengths=lengths)
101
+
102
+ # Presence prediction
103
+ presence = self.presence_regressor(x)
104
+ return {"presence": presence, "lengths": lengths}
105
+
106
+
107
+ if __name__ == "__main__":
108
+ model = TaikoConformer5().to(device=DEVICE)
109
+ print(model)
110
+
111
+ for name, param in model.named_parameters():
112
+ if param.requires_grad:
113
+ print(f"{name}: {param.numel():,}")
114
+
115
+ params = sum(p.numel() for p in model.parameters() if p.requires_grad)
116
+ print(f"Total parameters: {params / 1e6:.2f}M")
117
+
118
+ batch_size = 4
119
+ mel_time_steps = 1024
120
+ input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
121
+
122
+ conformer_lengths = torch.tensor(
123
+ [mel_time_steps] * batch_size, dtype=torch.long
124
+ ).to(DEVICE)
125
+
126
+ notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
127
+ DEVICE
128
+ )
129
+
130
+ output = model(input_mel, conformer_lengths, notes_per_second_input)
131
+ print("Output shapes:")
132
+ for key, value in output.items():
133
+ print(f"{key}: {value.shape}")
tc5/preprocess.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio
5
+ from torchaudio.transforms import FrequencyMasking
6
+ from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
7
+ from .model import TaikoConformer5
8
+
9
+
10
+ mel_transform = torchaudio.transforms.MelSpectrogram(
11
+ sample_rate=SAMPLE_RATE,
12
+ n_mels=N_MELS,
13
+ hop_length=HOP_LENGTH,
14
+ n_fft=2048,
15
+ )
16
+
17
+
18
+ freq_mask = FrequencyMasking(freq_mask_param=15)
19
+
20
+
21
+ def preprocess(example, difficulty="oni"):
22
+ wav_tensor = example["audio"]["array"]
23
+ sr = example["audio"]["sampling_rate"]
24
+ # 1) load & resample
25
+ if sr != SAMPLE_RATE:
26
+ wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
27
+ # normalize audio
28
+ wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
29
+ # add random Gaussian noise
30
+ if torch.rand(1).item() < 0.5:
31
+ wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
32
+ # 2) mel: (1, N_MELS, T)
33
+ mel = mel_transform(wav_tensor).unsqueeze(0)
34
+ # apply SpecAugment
35
+ # we don't use time masking since we don't want model to predict notes when they are masked
36
+ mel = freq_mask(mel)
37
+ _, _, T = mel.shape
38
+ # 3) build label sequence of length ceil(T / TIME_SUB)
39
+ T_sub = math.ceil(T / TIME_SUB)
40
+
41
+ # Initialize energy-based labels for Don, Ka, Drumroll
42
+ don_labels = torch.zeros(T_sub, dtype=torch.float32)
43
+ ka_labels = torch.zeros(T_sub, dtype=torch.float32)
44
+ drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
45
+
46
+ # Define exponential decay tail parameters
47
+ tail_length = 40 # number of frames for decay tail
48
+ decay_rate = 8.0 # decay rate parameter, adjust as needed
49
+ tail_kernel = torch.exp(
50
+ -torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
51
+ )
52
+
53
+ fps = SAMPLE_RATE / HOP_LENGTH
54
+ num_valid_notes = 0
55
+ for onset in example[difficulty]:
56
+ typ, t_start, t_end, *_ = onset
57
+
58
+ # Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
59
+ if typ < 1 or typ > N_TYPES: # Filter out invalid types
60
+ continue
61
+
62
+ num_valid_notes += 1
63
+
64
+ f = int(round(t_start.item() * fps))
65
+ idx = f // TIME_SUB
66
+ if 0 <= idx < T_sub:
67
+ # Apply exponential decay kernel to the corresponding energy channel
68
+ # Type 1 and 3 are Don
69
+ if typ == 1 or typ == 3:
70
+ for i, val in enumerate(tail_kernel):
71
+ target_idx = idx + i
72
+ if 0 <= target_idx < T_sub:
73
+ don_labels[target_idx] = max(
74
+ don_labels[target_idx].item(), val.item()
75
+ )
76
+ # Type 2 and 4 are Ka
77
+ elif typ == 2 or typ == 4:
78
+ for i, val in enumerate(tail_kernel):
79
+ target_idx = idx + i
80
+ if 0 <= target_idx < T_sub:
81
+ ka_labels[target_idx] = max(
82
+ ka_labels[target_idx].item(), val.item()
83
+ )
84
+ # Type 5, 6, 7 are Drumroll
85
+ elif typ >= 5 and typ <= 7:
86
+ f_end = int(round(t_end.item() * fps))
87
+ idx_end = f_end // TIME_SUB
88
+
89
+ for dr in range(idx, idx_end):
90
+ if 0 <= dr < T_sub:
91
+ drumroll_labels[dr] = 1.0
92
+
93
+ for i, val in enumerate(tail_kernel):
94
+ target_idx = idx_end + i
95
+ if 0 <= target_idx < T_sub:
96
+ drumroll_labels[target_idx] = max(
97
+ drumroll_labels[target_idx].item(), val.item()
98
+ )
99
+
100
+ duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
101
+ nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
102
+ print(
103
+ f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}"
104
+ )
105
+
106
+ return {
107
+ "mel": mel,
108
+ "don_labels": don_labels,
109
+ "ka_labels": ka_labels,
110
+ "drumroll_labels": drumroll_labels,
111
+ "nps": torch.tensor(nps, dtype=torch.float32),
112
+ "duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
113
+ }
114
+
115
+
116
+ def collate_fn(batch):
117
+ mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS)
118
+ # Extract new energy-based labels
119
+ don_labels_list = [b["don_labels"] for b in batch]
120
+ ka_labels_list = [b["ka_labels"] for b in batch]
121
+ drumroll_labels_list = [b["drumroll_labels"] for b in batch]
122
+
123
+ nps_list = [b["nps"] for b in batch] # Extract NPS
124
+ durations_list = [b["duration_seconds"] for b in batch] # Extract durations
125
+
126
+ # Pad mels
127
+ padded_mels = nn.utils.rnn.pad_sequence(
128
+ mels_list, batch_first=True
129
+ ) # (B, T_max, N_MELS)
130
+ # Reshape for CNN: (B, 1, N_MELS, T_max)
131
+ reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
132
+
133
+ # Simulate CNN time downsampling to get output lengths
134
+ dummy_model_for_shape_inference = TaikoConformer5()
135
+ dummy_cnn = dummy_model_for_shape_inference.cnn
136
+ with torch.no_grad():
137
+ cnn_out = dummy_cnn(reshaped_mels) # Use reshaped_mels that has batch dim
138
+ _, _, _, T_cnn = cnn_out.shape
139
+
140
+ padded_don_labels = []
141
+ padded_ka_labels = []
142
+ padded_drumroll_labels = []
143
+ # lengths = [] # This was for original presence/type labels, conformer_input_lengths is used for model
144
+
145
+ for i in range(len(batch)):
146
+ d_labels = don_labels_list[i]
147
+ k_labels = ka_labels_list[i]
148
+ dr_labels = drumroll_labels_list[i]
149
+
150
+ item_original_T_sub = d_labels.shape[
151
+ 0
152
+ ] # Assuming all label types have same original length
153
+ out_len = T_cnn # Target length for labels is T_cnn
154
+
155
+ # Pad or truncate don_labels
156
+ if item_original_T_sub < out_len:
157
+ pad_d = torch.full(
158
+ (out_len - item_original_T_sub,),
159
+ 0, # Pad with 0 for energy labels
160
+ dtype=d_labels.dtype,
161
+ device=d_labels.device,
162
+ )
163
+ padded_d = torch.cat([d_labels, pad_d], dim=0)
164
+ else:
165
+ padded_d = d_labels[:out_len]
166
+ padded_don_labels.append(padded_d)
167
+
168
+ # Pad or truncate ka_labels
169
+ if item_original_T_sub < out_len:
170
+ pad_k = torch.full(
171
+ (out_len - item_original_T_sub,),
172
+ 0, # Pad with 0 for energy labels
173
+ dtype=k_labels.dtype,
174
+ device=k_labels.device,
175
+ )
176
+ padded_k = torch.cat([k_labels, pad_k], dim=0)
177
+ else:
178
+ padded_k = k_labels[:out_len]
179
+ padded_ka_labels.append(padded_k)
180
+
181
+ # Pad or truncate drumroll_labels
182
+ if item_original_T_sub < out_len:
183
+ pad_dr = torch.full(
184
+ (out_len - item_original_T_sub,),
185
+ 0, # Pad with 0 for energy labels
186
+ dtype=dr_labels.dtype,
187
+ device=dr_labels.device,
188
+ )
189
+ padded_dr = torch.cat([dr_labels, pad_dr], dim=0)
190
+ else:
191
+ padded_dr = dr_labels[:out_len]
192
+ padded_drumroll_labels.append(padded_dr)
193
+
194
+ # For Conformer input lengths: lengths of mel sequences after CNN subsampling
195
+ # (Assuming CNN does not subsample in time, T_cnn is effectively T_mel_padded)
196
+ # The `lengths` for the Conformer should be based on the mel input to the conformer part.
197
+ # The existing calculation for conformer_input_lengths seems to relate to TIME_SUB.
198
+ # If the Conformer input itself is not subsampled by TIME_SUB, this might need review.
199
+ # For now, keeping the existing conformer_input_lengths logic as it's outside the scope of label change.
200
+ conformer_input_lengths = [
201
+ math.ceil(mels_list[i].shape[0] / TIME_SUB) for i in range(len(batch))
202
+ ]
203
+ conformer_input_lengths = torch.tensor(
204
+ [min(l, T_cnn) for l in conformer_input_lengths], dtype=torch.long
205
+ )
206
+
207
+ return {
208
+ "mel": reshaped_mels,
209
+ "don_labels": torch.stack(padded_don_labels),
210
+ "ka_labels": torch.stack(padded_ka_labels),
211
+ "drumroll_labels": torch.stack(padded_drumroll_labels),
212
+ "lengths": conformer_input_lengths, # These are for the Conformer model
213
+ "nps": torch.stack(nps_list),
214
+ "durations": torch.stack(durations_list),
215
+ }
tc5/train.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from accelerate.utils import set_seed
2
+
3
+ set_seed(1024)
4
+
5
+
6
+ import math
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ from tqdm import tqdm
11
+ from datasets import concatenate_datasets
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ from .config import (
15
+ BATCH_SIZE,
16
+ DEVICE,
17
+ EPOCHS,
18
+ LR,
19
+ GRAD_ACCUM_STEPS,
20
+ HOP_LENGTH,
21
+ SAMPLE_RATE,
22
+ )
23
+ from .model import TaikoConformer5
24
+ from .dataset import ds
25
+ from .preprocess import preprocess, collate_fn
26
+ from .loss import TaikoEnergyLoss
27
+ from huggingface_hub import upload_folder
28
+
29
+
30
+ # --- Helper function to log energy plots ---
31
+ def log_energy_plots_to_tensorboard(
32
+ writer,
33
+ tag_prefix,
34
+ epoch,
35
+ pred_don,
36
+ pred_ka,
37
+ pred_drumroll,
38
+ true_don,
39
+ true_ka,
40
+ true_drumroll,
41
+ valid_length, # Actual valid length of the sequence (before padding)
42
+ hop_sec,
43
+ ):
44
+ """
45
+ Logs a plot of predicted vs. true energies for one sample to TensorBoard.
46
+ Energies should be 1D numpy arrays for the single sample, up to valid_length.
47
+ """
48
+ # Ensure data is on CPU and converted to numpy, and select only the valid part
49
+ pred_don = pred_don[:valid_length].detach().cpu().numpy()
50
+ pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
51
+ pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
52
+ true_don = true_don[:valid_length].cpu().numpy()
53
+ true_ka = true_ka[:valid_length].cpu().numpy()
54
+ true_drumroll = true_drumroll[:valid_length].cpu().numpy()
55
+
56
+ time_axis = np.arange(valid_length) * hop_sec
57
+
58
+ fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
59
+ fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)
60
+
61
+ axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
62
+ axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
63
+ axs[0].set_ylabel("Don Energy")
64
+ axs[0].legend()
65
+ axs[0].grid(True)
66
+
67
+ axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
68
+ axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
69
+ axs[1].set_ylabel("Ka Energy")
70
+ axs[1].legend()
71
+ axs[1].grid(True)
72
+
73
+ axs[2].plot(
74
+ time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
75
+ )
76
+ axs[2].plot(
77
+ time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
78
+ )
79
+ axs[2].set_ylabel("Drumroll Energy")
80
+ axs[2].set_xlabel("Time (s)")
81
+ axs[2].legend()
82
+ axs[2].grid(True)
83
+
84
+ plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
85
+ writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
86
+ plt.close(fig)
87
+
88
+
89
+ def main():
90
+ global ds
91
+
92
+ # Calculate hop seconds for model output frames
93
+ # This assumes the model output time dimension corresponds to the mel spectrogram time dimension
94
+ output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
95
+
96
+ best_val_loss = float("inf")
97
+ patience = 10 # Increased patience a bit
98
+ pat_count = 0
99
+
100
+ ds_oni = ds.map(
101
+ preprocess,
102
+ remove_columns=ds.column_names,
103
+ fn_kwargs={"difficulty": "oni"},
104
+ writer_batch_size=10,
105
+ )
106
+ ds_hard = ds.map(
107
+ preprocess,
108
+ remove_columns=ds.column_names,
109
+ fn_kwargs={"difficulty": "hard"},
110
+ writer_batch_size=10,
111
+ )
112
+ ds_normal = ds.map(
113
+ preprocess,
114
+ remove_columns=ds.column_names,
115
+ fn_kwargs={"difficulty": "normal"},
116
+ writer_batch_size=10,
117
+ )
118
+ ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])
119
+
120
+ ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
121
+ train_loader = DataLoader(
122
+ ds_train_test["train"],
123
+ batch_size=BATCH_SIZE,
124
+ shuffle=True,
125
+ collate_fn=collate_fn,
126
+ num_workers=2,
127
+ )
128
+ val_loader = DataLoader(
129
+ ds_train_test["test"],
130
+ batch_size=BATCH_SIZE,
131
+ shuffle=False,
132
+ collate_fn=collate_fn,
133
+ num_workers=2,
134
+ )
135
+
136
+ model = TaikoConformer5().to(DEVICE)
137
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
138
+
139
+ criterion = TaikoEnergyLoss(reduction="mean").to(DEVICE)
140
+
141
+ # Adjust scheduler steps for gradient accumulation
142
+ num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
143
+ total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch
144
+
145
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
146
+ optimizer, max_lr=LR, total_steps=total_optimizer_steps
147
+ )
148
+
149
+ writer = SummaryWriter()
150
+
151
+ for epoch in range(1, EPOCHS + 1):
152
+ model.train()
153
+ total_epoch_loss = 0.0
154
+ optimizer.zero_grad()
155
+
156
+ for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
157
+ mel = batch["mel"].to(DEVICE)
158
+ # Unpack new energy-based labels
159
+ don_labels = batch["don_labels"].to(DEVICE)
160
+ ka_labels = batch["ka_labels"].to(DEVICE)
161
+ drumroll_labels = batch["drumroll_labels"].to(DEVICE)
162
+ lengths = batch["lengths"].to(
163
+ DEVICE
164
+ ) # These are for the Conformer model output
165
+ nps = batch["nps"].to(DEVICE)
166
+
167
+ output_dict = model(mel, lengths, nps)
168
+ # output_dict["presence"] is now (B, T_out, 3) for don, ka, drumroll energies
169
+ pred_energies_batch = output_dict["presence"] # (B, T_out, 3)
170
+
171
+ loss_input_batch = {
172
+ "don_labels": don_labels,
173
+ "ka_labels": ka_labels,
174
+ "drumroll_labels": drumroll_labels,
175
+ "lengths": lengths, # Pass lengths for masking within the loss function
176
+ }
177
+ loss = criterion(output_dict, loss_input_batch)
178
+
179
+ (loss / GRAD_ACCUM_STEPS).backward()
180
+
181
+ if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
182
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
183
+ optimizer.step()
184
+ scheduler.step()
185
+ optimizer.zero_grad()
186
+
187
+ total_epoch_loss += loss.item()
188
+
189
+ # Log plot for the first sample of the first batch in each training epoch
190
+ if idx == 0:
191
+ first_sample_pred_don = pred_energies_batch[0, :, 0]
192
+ first_sample_pred_ka = pred_energies_batch[0, :, 1]
193
+ first_sample_pred_drumroll = pred_energies_batch[0, :, 2]
194
+
195
+ first_sample_true_don = don_labels[0, :]
196
+ first_sample_true_ka = ka_labels[0, :]
197
+ first_sample_true_drumroll = drumroll_labels[0, :]
198
+
199
+ first_sample_length = lengths[
200
+ 0
201
+ ].item() # Get the valid length of the first sample
202
+
203
+ log_energy_plots_to_tensorboard(
204
+ writer,
205
+ "Train/Sample_0",
206
+ epoch,
207
+ first_sample_pred_don,
208
+ first_sample_pred_ka,
209
+ first_sample_pred_drumroll,
210
+ first_sample_true_don,
211
+ first_sample_true_ka,
212
+ first_sample_true_drumroll,
213
+ first_sample_length,
214
+ output_frame_hop_sec,
215
+ )
216
+
217
+ avg_train_loss = total_epoch_loss / len(train_loader)
218
+ writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)
219
+
220
+ # Validation
221
+ model.eval()
222
+ total_val_loss = 0.0
223
+ # Removed storage for classification logits/labels and confusion matrix components
224
+
225
+ with torch.no_grad():
226
+ for val_idx, batch in enumerate(
227
+ tqdm(val_loader, desc=f"Val Epoch {epoch}")
228
+ ):
229
+ mel = batch["mel"].to(DEVICE)
230
+ don_labels = batch["don_labels"].to(DEVICE)
231
+ ka_labels = batch["ka_labels"].to(DEVICE)
232
+ drumroll_labels = batch["drumroll_labels"].to(DEVICE)
233
+ lengths = batch["lengths"].to(DEVICE)
234
+ nps = batch["nps"].to(DEVICE) # Ground truth NPS from batch
235
+
236
+ output_dict = model(mel, lengths, nps)
237
+ pred_energies_val_batch = output_dict["presence"] # (B, T_out, 3)
238
+
239
+ val_loss_input_batch = {
240
+ "don_labels": don_labels,
241
+ "ka_labels": ka_labels,
242
+ "drumroll_labels": drumroll_labels,
243
+ "lengths": lengths,
244
+ }
245
+ val_loss = criterion(output_dict, val_loss_input_batch)
246
+ total_val_loss += val_loss.item()
247
+
248
+ # Log plot for the first sample of the first batch in each validation epoch
249
+ if val_idx == 0:
250
+ first_val_sample_pred_don = pred_energies_val_batch[0, :, 0]
251
+ first_val_sample_pred_ka = pred_energies_val_batch[0, :, 1]
252
+ first_val_sample_pred_drumroll = pred_energies_val_batch[0, :, 2]
253
+
254
+ first_val_sample_true_don = don_labels[0, :]
255
+ first_val_sample_true_ka = ka_labels[0, :]
256
+ first_val_sample_true_drumroll = drumroll_labels[0, :]
257
+
258
+ first_val_sample_length = lengths[0].item()
259
+
260
+ log_energy_plots_to_tensorboard(
261
+ writer,
262
+ "Eval/Sample_0",
263
+ epoch,
264
+ first_val_sample_pred_don,
265
+ first_val_sample_pred_ka,
266
+ first_val_sample_pred_drumroll,
267
+ first_val_sample_true_don,
268
+ first_val_sample_true_ka,
269
+ first_val_sample_true_drumroll,
270
+ first_val_sample_length,
271
+ output_frame_hop_sec,
272
+ )
273
+
274
+ # Log ground truth NPS for reference during validation if needed
275
+ # writer.add_scalar("NPS/GT_Val_Batch_Avg", nps.mean().item(), epoch * len(val_loader) + idx)
276
+
277
+ avg_val_loss = total_val_loss / len(val_loader)
278
+ writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)
279
+
280
+ # Log learning rate
281
+ current_lr = optimizer.param_groups[0]["lr"]
282
+ writer.add_scalar("LR/learning_rate", current_lr, epoch)
283
+
284
+ # Log ground truth NPS from the last validation batch (or mean over epoch)
285
+ if "nps" in batch: # Check if nps is in the last batch
286
+ writer.add_scalar(
287
+ "NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
288
+ )
289
+
290
+ print(
291
+ f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
292
+ )
293
+
294
+ if avg_val_loss < best_val_loss:
295
+ best_val_loss = avg_val_loss
296
+ pat_count = 0
297
+ torch.save(model.state_dict(), "best_model.pt") # Changed model save name
298
+ print(f"Saved new best model to best_model.pt at epoch {epoch}")
299
+ else:
300
+ pat_count += 1
301
+ if pat_count >= patience:
302
+ print("Early stopping!")
303
+ break
304
+ writer.close()
305
+
306
+ model_id = "JacobLinCool/taiko-conformer-5"
307
+ try:
308
+ model.push_to_hub(model_id, commit_message="Upload trained model")
309
+ upload_folder(
310
+ repo_id=model_id,
311
+ folder_path="runs",
312
+ path_in_repo=".",
313
+ commit_message="Upload training logs",
314
+ ignore_patterns=["*.txt", "*.json", "*.csv"],
315
+ )
316
+ print(f"Model and logs uploaded to {model_id}")
317
+ except Exception as e:
318
+ print(f"Error uploading to Hugging Face Hub: {e}")
319
+ print("Make sure you have the correct permissions and try again.")
320
+
321
+
322
+ if __name__ == "__main__":
323
+ main()
tc6/__init__.py ADDED
File without changes
tc6/config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # ─── 1) CONFIG ─────────────────────────────────────────────────────
4
+ SAMPLE_RATE = 22050
5
+ N_MELS = 80
6
+ HOP_LENGTH = 256
7
+ TIME_SUB = 1
8
+ CNN_CH = 256
9
+ N_HEADS = 8
10
+ D_MODEL = 512
11
+ FF_DIM = 1024
12
+ N_LAYERS = 6
13
+ DEPTHWISE_CONV_KERNEL_SIZE = 31
14
+ DROPOUT = 0.1
15
+ HIDDEN_DIM = 64
16
+ N_TYPES = 7
17
+ BATCH_SIZE = 2
18
+ GRAD_ACCUM_STEPS = 8
19
+ LR = 3e-4
20
+ EPOCHS = 200
21
+ DEVICE = (
22
+ "cuda"
23
+ if torch.cuda.is_available()
24
+ else "mps" if torch.backends.mps.is_available() else "cpu"
25
+ )
tc6/dataset.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, concatenate_datasets
2
+
3
+ ds1 = load_dataset("JacobLinCool/taiko-2023-1.1", split="train")
4
+ ds2 = load_dataset("JacobLinCool/taiko-2023-1.2", split="train")
5
+ ds3 = load_dataset("JacobLinCool/taiko-2023-1.3", split="train")
6
+ ds4 = load_dataset("JacobLinCool/taiko-2023-1.4", split="train")
7
+ ds5 = load_dataset("JacobLinCool/taiko-2023-1.5", split="train")
8
+ ds6 = load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
9
+ ds7 = load_dataset("JacobLinCool/taiko-2023-1.7", split="train")
10
+ ds = concatenate_datasets([ds1, ds2, ds3, ds4, ds5, ds6, ds7]).with_format("torch")
11
+
12
+ good = list(range(len(ds)))
13
+ good.remove(1079) # 1079 has file problem
14
+ ds = ds.select(good)
15
+
16
+ # for local test
17
+ # ds = (
18
+ # load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
19
+ # .with_format("torch")
20
+ # .select(range(10))
21
+ # )
tc6/infer.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import torchaudio
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH
7
+ import torch.profiler
8
+
9
+
10
+ # --- PREPROCESSING (match training) ---
11
+ def preprocess_audio(audio_path):
12
+ wav, sr = torchaudio.load(audio_path)
13
+ wav = wav.mean(dim=0) # mono
14
+ if sr != SAMPLE_RATE:
15
+ wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
16
+ wav = wav / (wav.abs().max() + 1e-8) # Normalize audio
17
+
18
+ mel_transform = torchaudio.transforms.MelSpectrogram(
19
+ sample_rate=SAMPLE_RATE,
20
+ n_mels=N_MELS,
21
+ hop_length=HOP_LENGTH,
22
+ n_fft=2048,
23
+ )
24
+ mel = mel_transform(wav)
25
+ return mel # mel is (N_MELS, T_mel)
26
+
27
+
28
+ # --- INFERENCE ---
29
+ def run_inference(model, mel_input, nps_input, difficulty_input, level_input, device):
30
+ model.eval()
31
+ with torch.no_grad():
32
+ mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel)
33
+ nps = nps_input.to(device).unsqueeze(0) # (1,)
34
+ difficulty = difficulty_input.to(device).unsqueeze(0) # (1,)
35
+ level = level_input.to(device).unsqueeze(0) # (1,)
36
+
37
+ mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel)
38
+
39
+ conformer_lengths = torch.tensor(
40
+ [mel_cnn_input.shape[-1]], dtype=torch.long, device=device
41
+ )
42
+
43
+ with torch.profiler.profile(
44
+ activities=[
45
+ torch.profiler.ProfilerActivity.CPU,
46
+ *(
47
+ [torch.profiler.ProfilerActivity.CUDA]
48
+ if device.type == "cuda"
49
+ else []
50
+ ),
51
+ ],
52
+ record_shapes=True,
53
+ profile_memory=True,
54
+ with_stack=False,
55
+ with_flops=True,
56
+ ) as prof:
57
+ out_dict = model(mel_cnn_input, conformer_lengths, nps, difficulty, level)
58
+ print(
59
+ prof.key_averages().table(
60
+ sort_by=(
61
+ "self_cuda_memory_usage"
62
+ if device.type == "cuda"
63
+ else "self_cpu_time_total"
64
+ ),
65
+ row_limit=20,
66
+ )
67
+ )
68
+
69
+ energies = out_dict["presence"].squeeze(0).cpu().numpy()
70
+
71
+ don_energy = energies[:, 0]
72
+ ka_energy = energies[:, 1]
73
+ drumroll_energy = energies[:, 2]
74
+
75
+ return don_energy, ka_energy, drumroll_energy
76
+
77
+
78
+ # --- DECODE TO ONSETS ---
79
+ def decode_onsets(
80
+ don_energy,
81
+ ka_energy,
82
+ drumroll_energy,
83
+ hop_sec,
84
+ threshold=0.5,
85
+ min_distance_frames=3,
86
+ ):
87
+ results = []
88
+ T_out = len(don_energy)
89
+ last_onset_frame = -min_distance_frames
90
+
91
+ for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection
92
+ if i < last_onset_frame + min_distance_frames:
93
+ continue
94
+
95
+ e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i]
96
+ energies_at_i = {
97
+ 1: e_don,
98
+ 2: e_ka,
99
+ 5: e_drum,
100
+ } # Type mapping: 1:Don, 2:Ka, 5:Drumroll
101
+
102
+ # Find which energy is max and if it's a peak above threshold
103
+ # Sort by energy value descending to prioritize higher energy in case of ties for peak condition
104
+ sorted_types_by_energy = sorted(
105
+ energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True
106
+ )
107
+
108
+ detected_this_frame = False
109
+ for onset_type in sorted_types_by_energy:
110
+ current_energy_series = None
111
+ if onset_type == 1:
112
+ current_energy_series = don_energy
113
+ elif onset_type == 2:
114
+ current_energy_series = ka_energy
115
+ elif onset_type == 5:
116
+ current_energy_series = drumroll_energy
117
+
118
+ energy_val = current_energy_series[i]
119
+
120
+ if (
121
+ energy_val > threshold
122
+ and energy_val > current_energy_series[i - 1]
123
+ and energy_val > current_energy_series[i + 1]
124
+ ):
125
+ # Check if this energy is the highest among the three at this frame
126
+ # This check is implicitly handled by iterating `sorted_types_by_energy`
127
+ # and breaking after the first detection.
128
+ results.append((i * hop_sec, onset_type))
129
+ last_onset_frame = i
130
+ detected_this_frame = True
131
+ break # Only one onset type per frame
132
+
133
+ return results
134
+
135
+
136
+ # --- VISUALIZATION ---
137
+ def plot_results(
138
+ mel_spectrogram,
139
+ don_energy,
140
+ ka_energy,
141
+ drumroll_energy,
142
+ onsets,
143
+ hop_sec,
144
+ out_path=None,
145
+ ):
146
+ # mel_spectrogram is (N_MELS, T_mel)
147
+ T_mel = mel_spectrogram.shape[1]
148
+ T_out = len(don_energy) # Length of energy arrays (model output time dimension)
149
+
150
+ # Time axes
151
+ time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE)
152
+ # hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE
153
+ # However, the model output T_out is related to T_mel (input to CNN).
154
+ # If CNN does not change time dimension, T_out = T_mel.
155
+ # If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB.
156
+ # The `lengths` passed to conformer in `run_inference` is T_mel.
157
+ # The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`.
158
+ # So, T_out from model is T_mel.
159
+ # The `hop_sec` for onsets should be based on the model output frame rate.
160
+ # If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE.
161
+ # The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE.
162
+ # This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels.
163
+ # The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`.
164
+ # In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN.
165
+ # The `lengths` for the conformer is based on this T_cnn_out.
166
+ # So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps.
167
+ # Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer.
168
+ # Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames.
169
+ time_axis_energies = np.arange(T_out) * hop_sec
170
+
171
+ fig, ax1 = plt.subplots(figsize=(100, 10))
172
+
173
+ # Plot Mel Spectrogram on ax1
174
+ mel_db = torchaudio.functional.amplitude_to_DB(
175
+ mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0
176
+ )
177
+ img = ax1.imshow(
178
+ mel_db.numpy(),
179
+ aspect="auto",
180
+ origin="lower",
181
+ cmap="magma",
182
+ extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS],
183
+ )
184
+ ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets")
185
+ ax1.set_xlabel("Time (s)")
186
+ ax1.set_ylabel("Mel Bin")
187
+ fig.colorbar(img, ax=ax1, format="%+2.0f dB")
188
+
189
+ # Create a second y-axis for energies
190
+ ax2 = ax1.twinx()
191
+ ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red")
192
+ ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue")
193
+ ax2.plot(
194
+ time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green"
195
+ )
196
+ ax2.set_ylabel("Energy")
197
+ ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded
198
+
199
+ # Overlay onsets from decode_onsets (t is already in seconds)
200
+ labeled_types = set()
201
+ # Group drumrolls into segments (reuse logic from write_tja)
202
+ drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5]
203
+ drumroll_times.sort()
204
+ drumroll_segments = []
205
+ if drumroll_times:
206
+ seg_start = drumroll_times[0]
207
+ prev = drumroll_times[0]
208
+ for t in drumroll_times[1:]:
209
+ if t - prev <= hop_sec * 6: # up to 5-frame gap
210
+ prev = t
211
+ else:
212
+ drumroll_segments.append((seg_start, prev))
213
+ seg_start = t
214
+ prev = t
215
+ drumroll_segments.append((seg_start, prev))
216
+ # Plot Don/Ka onsets as vertical lines
217
+ for t_sec, typ in onsets:
218
+ if typ == 5:
219
+ continue # skip drumroll onsets
220
+ color_map = {1: "darkred", 2: "darkblue"}
221
+ label_map = {1: "Don Onset", 2: "Ka Onset"}
222
+ line_color = color_map.get(typ, "black")
223
+ line_label = label_map.get(typ, f"Type {typ} Onset")
224
+ if typ not in labeled_types:
225
+ ax1.axvline(
226
+ t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label
227
+ )
228
+ labeled_types.add(typ)
229
+ else:
230
+ ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9)
231
+ # Plot drumroll segments as shaded regions
232
+ for seg_start, seg_end in drumroll_segments:
233
+ ax1.axvspan(
234
+ seg_start,
235
+ seg_end + hop_sec,
236
+ color="green",
237
+ alpha=0.2,
238
+ label="Drumroll Segment" if "drumroll" not in labeled_types else None,
239
+ )
240
+ labeled_types.add("drumroll")
241
+
242
+ # Combine legends from both axes
243
+ lines, labels = ax1.get_legend_handles_labels()
244
+ lines2, labels2 = ax2.get_legend_handles_labels()
245
+ ax2.legend(lines + lines2, labels + labels2, loc="upper right")
246
+
247
+ fig.tight_layout()
248
+
249
+ # Return plot as image buffer or save to file if path provided
250
+ if out_path:
251
+ plt.savefig(out_path)
252
+ print(f"Saved plot to {out_path}")
253
+ plt.close(fig)
254
+ return out_path
255
+ else:
256
+ # Return plot as in-memory buffer
257
+ return fig
258
+
259
+
260
+ def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
261
+ # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
262
+ # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
263
+ sec_per_beat = 60 / bpm
264
+ beats_per_measure = 4 # Assuming 4/4 time signature
265
+ sec_per_measure = sec_per_beat * beats_per_measure
266
+ # Step 1: Map onsets to (measure_idx, slot, typ)
267
+ slot_events = []
268
+ for t, typ in onsets:
269
+ measure_idx = int(t // sec_per_measure)
270
+ t_in_measure = t % sec_per_measure
271
+ slot = int(round(t_in_measure / sec_per_measure * quantize))
272
+ if slot >= quantize:
273
+ slot = quantize - 1
274
+ slot_events.append((measure_idx, slot, typ))
275
+ # Step 2: Build measure/slot grid
276
+ if slot_events:
277
+ max_measure_idx = max(m for m, _, _ in slot_events)
278
+ else:
279
+ max_measure_idx = -1
280
+ measures = {i: [0] * quantize for i in range(max_measure_idx + 1)}
281
+ # Step 3: Place Don/Ka, collect drumrolls
282
+ drumroll_slots = set()
283
+ for m, s, typ in slot_events:
284
+ if typ in [1, 2]:
285
+ measures[m][s] = typ
286
+ elif typ == 5:
287
+ drumroll_slots.add((m, s))
288
+ # Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end)
289
+ # Flatten all slots to a list of (measure, slot) sorted
290
+ drumroll_list = sorted(list(drumroll_slots))
291
+ # Group into contiguous regions (allowing a gap of 5 slots)
292
+ grouped = []
293
+ group = []
294
+ for ms in drumroll_list:
295
+ if not group:
296
+ group = [ms]
297
+ else:
298
+ last_m, last_s = group[-1]
299
+ m, s = ms
300
+ # Calculate slot distance, considering measure wrap
301
+ slot_dist = None
302
+ if m == last_m:
303
+ slot_dist = s - last_s
304
+ elif m == last_m + 1 and last_s <= quantize - 1:
305
+ slot_dist = (quantize - 1 - last_s) + s + 1
306
+ else:
307
+ slot_dist = None
308
+ # Allow gap of up to 5 slots (slot_dist <= 6)
309
+ if slot_dist is not None and 1 <= slot_dist <= 6:
310
+ group.append(ms)
311
+ else:
312
+ grouped.append(group)
313
+ group = [ms]
314
+ if group:
315
+ grouped.append(group)
316
+ # Mark 5 (start) and 8 (end) for each group
317
+ for region in grouped:
318
+ if len(region) == 1:
319
+ m, s = region[0]
320
+ measures[m][s] = 5
321
+ # Place 8 in next slot (or next measure if at end)
322
+ if s < quantize - 1:
323
+ measures[m][s + 1] = 8
324
+ elif m < max_measure_idx:
325
+ measures[m + 1][0] = 8
326
+ else:
327
+ m_start, s_start = region[0]
328
+ m_end, s_end = region[-1]
329
+ measures[m_start][s_start] = 5
330
+ measures[m_end][s_end] = 8
331
+ # Fill 0 for middle slots (already 0 by default)
332
+ # Step 5: Generate TJA content
333
+ tja_content = []
334
+ tja_content.append(f"TITLE:{audio} (TC6, {time.strftime('%Y-%m-%d %H:%M:%S')})")
335
+ tja_content.append(f"BPM:{bpm}")
336
+ tja_content.append(f"WAVE:{audio}")
337
+ tja_content.append("OFFSET:0")
338
+ tja_content.append("COURSE:Oni\nLEVEL:9\n")
339
+ tja_content.append("#START")
340
+ for i in range(max_measure_idx + 1):
341
+ notes = measures.get(i, [0] * quantize)
342
+ line = "".join(str(n) for n in notes)
343
+ tja_content.append(line + ",")
344
+ tja_content.append("#END")
345
+
346
+ tja_string = "\n".join(tja_content)
347
+
348
+ # If out_path is provided, also write to file
349
+ if out_path:
350
+ with open(out_path, "w", encoding="utf-8") as f:
351
+ f.write(tja_string)
352
+ print(f"TJA chart saved to {out_path}")
353
+
354
+ return tja_string
tc6/loss.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class TaikoEnergyLoss(nn.Module):
6
+ def __init__(self, reduction="mean"):
7
+ super().__init__()
8
+ # Use 'none' reduction to get element-wise losses, then manually apply masking and reduction
9
+ self.mse_loss = nn.MSELoss(reduction="none")
10
+ self.reduction = reduction
11
+
12
+ def forward(self, outputs, batch):
13
+ """
14
+ Calculates the MSE loss for energy-based predictions.
15
+
16
+ Args:
17
+ outputs (dict): Model output, containing 'presence' tensor.
18
+ outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
19
+ batch (dict): Batch data from collate_fn, containing true labels and lengths.
20
+ batch['don_labels'], batch['ka_labels'], batch['drumroll_labels'] shape: (B, T)
21
+ batch['lengths'] shape: (B,) - valid sequence lengths for time dimension T.
22
+ Returns:
23
+ torch.Tensor: The calculated loss.
24
+ """
25
+ pred_energies = outputs["presence"] # (B, T, 3)
26
+
27
+ true_don = batch["don_labels"] # (B, T)
28
+ true_ka = batch["ka_labels"] # (B, T)
29
+ true_drumroll = batch["drumroll_labels"] # (B, T)
30
+
31
+ # Stack true labels to match the structure of pred_energies (B, T, 3)
32
+ true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2)
33
+
34
+ B, T, _ = pred_energies.shape
35
+
36
+ # Create a mask based on batch['lengths'] to ignore padded parts of sequences
37
+ # batch['lengths'] gives the actual length of each sequence in the batch
38
+ # mask shape: (B, T)
39
+ mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
40
+ "lengths"
41
+ ].unsqueeze(1)
42
+ # Expand mask to (B, T, 1) to broadcast across the 3 energy channels
43
+ mask_3d = mask_2d.unsqueeze(2)
44
+
45
+ # Calculate element-wise MSE loss
46
+ loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3)
47
+
48
+ # Apply the mask to the loss
49
+ masked_loss = loss_elementwise * mask_3d
50
+
51
+ if self.reduction == "mean":
52
+ # Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
53
+ total_loss = masked_loss.sum()
54
+ num_valid_elements = mask_3d.sum() # Total number of unmasked float values
55
+ if num_valid_elements > 0:
56
+ return total_loss / num_valid_elements
57
+ else:
58
+ # Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
59
+ return torch.tensor(
60
+ 0.0, device=pred_energies.device, requires_grad=True
61
+ )
62
+ elif self.reduction == "sum":
63
+ return masked_loss.sum()
64
+ else: # 'none' or any other case
65
+ return masked_loss
tc6/model.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchaudio.models import Conformer
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from .config import (
6
+ N_MELS,
7
+ CNN_CH,
8
+ N_HEADS,
9
+ D_MODEL,
10
+ FF_DIM,
11
+ N_LAYERS,
12
+ DROPOUT,
13
+ DEPTHWISE_CONV_KERNEL_SIZE,
14
+ HIDDEN_DIM,
15
+ DEVICE,
16
+ )
17
+
18
+
19
+ class TaikoConformer6(nn.Module, PyTorchModelHubMixin):
20
+ def __init__(self):
21
+ super().__init__()
22
+ # 1) CNN frontend: frequency-only pooling
23
+ self.cnn = nn.Sequential(
24
+ nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
25
+ nn.BatchNorm2d(CNN_CH),
26
+ nn.GELU(),
27
+ nn.Dropout2d(DROPOUT),
28
+ nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
29
+ nn.BatchNorm2d(CNN_CH),
30
+ nn.GELU(),
31
+ nn.Dropout2d(DROPOUT),
32
+ )
33
+ feat_dim = CNN_CH * (N_MELS // 4)
34
+
35
+ # 2) Linear projection to model dimension
36
+ self.proj = nn.Linear(feat_dim, D_MODEL)
37
+
38
+ # 3) FiLM conditioning for notes_per_second, difficulty, and level
39
+ self.film_nps = nn.Linear(1, 2 * D_MODEL)
40
+ self.film_difficulty = nn.Linear(
41
+ 1, 2 * D_MODEL
42
+ ) # Assuming difficulty is a single scalar
43
+ self.film_level = nn.Linear(1, 2 * D_MODEL) # Assuming level is a single scalar
44
+
45
+ # 4) Conformer encoder
46
+ self.encoder = Conformer(
47
+ input_dim=D_MODEL,
48
+ num_heads=N_HEADS,
49
+ ffn_dim=FF_DIM,
50
+ num_layers=N_LAYERS,
51
+ depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
52
+ dropout=DROPOUT,
53
+ use_group_norm=False,
54
+ convolution_first=False,
55
+ )
56
+
57
+ # 5) Presence regressor head
58
+ self.presence_regressor = nn.Sequential(
59
+ nn.Dropout(DROPOUT),
60
+ nn.Linear(D_MODEL, HIDDEN_DIM),
61
+ nn.GELU(),
62
+ nn.Dropout(DROPOUT),
63
+ nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
64
+ nn.Sigmoid(), # Output between 0 and 1
65
+ )
66
+
67
+ # 6) Initialize weights
68
+ for m in self.modules():
69
+ if isinstance(m, nn.Conv2d):
70
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
71
+ elif isinstance(m, nn.Linear):
72
+ nn.init.xavier_uniform_(m.weight)
73
+ if m.bias is not None:
74
+ nn.init.zeros_(m.bias)
75
+
76
+ def forward(
77
+ self,
78
+ mel: torch.Tensor,
79
+ lengths: torch.Tensor,
80
+ notes_per_second: torch.Tensor,
81
+ difficulty: torch.Tensor,
82
+ level: torch.Tensor,
83
+ ):
84
+ """
85
+ Args:
86
+ mel: (B, 1, N_MELS, T_mel)
87
+ lengths: (B,) lengths after CNN
88
+ notes_per_second: (B,) stream of control values
89
+ difficulty: (B,) difficulty values
90
+ level: (B,) level values
91
+ Returns:
92
+ Dict with:
93
+ 'presence': (B, T_cnn_out, 3) # Corrected from 4 to 3
94
+ 'lengths': lengths
95
+ """
96
+ # CNN frontend
97
+ x = self.cnn(mel) # (B, C, F, T)
98
+ B, C, F, T = x.size()
99
+ x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
100
+
101
+ # Project to model dimension
102
+ x = self.proj(x) # (B, T, D_MODEL)
103
+
104
+ # FiLM conditioning
105
+ nps = notes_per_second.unsqueeze(-1).float() # (B, 1)
106
+ gamma_beta_nps = self.film_nps(nps) # (B, 2*D_MODEL)
107
+ gamma_nps, beta_nps = gamma_beta_nps.chunk(2, dim=-1)
108
+ x = gamma_nps.unsqueeze(1) * x + beta_nps.unsqueeze(1)
109
+
110
+ diff = difficulty.unsqueeze(-1).float() # (B, 1)
111
+ gamma_beta_diff = self.film_difficulty(diff) # (B, 2*D_MODEL)
112
+ gamma_diff, beta_diff = gamma_beta_diff.chunk(2, dim=-1)
113
+ x = gamma_diff.unsqueeze(1) * x + beta_diff.unsqueeze(1)
114
+
115
+ lvl = level.unsqueeze(-1).float() # (B, 1)
116
+ gamma_beta_lvl = self.film_level(lvl) # (B, 2*D_MODEL)
117
+ gamma_lvl, beta_lvl = gamma_beta_lvl.chunk(2, dim=-1)
118
+ x = gamma_lvl.unsqueeze(1) * x + beta_lvl.unsqueeze(1)
119
+
120
+ # Conformer encoder
121
+ x, _ = self.encoder(x, lengths=lengths)
122
+
123
+ # Presence prediction
124
+ presence = self.presence_regressor(x)
125
+ return {"presence": presence, "lengths": lengths}
126
+
127
+
128
+ if __name__ == "__main__":
129
+ model = TaikoConformer6().to(device=DEVICE)
130
+ print(model)
131
+
132
+ for name, param in model.named_parameters():
133
+ if param.requires_grad:
134
+ print(f"{name}: {param.numel():,}")
135
+
136
+ params = sum(p.numel() for p in model.parameters() if p.requires_grad)
137
+ print(f"Total parameters: {params / 1e6:.2f}M")
138
+
139
+ batch_size = 4
140
+ mel_time_steps = 1024
141
+ input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
142
+
143
+ conformer_lengths = torch.tensor(
144
+ [mel_time_steps] * batch_size, dtype=torch.long
145
+ ).to(DEVICE)
146
+
147
+ notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
148
+ DEVICE
149
+ )
150
+ difficulty_input = torch.tensor([1.0] * batch_size, dtype=torch.float32).to(
151
+ DEVICE
152
+ ) # Example difficulty
153
+ level_input = torch.tensor([5.0] * batch_size, dtype=torch.float32).to(
154
+ DEVICE
155
+ ) # Example level
156
+
157
+ output = model(
158
+ input_mel,
159
+ conformer_lengths,
160
+ notes_per_second_input,
161
+ difficulty_input,
162
+ level_input,
163
+ )
164
+ print("Output shapes:")
165
+ for key, value in output.items():
166
+ print(f"{key}: {value.shape}")
tc6/preprocess.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio
5
+ from torchaudio.transforms import FrequencyMasking
6
+ from tja import parse_tja, PyParsingMode
7
+ from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
8
+ from .model import TaikoConformer6
9
+
10
+
11
+ mel_transform = torchaudio.transforms.MelSpectrogram(
12
+ sample_rate=SAMPLE_RATE,
13
+ n_mels=N_MELS,
14
+ hop_length=HOP_LENGTH,
15
+ n_fft=2048,
16
+ )
17
+
18
+
19
+ freq_mask = FrequencyMasking(freq_mask_param=15)
20
+
21
+
22
+ def preprocess(example, difficulty="oni"):
23
+ wav_tensor = example["audio"]["array"]
24
+ sr = example["audio"]["sampling_rate"]
25
+ # 1) load & resample
26
+ if sr != SAMPLE_RATE:
27
+ wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
28
+ # normalize audio
29
+ wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
30
+ # add random Gaussian noise
31
+ if torch.rand(1).item() < 0.5:
32
+ wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
33
+ # 2) mel: (1, N_MELS, T)
34
+ mel = mel_transform(wav_tensor).unsqueeze(0)
35
+ # apply SpecAugment
36
+ mel = freq_mask(mel)
37
+ _, _, T = mel.shape
38
+ # 3) build label sequence of length ceil(T / TIME_SUB)
39
+ T_sub = math.ceil(T / TIME_SUB)
40
+
41
+ # Initialize energy-based labels for Don, Ka, Drumroll
42
+ don_labels = torch.zeros(T_sub, dtype=torch.float32)
43
+ ka_labels = torch.zeros(T_sub, dtype=torch.float32)
44
+ drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
45
+
46
+ # Define exponential decay tail parameters
47
+ tail_length = 40 # number of frames for decay tail
48
+ decay_rate = 8.0 # decay rate parameter, adjust as needed
49
+ tail_kernel = torch.exp(
50
+ -torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
51
+ )
52
+
53
+ fps = SAMPLE_RATE / HOP_LENGTH
54
+ num_valid_notes = 0
55
+ for onset in example[difficulty]:
56
+ typ, t_start, t_end, *_ = onset
57
+
58
+ # Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
59
+ if typ < 1 or typ > N_TYPES: # Filter out invalid types
60
+ continue
61
+
62
+ num_valid_notes += 1
63
+
64
+ exact_frame_start = t_start.item() * fps
65
+
66
+ # Type 1 and 3 are Don, Type 2 and 4 are Ka
67
+ if typ == 1 or typ == 3 or typ == 2 or typ == 4:
68
+ exact_hit_time_sub = exact_frame_start / TIME_SUB
69
+
70
+ current_labels = don_labels if (typ == 1 or typ == 3) else ka_labels
71
+
72
+ start_points_info = []
73
+ rounded_hit_time_sub = round(exact_hit_time_sub)
74
+
75
+ if (
76
+ abs(exact_hit_time_sub - rounded_hit_time_sub) < 1e-6
77
+ ): # Tolerance for float precision
78
+ idx_single = int(rounded_hit_time_sub)
79
+ if 0 <= idx_single < T_sub:
80
+ start_points_info.append({"idx": idx_single, "weight": 1.0})
81
+ else:
82
+ idx_floor = math.floor(exact_hit_time_sub)
83
+ idx_ceil = idx_floor + 1
84
+
85
+ frac = exact_hit_time_sub - idx_floor
86
+ weight_ceil = frac
87
+ weight_floor = 1.0 - frac
88
+
89
+ if weight_floor > 1e-6 and 0 <= idx_floor < T_sub:
90
+ start_points_info.append({"idx": idx_floor, "weight": weight_floor})
91
+ if weight_ceil > 1e-6 and 0 <= idx_ceil < T_sub:
92
+ start_points_info.append({"idx": idx_ceil, "weight": weight_ceil})
93
+
94
+ for point_info in start_points_info:
95
+ start_idx = point_info["idx"]
96
+ weight = point_info["weight"]
97
+ for k_idx, kernel_val in enumerate(tail_kernel):
98
+ target_idx = start_idx + k_idx
99
+ if 0 <= target_idx < T_sub:
100
+ current_labels[target_idx] = max(
101
+ current_labels[target_idx].item(),
102
+ weight * kernel_val.item(),
103
+ )
104
+
105
+ # Type 5, 6, 7 are Drumroll
106
+ elif typ >= 5 and typ <= 7:
107
+ exact_frame_end = t_end.item() * fps
108
+ exact_start_time_sub = exact_frame_start / TIME_SUB
109
+ exact_end_time_sub = exact_frame_end / TIME_SUB
110
+
111
+ # Improved drumroll body
112
+ body_loop_start_idx = math.floor(exact_start_time_sub)
113
+ body_loop_end_idx = math.ceil(exact_end_time_sub)
114
+
115
+ for dr_idx in range(body_loop_start_idx, body_loop_end_idx):
116
+ if 0 <= dr_idx < T_sub:
117
+ drumroll_labels[dr_idx] = 1.0
118
+
119
+ # Improved drumroll tail (starts from exact_end_time_sub)
120
+ tail_start_points_info = []
121
+ rounded_end_time_sub = round(exact_end_time_sub)
122
+ if abs(exact_end_time_sub - rounded_end_time_sub) < 1e-6:
123
+ idx_single_tail = int(rounded_end_time_sub)
124
+ if 0 <= idx_single_tail < T_sub:
125
+ tail_start_points_info.append(
126
+ {"idx": idx_single_tail, "weight": 1.0}
127
+ )
128
+ else:
129
+ idx_floor_tail = math.floor(exact_end_time_sub)
130
+ idx_ceil_tail = idx_floor_tail + 1
131
+
132
+ frac_tail = exact_end_time_sub - idx_floor_tail
133
+ weight_ceil_tail = frac_tail
134
+ weight_floor_tail = 1.0 - frac_tail
135
+
136
+ if weight_floor_tail > 1e-6 and 0 <= idx_floor_tail < T_sub:
137
+ tail_start_points_info.append(
138
+ {"idx": idx_floor_tail, "weight": weight_floor_tail}
139
+ )
140
+ if weight_ceil_tail > 1e-6 and 0 <= idx_ceil_tail < T_sub:
141
+ tail_start_points_info.append(
142
+ {"idx": idx_ceil_tail, "weight": weight_ceil_tail}
143
+ )
144
+
145
+ for point_info in tail_start_points_info:
146
+ start_idx = point_info["idx"]
147
+ weight = point_info["weight"]
148
+ for k_idx, kernel_val in enumerate(tail_kernel):
149
+ target_idx = start_idx + k_idx
150
+ if 0 <= target_idx < T_sub:
151
+ drumroll_labels[target_idx] = max(
152
+ drumroll_labels[target_idx].item(),
153
+ weight * kernel_val.item(),
154
+ )
155
+
156
+ duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
157
+ nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
158
+
159
+ parsed = parse_tja(example["tja"], mode=PyParsingMode.Full)
160
+ chart = next(
161
+ (chart for chart in parsed.charts if chart.course.lower() == difficulty), None
162
+ )
163
+ difficulty_id = (
164
+ 0
165
+ if difficulty == "easy"
166
+ else (
167
+ 1
168
+ if difficulty == "normal"
169
+ else 2 if difficulty == "hard" else 3 if difficulty == "oni" else 4
170
+ ) # Assuming 4 for edit/ura
171
+ )
172
+ level = chart.level if chart else 0
173
+
174
+ # --- CNN shape inference and label padding/truncation ---
175
+ # Simulate CNN to get output time length (T_cnn)
176
+ dummy_model = TaikoConformer6()
177
+ with torch.no_grad():
178
+ cnn_out = dummy_model.cnn(mel.unsqueeze(0)) # (1, C, F, T_cnn)
179
+ _, _, _, T_cnn = cnn_out.shape
180
+
181
+ # Pad or truncate labels to T_cnn
182
+ def pad_or_truncate(label, out_len):
183
+ if label.shape[0] < out_len:
184
+ pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
185
+ return torch.cat([label, pad], dim=0)
186
+ else:
187
+ return label[:out_len]
188
+
189
+ don_labels = pad_or_truncate(don_labels, T_cnn)
190
+ ka_labels = pad_or_truncate(ka_labels, T_cnn)
191
+ drumroll_labels = pad_or_truncate(drumroll_labels, T_cnn)
192
+
193
+ # For conformer input lengths: based on original mel shape (before CNN)
194
+ conformer_input_length = min(math.ceil(T / TIME_SUB), T_cnn)
195
+
196
+ print(
197
+ f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}, Difficulty: {difficulty_id}, Level: {level}"
198
+ )
199
+
200
+ return {
201
+ "mel": mel, # (1, N_MELS, T)
202
+ "don_labels": don_labels, # (T_cnn,)
203
+ "ka_labels": ka_labels, # (T_cnn,)
204
+ "drumroll_labels": drumroll_labels, # (T_cnn,)
205
+ "nps": torch.tensor(nps, dtype=torch.float32),
206
+ "difficulty": torch.tensor(difficulty_id, dtype=torch.long),
207
+ "level": torch.tensor(level, dtype=torch.long),
208
+ "duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
209
+ "length": torch.tensor(
210
+ conformer_input_length, dtype=torch.long
211
+ ), # for conformer
212
+ }
213
+
214
+
215
+ def collate_fn(batch):
216
+ mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS)
217
+ don_labels_list = [b["don_labels"] for b in batch]
218
+ ka_labels_list = [b["ka_labels"] for b in batch]
219
+ drumroll_labels_list = [b["drumroll_labels"] for b in batch]
220
+ nps_list = [b["nps"] for b in batch]
221
+ difficulty_list = [b["difficulty"] for b in batch]
222
+ level_list = [b["level"] for b in batch]
223
+ durations_list = [b["duration_seconds"] for b in batch]
224
+ lengths_list = [b["length"] for b in batch]
225
+
226
+ # Pad mels
227
+ padded_mels = nn.utils.rnn.pad_sequence(
228
+ mels_list, batch_first=True
229
+ ) # (B, T_max, N_MELS)
230
+ reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
231
+ T_max = padded_mels.shape[1]
232
+
233
+ # Pad labels to T_max
234
+ def pad_label(label, out_len):
235
+ if label.shape[0] < out_len:
236
+ pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
237
+ return torch.cat([label, pad], dim=0)
238
+ else:
239
+ return label[:out_len]
240
+
241
+ don_labels = torch.stack([pad_label(l, T_max) for l in don_labels_list])
242
+ ka_labels = torch.stack([pad_label(l, T_max) for l in ka_labels_list])
243
+ drumroll_labels = torch.stack([pad_label(l, T_max) for l in drumroll_labels_list])
244
+ lengths = torch.tensor(
245
+ [min(l.item(), T_max) for l in lengths_list], dtype=torch.long
246
+ )
247
+
248
+ return {
249
+ "mel": reshaped_mels,
250
+ "don_labels": don_labels,
251
+ "ka_labels": ka_labels,
252
+ "drumroll_labels": drumroll_labels,
253
+ "lengths": lengths, # for conformer
254
+ "nps": torch.stack(nps_list),
255
+ "difficulty": torch.stack(difficulty_list),
256
+ "level": torch.stack(level_list),
257
+ "durations": torch.stack(durations_list),
258
+ }
tc6/train.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from accelerate.utils import set_seed
2
+
3
+ set_seed(1024)
4
+
5
+
6
+ import math
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ from tqdm import tqdm
11
+ from datasets import concatenate_datasets
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ from .config import (
15
+ BATCH_SIZE,
16
+ DEVICE,
17
+ EPOCHS,
18
+ LR,
19
+ GRAD_ACCUM_STEPS,
20
+ HOP_LENGTH,
21
+ SAMPLE_RATE,
22
+ )
23
+ from .model import TaikoConformer6
24
+ from .dataset import ds
25
+ from .preprocess import preprocess, collate_fn
26
+ from .loss import TaikoEnergyLoss
27
+ from huggingface_hub import upload_folder
28
+
29
+
30
+ # --- Helper function to log energy plots ---
31
+ def log_energy_plots_to_tensorboard(
32
+ writer,
33
+ tag_prefix,
34
+ epoch,
35
+ pred_don,
36
+ pred_ka,
37
+ pred_drumroll,
38
+ true_don,
39
+ true_ka,
40
+ true_drumroll,
41
+ valid_length, # Actual valid length of the sequence (before padding)
42
+ hop_sec,
43
+ ):
44
+ """
45
+ Logs a plot of predicted vs. true energies for one sample to TensorBoard.
46
+ Energies should be 1D numpy arrays for the single sample, up to valid_length.
47
+ """
48
+ # Ensure data is on CPU and converted to numpy, and select only the valid part
49
+ pred_don = pred_don[:valid_length].detach().cpu().numpy()
50
+ pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
51
+ pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
52
+ true_don = true_don[:valid_length].cpu().numpy()
53
+ true_ka = true_ka[:valid_length].cpu().numpy()
54
+ true_drumroll = true_drumroll[:valid_length].cpu().numpy()
55
+
56
+ time_axis = np.arange(valid_length) * hop_sec
57
+
58
+ fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
59
+ fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)
60
+
61
+ axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
62
+ axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
63
+ axs[0].set_ylabel("Don Energy")
64
+ axs[0].legend()
65
+ axs[0].grid(True)
66
+
67
+ axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
68
+ axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
69
+ axs[1].set_ylabel("Ka Energy")
70
+ axs[1].legend()
71
+ axs[1].grid(True)
72
+
73
+ axs[2].plot(
74
+ time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
75
+ )
76
+ axs[2].plot(
77
+ time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
78
+ )
79
+ axs[2].set_ylabel("Drumroll Energy")
80
+ axs[2].set_xlabel("Time (s)")
81
+ axs[2].legend()
82
+ axs[2].grid(True)
83
+
84
+ plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
85
+ writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
86
+ plt.close(fig)
87
+
88
+
89
+ def main():
90
+ global ds
91
+
92
+ # Calculate hop seconds for model output frames
93
+ # This assumes the model output time dimension corresponds to the mel spectrogram time dimension
94
+ output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
95
+
96
+ best_val_loss = float("inf")
97
+ patience = 10 # Increased patience a bit
98
+ pat_count = 0
99
+
100
+ ds_oni = ds.map(
101
+ preprocess,
102
+ remove_columns=ds.column_names,
103
+ fn_kwargs={"difficulty": "oni"},
104
+ writer_batch_size=10,
105
+ )
106
+ ds_hard = ds.map(
107
+ preprocess,
108
+ remove_columns=ds.column_names,
109
+ fn_kwargs={"difficulty": "hard"},
110
+ writer_batch_size=10,
111
+ )
112
+ ds_normal = ds.map(
113
+ preprocess,
114
+ remove_columns=ds.column_names,
115
+ fn_kwargs={"difficulty": "normal"},
116
+ writer_batch_size=10,
117
+ )
118
+ ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])
119
+
120
+ ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
121
+ # ds_train_test.push_to_hub("JacobLinCool/taiko-conformer-6-ds")
122
+ train_loader = DataLoader(
123
+ ds_train_test["train"],
124
+ batch_size=BATCH_SIZE,
125
+ shuffle=True,
126
+ collate_fn=collate_fn,
127
+ num_workers=16,
128
+ persistent_workers=True,
129
+ prefetch_factor=4,
130
+ )
131
+ val_loader = DataLoader(
132
+ ds_train_test["test"],
133
+ batch_size=BATCH_SIZE,
134
+ shuffle=False,
135
+ collate_fn=collate_fn,
136
+ num_workers=16,
137
+ persistent_workers=True,
138
+ prefetch_factor=4,
139
+ )
140
+
141
+ model = TaikoConformer6().to(DEVICE)
142
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
143
+
144
+ criterion = TaikoEnergyLoss(reduction="mean").to(DEVICE)
145
+
146
+ # Adjust scheduler steps for gradient accumulation
147
+ num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
148
+ total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch
149
+
150
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
151
+ optimizer, max_lr=LR, total_steps=total_optimizer_steps
152
+ )
153
+
154
+ writer = SummaryWriter()
155
+
156
+ for epoch in range(1, EPOCHS + 1):
157
+ model.train()
158
+ total_epoch_loss = 0.0
159
+ optimizer.zero_grad()
160
+
161
+ for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
162
+ mel = batch["mel"].to(DEVICE)
163
+ # Unpack new energy-based labels
164
+ don_labels = batch["don_labels"].to(DEVICE)
165
+ ka_labels = batch["ka_labels"].to(DEVICE)
166
+ drumroll_labels = batch["drumroll_labels"].to(DEVICE)
167
+ lengths = batch["lengths"].to(
168
+ DEVICE
169
+ ) # These are for the Conformer model output
170
+ nps = batch["nps"].to(DEVICE)
171
+ difficulty = batch["difficulty"].to(DEVICE) # Add difficulty
172
+ level = batch["level"].to(DEVICE) # Add level
173
+
174
+ output_dict = model(
175
+ mel, lengths, nps, difficulty, level
176
+ ) # Pass difficulty and level
177
+ # output_dict["presence"] is now (B, T_out, 3) for don, ka, drumroll energies
178
+ pred_energies_batch = output_dict["presence"] # (B, T_out, 3)
179
+
180
+ loss_input_batch = {
181
+ "don_labels": don_labels,
182
+ "ka_labels": ka_labels,
183
+ "drumroll_labels": drumroll_labels,
184
+ "lengths": lengths, # Pass lengths for masking within the loss function
185
+ }
186
+ loss = criterion(output_dict, loss_input_batch)
187
+
188
+ (loss / GRAD_ACCUM_STEPS).backward()
189
+
190
+ if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
191
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
192
+ optimizer.step()
193
+ scheduler.step()
194
+ optimizer.zero_grad()
195
+
196
+ total_epoch_loss += loss.item()
197
+
198
+ # Log plot for the first sample of the first batch in each training epoch
199
+ if idx == 0:
200
+ first_sample_pred_don = pred_energies_batch[0, :, 0]
201
+ first_sample_pred_ka = pred_energies_batch[0, :, 1]
202
+ first_sample_pred_drumroll = pred_energies_batch[0, :, 2]
203
+
204
+ first_sample_true_don = don_labels[0, :]
205
+ first_sample_true_ka = ka_labels[0, :]
206
+ first_sample_true_drumroll = drumroll_labels[0, :]
207
+
208
+ first_sample_length = lengths[
209
+ 0
210
+ ].item() # Get the valid length of the first sample
211
+
212
+ log_energy_plots_to_tensorboard(
213
+ writer,
214
+ "Train/Sample_0",
215
+ epoch,
216
+ first_sample_pred_don,
217
+ first_sample_pred_ka,
218
+ first_sample_pred_drumroll,
219
+ first_sample_true_don,
220
+ first_sample_true_ka,
221
+ first_sample_true_drumroll,
222
+ first_sample_length,
223
+ output_frame_hop_sec,
224
+ )
225
+
226
+ avg_train_loss = total_epoch_loss / len(train_loader)
227
+ writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)
228
+
229
+ # Validation
230
+ model.eval()
231
+ total_val_loss = 0.0
232
+ # Removed storage for classification logits/labels and confusion matrix components
233
+
234
+ with torch.no_grad():
235
+ for val_idx, batch in enumerate(
236
+ tqdm(val_loader, desc=f"Val Epoch {epoch}")
237
+ ):
238
+ mel = batch["mel"].to(DEVICE)
239
+ don_labels = batch["don_labels"].to(DEVICE)
240
+ ka_labels = batch["ka_labels"].to(DEVICE)
241
+ drumroll_labels = batch["drumroll_labels"].to(DEVICE)
242
+ lengths = batch["lengths"].to(DEVICE)
243
+ nps = batch["nps"].to(DEVICE) # Ground truth NPS from batch
244
+ difficulty = batch["difficulty"].to(DEVICE) # Add difficulty
245
+ level = batch["level"].to(DEVICE) # Add level
246
+
247
+ output_dict = model(
248
+ mel, lengths, nps, difficulty, level
249
+ ) # Pass difficulty and level
250
+ pred_energies_val_batch = output_dict["presence"] # (B, T_out, 3)
251
+
252
+ val_loss_input_batch = {
253
+ "don_labels": don_labels,
254
+ "ka_labels": ka_labels,
255
+ "drumroll_labels": drumroll_labels,
256
+ "lengths": lengths,
257
+ }
258
+ val_loss = criterion(output_dict, val_loss_input_batch)
259
+ total_val_loss += val_loss.item()
260
+
261
+ # Log plot for the first sample of the first batch in each validation epoch
262
+ if val_idx == 0:
263
+ first_val_sample_pred_don = pred_energies_val_batch[0, :, 0]
264
+ first_val_sample_pred_ka = pred_energies_val_batch[0, :, 1]
265
+ first_val_sample_pred_drumroll = pred_energies_val_batch[0, :, 2]
266
+
267
+ first_val_sample_true_don = don_labels[0, :]
268
+ first_val_sample_true_ka = ka_labels[0, :]
269
+ first_val_sample_true_drumroll = drumroll_labels[0, :]
270
+
271
+ first_val_sample_length = lengths[0].item()
272
+
273
+ log_energy_plots_to_tensorboard(
274
+ writer,
275
+ "Eval/Sample_0",
276
+ epoch,
277
+ first_val_sample_pred_don,
278
+ first_val_sample_pred_ka,
279
+ first_val_sample_pred_drumroll,
280
+ first_val_sample_true_don,
281
+ first_val_sample_true_ka,
282
+ first_val_sample_true_drumroll,
283
+ first_val_sample_length,
284
+ output_frame_hop_sec,
285
+ )
286
+
287
+ # Log ground truth NPS for reference during validation if needed
288
+ # writer.add_scalar("NPS/GT_Val_Batch_Avg", nps.mean().item(), epoch * len(val_loader) + idx)
289
+
290
+ avg_val_loss = total_val_loss / len(val_loader)
291
+ writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)
292
+
293
+ # Log learning rate
294
+ current_lr = optimizer.param_groups[0]["lr"]
295
+ writer.add_scalar("LR/learning_rate", current_lr, epoch)
296
+
297
+ # Log ground truth NPS from the last validation batch (or mean over epoch)
298
+ if "nps" in batch: # Check if nps is in the last batch
299
+ writer.add_scalar(
300
+ "NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
301
+ )
302
+
303
+ print(
304
+ f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
305
+ )
306
+
307
+ if avg_val_loss < best_val_loss:
308
+ best_val_loss = avg_val_loss
309
+ pat_count = 0
310
+ torch.save(model.state_dict(), "best_model.pt") # Changed model save name
311
+ print(f"Saved new best model to best_model.pt at epoch {epoch}")
312
+ else:
313
+ pat_count += 1
314
+ if pat_count >= patience:
315
+ print("Early stopping!")
316
+ break
317
+ writer.close()
318
+
319
+ model_id = "JacobLinCool/taiko-conformer-6"
320
+ try:
321
+ model.push_to_hub(model_id, commit_message="Upload trained model")
322
+ upload_folder(
323
+ repo_id=model_id,
324
+ folder_path="runs",
325
+ path_in_repo=".",
326
+ commit_message="Upload training logs",
327
+ ignore_patterns=["*.txt", "*.json", "*.csv"],
328
+ )
329
+ print(f"Model and logs uploaded to {model_id}")
330
+ except Exception as e:
331
+ print(f"Error uploading to Hugging Face Hub: {e}")
332
+ print("Make sure you have the correct permissions and try again.")
333
+
334
+
335
+ if __name__ == "__main__":
336
+ main()
tc7/__init__.py ADDED
File without changes
tc7/config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # ─── 1) CONFIG ─────────────────────────────────────────────────────
4
+ SAMPLE_RATE = 22050
5
+ N_MELS = 80
6
+ HOP_LENGTH = 256
7
+ TIME_SUB = 1
8
+ CNN_CH = 256
9
+ N_HEADS = 8
10
+ D_MODEL = 512
11
+ FF_DIM = 1024
12
+ N_LAYERS = 6
13
+ DEPTHWISE_CONV_KERNEL_SIZE = 31
14
+ DROPOUT = 0.1
15
+ HIDDEN_DIM = 64
16
+ N_TYPES = 7
17
+ BATCH_SIZE = 2
18
+ GRAD_ACCUM_STEPS = 8
19
+ LR = 3e-4
20
+ EPOCHS = 200
21
+ NPS_PENALTY_WEIGHT_ALPHA = 0.3
22
+ NPS_PENALTY_WEIGHT_BETA = 1.0
23
+ DEVICE = (
24
+ "cuda"
25
+ if torch.cuda.is_available()
26
+ else "mps" if torch.backends.mps.is_available() else "cpu"
27
+ )
tc7/dataset.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, concatenate_datasets
2
+
3
+ ds1 = load_dataset("JacobLinCool/taiko-2023-1.1", split="train")
4
+ ds2 = load_dataset("JacobLinCool/taiko-2023-1.2", split="train")
5
+ ds3 = load_dataset("JacobLinCool/taiko-2023-1.3", split="train")
6
+ ds4 = load_dataset("JacobLinCool/taiko-2023-1.4", split="train")
7
+ ds5 = load_dataset("JacobLinCool/taiko-2023-1.5", split="train")
8
+ ds6 = load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
9
+ ds7 = load_dataset("JacobLinCool/taiko-2023-1.7", split="train")
10
+ ds = concatenate_datasets([ds1, ds2, ds3, ds4, ds5, ds6, ds7]).with_format("torch")
11
+
12
+ good = list(range(len(ds)))
13
+ good.remove(1079) # 1079 has file problem
14
+ ds = ds.select(good)
15
+
16
+ # for local test
17
+ # ds = (
18
+ # load_dataset("JacobLinCool/taiko-2023-1.6", split="train")
19
+ # .with_format("torch")
20
+ # .select(range(10))
21
+ # )
tc7/infer.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import torchaudio
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from .config import SAMPLE_RATE, N_MELS, HOP_LENGTH
7
+ import torch.profiler
8
+
9
+
10
+ # --- PREPROCESSING (match training) ---
11
+ def preprocess_audio(audio_path):
12
+ wav, sr = torchaudio.load(audio_path)
13
+ wav = wav.mean(dim=0) # mono
14
+ if sr != SAMPLE_RATE:
15
+ wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
16
+ wav = wav / (wav.abs().max() + 1e-8) # Normalize audio
17
+
18
+ mel_transform = torchaudio.transforms.MelSpectrogram(
19
+ sample_rate=SAMPLE_RATE,
20
+ n_mels=N_MELS,
21
+ hop_length=HOP_LENGTH,
22
+ n_fft=2048,
23
+ )
24
+ mel = mel_transform(wav)
25
+ return mel # mel is (N_MELS, T_mel)
26
+
27
+
28
+ # --- INFERENCE ---
29
+ def run_inference(model, mel_input, nps_input, difficulty_input, level_input, device):
30
+ model.eval()
31
+ with torch.no_grad():
32
+ mel = mel_input.to(device).unsqueeze(0) # (1, N_MELS, T_mel)
33
+ nps = nps_input.to(device).unsqueeze(0) # (1,)
34
+ difficulty = difficulty_input.to(device).unsqueeze(0) # (1,)
35
+ level = level_input.to(device).unsqueeze(0) # (1,)
36
+
37
+ mel_cnn_input = mel.unsqueeze(1) # (1, 1, N_MELS, T_mel)
38
+
39
+ conformer_lengths = torch.tensor(
40
+ [mel_cnn_input.shape[-1]], dtype=torch.long, device=device
41
+ )
42
+
43
+ with torch.profiler.profile(
44
+ activities=[
45
+ torch.profiler.ProfilerActivity.CPU,
46
+ *(
47
+ [torch.profiler.ProfilerActivity.CUDA]
48
+ if device.type == "cuda"
49
+ else []
50
+ ),
51
+ ],
52
+ record_shapes=True,
53
+ profile_memory=True,
54
+ with_stack=False,
55
+ with_flops=True,
56
+ ) as prof:
57
+ out_dict = model(mel_cnn_input, conformer_lengths, nps, difficulty, level)
58
+ print(
59
+ prof.key_averages().table(
60
+ sort_by=(
61
+ "self_cuda_memory_usage"
62
+ if device.type == "cuda"
63
+ else "self_cpu_time_total"
64
+ ),
65
+ row_limit=20,
66
+ )
67
+ )
68
+
69
+ energies = out_dict["presence"].squeeze(0).cpu().numpy()
70
+
71
+ don_energy = energies[:, 0]
72
+ ka_energy = energies[:, 1]
73
+ drumroll_energy = energies[:, 2]
74
+
75
+ return don_energy, ka_energy, drumroll_energy
76
+
77
+
78
+ # --- DECODE TO ONSETS ---
79
+ def decode_onsets(
80
+ don_energy,
81
+ ka_energy,
82
+ drumroll_energy,
83
+ hop_sec,
84
+ threshold=0.5,
85
+ min_distance_frames=3,
86
+ ):
87
+ results = []
88
+ T_out = len(don_energy)
89
+ last_onset_frame = -min_distance_frames
90
+
91
+ for i in range(1, T_out - 1): # Iterate considering neighbors for peak detection
92
+ if i < last_onset_frame + min_distance_frames:
93
+ continue
94
+
95
+ e_don, e_ka, e_drum = don_energy[i], ka_energy[i], drumroll_energy[i]
96
+ energies_at_i = {
97
+ 1: e_don,
98
+ 2: e_ka,
99
+ 5: e_drum,
100
+ } # Type mapping: 1:Don, 2:Ka, 5:Drumroll
101
+
102
+ # Find which energy is max and if it's a peak above threshold
103
+ # Sort by energy value descending to prioritize higher energy in case of ties for peak condition
104
+ sorted_types_by_energy = sorted(
105
+ energies_at_i.keys(), key=lambda x: energies_at_i[x], reverse=True
106
+ )
107
+
108
+ detected_this_frame = False
109
+ for onset_type in sorted_types_by_energy:
110
+ current_energy_series = None
111
+ if onset_type == 1:
112
+ current_energy_series = don_energy
113
+ elif onset_type == 2:
114
+ current_energy_series = ka_energy
115
+ elif onset_type == 5:
116
+ current_energy_series = drumroll_energy
117
+
118
+ energy_val = current_energy_series[i]
119
+
120
+ if (
121
+ energy_val > threshold
122
+ and energy_val > current_energy_series[i - 1]
123
+ and energy_val > current_energy_series[i + 1]
124
+ ):
125
+ # Check if this energy is the highest among the three at this frame
126
+ # This check is implicitly handled by iterating `sorted_types_by_energy`
127
+ # and breaking after the first detection.
128
+ results.append((i * hop_sec, onset_type))
129
+ last_onset_frame = i
130
+ detected_this_frame = True
131
+ break # Only one onset type per frame
132
+
133
+ return results
134
+
135
+
136
+ # --- VISUALIZATION ---
137
+ def plot_results(
138
+ mel_spectrogram,
139
+ don_energy,
140
+ ka_energy,
141
+ drumroll_energy,
142
+ onsets,
143
+ hop_sec,
144
+ out_path=None,
145
+ ):
146
+ # mel_spectrogram is (N_MELS, T_mel)
147
+ T_mel = mel_spectrogram.shape[1]
148
+ T_out = len(don_energy) # Length of energy arrays (model output time dimension)
149
+
150
+ # Time axes
151
+ time_axis_mel = np.arange(T_mel) * (HOP_LENGTH / SAMPLE_RATE)
152
+ # hop_sec for model output is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE
153
+ # However, the model output T_out is related to T_mel (input to CNN).
154
+ # If CNN does not change time dimension, T_out = T_mel.
155
+ # If TIME_SUB is used for label generation, T_out = T_mel / TIME_SUB.
156
+ # The `lengths` passed to conformer in `run_inference` is T_mel.
157
+ # The output of conformer `x` has shape (B, T, D_MODEL). This T is `lengths`.
158
+ # So, T_out from model is T_mel.
159
+ # The `hop_sec` for onsets should be based on the model output frame rate.
160
+ # If model output T_out corresponds to T_mel, then hop_sec for plotting energies is HOP_LENGTH / SAMPLE_RATE.
161
+ # The `hop_sec` passed to `decode_onsets` is (HOP_LENGTH * TIME_SUB) / SAMPLE_RATE.
162
+ # This seems inconsistent. Let's clarify: `TIME_SUB` is used in `preprocess.py` to determine `T_sub` for labels.
163
+ # The model's CNN output time dimension T_cnn is used to pad/truncate labels in `collate_fn`.
164
+ # In `model.py`, the CNN does not stride in time. So T_cnn_out = T_mel_input_to_CNN.
165
+ # The `lengths` for the conformer is based on this T_cnn_out.
166
+ # So, the output of the regressor `presence` (B, T_cnn_out, 3) has T_cnn_out time steps.
167
+ # Each step corresponds to `HOP_LENGTH / SAMPLE_RATE` seconds if TIME_SUB is not involved in downsampling features for the conformer.
168
+ # Let's assume the `hop_sec` used for `decode_onsets` is correct for interpreting model output frames.
169
+ time_axis_energies = np.arange(T_out) * hop_sec
170
+
171
+ fig, ax1 = plt.subplots(figsize=(100, 10))
172
+
173
+ # Plot Mel Spectrogram on ax1
174
+ mel_db = torchaudio.functional.amplitude_to_DB(
175
+ mel_spectrogram, multiplier=10.0, amin=1e-10, db_multiplier=0.0
176
+ )
177
+ img = ax1.imshow(
178
+ mel_db.numpy(),
179
+ aspect="auto",
180
+ origin="lower",
181
+ cmap="magma",
182
+ extent=[time_axis_mel[0], time_axis_mel[-1], 0, N_MELS],
183
+ )
184
+ ax1.set_title("Mel Spectrogram with Predicted Energies and Onsets")
185
+ ax1.set_xlabel("Time (s)")
186
+ ax1.set_ylabel("Mel Bin")
187
+ fig.colorbar(img, ax=ax1, format="%+2.0f dB")
188
+
189
+ # Create a second y-axis for energies
190
+ ax2 = ax1.twinx()
191
+ ax2.plot(time_axis_energies, don_energy, label="Don Energy", color="red")
192
+ ax2.plot(time_axis_energies, ka_energy, label="Ka Energy", color="blue")
193
+ ax2.plot(
194
+ time_axis_energies, drumroll_energy, label="Drumroll Energy", color="green"
195
+ )
196
+ ax2.set_ylabel("Energy")
197
+ ax2.set_ylim(0, 1.2) # Assuming energies are somewhat normalized or bounded
198
+
199
+ # Overlay onsets from decode_onsets (t is already in seconds)
200
+ labeled_types = set()
201
+ # Group drumrolls into segments (reuse logic from write_tja)
202
+ drumroll_times = [t_sec for t_sec, typ in onsets if typ == 5]
203
+ drumroll_times.sort()
204
+ drumroll_segments = []
205
+ if drumroll_times:
206
+ seg_start = drumroll_times[0]
207
+ prev = drumroll_times[0]
208
+ for t in drumroll_times[1:]:
209
+ if t - prev <= hop_sec * 6: # up to 5-frame gap
210
+ prev = t
211
+ else:
212
+ drumroll_segments.append((seg_start, prev))
213
+ seg_start = t
214
+ prev = t
215
+ drumroll_segments.append((seg_start, prev))
216
+ # Plot Don/Ka onsets as vertical lines
217
+ for t_sec, typ in onsets:
218
+ if typ == 5:
219
+ continue # skip drumroll onsets
220
+ color_map = {1: "darkred", 2: "darkblue"}
221
+ label_map = {1: "Don Onset", 2: "Ka Onset"}
222
+ line_color = color_map.get(typ, "black")
223
+ line_label = label_map.get(typ, f"Type {typ} Onset")
224
+ if typ not in labeled_types:
225
+ ax1.axvline(
226
+ t_sec, color=line_color, linestyle="--", alpha=0.9, label=line_label
227
+ )
228
+ labeled_types.add(typ)
229
+ else:
230
+ ax1.axvline(t_sec, color=line_color, linestyle="--", alpha=0.9)
231
+ # Plot drumroll segments as shaded regions
232
+ for seg_start, seg_end in drumroll_segments:
233
+ ax1.axvspan(
234
+ seg_start,
235
+ seg_end + hop_sec,
236
+ color="green",
237
+ alpha=0.2,
238
+ label="Drumroll Segment" if "drumroll" not in labeled_types else None,
239
+ )
240
+ labeled_types.add("drumroll")
241
+
242
+ # Combine legends from both axes
243
+ lines, labels = ax1.get_legend_handles_labels()
244
+ lines2, labels2 = ax2.get_legend_handles_labels()
245
+ ax2.legend(lines + lines2, labels + labels2, loc="upper right")
246
+
247
+ fig.tight_layout()
248
+
249
+ # Return plot as image buffer or save to file if path provided
250
+ if out_path:
251
+ plt.savefig(out_path)
252
+ print(f"Saved plot to {out_path}")
253
+ plt.close(fig)
254
+ return out_path
255
+ else:
256
+ # Return plot as in-memory buffer
257
+ return fig
258
+
259
+
260
+ def write_tja(onsets, out_path=None, bpm=160, quantize=96, audio="audio.wav"):
261
+ # TJA types: 0:no note, 1:Don, 2:Ka, 3:BigDon, 4:BigKa, 5:DrumrollStart, 8:DrumrollEnd
262
+ # Model output types: 1:Don, 2:Ka, 5:Drumroll (interpreted as start/single)
263
+ sec_per_beat = 60 / bpm
264
+ beats_per_measure = 4 # Assuming 4/4 time signature
265
+ sec_per_measure = sec_per_beat * beats_per_measure
266
+ # Step 1: Map onsets to (measure_idx, slot, typ)
267
+ slot_events = []
268
+ for t, typ in onsets:
269
+ measure_idx = int(t // sec_per_measure)
270
+ t_in_measure = t % sec_per_measure
271
+ slot = int(round(t_in_measure / sec_per_measure * quantize))
272
+ if slot >= quantize:
273
+ slot = quantize - 1
274
+ slot_events.append((measure_idx, slot, typ))
275
+ # Step 2: Build measure/slot grid
276
+ if slot_events:
277
+ max_measure_idx = max(m for m, _, _ in slot_events)
278
+ else:
279
+ max_measure_idx = -1
280
+ measures = {i: [0] * quantize for i in range(max_measure_idx + 1)}
281
+ # Step 3: Place Don/Ka, collect drumrolls
282
+ drumroll_slots = set()
283
+ for m, s, typ in slot_events:
284
+ if typ in [1, 2]:
285
+ measures[m][s] = typ
286
+ elif typ == 5:
287
+ drumroll_slots.add((m, s))
288
+ # Step 4: Process drumrolls into contiguous regions, mark 5 (start) and 8 (end)
289
+ # Flatten all slots to a list of (measure, slot) sorted
290
+ drumroll_list = sorted(list(drumroll_slots))
291
+ # Group into contiguous regions (allowing a gap of 5 slots)
292
+ grouped = []
293
+ group = []
294
+ for ms in drumroll_list:
295
+ if not group:
296
+ group = [ms]
297
+ else:
298
+ last_m, last_s = group[-1]
299
+ m, s = ms
300
+ # Calculate slot distance, considering measure wrap
301
+ slot_dist = None
302
+ if m == last_m:
303
+ slot_dist = s - last_s
304
+ elif m == last_m + 1 and last_s <= quantize - 1:
305
+ slot_dist = (quantize - 1 - last_s) + s + 1
306
+ else:
307
+ slot_dist = None
308
+ # Allow gap of up to 5 slots (slot_dist <= 6)
309
+ if slot_dist is not None and 1 <= slot_dist <= 6:
310
+ group.append(ms)
311
+ else:
312
+ grouped.append(group)
313
+ group = [ms]
314
+ if group:
315
+ grouped.append(group)
316
+ # Mark 5 (start) and 8 (end) for each group
317
+ for region in grouped:
318
+ if len(region) == 1:
319
+ m, s = region[0]
320
+ measures[m][s] = 5
321
+ # Place 8 in next slot (or next measure if at end)
322
+ if s < quantize - 1:
323
+ measures[m][s + 1] = 8
324
+ elif m < max_measure_idx:
325
+ measures[m + 1][0] = 8
326
+ else:
327
+ m_start, s_start = region[0]
328
+ m_end, s_end = region[-1]
329
+ measures[m_start][s_start] = 5
330
+ measures[m_end][s_end] = 8
331
+ # Fill 0 for middle slots (already 0 by default)
332
+ # Step 5: Generate TJA content
333
+ tja_content = []
334
+ tja_content.append(f"TITLE:{audio} (TC7, {time.strftime('%Y-%m-%d %H:%M:%S')})")
335
+ tja_content.append(f"BPM:{bpm}")
336
+ tja_content.append(f"WAVE:{audio}")
337
+ tja_content.append("OFFSET:0")
338
+ tja_content.append("COURSE:Oni\nLEVEL:9\n")
339
+ tja_content.append("#START")
340
+ for i in range(max_measure_idx + 1):
341
+ notes = measures.get(i, [0] * quantize)
342
+ line = "".join(str(n) for n in notes)
343
+ tja_content.append(line + ",")
344
+ tja_content.append("#END")
345
+
346
+ tja_string = "\n".join(tja_content)
347
+
348
+ # If out_path is provided, also write to file
349
+ if out_path:
350
+ with open(out_path, "w", encoding="utf-8") as f:
351
+ f.write(tja_string)
352
+ print(f"TJA chart saved to {out_path}")
353
+
354
+ return tja_string
tc7/loss.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class TaikoLoss(nn.Module):
6
+ def __init__(
7
+ self,
8
+ reduction="mean",
9
+ nps_penalty_weight_alpha=0.3,
10
+ nps_penalty_weight_beta=1.0,
11
+ ):
12
+ super().__init__()
13
+ self.mse_loss = nn.MSELoss(reduction="none")
14
+ self.reduction = reduction
15
+ self.nps_penalty_weight_alpha = nps_penalty_weight_alpha
16
+ self.nps_penalty_weight_beta = nps_penalty_weight_beta
17
+
18
+ def forward(self, outputs, batch):
19
+ """
20
+ Calculates the MSE loss for energy-based predictions, with a two-level penalty
21
+ based on sliding NPS values.
22
+ - A heavier penalty if sliding_nps is 0.
23
+ - A continuous penalty if sliding_nps > 0.
24
+
25
+ Args:
26
+ outputs (dict): Model output, containing 'presence' tensor.
27
+ outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
28
+ batch (dict): Batch data from collate_fn, containing true labels, lengths,
29
+ and sliding_nps_labels.
30
+ batch['sliding_nps_labels'] shape: (B, T)
31
+ Returns:
32
+ torch.Tensor: The calculated loss.
33
+ """
34
+ pred_energies = outputs["presence"] # (B, T, 3)
35
+ true_don = batch["don_labels"] # (B, T)
36
+ true_ka = batch["ka_labels"] # (B, T)
37
+ true_drumroll = batch["drumroll_labels"] # (B, T)
38
+ true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2).to(
39
+ pred_energies.device
40
+ ) # (B, T, 3)
41
+
42
+ B, T, _ = pred_energies.shape
43
+
44
+ # Create a mask based on batch['lengths'] to ignore padded parts of sequences
45
+ # batch['lengths'] gives the actual length of each sequence in the batch
46
+ # mask shape: (B, T)
47
+ mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
48
+ "lengths"
49
+ ].to(pred_energies.device).unsqueeze(1)
50
+ # Expand mask to (B, T, 1) to broadcast across the 3 energy channels
51
+ mask_3d = mask_2d.unsqueeze(2) # (B, T, 1)
52
+
53
+ # Calculate element-wise MSE loss
54
+ mse_loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3)
55
+
56
+ # Calculate two-level Sliding NPS penalty
57
+ sliding_nps = batch["sliding_nps_labels"].to(pred_energies.device) # (B, T)
58
+
59
+ penalty_coefficients = torch.zeros_like(sliding_nps) # (B, T)
60
+
61
+ is_zero_nps = sliding_nps == 0.0
62
+ is_not_zero_nps = ~is_zero_nps
63
+
64
+ # Apply heavy penalty where sliding_nps is 0
65
+ penalty_coefficients[is_zero_nps] = self.nps_penalty_weight_beta
66
+
67
+ # Apply continuous penalty where sliding_nps > 0
68
+ penalty_coefficients[is_not_zero_nps] = self.nps_penalty_weight_alpha * (
69
+ 1 - sliding_nps[is_not_zero_nps]
70
+ )
71
+
72
+ # Apply penalty factor to the MSE loss
73
+ loss_elementwise = mse_loss_elementwise * (
74
+ 1 + penalty_coefficients.unsqueeze(2)
75
+ )
76
+
77
+ # Apply the mask to the combined loss
78
+ masked_loss = loss_elementwise * mask_3d
79
+
80
+ if self.reduction == "mean":
81
+ # Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
82
+ total_loss = masked_loss.sum()
83
+ num_valid_elements = mask_3d.sum() # Total number of unmasked float values
84
+ if num_valid_elements > 0:
85
+ return total_loss / num_valid_elements
86
+ else:
87
+ # Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
88
+ return torch.tensor(
89
+ 0.0, device=pred_energies.device, requires_grad=True
90
+ )
91
+ elif self.reduction == "sum":
92
+ return masked_loss.sum()
93
+ else: # 'none' or any other case
94
+ return masked_loss
tc7/model.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchaudio.models import Conformer
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from .config import (
6
+ N_MELS,
7
+ CNN_CH,
8
+ N_HEADS,
9
+ D_MODEL,
10
+ FF_DIM,
11
+ N_LAYERS,
12
+ DROPOUT,
13
+ DEPTHWISE_CONV_KERNEL_SIZE,
14
+ HIDDEN_DIM,
15
+ DEVICE,
16
+ )
17
+
18
+
19
+ class TaikoConformer7(nn.Module, PyTorchModelHubMixin):
20
+ def __init__(self):
21
+ super().__init__()
22
+ # 1) CNN frontend: frequency-only pooling
23
+ self.cnn = nn.Sequential(
24
+ nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
25
+ nn.BatchNorm2d(CNN_CH),
26
+ nn.GELU(),
27
+ nn.Dropout2d(DROPOUT),
28
+ nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
29
+ nn.BatchNorm2d(CNN_CH),
30
+ nn.GELU(),
31
+ nn.Dropout2d(DROPOUT),
32
+ )
33
+ feat_dim = CNN_CH * (N_MELS // 4)
34
+
35
+ # 2) Linear projection to model dimension
36
+ self.proj = nn.Linear(feat_dim, D_MODEL)
37
+
38
+ # 3) FiLM conditioning for notes_per_second, difficulty, and level
39
+ self.film_nps = nn.Linear(1, 2 * D_MODEL)
40
+ self.film_difficulty = nn.Linear(
41
+ 1, 2 * D_MODEL
42
+ ) # Assuming difficulty is a single scalar
43
+ self.film_level = nn.Linear(1, 2 * D_MODEL) # Assuming level is a single scalar
44
+
45
+ # 4) Conformer encoder
46
+ self.encoder = Conformer(
47
+ input_dim=D_MODEL,
48
+ num_heads=N_HEADS,
49
+ ffn_dim=FF_DIM,
50
+ num_layers=N_LAYERS,
51
+ depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
52
+ dropout=DROPOUT,
53
+ use_group_norm=False,
54
+ convolution_first=False,
55
+ )
56
+
57
+ # 5) Presence regressor head
58
+ self.presence_regressor = nn.Sequential(
59
+ nn.Dropout(DROPOUT),
60
+ nn.Linear(D_MODEL, HIDDEN_DIM),
61
+ nn.GELU(),
62
+ nn.Dropout(DROPOUT),
63
+ nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
64
+ nn.Sigmoid(), # Output between 0 and 1
65
+ )
66
+
67
+ # 6) Initialize weights
68
+ for m in self.modules():
69
+ if isinstance(m, nn.Conv2d):
70
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
71
+ elif isinstance(m, nn.Linear):
72
+ nn.init.xavier_uniform_(m.weight)
73
+ if m.bias is not None:
74
+ nn.init.zeros_(m.bias)
75
+
76
+ def forward(
77
+ self,
78
+ mel: torch.Tensor,
79
+ lengths: torch.Tensor,
80
+ notes_per_second: torch.Tensor,
81
+ difficulty: torch.Tensor,
82
+ level: torch.Tensor,
83
+ ):
84
+ """
85
+ Args:
86
+ mel: (B, 1, N_MELS, T_mel)
87
+ lengths: (B,) lengths after CNN
88
+ notes_per_second: (B,) stream of control values
89
+ difficulty: (B,) difficulty values
90
+ level: (B,) level values
91
+ Returns:
92
+ Dict with:
93
+ 'presence': (B, T_cnn_out, 3) # Corrected from 4 to 3
94
+ 'lengths': lengths
95
+ """
96
+ # CNN frontend
97
+ x = self.cnn(mel) # (B, C, F, T)
98
+ B, C, F, T = x.size()
99
+ x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
100
+
101
+ # Project to model dimension
102
+ x = self.proj(x) # (B, T, D_MODEL)
103
+
104
+ # FiLM conditioning
105
+ nps = notes_per_second.unsqueeze(-1).float() # (B, 1)
106
+ gamma_beta_nps = self.film_nps(nps) # (B, 2*D_MODEL)
107
+ gamma_nps, beta_nps = gamma_beta_nps.chunk(2, dim=-1)
108
+ x = gamma_nps.unsqueeze(1) * x + beta_nps.unsqueeze(1)
109
+
110
+ diff = difficulty.unsqueeze(-1).float() # (B, 1)
111
+ gamma_beta_diff = self.film_difficulty(diff) # (B, 2*D_MODEL)
112
+ gamma_diff, beta_diff = gamma_beta_diff.chunk(2, dim=-1)
113
+ x = gamma_diff.unsqueeze(1) * x + beta_diff.unsqueeze(1)
114
+
115
+ lvl = level.unsqueeze(-1).float() # (B, 1)
116
+ gamma_beta_lvl = self.film_level(lvl) # (B, 2*D_MODEL)
117
+ gamma_lvl, beta_lvl = gamma_beta_lvl.chunk(2, dim=-1)
118
+ x = gamma_lvl.unsqueeze(1) * x + beta_lvl.unsqueeze(1)
119
+
120
+ # Conformer encoder
121
+ x, _ = self.encoder(x, lengths=lengths)
122
+
123
+ # Presence prediction
124
+ presence = self.presence_regressor(x)
125
+ return {"presence": presence, "lengths": lengths}
126
+
127
+
128
+ if __name__ == "__main__":
129
+ model = TaikoConformer7().to(device=DEVICE)
130
+ print(model)
131
+
132
+ for name, param in model.named_parameters():
133
+ if param.requires_grad:
134
+ print(f"{name}: {param.numel():,}")
135
+
136
+ params = sum(p.numel() for p in model.parameters() if p.requires_grad)
137
+ print(f"Total parameters: {params / 1e6:.2f}M")
138
+
139
+ batch_size = 4
140
+ mel_time_steps = 1024
141
+ input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
142
+
143
+ conformer_lengths = torch.tensor(
144
+ [mel_time_steps] * batch_size, dtype=torch.long
145
+ ).to(DEVICE)
146
+
147
+ notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
148
+ DEVICE
149
+ )
150
+ difficulty_input = torch.tensor([1.0] * batch_size, dtype=torch.float32).to(
151
+ DEVICE
152
+ ) # Example difficulty
153
+ level_input = torch.tensor([5.0] * batch_size, dtype=torch.float32).to(
154
+ DEVICE
155
+ ) # Example level
156
+
157
+ output = model(
158
+ input_mel,
159
+ conformer_lengths,
160
+ notes_per_second_input,
161
+ difficulty_input,
162
+ level_input,
163
+ )
164
+ print("Output shapes:")
165
+ for key, value in output.items():
166
+ print(f"{key}: {value.shape}")
tc7/preprocess.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio
5
+ from torchaudio.transforms import FrequencyMasking
6
+ from tja import parse_tja, PyParsingMode
7
+ from .config import N_TYPES, SAMPLE_RATE, N_MELS, HOP_LENGTH, TIME_SUB
8
+ from .model import TaikoConformer7
9
+
10
+
11
+ mel_transform = torchaudio.transforms.MelSpectrogram(
12
+ sample_rate=SAMPLE_RATE,
13
+ n_mels=N_MELS,
14
+ hop_length=HOP_LENGTH,
15
+ n_fft=2048,
16
+ )
17
+
18
+
19
+ freq_mask = FrequencyMasking(freq_mask_param=15)
20
+
21
+
22
+ def preprocess(example, difficulty="oni"):
23
+ wav_tensor = example["audio"]["array"]
24
+ sr = example["audio"]["sampling_rate"]
25
+ # 1) load & resample
26
+ if sr != SAMPLE_RATE:
27
+ wav_tensor = torchaudio.functional.resample(wav_tensor, sr, SAMPLE_RATE)
28
+ # normalize audio
29
+ wav_tensor = wav_tensor / (wav_tensor.abs().max() + 1e-8)
30
+ # add random Gaussian noise
31
+ if torch.rand(1).item() < 0.5:
32
+ wav_tensor = wav_tensor + 0.005 * torch.randn_like(wav_tensor)
33
+ # 2) mel: (1, N_MELS, T)
34
+ mel = mel_transform(wav_tensor).unsqueeze(0)
35
+ # apply SpecAugment
36
+ mel = freq_mask(mel)
37
+ _, _, T = mel.shape
38
+ # 3) build label sequence of length ceil(T / TIME_SUB)
39
+ T_sub = math.ceil(T / TIME_SUB)
40
+
41
+ # Initialize energy-based labels for Don, Ka, Drumroll
42
+ don_labels = torch.zeros(T_sub, dtype=torch.float32)
43
+ ka_labels = torch.zeros(T_sub, dtype=torch.float32)
44
+ drumroll_labels = torch.zeros(T_sub, dtype=torch.float32)
45
+ sliding_nps_labels = torch.zeros(
46
+ T_sub, dtype=torch.float32
47
+ ) # New label for sliding NPS
48
+
49
+ # Define exponential decay tail parameters
50
+ tail_length = 40 # number of frames for decay tail
51
+ decay_rate = 8.0 # decay rate parameter, adjust as needed
52
+ tail_kernel = torch.exp(
53
+ -torch.arange(0, tail_length, dtype=torch.float32) / decay_rate
54
+ )
55
+
56
+ fps = SAMPLE_RATE / HOP_LENGTH
57
+ num_valid_notes = 0
58
+ for onset in example[difficulty]:
59
+ typ, t_start, t_end, *_ = onset
60
+
61
+ # Assuming N_TYPES in config is appropriately set (e.g., 7 or more)
62
+ if typ < 1 or typ > N_TYPES: # Filter out invalid types
63
+ continue
64
+
65
+ num_valid_notes += 1
66
+
67
+ exact_frame_start = t_start.item() * fps
68
+
69
+ # Type 1 and 3 are Don, Type 2 and 4 are Ka
70
+ if typ == 1 or typ == 3 or typ == 2 or typ == 4:
71
+ exact_hit_time_sub = exact_frame_start / TIME_SUB
72
+
73
+ current_labels = don_labels if (typ == 1 or typ == 3) else ka_labels
74
+
75
+ start_points_info = []
76
+ rounded_hit_time_sub = round(exact_hit_time_sub)
77
+
78
+ if (
79
+ abs(exact_hit_time_sub - rounded_hit_time_sub) < 1e-6
80
+ ): # Tolerance for float precision
81
+ idx_single = int(rounded_hit_time_sub)
82
+ if 0 <= idx_single < T_sub:
83
+ start_points_info.append({"idx": idx_single, "weight": 1.0})
84
+ else:
85
+ idx_floor = math.floor(exact_hit_time_sub)
86
+ idx_ceil = idx_floor + 1
87
+
88
+ frac = exact_hit_time_sub - idx_floor
89
+ weight_ceil = frac
90
+ weight_floor = 1.0 - frac
91
+
92
+ if weight_floor > 1e-6 and 0 <= idx_floor < T_sub:
93
+ start_points_info.append({"idx": idx_floor, "weight": weight_floor})
94
+ if weight_ceil > 1e-6 and 0 <= idx_ceil < T_sub:
95
+ start_points_info.append({"idx": idx_ceil, "weight": weight_ceil})
96
+
97
+ for point_info in start_points_info:
98
+ start_idx = point_info["idx"]
99
+ weight = point_info["weight"]
100
+ for k_idx, kernel_val in enumerate(tail_kernel):
101
+ target_idx = start_idx + k_idx
102
+ if 0 <= target_idx < T_sub:
103
+ current_labels[target_idx] = max(
104
+ current_labels[target_idx].item(),
105
+ weight * kernel_val.item(),
106
+ )
107
+
108
+ # Type 5, 6, 7 are Drumroll
109
+ elif typ >= 5 and typ <= 7:
110
+ exact_frame_end = t_end.item() * fps
111
+ exact_start_time_sub = exact_frame_start / TIME_SUB
112
+ exact_end_time_sub = exact_frame_end / TIME_SUB
113
+
114
+ # Improved drumroll body
115
+ body_loop_start_idx = math.floor(exact_start_time_sub)
116
+ body_loop_end_idx = math.ceil(exact_end_time_sub)
117
+
118
+ for dr_idx in range(body_loop_start_idx, body_loop_end_idx):
119
+ if 0 <= dr_idx < T_sub:
120
+ drumroll_labels[dr_idx] = 1.0
121
+
122
+ # Improved drumroll tail (starts from exact_end_time_sub)
123
+ tail_start_points_info = []
124
+ rounded_end_time_sub = round(exact_end_time_sub)
125
+ if abs(exact_end_time_sub - rounded_end_time_sub) < 1e-6:
126
+ idx_single_tail = int(rounded_end_time_sub)
127
+ if 0 <= idx_single_tail < T_sub:
128
+ tail_start_points_info.append(
129
+ {"idx": idx_single_tail, "weight": 1.0}
130
+ )
131
+ else:
132
+ idx_floor_tail = math.floor(exact_end_time_sub)
133
+ idx_ceil_tail = idx_floor_tail + 1
134
+
135
+ frac_tail = exact_end_time_sub - idx_floor_tail
136
+ weight_ceil_tail = frac_tail
137
+ weight_floor_tail = 1.0 - frac_tail
138
+
139
+ if weight_floor_tail > 1e-6 and 0 <= idx_floor_tail < T_sub:
140
+ tail_start_points_info.append(
141
+ {"idx": idx_floor_tail, "weight": weight_floor_tail}
142
+ )
143
+ if weight_ceil_tail > 1e-6 and 0 <= idx_ceil_tail < T_sub:
144
+ tail_start_points_info.append(
145
+ {"idx": idx_ceil_tail, "weight": weight_ceil_tail}
146
+ )
147
+
148
+ for point_info in tail_start_points_info:
149
+ start_idx = point_info["idx"]
150
+ weight = point_info["weight"]
151
+ for k_idx, kernel_val in enumerate(tail_kernel):
152
+ target_idx = start_idx + k_idx
153
+ if 0 <= target_idx < T_sub:
154
+ drumroll_labels[target_idx] = max(
155
+ drumroll_labels[target_idx].item(),
156
+ weight * kernel_val.item(),
157
+ )
158
+
159
+ # Calculate sliding window NPS
160
+ note_events = (
161
+ []
162
+ ) # Store tuples of (time_sec, type_is_drumroll_start_or_end, duration_if_drumroll)
163
+ for onset in example[difficulty]:
164
+ typ, t_start_tensor, t_end_tensor, *_ = onset
165
+ t_start = t_start_tensor.item()
166
+ t_end = t_end_tensor.item()
167
+
168
+ if typ in [1, 2, 3, 4]: # Don or Ka
169
+ note_events.append(
170
+ (t_start, False, 0)
171
+ ) # False indicates not a drumroll event, duration 0
172
+ elif typ >= 5 and typ <= 7: # Drumroll
173
+ note_events.append(
174
+ (t_start, True, t_end - t_start)
175
+ ) # True indicates drumroll start, store duration
176
+ # We don't explicitly need a drumroll end event for this calculation method
177
+
178
+ note_events.sort(key=lambda x: x[0]) # Sort by time
179
+
180
+ window_duration_seconds = 0.5
181
+ # drumroll_nps_rate = 10.0 # Removed: Will use adaptive rate
182
+
183
+ # Step 1: Calculate base_sliding_nps_labels (Don/Ka only)
184
+ base_don_ka_sliding_nps = torch.zeros(T_sub, dtype=torch.float32)
185
+ time_step_duration_sec = TIME_SUB / fps # Duration of one T_sub segment
186
+
187
+ for k_idx in range(T_sub):
188
+ k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps
189
+ k_window_start_sec = k_window_end_sec - window_duration_seconds
190
+
191
+ current_don_ka_count = 0.0
192
+ for event_t, is_drumroll, _ in note_events:
193
+ if not is_drumroll: # Don or Ka hit
194
+ if k_window_start_sec <= event_t < k_window_end_sec:
195
+ current_don_ka_count += 1
196
+ base_don_ka_sliding_nps[k_idx] = current_don_ka_count / window_duration_seconds
197
+
198
+ # Step 2: Calculate adaptive_drumroll_rates_for_all_events
199
+ adaptive_drumroll_rates_for_all_events = []
200
+ for event_t, is_drumroll, drumroll_dur in note_events:
201
+ if is_drumroll:
202
+ drumroll_start_sec = event_t
203
+ drumroll_end_sec = event_t + drumroll_dur
204
+
205
+ slice_start_idx = math.floor(drumroll_start_sec / time_step_duration_sec)
206
+ slice_end_idx = math.ceil(drumroll_end_sec / time_step_duration_sec)
207
+
208
+ slice_start_idx = max(0, slice_start_idx)
209
+ slice_end_idx = min(T_sub, slice_end_idx)
210
+
211
+ max_nps_in_drumroll_period = 0.0
212
+ if slice_start_idx < slice_end_idx:
213
+ relevant_base_nps_values = base_don_ka_sliding_nps[
214
+ slice_start_idx:slice_end_idx
215
+ ]
216
+ if relevant_base_nps_values.numel() > 0:
217
+ max_nps_in_drumroll_period = torch.max(
218
+ relevant_base_nps_values
219
+ ).item()
220
+
221
+ rate = max(5.0, max_nps_in_drumroll_period)
222
+ adaptive_drumroll_rates_for_all_events.append(rate)
223
+ else:
224
+ adaptive_drumroll_rates_for_all_events.append(0.0) # Placeholder
225
+
226
+ # Step 3: Calculate final sliding_nps_labels using adaptive rates
227
+ # sliding_nps_labels is already initialized with zeros earlier in the function.
228
+ for k_idx in range(T_sub):
229
+ k_window_end_sec = ((k_idx + 1) * TIME_SUB) / fps
230
+ k_window_start_sec = k_window_end_sec - window_duration_seconds
231
+
232
+ current_window_total_nps_contribution = 0.0
233
+ for event_idx, (event_t, is_drumroll, drumroll_dur) in enumerate(note_events):
234
+ if is_drumroll:
235
+ drumroll_start_sec = event_t
236
+ drumroll_end_sec = event_t + drumroll_dur
237
+
238
+ overlap_start = max(k_window_start_sec, drumroll_start_sec)
239
+ overlap_end = min(k_window_end_sec, drumroll_end_sec)
240
+
241
+ if overlap_end > overlap_start:
242
+ overlap_duration = overlap_end - overlap_start
243
+ current_adaptive_rate = adaptive_drumroll_rates_for_all_events[
244
+ event_idx
245
+ ]
246
+ current_window_total_nps_contribution += (
247
+ overlap_duration * current_adaptive_rate
248
+ )
249
+ else: # Don or Ka hit
250
+ if k_window_start_sec <= event_t < k_window_end_sec:
251
+ current_window_total_nps_contribution += (
252
+ 1 # Each hit contributes 1 to the count
253
+ )
254
+
255
+ sliding_nps_labels[k_idx] = (
256
+ current_window_total_nps_contribution / window_duration_seconds
257
+ )
258
+
259
+ # Normalize sliding_nps_labels to 0-1 range
260
+ if T_sub > 0: # Ensure there are elements to normalize
261
+ min_nps_val = torch.min(sliding_nps_labels)
262
+ max_nps_val = torch.max(sliding_nps_labels)
263
+ denominator = max_nps_val - min_nps_val
264
+ if denominator > 1e-6: # Use a small epsilon for float comparison
265
+ sliding_nps_labels = (sliding_nps_labels - min_nps_val) / denominator
266
+ else:
267
+ # If all values are (nearly) the same
268
+ if max_nps_val > 1e-6: # If the constant value is positive
269
+ sliding_nps_labels = torch.ones_like(sliding_nps_labels)
270
+ else: # If the constant value is zero (or very close to it)
271
+ sliding_nps_labels = torch.zeros_like(sliding_nps_labels)
272
+
273
+ duration_seconds = wav_tensor.shape[-1] / SAMPLE_RATE
274
+ nps = num_valid_notes / duration_seconds if duration_seconds > 0 else 0.0
275
+
276
+ parsed = parse_tja(example["tja"], mode=PyParsingMode.Full)
277
+ chart = next(
278
+ (chart for chart in parsed.charts if chart.course.lower() == difficulty), None
279
+ )
280
+ difficulty_id = (
281
+ 0
282
+ if difficulty == "easy"
283
+ else (
284
+ 1
285
+ if difficulty == "normal"
286
+ else 2 if difficulty == "hard" else 3 if difficulty == "oni" else 4
287
+ ) # Assuming 4 for edit/ura
288
+ )
289
+ level = chart.level if chart else 0
290
+
291
+ # --- CNN shape inference and label padding/truncation ---
292
+ # Simulate CNN to get output time length (T_cnn)
293
+ dummy_model = TaikoConformer7()
294
+ with torch.no_grad():
295
+ cnn_out = dummy_model.cnn(mel.unsqueeze(0)) # (1, C, F, T_cnn)
296
+ _, _, _, T_cnn = cnn_out.shape
297
+
298
+ # Pad or truncate labels to T_cnn
299
+ def pad_or_truncate(label, out_len):
300
+ if label.shape[0] < out_len:
301
+ pad = torch.zeros(out_len - label.shape[0], dtype=label.dtype)
302
+ return torch.cat([label, pad], dim=0)
303
+ else:
304
+ return label[:out_len]
305
+
306
+ don_labels = pad_or_truncate(don_labels, T_cnn)
307
+ ka_labels = pad_or_truncate(ka_labels, T_cnn)
308
+ drumroll_labels = pad_or_truncate(drumroll_labels, T_cnn)
309
+ sliding_nps_labels = pad_or_truncate(sliding_nps_labels, T_cnn) # Pad new label
310
+
311
+ # For conformer input lengths: this should be T_cnn
312
+ conformer_sequence_length = T_cnn # This is the actual sequence length after CNN
313
+
314
+ print(
315
+ f"Processed {num_valid_notes} notes in {duration_seconds:.2f} seconds, NPS: {nps:.2f}, Difficulty: {difficulty_id}, Level: {level}"
316
+ )
317
+
318
+ return {
319
+ "mel": mel, # (1, N_MELS, T)
320
+ "don_labels": don_labels, # (T_cnn,)
321
+ "ka_labels": ka_labels, # (T_cnn,)
322
+ "drumroll_labels": drumroll_labels, # (T_cnn,)
323
+ "sliding_nps_labels": sliding_nps_labels, # Add new label (T_cnn,)
324
+ "nps": torch.tensor(nps, dtype=torch.float32),
325
+ "difficulty": torch.tensor(difficulty_id, dtype=torch.long),
326
+ "level": torch.tensor(level, dtype=torch.long),
327
+ "duration_seconds": torch.tensor(duration_seconds, dtype=torch.float32),
328
+ "length": torch.tensor(
329
+ conformer_sequence_length, dtype=torch.long
330
+ ), # Use T_cnn for conformer and loss masking
331
+ }
332
+
333
+
334
+ def collate_fn(batch):
335
+ mels_list = [b["mel"].squeeze(0).transpose(0, 1) for b in batch] # (T, N_MELS)
336
+ don_labels_list = [b["don_labels"] for b in batch]
337
+ ka_labels_list = [b["ka_labels"] for b in batch]
338
+ drumroll_labels_list = [b["drumroll_labels"] for b in batch]
339
+ sliding_nps_labels_list = [b["sliding_nps_labels"] for b in batch] # New label list
340
+ nps_list = [b["nps"] for b in batch]
341
+ difficulty_list = [b["difficulty"] for b in batch]
342
+ level_list = [b["level"] for b in batch]
343
+ durations_list = [b["duration_seconds"] for b in batch]
344
+ lengths_list = [b["length"] for b in batch] # These are T_cnn_i for each example
345
+
346
+ # Pad mels
347
+ padded_mels = nn.utils.rnn.pad_sequence(
348
+ mels_list, batch_first=True
349
+ ) # (B, T_max_mel, N_MELS)
350
+ reshaped_mels = padded_mels.transpose(1, 2).unsqueeze(1)
351
+ # T_max_mel_batch = padded_mels.shape[1] # Max mel length in batch, not used for label padding anymore
352
+
353
+ # Determine max sequence length for labels (max T_cnn in batch)
354
+ max_label_len = 0
355
+ if lengths_list: # handle empty batch case
356
+ max_label_len = max(l.item() for l in lengths_list) if lengths_list else 0
357
+
358
+ # Pad labels to max_label_len (max_t_cnn_in_batch)
359
+ def pad_label_to_max_len(label_tensor, target_len):
360
+ current_len = label_tensor.shape[0]
361
+ if current_len < target_len:
362
+ padding_size = target_len - current_len
363
+ # Ensure padding is created on the same device as the label_tensor
364
+ padding = torch.zeros(
365
+ padding_size, dtype=label_tensor.dtype, device=label_tensor.device
366
+ )
367
+ return torch.cat((label_tensor, padding), dim=0)
368
+ elif (
369
+ current_len > target_len
370
+ ): # Should ideally not happen if lengths_list is correct
371
+ return label_tensor[:target_len]
372
+ return label_tensor
373
+
374
+ don_labels = torch.stack(
375
+ [pad_label_to_max_len(l, max_label_len) for l in don_labels_list]
376
+ )
377
+ ka_labels = torch.stack(
378
+ [pad_label_to_max_len(l, max_label_len) for l in ka_labels_list]
379
+ )
380
+ drumroll_labels = torch.stack(
381
+ [pad_label_to_max_len(l, max_label_len) for l in drumroll_labels_list]
382
+ )
383
+ sliding_nps_labels = torch.stack(
384
+ [pad_label_to_max_len(l, max_label_len) for l in sliding_nps_labels_list]
385
+ ) # Pad new labels
386
+
387
+ actual_lengths = torch.tensor([l.item() for l in lengths_list], dtype=torch.long)
388
+
389
+ return {
390
+ "mel": reshaped_mels,
391
+ "don_labels": don_labels,
392
+ "ka_labels": ka_labels,
393
+ "drumroll_labels": drumroll_labels,
394
+ "sliding_nps_labels": sliding_nps_labels, # Add new batched labels
395
+ "lengths": actual_lengths, # for conformer and loss masking (T_cnn_i for each item)
396
+ "nps": torch.stack(nps_list),
397
+ "difficulty": torch.stack(difficulty_list),
398
+ "level": torch.stack(level_list),
399
+ "durations": torch.stack(durations_list),
400
+ }
tc7/train.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from accelerate.utils import set_seed
2
+
3
+ set_seed(1024)
4
+
5
+ import math
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ from tqdm import tqdm
10
+ from datasets import concatenate_datasets
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ from .config import (
14
+ BATCH_SIZE,
15
+ DEVICE,
16
+ EPOCHS,
17
+ LR,
18
+ GRAD_ACCUM_STEPS,
19
+ HOP_LENGTH,
20
+ NPS_PENALTY_WEIGHT_ALPHA,
21
+ NPS_PENALTY_WEIGHT_BETA,
22
+ SAMPLE_RATE,
23
+ )
24
+ from .model import TaikoConformer7
25
+ from .dataset import ds
26
+ from .preprocess import preprocess, collate_fn
27
+ from .loss import TaikoLoss
28
+ from huggingface_hub import upload_folder
29
+
30
+
31
+ def log_energy_plots_to_tensorboard(
32
+ writer,
33
+ tag_prefix,
34
+ epoch,
35
+ pred_don,
36
+ pred_ka,
37
+ pred_drumroll,
38
+ true_don,
39
+ true_ka,
40
+ true_drumroll,
41
+ valid_length,
42
+ hop_sec,
43
+ ):
44
+ """
45
+ Logs a plot of predicted vs. true energies for one sample to TensorBoard.
46
+ Energies should be 1D numpy arrays for the single sample, up to valid_length.
47
+ """
48
+ pred_don = pred_don[:valid_length].detach().cpu().numpy()
49
+ pred_ka = pred_ka[:valid_length].detach().cpu().numpy()
50
+ pred_drumroll = pred_drumroll[:valid_length].detach().cpu().numpy()
51
+ true_don = true_don[:valid_length].cpu().numpy()
52
+ true_ka = true_ka[:valid_length].cpu().numpy()
53
+ true_drumroll = true_drumroll[:valid_length].cpu().numpy()
54
+
55
+ time_axis = np.arange(valid_length) * hop_sec
56
+
57
+ fig, axs = plt.subplots(3, 1, figsize=(15, 10), sharex=True)
58
+ fig.suptitle(f"{tag_prefix} - Epoch {epoch}", fontsize=16)
59
+
60
+ axs[0].plot(time_axis, true_don, label="True Don", color="blue", linestyle="--")
61
+ axs[0].plot(time_axis, pred_don, label="Pred Don", color="lightblue", alpha=0.8)
62
+ axs[0].set_ylabel("Don Energy")
63
+ axs[0].legend()
64
+ axs[0].grid(True)
65
+
66
+ axs[1].plot(time_axis, true_ka, label="True Ka", color="red", linestyle="--")
67
+ axs[1].plot(time_axis, pred_ka, label="Pred Ka", color="lightcoral", alpha=0.8)
68
+ axs[1].set_ylabel("Ka Energy")
69
+ axs[1].legend()
70
+ axs[1].grid(True)
71
+
72
+ axs[2].plot(
73
+ time_axis, true_drumroll, label="True Drumroll", color="green", linestyle="--"
74
+ )
75
+ axs[2].plot(
76
+ time_axis, pred_drumroll, label="Pred Drumroll", color="lightgreen", alpha=0.8
77
+ )
78
+ axs[2].set_ylabel("Drumroll Energy")
79
+ axs[2].set_xlabel("Time (s)")
80
+ axs[2].legend()
81
+ axs[2].grid(True)
82
+
83
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
84
+ writer.add_figure(f"{tag_prefix}/Energy_Comparison", fig, epoch)
85
+ plt.close(fig)
86
+
87
+
88
+ def main():
89
+ global ds
90
+
91
+ output_frame_hop_sec = HOP_LENGTH / SAMPLE_RATE
92
+
93
+ best_val_loss = float("inf")
94
+ patience = 10
95
+ pat_count = 0
96
+
97
+ ds_oni = ds.map(
98
+ preprocess,
99
+ remove_columns=ds.column_names,
100
+ fn_kwargs={"difficulty": "oni"},
101
+ writer_batch_size=10,
102
+ )
103
+ ds_hard = ds.map(
104
+ preprocess,
105
+ remove_columns=ds.column_names,
106
+ fn_kwargs={"difficulty": "hard"},
107
+ writer_batch_size=10,
108
+ )
109
+ ds_normal = ds.map(
110
+ preprocess,
111
+ remove_columns=ds.column_names,
112
+ fn_kwargs={"difficulty": "normal"},
113
+ writer_batch_size=10,
114
+ )
115
+ ds = concatenate_datasets([ds_oni, ds_hard, ds_normal])
116
+
117
+ ds_train_test = ds.train_test_split(test_size=0.1, seed=42)
118
+ train_loader = DataLoader(
119
+ ds_train_test["train"],
120
+ batch_size=BATCH_SIZE,
121
+ shuffle=True,
122
+ collate_fn=collate_fn,
123
+ num_workers=8,
124
+ persistent_workers=True,
125
+ prefetch_factor=4,
126
+ )
127
+ val_loader = DataLoader(
128
+ ds_train_test["test"],
129
+ batch_size=BATCH_SIZE,
130
+ shuffle=False,
131
+ collate_fn=collate_fn,
132
+ num_workers=8,
133
+ persistent_workers=True,
134
+ prefetch_factor=4,
135
+ )
136
+
137
+ model = TaikoConformer7().to(DEVICE)
138
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
139
+
140
+ criterion = TaikoLoss(
141
+ reduction="mean",
142
+ nps_penalty_weight_alpha=NPS_PENALTY_WEIGHT_ALPHA,
143
+ nps_penalty_weight_beta=NPS_PENALTY_WEIGHT_BETA,
144
+ ).to(DEVICE)
145
+
146
+ num_optimizer_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
147
+ total_optimizer_steps = EPOCHS * num_optimizer_steps_per_epoch
148
+
149
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
150
+ optimizer, max_lr=LR, total_steps=total_optimizer_steps
151
+ )
152
+
153
+ writer = SummaryWriter()
154
+
155
+ for epoch in range(1, EPOCHS + 1):
156
+ model.train()
157
+ total_epoch_loss = 0.0
158
+ optimizer.zero_grad()
159
+
160
+ for idx, batch in enumerate(tqdm(train_loader, desc=f"Train Epoch {epoch}")):
161
+ mel = batch["mel"].to(DEVICE)
162
+ lengths = batch["lengths"].to(DEVICE)
163
+ nps = batch["nps"].to(DEVICE)
164
+ difficulty = batch["difficulty"].to(DEVICE)
165
+ level = batch["level"].to(DEVICE)
166
+
167
+ outputs = model(mel, lengths, nps, difficulty, level)
168
+ loss = criterion(outputs, batch)
169
+
170
+ total_epoch_loss += loss.item()
171
+
172
+ loss = loss / GRAD_ACCUM_STEPS
173
+ loss.backward()
174
+
175
+ if (idx + 1) % GRAD_ACCUM_STEPS == 0 or (idx + 1) == len(train_loader):
176
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
177
+ optimizer.step()
178
+ scheduler.step()
179
+ optimizer.zero_grad()
180
+
181
+ writer.add_scalar(
182
+ "Loss/Train_Step",
183
+ loss.item() * GRAD_ACCUM_STEPS,
184
+ epoch * len(train_loader) + idx,
185
+ )
186
+ writer.add_scalar(
187
+ "LR", scheduler.get_last_lr()[0], epoch * len(train_loader) + idx
188
+ )
189
+
190
+ if idx < 3:
191
+ if mel.size(0) > 0:
192
+ pred_don = outputs["presence"][0, :, 0]
193
+ pred_ka = outputs["presence"][0, :, 1]
194
+ pred_drumroll = outputs["presence"][0, :, 2]
195
+ true_don = batch["don_labels"][0]
196
+ true_ka = batch["ka_labels"][0]
197
+ true_drumroll = batch["drumroll_labels"][0]
198
+ valid_length = batch["lengths"][0].item()
199
+
200
+ log_energy_plots_to_tensorboard(
201
+ writer,
202
+ f"Train_Sample_Batch_{idx}_Sample_0",
203
+ epoch,
204
+ pred_don,
205
+ pred_ka,
206
+ pred_drumroll,
207
+ true_don,
208
+ true_ka,
209
+ true_drumroll,
210
+ valid_length,
211
+ output_frame_hop_sec,
212
+ )
213
+
214
+ avg_train_loss = total_epoch_loss / len(train_loader)
215
+ writer.add_scalar("Loss/Train_Avg", avg_train_loss, epoch)
216
+
217
+ model.eval()
218
+ total_val_loss = 0.0
219
+
220
+ with torch.no_grad():
221
+ for idx, batch in enumerate(tqdm(val_loader, desc=f"Val Epoch {epoch}")):
222
+ mel = batch["mel"].to(DEVICE)
223
+ lengths = batch["lengths"].to(DEVICE)
224
+ nps = batch["nps"].to(DEVICE)
225
+ difficulty = batch["difficulty"].to(DEVICE)
226
+ level = batch["level"].to(DEVICE)
227
+
228
+ outputs = model(mel, lengths, nps, difficulty, level)
229
+ loss = criterion(outputs, batch)
230
+ total_val_loss += loss.item()
231
+
232
+ if idx < 3:
233
+ if mel.size(0) > 0:
234
+ pred_don = outputs["presence"][0, :, 0]
235
+ pred_ka = outputs["presence"][0, :, 1]
236
+ pred_drumroll = outputs["presence"][0, :, 2]
237
+ true_don = batch["don_labels"][0]
238
+ true_ka = batch["ka_labels"][0]
239
+ true_drumroll = batch["drumroll_labels"][0]
240
+ valid_length = batch["lengths"][0].item()
241
+
242
+ log_energy_plots_to_tensorboard(
243
+ writer,
244
+ f"Val_Sample_Batch_{idx}_Sample_0",
245
+ epoch,
246
+ pred_don,
247
+ pred_ka,
248
+ pred_drumroll,
249
+ true_don,
250
+ true_ka,
251
+ true_drumroll,
252
+ valid_length,
253
+ output_frame_hop_sec,
254
+ )
255
+
256
+ avg_val_loss = total_val_loss / len(val_loader)
257
+ writer.add_scalar("Loss/Val_Avg", avg_val_loss, epoch)
258
+
259
+ current_lr = optimizer.param_groups[0]["lr"]
260
+ writer.add_scalar("LR/learning_rate", current_lr, epoch)
261
+
262
+ if "nps" in batch:
263
+ writer.add_scalar(
264
+ "NPS/GT_Val_LastBatch_Avg", batch["nps"].mean().item(), epoch
265
+ )
266
+
267
+ print(
268
+ f"Epoch {epoch:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}"
269
+ )
270
+
271
+ if avg_val_loss < best_val_loss:
272
+ best_val_loss = avg_val_loss
273
+ pat_count = 0
274
+ torch.save(model.state_dict(), "best_model.pt")
275
+ print(f"Saved new best model to best_model.pt at epoch {epoch}")
276
+ else:
277
+ pat_count += 1
278
+ if pat_count >= patience:
279
+ print("Early stopping!")
280
+ break
281
+ writer.close()
282
+
283
+ model_id = "JacobLinCool/taiko-conformer-7"
284
+ try:
285
+ model.push_to_hub(
286
+ model_id, commit_message=f"Epoch {epoch}, Val Loss: {avg_val_loss:.4f}"
287
+ )
288
+ upload_folder(
289
+ repo_id=model_id,
290
+ folder_path="runs",
291
+ path_in_repo="runs",
292
+ commit_message="Upload TensorBoard logs",
293
+ )
294
+ except Exception as e:
295
+ print(f"Error uploading model or logs: {e}")
296
+ print("Make sure you have the correct permissions and try again.")
297
+
298
+
299
+ if __name__ == "__main__":
300
+ main()