Make the Index Rounding Mode Consistent Between the 2D and 3D GridSample Nearest Neighbor Interpolations (#97000)

## BC-breaking note:

This is technically a bugfix. Prior to this PR, for `torch.nn.functional.grid_sample(mode='nearest')` the 2D kernel used `std::nearbyint` whereas the 3D kernel used `std::round` in order to determine the nearest pixel locations after un-normalization of the grid. This PR fixes the 3D kernel to use `std::nearbyint` which rounds values that are exactly `<>.5` to the nearest even which is consistent with the behavior of `torch.round`. Unnormalized indices that are exactly `<>.5` will now be rounded to the nearest even instead of being rounded away from 0.

## Description
In the nearest neighbor interpolation mode, the 2D GridSample rounds index to the nearest even using [std::nearbyint](https://github.com/pytorch/pytorch/blob/v2.0.0/aten/src/ATen/native/cpu/zmath.h#L182) whereas the 3D GridSample rounds index away from zero using std::round. This discrepancy needs to be resolved. We are making both 2D GridSample and 3D GridSample to round to the nearest even.

## Unit Test Goals
1. Make sure the x dimension and y dimension rounding behaviors are the same for 2D GridSample.
2. ~~Make sure the 2D GridSample rounding mode is rounding to the nearest even.~~
3. Make sure the x dimension, y dimension, and z dimension rounding behaviors are the same for 3D GridSample.
4. ~~Make sure the 3D GridSample rounding mode is rounding to the nearest even.~~
5. The 2D GridSample and 3D GridSample rounding behaviors are exactly the same.

After some experiments, I found 2 and 4 are difficult to achieve. Even though I can compute the normalized coordinates corresponding to the unnormalized coordinates including [0, 0.5, 1.0, 1.5, 2.0, 2.5, ..., 10.0], the unnormalization process in the GridSample implementations always have a small chance of having floating point error. Therefore, it's not possible to unit test the rounding mode from the normalized coordinates.

## Unit Test Methods

The unit test is simple. By using the same values along the dimension to be tested in the input tensor and the same normalized indices in the grid tensor. The interpolation on the 2D GridSample x-dimension, 2D GridSample y-dimension, 3D GridSample x-dimension, 3D GridSample y-dimension, 3D GridSample z-dimension. Should produce exactly the same interpolated values.
If one CPU/CUDA 2D/3D implementation use a different rounding mode from others, the unit test shall fail.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97000
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Lei Mao
2023-04-05 18:47:03 +00:00
committed by PyTorch MergeBot
parent dcec2100b1
commit 937ba248eb
3 changed files with 176 additions and 14 deletions

View File

@ -174,9 +174,9 @@ namespace {
}
}
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int64_t ix_nearest = static_cast<int64_t>(std::round(ix));
int64_t iy_nearest = static_cast<int64_t>(std::round(iy));
int64_t iz_nearest = static_cast<int64_t>(std::round(iz));
int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
int64_t iz_nearest = static_cast<int64_t>(std::nearbyint(iz));
// assign nearest neighor pixel value to output pixel
scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
@ -411,9 +411,9 @@ namespace {
gGrid_ptr_NDHW[1] = giy_mult * giy;
gGrid_ptr_NDHW[2] = giz_mult * giz;
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int64_t ix_nearest = static_cast<int64_t>(std::round(ix));
int64_t iy_nearest = static_cast<int64_t>(std::round(iy));
int64_t iz_nearest = static_cast<int64_t>(std::round(iz));
int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
int64_t iz_nearest = static_cast<int64_t>(std::nearbyint(iz));
// assign nearest neighor pixel value to output pixel
scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;

View File

@ -283,9 +283,9 @@ namespace {
*out_ptr_NCDHW = out_acc;
}
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
index_t ix_nearest = static_cast<index_t>(std::round(ix));
index_t iy_nearest = static_cast<index_t>(std::round(iy));
index_t iz_nearest = static_cast<index_t>(std::round(iz));
index_t ix_nearest = static_cast<index_t>(std::nearbyint(ix));
index_t iy_nearest = static_cast<index_t>(std::nearbyint(iy));
index_t iz_nearest = static_cast<index_t>(std::nearbyint(iz));
// assign nearest neighor pixel value to output pixel
auto inp_ptr_NC = input.data + n * inp_sN;
@ -431,8 +431,8 @@ namespace {
gGrid_ptr_NHW[1] = giy_mult * giy;
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
if (input_requires_grad) {
index_t ix_nearest = static_cast<index_t>(std::round(ix));
index_t iy_nearest = static_cast<index_t>(std::round(iy));
index_t ix_nearest = static_cast<index_t>(std::nearbyint(ix));
index_t iy_nearest = static_cast<index_t>(std::nearbyint(iy));
// assign nearest neighor pixel value to output pixel
scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
@ -720,9 +720,9 @@ namespace {
gGrid_ptr_NDHW[2] = giz_mult * giz;
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
if (input_requires_grad) {
auto ix_nearest = static_cast<index_t>(std::round(ix));
auto iy_nearest = static_cast<index_t>(std::round(iy));
auto iz_nearest = static_cast<index_t>(std::round(iz));
auto ix_nearest = static_cast<index_t>(std::nearbyint(ix));
auto iy_nearest = static_cast<index_t>(std::nearbyint(iy));
auto iz_nearest = static_cast<index_t>(std::nearbyint(iz));
// assign nearest neighor pixel value to output pixel
scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;

View File

@ -6058,6 +6058,168 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
for input_requires_grad in [False, True]:
test(N, C, D, H, W, mode, padding_mode, align_corners, input_requires_grad)
def test_grid_sample_nearest_neighbor_rounding_mode_consistency(self):
device_list = ['cpu']
if TEST_CUDA:
device_list.append('cuda')
def normalize_indices(indices_unnormalized: torch.Tensor, dim_size: int, align_corners: bool):
if align_corners:
indices_normalized = 2 * indices_unnormalized / (dim_size - 1) - 1
else:
indices_normalized = (indices_unnormalized * 2 + 1) / dim_size - 1
return indices_normalized
test_dim_size = 10
non_test_dim_size = 9
step_size = 0.1
batch_size = 1
channel_size = 1
mode = 'nearest'
for device in device_list:
for padding_mode in ('zeros', 'border', 'reflection'):
for align_corners in (True, False):
# Unnormalized inquiry indices
inquiry_indices_unnormalized = torch.arange(
0,
test_dim_size - 1 + step_size, step_size,
dtype=torch.float32,
device=device
)
# Note that even though we are trying to create normalized indices
# which results in x.0 and x.5 indices after unnormalization,
# because of the numerical error,
# the rounding direction might not always be expected as designed.
# The best we could do is to ensure the rounding behaviors across
# different implementations for different dimensions are
# exactly the same.
inquiry_indices = normalize_indices(
indices_unnormalized=inquiry_indices_unnormalized,
dim_size=test_dim_size,
align_corners=align_corners
)
num_inqueries = inquiry_indices.shape[0]
inquiry_fixed_indices = torch.full((num_inqueries,), 0.5, dtype=torch.float32, device=device)
array_data = torch.rand(test_dim_size, dtype=torch.float32, device=device)
# 2D grid sample x-dim interpolation
# The input_tensor_2d_x is of shape
# [batch_size, channel_size, non_test_dim_size, test_dim_size]
input_tensor_2d_x = array_data.reshape(1, test_dim_size).repeat(
batch_size,
channel_size,
non_test_dim_size,
1
)
# The grid_tensor_2d_x is of shape
# [batch_size, 1, num_inqueries]
grid_tensor_2d_x = torch.cat(
tensors=(
inquiry_indices.reshape(num_inqueries, 1),
inquiry_fixed_indices.reshape(num_inqueries, 1),
),
dim=1
).repeat(batch_size, 1, 1, 1)
# The output_tensor_2d_x is of shape
# [batch_size, channel_size, 1, num_inqueries]
output_tensor_2d_x = F.grid_sample(
input=input_tensor_2d_x,
grid=grid_tensor_2d_x,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
)
# 2D grid sample y-dim interpolation
# The input_tensor_2d_y is of shape
# [batch_size, channel_size, test_dim_size, non_test_dim_size]
input_tensor_2d_y = torch.transpose(input_tensor_2d_x, 3, 2)
# The grid_tensor_2d_y is of shape
# [batch_size, 1, num_inqueries]
grid_tensor_2d_y = torch.index_select(
grid_tensor_2d_x,
-1,
torch.tensor([1, 0], dtype=torch.int64, device=device)
)
# The output_tensor_2d_y is of shape
# [batch_size, channel_size, 1, num_inqueries]
output_tensor_2d_y = F.grid_sample(
input=input_tensor_2d_y,
grid=grid_tensor_2d_y,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
)
self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_2d_y[0, 0, 0, :], atol=0, rtol=0)
# 3D grid sample x-dim interpolation
# The input_tensor_3d_x is of shape
# [batch_size, channel_size, non_test_dim_size, non_test_dim_size, test_dim_size]
input_tensor_3d_x = array_data.reshape(1, test_dim_size).repeat(
batch_size, channel_size, non_test_dim_size, non_test_dim_size, 1)
# The grid_tensor_3d_x is of shape
# [batch_size, 1, 1, num_inqueries]
grid_tensor_3d_x = torch.cat(
tensors=(
inquiry_indices.reshape(num_inqueries, 1),
inquiry_fixed_indices.reshape(num_inqueries, 1),
inquiry_fixed_indices.reshape(num_inqueries, 1),
),
dim=1
).repeat(batch_size, 1, 1, 1, 1)
# The output_tensor_3d_x is of shape
# [batch_size, channel_size, 1, 1, num_inqueries]
output_tensor_3d_x = F.grid_sample(
input=input_tensor_3d_x,
grid=grid_tensor_3d_x,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
)
self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_x[0, 0, 0, 0, :], atol=0, rtol=0)
# 3D grid sample y-dim interpolation
# The input_tensor_3d_y is of shape
# [batch_size, channel_size, non_test_dim_size, test_dim_size, non_test_dim_size]
input_tensor_3d_y = torch.transpose(input_tensor_3d_x, 4, 3)
# The grid_tensor_3d_y is of shape
# [batch_size, 1, 1, num_inqueries]
grid_tensor_3d_y = torch.index_select(
grid_tensor_3d_x,
-1,
torch.tensor([1, 0, 2], dtype=torch.int64, device=device)
)
# The output_tensor_3d_y is of shape
# [batch_size, channel_size, 1, 1, num_inqueries]
output_tensor_3d_y = F.grid_sample(
input=input_tensor_3d_y,
grid=grid_tensor_3d_y,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
)
self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_y[0, 0, 0, 0, :], atol=0, rtol=0)
# 3D grid sample z-dim interpolation
# The input_tensor_3d_z is of shape
# [batch_size, channel_size, non_test_dim_size, non_test_dim_size, test_dim_size]
input_tensor_3d_z = torch.transpose(input_tensor_3d_x, 4, 2)
# The grid_tensor_3d_z is of shape
# [batch_size, 1, 1, num_inqueries]
grid_tensor_3d_z = torch.index_select(
grid_tensor_3d_x,
-1,
torch.tensor([1, 2, 0], dtype=torch.int64, device=device)
)
# The output_tensor_3d_z is of shape
# [batch_size, channel_size, 1, 1, num_inqueries]
output_tensor_3d_z = F.grid_sample(
input=input_tensor_3d_z,
grid=grid_tensor_3d_z,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
)
self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_z[0, 0, 0, 0, :], atol=0, rtol=0)
def test_affine_grid(self):
# test known input on CPU
input = torch.arange(1., 7).view(1, 2, 3)