Fix replaceAtenConvolution for BC. (#44036)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44036

Running replaceAtenConvolution on older traced model wont work as
_convolution signature has changed and replaceAtenConvolution was
changed to account for that.
But we did not preserve the old behavior during that. This change
restores the old behavior while keeing the new one.

Test Plan: Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D23476775

fbshipit-source-id: 73a0c2b7387f2a8d82a8d26070d0059972126836
This commit is contained in:
Kimish Patel
2020-09-03 12:53:22 -07:00
committed by Facebook GitHub Bot
parent ba65cce2a2
commit a153f69417

View File

@ -70,6 +70,7 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
%r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
%transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled)
return (%r) )";
std::string convolution = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
@ -78,6 +79,12 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
%transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32)
return (%r) )";
std::string conv2d_for_deprecated_conv = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv2d = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
@ -85,6 +92,12 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
%r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv2d_transpose_for_deprecated_conv = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
return (%r) )";
std::string conv2d_transpose = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
@ -92,6 +105,12 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
return (%r) )";
std::string conv1d_for_deprecated_conv = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv1d = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
@ -99,6 +118,12 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
%r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv3d_for_deprecated_conv = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv3d = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
@ -174,19 +199,24 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
SubgraphRewriter rewriter_conv1d;
rewriter_conv1d.RegisterRewritePattern(convolution, conv1d);
rewriter_conv1d.RegisterRewritePattern(convolution_deprecated, conv1d);
rewriter_conv1d.RegisterRewritePattern(
convolution_deprecated, conv1d_for_deprecated_conv);
rewriter_conv1d.runOnGraph(graph, filter_conv1d);
SubgraphRewriter rewriter_conv2d;
rewriter_conv2d.RegisterRewritePattern(convolution, conv2d);
rewriter_conv2d.RegisterRewritePattern(convolution_deprecated, conv2d);
rewriter_conv2d.RegisterRewritePattern(
convolution_deprecated, conv2d_for_deprecated_conv);
rewriter_conv2d.runOnGraph(graph, filter_conv2d);
SubgraphRewriter rewriter_conv2d_transpose;
rewriter_conv2d_transpose.RegisterRewritePattern(
convolution, conv2d_transpose);
rewriter_conv2d_transpose.RegisterRewritePattern(
convolution_deprecated, conv2d_transpose_for_deprecated_conv);
rewriter_conv2d_transpose.runOnGraph(graph, filter_conv2d_transpose);
SubgraphRewriter rewriter_conv3d;
rewriter_conv3d.RegisterRewritePattern(convolution, conv3d);
rewriter_conv3d.RegisterRewritePattern(convolution_deprecated, conv3d);
rewriter_conv3d.RegisterRewritePattern(
convolution_deprecated, conv3d_for_deprecated_conv);
rewriter_conv3d.runOnGraph(graph, filter_conv3d);
}