diff --git a/aten/src/ATen/native/cuda/GridSampler.cu b/aten/src/ATen/native/cuda/GridSampler.cu index e65f6d59c7d3..384658e75187 100644 --- a/aten/src/ATen/native/cuda/GridSampler.cu +++ b/aten/src/ATen/native/cuda/GridSampler.cu @@ -58,8 +58,8 @@ namespace { const index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; // get the corresponding input x, y co-ordinates from grid - scalar_t x = grid.data[grid_offset]; - scalar_t y = grid.data[grid_offset + grid_sCoor]; + opmath_t x = grid.data[grid_offset]; + opmath_t y = grid.data[grid_offset + grid_sCoor]; opmath_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); opmath_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); @@ -194,9 +194,9 @@ namespace { const index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; // get the corresponding input x, y, z co-ordinates from grid - scalar_t x = grid.data[grid_offset]; - scalar_t y = grid.data[grid_offset + grid_sCoor]; - scalar_t z = grid.data[grid_offset + 2 * grid_sCoor]; + opmath_t x = grid.data[grid_offset]; + opmath_t y = grid.data[grid_offset + grid_sCoor]; + opmath_t z = grid.data[grid_offset + 2 * grid_sCoor]; opmath_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); opmath_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); diff --git a/test/test_nn.py b/test/test_nn.py index 3e18244951b1..b883c97f9054 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9954,10 +9954,11 @@ class TestNNDeviceType(NNTestCase): (1, 1, 3, 3, 3)) grid[:, 1, 1, 1, 0] = float('inf') result = torch.nn.functional.grid_sample(image, grid, padding_mode='zeros') + tol_override = {'atol': 0.005, 'rtol': 0} if dtype == torch.half else {} self.assertEqual(result, torch.tensor([[[[[27., 26., 25.], [24., 23., 22.], [21., 20., 19.]], [[18., 17., 16.], [15., 0., 13.], [12., 11., 10.]], [[9., 8., 7.], [6., 5., 4.], [3., 2., 1.]]]]], - device=device, dtype=dtype)) + device=device, dtype=dtype), **tol_override) result.backward(torch.ones_like(result)) expected_grad = torch.ones_like(image) expected_grad[0, 0, 1, 1, 1] = 0 @@ -10066,20 +10067,23 @@ class TestNNDeviceType(NNTestCase): @onlyCUDA def test_grid_sample_half_precision(self): - def helper(shape_in, shape_out): + def helper(shape_in, shape_out, align_corners): for mode in ('bilinear', 'nearest', 'bicubic'): if len(shape_in) != 4 and mode == 'bicubic': continue data = torch.randn(shape_in, device='cuda', dtype=torch.half) grid = torch.rand(shape_out, device='cuda', dtype=torch.half) * 2.0 - 1.0 - out_half = F.grid_sample(data, grid, mode=mode, padding_mode='zeros', align_corners=False) - out_double = F.grid_sample(data.double(), grid.double(), mode=mode, padding_mode='zeros', align_corners=False) + out_half = F.grid_sample(data, grid, mode=mode, padding_mode='zeros', align_corners=align_corners) + out_double = F.grid_sample(data.double(), grid.double(), mode=mode, padding_mode='zeros', + align_corners=align_corners) self.assertEqual(out_half, out_double.half(), msg="grid_sample with mode = {} doesn't match".format(mode)) - helper((32, 64, 16, 16), (32, 8, 8, 2)) - helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3)) + helper((32, 64, 16, 16), (32, 8, 8, 2), True) + helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), True) + helper((32, 64, 16, 16), (32, 8, 8, 2), False) + helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False) def _test_gumbel_softmax_st_shapes(self, device, dtype, shape, dim, count_expected): logits = torch.randn(shape, dtype=torch.float, device=device)