mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-25 16:14:55 +08:00
[On Device Quantization][pytorch]Make insert_quant_dequant support ondevice ptq (#83570)
Summary: This diff adds a way to: - clone previously observed method - Add calls to observer's calculate_qparams methods - Extract the scale and zero point - Use them to insert quant dequant nodes Now for forward method we have - observe_forward - quantize_forward observe_forward is used post training to observer statistics. In the case of dynamic PTQ this requires just running that method once to update weight observer statistics. quantize_forward method will be used to use the observer statistics to calculate quantization parameters and apply that to quant dequant op. Subsequent diffs will replace dequant + op with their quantized op counter parts and replace quantize ops with relevant packed params class where possible Test Plan: To be written Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D38771419](https://our.internmc.facebook.com/intern/diff/D38771419) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83570 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
6a5d9f1be0
commit
446afb5f9f
@ -10,6 +10,7 @@ from torch.ao.quantization import (
|
||||
|
||||
from torch.ao.quantization.quantize_jit import (
|
||||
_prepare_ondevice_dynamic_jit,
|
||||
_convert_ondevice_dynamic_jit,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
@ -55,8 +56,18 @@ class OnDevicePTQUtils(object):
|
||||
observer_module_name = ['MinMaxObserver', 'PerChannelMinMaxObserver']
|
||||
|
||||
@staticmethod
|
||||
def insert_observers(m, qconfig_dict):
|
||||
def insert_observers(model, qconfig_dict):
|
||||
inputs = model.get_example_inputs()
|
||||
scripted_model = get_script_module(model, False, inputs)
|
||||
scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict)
|
||||
return scripted_model
|
||||
|
||||
@staticmethod
|
||||
def insert_observers_quant_dequant(model, qconfig_dict):
|
||||
inputs = model.get_example_inputs()
|
||||
m = get_script_module(model, False, inputs)
|
||||
m = _prepare_ondevice_dynamic_jit(m, qconfig_dict)
|
||||
m = _convert_ondevice_dynamic_jit(m, 'forward', True, False)
|
||||
return m
|
||||
|
||||
@staticmethod
|
||||
@ -75,23 +86,25 @@ class OnDevicePTQUtils(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_calculate_qparam(node):
|
||||
if node.kind() == "prim::CallMethod":
|
||||
if node.s('name') == "calculate_qparams":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class TestOnDeviceDynamicPTQInsertObservers(TestCase):
|
||||
def _insert_observers(self, model, qconfig_dict):
|
||||
inputs = model.get_example_inputs()
|
||||
scripted_model = get_script_module(model, False, inputs)
|
||||
scripted_model = OnDevicePTQUtils.insert_observers(scripted_model, qconfig_dict)
|
||||
return scripted_model
|
||||
|
||||
def _check_num_and_type_of_observers(self, model, num_observers):
|
||||
qconfig_dict = {"": default_dynamic_qconfig}
|
||||
scripted_model = self._insert_observers(model, qconfig_dict)
|
||||
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
||||
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
|
||||
self.assertTrue(len(observer_modules) == num_observers)
|
||||
for observer in observer_modules:
|
||||
self.assertTrue(observer.original_name == 'MinMaxObserver')
|
||||
|
||||
qconfig_dict = {"": per_channel_dynamic_qconfig}
|
||||
scripted_model = self._insert_observers(model, qconfig_dict)
|
||||
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
||||
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
|
||||
self.assertTrue(len(observer_modules) == num_observers)
|
||||
for observer in observer_modules:
|
||||
@ -103,7 +116,7 @@ class TestOnDeviceDynamicPTQInsertObservers(TestCase):
|
||||
orig_scripted_model = get_script_module(model, False, inputs)
|
||||
torch._C._jit_pass_inline(orig_scripted_model.graph)
|
||||
orig_forward_graph = orig_scripted_model.graph.str()
|
||||
scripted_model = self._insert_observers(model, qconfig_dict)
|
||||
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
||||
quant_forward_graph = scripted_model.graph.str()
|
||||
# exact graph matching is difficult so just resorting to # of lines
|
||||
# instead of implementing graph matching
|
||||
@ -135,10 +148,80 @@ class TestOnDeviceDynamicPTQInsertObservers(TestCase):
|
||||
model = MyConvLinearModule()
|
||||
qconfig_dict = {"": default_dynamic_qconfig}
|
||||
inputs = model.get_example_inputs()
|
||||
scripted_model = self._insert_observers(model, qconfig_dict)
|
||||
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
||||
observe_forward_graph = scripted_model.observe_forward.graph
|
||||
num_weight_only_observers = 0
|
||||
for node in observe_forward_graph.nodes():
|
||||
if (self._observer_is_weight_only(node)):
|
||||
num_weight_only_observers += 1
|
||||
self.assertEqual(num_weight_only_observers, 3)
|
||||
|
||||
|
||||
class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase):
|
||||
def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0):
|
||||
quantize_forward_graph = model.quantize_forward.graph
|
||||
quantize_per_tensor = quantize_per_channel = dequantize = 0
|
||||
for n in quantize_forward_graph.nodes():
|
||||
if "aten::quantize_per_tensor" in n.kind():
|
||||
quantize_per_tensor += 1
|
||||
if "aten::quantize_per_channel" in n.kind():
|
||||
quantize_per_channel += 1
|
||||
if "aten::dequantize" in n.kind():
|
||||
dequantize += 1
|
||||
self.assertEqual(quantize_per_tensor + quantize_per_channel, dequantize)
|
||||
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
|
||||
|
||||
def _validate_calculate_qparams(self, model, num_nodes):
|
||||
quantize_forward_graph = model.quantize_forward.graph
|
||||
num_calculate_qparams = 0
|
||||
for n in quantize_forward_graph.nodes():
|
||||
if OnDevicePTQUtils.is_calculate_qparam(n):
|
||||
num_calculate_qparams += 1
|
||||
self.assertEqual(num_calculate_qparams, num_nodes)
|
||||
|
||||
def _validate_no_observer_forward(self, model):
|
||||
quantize_forward_graph = model.quantize_forward.graph
|
||||
for n in quantize_forward_graph.nodes():
|
||||
if (n.kind() == "prim::CallMethod") and n.s("name") == "forward":
|
||||
if (OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0))):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_quant_dequant_and_calc_qparams(self, model, num_nodes):
|
||||
qconfig_dict = {"": default_dynamic_qconfig}
|
||||
m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict)
|
||||
self._validate_quant_dequant_nodes(m, num_nodes)
|
||||
self._validate_calculate_qparams(m, num_nodes)
|
||||
self._validate_no_observer_forward(m)
|
||||
|
||||
qconfig_dict = {"": per_channel_dynamic_qconfig}
|
||||
m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict)
|
||||
self._validate_quant_dequant_nodes(m, num_nodes, num_nodes)
|
||||
self._validate_calculate_qparams(m, num_nodes)
|
||||
self._validate_no_observer_forward(m)
|
||||
|
||||
def _check_quantize_forward_runs(self, model):
|
||||
inputs = model.get_example_inputs()
|
||||
qconfig_dict = {"": default_dynamic_qconfig}
|
||||
m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict)
|
||||
m.observe_forward(*inputs)
|
||||
m.quantize_forward(*inputs)
|
||||
|
||||
qconfig_dict = {"": per_channel_dynamic_qconfig}
|
||||
m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict)
|
||||
# First must run observe forward to record the stats to produce
|
||||
# correct scales and zero points
|
||||
m.observe_forward(*inputs)
|
||||
m.quantize_forward(*inputs)
|
||||
|
||||
def test_num_quant_dequant_nodes(self):
|
||||
model = LinearAddModel()
|
||||
self._check_quant_dequant_and_calc_qparams(model, 2)
|
||||
model = MyConvLinearModule()
|
||||
self._check_quant_dequant_and_calc_qparams(model, 3)
|
||||
|
||||
def test_quantize_forward_runs(self):
|
||||
model = LinearAddModel()
|
||||
self._check_quantize_forward_runs(model)
|
||||
model = MyConvLinearModule()
|
||||
self._check_quantize_forward_runs(model)
|
||||
|
||||
@ -261,6 +261,11 @@ def _jit_pass_insert_quant_dequant(module: 'torch.jit.ScriptModule',
|
||||
inplace: _bool,
|
||||
debug: _bool,
|
||||
quant_type: _int): ...
|
||||
def _jit_pass_insert_quant_dequant_for_ondevice_ptq(module: 'torch.jit.ScriptModule',
|
||||
method_name: str,
|
||||
inplace: _bool,
|
||||
debug: _bool,
|
||||
quant_type: _int): ...
|
||||
def _jit_pass_quant_finalize(module: 'torch.jit.ScriptModule',
|
||||
quant_type: _int,
|
||||
preserved_attrs: Sequence[str]): ...
|
||||
|
||||
@ -116,6 +116,21 @@ def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC,
|
||||
torch._C._jit_pass_dce(model.graph)
|
||||
return model
|
||||
|
||||
|
||||
def _convert_ondevice_jit(model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC):
|
||||
_check_is_script_module(model)
|
||||
assert quant_type == QuantType.DYNAMIC, "This API, while should work for static quant, is only tested for dynamic quant."
|
||||
assert not method_name.startswith("observe_"), "Pass in valid method to be quantized, e.g. forward"
|
||||
observe_method_name = "observe_" + method_name
|
||||
model_c = model._c
|
||||
model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq(
|
||||
model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC)
|
||||
if inplace:
|
||||
model._reconstruct(model_c)
|
||||
else:
|
||||
model = wrap_cpp_module(model_c)
|
||||
return model
|
||||
|
||||
def convert_jit(model, inplace=False, debug=False, preserved_attrs=None):
|
||||
torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit")
|
||||
return _convert_jit(model, inplace, debug, quant_type=QuantType.STATIC, preserved_attrs=preserved_attrs)
|
||||
@ -124,6 +139,10 @@ def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None)
|
||||
torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit")
|
||||
return _convert_jit(model, inplace, debug, quant_type=QuantType.DYNAMIC, preserved_attrs=preserved_attrs)
|
||||
|
||||
|
||||
def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False):
|
||||
return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC)
|
||||
|
||||
def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC):
|
||||
# Always do inplace convert because the Tensor is already
|
||||
# copied in prepare_jit when inplace is False
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
#include <c10/core/QScheme.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/frontend/schema_matching.h>
|
||||
#include <torch/csrc/jit/ir/subgraph_matcher.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
@ -24,6 +25,16 @@ using DynamicQuantOps = std::tuple<Node*, Node*, Node*>;
|
||||
|
||||
std::string kScalarType = "_scalar_type";
|
||||
|
||||
struct QuantOpParams {
|
||||
c10::QScheme qscheme{c10::kPerTensorAffine};
|
||||
std::vector<Value*> qparams;
|
||||
// This is only so that insertQuantizationOps can be templatized
|
||||
// and subsequntly significant portion of that code can be reused.
|
||||
std::string back() const {
|
||||
return "AttributeDoesNotExist";
|
||||
}
|
||||
};
|
||||
|
||||
c10::QScheme toAffine(c10::QScheme qscheme) {
|
||||
switch (qscheme) {
|
||||
case c10::kPerTensorAffine:
|
||||
@ -279,12 +290,20 @@ bool isEmbeddingBagOp(
|
||||
embedding_bag_name.value().find("embedding_bag_") != std::string::npos;
|
||||
}
|
||||
|
||||
// Insert quant and dequant nodes into the graph for both static and dynamic
|
||||
// quant.
|
||||
template <typename T>
|
||||
Node* insertQuantDequantNodes(
|
||||
Value* self,
|
||||
Node* observer,
|
||||
const std::vector<std::string>& qparam_names,
|
||||
T& qparams,
|
||||
const std::string& quantize_func);
|
||||
|
||||
// Insert quant and dequant nodes into the graph for both static and dynamic
|
||||
// quant.
|
||||
template <>
|
||||
Node* insertQuantDequantNodes<std::vector<std::string>>(
|
||||
Value* self,
|
||||
Node* observer,
|
||||
std::vector<std::string>& qparam_names,
|
||||
const std::string& quantize_func) {
|
||||
Graph* g = observer->owningGraph();
|
||||
Value* observer_out = observer->output();
|
||||
@ -416,12 +435,13 @@ Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) {
|
||||
return qembedding_bag;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void insertQuantizationOps(
|
||||
Module& module,
|
||||
Value* self,
|
||||
Node* observer,
|
||||
bool is_per_channel,
|
||||
const std::vector<std::string>& qparam_names,
|
||||
T& qparams,
|
||||
QuantType quant_type = QuantType::STATIC) {
|
||||
Graph* g = observer->owningGraph();
|
||||
// Observer output
|
||||
@ -467,7 +487,7 @@ void insertQuantizationOps(
|
||||
observer_compute_dtype == at::ScalarType::QUInt8 ||
|
||||
observer_compute_dtype == at::ScalarType::QInt8) {
|
||||
// For activation tensors we insert choose_qparams, quant, dequant ops.
|
||||
Value* dtype = g->insertGetAttr(self, qparam_names.back());
|
||||
Value* dtype = g->insertGetAttr(self, qparams.back());
|
||||
std::tie(choose_qparams, quant, dequant) =
|
||||
insertChooseQParamQuantDequant(
|
||||
g, observer_out, dtype, at::Symbol::aten(quantize_func));
|
||||
@ -479,12 +499,10 @@ void insertQuantizationOps(
|
||||
}
|
||||
} else {
|
||||
// For weight tensors we insert quant-dequant ops.
|
||||
dequant =
|
||||
insertQuantDequantNodes(self, observer, qparam_names, quantize_func);
|
||||
dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func);
|
||||
}
|
||||
} else { // Static quant
|
||||
dequant =
|
||||
insertQuantDequantNodes(self, observer, qparam_names, quantize_func);
|
||||
dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func);
|
||||
}
|
||||
observer_out->replaceAllUsesWith(original_val);
|
||||
|
||||
@ -700,10 +718,16 @@ class InsertQuantDeQuantHelper {
|
||||
|
||||
void run(Module& module, const std::string& method_name);
|
||||
|
||||
void runForOnDevicePTQ(Module& module, const std::string& method_name);
|
||||
|
||||
// Cleanup observer nodes from graph and observer modules
|
||||
// from module object and ClassType
|
||||
void cleanup(Module& module);
|
||||
|
||||
// Cleanup observer nodes only but not modules
|
||||
// This is for ondevice PTQ
|
||||
void removeObserverNodes(Module& m);
|
||||
|
||||
// In order to propagate quantization ops through the ops that doesn't
|
||||
// require observation, we'll first inline the graph, and call the
|
||||
// PropgateQuantizationOps pass
|
||||
@ -725,6 +749,11 @@ class InsertQuantDeQuantHelper {
|
||||
std::tuple<c10::QScheme, QParamVector> getQSchemeAndQParamVector(
|
||||
script::Module& module,
|
||||
Node* n);
|
||||
QuantOpParams insertCalculateQParams(
|
||||
script::Module& module,
|
||||
Graph* g,
|
||||
Node* n);
|
||||
|
||||
void checkQScheme(Graph* g, c10::QScheme qscheme) {
|
||||
if (qscheme_for_graph_.count(g)) {
|
||||
// FIXME[T110786721]: This check was broken before nevery failing.
|
||||
@ -747,7 +776,12 @@ class InsertQuantDeQuantHelper {
|
||||
|
||||
void collectObserverNodesAndValueToQuantize(Module& module, Value*);
|
||||
void cleanup(Module& module, Graph* g);
|
||||
void removeObserverNodes(Graph* g);
|
||||
void quantizeTensors(Module& module, Graph* g, Value* self);
|
||||
void insertCalculateQParamsAndQuantizationOps(
|
||||
Module& module,
|
||||
Graph* g,
|
||||
Value* self);
|
||||
|
||||
// Function that extracts and runs the weight observer in a separate
|
||||
// subgraph.
|
||||
@ -837,6 +871,27 @@ void InsertQuantDeQuantHelper::collectObserverNodesAndValueToQuantize(
|
||||
observer_nodes_for_graph_[g].push_back(observer);
|
||||
}
|
||||
|
||||
void InsertQuantDeQuantHelper::removeObserverNodes(Module& module) {
|
||||
for (auto& method : module.get_methods()) {
|
||||
removeObserverNodes(method.graph().get());
|
||||
}
|
||||
for (Module m : module.children()) {
|
||||
removeObserverNodes(m);
|
||||
}
|
||||
}
|
||||
|
||||
void InsertQuantDeQuantHelper::removeObserverNodes(Graph* g) {
|
||||
if (nodes_to_destroy_.count(g)) {
|
||||
for (auto& n : nodes_to_destroy_.at(g)) {
|
||||
n->removeAllInputs();
|
||||
}
|
||||
for (auto& n : nodes_to_destroy_.at(g)) {
|
||||
n->destroy();
|
||||
}
|
||||
nodes_to_destroy_.at(g).clear();
|
||||
}
|
||||
}
|
||||
|
||||
void InsertQuantDeQuantHelper::cleanup(Module& module) {
|
||||
for (auto& method : module.get_methods()) {
|
||||
cleanup(module, method.graph().get());
|
||||
@ -848,15 +903,7 @@ void InsertQuantDeQuantHelper::cleanup(Module& module) {
|
||||
|
||||
void InsertQuantDeQuantHelper::cleanup(Module& module, Graph* g) {
|
||||
GRAPH_DUMP("Before Remove Observers:", g);
|
||||
if (nodes_to_destroy_.count(g)) {
|
||||
for (auto& n : nodes_to_destroy_.at(g)) {
|
||||
n->removeAllInputs();
|
||||
}
|
||||
for (auto& n : nodes_to_destroy_.at(g)) {
|
||||
n->destroy();
|
||||
}
|
||||
nodes_to_destroy_.at(g).clear();
|
||||
}
|
||||
removeObserverNodes(g);
|
||||
|
||||
// 1. If we have seen this graph before, this means the observer
|
||||
// attributes has been removed from the type(see step 2) but the slot
|
||||
@ -1471,6 +1518,200 @@ void InsertQuantDeQuantHelper::propagateQuantizationOps(Module& module) {
|
||||
RemoveRedundantDequantize(graph);
|
||||
}
|
||||
|
||||
// Insert quant and dequant nodes into the graph for both static and dynamic
|
||||
// quant.
|
||||
template <>
|
||||
Node* insertQuantDequantNodes<QuantOpParams>(
|
||||
Value* self,
|
||||
Node* observer,
|
||||
QuantOpParams& qparams,
|
||||
const std::string& quantize_func) {
|
||||
(void)self;
|
||||
Graph* g = observer->owningGraph();
|
||||
Value* observer_out = observer->output();
|
||||
Value* original_val = observer->input(1);
|
||||
std::vector<Value*> inputs;
|
||||
// + 1 for tensor to be quantized
|
||||
inputs.reserve(qparams.qparams.size() + 1);
|
||||
inputs.push_back({observer_out});
|
||||
for (const auto& qparam_values : qparams.qparams) {
|
||||
inputs.push_back(qparam_values);
|
||||
}
|
||||
Node* quant = insertQuant(
|
||||
g,
|
||||
inputs,
|
||||
at::Symbol::aten(quantize_func),
|
||||
original_val->debugName() + ".quant");
|
||||
// Have to make sure that quant node appears after the values it depends on.
|
||||
for (Value* v : inputs) {
|
||||
quant->moveAfter(v->node());
|
||||
}
|
||||
Node* dequant = insertDeQuant(g, quant->output(), original_val);
|
||||
dequant->moveAfter(quant);
|
||||
return dequant;
|
||||
}
|
||||
|
||||
void checkCalculateQParamsResultTypes(const Node* out) {
|
||||
TORCH_CHECK(
|
||||
out->outputs().size() == 2,
|
||||
"cacluate_qparams should produce output of size 2 (scale, zero_point).");
|
||||
Value* scale = out->output(0);
|
||||
Value* zp = out->output(1);
|
||||
TORCH_CHECK(
|
||||
scale->type()->expect<TensorType>(),
|
||||
"Scale value should be of Tensor type.");
|
||||
TORCH_CHECK(
|
||||
zp->type()->expect<TensorType>(), "Scale value should be of float type.");
|
||||
}
|
||||
|
||||
QuantOpParams InsertQuantDeQuantHelper::insertCalculateQParams(
|
||||
script::Module& module,
|
||||
Graph* g,
|
||||
Node* n) {
|
||||
// TODO: refactor findObserverName to take Node* as input
|
||||
Value* self = g->inputs()[0];
|
||||
Value* v = n->output();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
v->type()->isSubtypeOf(*TensorType::get()),
|
||||
"Expected output of observer node to be Tensor");
|
||||
auto observer_name = findObserverName(v);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
observer_name,
|
||||
"getQSchemeAndParamMap expects the corresponding observer for ",
|
||||
v->debugName(),
|
||||
" exists.");
|
||||
std::vector<Value*> qparams_graph_values;
|
||||
QuantOpParams quant_op_params;
|
||||
|
||||
TORCH_CHECK(
|
||||
!isPlaceholderObserver(n->input(0)),
|
||||
"Placeholder observers are not supported in ondevice PTQ.");
|
||||
auto observer_module = module.attr(observer_name.value()).toModule();
|
||||
Value* observer_module_value = g->insertGetAttr(self, observer_name.value());
|
||||
auto scalar_type = observer_module.attr("dtype");
|
||||
TORCH_CHECK(
|
||||
scalar_type.toScalarType() != at::ScalarType::Undefined,
|
||||
"dtype of observer can't be undefined");
|
||||
// Not sure if we need to support this for on device PTQ.
|
||||
if (scalar_type == at::ScalarType::Half) {
|
||||
return quant_op_params;
|
||||
}
|
||||
auto calculate_qparams = observer_module.get_method("calculate_qparams");
|
||||
auto calculate_qparams_schema = calculate_qparams.function().getSchema();
|
||||
MatchedSchema matched_schema = matchSchema(
|
||||
calculate_qparams_schema,
|
||||
v->node()->sourceRange(),
|
||||
*g,
|
||||
{observer_module_value},
|
||||
{});
|
||||
Node* call = g->insertMethodCall("calculate_qparams", matched_schema)->node();
|
||||
Node* scale_zp_node = g->insertNode(g->createTupleUnpack(call->output(0)));
|
||||
checkCalculateQParamsResultTypes(scale_zp_node);
|
||||
auto qscheme = observer_module.attr("qscheme").toQScheme();
|
||||
quant_op_params.qscheme = qscheme;
|
||||
quant_op_params.qparams.push_back(scale_zp_node->output(0)); // scale Value*
|
||||
quant_op_params.qparams.push_back(
|
||||
scale_zp_node->output(1)); // zero_point Value*
|
||||
if (isPerChannel(qscheme)) {
|
||||
Value* ch_axis_value = g->insertGetAttr(observer_module_value, "ch_axis");
|
||||
quant_op_params.qparams.push_back(ch_axis_value);
|
||||
}
|
||||
Value* scalar_type_value = g->insertGetAttr(observer_module_value, "dtype");
|
||||
quant_op_params.qparams.push_back(scalar_type_value);
|
||||
return quant_op_params;
|
||||
}
|
||||
|
||||
void InsertQuantDeQuantHelper::insertCalculateQParamsAndQuantizationOps(
|
||||
Module& module,
|
||||
Graph* graph,
|
||||
Value* self) {
|
||||
if (!observer_nodes_for_graph_.count(graph)) {
|
||||
return;
|
||||
}
|
||||
for (auto* n : observer_nodes_for_graph_.at(graph)) {
|
||||
Graph* g = n->owningGraph();
|
||||
// Observer output
|
||||
Value* observer_out = n->output();
|
||||
// Inserting before insert point
|
||||
WithInsertPoint insert_qparams_calc(observer_out->node()->next());
|
||||
auto quant_op_params = insertCalculateQParams(module, g, n);
|
||||
insertQuantizationOps(
|
||||
module,
|
||||
self,
|
||||
n,
|
||||
isPerChannel(quant_op_params.qscheme),
|
||||
quant_op_params,
|
||||
quant_type_);
|
||||
}
|
||||
}
|
||||
|
||||
void InsertQuantDeQuantHelper::runForOnDevicePTQ(
|
||||
Module& module,
|
||||
const std::string& method_name) {
|
||||
// In all likelihood this really wont do anything because we expect that
|
||||
// the input method for quantization's prepare step will be inlined. Thus
|
||||
// only call methods we will see will belong to observer's forward calls.
|
||||
for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
|
||||
auto& invoked_module = std::get<0>(invoked_methods);
|
||||
const auto& invoked_method_name = std::get<1>(invoked_methods);
|
||||
runForOnDevicePTQ(invoked_module, invoked_method_name);
|
||||
}
|
||||
|
||||
Method method = module.get_method(method_name);
|
||||
auto graph = method.graph();
|
||||
// Unliked the run method we dont need to extract new qparam values for the
|
||||
// the same graph used in different call site.
|
||||
// Reason is that for on device PTQ we dont:
|
||||
// 1. Run calculate_qparams
|
||||
// 2. Get the scale and zero point
|
||||
// 3. get axis and dtype
|
||||
// 4. register values from 2 and 3 as attributes on the parent module.
|
||||
// Instead we insert call to calculate_qparams (1) via insertCalculateQParams
|
||||
// in the graph itself. Then instead of 2 and 3, we get the output Value*
|
||||
// and for 3, we insert GetAttr for axis and dtype and use those Value*
|
||||
// with insterQuantizationOps
|
||||
|
||||
// prim::Param nodes do not belong to the graph. Hence the Insert
|
||||
// point is the beginning of graph node. This also safe guards against
|
||||
// observing a potentially mutated value due to some in-place operation
|
||||
std::vector<Value*> input_values;
|
||||
for (const auto idx : c10::irange(1, method.num_inputs())) {
|
||||
auto& v = graph->inputs()[idx];
|
||||
if (v->type()->isSubtypeOf(*TensorType::get())) {
|
||||
input_values.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
std::stack<Block*> blocks_to_visit;
|
||||
blocks_to_visit.push(graph->block());
|
||||
while (!blocks_to_visit.empty()) {
|
||||
Block* b = blocks_to_visit.top();
|
||||
blocks_to_visit.pop();
|
||||
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) {
|
||||
Node* n = *it++;
|
||||
for (Value* v : n->outputs()) {
|
||||
if (!v->type()->isSubtypeOf(*TensorType::get())) {
|
||||
continue;
|
||||
}
|
||||
collectObserverNodesAndValueToQuantize(module, v);
|
||||
}
|
||||
|
||||
for (Block* subblock : n->blocks()) {
|
||||
blocks_to_visit.push(subblock);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Value* v : input_values) {
|
||||
collectObserverNodesAndValueToQuantize(module, v);
|
||||
}
|
||||
|
||||
GRAPH_DUMP("Before insertCalculateQparamsAndQuantizationOps:", graph);
|
||||
Value* self = graph->inputs()[0];
|
||||
insertCalculateQParamsAndQuantizationOps(module, graph.get(), self);
|
||||
GRAPH_DUMP("After insertCalculateQparamsAndQuantizationOps:", graph);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ReplicateQuant(std::shared_ptr<Graph>& graph) {
|
||||
@ -1573,5 +1814,57 @@ Module InsertQuantDeQuant(
|
||||
return module;
|
||||
}
|
||||
|
||||
/*
|
||||
*
|
||||
* Assumption: method_name method has observer placed
|
||||
* Objective: modify that method to insert calls to:
|
||||
* 1. calculate_qparams
|
||||
* 2. GetAttr for axis and dtype values
|
||||
* 3. Use Values from above two to insert calls to quant + dequant
|
||||
* Thus after this step you have a graph of, e.g., observe_forward,
|
||||
* that has observer nodes, calculate_qparams run on those observer nodes,
|
||||
* output of which is used by quant-dequant nodes. output of dequant is used
|
||||
* by the actual op.
|
||||
* Later on we will replace dequant + op (e.g. linear) with
|
||||
* 1. prepacked_op context
|
||||
* 2. unpack
|
||||
* 3. dequantize
|
||||
* 4. linear
|
||||
*
|
||||
* Of the above pattern 2, 3, and 4 can be replaced by linear_run op
|
||||
*/
|
||||
// Module InsertQuantDeQuantForOnDevicePTQ(
|
||||
Module InsertQuantDeQuantOnDevicePTQ(
|
||||
Module& input_module,
|
||||
const std::string& method_name,
|
||||
bool inplace,
|
||||
bool debug,
|
||||
QuantType quant_type) {
|
||||
Module module = input_module.clone(inplace);
|
||||
const std::string kObserveString = "observe_";
|
||||
const auto matched_pos = method_name.find(kObserveString);
|
||||
const auto end_pos = matched_pos + kObserveString.length();
|
||||
const std::string orig_method_name = method_name.substr(end_pos);
|
||||
TORCH_CHECK(
|
||||
matched_pos == 0,
|
||||
"Quant dequant nodes can only be added to observe_",
|
||||
orig_method_name,
|
||||
". Please make sure to run prepare step for on-device PTQ.");
|
||||
|
||||
std::string quantize_method_name = "quantize_" + orig_method_name;
|
||||
cloneMethod(module, method_name, quantize_method_name);
|
||||
InsertQuantDeQuantHelper h(quant_type, debug);
|
||||
h.runForOnDevicePTQ(module, quantize_method_name);
|
||||
h.removeObserverNodes(module);
|
||||
// Dont need:
|
||||
// ReplicateChooseQParamsQuantDequant: This is propagating dynamic quant's
|
||||
// quant dequant RemoveRedundantQuantizationOps: THis is removing activation
|
||||
// observers for dynamic quant when the op related to it is not dynamically
|
||||
// quantizable. Doesnt really make sense. In our case we wont have those
|
||||
// anyway since for dynamic quant activations wont be observed We can still
|
||||
// use this function because the above two methods should really be a noop
|
||||
h.propagateQuantizationOps(module);
|
||||
return module;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -35,5 +35,12 @@ TORCH_API Module InsertQuantDeQuant(
|
||||
bool debug,
|
||||
QuantType quant_type = QuantType::STATIC);
|
||||
|
||||
TORCH_API Module InsertQuantDeQuantOnDevicePTQ(
|
||||
Module& module,
|
||||
const std::string& method_name,
|
||||
bool inplace,
|
||||
bool debug,
|
||||
QuantType quant_type = QuantType::STATIC);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -443,6 +443,22 @@ void initJITBindings(PyObject* module) {
|
||||
py::arg("inplace"),
|
||||
py::arg("debug"),
|
||||
py::arg("quant_type_int") = 1)
|
||||
.def(
|
||||
"_jit_pass_insert_quant_dequant_for_ondevice_ptq",
|
||||
[](Module& module,
|
||||
const std::string& method_name,
|
||||
bool inplace,
|
||||
bool debug,
|
||||
int quant_type_int) {
|
||||
auto quant_type = static_cast<QuantType>(quant_type_int);
|
||||
return InsertQuantDeQuantOnDevicePTQ(
|
||||
module, method_name, inplace, debug, quant_type);
|
||||
},
|
||||
py::arg("module"),
|
||||
py::arg("method_name"),
|
||||
py::arg("inplace"),
|
||||
py::arg("debug"),
|
||||
py::arg("quant_type_int") = 1)
|
||||
.def(
|
||||
"_jit_pass_insert_prepack_unpack",
|
||||
[](std::shared_ptr<Graph>& g) { return InsertPrepackUnpack(g); })
|
||||
|
||||
@ -25,7 +25,7 @@ _all__ = [
|
||||
'quantize', 'quantize_dynamic', 'quantize_qat',
|
||||
'prepare', 'convert', 'prepare_qat',
|
||||
# Top level API for graph mode quantization on TorchScript
|
||||
'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit'
|
||||
'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit', '_convert_ondevice_dynamic_jit',
|
||||
# Top level API for graph mode quantization on GraphModule(torch.fx)
|
||||
# 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
|
||||
# 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
|
||||
|
||||
Reference in New Issue
Block a user