SteveZerb commited on
Commit
bf288c7
·
verified ·
1 Parent(s): bc32a90

Delete app_onnx.py

Browse files
Files changed (1) hide show
  1. app_onnx.py +0 -594
app_onnx.py DELETED
@@ -1,594 +0,0 @@
1
- import spaces
2
- import random
3
- import argparse
4
- import glob
5
- import json
6
- import os
7
- import time
8
- from concurrent.futures import ThreadPoolExecutor
9
-
10
- import gradio as gr
11
- import numpy as np
12
- import onnxruntime as rt
13
- import tqdm
14
- from huggingface_hub import hf_hub_download
15
-
16
- import MIDI
17
- from midi_synthesizer import MidiSynthesizer
18
- from midi_tokenizer import MIDITokenizer
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- in_space = os.getenv("SYSTEM") == "spaces"
22
-
23
-
24
- def softmax(x, axis):
25
- x_max = np.amax(x, axis=axis, keepdims=True)
26
- exp_x_shifted = np.exp(x - x_max)
27
- return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
28
-
29
-
30
- def sample_top_p_k(probs, p, k, generator=None):
31
- if generator is None:
32
- generator = np.random
33
- probs_idx = np.argsort(-probs, axis=-1)
34
- probs_sort = np.take_along_axis(probs, probs_idx, -1)
35
- probs_sum = np.cumsum(probs_sort, axis=-1)
36
- mask = probs_sum - probs_sort > p
37
- probs_sort[mask] = 0.0
38
- mask = np.zeros(probs_sort.shape[-1])
39
- mask[:k] = 1
40
- probs_sort = probs_sort * mask
41
- probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
42
- shape = probs_sort.shape
43
- probs_sort_flat = probs_sort.reshape(-1, shape[-1])
44
- probs_idx_flat = probs_idx.reshape(-1, shape[-1])
45
- next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
46
- next_token = next_token.reshape(*shape[:-1])
47
- return next_token
48
-
49
-
50
- def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
51
- io_binding = model.io_binding()
52
- for input_ in model.get_inputs():
53
- name = input_.name
54
- if name.startswith("past_key_values"):
55
- present_name = name.replace("past_key_values", "present")
56
- if present_name in outputs:
57
- v = outputs[present_name]
58
- else:
59
- v = rt.OrtValue.ortvalue_from_shape_and_type(
60
- (batch_size, input_.shape[1], past_len, input_.shape[3]),
61
- element_type=np.float32,
62
- device_type=device)
63
- inputs[name] = v
64
- else:
65
- v = inputs[name]
66
- io_binding.bind_ortvalue_input(name, v)
67
-
68
- for output in model.get_outputs():
69
- name = output.name
70
- if name.startswith("present"):
71
- v = rt.OrtValue.ortvalue_from_shape_and_type(
72
- (batch_size, output.shape[1], cur_len, output.shape[3]),
73
- element_type=np.float32,
74
- device_type=device)
75
- outputs[name] = v
76
- else:
77
- v = outputs[name]
78
- io_binding.bind_ortvalue_output(name, v)
79
- return io_binding
80
-
81
- def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
82
- disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
83
- tokenizer = model[2]
84
- if disable_channels is not None:
85
- disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
86
- else:
87
- disable_channels = []
88
- if generator is None:
89
- generator = np.random
90
- max_token_seq = tokenizer.max_token_seq
91
- if prompt is None:
92
- input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
93
- input_tensor[0, 0] = tokenizer.bos_id # bos
94
- input_tensor = input_tensor[None, :, :]
95
- input_tensor = np.repeat(input_tensor, repeats=batch_size, axis=0)
96
- else:
97
- if len(prompt.shape) == 2:
98
- prompt = prompt[None, :]
99
- prompt = np.repeat(prompt, repeats=batch_size, axis=0)
100
- elif prompt.shape[0] == 1:
101
- prompt = np.repeat(prompt, repeats=batch_size, axis=0)
102
- elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
103
- raise ValueError(f"invalid shape for prompt, {prompt.shape}")
104
- prompt = prompt[..., :max_token_seq]
105
- if prompt.shape[-1] < max_token_seq:
106
- prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
107
- mode="constant", constant_values=tokenizer.pad_id)
108
- input_tensor = prompt
109
- cur_len = input_tensor.shape[1]
110
- bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
111
- model0_inputs = {}
112
- model0_outputs = {}
113
- emb_size = 1024
114
- for output in model[0].get_outputs():
115
- if output.name == "hidden":
116
- emb_size = output.shape[2]
117
- past_len = 0
118
- with bar:
119
- while cur_len < max_len:
120
- end = [False] * batch_size
121
- model0_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(input_tensor[:, past_len:], device_type=device)
122
- model0_outputs["hidden"] = rt.OrtValue.ortvalue_from_shape_and_type(
123
- (batch_size, cur_len - past_len, emb_size),
124
- element_type=np.float32,
125
- device_type=device)
126
- io_binding = apply_io_binding(model[0], model0_inputs, model0_outputs, batch_size, past_len, cur_len)
127
- io_binding.synchronize_inputs()
128
- model[0].run_with_iobinding(io_binding)
129
- io_binding.synchronize_outputs()
130
-
131
- hidden = model0_outputs["hidden"].numpy()[:, -1:]
132
- next_token_seq = np.zeros((batch_size, 0), dtype=np.int64)
133
- event_names = [""] * batch_size
134
- model1_inputs = {"hidden": rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)}
135
- model1_outputs = {}
136
- for i in range(max_token_seq):
137
- mask = np.zeros((batch_size, tokenizer.vocab_size), dtype=np.int64)
138
- for b in range(batch_size):
139
- if end[b]:
140
- mask[b, tokenizer.pad_id] = 1
141
- continue
142
- if i == 0:
143
- mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
144
- if disable_patch_change:
145
- mask_ids.remove(tokenizer.event_ids["patch_change"])
146
- if disable_control_change:
147
- mask_ids.remove(tokenizer.event_ids["control_change"])
148
- mask[b, mask_ids] = 1
149
- else:
150
- param_names = tokenizer.events[event_names[b]]
151
- if i > len(param_names):
152
- mask[b, tokenizer.pad_id] = 1
153
- continue
154
- param_name = param_names[i - 1]
155
- mask_ids = tokenizer.parameter_ids[param_name]
156
- if param_name == "channel":
157
- mask_ids = [i for i in mask_ids if i not in disable_channels]
158
- mask[b, mask_ids] = 1
159
- mask = mask[:, None, :]
160
- x = next_token_seq
161
- if i != 0:
162
- # cached
163
- if i == 1:
164
- hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
165
- model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
166
- x = x[:, -1:]
167
- model1_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(x, device_type=device)
168
- model1_outputs["y"] = rt.OrtValue.ortvalue_from_shape_and_type(
169
- (batch_size, 1, tokenizer.vocab_size),
170
- element_type=np.float32,
171
- device_type=device
172
- )
173
- io_binding = apply_io_binding(model[1], model1_inputs, model1_outputs, batch_size, i, i+1)
174
- io_binding.synchronize_inputs()
175
- model[1].run_with_iobinding(io_binding)
176
- io_binding.synchronize_outputs()
177
- logits = model1_outputs["y"].numpy()
178
- scores = softmax(logits / temp, -1) * mask
179
- samples = sample_top_p_k(scores, top_p, top_k, generator)
180
- if i == 0:
181
- next_token_seq = samples
182
- for b in range(batch_size):
183
- if end[b]:
184
- continue
185
- eid = samples[b].item()
186
- if eid == tokenizer.eos_id:
187
- end[b] = True
188
- else:
189
- event_names[b] = tokenizer.id_events[eid]
190
- else:
191
- next_token_seq = np.concatenate([next_token_seq, samples], axis=1)
192
- if all([len(tokenizer.events[event_names[b]]) == i for b in range(batch_size) if not end[b]]):
193
- break
194
- if next_token_seq.shape[1] < max_token_seq:
195
- next_token_seq = np.pad(next_token_seq,
196
- ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
197
- mode="constant", constant_values=tokenizer.pad_id)
198
- next_token_seq = next_token_seq[:, None, :]
199
- input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
200
- past_len = cur_len
201
- cur_len += 1
202
- bar.update(1)
203
- yield next_token_seq[:, 0]
204
- if all(end):
205
- break
206
-
207
-
208
- def create_msg(name, data):
209
- return {"name": name, "data": data}
210
-
211
-
212
- def send_msgs(msgs):
213
- return json.dumps(msgs)
214
-
215
-
216
- def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
217
- time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
218
- remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
219
- t = gen_events // 28
220
- if "large" in model_name:
221
- t = gen_events // 20
222
- return t + 10
223
-
224
-
225
- @spaces.GPU(duration=get_duration)
226
- def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
227
- key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
228
- seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
229
- model = models[model_name]
230
- model_base = rt.InferenceSession(model[0], providers=providers)
231
- model_token = rt.InferenceSession(model[1], providers=providers)
232
- tokenizer = model[2]
233
- model = [model_base, model_token, tokenizer]
234
- bpm = int(bpm)
235
- if time_sig == "auto":
236
- time_sig = None
237
- time_sig_nn = 4
238
- time_sig_dd = 2
239
- else:
240
- time_sig_nn, time_sig_dd = time_sig.split('/')
241
- time_sig_nn = int(time_sig_nn)
242
- time_sig_dd = {2: 1, 4: 2, 8: 3}[int(time_sig_dd)]
243
- if key_sig == 0:
244
- key_sig = None
245
- key_sig_sf = 0
246
- key_sig_mi = 0
247
- else:
248
- key_sig = (key_sig - 1)
249
- key_sig_sf = key_sig // 2 - 7
250
- key_sig_mi = key_sig % 2
251
- gen_events = int(gen_events)
252
- max_len = gen_events
253
- if seed_rand:
254
- seed = random.randint(0, MAX_SEED)
255
- generator = np.random.RandomState(seed)
256
- disable_patch_change = False
257
- disable_channels = None
258
- if tab == 0:
259
- i = 0
260
- mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
261
- if tokenizer.version == "v2":
262
- if time_sig is not None:
263
- mid.append(tokenizer.event2tokens(["time_signature", 0, 0, 0, time_sig_nn - 1, time_sig_dd - 1]))
264
- if key_sig is not None:
265
- mid.append(tokenizer.event2tokens(["key_signature", 0, 0, 0, key_sig_sf + 7, key_sig_mi]))
266
- if bpm != 0:
267
- mid.append(tokenizer.event2tokens(["set_tempo", 0, 0, 0, bpm]))
268
- patches = {}
269
- if instruments is None:
270
- instruments = []
271
- for instr in instruments:
272
- patches[i] = patch2number[instr]
273
- i = (i + 1) if i != 8 else 10
274
- if drum_kit != "None":
275
- patches[9] = drum_kits2number[drum_kit]
276
- for i, (c, p) in enumerate(patches.items()):
277
- mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i + 1, c, p]))
278
- mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
279
- mid_seq = mid.tolist()
280
- if len(instruments) > 0:
281
- disable_patch_change = True
282
- disable_channels = [i for i in range(16) if i not in patches]
283
- elif tab == 1 and mid is not None:
284
- eps = 4 if reduce_cc_st else 0
285
- mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
286
- remap_track_channel=remap_track_channel,
287
- add_default_instr=add_default_instr,
288
- remove_empty_channels=remove_empty_channels)
289
- mid = mid[:int(midi_events)]
290
- mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
291
- mid_seq = mid.tolist()
292
- elif tab == 2 and mid_seq is not None:
293
- mid = np.asarray(mid_seq, dtype=np.int64)
294
- if continuation_select > 0:
295
- continuation_state.append(mid_seq)
296
- mid = np.repeat(mid[continuation_select - 1:continuation_select], repeats=OUTPUT_BATCH_SIZE, axis=0)
297
- mid_seq = mid.tolist()
298
- else:
299
- continuation_state.append(mid.shape[1])
300
- else:
301
- continuation_state = [0]
302
- mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
303
- mid = np.asarray([mid] * OUTPUT_BATCH_SIZE, dtype=np.int64)
304
- mid_seq = mid.tolist()
305
-
306
- if mid is not None:
307
- max_len += mid.shape[1]
308
-
309
- init_msgs = [create_msg("progress", [0, gen_events])]
310
- if not (tab == 2 and continuation_select == 0):
311
- for i in range(OUTPUT_BATCH_SIZE):
312
- events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
313
- init_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
314
- create_msg("visualizer_append", [i, events])]
315
- yield mid_seq, continuation_state, seed, send_msgs(init_msgs)
316
- midi_generator = generate(model, mid, batch_size=OUTPUT_BATCH_SIZE, max_len=max_len, temp=temp,
317
- top_p=top_p, top_k=top_k, disable_patch_change=disable_patch_change,
318
- disable_control_change=not allow_cc, disable_channels=disable_channels,
319
- generator=generator)
320
- events = [list() for i in range(OUTPUT_BATCH_SIZE)]
321
- t = time.time() + 1
322
- for i, token_seqs in enumerate(midi_generator):
323
- token_seqs = token_seqs.tolist()
324
- for j in range(OUTPUT_BATCH_SIZE):
325
- token_seq = token_seqs[j]
326
- mid_seq[j].append(token_seq)
327
- events[j].append(tokenizer.tokens2event(token_seq))
328
- if time.time() - t > 0.5:
329
- msgs = [create_msg("progress", [i + 1, gen_events])]
330
- for j in range(OUTPUT_BATCH_SIZE):
331
- msgs += [create_msg("visualizer_append", [j, events[j]])]
332
- events[j] = list()
333
- yield mid_seq, continuation_state, seed, send_msgs(msgs)
334
- t = time.time()
335
- yield mid_seq, continuation_state, seed, send_msgs([])
336
-
337
-
338
- def finish_run(model_name, mid_seq):
339
- if mid_seq is None:
340
- outputs = [None] * OUTPUT_BATCH_SIZE
341
- return *outputs, []
342
- tokenizer = models[model_name][2]
343
- outputs = []
344
- end_msgs = [create_msg("progress", [0, 0])]
345
- if not os.path.exists("outputs"):
346
- os.mkdir("outputs")
347
- for i in range(OUTPUT_BATCH_SIZE):
348
- events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
349
- mid = tokenizer.detokenize(mid_seq[i])
350
- with open(f"outputs/output{i + 1}.mid", 'wb') as f:
351
- f.write(MIDI.score2midi(mid))
352
- outputs.append(f"outputs/output{i + 1}.mid")
353
- end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
354
- create_msg("visualizer_append", [i, events]),
355
- create_msg("visualizer_end", i)]
356
- return *outputs, send_msgs(end_msgs)
357
-
358
-
359
- def synthesis_task(mid):
360
- return synthesizer.synthesis(MIDI.score2opus(mid))
361
-
362
- def render_audio(model_name, mid_seq, should_render_audio):
363
- if (not should_render_audio) or mid_seq is None:
364
- outputs = [None] * OUTPUT_BATCH_SIZE
365
- return tuple(outputs)
366
- tokenizer = models[model_name][2]
367
- outputs = []
368
- if not os.path.exists("outputs"):
369
- os.mkdir("outputs")
370
- audio_futures = []
371
- for i in range(OUTPUT_BATCH_SIZE):
372
- mid = tokenizer.detokenize(mid_seq[i])
373
- audio_future = thread_pool.submit(synthesis_task, mid)
374
- audio_futures.append(audio_future)
375
- for future in audio_futures:
376
- outputs.append((44100, future.result()))
377
- if OUTPUT_BATCH_SIZE == 1:
378
- return outputs[0]
379
- return tuple(outputs)
380
-
381
-
382
- def undo_continuation(model_name, mid_seq, continuation_state):
383
- if mid_seq is None or len(continuation_state) < 2:
384
- return mid_seq, continuation_state, send_msgs([])
385
- tokenizer = models[model_name][2]
386
- if isinstance(continuation_state[-1], list):
387
- mid_seq = continuation_state[-1]
388
- else:
389
- mid_seq = [ms[:continuation_state[-1]] for ms in mid_seq]
390
- continuation_state = continuation_state[:-1]
391
- end_msgs = [create_msg("progress", [0, 0])]
392
- for i in range(OUTPUT_BATCH_SIZE):
393
- events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
394
- end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
395
- create_msg("visualizer_append", [i, events]),
396
- create_msg("visualizer_end", i)]
397
- return mid_seq, continuation_state, send_msgs(end_msgs)
398
-
399
-
400
- def load_javascript(dir="javascript"):
401
- scripts_list = glob.glob(f"{dir}/*.js")
402
- javascript = ""
403
- for path in scripts_list:
404
- with open(path, "r", encoding="utf8") as jsfile:
405
- js_content = jsfile.read()
406
- js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
407
- f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
408
- javascript += f"\n<!-- {path} --><script>{js_content}</script>"
409
- template_response_ori = gr.routes.templates.TemplateResponse
410
-
411
- def template_response(*args, **kwargs):
412
- res = template_response_ori(*args, **kwargs)
413
- res.body = res.body.replace(
414
- b'</head>', f'{javascript}</head>'.encode("utf8"))
415
- res.init_headers()
416
- return res
417
-
418
- gr.routes.templates.TemplateResponse = template_response
419
-
420
-
421
- def hf_hub_download_retry(repo_id, filename):
422
- print(f"downloading {repo_id} {filename}")
423
- retry = 0
424
- err = None
425
- while retry < 30:
426
- try:
427
- return hf_hub_download(repo_id=repo_id, filename=filename)
428
- except Exception as e:
429
- err = e
430
- retry += 1
431
- if err:
432
- raise err
433
-
434
-
435
- def get_tokenizer(repo_id):
436
- config_path = hf_hub_download_retry(repo_id=repo_id, filename=f"config.json")
437
- with open(config_path, "r") as f:
438
- config = json.load(f)
439
- tokenizer = MIDITokenizer(config["tokenizer"]["version"])
440
- tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
441
- return tokenizer
442
-
443
-
444
- number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
445
- 40: "Blush", 48: "Orchestra"}
446
- patch2number = {v: k for k, v in MIDI.Number2patch.items()}
447
- drum_kits2number = {v: k for k, v in number2drum_kits.items()}
448
- key_signatures = ['C♭', 'A♭m', 'G♭', 'E♭m', 'D♭', 'B♭m', 'A♭', 'Fm', 'E♭', 'Cm', 'B♭', 'Gm', 'F', 'Dm',
449
- 'C', 'Am', 'G', 'Em', 'D', 'Bm', 'A', 'F♯m', 'E', 'C♯m', 'B', 'G♯m', 'F♯', 'D♯m', 'C♯', 'A♯m']
450
-
451
- if __name__ == "__main__":
452
- parser = argparse.ArgumentParser()
453
- parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
454
- parser.add_argument("--port", type=int, default=7860, help="gradio server port")
455
- parser.add_argument("--device", type=str, default="cuda", help="device to run model")
456
- parser.add_argument("--batch", type=int, default=8, help="batch size")
457
- parser.add_argument("--max-gen", type=int, default=1024, help="max")
458
- opt = parser.parse_args()
459
- OUTPUT_BATCH_SIZE = opt.batch
460
- soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
461
- thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
462
- synthesizer = MidiSynthesizer(soundfont_path)
463
- models_info = {
464
- "generic pretrain model (tv2o-medium) by skytnt": [
465
- "skytnt/midi-model-tv2o-medium", "", {
466
- "jpop": "skytnt/midi-model-tv2om-jpop-lora",
467
- "touhou": "skytnt/midi-model-tv2om-touhou-lora"
468
- }
469
- ],
470
- "generic pretrain model (tv2o-large) by asigalov61": [
471
- "asigalov61/Music-Llama", "", {}
472
- ],
473
- "generic pretrain model (tv2o-medium) by asigalov61": [
474
- "asigalov61/Music-Llama-Medium", "", {}
475
- ],
476
- "generic pretrain model (tv1-medium) by skytnt": [
477
- "skytnt/midi-model", "", {}
478
- ]
479
- }
480
- models = {}
481
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
482
- device = "cuda"
483
-
484
- for name, (repo_id, path, loras) in models_info.items():
485
- model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
486
- model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
487
- tokenizer = get_tokenizer(repo_id)
488
- models[name] = [model_base_path, model_token_path, tokenizer]
489
- for lora_name, lora_repo in loras.items():
490
- model_base_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_base.onnx")
491
- model_token_path = hf_hub_download_retry(repo_id=lora_repo, filename=f"onnx/model_token.onnx")
492
- models[f"{name} with {lora_name} lora"] = [model_base_path, model_token_path, tokenizer]
493
-
494
- load_javascript()
495
- app = gr.Blocks(theme=gr.themes.Soft())
496
- with app:
497
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
498
- gr.Markdown("\n\n"
499
- "A modified version of the Midi-Generator for the IAT-360 Course by Ethan Lum\n\n"
500
- "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
501
- "[Open In Colab]"
502
- "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
503
- " or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
504
- " for unlimited generation\n\n"
505
- "**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer\n\n"
506
- "The current **best** model: generic pretrain model (tv2o-medium) by skytnt"
507
- )
508
- js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
509
- js_msg.change(None, [js_msg], [], js="""
510
- (msg_json) =>{
511
- let msgs = JSON.parse(msg_json);
512
- executeCallbacks(msgReceiveCallbacks, msgs);
513
- return [];
514
- }
515
- """)
516
- input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
517
- type="value", value=list(models.keys())[0])
518
- tab_select = gr.State(value=0)
519
- with gr.Tabs():
520
-
521
- with gr.TabItem("midi prompt") as tab2:
522
- input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
523
- input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
524
- step=1,
525
- value=128)
526
- input_reduce_cc_st = gr.Checkbox(label="reduce control_change and set_tempo events", value=True)
527
- input_remap_track_channel = gr.Checkbox(
528
- label="remap tracks and channels so each track has only one channel and in order", value=True)
529
- input_add_default_instr = gr.Checkbox(
530
- label="add a default instrument to channels that don't have an instrument", value=True)
531
- input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
532
- example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
533
- [input_midi, input_midi_events])
534
- with gr.TabItem("last output prompt") as tab3:
535
- gr.Markdown("Continue generating on the last output.")
536
- input_continuation_select = gr.Radio(label="select output to continue generating", value="all",
537
- choices=["all"] + [f"output{i + 1}" for i in
538
- range(OUTPUT_BATCH_SIZE)],
539
- type="index"
540
- )
541
- undo_btn = gr.Button("undo the last continuation")
542
-
543
-
544
- tab2.select(lambda: 1, None, tab_select, queue=False)
545
- tab3.select(lambda: 2, None, tab_select, queue=False)
546
- input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
547
- step=1, value=0)
548
- input_seed_rand = gr.Checkbox(label="random seed", value=True)
549
- input_gen_events = gr.Slider(label="generate max n midi events", minimum=1, maximum=opt.max_gen,
550
- step=1, value=opt.max_gen // 2)
551
- with gr.Accordion("options", open=False):
552
- input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
553
- input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
554
- input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
555
- input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
556
- input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
557
- example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
558
- [input_temp, input_top_p, input_top_k])
559
- run_btn = gr.Button("generate", variant="primary")
560
- # stop_btn = gr.Button("stop and output")
561
- output_midi_seq = gr.State()
562
- output_continuation_state = gr.State([0])
563
- midi_outputs = []
564
- audio_outputs = []
565
- with gr.Tabs(elem_id="output_tabs"):
566
- for i in range(OUTPUT_BATCH_SIZE):
567
- with gr.TabItem(f"output {i + 1}") as tab1:
568
- output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
569
- output_audio = gr.Audio(label="output audio", format="mp3", elem_id=f"midi_audio_{i}")
570
- output_midi = gr.File(label="output midi", file_types=[".mid"])
571
- midi_outputs.append(output_midi)
572
- audio_outputs.append(output_audio)
573
- run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
574
- input_continuation_select, input_instruments, input_drum_kit, input_bpm,
575
- input_time_sig, input_key_sig, input_midi, input_midi_events,
576
- input_reduce_cc_st, input_remap_track_channel,
577
- input_add_default_instr, input_remove_empty_channels,
578
- input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
579
- input_top_k, input_allow_cc],
580
- [output_midi_seq, output_continuation_state, input_seed, js_msg], queue=True)
581
- finish_run_event = run_event.then(fn=finish_run,
582
- inputs=[input_model, output_midi_seq],
583
- outputs=midi_outputs + [js_msg],
584
- queue=False)
585
- finish_run_event.then(fn=render_audio,
586
- inputs=[input_model, output_midi_seq, input_render_audio],
587
- outputs=audio_outputs,
588
- queue=False)
589
- # stop_btn.click(None, [], [], cancels=run_event,
590
- # queue=False)
591
- undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
592
- [output_midi_seq, output_continuation_state, js_msg], queue=False)
593
- app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
594
- thread_pool.shutdown()