mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Align CPU behavior with CUDA for ConvTranspose
when out_channels=0
(#142859)
Fixes https://github.com/pytorch/pytorch/issues/142466. Remove the `weight.numel() != 0` check to align the behavior with CUDA for `ConvTranspose` when `out_channels=0`. After removing this check, the existing code is already able to give an empty output in such case. Test plan: ``` python -u test/nn/test_convolution.py -k test_ConvTranspose_output_channels_0_cpu_float32 python -u test/nn/test_convolution.py -k test_ConvTranspose_output_channels_0_cuda_float32 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/142859 Approved by: https://github.com/mingfeima, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
698106951e
commit
0bff377880
@ -78,8 +78,8 @@ static inline void slow_conv_transpose2d_shape_check(
|
||||
|
||||
if (weight.defined()) {
|
||||
TORCH_CHECK(
|
||||
weight.numel() != 0 && (weight.dim() == 2 || weight.dim() == 4),
|
||||
"non-empty 2D or 4D weight tensor expected, but got: ",
|
||||
(weight.dim() == 2 || weight.dim() == 4),
|
||||
"2D or 4D weight tensor expected, but got: ",
|
||||
weight.sizes());
|
||||
if (bias.defined()) {
|
||||
check_dim_size(bias, 1, 0, weight.size(1));
|
||||
|
@ -98,8 +98,8 @@ static inline void slow_conv_transpose3d_shape_check(
|
||||
if (weight.defined()) {
|
||||
/* TODO: TORCH_CHECK just have 2 args: condition and message */
|
||||
TORCH_CHECK(
|
||||
weight.numel() != 0 && weight.dim() == 5,
|
||||
"non-empty 5D (n_output_plane x n_input_plane x kernel_depth",
|
||||
weight.dim() == 5,
|
||||
"5D (n_output_plane x n_input_plane x kernel_depth",
|
||||
" x kernel_height x kernel_width) tensor ",
|
||||
"expected for weight, but got: ",
|
||||
weight.sizes());
|
||||
|
@ -1749,6 +1749,28 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
actual = F.conv2d(x, y, padding="same", dilation=3)
|
||||
self.assertEqual(expect, actual, rtol=rtol, atol=atol)
|
||||
|
||||
@dtypes(torch.float)
|
||||
# aten/src/ATen/native/mps/OperationUtils.mm: TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!"); on MPS
|
||||
@expectedFailureMPS
|
||||
def test_ConvTranspose_output_channels_0(self, device, dtype):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, operator, dim):
|
||||
super().__init__()
|
||||
self.op = eval(
|
||||
f"torch.nn.{operator}{dim}d(in_channels=1, out_channels=0, kernel_size={tuple([1] * dim)})"
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.op(x)
|
||||
return x
|
||||
|
||||
for dim in [1, 2, 3]:
|
||||
x = torch.randn([1] * (dim + 1), device=device, dtype=dtype)
|
||||
model = Model("ConvTranspose", dim).to(device).to(dtype=dtype)
|
||||
y = model(x)
|
||||
self.assertEqual(y.numel(), 0)
|
||||
self.assertEqual(x.shape[1:], y.shape[1:])
|
||||
|
||||
@dtypes(torch.float, torch.cfloat)
|
||||
def test_conv3d_same_padding(self, device, dtype):
|
||||
if dtype is torch.cfloat:
|
||||
|
Reference in New Issue
Block a user