Spaces:
Runtime error
Runtime error
# https://huggingface.co/St0nedB/deepest-public | |
import os | |
import numpy as np | |
import logging | |
import gradio as gr | |
import subprocess | |
import sys | |
import matplotlib | |
import matplotlib.pyplot as plt | |
from dataset import UAVDataset | |
logger = logging.basicConfig(level=logging.ERROR) | |
matplotlib.use("agg") | |
MARKER_STYLE = dict( | |
linestyle="none", | |
markersize=15, | |
marker="o", | |
fillstyle="none", | |
markeredgewidth=2, | |
color="none", | |
markerfacecolor="none", | |
markerfacecoloralt="none", | |
markeredgecolor="white", | |
) | |
SCENARIO = "1to2_H15_V11" | |
RXS = ["VGH0", "VGH1", "VGH2"] | |
DATASETS = [] | |
CHANNELS = [] | |
GROUNDTRUTHS = [] | |
SLOWTIME_WINDOW = 100 | |
NUM_WINDOWS = 0 | |
def get_channel(x: np.ndarray, start_idx: int, window_slowtime: int, filter_clutter: bool=False, upsample: int = 1) -> np.ndarray: | |
if start_idx > x.shape[0] - window_slowtime: | |
raise ValueError( | |
"Start index must be smaller than the number of slowtime samples minus the window size.") | |
x = x[start_idx:start_idx+window_slowtime+1, :] | |
if filter_clutter: | |
x = np.diff(x, n=1, axis=0) | |
t_n, f_n = x.shape | |
y = np.fft.fft(np.fft.ifft(x, n=f_n*upsample, axis=1), n=t_n*upsample, axis=0) | |
y /= np.linalg.norm(y) | |
y = np.fft.fftshift(y, axes=0) | |
y = y[:, :80*upsample] | |
return y | |
def get_groundtruth(x: np.ndarray, start_idx: int, window_slowtime: int): | |
if start_idx > x.shape[0] - window_slowtime: | |
raise ValueError( | |
"Start index must be smaller than the number of slowtime samples minus the window size.") | |
delay = x[start_idx+window_slowtime//2, 0] | |
doppler = x[start_idx+window_slowtime//2, 1] | |
return np.array([delay, doppler]) | |
def get_data(channel: np.ndarray, groundtruth: np.ndarray, window_slowtime: int, start_idx: int): | |
channel = get_channel(channel, start_idx, window_slowtime, filter_clutter=True, upsample=2) | |
groundtruth = get_groundtruth(groundtruth, start_idx, window_slowtime) | |
return channel, groundtruth | |
def update_fig(channel: np.ndarray, groundtruth: np.ndarray): | |
plt.close() | |
fig = plt.figure() | |
plt.imshow(20*np.log10(np.abs(channel)), aspect="auto", | |
cmap="inferno", vmin=-70, vmax=0, extent=[0, 1e-6, +1/(2*320e-6), -1/(2*320e-6)]) | |
plt.plot(groundtruth[0], groundtruth[1], **MARKER_STYLE) | |
plt.xlabel("Delay [$\mu s$]") | |
plt.ylabel("Doppler-Shift [Hz]") | |
return fig | |
def update(channel: np.ndarray, groundtruth: np.ndarray, window_slowtime: int, start_idx: int): | |
channel_window, groundtruth_window = get_data( | |
channel, groundtruth, window_slowtime, int(start_idx)) | |
fig = update_fig(channel_window, groundtruth_window) | |
return fig | |
def update_all(start_idx: int): | |
figs = [] | |
for cc, gg in zip(CHANNELS, GROUNDTRUTHS): | |
fig = update(cc, gg, SLOWTIME_WINDOW, start_idx*SLOWTIME_WINDOW) | |
figs.append(fig) | |
return figs | |
def demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"Demo for the [ISAC-UAV-Dataset](https://github.com/EMS-TU-Ilmenau/isac-uav-dataset)" | |
) | |
with gr.Row(): | |
vgh0 = gr.Plot(update_all(1)[0], label="VGH0") | |
vgh1 = gr.Plot(update_all(1)[1], label="VGH1") | |
vgh2 = gr.Plot(update_all(1)[2], label="VGH2") | |
with gr.Row(): | |
slider = gr.Slider(1, NUM_WINDOWS, 1, step=1, label="Slowtime Window", queue=True, every=1) | |
# update callbacks | |
slider.input(update_all, [slider], [vgh0, vgh1, vgh2], show_progress="minimal") | |
demo.launch() | |
def prepare_data(): | |
files_exist = [os.path.exists(f"{SCENARIO}_{rx}_channel.h5") for rx in RXS] | |
if all(files_exist): | |
return | |
subprocess.check_call([sys.executable, "downloader.py", "--scenario", SCENARIO]) | |
return | |
def update_globals(): | |
global DATASETS, CHANNELS, GROUNDTRUTHS, NUM_WINDOWS | |
DATASETS = [ | |
UAVDataset( | |
f"{SCENARIO}_{rx}_channel.h5", | |
f"{SCENARIO}_{rx}_target.h5", | |
) for rx in RXS | |
] | |
CHANNELS = [d.channel for d in DATASETS] | |
GROUNDTRUTHS = [d.groundtruth for d in DATASETS] | |
NUM_WINDOWS = (len(DATASETS[0]) - SLOWTIME_WINDOW) // SLOWTIME_WINDOW | |
return | |
def main(): | |
prepare_data() | |
update_globals() | |
demo() | |
if __name__ == "__main__": | |
main() |