| import torch | |
| from contextlib import suppress | |
| from functools import partial | |
| def get_autocast(precision, device_type='cuda'): | |
| if precision =='amp': | |
| amp_dtype = torch.float16 | |
| elif precision == 'amp_bfloat16' or precision == 'amp_bf16': | |
| amp_dtype = torch.bfloat16 | |
| else: | |
| return suppress | |
| return partial(torch.amp.autocast, device_type=device_type, dtype=amp_dtype) |