Files
pytorch/torch/csrc/jit/passes/graph_rewrite_helper.cpp

301 lines
14 KiB
C++

#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
namespace torch::jit::graph_rewrite_helper {
std::string getFuncName(Value* func_value) {
auto func = func_value->type()->expectRef<FunctionType>().function();
const auto& qname = func->qualname();
const auto& name = qname.qualifiedName();
auto rdot_idx = name.rfind('.');
if (rdot_idx != std::string::npos) {
return name.substr(rdot_idx + 1, name.length());
} else {
return name;
}
}
Value* getValue(
const std::string& name,
const std::unordered_map<const Value*, Value*>& match_vmap,
const std::unordered_map<std::string, Value*>& vmap) {
return match_vmap.at(vmap.at(name));
}
std::optional<IValue> getIValue(
const std::string& name,
const std::unordered_map<const Value*, Value*>& match_vmap,
const std::unordered_map<std::string, Value*>& vmap) {
return toIValue(getValue(name, match_vmap, vmap));
}
static std::unordered_map<std::string, c10::IValue> getConvParams(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
std::unordered_map<std::string, c10::IValue> calc_values;
const auto& match_vmap = match.values_map;
auto transposed_value = getIValue("transposed", match_vmap, vmap).value();
calc_values["transposed"] = transposed_value;
auto output_padding_value =
getIValue("output_padding", match_vmap, vmap).value();
calc_values["output_padding"] = output_padding_value;
auto stride_value = getIValue("stride", match_vmap, vmap).value();
calc_values["stride"] = stride_value;
auto padding_value = getIValue("padding", match_vmap, vmap).value();
calc_values["padding"] = padding_value;
auto dilation_value = getIValue("dilation", match_vmap, vmap).value();
calc_values["dilation"] = dilation_value;
return calc_values;
}
void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
// TODO: remove constant prop in the pass
ConstantPropagation(graph);
std::string convolution_deprecated = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
%transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled)
return (%r) )";
std::string convolution = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
%r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
%transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32)
return (%r) )";
std::string conv2d_for_deprecated_conv = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv2d = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
%r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv1d_for_deprecated_conv = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv1d = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
%r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv3d_for_deprecated_conv = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv3d = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
%r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv_transpose1d_for_deprecated_conv = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
return (%r) )";
std::string conv_transpose1d = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
%r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
return (%r) )";
std::string conv_transpose2d_for_deprecated_conv = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
return (%r) )";
std::string conv_transpose2d = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
return (%r) )";
std::string conv_transpose3d_for_deprecated_conv = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool):
%r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
return (%r) )";
std::string conv_transpose3d = R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
%r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
return (%r) )";
// Filter the unsupported case
auto filter_conv1d = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
auto calc_value_map = getConvParams(match, vmap);
if (calc_value_map["output_padding"].toIntList().size() != 1 ||
calc_value_map["stride"].toIntList().size() != 1 ||
calc_value_map["padding"].toIntList().size() != 1 ||
calc_value_map["dilation"].toIntList().size() != 1) {
return false;
}
return !calc_value_map["transposed"].toBool();
};
auto filter_conv2d = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
auto calc_value_map = getConvParams(match, vmap);
if (calc_value_map["output_padding"].toIntList().size() != 2 ||
calc_value_map["stride"].toIntList().size() != 2 ||
calc_value_map["padding"].toIntList().size() != 2 ||
calc_value_map["dilation"].toIntList().size() != 2) {
return false;
}
return !calc_value_map["transposed"].toBool();
};
auto filter_conv3d = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
auto calc_value_map = getConvParams(match, vmap);
if (calc_value_map["output_padding"].toIntList().size() != 3 ||
calc_value_map["stride"].toIntList().size() != 3 ||
calc_value_map["padding"].toIntList().size() != 3 ||
calc_value_map["dilation"].toIntList().size() != 3) {
return false;
}
return !calc_value_map["transposed"].toBool();
};
auto filter_conv_transpose1d =
[](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
auto calc_value_map = getConvParams(match, vmap);
if (calc_value_map["output_padding"].toIntList().size() != 1 ||
calc_value_map["stride"].toIntList().size() != 1 ||
calc_value_map["padding"].toIntList().size() != 1 ||
calc_value_map["dilation"].toIntList().size() != 1) {
return false;
}
return calc_value_map["transposed"].toBool();
};
auto filter_conv_transpose2d =
[](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
auto calc_value_map = getConvParams(match, vmap);
if (calc_value_map["output_padding"].toIntList().size() != 2 ||
calc_value_map["stride"].toIntList().size() != 2 ||
calc_value_map["padding"].toIntList().size() != 2 ||
calc_value_map["dilation"].toIntList().size() != 2) {
return false;
}
return calc_value_map["transposed"].toBool();
};
auto filter_conv_transpose3d =
[](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
auto calc_value_map = getConvParams(match, vmap);
if (calc_value_map["output_padding"].toIntList().size() != 3 ||
calc_value_map["stride"].toIntList().size() != 3 ||
calc_value_map["padding"].toIntList().size() != 3 ||
calc_value_map["dilation"].toIntList().size() != 3) {
return false;
}
return calc_value_map["transposed"].toBool();
};
SubgraphRewriter rewriter_conv1d;
rewriter_conv1d.RegisterRewritePattern(convolution, conv1d);
rewriter_conv1d.RegisterRewritePattern(
convolution_deprecated, conv1d_for_deprecated_conv);
rewriter_conv1d.runOnGraph(graph, filter_conv1d);
SubgraphRewriter rewriter_conv2d;
rewriter_conv2d.RegisterRewritePattern(convolution, conv2d);
rewriter_conv2d.RegisterRewritePattern(
convolution_deprecated, conv2d_for_deprecated_conv);
rewriter_conv2d.runOnGraph(graph, filter_conv2d);
SubgraphRewriter rewriter_conv3d;
rewriter_conv3d.RegisterRewritePattern(convolution, conv3d);
rewriter_conv3d.RegisterRewritePattern(
convolution_deprecated, conv3d_for_deprecated_conv);
rewriter_conv3d.runOnGraph(graph, filter_conv3d);
SubgraphRewriter rewriter_conv_transpose1d;
rewriter_conv_transpose1d.RegisterRewritePattern(
convolution, conv_transpose1d);
rewriter_conv_transpose1d.RegisterRewritePattern(
convolution_deprecated, conv_transpose1d_for_deprecated_conv);
rewriter_conv_transpose1d.runOnGraph(graph, filter_conv_transpose1d);
SubgraphRewriter rewriter_conv_transpose2d;
rewriter_conv_transpose2d.RegisterRewritePattern(
convolution, conv_transpose2d);
rewriter_conv_transpose2d.RegisterRewritePattern(
convolution_deprecated, conv_transpose2d_for_deprecated_conv);
rewriter_conv_transpose2d.runOnGraph(graph, filter_conv_transpose2d);
SubgraphRewriter rewriter_conv_transpose3d;
rewriter_conv_transpose3d.RegisterRewritePattern(
convolution, conv_transpose3d);
rewriter_conv_transpose3d.RegisterRewritePattern(
convolution_deprecated, conv_transpose3d_for_deprecated_conv);
rewriter_conv_transpose3d.runOnGraph(graph, filter_conv_transpose3d);
}
bool isClampFusable(
const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
TORCH_CHECK(
vmap.find("dummy_min_max") != vmap.end(),
"Expected to find dummy_min_max Value in the subgraph to be replaced.");
auto dummy_min_max =
graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
// Also check if the output_min and output_max values are actually constant.
// If hardtanh's min/max Value's are not actually constants, we will end up
// rerouting those values to prepack op. And if they are not constants
// we will not be able to remove prepacking ops.
if (vmap.find("output_min") != vmap.end()) {
// aten::relu pattern does not have output_min/output_max.
// aten::hardtanh/_ does.
TORCH_CHECK(
vmap.find("output_max") != vmap.end(),
"Expected to find output_max as well given "
"output_min exist in pattern graph.");
// If output_min/max are not constant, we get std::nullopt.
auto output_min =
graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
auto output_max =
graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
is_fusable =
is_fusable && (output_min.has_value() && output_max.has_value());
}
return is_fusable;
}
} // namespace torch::jit::graph_rewrite_helper