mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[NNC] support aten::_convolution when it is 2D conv (#84038)
## Motivation Currently, only `aten::conv2d` has been supported in NNC. When using `torch.jit.trace`, the node on the graph will be `aten::_convolution`. This PR adds support of `aten::_convolution` node when it corresponds to a 2D convolution. ## Pitch Support `aten::_convolution` in NNC when we can infer from the parameters that it is a 2D convolution to support models obtained from `torch.jit.trace`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84038 Approved by: https://github.com/huiguoo
This commit is contained in:
committed by
PyTorch MergeBot
parent
b049493ed5
commit
ebf45a0785
@ -19,7 +19,7 @@ class TestMkldnnFusion(JitTestCase):
|
||||
for pat in fused_patterns:
|
||||
self.assertGraphContainsExactly(graph, pat, 0)
|
||||
|
||||
def _check_model(self, m, x):
|
||||
def _check_model(self, m, x, trace=False):
|
||||
old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
|
||||
torch._C._debug_set_fusion_group_inlining(False)
|
||||
|
||||
@ -31,7 +31,10 @@ class TestMkldnnFusion(JitTestCase):
|
||||
|
||||
m.eval()
|
||||
with torch.no_grad():
|
||||
script = torch.jit.script(m)
|
||||
if trace:
|
||||
script = torch.jit.trace(m, x)
|
||||
else:
|
||||
script = torch.jit.script(m)
|
||||
script = torch.jit.freeze(script)
|
||||
|
||||
with torch.no_grad():
|
||||
@ -61,28 +64,30 @@ class TestMkldnnFusion(JitTestCase):
|
||||
[torch.contiguous_format, False],
|
||||
[torch.channels_last, True],
|
||||
]:
|
||||
input_size = 224
|
||||
batch_size = 1
|
||||
kernel_size = 3
|
||||
options = itertools.product([True, False], [1, 2], [1, 4])
|
||||
for bias, dilation, groups in options:
|
||||
iC = 3 * groups
|
||||
oC = 10 * groups
|
||||
m = M(iC,
|
||||
oC,
|
||||
bias,
|
||||
kernel_size=(kernel_size, kernel_size),
|
||||
stride=2,
|
||||
padding=1,
|
||||
dilation=dilation,
|
||||
groups=groups).to(memory_format=memory_format)
|
||||
x = torch.randn(batch_size, iC, input_size, input_size).to(memory_format=memory_format)
|
||||
graph = self._check_model(m, x)
|
||||
if enabled:
|
||||
self.assertFused(graph, ['aten::conv2d'])
|
||||
self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
|
||||
else:
|
||||
self.assertGraphContains(graph, kind='aten::conv2d')
|
||||
for trace in [True, False]:
|
||||
input_size = 224
|
||||
batch_size = 1
|
||||
kernel_size = 3
|
||||
options = itertools.product([True, False], [1, 2], [1, 4])
|
||||
for bias, dilation, groups in options:
|
||||
iC = 3 * groups
|
||||
oC = 10 * groups
|
||||
m = M(iC,
|
||||
oC,
|
||||
bias,
|
||||
kernel_size=(kernel_size, kernel_size),
|
||||
stride=2,
|
||||
padding=1,
|
||||
dilation=dilation,
|
||||
groups=groups).to(memory_format=memory_format)
|
||||
x = torch.randn(batch_size, iC, input_size, input_size).to(memory_format=memory_format)
|
||||
graph = self._check_model(m, x, trace)
|
||||
conv_node_name = 'aten::_convolution' if trace else 'aten::conv2d'
|
||||
if enabled:
|
||||
self.assertFused(graph, [conv_node_name])
|
||||
self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
|
||||
else:
|
||||
self.assertGraphContains(graph, kind=conv_node_name)
|
||||
|
||||
def test_conv_eltwise(self):
|
||||
class M(nn.Module):
|
||||
@ -113,6 +118,47 @@ class TestMkldnnFusion(JitTestCase):
|
||||
else:
|
||||
self.assertGraphContains(graph, kind='aten::conv2d')
|
||||
|
||||
def test_unsupported_conv(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self, m, in_channels, out_channels, bias, **kwargs):
|
||||
super(M, self).__init__()
|
||||
self.conv = m(in_channels, out_channels, bias=bias, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
res = self.conv(x)
|
||||
return res
|
||||
|
||||
for module, dim, memory_format in [
|
||||
[nn.Conv3d, 3, torch.contiguous_format],
|
||||
[nn.Conv3d, 3, torch.channels_last_3d],
|
||||
[nn.ConvTranspose2d, 2, torch.contiguous_format],
|
||||
[nn.ConvTranspose2d, 2, torch.channels_last],
|
||||
]:
|
||||
trace = True
|
||||
input_size = 224
|
||||
batch_size = 1
|
||||
kernel_size = 3
|
||||
groups = 2
|
||||
bias = True
|
||||
iC = 3 * groups
|
||||
oC = 10 * groups
|
||||
dilation = 2
|
||||
m = M(module,
|
||||
iC,
|
||||
oC,
|
||||
bias,
|
||||
kernel_size=kernel_size,
|
||||
stride=2,
|
||||
padding=1,
|
||||
dilation=dilation,
|
||||
groups=groups).to(memory_format=memory_format)
|
||||
input_sizes = [batch_size, iC, input_size, input_size]
|
||||
if dim == 3:
|
||||
input_sizes.append(input_size)
|
||||
x = torch.randn(input_sizes).to(memory_format=memory_format)
|
||||
graph = self._check_model(m, x, trace)
|
||||
self.assertGraphContains(graph, kind='aten::_convolution')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -105,9 +105,6 @@ void insertPrePackedConvOp(Block* b) {
|
||||
}
|
||||
|
||||
void insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
|
||||
// Replace _convolution with conv2d
|
||||
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
|
||||
|
||||
insertPrePackedConvOp(graph->block());
|
||||
}
|
||||
|
||||
|
||||
@ -80,6 +80,7 @@ static const OperatorSet& supported_non_eltwise_set() {
|
||||
static const OperatorSet supported_non_eltwise_set{
|
||||
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
|
||||
"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
|
||||
"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
|
||||
"aten::matmul(Tensor self, Tensor other) -> Tensor",
|
||||
};
|
||||
// clang-format on
|
||||
@ -897,6 +898,7 @@ class TensorExprFuser {
|
||||
};
|
||||
static const OperatorSet cpu_compute_heavy_set{
|
||||
"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
|
||||
"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
|
||||
"aten::matmul(Tensor self, Tensor other) -> Tensor",
|
||||
};
|
||||
static const OperatorSet gpu_only_operator_set{
|
||||
@ -1043,7 +1045,11 @@ class TensorExprFuser {
|
||||
}
|
||||
}
|
||||
|
||||
if (node->kind() == aten::conv2d) {
|
||||
if (node->kind() == aten::_convolution && !tensorexpr::isConv2d(node)) {
|
||||
GRAPH_DEBUG("This aten::_convolution node is not a 2D conv");
|
||||
return false;
|
||||
}
|
||||
if (node->kind() == aten::_convolution || node->kind() == aten::conv2d) {
|
||||
if (!tensorexpr::conv2dIsSupportedJit(node) &&
|
||||
!tensorexpr::mkldnnPrepackedConvIsSupportedJit(node)) {
|
||||
GRAPH_DEBUG("Params of conv2d are not supported");
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/string_utils.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
||||
#include <torch/csrc/jit/passes/mkldnn_rewrite.h>
|
||||
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
|
||||
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
||||
@ -233,6 +234,20 @@ bool isContiguous(const torch::jit::Value* v, at::MemoryFormat memory_format) {
|
||||
return *strides == TensorType::contiguousStridesOf(*sizes, memory_format);
|
||||
}
|
||||
|
||||
size_t get_conv_groups_index(const torch::jit::Node* node) {
|
||||
switch (node->kind()) {
|
||||
case aten::conv2d:
|
||||
return 6;
|
||||
case aten::_convolution:
|
||||
return 8;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"mkldnnPrepackedConvIsSupportedJit expects node kind to be conv2d or _convolution but got ",
|
||||
node->kind());
|
||||
}
|
||||
}
|
||||
|
||||
// The fuser only supports conv2d with very specific properties:
|
||||
// - Static shapes: 4-d input and filter, 1-d bias.
|
||||
// - Constant strides/padding/dilation/groups
|
||||
@ -246,7 +261,8 @@ bool conv2dIsSupportedJit(const torch::jit::Node* node) {
|
||||
auto const& stride = toIValue(node->input(3));
|
||||
auto const& pad = toIValue(node->input(4));
|
||||
auto const& dilation = toIValue(node->input(5));
|
||||
auto const& groups = toIValue(node->input(6));
|
||||
size_t groups_index = get_conv_groups_index(node);
|
||||
auto const& groups = toIValue(node->input(groups_index));
|
||||
|
||||
// Everything should be statically known.
|
||||
if (!input || !weight || !bias || !stride || !pad || !dilation || !groups) {
|
||||
@ -278,7 +294,8 @@ bool mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node* node) {
|
||||
auto const& stride = toIValue(node->input(3));
|
||||
auto const& pad = toIValue(node->input(4));
|
||||
auto const& dilation = toIValue(node->input(5));
|
||||
auto const& groups = toIValue(node->input(6));
|
||||
size_t groups_index = get_conv_groups_index(node);
|
||||
auto const& groups = toIValue(node->input(groups_index));
|
||||
|
||||
// Everything should be statically known (bias could be NoneType =
|
||||
// prim::Constant()).
|
||||
@ -314,6 +331,37 @@ bool mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node* node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isConv2d(const Node* node) {
|
||||
if (node->kind() != aten::_convolution) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto const& stride = toIValue(node->input(3));
|
||||
auto const& pad = toIValue(node->input(4));
|
||||
auto const& dilation = toIValue(node->input(5));
|
||||
auto const& transposed = toIValue(node->input(6));
|
||||
auto const& output_padding = toIValue(node->input(7));
|
||||
|
||||
if (!stride || !pad || !dilation || !transposed || !output_padding) {
|
||||
GRAPH_DEBUG("some params aren't static");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stride.value().toIntList().size() != 2 ||
|
||||
pad.value().toIntList().size() != 2 ||
|
||||
dilation.value().toIntList().size() != 2 ||
|
||||
output_padding.value().toIntList().size() != 2) {
|
||||
GRAPH_DEBUG("Conv not 2d");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (transposed.value().toBool()) {
|
||||
GRAPH_DEBUG("transposed Conv");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// The fuser currently only supports matmul of 2D x 2D matrices
|
||||
bool matmulIsSupported(const torch::jit::Node* node) {
|
||||
auto const& input0 = getTensorInfoJit(node->input(0));
|
||||
@ -1606,6 +1654,7 @@ void TensorExprKernel::optimizeOwningGraph() {
|
||||
deduceMemoryLayoutPolicy();
|
||||
|
||||
// Fuse Conv with Eltwise Op
|
||||
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph_);
|
||||
FuseConvWithEltwise(graph_);
|
||||
|
||||
// Optimize the concatenation
|
||||
|
||||
@ -26,6 +26,8 @@ struct SmallSizeTPairHash {
|
||||
bool conv2dIsSupportedJit(const Node* node);
|
||||
// Returns true if the TE fuser supports this conv2d with mkldnn prepacked conv.
|
||||
bool mkldnnPrepackedConvIsSupportedJit(const Node* node);
|
||||
// Returns true if the the _convolution node is Conv2d.
|
||||
bool isConv2d(const Node* node);
|
||||
// Returns true if the TE fuser supports this matmul.
|
||||
bool matmulIsSupported(const Node* node);
|
||||
template <typename T>
|
||||
|
||||
Reference in New Issue
Block a user