masszhou commited on
Commit
0ec3ee3
·
1 Parent(s): 3e18896

fix tensor type

Browse files
Files changed (2) hide show
  1. mdxnet_model.py +1 -3
  2. uvr_processing.py +2 -0
mdxnet_model.py CHANGED
@@ -225,9 +225,7 @@ class MDX:
225
  waves = np.array(wave_p[:, i:i + self.model.chunk_size])
226
  mix_waves.append(waves)
227
 
228
- mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(
229
- self.device
230
- )
231
 
232
  return mix_waves, pad, trim
233
 
 
225
  waves = np.array(wave_p[:, i:i + self.model.chunk_size])
226
  mix_waves.append(waves)
227
 
228
+ mix_waves = torch.tensor(np.array(mix_waves), dtype=torch.float32).to(self.device)
 
 
229
 
230
  return mix_waves, pad, trim
231
 
uvr_processing.py CHANGED
@@ -43,6 +43,7 @@ def run_mdx(model_params: Dict,
43
  device = torch.device("cpu")
44
  processor_num = -1
45
  m_threads = 1
 
46
 
47
  model_hash = MDX.get_hash(model_path) # type: str
48
  mp = model_params.get(model_hash)
@@ -90,6 +91,7 @@ def run_mdx_cpu(model_params: Dict,
90
  denoise: bool = False,
91
  m_threads: int = 2,
92
  device_base: str = ""):
 
93
  m_threads = 1
94
  duration = librosa.get_duration(filename=input_filename)
95
  if duration >= 60 and duration <= 120:
 
43
  device = torch.device("cpu")
44
  processor_num = -1
45
  m_threads = 1
46
+ print(f"device: {device}")
47
 
48
  model_hash = MDX.get_hash(model_path) # type: str
49
  mp = model_params.get(model_hash)
 
91
  denoise: bool = False,
92
  m_threads: int = 2,
93
  device_base: str = ""):
94
+ print("run_mdx_cpu")
95
  m_threads = 1
96
  duration = librosa.get_duration(filename=input_filename)
97
  if duration >= 60 and duration <= 120: