mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
## Motivation Currently, only `aten::conv2d` has been supported in NNC. When using `torch.jit.trace`, the node on the graph will be `aten::_convolution`. This PR adds support of `aten::_convolution` node when it corresponds to a 2D convolution. ## Pitch Support `aten::_convolution` in NNC when we can infer from the parameters that it is a 2D convolution to support models obtained from `torch.jit.trace`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84038 Approved by: https://github.com/huiguoo
165 lines
6.1 KiB
Python
165 lines
6.1 KiB
Python
# Owner(s): ["module: mkldnn"]
|
|
import itertools
|
|
import unittest
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
from test_tensorexpr import warmup_and_run_forward
|
|
|
|
FUSION_GROUP = 'prim::TensorExprGroup'
|
|
|
|
|
|
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
|
|
class TestMkldnnFusion(JitTestCase):
|
|
def assertFused(self, graph, fused_patterns):
|
|
for pat in fused_patterns:
|
|
self.assertGraphContainsExactly(graph, pat, 0)
|
|
|
|
def _check_model(self, m, x, trace=False):
|
|
old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
|
|
torch._C._debug_set_fusion_group_inlining(False)
|
|
|
|
old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
|
|
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
|
|
old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
|
|
torch._C._jit_set_te_must_use_llvm_cpu(False)
|
|
|
|
m.eval()
|
|
with torch.no_grad():
|
|
if trace:
|
|
script = torch.jit.trace(m, x)
|
|
else:
|
|
script = torch.jit.script(m)
|
|
script = torch.jit.freeze(script)
|
|
|
|
with torch.no_grad():
|
|
y = warmup_and_run_forward(script, x)
|
|
y = script(x)
|
|
y_ref = m(x)
|
|
|
|
graph = script.graph_for(*x)
|
|
self.assertEqual(y, y_ref)
|
|
|
|
torch._C._debug_set_fusion_group_inlining(old_fusion_inlining)
|
|
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state)
|
|
torch._C._jit_set_te_must_use_llvm_cpu(old_te_must_use_llvm_cpu)
|
|
return graph
|
|
|
|
def test_single_conv(self):
|
|
class M(nn.Module):
|
|
def __init__(self, in_channels, out_channels, bias, **kwargs):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
|
|
|
|
def forward(self, x):
|
|
res = self.conv(x)
|
|
return res
|
|
|
|
for memory_format, enabled in [
|
|
[torch.contiguous_format, False],
|
|
[torch.channels_last, True],
|
|
]:
|
|
for trace in [True, False]:
|
|
input_size = 224
|
|
batch_size = 1
|
|
kernel_size = 3
|
|
options = itertools.product([True, False], [1, 2], [1, 4])
|
|
for bias, dilation, groups in options:
|
|
iC = 3 * groups
|
|
oC = 10 * groups
|
|
m = M(iC,
|
|
oC,
|
|
bias,
|
|
kernel_size=(kernel_size, kernel_size),
|
|
stride=2,
|
|
padding=1,
|
|
dilation=dilation,
|
|
groups=groups).to(memory_format=memory_format)
|
|
x = torch.randn(batch_size, iC, input_size, input_size).to(memory_format=memory_format)
|
|
graph = self._check_model(m, x, trace)
|
|
conv_node_name = 'aten::_convolution' if trace else 'aten::conv2d'
|
|
if enabled:
|
|
self.assertFused(graph, [conv_node_name])
|
|
self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
|
|
else:
|
|
self.assertGraphContains(graph, kind=conv_node_name)
|
|
|
|
def test_conv_eltwise(self):
|
|
class M(nn.Module):
|
|
def __init__(self, eltwise_fn, in_channels, out_channels, bias, **kwargs):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
|
|
self.eltwise = eltwise_fn
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.eltwise(x)
|
|
return x
|
|
|
|
for memory_format, enabled in [
|
|
[torch.contiguous_format, False],
|
|
[torch.channels_last, True],
|
|
]:
|
|
for eltwise_fn in [torch.relu]:
|
|
for bias in [True, False]:
|
|
for oC in [1, 10]:
|
|
m = M(eltwise_fn, 3, oC, bias, kernel_size=(3, 3)).to(memory_format=memory_format)
|
|
x = torch.randn(1, 3, 224, 224).to(memory_format=memory_format)
|
|
|
|
graph = self._check_model(m, x)
|
|
if enabled:
|
|
self.assertFused(graph, ['aten::conv2d', 'aten::' + eltwise_fn.__name__])
|
|
self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
|
|
else:
|
|
self.assertGraphContains(graph, kind='aten::conv2d')
|
|
|
|
def test_unsupported_conv(self):
|
|
class M(nn.Module):
|
|
def __init__(self, m, in_channels, out_channels, bias, **kwargs):
|
|
super(M, self).__init__()
|
|
self.conv = m(in_channels, out_channels, bias=bias, **kwargs)
|
|
|
|
def forward(self, x):
|
|
res = self.conv(x)
|
|
return res
|
|
|
|
for module, dim, memory_format in [
|
|
[nn.Conv3d, 3, torch.contiguous_format],
|
|
[nn.Conv3d, 3, torch.channels_last_3d],
|
|
[nn.ConvTranspose2d, 2, torch.contiguous_format],
|
|
[nn.ConvTranspose2d, 2, torch.channels_last],
|
|
]:
|
|
trace = True
|
|
input_size = 224
|
|
batch_size = 1
|
|
kernel_size = 3
|
|
groups = 2
|
|
bias = True
|
|
iC = 3 * groups
|
|
oC = 10 * groups
|
|
dilation = 2
|
|
m = M(module,
|
|
iC,
|
|
oC,
|
|
bias,
|
|
kernel_size=kernel_size,
|
|
stride=2,
|
|
padding=1,
|
|
dilation=dilation,
|
|
groups=groups).to(memory_format=memory_format)
|
|
input_sizes = [batch_size, iC, input_size, input_size]
|
|
if dim == 3:
|
|
input_sizes.append(input_size)
|
|
x = torch.randn(input_sizes).to(memory_format=memory_format)
|
|
graph = self._check_model(m, x, trace)
|
|
self.assertGraphContains(graph, kind='aten::_convolution')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|