St0nedB's picture
initial commit
6e4ade2
raw
history blame
4.39 kB
# https://huggingface.co/St0nedB/deepest-public
import os
import numpy as np
import logging
import gradio as gr
import subprocess
import sys
import matplotlib
import matplotlib.pyplot as plt
from dataset import UAVDataset
logger = logging.basicConfig(level=logging.ERROR)
matplotlib.use("agg")
MARKER_STYLE = dict(
linestyle="none",
markersize=15,
marker="o",
fillstyle="none",
markeredgewidth=2,
color="none",
markerfacecolor="none",
markerfacecoloralt="none",
markeredgecolor="white",
)
SCENARIO = "1to2_H15_V11"
RXS = ["VGH0", "VGH1", "VGH2"]
DATASETS = []
CHANNELS = []
GROUNDTRUTHS = []
SLOWTIME_WINDOW = 100
NUM_WINDOWS = 0
def get_channel(x: np.ndarray, start_idx: int, window_slowtime: int, filter_clutter: bool=False, upsample: int = 1) -> np.ndarray:
if start_idx > x.shape[0] - window_slowtime:
raise ValueError(
"Start index must be smaller than the number of slowtime samples minus the window size.")
x = x[start_idx:start_idx+window_slowtime+1, :]
if filter_clutter:
x = np.diff(x, n=1, axis=0)
t_n, f_n = x.shape
y = np.fft.fft(np.fft.ifft(x, n=f_n*upsample, axis=1), n=t_n*upsample, axis=0)
y /= np.linalg.norm(y)
y = np.fft.fftshift(y, axes=0)
y = y[:, :80*upsample]
return y
def get_groundtruth(x: np.ndarray, start_idx: int, window_slowtime: int):
if start_idx > x.shape[0] - window_slowtime:
raise ValueError(
"Start index must be smaller than the number of slowtime samples minus the window size.")
delay = x[start_idx+window_slowtime//2, 0]
doppler = x[start_idx+window_slowtime//2, 1]
return np.array([delay, doppler])
def get_data(channel: np.ndarray, groundtruth: np.ndarray, window_slowtime: int, start_idx: int):
channel = get_channel(channel, start_idx, window_slowtime, filter_clutter=True, upsample=2)
groundtruth = get_groundtruth(groundtruth, start_idx, window_slowtime)
return channel, groundtruth
def update_fig(channel: np.ndarray, groundtruth: np.ndarray):
plt.close()
fig = plt.figure()
plt.imshow(20*np.log10(np.abs(channel)), aspect="auto",
cmap="inferno", vmin=-70, vmax=0, extent=[0, 1e-6, +1/(2*320e-6), -1/(2*320e-6)])
plt.plot(groundtruth[0], groundtruth[1], **MARKER_STYLE)
plt.xlabel("Delay [$\mu s$]")
plt.ylabel("Doppler-Shift [Hz]")
return fig
def update(channel: np.ndarray, groundtruth: np.ndarray, window_slowtime: int, start_idx: int):
channel_window, groundtruth_window = get_data(
channel, groundtruth, window_slowtime, int(start_idx))
fig = update_fig(channel_window, groundtruth_window)
return fig
def update_all(start_idx: int):
figs = []
for cc, gg in zip(CHANNELS, GROUNDTRUTHS):
fig = update(cc, gg, SLOWTIME_WINDOW, start_idx*SLOWTIME_WINDOW)
figs.append(fig)
return figs
def demo():
with gr.Blocks() as demo:
gr.Markdown(
"Demo for the [ISAC-UAV-Dataset](https://github.com/EMS-TU-Ilmenau/isac-uav-dataset)"
)
with gr.Row():
vgh0 = gr.Plot(update_all(1)[0], label="VGH0")
vgh1 = gr.Plot(update_all(1)[1], label="VGH1")
vgh2 = gr.Plot(update_all(1)[2], label="VGH2")
with gr.Row():
slider = gr.Slider(1, NUM_WINDOWS, 1, step=1, label="Slowtime Window", queue=True, every=1)
# update callbacks
slider.input(update_all, [slider], [vgh0, vgh1, vgh2], show_progress="minimal")
demo.launch()
def prepare_data():
files_exist = [os.path.exists(f"{SCENARIO}_{rx}_channel.h5") for rx in RXS]
if all(files_exist):
return
subprocess.check_call([sys.executable, "downloader.py", "--scenario", SCENARIO])
return
def update_globals():
global DATASETS, CHANNELS, GROUNDTRUTHS, NUM_WINDOWS
DATASETS = [
UAVDataset(
f"{SCENARIO}_{rx}_channel.h5",
f"{SCENARIO}_{rx}_target.h5",
) for rx in RXS
]
CHANNELS = [d.channel for d in DATASETS]
GROUNDTRUTHS = [d.groundtruth for d in DATASETS]
NUM_WINDOWS = (len(DATASETS[0]) - SLOWTIME_WINDOW) // SLOWTIME_WINDOW
return
def main():
prepare_data()
update_globals()
demo()
if __name__ == "__main__":
main()