mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Fix replaceAtenConvolution for BC. (#44036)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44036 Running replaceAtenConvolution on older traced model wont work as _convolution signature has changed and replaceAtenConvolution was changed to account for that. But we did not preserve the old behavior during that. This change restores the old behavior while keeing the new one. Test Plan: Imported from OSS Reviewed By: jerryzh168 Differential Revision: D23476775 fbshipit-source-id: 73a0c2b7387f2a8d82a8d26070d0059972126836
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ba65cce2a2
commit
a153f69417
@ -70,6 +70,7 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
||||
%r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
|
||||
%transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled)
|
||||
return (%r) )";
|
||||
|
||||
std::string convolution = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
@ -78,6 +79,12 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
||||
%transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv2d_for_deprecated_conv = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool):
|
||||
%r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
std::string conv2d = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
@ -85,6 +92,12 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
||||
%r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv2d_transpose_for_deprecated_conv = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool):
|
||||
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
|
||||
return (%r) )";
|
||||
std::string conv2d_transpose = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
@ -92,6 +105,12 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
||||
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv1d_for_deprecated_conv = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool):
|
||||
%r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
std::string conv1d = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
@ -99,6 +118,12 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
||||
%r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv3d_for_deprecated_conv = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
%deterministic:bool, %cudnn_enabled:bool):
|
||||
%r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
std::string conv3d = R"(
|
||||
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
|
||||
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
|
||||
@ -174,19 +199,24 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
|
||||
|
||||
SubgraphRewriter rewriter_conv1d;
|
||||
rewriter_conv1d.RegisterRewritePattern(convolution, conv1d);
|
||||
rewriter_conv1d.RegisterRewritePattern(convolution_deprecated, conv1d);
|
||||
rewriter_conv1d.RegisterRewritePattern(
|
||||
convolution_deprecated, conv1d_for_deprecated_conv);
|
||||
rewriter_conv1d.runOnGraph(graph, filter_conv1d);
|
||||
SubgraphRewriter rewriter_conv2d;
|
||||
rewriter_conv2d.RegisterRewritePattern(convolution, conv2d);
|
||||
rewriter_conv2d.RegisterRewritePattern(convolution_deprecated, conv2d);
|
||||
rewriter_conv2d.RegisterRewritePattern(
|
||||
convolution_deprecated, conv2d_for_deprecated_conv);
|
||||
rewriter_conv2d.runOnGraph(graph, filter_conv2d);
|
||||
SubgraphRewriter rewriter_conv2d_transpose;
|
||||
rewriter_conv2d_transpose.RegisterRewritePattern(
|
||||
convolution, conv2d_transpose);
|
||||
rewriter_conv2d_transpose.RegisterRewritePattern(
|
||||
convolution_deprecated, conv2d_transpose_for_deprecated_conv);
|
||||
rewriter_conv2d_transpose.runOnGraph(graph, filter_conv2d_transpose);
|
||||
SubgraphRewriter rewriter_conv3d;
|
||||
rewriter_conv3d.RegisterRewritePattern(convolution, conv3d);
|
||||
rewriter_conv3d.RegisterRewritePattern(convolution_deprecated, conv3d);
|
||||
rewriter_conv3d.RegisterRewritePattern(
|
||||
convolution_deprecated, conv3d_for_deprecated_conv);
|
||||
rewriter_conv3d.runOnGraph(graph, filter_conv3d);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user