[NNC] support aten::_convolution when it is 2D conv (#84038)

## Motivation
Currently, only `aten::conv2d` has been supported in NNC. When using `torch.jit.trace`, the node on the graph will be `aten::_convolution`. This PR adds support of `aten::_convolution` node when it corresponds to a 2D convolution.

## Pitch
Support `aten::_convolution` in NNC when we can infer from the parameters that it is a 2D convolution to support models obtained from `torch.jit.trace`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84038
Approved by: https://github.com/huiguoo
This commit is contained in:
Wu, Chunyuan
2022-09-19 17:45:20 +00:00
committed by PyTorch MergeBot
parent b049493ed5
commit ebf45a0785
5 changed files with 130 additions and 30 deletions

View File

@ -19,7 +19,7 @@ class TestMkldnnFusion(JitTestCase):
for pat in fused_patterns:
self.assertGraphContainsExactly(graph, pat, 0)
def _check_model(self, m, x):
def _check_model(self, m, x, trace=False):
old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
torch._C._debug_set_fusion_group_inlining(False)
@ -31,7 +31,10 @@ class TestMkldnnFusion(JitTestCase):
m.eval()
with torch.no_grad():
script = torch.jit.script(m)
if trace:
script = torch.jit.trace(m, x)
else:
script = torch.jit.script(m)
script = torch.jit.freeze(script)
with torch.no_grad():
@ -61,28 +64,30 @@ class TestMkldnnFusion(JitTestCase):
[torch.contiguous_format, False],
[torch.channels_last, True],
]:
input_size = 224
batch_size = 1
kernel_size = 3
options = itertools.product([True, False], [1, 2], [1, 4])
for bias, dilation, groups in options:
iC = 3 * groups
oC = 10 * groups
m = M(iC,
oC,
bias,
kernel_size=(kernel_size, kernel_size),
stride=2,
padding=1,
dilation=dilation,
groups=groups).to(memory_format=memory_format)
x = torch.randn(batch_size, iC, input_size, input_size).to(memory_format=memory_format)
graph = self._check_model(m, x)
if enabled:
self.assertFused(graph, ['aten::conv2d'])
self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
else:
self.assertGraphContains(graph, kind='aten::conv2d')
for trace in [True, False]:
input_size = 224
batch_size = 1
kernel_size = 3
options = itertools.product([True, False], [1, 2], [1, 4])
for bias, dilation, groups in options:
iC = 3 * groups
oC = 10 * groups
m = M(iC,
oC,
bias,
kernel_size=(kernel_size, kernel_size),
stride=2,
padding=1,
dilation=dilation,
groups=groups).to(memory_format=memory_format)
x = torch.randn(batch_size, iC, input_size, input_size).to(memory_format=memory_format)
graph = self._check_model(m, x, trace)
conv_node_name = 'aten::_convolution' if trace else 'aten::conv2d'
if enabled:
self.assertFused(graph, [conv_node_name])
self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
else:
self.assertGraphContains(graph, kind=conv_node_name)
def test_conv_eltwise(self):
class M(nn.Module):
@ -113,6 +118,47 @@ class TestMkldnnFusion(JitTestCase):
else:
self.assertGraphContains(graph, kind='aten::conv2d')
def test_unsupported_conv(self):
class M(nn.Module):
def __init__(self, m, in_channels, out_channels, bias, **kwargs):
super(M, self).__init__()
self.conv = m(in_channels, out_channels, bias=bias, **kwargs)
def forward(self, x):
res = self.conv(x)
return res
for module, dim, memory_format in [
[nn.Conv3d, 3, torch.contiguous_format],
[nn.Conv3d, 3, torch.channels_last_3d],
[nn.ConvTranspose2d, 2, torch.contiguous_format],
[nn.ConvTranspose2d, 2, torch.channels_last],
]:
trace = True
input_size = 224
batch_size = 1
kernel_size = 3
groups = 2
bias = True
iC = 3 * groups
oC = 10 * groups
dilation = 2
m = M(module,
iC,
oC,
bias,
kernel_size=kernel_size,
stride=2,
padding=1,
dilation=dilation,
groups=groups).to(memory_format=memory_format)
input_sizes = [batch_size, iC, input_size, input_size]
if dim == 3:
input_sizes.append(input_size)
x = torch.randn(input_sizes).to(memory_format=memory_format)
graph = self._check_model(m, x, trace)
self.assertGraphContains(graph, kind='aten::_convolution')
if __name__ == "__main__":
run_tests()

View File

@ -105,9 +105,6 @@ void insertPrePackedConvOp(Block* b) {
}
void insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
// Replace _convolution with conv2d
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
insertPrePackedConvOp(graph->block());
}

View File

