import torch | |
def clean_memory_on_device(device): | |
if device.type == "cuda": | |
torch.cuda.empty_cache() | |
elif device.type == "cpu": | |
pass | |
elif device.type == "mps": # not tested | |
torch.mps.empty_cache() | |
def synchronize_device(device: torch.device): | |
if device.type == "cuda": | |
torch.cuda.synchronize() | |
elif device.type == "xpu": | |
torch.xpu.synchronize() | |
elif device.type == "mps": | |
torch.mps.synchronize() | |