mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This PR fixes more missing-prototypes violations in the torch_cpu source following PRs #100053, #100147 and #100245 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100849 Approved by: https://github.com/albanD
280 lines
10 KiB
C++
280 lines
10 KiB
C++
#include <torch/csrc/jit/passes/quantization/finalize.h>
|
|
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/clear_profiling.h>
|
|
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
|
|
#include <torch/csrc/jit/passes/constant_pooling.h>
|
|
#include <torch/csrc/jit/passes/constant_propagation.h>
|
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
|
#include <torch/csrc/jit/passes/freeze_module.h>
|
|
#include <torch/csrc/jit/passes/loop_unrolling.h>
|
|
#include <torch/csrc/jit/passes/peephole.h>
|
|
#include <torch/csrc/jit/passes/prepack_folding.h>
|
|
#include <torch/csrc/jit/passes/quantization/quantization_patterns.h>
|
|
#include <torch/csrc/jit/passes/quantization/register_packed_params.h>
|
|
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
|
|
|
#include <utility>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace {
|
|
|
|
void insertPrepackUnpackForLinear(std::shared_ptr<Graph>& graph) {
|
|
std::vector<QuantFusionInfo> patterns_and_replacements =
|
|
linear_prepack_unpack_patterns();
|
|
|
|
for (const auto& entry : patterns_and_replacements) {
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(entry.pattern, entry.replacement);
|
|
rewriter.runOnGraph(graph, entry.filters);
|
|
}
|
|
}
|
|
|
|
void insertPrepackUnpackForConv(std::shared_ptr<Graph>& graph) {
|
|
std::vector<QuantFusionInfo> patterns_and_replacements =
|
|
conv_prepack_unpack_patterns();
|
|
|
|
for (const auto& entry : patterns_and_replacements) {
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(entry.pattern, entry.replacement);
|
|
rewriter.runOnGraph(graph, entry.filters);
|
|
}
|
|
}
|
|
|
|
void removePackedParamInsertionAndFPWeightsSetAttr(
|
|
std::shared_ptr<Graph>& g,
|
|
const std::unordered_set<std::string>& packed_param_attr_names) {
|
|
DepthFirstGraphNodeIterator it(g);
|
|
Node* n = nullptr;
|
|
std::vector<Node*> nodes_to_delete;
|
|
while ((n = it.next()) != nullptr) {
|
|
if (n->kind() == prim::SetAttr) {
|
|
const std::string& attr_name = n->s(attr::name);
|
|
if (packed_param_attr_names.count(attr_name)) {
|
|
nodes_to_delete.push_back(n);
|
|
} else {
|
|
Value* v = n->input(0);
|
|
Value* self = g->inputs()[0];
|
|
std::vector<std::string> paths = getModuleAccessPath(v, self);
|
|
std::string path = joinPaths(paths);
|
|
if (packed_param_attr_names.count(path)) {
|
|
nodes_to_delete.push_back(n);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
for (auto node : nodes_to_delete) {
|
|
node->removeAllInputs();
|
|
}
|
|
for (auto node : nodes_to_delete) {
|
|
node->destroy();
|
|
}
|
|
ConstantPooling(g);
|
|
EliminateDeadCode(g);
|
|
}
|
|
|
|
void removeObserverCallMethods(std::shared_ptr<Graph>& g) {
|
|
DepthFirstGraphNodeIterator it(g);
|
|
Node* n = nullptr;
|
|
std::vector<Node*> nodes_to_delete;
|
|
while ((n = it.next()) != nullptr) {
|
|
if (n->kind() == prim::CallMethod) {
|
|
const std::string& attr_name = n->s(attr::name);
|
|
if (attr_name == "calculate_qparams") {
|
|
auto observer_node = n->input(0)->node();
|
|
if (observer_node->kind() == prim::GetAttr &&
|
|
observer_node->s(attr::name).find("_observer_") !=
|
|
std::string::npos) {
|
|
nodes_to_delete.push_back(n);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
for (auto node : nodes_to_delete) {
|
|
node->removeAllInputs();
|
|
}
|
|
for (auto node : nodes_to_delete) {
|
|
node->destroy();
|
|
}
|
|
EliminateDeadCode(g);
|
|
}
|
|
|
|
void keepOnlyPackedParamsGeneration(Module& m, const std::string& method_name) {
|
|
auto g = m.get_method(method_name).graph();
|
|
Function& function = m.get_method(method_name).function();
|
|
const auto& schema = function.getSchema();
|
|
auto new_schema = schema.cloneWithReturns({Argument("", NoneType::get())});
|
|
for (size_t i = 0, output_size = g->outputs().size(); i < output_size; i++) {
|
|
g->eraseOutput(i);
|
|
}
|
|
Node* none_node = g->createNone();
|
|
g->registerOutput(none_node->output());
|
|
none_node->insertBefore(g->return_node());
|
|
function.setSchema(std::move(new_schema));
|
|
EliminateDeadCode(g);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void QuantFusion(std::shared_ptr<Graph>& graph, QuantType quant_type) {
|
|
std::vector<QuantFusionInfo> patterns;
|
|
if (quant_type == QuantType::DYNAMIC) {
|
|
patterns = dynamic_quant_fusion_pattern_and_replacements();
|
|
std::vector<QuantFusionInfo> patterns_wo_dynamic_activation_quant =
|
|
dynamic_quantized_linear_pattern_and_replacements();
|
|
patterns.insert(
|
|
patterns.end(),
|
|
patterns_wo_dynamic_activation_quant.begin(),
|
|
patterns_wo_dynamic_activation_quant.end());
|
|
} else {
|
|
patterns = quant_fusion_pattern_and_replacements();
|
|
}
|
|
for (const auto& info : patterns) {
|
|
SubgraphRewriter rewriter;
|
|
rewriter.RegisterRewritePattern(info.pattern, info.replacement);
|
|
rewriter.runOnGraph(graph, info.filters);
|
|
}
|
|
}
|
|
|
|
void InsertPrepackUnpack(std::shared_ptr<Graph>& graph) {
|
|
insertPrepackUnpackForLinear(graph);
|
|
insertPrepackUnpackForConv(graph);
|
|
}
|
|
|
|
void InsertPrepackUnpack(Module& module) {
|
|
for (auto& method : module.get_methods()) {
|
|
auto graph = method.graph();
|
|
InsertPrepackUnpack(graph);
|
|
}
|
|
for (Module m : module.children()) {
|
|
InsertPrepackUnpack(m);
|
|
}
|
|
}
|
|
|
|
void FoldQuantizedPrepackingOps(Module& module) {
|
|
auto filter_fn = [](const Node* n) -> bool {
|
|
return (
|
|
n->kind() == Symbol::fromQualString("quantized::linear_prepack") ||
|
|
n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") ||
|
|
n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") ||
|
|
n->kind() == Symbol::fromQualString("quantized::conv3d_prepack") ||
|
|
n->kind() ==
|
|
Symbol::fromQualString("quantized::conv_transpose1d_prepack") ||
|
|
n->kind() ==
|
|
Symbol::fromQualString("quantized::conv_transpose2d_prepack"));
|
|
};
|
|
PrePackingOpsFolder(module, filter_fn, "quantized");
|
|
}
|
|
|
|
static std::unordered_set<std::string> RegisterPrePackingParams(
|
|
Module& module,
|
|
const std::string& method_name) {
|
|
auto filter_fn = [](const Node* n) -> bool {
|
|
return (
|
|
n->kind() == Symbol::fromQualString("quantized::linear_prepack") ||
|
|
n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") ||
|
|
n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") ||
|
|
n->kind() == Symbol::fromQualString("quantized::conv3d_prepack") ||
|
|
n->kind() ==
|
|
Symbol::fromQualString("quantized::conv_transpose1d_prepack") ||
|
|
n->kind() ==
|
|
Symbol::fromQualString("quantized::conv_transpose2d_prepack"));
|
|
};
|
|
return RegisterPrePackParams(module, method_name, filter_fn, "");
|
|
}
|
|
|
|
Module Finalize(
|
|
Module& module,
|
|
QuantType quant_type,
|
|
const std::vector<std::string>& preserved_attrs) {
|
|
// Tracing annotates the resulting graph with shape information. In many case,
|
|
// user applies different input shapes to traced graph. It is on the user to
|
|
// know it is correct to do so. The quantized module needs to be clean up and
|
|
// To prevent the JIT optimizations from leveraging the annotated shape info,
|
|
// clear shape information in the graph.
|
|
for (auto func : module.type()->methods()) {
|
|
ClearProfilingInformation(toGraphFunction(*func).graph());
|
|
}
|
|
|
|
auto graph = module.get_method("forward").graph();
|
|
InsertPrepackUnpack(graph);
|
|
GRAPH_DUMP("Before QuantFusion:", graph);
|
|
QuantFusion(graph, quant_type);
|
|
auto frozen = freeze_module(module, preserved_attrs);
|
|
FoldQuantizedPrepackingOps(frozen);
|
|
return frozen;
|
|
}
|
|
|
|
Module FinalizeOnDevicePTQ(
|
|
Module& module,
|
|
QuantType quant_type,
|
|
const std::string& method_name) {
|
|
// Tracing annotates the resulting graph with shape information. In many case,
|
|
// user applies different input shapes to traced graph. It is on the user to
|
|
// know it is correct to do so. The quantized module needs to be clean up and
|
|
// To prevent the JIT optimizations from leveraging the annotated shape info,
|
|
// clear shape information in the graph.
|
|
for (auto func : module.type()->methods()) {
|
|
ClearProfilingInformation(toGraphFunction(*func).graph());
|
|
}
|
|
|
|
const std::string kQuantizeString = "quantize_";
|
|
const auto matched_pos = method_name.find(kQuantizeString);
|
|
const auto end_pos = matched_pos + kQuantizeString.length();
|
|
const std::string orig_method_name = method_name.substr(end_pos);
|
|
TORCH_CHECK(
|
|
matched_pos == 0,
|
|
"Quantized ops can only be added to quantize_",
|
|
orig_method_name,
|
|
". Please make sure to run quant/dequant nodes insertion step for on-device PTQ.");
|
|
|
|
const std::string quantized_method_name = "quantized_" + orig_method_name;
|
|
auto graph = module.get_method(method_name).graph();
|
|
// Doing some AOT optimizations here
|
|
// Of all CSE seems to be required otherwise in some experiments
|
|
// serialized model is incorrect. As in it cannot be deserialized
|
|
// Rest are included as canonical optimizations that are not for inference
|
|
EliminateCommonSubexpression(graph);
|
|
EliminateDeadCode(graph);
|
|
PeepholeOptimize(graph);
|
|
ConstantPropagation(graph);
|
|
UnrollConstantLoops(graph);
|
|
ConstantPooling(graph);
|
|
|
|
InsertPrepackUnpack(graph);
|
|
GRAPH_DUMP("Before QuantFusion:", graph);
|
|
QuantFusion(graph, quant_type);
|
|
auto packed_param_attr_names = RegisterPrePackingParams(module, method_name);
|
|
GRAPH_DUMP("After QuantFusion + packed param registration:", graph);
|
|
|
|
// Now we have:
|
|
// 1. Inserted quantized weights packed params
|
|
// 2. Inserted packed params to module
|
|
// 3. Inserted quantized op
|
|
// The next thing we need is:
|
|
// 1. Replicate this method in quantize_forward
|
|
// 2. Remove SetAttr for fp weights that are reset by quantize_forward
|
|
// 3. Remove SetAttr node which will subsequently optimize away the nodes
|
|
// producing packed_params
|
|
// 4. Modify quantized_forward to remove all the nodes except for SetAttrs
|
|
cloneMethod(module, method_name, quantized_method_name);
|
|
// removeWeightSetAttrs(module, quantized_method_name);
|
|
auto quantized_graph = module.get_method(quantized_method_name).graph();
|
|
removePackedParamInsertionAndFPWeightsSetAttr(
|
|
quantized_graph, packed_param_attr_names);
|
|
// Removing packed params is not sufficient since that does not do DCE
|
|
// for observer node's getatts and callmethods because callmethods have side
|
|
// effects
|
|
removeObserverCallMethods(quantized_graph);
|
|
// This step removed the return output from the graph and subsequent
|
|
// DCE removes all the ops. After that only remaining things should be
|
|
// packed_params
|
|
keepOnlyPackedParamsGeneration(module, method_name);
|
|
return module;
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|