add channels last (2d) support for mkldnn_convolution (#55584)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55584

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D27941368

Pulled By: VitalyFedyunin

fbshipit-source-id: 7dd6f02a5787efa1995f31cdbd3244b25653840c
(cherry picked from commit bb555ed0fedafd529cb552807326384e95c90df9)
This commit is contained in:
mingfeima
2022-04-20 15:28:43 -07:00
committed by PyTorch MergeBot
parent cebdca4191
commit 92a9c0e3e0
6 changed files with 239 additions and 119 deletions

View File

@ -241,6 +241,47 @@ class TestMkldnn(TestCase):
def test_conv3d(self):
self._test_conv_base(dim=3)
def test_conv2d_nhwc(self):
conv_module = torch.nn.Conv2d
input_shapes = (224, 224)
options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
for train, bias, dilation, groups in options:
N = torch.randint(3, 10, (1,)).item()
M = torch.randint(1, 3, (1,)).item() * groups
C = torch.randint(1, 3, (1,)).item() * groups
x_shape = (N, C) + input_shapes
x = torch.randn(x_shape, dtype=torch.float32)
# conv1: mkldnn conv2d in contiguous memory format (nchw)
# conv2: mkldnn conv2d in channels last memory format (nhwc)
conv1 = conv_module(in_channels=C,
out_channels=M,
kernel_size=3,
stride=2,
padding=1,
dilation=dilation,
bias=bias,
groups=groups).float()
conv2 = copy.deepcopy(conv1).to(memory_format=torch.channels_last)
x1 = x.clone()
x2 = x.clone().to(memory_format=torch.channels_last)
if train:
x1.requires_grad_()
x2.requires_grad_()
y1 = conv1(x1)
y2 = conv2(x2)
self.assertEqual(y1, y2)
if train:
y1.sum().backward()
y2.sum().backward()
self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(conv1.weight.grad,
conv2.weight.grad,
atol=1e-3,
rtol=1e-3)
if bias:
self.assertEqual(conv1.bias.grad, conv2.bias.grad)
self.assertEqual(x1.grad, x2.grad)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
def _test_conv_bf16_base(self, dim):
conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}