File size: 4,472 Bytes
2571f24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import torch
from torch.autograd import gradcheck
from interpol import grid_pull, grid_push, grid_count, grid_grad, add_identity_grid_
import pytest
import inspect
# global parameters
dtype = torch.double # data type (double advised to check gradients)
shape1 = 3 # size along each dimension
extrapolate = True
if hasattr(torch, 'use_deterministic_algorithms'):
torch.use_deterministic_algorithms(True)
kwargs = dict(rtol=1., raise_exception=True)
if 'check_undefined_grad' in inspect.signature(gradcheck).parameters:
kwargs['check_undefined_grad'] = False
if 'nondet_tol' in inspect.signature(gradcheck).parameters:
kwargs['nondet_tol'] = 1e-3
# parameters
devices = [('cpu', 1)]
if torch.backends.openmp.is_available() or torch.backends.mkl.is_available():
print('parallel backend available')
devices.append(('cpu', 10))
if torch.cuda.is_available():
print('cuda backend available')
devices.append('cuda')
dims = [1, 2, 3]
bounds = list(range(7))
order_bounds = []
for o in range(3):
for b in bounds:
order_bounds += [(o, b)]
for o in range(3, 8):
order_bounds += [(o, 3)] # only test dc2 for order > 2
def make_data(shape, device, dtype):
grid = torch.randn([2, *shape, len(shape)], device=device, dtype=dtype)
grid = add_identity_grid_(grid)
vol = torch.randn((2, 1,) + shape, device=device, dtype=dtype)
return vol, grid
def init_device(device):
if isinstance(device, (list, tuple)):
device, param = device
else:
param = 1 if device == 'cpu' else 0
if device == 'cuda':
torch.cuda.set_device(param)
torch.cuda.init()
try:
torch.cuda.empty_cache()
except RuntimeError:
pass
device = '{}:{}'.format(device, param)
else:
assert device == 'cpu'
torch.set_num_threads(param)
return torch.device(device)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_grad(device, dim, bound, interpolation):
print(f'grad_{dim}d({interpolation}, {bound}) on {device}')
device = init_device(device)
shape = (shape1,) * dim
vol, grid = make_data(shape, device, dtype)
vol.requires_grad = True
grid.requires_grad = True
assert gradcheck(grid_grad, (vol, grid, interpolation, bound, extrapolate),
**kwargs)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_pull(device, dim, bound, interpolation):
print(f'pull_{dim}d({interpolation}, {bound}) on {device}')
device = init_device(device)
shape = (shape1,) * dim
vol, grid = make_data(shape, device, dtype)
vol.requires_grad = True
grid.requires_grad = True
assert gradcheck(grid_pull, (vol, grid, interpolation, bound, extrapolate),
**kwargs)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_push(device, dim, bound, interpolation):
print(f'push_{dim}d({interpolation}, {bound}) on {device}')
device = init_device(device)
shape = (shape1,) * dim
vol, grid = make_data(shape, device, dtype)
vol.requires_grad = True
grid.requires_grad = True
assert gradcheck(grid_push, (vol, grid, shape, interpolation, bound, extrapolate),
**kwargs)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_count(device, dim, bound, interpolation):
print(f'count_{dim}d({interpolation}, {bound}) on {device}')
device = init_device(device)
shape = (shape1,) * dim
_, grid = make_data(shape, device, dtype)
grid.requires_grad = True
assert gradcheck(grid_count, (grid, shape, interpolation, bound, extrapolate),
**kwargs) |