svjack's picture
Upload folder using huggingface_hub
ef46f0f verified
raw
history blame
472 Bytes
import torch
def clean_memory_on_device(device):
if device.type == "cuda":
torch.cuda.empty_cache()
elif device.type == "cpu":
pass
elif device.type == "mps": # not tested
torch.mps.empty_cache()
def synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()