Spaces:
Build error
Build error
from openTSNE import TSNE | |
import numpy as np | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import random | |
def visualize( | |
x, | |
y, | |
ax=None, | |
title=None, | |
draw_legend=True, | |
draw_centers=False, | |
draw_cluster_labels=False, | |
colors=None, | |
legend_kwargs=None, | |
label_order=None, | |
**kwargs | |
): | |
if ax is None: | |
_, ax = matplotlib.pyplot.subplots(figsize=(10, 8)) | |
if title is not None: | |
ax.set_title(title) | |
plot_params = {"alpha": kwargs.get("alpha", 0.6), "s": kwargs.get("s", 1)} | |
# Create main plot | |
if label_order is not None: | |
assert all(np.isin(np.unique(y), label_order)) | |
classes = [l for l in label_order if l in np.unique(y)] | |
else: | |
classes = np.unique(y) | |
if colors is None: | |
default_colors = matplotlib.rcParams["axes.prop_cycle"] | |
colors = {k: v["color"] for k, v in zip(classes, default_colors())} | |
point_colors = list(map(colors.get, y)) | |
ax.scatter(x[:, 0], x[:, 1], c=point_colors, rasterized=True, **plot_params) | |
# Plot mediods | |
if draw_centers: | |
centers = [] | |
for yi in classes: | |
mask = yi == y | |
centers.append(np.median(x[mask, :2], axis=0)) | |
centers = np.array(centers) | |
center_colors = list(map(colors.get, classes)) | |
ax.scatter( | |
centers[:, 0], centers[:, 1], c=center_colors, s=48, alpha=1, edgecolor="k" | |
) | |
# Draw mediod labels | |
if draw_cluster_labels: | |
for idx, label in enumerate(classes): | |
ax.text( | |
centers[idx, 0], | |
centers[idx, 1] + 2.2, | |
label, | |
fontsize=kwargs.get("fontsize", 6), | |
horizontalalignment="center", | |
) | |
# Hide ticks and axis | |
ax.set_xticks([]), ax.set_yticks([]), ax.axis("off") | |
if draw_legend: | |
legend_handles = [ | |
matplotlib.lines.Line2D( | |
[], | |
[], | |
marker="s", | |
color="w", | |
markerfacecolor=colors[yi], | |
ms=10, | |
alpha=1, | |
linewidth=0, | |
label=yi, | |
markeredgecolor="k", | |
) | |
for yi in classes | |
] | |
legend_kwargs_ = dict(loc="best", bbox_to_anchor=(0.05, 0.5), frameon=False, ) | |
if legend_kwargs is not None: | |
legend_kwargs_.update(legend_kwargs) | |
ax.legend(handles=legend_handles, **legend_kwargs_) | |
tsne = TSNE( | |
perplexity=30, | |
metric="euclidean", | |
n_jobs=8, | |
random_state=42, | |
verbose=True, | |
) | |
idexp_lm3d_pred_lrs3 = np.load("infer_out/tmp_npys/lrs3_pred_all.npy") | |
idx = np.random.choice(np.arange(len(idexp_lm3d_pred_lrs3)), 10000) | |
idexp_lm3d_pred_lrs3 = idexp_lm3d_pred_lrs3[idx] | |
person_ds = np.load("data/binary/videos/May/trainval_dataset.npy", allow_pickle=True).tolist() | |
person_idexp_mean = person_ds['idexp_lm3d_mean'].reshape([1,204]) | |
person_idexp_std = person_ds['idexp_lm3d_std'].reshape([1,204]) | |
person_idexp_lm3d_train = np.stack([s['idexp_lm3d_normalized'].reshape([204,]) for s in person_ds['train_samples']]) | |
person_idexp_lm3d_val = np.stack([s['idexp_lm3d_normalized'].reshape([204,]) for s in person_ds['val_samples']]) | |
lrs3_stats = np.load('/home/yezhenhui/datasets/binary/lrs3_0702/stats.npy',allow_pickle=True).tolist() | |
lrs3_idexp_mean = lrs3_stats['idexp_lm3d_mean'].reshape([1,204]) | |
lrs3_idexp_std = lrs3_stats['idexp_lm3d_std'].reshape([1,204]) | |
person_idexp_lm3d_train = person_idexp_lm3d_train * person_idexp_std + person_idexp_mean | |
# person_idexp_lm3d_train = (person_idexp_lm3d_train - lrs3_idexp_mean) / lrs3_idexp_std | |
person_idexp_lm3d_val = person_idexp_lm3d_val * person_idexp_std + person_idexp_mean | |
# person_idexp_lm3d_val = (person_idexp_lm3d_val - lrs3_idexp_mean) / lrs3_idexp_std | |
idexp_lm3d_pred_lrs3 = idexp_lm3d_pred_lrs3 * lrs3_idexp_std + lrs3_idexp_mean | |
idexp_lm3d_pred_vae = np.load("infer_out/tmp_npys/pred_exp_0_vae.npy").reshape([-1,204]) | |
idexp_lm3d_pred_postnet = np.load("infer_out/tmp_npys/pred_exp_0_postnet_hubert.npy").reshape([-1,204]) | |
# idexp_lm3d_pred_postnet = idexp_lm3d_pred_postnet * lrs3_idexp_std + lrs3_idexp_mean | |
idexp_lm3d_all = np.concatenate([idexp_lm3d_pred_lrs3, person_idexp_lm3d_train,idexp_lm3d_pred_vae, idexp_lm3d_pred_postnet]) | |
idexp_lm3d_all_emb = tsne.fit(idexp_lm3d_all) # array(float64) [B,50]==>[B, 2] | |
# z_p_emb = tsne.fit(z_p) # array(float64) [B,50]==>[B, 2] | |
y1 = ["pred_lrs3" for _ in range(len(idexp_lm3d_pred_lrs3))] | |
y2 = ["person_train" for _ in range(len(person_idexp_lm3d_train))] | |
y3 = ["vae" for _ in range(len(idexp_lm3d_pred_vae))] | |
y4 = ["postnet" for _ in range(len(idexp_lm3d_pred_postnet))] | |
visualize(idexp_lm3d_all_emb, y1+y2+y3+y4) | |
plt.savefig("infer_out/tmp_npys/lrs3_pred_all_0k.png") |