import numpy as np import torch import torchaudio import torchaudio.transforms as T import matplotlib.pyplot as plt import os from typing import List, Tuple from config import LOGS_DIR ##Some utils: def load_audio_files(file_paths: List[str]) -> List[Tuple[torch.Tensor, int]]: """ Load multiple audio files and ensure they have the same length. Args: file_paths: List of paths to audio files Returns: List of tuples containing audio data and sample rate """ audio_data = [] for path in file_paths: # Load audio file waveform, sample_rate = torchaudio.load(path) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) audio_data.append((waveform.squeeze(), sample_rate)) # Verify all audio files have the same length and sample rate lengths = [len(audio) for audio, _ in audio_data] sample_rates = [sr for _, sr in audio_data] if len(set(lengths)) > 1: raise ValueError(f"Audio files have different lengths: {lengths}") if len(set(sample_rates)) > 1: raise ValueError(f"Audio files have different sample rates: {sample_rates}") return audio_data def normalize_audio_volumes(audio_data: List[Tuple[torch.Tensor, int]]) -> List[Tuple[torch.Tensor, int]]: """ Normalize the volume of each audio file to have the same energy level. Args: audio_data: List of tuples containing audio data and sample rate Returns: List of tuples containing normalized audio data and sample rate """ normalized_data = [] # Calculate RMS (Root Mean Square) for each audio rms_values = [] for audio, sr in audio_data: # Calculate energy (squared amplitude) energy = torch.mean(audio ** 2) # Calculate RMS (square root of mean energy) rms = torch.sqrt(energy) rms_values.append(rms) # Find the target RMS (we'll use the median to avoid outliers) target_rms = torch.median(torch.tensor(rms_values)) # Normalize each audio to the target RMS for (audio, sr), rms in zip(audio_data, rms_values): if rms > 0: # Avoid division by zero # Calculate scaling factor scaling_factor = target_rms / rms # Apply scaling normalized_audio = audio * scaling_factor else: normalized_audio = audio normalized_data.append((normalized_audio, sr)) return normalized_data def plot_energy_comparison(original_metrics: List[dict], normalized_metrics: List[dict], file_names: List[str], output_path: str = "./logs/energy_comparison.png") -> None: """ Plot a comparison of energy metrics before and after normalization. Args: original_metrics: List of dictionaries containing metrics for original audio normalized_metrics: List of dictionaries containing metrics for normalized audio file_names: List of audio file names output_path: Path to save the plot """ fig, axs = plt.subplots(2, 2, figsize=(14, 10)) # Extract metrics orig_rms = [m['rms'] for m in original_metrics] norm_rms = [m['rms'] for m in normalized_metrics] orig_peak = [m['peak'] for m in original_metrics] norm_peak = [m['peak'] for m in normalized_metrics] orig_dr = [m['dynamic_range_db'] for m in original_metrics] norm_dr = [m['dynamic_range_db'] for m in normalized_metrics] orig_cf = [m['crest_factor'] for m in original_metrics] norm_cf = [m['crest_factor'] for m in normalized_metrics] # Prepare x-axis x = np.arange(len(file_names)) width = 0.35 # Plot RMS (volume) axs[0, 0].bar(x - width/2, orig_rms, width, label='Original') axs[0, 0].bar(x + width/2, norm_rms, width, label='Normalized') axs[0, 0].set_title('RMS Energy (Volume)') axs[0, 0].set_xticks(x) axs[0, 0].set_xticklabels(file_names, rotation=45, ha='right') axs[0, 0].set_ylabel('RMS Value') axs[0, 0].legend() # Plot Peak Amplitude axs[0, 1].bar(x - width/2, orig_peak, width, label='Original') axs[0, 1].bar(x + width/2, norm_peak, width, label='Normalized') axs[0, 1].set_title('Peak Amplitude') axs[0, 1].set_xticks(x) axs[0, 1].set_xticklabels(file_names, rotation=45, ha='right') axs[0, 1].set_ylabel('Peak Value') axs[0, 1].legend() # Plot Dynamic Range axs[1, 0].bar(x - width/2, orig_dr, width, label='Original') axs[1, 0].bar(x + width/2, norm_dr, width, label='Normalized') axs[1, 0].set_title('Dynamic Range (dB)') axs[1, 0].set_xticks(x) axs[1, 0].set_xticklabels(file_names, rotation=45, ha='right') axs[1, 0].set_ylabel('dB') axs[1, 0].legend() # Plot Crest Factor axs[1, 1].bar(x - width/2, orig_cf, width, label='Original') axs[1, 1].bar(x + width/2, norm_cf, width, label='Normalized') axs[1, 1].set_title('Crest Factor (Peak-to-RMS Ratio)') axs[1, 1].set_xticks(x) axs[1, 1].set_xticklabels(file_names, rotation=45, ha='right') axs[1, 1].set_ylabel('Ratio') axs[1, 1].legend() plt.tight_layout() # Create directory if it doesn't exist os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) # Save the plot plt.savefig(output_path) plt.close() def calculate_audio_metrics(audio_data: List[Tuple[torch.Tensor, int]]) -> List[dict]: """ Calculate various audio metrics for each audio file. Args: audio_data: List of tuples containing audio data and sample rate Returns: List of dictionaries containing metrics """ metrics = [] for audio, sr in audio_data: # Calculate RMS (Root Mean Square) energy = torch.mean(audio ** 2) rms = torch.sqrt(energy) # Calculate peak amplitude peak = torch.max(torch.abs(audio)) # Calculate dynamic range if torch.min(torch.abs(audio[audio != 0])) > 0: min_non_zero = torch.min(torch.abs(audio[audio != 0])) dynamic_range_db = 20 * torch.log10(peak / min_non_zero) else: dynamic_range_db = torch.tensor(float('inf')) # Calculate crest factor (peak to RMS ratio) crest_factor = peak / rms if rms > 0 else torch.tensor(float('inf')) metrics.append({ 'rms': rms.item(), 'peak': peak.item(), 'dynamic_range_db': dynamic_range_db.item() if not torch.isinf(dynamic_range_db) else float('inf'), 'crest_factor': crest_factor.item() if not torch.isinf(crest_factor) else float('inf') }) return metrics def create_weighted_composite( audio_data: List[Tuple[torch.Tensor, int]], weights: List[float] ) -> torch.Tensor: """ Create a weighted composite of multiple audio files. Args: audio_data: List of tuples containing audio data and sample rate weights: List of weights for each audio file Returns: Weighted composite audio data """ if len(audio_data) != len(weights): raise ValueError("Number of audio files and weights must match") # Normalize weights to sum to 1 weights = torch.tensor(weights) / sum(weights) # Initialize composite audio with zeros composite = torch.zeros_like(audio_data[0][0]) # Add weighted audio data for (audio, _), weight in zip(audio_data, weights): composite += audio * weight # Normalize to prevent clipping max_val = torch.max(torch.abs(composite)) if max_val > 1.0: composite = composite / max_val return composite def create_melspectrograms( audio_data: List[Tuple[torch.Tensor, int]], composite: torch.Tensor, sr: int ) -> List[torch.Tensor]: """ Create melspectrograms for individual audio files and the composite. Args: audio_data: List of tuples containing audio data and sample rate composite: Composite audio data sr: Sample rate Returns: List of melspectrogram data """ specs = [] # Create mel spectrogram transform mel_transform = T.MelSpectrogram( sample_rate=sr, n_fft=2048, win_length=2048, hop_length=512, n_mels=128, f_max=8000 ) # Generate spectrograms for individual audio files for audio, _ in audio_data: melspec = mel_transform(audio) specs.append(melspec) # Generate spectrogram for composite audio composite_melspec = mel_transform(composite) specs.append(composite_melspec) return specs def plot_melspectrograms( specs: List[torch.Tensor], sr: int, file_names: List[str], weights: List[float], output_path: str = "melspectrograms.png" ) -> None: """ Plot melspectrograms for individual audio files and the composite. Args: specs: List of melspectrogram data sr: Sample rate file_names: List of audio file names weights: List of weights for each audio file output_path: Path to save the plot """ fig, axs = plt.subplots(len(specs), 1, figsize=(12, 4 * len(specs))) # Create labels for the plots labels = [f"{name} (weight: {weight:.2f})" for name, weight in zip(file_names, weights)] labels.append("Composite.wav") # Convert to dB scale (similar to librosa's power_to_db) def power_to_db(spec): return 10 * torch.log10(spec + 1e-10) # Plot each melspectrogram for i, (spec, label) in enumerate(zip(specs, labels)): spec_db = power_to_db(spec).numpy().squeeze() # For single subplot case if len(specs) == 1: ax = axs else: ax = axs[i] img = ax.imshow( spec_db, aspect='auto', origin='lower', interpolation='none', extent=[0, spec_db.shape[1], 0, sr/2] ) ax.set_title(label) ax.set_ylabel('Frequency (Hz)') ax.set_xlabel('Time Frames') # No colorbar as requested plt.tight_layout() # Create directory if it doesn't exist os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) # Save the plot plt.savefig(output_path,dpi=300) plt.close() def compose_audio( file_paths: List[str], weights: List[float], output_audio_path: str = os.path.join(LOGS_DIR, "composite.wav"), output_plot_path: str = os.path.join(LOGS_DIR, "plot/melspectrograms.png"), energy_plot_path: str = os.path.join(LOGS_DIR, "plot/energy_comparison.png") ) -> None: """ Main function to process audio files and create visualizations. Args: file_paths: List of paths to audio files (supports 4 audio files) weights: List of weights for each audio file output_audio_path: Path to save the composite audio output_plot_path: Path to save the melspectrogram plot energy_plot_path: Path to save the energy comparison plot """ # Load audio files audio_data = load_audio_files(file_paths) # # Calculate metrics for original audio print("Calculating metrics for original audio...") original_metrics = calculate_audio_metrics(audio_data) # Normalize audio volumes to have same energy level print("Normalizing audio volumes...") normalized_audio_data = normalize_audio_volumes(audio_data) # Calculate metrics for normalized audio print("Calculating metrics for normalized audio...") normalized_metrics = calculate_audio_metrics(normalized_audio_data) # Print energy comparison print("\nAudio Energy Comparison (RMS values):") print("-" * 50) print(f"{'File':<20} {'Original':<15} {'Normalized':<15} {'Scaling Factor':<15}") print("-" * 50) for i, path in enumerate(file_paths): file_name = path.split("/")[-1] orig_rms = original_metrics[i]['rms'] norm_rms = normalized_metrics[i]['rms'] scaling = norm_rms / orig_rms if orig_rms > 0 else float('inf') print(f"{file_name[:20]:<20} {orig_rms:<15.6f} {norm_rms:<15.6f} {scaling:<15.6f}") # Create energy comparison plot print("\nCreating energy comparison plot...") file_names = [path.split("/")[-1] for path in file_paths] plot_energy_comparison(original_metrics, normalized_metrics, file_names, energy_plot_path) # Get sample rate (all files have the same sample rate) sr = normalized_audio_data[0][1] # Create weighted composite print("\nCreating weighted composite...") composite = create_weighted_composite(normalized_audio_data, weights) # Create directory if it doesn't exist os.makedirs(os.path.dirname(output_audio_path) or '.', exist_ok=True) # Save composite audio print("Saving composite audio...") torchaudio.save(output_audio_path, composite.unsqueeze(0), sr) # Create melspectrograms for normalized audio (not original) print("Creating melspectrograms for normalized audio...") specs = create_melspectrograms(normalized_audio_data, composite, sr) # Get file names without path labeled_file_names = [path.split("/")[-1] for path in file_paths] # Plot melspectrograms print("Plotting melspectrograms...") plot_melspectrograms(specs, sr, labeled_file_names, weights, output_plot_path) print(f"\nComposite audio saved to {output_audio_path}") print(f"Melspectrograms saved to {output_plot_path}") print(f"Energy comparison saved to {energy_plot_path}") print(f"Composite audio saved to {output_audio_path}") print(f"Melspectrograms saved to {output_plot_path}") # if __name__ == "__main__": # import argparse # parser = argparse.ArgumentParser(description="Mix audio files with weights and create melspectrograms") # parser.add_argument("--files", nargs="+", required=True, help="Paths to audio files") # parser.add_argument("--weights", nargs="+", type=float, required=True, help="Weights for each audio file") # parser.add_argument("--output-audio", default="./logs/composite.wav", help="Path to save the composite audio") # parser.add_argument("--output-plot", default="./logs/melspectrograms.png", help="Path to save the melspectrogram plot") # args = parser.parse_args() # os.makedirs("./logs", exist_ok=True) # main(args.files, args.weights, args.output_audio, args.output_plot) # Example usage: # python audio_mixer.py --files audio1.wav audio2.wav audio3.wav audio4.wav --weights 0.4 0.3 0.2 0.1