File size: 4,472 Bytes
2571f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from torch.autograd import gradcheck
from interpol import grid_pull, grid_push, grid_count, grid_grad, add_identity_grid_
import pytest
import inspect

# global parameters
dtype = torch.double        # data type (double advised to check gradients)
shape1 = 3                  # size along each dimension
extrapolate = True

if hasattr(torch, 'use_deterministic_algorithms'):
    torch.use_deterministic_algorithms(True)
kwargs = dict(rtol=1., raise_exception=True)
if 'check_undefined_grad' in inspect.signature(gradcheck).parameters:
    kwargs['check_undefined_grad'] = False
if 'nondet_tol' in inspect.signature(gradcheck).parameters:
    kwargs['nondet_tol'] = 1e-3

# parameters
devices = [('cpu', 1)]
if torch.backends.openmp.is_available() or torch.backends.mkl.is_available():
    print('parallel backend available')
    devices.append(('cpu', 10))
if torch.cuda.is_available():
    print('cuda backend available')
    devices.append('cuda')

dims = [1, 2, 3]
bounds = list(range(7))
order_bounds = []
for o in range(3):
    for b in bounds:
        order_bounds += [(o, b)]
for o in range(3, 8):
    order_bounds += [(o, 3)]  # only test dc2 for order > 2


def make_data(shape, device, dtype):
    grid = torch.randn([2, *shape, len(shape)], device=device, dtype=dtype)
    grid = add_identity_grid_(grid)
    vol = torch.randn((2, 1,) + shape, device=device, dtype=dtype)
    return vol, grid


def init_device(device):
    if isinstance(device, (list, tuple)):
        device, param = device
    else:
        param = 1 if device == 'cpu' else 0
    if device == 'cuda':
        torch.cuda.set_device(param)
        torch.cuda.init()
        try:
            torch.cuda.empty_cache()
        except RuntimeError:
            pass
        device = '{}:{}'.format(device, param)
    else:
        assert device == 'cpu'
        torch.set_num_threads(param)
    return torch.device(device)


@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_grad(device, dim, bound, interpolation):
    print(f'grad_{dim}d({interpolation}, {bound}) on {device}')
    device = init_device(device)
    shape = (shape1,) * dim
    vol, grid = make_data(shape, device, dtype)
    vol.requires_grad = True
    grid.requires_grad = True
    assert gradcheck(grid_grad, (vol, grid, interpolation, bound, extrapolate),
                     **kwargs)


@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_pull(device, dim, bound, interpolation):
    print(f'pull_{dim}d({interpolation}, {bound}) on {device}')
    device = init_device(device)
    shape = (shape1,) * dim
    vol, grid = make_data(shape, device, dtype)
    vol.requires_grad = True
    grid.requires_grad = True
    assert gradcheck(grid_pull, (vol, grid, interpolation, bound, extrapolate),
                     **kwargs)


@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_push(device, dim, bound, interpolation):
    print(f'push_{dim}d({interpolation}, {bound}) on {device}')
    device = init_device(device)
    shape = (shape1,) * dim
    vol, grid = make_data(shape, device, dtype)
    vol.requires_grad = True
    grid.requires_grad = True
    assert gradcheck(grid_push, (vol, grid, shape, interpolation, bound, extrapolate),
                     **kwargs)


@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dim", dims)
# @pytest.mark.parametrize("bound", bounds)
# @pytest.mark.parametrize("interpolation", orders)
@pytest.mark.parametrize("interpolation,bound", order_bounds)
def test_gradcheck_count(device, dim, bound, interpolation):
    print(f'count_{dim}d({interpolation}, {bound}) on {device}')
    device = init_device(device)
    shape = (shape1,) * dim
    _, grid = make_data(shape, device, dtype)
    grid.requires_grad = True
    assert gradcheck(grid_count, (grid, shape, interpolation, bound, extrapolate),
                     **kwargs)