mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
add channels last (2d) support for mkldnn_convolution (#55584)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55584 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D27941368 Pulled By: VitalyFedyunin fbshipit-source-id: 7dd6f02a5787efa1995f31cdbd3244b25653840c (cherry picked from commit bb555ed0fedafd529cb552807326384e95c90df9)
This commit is contained in:
committed by
PyTorch MergeBot
parent
cebdca4191
commit
92a9c0e3e0
@ -241,6 +241,47 @@ class TestMkldnn(TestCase):
|
||||
def test_conv3d(self):
|
||||
self._test_conv_base(dim=3)
|
||||
|
||||
def test_conv2d_nhwc(self):
|
||||
conv_module = torch.nn.Conv2d
|
||||
input_shapes = (224, 224)
|
||||
options = itertools.product([True, False], [True, False], [1, 2], [1, 4])
|
||||
for train, bias, dilation, groups in options:
|
||||
N = torch.randint(3, 10, (1,)).item()
|
||||
M = torch.randint(1, 3, (1,)).item() * groups
|
||||
C = torch.randint(1, 3, (1,)).item() * groups
|
||||
x_shape = (N, C) + input_shapes
|
||||
x = torch.randn(x_shape, dtype=torch.float32)
|
||||
# conv1: mkldnn conv2d in contiguous memory format (nchw)
|
||||
# conv2: mkldnn conv2d in channels last memory format (nhwc)
|
||||
conv1 = conv_module(in_channels=C,
|
||||
out_channels=M,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
groups=groups).float()
|
||||
conv2 = copy.deepcopy(conv1).to(memory_format=torch.channels_last)
|
||||
x1 = x.clone()
|
||||
x2 = x.clone().to(memory_format=torch.channels_last)
|
||||
if train:
|
||||
x1.requires_grad_()
|
||||
x2.requires_grad_()
|
||||
y1 = conv1(x1)
|
||||
y2 = conv2(x2)
|
||||
self.assertEqual(y1, y2)
|
||||
if train:
|
||||
y1.sum().backward()
|
||||
y2.sum().backward()
|
||||
self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last))
|
||||
self.assertEqual(conv1.weight.grad,
|
||||
conv2.weight.grad,
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
if bias:
|
||||
self.assertEqual(conv1.bias.grad, conv2.bias.grad)
|
||||
self.assertEqual(x1.grad, x2.grad)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
||||
def _test_conv_bf16_base(self, dim):
|
||||
conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
|
||||
|
||||
Reference in New Issue
Block a user