# https://huggingface.co/St0nedB/deepest-public import os import numpy as np import logging import gradio as gr import subprocess import sys import matplotlib import matplotlib.pyplot as plt from dataset import UAVDataset logger = logging.basicConfig(level=logging.ERROR) matplotlib.use("agg") MARKER_STYLE = dict( linestyle="none", markersize=15, marker="o", fillstyle="none", markeredgewidth=2, color="none", markerfacecolor="none", markerfacecoloralt="none", markeredgecolor="white", ) SCENARIO = "1to2_H15_V11" RXS = ["VGH0", "VGH1", "VGH2"] DATASETS = [] CHANNELS = [] GROUNDTRUTHS = [] SLOWTIME_WINDOW = 100 NUM_WINDOWS = 0 def get_channel(x: np.ndarray, start_idx: int, window_slowtime: int, filter_clutter: bool=False, upsample: int = 1) -> np.ndarray: if start_idx > x.shape[0] - window_slowtime: raise ValueError( "Start index must be smaller than the number of slowtime samples minus the window size.") x = x[start_idx:start_idx+window_slowtime+1, :] if filter_clutter: x = np.diff(x, n=1, axis=0) t_n, f_n = x.shape y = np.fft.fft(np.fft.ifft(x, n=f_n*upsample, axis=1), n=t_n*upsample, axis=0) y /= np.linalg.norm(y) y = np.fft.fftshift(y, axes=0) y = y[:, :80*upsample] return y def get_groundtruth(x: np.ndarray, start_idx: int, window_slowtime: int): if start_idx > x.shape[0] - window_slowtime: raise ValueError( "Start index must be smaller than the number of slowtime samples minus the window size.") delay = x[start_idx+window_slowtime//2, 0] doppler = x[start_idx+window_slowtime//2, 1] return np.array([delay, doppler]) def get_data(channel: np.ndarray, groundtruth: np.ndarray, window_slowtime: int, start_idx: int): channel = get_channel(channel, start_idx, window_slowtime, filter_clutter=True, upsample=2) groundtruth = get_groundtruth(groundtruth, start_idx, window_slowtime) return channel, groundtruth def update_fig(channel: np.ndarray, groundtruth: np.ndarray): plt.close() fig = plt.figure() plt.imshow(20*np.log10(np.abs(channel)), aspect="auto", cmap="inferno", vmin=-70, vmax=0, extent=[0, 1e-6, +1/(2*320e-6), -1/(2*320e-6)]) plt.plot(groundtruth[0], groundtruth[1], **MARKER_STYLE) plt.xlabel("Delay [$\mu s$]") plt.ylabel("Doppler-Shift [Hz]") return fig def update(channel: np.ndarray, groundtruth: np.ndarray, window_slowtime: int, start_idx: int): channel_window, groundtruth_window = get_data( channel, groundtruth, window_slowtime, int(start_idx)) fig = update_fig(channel_window, groundtruth_window) return fig def update_all(start_idx: int): figs = [] for cc, gg in zip(CHANNELS, GROUNDTRUTHS): fig = update(cc, gg, SLOWTIME_WINDOW, start_idx*SLOWTIME_WINDOW) figs.append(fig) return figs def demo(): with gr.Blocks() as demo: gr.Markdown( "Demo for the [ISAC-UAV-Dataset](https://github.com/EMS-TU-Ilmenau/isac-uav-dataset)" ) with gr.Row(): vgh0 = gr.Plot(update_all(1)[0], label="VGH0") vgh1 = gr.Plot(update_all(1)[1], label="VGH1") vgh2 = gr.Plot(update_all(1)[2], label="VGH2") with gr.Row(): slider = gr.Slider(1, NUM_WINDOWS, 1, step=1, label="Slowtime Window", queue=True, every=1) # update callbacks slider.input(update_all, [slider], [vgh0, vgh1, vgh2], show_progress="minimal") demo.launch() def prepare_data(): files_exist = [os.path.exists(f"{SCENARIO}_{rx}_channel.h5") for rx in RXS] if all(files_exist): return subprocess.check_call([sys.executable, "downloader.py", "--scenario", SCENARIO]) return def update_globals(): global DATASETS, CHANNELS, GROUNDTRUTHS, NUM_WINDOWS DATASETS = [ UAVDataset( f"{SCENARIO}_{rx}_channel.h5", f"{SCENARIO}_{rx}_target.h5", ) for rx in RXS ] CHANNELS = [d.channel for d in DATASETS] GROUNDTRUTHS = [d.groundtruth for d in DATASETS] NUM_WINDOWS = (len(DATASETS[0]) - SLOWTIME_WINDOW) // SLOWTIME_WINDOW return def main(): prepare_data() update_globals() demo() if __name__ == "__main__": main()