Fix edge case for size 1 channels dim in AdaptiveMaxPool (#116482)

Fixes https://github.com/pytorch/pytorch/issues/107842

Unlike `AdaptiveAvgPool`, `AdaptiveMaxPool` does not have a CUDA kernel for ChannelsLast. We workaround this by calling `contiguous()` on the input. However, there is an edge case when the channels dimension has size 1.

```python
>>> t = torch.randn(2, 1, 3, 3)
>>> t.stride()
(9, 9, 3, 1)
>>> t_c =  t.to(memory_format=torch.channels_last)
>>> t_c.stride()
(9, 1, 3, 1)  # (CHW, 1, CW, C)
>>> t_c.is_contiguous()
True  # contiguity check doesn't check strides for singleton dimensions
```

Since the CUDA kernel treats the batch,`B`, and  channels,`C`, dimensions as implicitly flattened and increments the data pointer for `input` to the start of the next plane using

669b182d33/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu (L67)

If our input falls into the aforementioned edge case, the `data_ptr` will not be incremented correctly. The simple fix for this is to calculate the stride for the channels dimension using $\prod_{i > 1}size(i)$

Analogous fix for the 3D case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116482
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2023-12-27 19:36:34 -08:00
committed by PyTorch MergeBot
parent dfc898ede4
commit b5e83b8c50
3 changed files with 34 additions and 16 deletions

View File

@ -268,7 +268,11 @@ const Tensor& indices) {
int64_t isizeH = input_.size(2);
int64_t isizeW = input_.size(3);
int64_t istrideD = input_.stride(1);
// In the kernel, the batch and channel dimensions are treated as if they
// are flattened and istrideD is used as the stride of this flattened dim
// Handle the edge case where input_.size(1) == 1, where despite passing the
// contiguity check the stride might not be H * W
int64_t istrideD = isizeH * isizeW;
int64_t istrideH = input_.stride(2);
int64_t istrideW = input_.stride(3);

View File

@ -346,7 +346,11 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cuda)
isizeH = input_.size(3);
isizeW = input_.size(4);
istrideD = input_.stride(1);
// In the kernel, the batch and channel dimensions are treated as if they
// are flattened and istrideD is used as the stride of this flattened dim
// Handle the edge case where input_.size(1) == 1, where despite passing the
// contiguity check the stride might not be T * H * W
istrideD = isizeT * isizeH * isizeW;
istrideT = input_.stride(2);
istrideH = input_.stride(3);
istrideW = input_.stride(4);

View File

@ -1072,38 +1072,48 @@ torch.cuda.synchronize()
@dtypes(torch.float, torch.double)
def test_adaptive_pooling_max_nhwc(self, device, dtype):
def helper(n, c, h, w, output_height, output_width, contig):
input = torch.randint(1, 10, (n, c, h, w), device=device, dtype=dtype)
input = input.contiguous(memory_format=torch.channels_last)
grad = torch.randint(1, 10, (4, 8, output_height, output_width), device=device, dtype=dtype)
grad = grad.contiguous(memory_format=torch.channels_last)
def helper(input_size, output_plane_size, contig):
n_plane_dims = len(output_plane_size)
mod = torch.nn.AdaptiveMaxPool2d if n_plane_dims == 2 else torch.nn.AdaptiveMaxPool3d
channels_last = torch.channels_last if n_plane_dims == 2 else torch.channels_last_3d
output_size = input_size[:2] + output_plane_size
input = torch.randint(1, 10, input_size, device=device, dtype=dtype)
input = input.contiguous(memory_format=channels_last)
grad = torch.randint(1, 10, output_size, device=device, dtype=dtype)
grad = grad.contiguous(memory_format=channels_last)
if not contig:
input = input[:, ::2, :, :]
grad = grad[:, ::2, :, :]
input = input[:, ::2]
grad = grad[:, ::2]
input.requires_grad_(True)
pool = torch.nn.AdaptiveMaxPool2d((output_height, output_width), return_indices=True).to(device)
pool = mod(output_plane_size, return_indices=True).to(device)
ref_input = input.detach().clone().contiguous().requires_grad_(True)
ref_grad = grad.detach().clone().contiguous()
ref_pool = torch.nn.AdaptiveMaxPool2d((output_height, output_width), return_indices=True).to(device)
ref_pool = mod(output_plane_size, return_indices=True).to(device)
out, ind = pool(input)
out.backward(grad)
ref_out, ref_ind = ref_pool(ref_input)
ref_out.backward(ref_grad)
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
# channels_last_3d case does not return channels_last_3d outputs
if n_plane_dims == 2:
self.assertTrue(out.is_contiguous(memory_format=channels_last))
self.assertTrue(ind.is_contiguous(memory_format=channels_last))
self.assertTrue(ref_out.is_contiguous())
self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(ref_ind.is_contiguous())
self.assertEqual(out, ref_out)
self.assertEqual(ind, ref_ind)
self.assertEqual(input.grad, ref_input.grad)
for contig in [True, False]:
helper(4, 8, 10, 10, 7, 7, contig)
helper(4, 8, 9, 14, 5, 8, contig)
helper(4, 8, 11, 11, 1, 1, contig)
helper((4, 8, 10, 10), (7, 7), contig)
helper((4, 8, 9, 14), (5, 8), contig)
helper((4, 8, 11, 11), (1, 1), contig)
helper((2, 1, 3, 3), (1, 1), contig)
helper((4, 8, 10, 10, 10), (7, 7, 7), contig)
helper((4, 8, 11, 11, 11), (1, 1, 1), contig)
helper((2, 1, 3, 3, 3), (1, 1, 1), contig)
@dtypes(torch.float, torch.double)
def test_pooling_max_nhwc(self, device, dtype):