RamziBm's picture
init
bdb955e
raw
history blame
1.29 kB
import yaml
import json
from typing import List
import torch
def tensor2list(d: dict):
tensor2list_lambda = lambda x: x.detach().cpu().numpy().tolist()
for k in d.keys():
if isinstance(d[k], torch.Tensor):
d[k] = tensor2list_lambda(d[k])
if isinstance(d[k], List):
if isinstance(d[k][0], torch.Tensor):
d[k] = [tensor2list_lambda(x) for x in d[k]]
return d
def write_json(json_serializable_dict, fout, indent=2):
with open(fout, "w") as fw:
json.dump(json_serializable_dict, fw, indent=indent)
def write_yaml(json_serializable_dict, fout):
with open(fout, "w") as fw:
yaml.dump(json_serializable_dict, fw, default_flow_style=False)
def detach_dict(x_dict):
with torch.no_grad():
for k in x_dict.keys():
if isinstance(x_dict[k], torch.Tensor):
x_dict[k] = x_dict[k].detach().cpu()
elif isinstance(x_dict[k], dict):
x_dict[k] = detach_dict(x_dict[k])
return x_dict
def tensor2list(xdict):
for k in xdict.keys():
if isinstance(xdict[k], torch.Tensor):
xdict[k] = xdict[k].numpy().tolist()
elif isinstance(xdict[k], dict):
xdict[k] = tensor2list(xdict[k])
return xdict