mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix edge case for size 1 channels dim in AdaptiveMaxPool (#116482)
Fixes https://github.com/pytorch/pytorch/issues/107842
Unlike `AdaptiveAvgPool`, `AdaptiveMaxPool` does not have a CUDA kernel for ChannelsLast. We workaround this by calling `contiguous()` on the input. However, there is an edge case when the channels dimension has size 1.
```python
>>> t = torch.randn(2, 1, 3, 3)
>>> t.stride()
(9, 9, 3, 1)
>>> t_c = t.to(memory_format=torch.channels_last)
>>> t_c.stride()
(9, 1, 3, 1) # (CHW, 1, CW, C)
>>> t_c.is_contiguous()
True # contiguity check doesn't check strides for singleton dimensions
```
Since the CUDA kernel treats the batch,`B`, and channels,`C`, dimensions as implicitly flattened and increments the data pointer for `input` to the start of the next plane using
669b182d33/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu (L67)
If our input falls into the aforementioned edge case, the `data_ptr` will not be incremented correctly. The simple fix for this is to calculate the stride for the channels dimension using $\prod_{i > 1}size(i)$
Analogous fix for the 3D case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116482
Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
dfc898ede4
commit
b5e83b8c50
@ -268,7 +268,11 @@ const Tensor& indices) {
|
||||
int64_t isizeH = input_.size(2);
|
||||
int64_t isizeW = input_.size(3);
|
||||
|
||||
int64_t istrideD = input_.stride(1);
|
||||
// In the kernel, the batch and channel dimensions are treated as if they
|
||||
// are flattened and istrideD is used as the stride of this flattened dim
|
||||
// Handle the edge case where input_.size(1) == 1, where despite passing the
|
||||
// contiguity check the stride might not be H * W
|
||||
int64_t istrideD = isizeH * isizeW;
|
||||
int64_t istrideH = input_.stride(2);
|
||||
int64_t istrideW = input_.stride(3);
|
||||
|
||||
|
@ -346,7 +346,11 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cuda)
|
||||
isizeH = input_.size(3);
|
||||
isizeW = input_.size(4);
|
||||
|
||||
istrideD = input_.stride(1);
|
||||
// In the kernel, the batch and channel dimensions are treated as if they
|
||||
// are flattened and istrideD is used as the stride of this flattened dim
|
||||
// Handle the edge case where input_.size(1) == 1, where despite passing the
|
||||
// contiguity check the stride might not be T * H * W
|
||||
istrideD = isizeT * isizeH * isizeW;
|
||||
istrideT = input_.stride(2);
|
||||
istrideH = input_.stride(3);
|
||||
istrideW = input_.stride(4);
|
||||
|
@ -1072,38 +1072,48 @@ torch.cuda.synchronize()
|
||||
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_adaptive_pooling_max_nhwc(self, device, dtype):
|
||||
def helper(n, c, h, w, output_height, output_width, contig):
|
||||
input = torch.randint(1, 10, (n, c, h, w), device=device, dtype=dtype)
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
grad = torch.randint(1, 10, (4, 8, output_height, output_width), device=device, dtype=dtype)
|
||||
grad = grad.contiguous(memory_format=torch.channels_last)
|
||||
def helper(input_size, output_plane_size, contig):
|
||||
n_plane_dims = len(output_plane_size)
|
||||
mod = torch.nn.AdaptiveMaxPool2d if n_plane_dims == 2 else torch.nn.AdaptiveMaxPool3d
|
||||
channels_last = torch.channels_last if n_plane_dims == 2 else torch.channels_last_3d
|
||||
output_size = input_size[:2] + output_plane_size
|
||||
input = torch.randint(1, 10, input_size, device=device, dtype=dtype)
|
||||
input = input.contiguous(memory_format=channels_last)
|
||||
grad = torch.randint(1, 10, output_size, device=device, dtype=dtype)
|
||||
grad = grad.contiguous(memory_format=channels_last)
|
||||
if not contig:
|
||||
input = input[:, ::2, :, :]
|
||||
grad = grad[:, ::2, :, :]
|
||||
input = input[:, ::2]
|
||||
grad = grad[:, ::2]
|
||||
input.requires_grad_(True)
|
||||
pool = torch.nn.AdaptiveMaxPool2d((output_height, output_width), return_indices=True).to(device)
|
||||
pool = mod(output_plane_size, return_indices=True).to(device)
|
||||
|
||||
ref_input = input.detach().clone().contiguous().requires_grad_(True)
|
||||
ref_grad = grad.detach().clone().contiguous()
|
||||
ref_pool = torch.nn.AdaptiveMaxPool2d((output_height, output_width), return_indices=True).to(device)
|
||||
ref_pool = mod(output_plane_size, return_indices=True).to(device)
|
||||
|
||||
out, ind = pool(input)
|
||||
out.backward(grad)
|
||||
ref_out, ref_ind = ref_pool(ref_input)
|
||||
ref_out.backward(ref_grad)
|
||||
|
||||
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
||||
# channels_last_3d case does not return channels_last_3d outputs
|
||||
if n_plane_dims == 2:
|
||||
self.assertTrue(out.is_contiguous(memory_format=channels_last))
|
||||
self.assertTrue(ind.is_contiguous(memory_format=channels_last))
|
||||
self.assertTrue(ref_out.is_contiguous())
|
||||
self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertTrue(ref_ind.is_contiguous())
|
||||
self.assertEqual(out, ref_out)
|
||||
self.assertEqual(ind, ref_ind)
|
||||
self.assertEqual(input.grad, ref_input.grad)
|
||||
|
||||
for contig in [True, False]:
|
||||
helper(4, 8, 10, 10, 7, 7, contig)
|
||||
helper(4, 8, 9, 14, 5, 8, contig)
|
||||
helper(4, 8, 11, 11, 1, 1, contig)
|
||||
helper((4, 8, 10, 10), (7, 7), contig)
|
||||
helper((4, 8, 9, 14), (5, 8), contig)
|
||||
helper((4, 8, 11, 11), (1, 1), contig)
|
||||
helper((2, 1, 3, 3), (1, 1), contig)
|
||||
helper((4, 8, 10, 10, 10), (7, 7, 7), contig)
|
||||
helper((4, 8, 11, 11, 11), (1, 1, 1), contig)
|
||||
helper((2, 1, 3, 3, 3), (1, 1, 1), contig)
|
||||
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_pooling_max_nhwc(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user