mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132535 Approved by: https://github.com/aaronenyeshi
61 lines
2.2 KiB
C++
61 lines
2.2 KiB
C++
#include <ATen/Context.h>
|
|
#include <torch/csrc/jit/mobile/module.h>
|
|
#include <torch/csrc/jit/mobile/quantization.h>
|
|
|
|
namespace torch::jit::mobile::quantization {
|
|
|
|
void PTQQuanizationHelper::quantize_dynamic(
|
|
torch::jit::mobile::Module& m,
|
|
const std::string& method_name) {
|
|
at::globalContext().setReleaseWeightsWhenPrepacking(false);
|
|
std::string reset_observers_method_name = "reset_observers_" + method_name;
|
|
std::string observe_method_name = "observe_" + method_name;
|
|
std::string quantize_method_name = "quantize_" + method_name;
|
|
std::string quantized_method_name = "quantized_" + method_name;
|
|
|
|
TORCH_CHECK(
|
|
m.find_method(reset_observers_method_name).has_value(),
|
|
"PTQ ready module must have",
|
|
reset_observers_method_name,
|
|
" method.");
|
|
TORCH_CHECK(
|
|
m.find_method(observe_method_name),
|
|
"PTQ ready module must have",
|
|
reset_observers_method_name,
|
|
" method.");
|
|
TORCH_CHECK(
|
|
m.find_method(quantize_method_name),
|
|
"PTQ ready module must have",
|
|
quantize_method_name,
|
|
" method.");
|
|
TORCH_CHECK(
|
|
m.find_method(quantized_method_name),
|
|
"PTQ ready module must have",
|
|
quantized_method_name,
|
|
" method.");
|
|
TORCH_CHECK(
|
|
m.find_method("get_all_bundled_inputs"),
|
|
"PTQ ready module must have get_all_bundled_inputs method.");
|
|
|
|
auto inputs = m.run_method("get_all_bundled_inputs")
|
|
.toList()
|
|
.get(0)
|
|
.toTupleRef()
|
|
.elements()
|
|
.vec();
|
|
m.get_method(reset_observers_method_name)({});
|
|
m.get_method(observe_method_name)(inputs);
|
|
m.get_method(quantize_method_name)(inputs);
|
|
|
|
m.compareMethodSchemas(method_name, quantized_method_name);
|
|
m.unsafeRemoveMethod(method_name);
|
|
const auto& quantized_method = m.find_method(quantized_method_name);
|
|
const Function& to_be_copied = quantized_method.value().function();
|
|
m.unsafeCopyMethod(method_name, to_be_copied);
|
|
m.unsafeRemoveMethod(quantized_method_name);
|
|
m.unsafeRemoveMethod(quantize_method_name);
|
|
m.unsafeRemoveMethod(observe_method_name);
|
|
m.unsafeRemoveMethod(reset_observers_method_name);
|
|
}
|
|
} // namespace torch::jit::mobile::quantization
|