[On Device Quantization][pytorch]Make insert_quant_dequant support ondevice ptq (#83570)

Summary:
This diff adds a way to:
- clone previously observed method
- Add calls to observer's calculate_qparams methods
- Extract the scale and zero point
- Use them to insert quant dequant nodes

Now for forward method we have
- observe_forward
- quantize_forward

observe_forward is used post training to observer statistics. In the
case of dynamic PTQ this requires just running that method once to
update weight observer statistics.

quantize_forward method will be used to use the observer
statistics to calculate quantization parameters and apply that to quant
dequant op.

Subsequent diffs will replace dequant + op with their quantized op
counter parts and replace quantize ops with relevant packed params class
where possible

Test Plan:
To be written

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D38771419](https://our.internmc.facebook.com/intern/diff/D38771419)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83570
Approved by: https://github.com/jerryzh168
This commit is contained in:
Kimish Patel
2022-08-27 16:06:14 -07:00
committed by PyTorch MergeBot
parent 6a5d9f1be0
commit 446afb5f9f
7 changed files with 453 additions and 30 deletions

View File

@ -10,6 +10,7 @@ from torch.ao.quantization import (
from torch.ao.quantization.quantize_jit import (
_prepare_ondevice_dynamic_jit,
_convert_ondevice_dynamic_jit,
)
from torch.testing._internal.common_utils import TestCase
@ -55,8 +56,18 @@ class OnDevicePTQUtils(object):
observer_module_name = ['MinMaxObserver', 'PerChannelMinMaxObserver']
@staticmethod
def insert_observers(m, qconfig_dict):
def insert_observers(model, qconfig_dict):
inputs = model.get_example_inputs()
scripted_model = get_script_module(model, False, inputs)
scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict)
return scripted_model
@staticmethod
def insert_observers_quant_dequant(model, qconfig_dict):
inputs = model.get_example_inputs()
m = get_script_module(model, False, inputs)
m = _prepare_ondevice_dynamic_jit(m, qconfig_dict)
m = _convert_ondevice_dynamic_jit(m, 'forward', True, False)
return m
@staticmethod
@ -75,23 +86,25 @@ class OnDevicePTQUtils(object):
return True
return False
@staticmethod
def is_calculate_qparam(node):
if node.kind() == "prim::CallMethod":
if node.s('name') == "calculate_qparams":
return True
return False
class TestOnDeviceDynamicPTQInsertObservers(TestCase):
def _insert_observers(self, model, qconfig_dict):
inputs = model.get_example_inputs()
scripted_model = get_script_module(model, False, inputs)
scripted_model = OnDevicePTQUtils.insert_observers(scripted_model, qconfig_dict)
return scripted_model
def _check_num_and_type_of_observers(self, model, num_observers):
qconfig_dict = {"": default_dynamic_qconfig}
scripted_model = self._insert_observers(model, qconfig_dict)
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
self.assertTrue(len(observer_modules) == num_observers)
for observer in observer_modules:
self.assertTrue(observer.original_name == 'MinMaxObserver')
qconfig_dict = {"": per_channel_dynamic_qconfig}
scripted_model = self._insert_observers(model, qconfig_dict)
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
self.assertTrue(len(observer_modules) == num_observers)
for observer in observer_modules:
@ -103,7 +116,7 @@ class TestOnDeviceDynamicPTQInsertObservers(TestCase):
orig_scripted_model = get_script_module(model, False, inputs)
torch._C._jit_pass_inline(orig_scripted_model.graph)
orig_forward_graph = orig_scripted_model.graph.str()
scripted_model = self._insert_observers(model, qconfig_dict)
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
quant_forward_graph = scripted_model.graph.str()
# exact graph matching is difficult so just resorting to # of lines
# instead of implementing graph matching
@ -135,10 +148,80 @@ class TestOnDeviceDynamicPTQInsertObservers(TestCase):
model = MyConvLinearModule()
qconfig_dict = {"": default_dynamic_qconfig}
inputs = model.get_example_inputs()
scripted_model = self._insert_observers(model, qconfig_dict)
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observe_forward_graph = scripted_model.observe_forward.graph
num_weight_only_observers = 0
for node in observe_forward_graph.nodes():
if (self._observer_is_weight_only(node)):
num_weight_only_observers += 1
self.assertEqual(num_weight_only_observers, 3)
class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase):
def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0):
quantize_forward_graph = model.quantize_forward.graph
quantize_per_tensor = quantize_per_channel = dequantize = 0
for n in quantize_forward_graph.nodes():
if "aten::quantize_per_tensor" in n.kind():
quantize_per_tensor += 1
if "aten::quantize_per_channel" in n.kind():
quantize_per_channel += 1
if "aten::dequantize" in n.kind():
dequantize += 1
self.assertEqual(quantize_per_tensor + quantize_per_channel, dequantize)
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
def _validate_calculate_qparams(self, model, num_nodes):
quantize_forward_graph = model.quantize_forward.graph
num_calculate_qparams = 0
for n in quantize_forward_graph.nodes():
if OnDevicePTQUtils.is_calculate_qparam(n):
num_calculate_qparams += 1
self.assertEqual(num_calculate_qparams, num_nodes)
def _validate_no_observer_forward(self, model):
quantize_forward_graph = model.quantize_forward.graph
for n in quantize_forward_graph.nodes():
if (n.kind() == "prim::CallMethod") and n.s("name") == "forward":
if (OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0))):
return False
return True
def _check_quant_dequant_and_calc_qparams(self, model, num_nodes):
qconfig_dict = {"": default_dynamic_qconfig}
m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict)
self._validate_quant_dequant_nodes(m, num_nodes)
self._validate_calculate_qparams(m, num_nodes)
self._validate_no_observer_forward(m)
qconfig_dict = {"": per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict)
self._validate_quant_dequant_nodes(m, num_nodes, num_nodes)
self._validate_calculate_qparams(m, num_nodes)
self._validate_no_observer_forward(m)
def _check_quantize_forward_runs(self, model):
inputs = model.get_example_inputs()
qconfig_dict = {"": default_dynamic_qconfig}
m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict)
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
qconfig_dict = {"": per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict)
# First must run observe forward to record the stats to produce
# correct scales and zero points
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
def test_num_quant_dequant_nodes(self):
model = LinearAddModel()
self._check_quant_dequant_and_calc_qparams(model, 2)
model = MyConvLinearModule()
self._check_quant_dequant_and_calc_qparams(model, 3)
def test_quantize_forward_runs(self):
model = LinearAddModel()
self._check_quantize_forward_runs(model)
model = MyConvLinearModule()
self._check_quantize_forward_runs(model)

