mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[FakeTensor] Supplement the relevant logic for converting conv1d to conv2d in meta_conv (#160408)
## Fixes https://github.com/pytorch/pytorch/issues/159462 also fixes #163569 , #163604 ## summary the issue is caused by the wrong stride of conv1d's result generated by meta_conv:4d5b3f2d5a/torch/_meta_registrations.py (L2453-L2471)
and the wrong stride will be used to codegen size assert in inductor:4d5b3f2d5a/torch/_inductor/ir.py (L6152-L6163)
## reason So why the computed stride is wrong in the meta_conv function? because the corresponding backend will convert conv1d to conv2d and change the input tensor' size and memory_format(channel last). but the meta_conv do not do this transformation, so a mismatch happend.4d5b3f2d5a/aten/src/ATen/native/Convolution.cpp (L1502-L1510)
just add corresponding logic in meta_conv. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160408 Approved by: https://github.com/eellison, https://github.com/jansel, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
8aba513506
commit
c106ee8515
@ -311,6 +311,33 @@ class CPUReproTests(TestCase):
|
||||
(v,),
|
||||
)
|
||||
|
||||
def test_conv1d_strided_weight_torch_compile(self):
|
||||
def fn(x, w):
|
||||
wt = w.transpose(2, 1)
|
||||
y = F.conv1d(x, wt)
|
||||
return y.clone()
|
||||
|
||||
x_eager = torch.randn(2, 3, 5, requires_grad=True)
|
||||
w_eager = torch.randn(4, 2, 3, requires_grad=True)
|
||||
|
||||
out_eager = fn(x_eager, w_eager)
|
||||
grad = torch.randn_like(out_eager)
|
||||
out_eager_val = out_eager.detach()
|
||||
out_eager.backward(grad)
|
||||
grad_x_eager = x_eager.grad.detach().clone()
|
||||
grad_w_eager = w_eager.grad.detach().clone()
|
||||
|
||||
x_comp = x_eager.detach().requires_grad_(True)
|
||||
w_comp = w_eager.detach().requires_grad_(True)
|
||||
compiled = torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)
|
||||
out_comp = compiled(x_comp, w_comp)
|
||||
out_comp_val = out_comp.detach()
|
||||
out_comp.backward(grad)
|
||||
|
||||
torch.testing.assert_close(out_comp_val, out_eager_val)
|
||||
torch.testing.assert_close(x_comp.grad, grad_x_eager)
|
||||
torch.testing.assert_close(w_comp.grad, grad_w_eager)
|
||||
|
||||
@config.patch(freezing=True)
|
||||
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
||||
@patch("torch.cuda.is_available", lambda: False)
|
||||
|
@ -4639,6 +4639,41 @@ class CommonTemplate:
|
||||
(torch.randn([4, 4, 4]),),
|
||||
)
|
||||
|
||||
def test_conv1d_with_permute(self):
|
||||
# fix https://github.com/pytorch/pytorch/issues/159462
|
||||
class ConvModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(1, 64, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1)
|
||||
return self.conv(x)
|
||||
|
||||
self.common(ConvModel(), (torch.randn([32, 100, 1]),), check_lowp=False)
|
||||
|
||||
def test_conv1d_depthwise(self):
|
||||
class ConvModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
768,
|
||||
768,
|
||||
kernel_size=(9,),
|
||||
stride=(1,),
|
||||
padding=(4,),
|
||||
groups=768,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
input_tensor = torch.randn([1, 768, 512]).as_strided(
|
||||
(1, 768, 512), (393216, 1, 768)
|
||||
)
|
||||
self.common(ConvModel(), (input_tensor,), check_lowp=False)
|
||||
|
||||
def test_convolution1(self):
|
||||
m = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(5, 6, [3, 3]),
|
||||
|
@ -110,6 +110,7 @@ test_failures = {
|
||||
#
|
||||
# Failed to find dynamic for loop variable:
|
||||
#
|
||||
"test_conv1d_with_permute_dynamic_shapes": TestFailure(("cpu",), is_skip=True),
|
||||
"test_arange1_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_arange2_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_arange3_dynamic_shapes": TestFailure(("cpu",)),
|
||||
|
@ -134,11 +134,11 @@ class TestUtils(TestCase):
|
||||
torch.Tensor(2, 2, 3),
|
||||
torch.Tensor(2, 2, 2),
|
||||
torch.Tensor(2),
|
||||
(1, 1),
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
(1,),
|
||||
(0,),
|
||||
(1,),
|
||||
True,
|
||||
(0, 0),
|
||||
(0,),
|
||||
1,
|
||||
),
|
||||
{},
|
||||
|
@ -201,6 +201,26 @@ class FakeTensorTest(TestCase):
|
||||
|
||||
self.assertEqual(torch.ones([10]), out[0])
|
||||
|
||||
def test_conv_nhwc(self):
|
||||
x = torch.randn([1, 1024, 16, 16]).to(memory_format=torch.channels_last)
|
||||
w = torch.randn([256, 1024, 4, 4]).to(memory_format=torch.channels_last)
|
||||
b = torch.randn([256])
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, w, b):
|
||||
return torch.ops.aten.convolution(
|
||||
x, w, b, [1, 1], [0, 0], [1, 1], False, [0, 0], 1
|
||||
)
|
||||
|
||||
model = Model()
|
||||
with FakeTensorMode(allow_non_fake_inputs=True) as mode:
|
||||
fake_out = model.forward(x, w, b)
|
||||
eager_out = model.forward(x, w, b)
|
||||
self.assertEqual(fake_out.stride(), eager_out.stride())
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_zero_dim(self):
|
||||
with FakeTensorMode() as mode:
|
||||
|
@ -2447,18 +2447,6 @@ def meta_conv(
|
||||
output_padding: list[int],
|
||||
groups: int,
|
||||
):
|
||||
def pick_memory_format():
|
||||
if device_hint(input_tensor) == "cuda":
|
||||
if is_channels_last(input_tensor) or is_channels_last(weight):
|
||||
return torch.channels_last
|
||||
else:
|
||||
if is_channels_last(input_tensor):
|
||||
return torch.channels_last
|
||||
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
||||
return torch.contiguous_format
|
||||
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
|
||||
return torch.preserve_format
|
||||
|
||||
shape_out = calc_conv_nd_return_shape(
|
||||
input_tensor,
|
||||
weight,
|
||||
@ -2476,7 +2464,6 @@ def meta_conv(
|
||||
shape_out[output_channels_dim] = 0
|
||||
|
||||
out = input_tensor.new_empty(shape_out)
|
||||
out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
|
||||
return out
|
||||
|
||||
|
||||
|
@ -1021,8 +1021,6 @@ def conv(fake_mode, func, *args, **kwargs):
|
||||
# TODO: We can make this a little more faithful with best effort
|
||||
# channels last detection (but only if it's statically obvious!)
|
||||
mem_fmt = None
|
||||
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
|
||||
mem_fmt = None
|
||||
else:
|
||||
if func is aten.convolution.default:
|
||||
conv_backend = torch._C._select_conv_backend(**kwargs)
|
||||
@ -1039,15 +1037,40 @@ def conv(fake_mode, func, *args, **kwargs):
|
||||
groups=kwargs["groups"],
|
||||
bias_sizes=kwargs["bias_sizes"],
|
||||
)
|
||||
# Expand 1d -> 2d.
|
||||
# Note: Avoid expanding before calling _select_conv_backend,
|
||||
# as the function handles 2D expansion internally.
|
||||
if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
|
||||
# Note: Using input.to(memory_format=contiguous) does not work.
|
||||
kwargs["input"] = kwargs["input"].contiguous().unsqueeze(2)
|
||||
kwargs["weight"] = kwargs["weight"].unsqueeze(2)
|
||||
if len(kwargs["stride"]) == 1:
|
||||
kwargs["stride"].insert(0, 1)
|
||||
kwargs["padding"].insert(0, 0)
|
||||
kwargs["dilation"].insert(0, 1)
|
||||
kwargs["output_padding"].insert(0, 0)
|
||||
mem_fmt = torch._C._conv_determine_backend_memory_format(
|
||||
kwargs["input"], kwargs["weight"], conv_backend
|
||||
)
|
||||
# revert 2d -> 1d
|
||||
if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
|
||||
kwargs["input"] = kwargs["input"].squeeze(2)
|
||||
kwargs["weight"] = kwargs["weight"].squeeze(2)
|
||||
if len(kwargs["stride"]) == 2:
|
||||
kwargs["stride"].pop(0)
|
||||
kwargs["padding"].pop(0)
|
||||
kwargs["dilation"].pop(0)
|
||||
kwargs["output_padding"].pop(0)
|
||||
|
||||
def convert(t, mem_fmt):
|
||||
if t is None:
|
||||
return t
|
||||
if mem_fmt is not None:
|
||||
t = t.to(memory_format=mem_fmt)
|
||||
# channels last only support 4d, try to expand dim then convert it back later.
|
||||
if t.dim() == 3 and mem_fmt == torch.channels_last:
|
||||
t = t.unsqueeze(2).to(memory_format=mem_fmt).squeeze(2)
|
||||
else:
|
||||
t = t.to(memory_format=mem_fmt)
|
||||
return FakeTensor(fake_mode, t, device)
|
||||
|
||||
with in_kernel_invocation_manager(fake_mode):
|
||||
|
Reference in New Issue
Block a user