Align CPU behavior with CUDA for ConvTranspose when out_channels=0 (#142859)

Fixes https://github.com/pytorch/pytorch/issues/142466.
Remove the `weight.numel() != 0` check to align the behavior with CUDA for `ConvTranspose` when `out_channels=0`. After removing this check, the existing code is already able to give an empty output in such case.

Test plan:
```
python -u test/nn/test_convolution.py -k test_ConvTranspose_output_channels_0_cpu_float32
python -u test/nn/test_convolution.py -k test_ConvTranspose_output_channels_0_cuda_float32
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142859
Approved by: https://github.com/mingfeima, https://github.com/malfet
This commit is contained in:
Wu, Chunyuan
2025-01-23 01:04:47 -05:00
committed by PyTorch MergeBot
parent 90448f0128
commit cb814c0b96
3 changed files with 28 additions and 4 deletions

View File

@ -78,8 +78,8 @@ static inline void slow_conv_transpose2d_shape_check(
if (weight.defined()) {
TORCH_CHECK(
weight.numel() != 0 && (weight.dim() == 2 || weight.dim() == 4),
"non-empty 2D or 4D weight tensor expected, but got: ",
(weight.dim() == 2 || weight.dim() == 4),
"2D or 4D weight tensor expected, but got: ",
weight.sizes());
if (bias.defined()) {
check_dim_size(bias, 1, 0, weight.size(1));

View File

@ -98,8 +98,8 @@ static inline void slow_conv_transpose3d_shape_check(
if (weight.defined()) {
/* TODO: TORCH_CHECK just have 2 args: condition and message */
TORCH_CHECK(
weight.numel() != 0 && weight.dim() == 5,
"non-empty 5D (n_output_plane x n_input_plane x kernel_depth",
weight.dim() == 5,
"5D (n_output_plane x n_input_plane x kernel_depth",
" x kernel_height x kernel_width) tensor ",
"expected for weight, but got: ",
weight.sizes());

View File

@ -40,6 +40,7 @@ from torch.testing._internal.common_device_type import (
skipCUDAIfRocmVersionLessThan,
skipMeta,
skipMPS,
skipXLA,
)
from torch.testing._internal.common_dtype import (
floating_and_complex_types_and,
@ -1749,6 +1750,29 @@ class TestConvolutionNNDeviceType(NNTestCase):
actual = F.conv2d(x, y, padding="same", dilation=3)
self.assertEqual(expect, actual, rtol=rtol, atol=atol)
@dtypes(torch.float)
# aten/src/ATen/native/mps/OperationUtils.mm: TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!"); on MPS
@expectedFailureMPS
@skipXLA
def test_ConvTranspose_output_channels_0(self, device, dtype):
class Model(nn.Module):
def __init__(self, operator, dim):
super().__init__()
self.op = eval(
f"torch.nn.{operator}{dim}d(in_channels=1, out_channels=0, kernel_size={tuple([1] * dim)})"
)
def forward(self, x):
x = self.op(x)
return x
for dim in [1, 2, 3]:
x = torch.randn([1] * (dim + 1), device=device, dtype=dtype)
model = Model("ConvTranspose", dim).to(device).to(dtype=dtype)
y = model(x)
self.assertEqual(y.numel(), 0)
self.assertEqual(x.shape[1:], y.shape[1:])
@dtypes(torch.float, torch.cfloat)
def test_conv3d_same_padding(self, device, dtype):
if dtype is torch.cfloat: