[FakeTensor] Supplement the relevant logic for converting conv1d to conv2d in meta_conv (#160408)

## Fixes https://github.com/pytorch/pytorch/issues/159462 also fixes #163569 , #163604

## summary
the issue is caused by the wrong stride of conv1d's result generated by meta_conv:
4d5b3f2d5a/torch/_meta_registrations.py (L2453-L2471)

and the wrong stride will be used to codegen size assert in inductor:
4d5b3f2d5a/torch/_inductor/ir.py (L6152-L6163)

## reason
So why the computed stride is wrong in the meta_conv function? because the corresponding backend will convert conv1d to conv2d and change the input tensor' size and memory_format(channel last). but the meta_conv do not do this transformation, so a mismatch happend.
4d5b3f2d5a/aten/src/ATen/native/Convolution.cpp (L1502-L1510)
 just add corresponding logic in meta_conv.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160408
Approved by: https://github.com/eellison, https://github.com/jansel, https://github.com/mlazos
This commit is contained in:
thenumberouscode
2025-09-26 15:45:02 +00:00
committed by PyTorch MergeBot
parent 8aba513506
commit c106ee8515
7 changed files with 113 additions and 20 deletions

View File

@ -311,6 +311,33 @@ class CPUReproTests(TestCase):
(v,),
)
def test_conv1d_strided_weight_torch_compile(self):
def fn(x, w):
wt = w.transpose(2, 1)
y = F.conv1d(x, wt)
return y.clone()
x_eager = torch.randn(2, 3, 5, requires_grad=True)
w_eager = torch.randn(4, 2, 3, requires_grad=True)
out_eager = fn(x_eager, w_eager)
grad = torch.randn_like(out_eager)
out_eager_val = out_eager.detach()
out_eager.backward(grad)
grad_x_eager = x_eager.grad.detach().clone()
grad_w_eager = w_eager.grad.detach().clone()
x_comp = x_eager.detach().requires_grad_(True)
w_comp = w_eager.detach().requires_grad_(True)
compiled = torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)
out_comp = compiled(x_comp, w_comp)
out_comp_val = out_comp.detach()
out_comp.backward(grad)
torch.testing.assert_close(out_comp_val, out_eager_val)
torch.testing.assert_close(x_comp.grad, grad_x_eager)
torch.testing.assert_close(w_comp.grad, grad_w_eager)
@config.patch(freezing=True)
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@patch("torch.cuda.is_available", lambda: False)

View File

@ -4639,6 +4639,41 @@ class CommonTemplate:
(torch.randn([4, 4, 4]),),
)
def test_conv1d_with_permute(self):
# fix https://github.com/pytorch/pytorch/issues/159462
class ConvModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv1d(1, 64, kernel_size=3, padding=1)
def forward(self, x):
x = x.permute(0, 2, 1)
return self.conv(x)
self.common(ConvModel(), (torch.randn([32, 100, 1]),), check_lowp=False)
def test_conv1d_depthwise(self):
class ConvModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv1d(
768,
768,
kernel_size=(9,),
stride=(1,),
padding=(4,),
groups=768,
bias=False,
)
def forward(self, x):
return self.conv(x)
input_tensor = torch.randn([1, 768, 512]).as_strided(
(1, 768, 512), (393216, 1, 768)
)
self.common(ConvModel(), (input_tensor,), check_lowp=False)
def test_convolution1(self):
m = torch.nn.Sequential(
torch.nn.Conv2d(5, 6, [3, 3]),

View File

@ -110,6 +110,7 @@ test_failures = {
#
# Failed to find dynamic for loop variable:
#
"test_conv1d_with_permute_dynamic_shapes": TestFailure(("cpu",), is_skip=True),
"test_arange1_dynamic_shapes": TestFailure(("cpu",)),
"test_arange2_dynamic_shapes": TestFailure(("cpu",)),
"test_arange3_dynamic_shapes": TestFailure(("cpu",)),

View File

@ -134,11 +134,11 @@ class TestUtils(TestCase):
torch.Tensor(2, 2, 3),
torch.Tensor(2, 2, 2),
torch.Tensor(2),
(1, 1),
(0, 0),
(1, 1),
(1,),
(0,),
(1,),
True,
(0, 0),
(0,),
1,
),
{},

View File

@ -201,6 +201,26 @@ class FakeTensorTest(TestCase):
self.assertEqual(torch.ones([10]), out[0])
def test_conv_nhwc(self):
x = torch.randn([1, 1024, 16, 16]).to(memory_format=torch.channels_last)
w = torch.randn([256, 1024, 4, 4]).to(memory_format=torch.channels_last)
b = torch.randn([256])
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, w, b):
return torch.ops.aten.convolution(
x, w, b, [1, 1], [0, 0], [1, 1], False, [0, 0], 1
)
model = Model()
with FakeTensorMode(allow_non_fake_inputs=True) as mode:
fake_out = model.forward(x, w, b)
eager_out = model.forward(x, w, b)
self.assertEqual(fake_out.stride(), eager_out.stride())
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_zero_dim(self):
with FakeTensorMode() as mode:

View File

@ -2447,18 +2447,6 @@ def meta_conv(
output_padding: list[int],
groups: int,
):
def pick_memory_format():
if device_hint(input_tensor) == "cuda":
if is_channels_last(input_tensor) or is_channels_last(weight):
return torch.channels_last
else:
if is_channels_last(input_tensor):
return torch.channels_last
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
return torch.contiguous_format
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
return torch.preserve_format
shape_out = calc_conv_nd_return_shape(
input_tensor,
weight,
@ -2476,7 +2464,6 @@ def meta_conv(
shape_out[output_channels_dim] = 0
out = input_tensor.new_empty(shape_out)
out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
return out

View File

@ -1021,8 +1021,6 @@ def conv(fake_mode, func, *args, **kwargs):
# TODO: We can make this a little more faithful with best effort
# channels last detection (but only if it's statically obvious!)
mem_fmt = None
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
mem_fmt = None
else:
if func is aten.convolution.default:
conv_backend = torch._C._select_conv_backend(**kwargs)
@ -1039,14 +1037,39 @@ def conv(fake_mode, func, *args, **kwargs):
groups=kwargs["groups"],
bias_sizes=kwargs["bias_sizes"],
)
# Expand 1d -> 2d.
# Note: Avoid expanding before calling _select_conv_backend,
# as the function handles 2D expansion internally.
if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
# Note: Using input.to(memory_format=contiguous) does not work.
kwargs["input"] = kwargs["input"].contiguous().unsqueeze(2)
kwargs["weight"] = kwargs["weight"].unsqueeze(2)
if len(kwargs["stride"]) == 1:
kwargs["stride"].insert(0, 1)
kwargs["padding"].insert(0, 0)
kwargs["dilation"].insert(0, 1)
kwargs["output_padding"].insert(0, 0)
mem_fmt = torch._C._conv_determine_backend_memory_format(
kwargs["input"], kwargs["weight"], conv_backend
)
# revert 2d -> 1d
if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
kwargs["input"] = kwargs["input"].squeeze(2)
kwargs["weight"] = kwargs["weight"].squeeze(2)
if len(kwargs["stride"]) == 2:
kwargs["stride"].pop(0)
kwargs["padding"].pop(0)
kwargs["dilation"].pop(0)
kwargs["output_padding"].pop(0)
def convert(t, mem_fmt):
if t is None:
return t
if mem_fmt is not None:
# channels last only support 4d, try to expand dim then convert it back later.
if t.dim() == 3 and mem_fmt == torch.channels_last:
t = t.unsqueeze(2).to(memory_format=mem_fmt).squeeze(2)
else:
t = t.to(memory_format=mem_fmt)
return FakeTensor(fake_mode, t, device)