View File

@ -261,6 +261,11 @@ def _jit_pass_insert_quant_dequant(module: 'torch.jit.ScriptModule',
inplace: _bool,
debug: _bool,
quant_type: _int): ...
def _jit_pass_insert_quant_dequant_for_ondevice_ptq(module: 'torch.jit.ScriptModule',
method_name: str,
inplace: _bool,
debug: _bool,
quant_type: _int): ...
def _jit_pass_quant_finalize(module: 'torch.jit.ScriptModule',
quant_type: _int,
preserved_attrs: Sequence[str]): ...

View File

@ -116,6 +116,21 @@ def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC,
torch._C._jit_pass_dce(model.graph)
return model
def _convert_ondevice_jit(model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC):
_check_is_script_module(model)
assert quant_type == QuantType.DYNAMIC, "This API, while should work for static quant, is only tested for dynamic quant."
assert not method_name.startswith("observe_"), "Pass in valid method to be quantized, e.g. forward"
observe_method_name = "observe_" + method_name
model_c = model._c
model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq(
model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC)
if inplace:
model._reconstruct(model_c)
else:
model = wrap_cpp_module(model_c)
return model
def convert_jit(model, inplace=False, debug=False, preserved_attrs=None):
torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit")
return _convert_jit(model, inplace, debug, quant_type=QuantType.STATIC, preserved_attrs=preserved_attrs)
@ -124,6 +139,10 @@ def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None)
torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit")
return _convert_jit(model, inplace, debug, quant_type=QuantType.DYNAMIC, preserved_attrs=preserved_attrs)
def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False):
return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC)
def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC):
# Always do inplace convert because the Tensor is already
# copied in prepare_jit when inplace is False

View File

