import yaml import json from typing import List import torch def tensor2list(d: dict): tensor2list_lambda = lambda x: x.detach().cpu().numpy().tolist() for k in d.keys(): if isinstance(d[k], torch.Tensor): d[k] = tensor2list_lambda(d[k]) if isinstance(d[k], List): if isinstance(d[k][0], torch.Tensor): d[k] = [tensor2list_lambda(x) for x in d[k]] return d def write_json(json_serializable_dict, fout, indent=2): with open(fout, "w") as fw: json.dump(json_serializable_dict, fw, indent=indent) def write_yaml(json_serializable_dict, fout): with open(fout, "w") as fw: yaml.dump(json_serializable_dict, fw, default_flow_style=False) def detach_dict(x_dict): with torch.no_grad(): for k in x_dict.keys(): if isinstance(x_dict[k], torch.Tensor): x_dict[k] = x_dict[k].detach().cpu() elif isinstance(x_dict[k], dict): x_dict[k] = detach_dict(x_dict[k]) return x_dict def tensor2list(xdict): for k in xdict.keys(): if isinstance(xdict[k], torch.Tensor): xdict[k] = xdict[k].numpy().tolist() elif isinstance(xdict[k], dict): xdict[k] = tensor2list(xdict[k]) return xdict