Hev832 commited on
Commit
00e0077
·
verified ·
1 Parent(s): 525511e

Delete lib/mdx.py

Browse files
Files changed (1) hide show
  1. lib/mdx.py +0 -289
lib/mdx.py DELETED
@@ -1,289 +0,0 @@
1
- import gc
2
- import hashlib
3
- import os
4
- import queue
5
- import threading
6
- import warnings
7
-
8
- import librosa
9
- import numpy as np
10
- import onnxruntime as ort
11
- import soundfile as sf
12
- import torch
13
- from tqdm import tqdm
14
-
15
- warnings.filterwarnings("ignore")
16
- stem_naming = {'Vocals': 'Instrumental', 'Other': 'Instruments', 'Instrumental': 'Vocals', 'Drums': 'Drumless', 'Bass': 'Bassless'}
17
-
18
-
19
- class MDXModel:
20
- def __init__(self, device, dim_f, dim_t, n_fft, hop=1024, stem_name=None, compensation=1.000):
21
- self.dim_f = dim_f
22
- self.dim_t = dim_t
23
- self.dim_c = 4
24
- self.n_fft = n_fft
25
- self.hop = hop
26
- self.stem_name = stem_name
27
- self.compensation = compensation
28
-
29
- self.n_bins = self.n_fft // 2 + 1
30
- self.chunk_size = hop * (self.dim_t - 1)
31
- self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
32
-
33
- out_c = self.dim_c
34
-
35
- self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
36
-
37
- def stft(self, x):
38
- x = x.reshape([-1, self.chunk_size])
39
- x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True)
40
- x = torch.view_as_real(x)
41
- x = x.permute([0, 3, 1, 2])
42
- x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 4, self.n_bins, self.dim_t])
43
- return x[:, :, :self.dim_f]
44
-
45
- def istft(self, x, freq_pad=None):
46
- freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
47
- x = torch.cat([x, freq_pad], -2)
48
- # c = 4*2 if self.target_name=='*' else 2
49
- x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
50
- x = x.permute([0, 2, 3, 1])
51
- x = x.contiguous()
52
- x = torch.view_as_complex(x)
53
- x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
54
- return x.reshape([-1, 2, self.chunk_size])
55
-
56
-
57
- class MDX:
58
- DEFAULT_SR = 44100
59
- # Unit: seconds
60
- DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
61
- DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
62
-
63
- DEFAULT_PROCESSOR = 0
64
-
65
- def __init__(self, model_path: str, params: MDXModel, processor=DEFAULT_PROCESSOR):
66
-
67
- # Set the device and the provider (CPU or CUDA)
68
- self.device = torch.device(f'cuda:{processor}') if processor >= 0 else torch.device('cpu')
69
- self.provider = ['CUDAExecutionProvider'] if processor >= 0 else ['CPUExecutionProvider']
70
-
71
- self.model = params
72
-
73
- # Load the ONNX model using ONNX Runtime
74
- self.ort = ort.InferenceSession(model_path, providers=self.provider)
75
- # Preload the model for faster performance
76
- self.ort.run(None, {'input': torch.rand(1, 4, params.dim_f, params.dim_t).numpy()})
77
- self.process = lambda spec: self.ort.run(None, {'input': spec.cpu().numpy()})[0]
78
-
79
- self.prog = None
80
-
81
- @staticmethod
82
- def get_hash(model_path):
83
- try:
84
- with open(model_path, 'rb') as f:
85
- f.seek(- 10000 * 1024, 2)
86
- model_hash = hashlib.md5(f.read()).hexdigest()
87
- except:
88
- model_hash = hashlib.md5(open(model_path, 'rb').read()).hexdigest()
89
-
90
- return model_hash
91
-
92
- @staticmethod
93
- def segment(wave, combine=True, chunk_size=DEFAULT_CHUNK_SIZE, margin_size=DEFAULT_MARGIN_SIZE):
94
- """
95
- Segment or join segmented wave array
96
-
97
- Args:
98
- wave: (np.array) Wave array to be segmented or joined
99
- combine: (bool) If True, combines segmented wave array. If False, segments wave array.
100
- chunk_size: (int) Size of each segment (in samples)
101
- margin_size: (int) Size of margin between segments (in samples)
102
-
103
- Returns:
104
- numpy array: Segmented or joined wave array
105
- """
106
-
107
- if combine:
108
- processed_wave = None # Initializing as None instead of [] for later numpy array concatenation
109
- for segment_count, segment in enumerate(wave):
110
- start = 0 if segment_count == 0 else margin_size
111
- end = None if segment_count == len(wave) - 1 else -margin_size
112
- if margin_size == 0:
113
- end = None
114
- if processed_wave is None: # Create array for first segment
115
- processed_wave = segment[:, start:end]
116
- else: # Concatenate to existing array for subsequent segments
117
- processed_wave = np.concatenate((processed_wave, segment[:, start:end]), axis=-1)
118
-
119
- else:
120
- processed_wave = []
121
- sample_count = wave.shape[-1]
122
-
123
- if chunk_size <= 0 or chunk_size > sample_count:
124
- chunk_size = sample_count
125
-
126
- if margin_size > chunk_size:
127
- margin_size = chunk_size
128
-
129
- for segment_count, skip in enumerate(range(0, sample_count, chunk_size)):
130
-
131
- margin = 0 if segment_count == 0 else margin_size
132
- end = min(skip + chunk_size + margin_size, sample_count)
133
- start = skip - margin
134
-
135
- cut = wave[:, start:end].copy()
136
- processed_wave.append(cut)
137
-
138
- if end == sample_count:
139
- break
140
-
141
- return processed_wave
142
-
143
- def pad_wave(self, wave):
144
- """
145
- Pad the wave array to match the required chunk size
146
-
147
- Args:
148
- wave: (np.array) Wave array to be padded
149
-
150
- Returns:
151
- tuple: (padded_wave, pad, trim)
152
- - padded_wave: Padded wave array
153
- - pad: Number of samples that were padded
154
- - trim: Number of samples that were trimmed
155
- """
156
- n_sample = wave.shape[1]
157
- trim = self.model.n_fft // 2
158
- gen_size = self.model.chunk_size - 2 * trim
159
- pad = gen_size - n_sample % gen_size
160
-
161
- # Padded wave
162
- wave_p = np.concatenate((np.zeros((2, trim)), wave, np.zeros((2, pad)), np.zeros((2, trim))), 1)
163
-
164
- mix_waves = []
165
- for i in range(0, n_sample + pad, gen_size):
166
- waves = np.array(wave_p[:, i:i + self.model.chunk_size])
167
- mix_waves.append(waves)
168
-
169
- mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
170
-
171
- return mix_waves, pad, trim
172
-
173
- def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
174
- """
175
- Process each wave segment in a multi-threaded environment
176
-
177
- Args:
178
- mix_waves: (torch.Tensor) Wave segments to be processed
179
- trim: (int) Number of samples trimmed during padding
180
- pad: (int) Number of samples padded during padding
181
- q: (queue.Queue) Queue to hold the processed wave segments
182
- _id: (int) Identifier of the processed wave segment
183
-
184
- Returns:
185
- numpy array: Processed wave segment
186
- """
187
- mix_waves = mix_waves.split(1)
188
- with torch.no_grad():
189
- pw = []
190
- for mix_wave in mix_waves:
191
- self.prog.update()
192
- spec = self.model.stft(mix_wave)
193
- processed_spec = torch.tensor(self.process(spec))
194
- processed_wav = self.model.istft(processed_spec.to(self.device))
195
- processed_wav = processed_wav[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).cpu().numpy()
196
- pw.append(processed_wav)
197
- processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
198
- q.put({_id: processed_signal})
199
- return processed_signal
200
-
201
- def process_wave(self, wave: np.array, mt_threads=1):
202
- """
203
- Process the wave array in a multi-threaded environment
204
-
205
- Args:
206
- wave: (np.array) Wave array to be processed
207
- mt_threads: (int) Number of threads to be used for processing
208
-
209
- Returns:
210
- numpy array: Processed wave array
211
- """
212
- self.prog = tqdm(total=0)
213
- chunk = wave.shape[-1] // mt_threads
214
- waves = self.segment(wave, False, chunk)
215
-
216
- # Create a queue to hold the processed wave segments
217
- q = queue.Queue()
218
- threads = []
219
- for c, batch in enumerate(waves):
220
- mix_waves, pad, trim = self.pad_wave(batch)
221
- self.prog.total = len(mix_waves) * mt_threads
222
- thread = threading.Thread(target=self._process_wave, args=(mix_waves, trim, pad, q, c))
223
- thread.start()
224
- threads.append(thread)
225
- for thread in threads:
226
- thread.join()
227
- self.prog.close()
228
-
229
- processed_batches = []
230
- while not q.empty():
231
- processed_batches.append(q.get())
232
- processed_batches = [list(wave.values())[0] for wave in
233
- sorted(processed_batches, key=lambda d: list(d.keys())[0])]
234
- assert len(processed_batches) == len(waves), 'Incomplete processed batches, please reduce batch size!'
235
- return self.segment(processed_batches, True, chunk)
236
-
237
-
238
- def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False, suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=2):
239
- device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
240
-
241
- device_properties = torch.cuda.get_device_properties(device)
242
- vram_gb = device_properties.total_memory / 1024**3
243
- m_threads = 1 if vram_gb < 8 else 2
244
-
245
- model_hash = MDX.get_hash(model_path)
246
- mp = model_params.get(model_hash)
247
- model = MDXModel(
248
- device,
249
- dim_f=mp["mdx_dim_f_set"],
250
- dim_t=2 ** mp["mdx_dim_t_set"],
251
- n_fft=mp["mdx_n_fft_scale_set"],
252
- stem_name=mp["primary_stem"],
253
- compensation=mp["compensate"]
254
- )
255
-
256
- mdx_sess = MDX(model_path, model)
257
- wave, sr = librosa.load(filename, mono=False, sr=44100)
258
- # normalizing input wave gives better output
259
- peak = max(np.max(wave), abs(np.min(wave)))
260
- wave /= peak
261
- if denoise:
262
- wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))
263
- wave_processed *= 0.5
264
- else:
265
- wave_processed = mdx_sess.process_wave(wave, m_threads)
266
- # return to previous peak
267
- wave_processed *= peak
268
- stem_name = model.stem_name if suffix is None else suffix
269
-
270
- main_filepath = None
271
- if not exclude_main:
272
- main_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav")
273
- sf.write(main_filepath, wave_processed.T, sr)
274
-
275
- invert_filepath = None
276
- if not exclude_inversion:
277
- diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix
278
- stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
279
- invert_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav")
280
- sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr)
281
-
282
- if not keep_orig:
283
- os.remove(filename)
284
-
285
- del mdx_sess, wave_processed, wave
286
- if torch.cuda.is_available():
287
- torch.cuda.empty_cache()
288
- gc.collect()
289
- return main_filepath, invert_filepath