@ -80,6 +80,7 @@ static const OperatorSet& supported_non_eltwise_set() {
static const OperatorSet supported_non_eltwise_set{
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
"aten::matmul(Tensor self, Tensor other) -> Tensor",
};
// clang-format on
@ -897,6 +898,7 @@ class TensorExprFuser {
};
static const OperatorSet cpu_compute_heavy_set{
"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
"aten::matmul(Tensor self, Tensor other) -> Tensor",
};
static const OperatorSet gpu_only_operator_set{
@ -1043,7 +1045,11 @@ class TensorExprFuser {
}
}
if (node->kind() == aten::conv2d) {
if (node->kind() == aten::_convolution && !tensorexpr::isConv2d(node)) {
GRAPH_DEBUG("This aten::_convolution node is not a 2D conv");
return false;
}
if (node->kind() == aten::_convolution || node->kind() == aten::conv2d) {
if (!tensorexpr::conv2dIsSupportedJit(node) &&
!tensorexpr::mkldnnPrepackedConvIsSupportedJit(node)) {
GRAPH_DEBUG("Params of conv2d are not supported");

View File

@ -8,6 +8,7 @@
#include <c10/util/irange.h>
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/mkldnn_rewrite.h>
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
@ -233,6 +234,20 @@ bool isContiguous(const torch::jit::Value* v, at::MemoryFormat memory_format) {
return *strides == TensorType::contiguousStridesOf(*sizes, memory_format);
}
size_t get_conv_groups_index(const torch::jit::Node* node) {
switch (node->kind()) {
case aten::conv2d:
return 6;
case aten::_convolution:
return 8;
default:
TORCH_CHECK(
false,
"mkldnnPrepackedConvIsSupportedJit expects node kind to be conv2d or _convolution but got ",
node->kind());
}
}
// The fuser only supports conv2d with very specific properties:
// - Static shapes: 4-d input and filter, 1-d bias.
// - Constant strides/padding/dilation/groups
@ -246,7 +261,8 @@ bool conv2dIsSupportedJit(const torch::jit::Node* node) {
auto const& stride = toIValue(node->input(3));
auto const& pad = toIValue(node->input(4));
auto const& dilation = toIValue(node->input(5));
auto const& groups = toIValue(node->input(6));
size_t groups_index = get_conv_groups_index(node);
auto const& groups = toIValue(node->input(groups_index));
// Everything should be statically known.
if (!input || !weight || !bias || !stride || !pad || !dilation || !groups) {
@ -278,7 +294,8 @@ bool mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node* node) {
auto const& stride = toIValue(node->input(3));
auto const& pad = toIValue(node->input(4));
auto const& dilation = toIValue(node->input(5));
auto const& groups = toIValue(node->input(6));
size_t groups_index = get_conv_groups_index(node);
auto const& groups = toIValue(node->input(groups_index));
// Everything should be statically known (bias could be NoneType =
// prim::Constant()).
@ -314,6 +331,37 @@ bool mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node* node) {
return false;
}
bool isConv2d(const Node* node) {
if (node->kind() != aten::_convolution) {
return false;
}
auto const& stride = toIValue(node->input(3));
auto const& pad = toIValue(node->input(4));
auto const& dilation = toIValue(node->input(5));
auto const& transposed = toIValue(node->input(6));
auto const& output_padding = toIValue(node->input(7));
if (!stride || !pad || !dilation || !transposed || !output_padding) {
GRAPH_DEBUG("some params aren't static");
return false;
}
if (stride.value().toIntList().size() != 2 ||
pad.value().toIntList().size() != 2 ||
dilation.value().toIntList().size() != 2 ||
output_padding.value().toIntList().size() != 2) {
GRAPH_DEBUG("Conv not 2d");
return false;
}
if (transposed.value().toBool()) {
GRAPH_DEBUG("transposed Conv");
return false;
}
return true;
}
// The fuser currently only supports matmul of 2D x 2D matrices
bool matmulIsSupported(const torch::jit::Node* node) {
auto const& input0 = getTensorInfoJit(node->input(0));
@ -1606,6 +1654,7 @@ void TensorExprKernel::optimizeOwningGraph() {
deduceMemoryLayoutPolicy();
// Fuse Conv with Eltwise Op
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph_);
FuseConvWithEltwise(graph_);
// Optimize the concatenation

View File

@ -26,6 +26,8 @@ struct SmallSizeTPairHash {
bool conv2dIsSupportedJit(const Node* node);
// Returns true if the TE fuser supports this conv2d with mkldnn prepacked conv.
bool mkldnnPrepackedConvIsSupportedJit(const Node* node);
// Returns true if the the _convolution node is Conv2d.
bool isConv2d(const Node* node);
// Returns true if the TE fuser supports this matmul.
bool matmulIsSupported(const Node* node);
template <typename T>