@ -2,6 +2,7 @@
#include <c10/core/QScheme.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
@ -24,6 +25,16 @@ using DynamicQuantOps = std::tuple<Node*, Node*, Node*>;
std::string kScalarType = "_scalar_type";
struct QuantOpParams {
c10::QScheme qscheme{c10::kPerTensorAffine};
std::vector<Value*> qparams;
// This is only so that insertQuantizationOps can be templatized
// and subsequntly significant portion of that code can be reused.
std::string back() const {
return "AttributeDoesNotExist";
}
};
c10::QScheme toAffine(c10::QScheme qscheme) {
switch (qscheme) {
case c10::kPerTensorAffine:
@ -279,12 +290,20 @@ bool isEmbeddingBagOp(
embedding_bag_name.value().find("embedding_bag_") != std::string::npos;
}
// Insert quant and dequant nodes into the graph for both static and dynamic
// quant.
template <typename T>
Node* insertQuantDequantNodes(
Value* self,
Node* observer,
const std::vector<std::string>& qparam_names,
T& qparams,
const std::string& quantize_func);
// Insert quant and dequant nodes into the graph for both static and dynamic
// quant.
template <>
Node* insertQuantDequantNodes<std::vector<std::string>>(
Value* self,
Node* observer,
std::vector<std::string>& qparam_names,
const std::string& quantize_func) {
Graph* g = observer->owningGraph();
Value* observer_out = observer->output();
@ -416,12 +435,13 @@ Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) {
return qembedding_bag;
}
template <typename T>
void insertQuantizationOps(
Module& module,
Value* self,
Node* observer,
bool is_per_channel,
const std::vector<std::string>& qparam_names,
T& qparams,
QuantType quant_type = QuantType::STATIC) {
Graph* g = observer->owningGraph();
// Observer output
@ -467,7 +487,7 @@ void insertQuantizationOps(
observer_compute_dtype == at::ScalarType::QUInt8 ||
observer_compute_dtype == at::ScalarType::QInt8) {
// For activation tensors we insert choose_qparams, quant, dequant ops.
Value* dtype = g->insertGetAttr(self, qparam_names.back());
Value* dtype = g->insertGetAttr(self, qparams.back());
std::tie(choose_qparams, quant, dequant) =
insertChooseQParamQuantDequant(
g, observer_out, dtype, at::Symbol::aten(quantize_func));
@ -479,12 +499,10 @@ void insertQuantizationOps(
}
} else {
// For weight tensors we insert quant-dequant ops.
dequant =
insertQuantDequantNodes(self, observer, qparam_names, quantize_func);
dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func);
}
} else { // Static quant
dequant =
insertQuantDequantNodes(self, observer, qparam_names, quantize_func);
dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func);
}
observer_out->replaceAllUsesWith(original_val);
@ -700,10 +718,16 @@ class InsertQuantDeQuantHelper {
void run(Module& module, const std::string& method_name);
void runForOnDevicePTQ(Module& module, const std::string& method_name);
// Cleanup observer nodes from graph and observer modules
// from module object and ClassType
void cleanup(Module& module);
// Cleanup observer nodes only but not modules
// This is for ondevice PTQ
void removeObserverNodes(Module& m);
// In order to propagate quantization ops through the ops that doesn't
// require observation, we'll first inline the graph, and call the
// PropgateQuantizationOps pass
@ -725,6 +749,11 @@ class InsertQuantDeQuantHelper {
std::tuple<c10::QScheme, QParamVector> getQSchemeAndQParamVector(
script::Module& module,
Node* n);
QuantOpParams insertCalculateQParams(
script::Module& module,
Graph* g,
Node* n);
void checkQScheme(Graph* g, c10::QScheme qscheme) {
if (qscheme_for_graph_.count(g)) {
// FIXME[T110786721]: This check was broken before nevery failing.
@ -747,7 +776,12 @@ class InsertQuantDeQuantHelper {
void collectObserverNodesAndValueToQuantize(Module& module, Value*);
void cleanup(Module& module, Graph* g);
void removeObserverNodes(Graph* g);
void quantizeTensors(Module& module, Graph* g, Value* self);
void insertCalculateQParamsAndQuantizationOps(
Module& module,
Graph* g,
Value* self);
// Function that extracts and runs the weight observer in a separate
// subgraph.
@ -837,6 +871,27 @@ void InsertQuantDeQuantHelper::collectObserverNodesAndValueToQuantize(
observer_nodes_for_graph_[g].push_back(observer);
}
void InsertQuantDeQuantHelper::removeObserverNodes(Module& module) {
for (auto& method : module.get_methods()) {
removeObserverNodes(method.graph().get());
}
for (Module m : module.children()) {
removeObserverNodes(m);
}
}
void InsertQuantDeQuantHelper::removeObserverNodes(Graph* g) {
if (nodes_to_destroy_.count(g)) {
for (auto& n : nodes_to_destroy_.at(g)) {
n->removeAllInputs();
}
for (auto& n : nodes_to_destroy_.at(g)) {
n->destroy();
}
nodes_to_destroy_.at(g).clear();
}
}
void InsertQuantDeQuantHelper::cleanup(Module& module) {
for (auto& method : module.get_methods()) {
cleanup(module, method.graph().get());
@ -848,15 +903,7 @@ void InsertQuantDeQuantHelper::cleanup(Module& module) {
void InsertQuantDeQuantHelper::cleanup(Module& module, Graph* g) {
GRAPH_DUMP("Before Remove Observers:", g);
if (nodes_to_destroy_.count(g)) {
for (auto& n : nodes_to_destroy_.at(g)) {
n->removeAllInputs();
}
for (auto& n : nodes_to_destroy_.at(g)) {
n->destroy();
}
nodes_to_destroy_.at(g).clear();
}
removeObserverNodes(g);
// 1. If we have seen this graph before, this means the observer
// attributes has been removed from the type(see step 2) but the slot
@ -1471,6 +1518,200 @@ void InsertQuantDeQuantHelper::propagateQuantizationOps(Module& module) {
RemoveRedundantDequantize(graph);
}
// Insert quant and dequant nodes into the graph for both static and dynamic
// quant.
template <>
Node* insertQuantDequantNodes<QuantOpParams>(
Value* self,
Node* observer,
QuantOpParams& qparams,
const std::string& quantize_func) {
(void)self;
Graph* g = observer->owningGraph();
Value* observer_out = observer->output();
Value* original_val = observer->input(1);
std::vector<Value*> inputs;
// + 1 for tensor to be quantized
inputs.reserve(qparams.qparams.size() + 1);
inputs.push_back({observer_out});
for (const auto& qparam_values : qparams.qparams) {
inputs.push_back(qparam_values);
}
Node* quant = insertQuant(
g,
inputs,
at::Symbol::aten(quantize_func),
original_val->debugName() + ".quant");
// Have to make sure that quant node appears after the values it depends on.
for (Value* v : inputs) {
quant->moveAfter(v->node());
}
Node* dequant = insertDeQuant(g, quant->output(), original_val);
dequant->moveAfter(quant);
return dequant;
}
void checkCalculateQParamsResultTypes(const Node* out) {
TORCH_CHECK(
out->outputs().size() == 2,
"cacluate_qparams should produce output of size 2 (scale, zero_point).");
Value* scale = out->output(0);
Value* zp = out->output(1);
TORCH_CHECK(
scale->type()->expect<TensorType>(),
"Scale value should be of Tensor type.");
TORCH_CHECK(
zp->type()->expect<TensorType>(), "Scale value should be of float type.");
}
QuantOpParams InsertQuantDeQuantHelper::insertCalculateQParams(
script::Module& module,
Graph* g,
Node* n) {
// TODO: refactor findObserverName to take Node* as input
Value* self = g->inputs()[0];
Value* v = n->output();
TORCH_INTERNAL_ASSERT(
v->type()->isSubtypeOf(*TensorType::get()),
"Expected output of observer node to be Tensor");
auto observer_name = findObserverName(v);
TORCH_INTERNAL_ASSERT(
observer_name,
"getQSchemeAndParamMap expects the corresponding observer for ",
v->debugName(),
" exists.");
std::vector<Value*> qparams_graph_values;
QuantOpParams quant_op_params;
TORCH_CHECK(
!isPlaceholderObserver(n->input(0)),
"Placeholder observers are not supported in ondevice PTQ.");
auto observer_module = module.attr(observer_name.value()).toModule();
Value* observer_module_value = g->insertGetAttr(self, observer_name.value());
auto scalar_type = observer_module.attr("dtype");
TORCH_CHECK(
scalar_type.toScalarType() != at::ScalarType::Undefined,
"dtype of observer can't be undefined");
// Not sure if we need to support this for on device PTQ.
if (scalar_type == at::ScalarType::Half) {
return quant_op_params;
}
auto calculate_qparams = observer_module.get_method("calculate_qparams");
auto calculate_qparams_schema = calculate_qparams.function().getSchema();
MatchedSchema matched_schema = matchSchema(
calculate_qparams_schema,
v->node()->sourceRange(),
*g,
{observer_module_value},
{});
Node* call = g->insertMethodCall("calculate_qparams", matched_schema)->node();
Node* scale_zp_node = g->insertNode(g->createTupleUnpack(call->output(0)));
checkCalculateQParamsResultTypes(scale_zp_node);
auto qscheme = observer_module.attr("qscheme").toQScheme();
quant_op_params.qscheme = qscheme;
quant_op_params.qparams.push_back(scale_zp_node->output(0)); // scale Value*
quant_op_params.qparams.push_back(
scale_zp_node->output(1)); // zero_point Value*
if (isPerChannel(qscheme)) {
Value* ch_axis_value = g->insertGetAttr(observer_module_value, "ch_axis");
quant_op_params.qparams.push_back(ch_axis_value);
}
Value* scalar_type_value = g->insertGetAttr(observer_module_value, "dtype");
quant_op_params.qparams.push_back(scalar_type_value);
return quant_op_params;
}
void InsertQuantDeQuantHelper::insertCalculateQParamsAndQuantizationOps(
Module& module,
Graph* graph,
Value* self) {
if (!observer_nodes_for_graph_.count(graph)) {
return;
}
for (auto* n : observer_nodes_for_graph_.at(graph)) {
Graph* g = n->owningGraph();
// Observer output
Value* observer_out = n->output();
// Inserting before insert point
WithInsertPoint insert_qparams_calc(observer_out->node()->next());
auto quant_op_params = insertCalculateQParams(module, g, n);
insertQuantizationOps(
module,
self,
n,
isPerChannel(quant_op_params.qscheme),
quant_op_params,
quant_type_);
}
}
void InsertQuantDeQuantHelper::runForOnDevicePTQ(
Module& module,
const std::string& method_name) {
// In all likelihood this really wont do anything because we expect that
// the input method for quantization's prepare step will be inlined. Thus
// only call methods we will see will belong to observer's forward calls.
for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
auto& invoked_module = std::get<0>(invoked_methods);
const auto& invoked_method_name = std::get<1>(invoked_methods);
runForOnDevicePTQ(invoked_module, invoked_method_name);
}
Method method = module.get_method(method_name);
auto graph = method.graph();
// Unliked the run method we dont need to extract new qparam values for the
// the same graph used in different call site.
// Reason is that for on device PTQ we dont:
// 1. Run calculate_qparams
// 2. Get the scale and zero point
// 3. get axis and dtype
// 4. register values from 2 and 3 as attributes on the parent module.
// Instead we insert call to calculate_qparams (1) via insertCalculateQParams
// in the graph itself. Then instead of 2 and 3, we get the output Value*
// and for 3, we insert GetAttr for axis and dtype and use those Value*
// with insterQuantizationOps
// prim::Param nodes do not belong to the graph. Hence the Insert
// point is the beginning of graph node. This also safe guards against
// observing a potentially mutated value due to some in-place operation
std::vector<Value*> input_values;
for (const auto idx : c10::irange(1, method.num_inputs())) {
auto& v = graph->inputs()[idx];
if (v->type()->isSubtypeOf(*TensorType::get())) {
input_values.push_back(v);
}
}
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) {
Node* n = *it++;
for (Value* v : n->outputs()) {
if (!v->type()->isSubtypeOf(*TensorType::get())) {
continue;
}
collectObserverNodesAndValueToQuantize(module, v);
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
for (Value* v : input_values) {
collectObserverNodesAndValueToQuantize(module, v);
}
GRAPH_DUMP("Before insertCalculateQparamsAndQuantizationOps:", graph);
Value* self = graph->inputs()[0];
insertCalculateQParamsAndQuantizationOps(module, graph.get(), self);
GRAPH_DUMP("After insertCalculateQparamsAndQuantizationOps:", graph);
}
} // namespace
void ReplicateQuant(std::shared_ptr<Graph>& graph) {
@ -1573,5 +1814,57 @@ Module InsertQuantDeQuant(
return module;
}
/*
*
* Assumption: method_name method has observer placed
* Objective: modify that method to insert calls to:
* 1. calculate_qparams
* 2. GetAttr for axis and dtype values
* 3. Use Values from above two to insert calls to quant + dequant
* Thus after this step you have a graph of, e.g., observe_forward,
* that has observer nodes, calculate_qparams run on those observer nodes,
* output of which is used by quant-dequant nodes. output of dequant is used
* by the actual op.
* Later on we will replace dequant + op (e.g. linear) with
* 1. prepacked_op context
* 2. unpack
* 3. dequantize
* 4. linear
*
* Of the above pattern 2, 3, and 4 can be replaced by linear_run op
*/
// Module InsertQuantDeQuantForOnDevicePTQ(
Module InsertQuantDeQuantOnDevicePTQ(
Module& input_module,
const std::string& method_name,
bool inplace,
bool debug,
QuantType quant_type) {
Module module = input_module.clone(inplace);
const std::string kObserveString = "observe_";
const auto matched_pos = method_name.find(kObserveString);
const auto end_pos = matched_pos + kObserveString.length();
const std::string orig_method_name = method_name.substr(end_pos);
TORCH_CHECK(
matched_pos == 0,
"Quant dequant nodes can only be added to observe_",
orig_method_name,
". Please make sure to run prepare step for on-device PTQ.");
std::string quantize_method_name = "quantize_" + orig_method_name;
cloneMethod(module, method_name, quantize_method_name);
InsertQuantDeQuantHelper h(quant_type, debug);
h.runForOnDevicePTQ(module, quantize_method_name);
h.removeObserverNodes(module);
// Dont need:
// ReplicateChooseQParamsQuantDequant: This is propagating dynamic quant's
// quant dequant RemoveRedundantQuantizationOps: THis is removing activation
// observers for dynamic quant when the op related to it is not dynamically
// quantizable. Doesnt really make sense. In our case we wont have those
// anyway since for dynamic quant activations wont be observed We can still
// use this function because the above two methods should really be a noop
h.propagateQuantizationOps(module);
return module;
}
} // namespace jit
} // namespace torch

View File

@ -35,5 +35,12 @@ TORCH_API Module InsertQuantDeQuant(
bool debug,
QuantType quant_type = QuantType::STATIC);
TORCH_API Module InsertQuantDeQuantOnDevicePTQ(
Module& module,
const std::string& method_name,
bool inplace,
bool debug,
QuantType quant_type = QuantType::STATIC);
} // namespace jit
} // namespace torch

View File

@ -443,6 +443,22 @@ void initJITBindings(PyObject* module) {
py::arg("inplace"),
py::arg("debug"),
py::arg("quant_type_int") = 1)
.def(
"_jit_pass_insert_quant_dequant_for_ondevice_ptq",
[](Module& module,
const std::string& method_name,
bool inplace,
bool debug,
int quant_type_int) {
auto quant_type = static_cast<QuantType>(quant_type_int);
return InsertQuantDeQuantOnDevicePTQ(
module, method_name, inplace, debug, quant_type);
},
py::arg("module"),
py::arg("method_name"),
py::arg("inplace"),
py::arg("debug"),
py::arg("quant_type_int") = 1)
.def(
"_jit_pass_insert_prepack_unpack",
[](std::shared_ptr<Graph>& g) { return InsertPrepackUnpack(g); })

View File

@ -25,7 +25,7 @@ _all__ = [
'quantize', 'quantize_dynamic', 'quantize_qat',
'prepare', 'convert', 'prepare_qat',
# Top level API for graph mode quantization on TorchScript
'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit'
'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit', '_convert_ondevice_dynamic_jit',
# Top level API for graph mode quantization on GraphModule(torch.fx)
# 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
# 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',