letitiaaa's picture
Initial commit
e66e8cc
import torch
from contextlib import suppress
from functools import partial
def get_autocast(precision, device_type='cuda'):
if precision =='amp':
amp_dtype = torch.float16
elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
amp_dtype = torch.bfloat16
else:
return suppress
return partial(torch.amp.autocast, device_type=device_type, dtype=amp_dtype)