diff --git a/test/quantization/jit/test_ondevice_quantization.py b/test/quantization/jit/test_ondevice_quantization.py index 20540ac5837c..74126af1be01 100644 --- a/test/quantization/jit/test_ondevice_quantization.py +++ b/test/quantization/jit/test_ondevice_quantization.py @@ -10,6 +10,7 @@ from torch.ao.quantization import ( from torch.ao.quantization.quantize_jit import ( _prepare_ondevice_dynamic_jit, + _convert_ondevice_dynamic_jit, ) from torch.testing._internal.common_utils import TestCase @@ -55,8 +56,18 @@ class OnDevicePTQUtils(object): observer_module_name = ['MinMaxObserver', 'PerChannelMinMaxObserver'] @staticmethod - def insert_observers(m, qconfig_dict): + def insert_observers(model, qconfig_dict): + inputs = model.get_example_inputs() + scripted_model = get_script_module(model, False, inputs) + scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict) + return scripted_model + + @staticmethod + def insert_observers_quant_dequant(model, qconfig_dict): + inputs = model.get_example_inputs() + m = get_script_module(model, False, inputs) m = _prepare_ondevice_dynamic_jit(m, qconfig_dict) + m = _convert_ondevice_dynamic_jit(m, 'forward', True, False) return m @staticmethod @@ -75,23 +86,25 @@ class OnDevicePTQUtils(object): return True return False + @staticmethod + def is_calculate_qparam(node): + if node.kind() == "prim::CallMethod": + if node.s('name') == "calculate_qparams": + return True + return False + class TestOnDeviceDynamicPTQInsertObservers(TestCase): - def _insert_observers(self, model, qconfig_dict): - inputs = model.get_example_inputs() - scripted_model = get_script_module(model, False, inputs) - scripted_model = OnDevicePTQUtils.insert_observers(scripted_model, qconfig_dict) - return scripted_model - def _check_num_and_type_of_observers(self, model, num_observers): qconfig_dict = {"": default_dynamic_qconfig} - scripted_model = self._insert_observers(model, qconfig_dict) + scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model) self.assertTrue(len(observer_modules) == num_observers) for observer in observer_modules: self.assertTrue(observer.original_name == 'MinMaxObserver') + qconfig_dict = {"": per_channel_dynamic_qconfig} - scripted_model = self._insert_observers(model, qconfig_dict) + scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model) self.assertTrue(len(observer_modules) == num_observers) for observer in observer_modules: @@ -103,7 +116,7 @@ class TestOnDeviceDynamicPTQInsertObservers(TestCase): orig_scripted_model = get_script_module(model, False, inputs) torch._C._jit_pass_inline(orig_scripted_model.graph) orig_forward_graph = orig_scripted_model.graph.str() - scripted_model = self._insert_observers(model, qconfig_dict) + scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) quant_forward_graph = scripted_model.graph.str() # exact graph matching is difficult so just resorting to # of lines # instead of implementing graph matching @@ -135,10 +148,80 @@ class TestOnDeviceDynamicPTQInsertObservers(TestCase): model = MyConvLinearModule() qconfig_dict = {"": default_dynamic_qconfig} inputs = model.get_example_inputs() - scripted_model = self._insert_observers(model, qconfig_dict) + scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) observe_forward_graph = scripted_model.observe_forward.graph num_weight_only_observers = 0 for node in observe_forward_graph.nodes(): if (self._observer_is_weight_only(node)): num_weight_only_observers += 1 self.assertEqual(num_weight_only_observers, 3) + + +class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase): + def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0): + quantize_forward_graph = model.quantize_forward.graph + quantize_per_tensor = quantize_per_channel = dequantize = 0 + for n in quantize_forward_graph.nodes(): + if "aten::quantize_per_tensor" in n.kind(): + quantize_per_tensor += 1 + if "aten::quantize_per_channel" in n.kind(): + quantize_per_channel += 1 + if "aten::dequantize" in n.kind(): + dequantize += 1 + self.assertEqual(quantize_per_tensor + quantize_per_channel, dequantize) + self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes) + + def _validate_calculate_qparams(self, model, num_nodes): + quantize_forward_graph = model.quantize_forward.graph + num_calculate_qparams = 0 + for n in quantize_forward_graph.nodes(): + if OnDevicePTQUtils.is_calculate_qparam(n): + num_calculate_qparams += 1 + self.assertEqual(num_calculate_qparams, num_nodes) + + def _validate_no_observer_forward(self, model): + quantize_forward_graph = model.quantize_forward.graph + for n in quantize_forward_graph.nodes(): + if (n.kind() == "prim::CallMethod") and n.s("name") == "forward": + if (OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0))): + return False + return True + + def _check_quant_dequant_and_calc_qparams(self, model, num_nodes): + qconfig_dict = {"": default_dynamic_qconfig} + m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict) + self._validate_quant_dequant_nodes(m, num_nodes) + self._validate_calculate_qparams(m, num_nodes) + self._validate_no_observer_forward(m) + + qconfig_dict = {"": per_channel_dynamic_qconfig} + m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict) + self._validate_quant_dequant_nodes(m, num_nodes, num_nodes) + self._validate_calculate_qparams(m, num_nodes) + self._validate_no_observer_forward(m) + + def _check_quantize_forward_runs(self, model): + inputs = model.get_example_inputs() + qconfig_dict = {"": default_dynamic_qconfig} + m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict) + m.observe_forward(*inputs) + m.quantize_forward(*inputs) + + qconfig_dict = {"": per_channel_dynamic_qconfig} + m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict) + # First must run observe forward to record the stats to produce + # correct scales and zero points + m.observe_forward(*inputs) + m.quantize_forward(*inputs) + + def test_num_quant_dequant_nodes(self): + model = LinearAddModel() + self._check_quant_dequant_and_calc_qparams(model, 2) + model = MyConvLinearModule() + self._check_quant_dequant_and_calc_qparams(model, 3) + + def test_quantize_forward_runs(self): + model = LinearAddModel() + self._check_quantize_forward_runs(model) + model = MyConvLinearModule() + self._check_quantize_forward_runs(model) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 126ed69ab58c..a5991ddfbb33 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -261,6 +261,11 @@ def _jit_pass_insert_quant_dequant(module: 'torch.jit.ScriptModule', inplace: _bool, debug: _bool, quant_type: _int): ... +def _jit_pass_insert_quant_dequant_for_ondevice_ptq(module: 'torch.jit.ScriptModule', + method_name: str, + inplace: _bool, + debug: _bool, + quant_type: _int): ... def _jit_pass_quant_finalize(module: 'torch.jit.ScriptModule', quant_type: _int, preserved_attrs: Sequence[str]): ... diff --git a/torch/ao/quantization/quantize_jit.py b/torch/ao/quantization/quantize_jit.py index c27d17086da8..00b0cd7dd7f0 100644 --- a/torch/ao/quantization/quantize_jit.py +++ b/torch/ao/quantization/quantize_jit.py @@ -116,6 +116,21 @@ def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC, torch._C._jit_pass_dce(model.graph) return model + +def _convert_ondevice_jit(model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC): + _check_is_script_module(model) + assert quant_type == QuantType.DYNAMIC, "This API, while should work for static quant, is only tested for dynamic quant." + assert not method_name.startswith("observe_"), "Pass in valid method to be quantized, e.g. forward" + observe_method_name = "observe_" + method_name + model_c = model._c + model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq( + model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + def convert_jit(model, inplace=False, debug=False, preserved_attrs=None): torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit") return _convert_jit(model, inplace, debug, quant_type=QuantType.STATIC, preserved_attrs=preserved_attrs) @@ -124,6 +139,10 @@ def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None) torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit") return _convert_jit(model, inplace, debug, quant_type=QuantType.DYNAMIC, preserved_attrs=preserved_attrs) + +def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False): + return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC) + def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC): # Always do inplace convert because the Tensor is already # copied in prepare_jit when inplace is False diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 9f08e8edb644..54bd6679980e 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -24,6 +25,16 @@ using DynamicQuantOps = std::tuple; std::string kScalarType = "_scalar_type"; +struct QuantOpParams { + c10::QScheme qscheme{c10::kPerTensorAffine}; + std::vector qparams; + // This is only so that insertQuantizationOps can be templatized + // and subsequntly significant portion of that code can be reused. + std::string back() const { + return "AttributeDoesNotExist"; + } +}; + c10::QScheme toAffine(c10::QScheme qscheme) { switch (qscheme) { case c10::kPerTensorAffine: @@ -279,12 +290,20 @@ bool isEmbeddingBagOp( embedding_bag_name.value().find("embedding_bag_") != std::string::npos; } -// Insert quant and dequant nodes into the graph for both static and dynamic -// quant. +template Node* insertQuantDequantNodes( Value* self, Node* observer, - const std::vector& qparam_names, + T& qparams, + const std::string& quantize_func); + +// Insert quant and dequant nodes into the graph for both static and dynamic +// quant. +template <> +Node* insertQuantDequantNodes>( + Value* self, + Node* observer, + std::vector& qparam_names, const std::string& quantize_func) { Graph* g = observer->owningGraph(); Value* observer_out = observer->output(); @@ -416,12 +435,13 @@ Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) { return qembedding_bag; } +template void insertQuantizationOps( Module& module, Value* self, Node* observer, bool is_per_channel, - const std::vector& qparam_names, + T& qparams, QuantType quant_type = QuantType::STATIC) { Graph* g = observer->owningGraph(); // Observer output @@ -467,7 +487,7 @@ void insertQuantizationOps( observer_compute_dtype == at::ScalarType::QUInt8 || observer_compute_dtype == at::ScalarType::QInt8) { // For activation tensors we insert choose_qparams, quant, dequant ops. - Value* dtype = g->insertGetAttr(self, qparam_names.back()); + Value* dtype = g->insertGetAttr(self, qparams.back()); std::tie(choose_qparams, quant, dequant) = insertChooseQParamQuantDequant( g, observer_out, dtype, at::Symbol::aten(quantize_func)); @@ -479,12 +499,10 @@ void insertQuantizationOps( } } else { // For weight tensors we insert quant-dequant ops. - dequant = - insertQuantDequantNodes(self, observer, qparam_names, quantize_func); + dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func); } } else { // Static quant - dequant = - insertQuantDequantNodes(self, observer, qparam_names, quantize_func); + dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func); } observer_out->replaceAllUsesWith(original_val); @@ -700,10 +718,16 @@ class InsertQuantDeQuantHelper { void run(Module& module, const std::string& method_name); + void runForOnDevicePTQ(Module& module, const std::string& method_name); + // Cleanup observer nodes from graph and observer modules // from module object and ClassType void cleanup(Module& module); + // Cleanup observer nodes only but not modules + // This is for ondevice PTQ + void removeObserverNodes(Module& m); + // In order to propagate quantization ops through the ops that doesn't // require observation, we'll first inline the graph, and call the // PropgateQuantizationOps pass @@ -725,6 +749,11 @@ class InsertQuantDeQuantHelper { std::tuple getQSchemeAndQParamVector( script::Module& module, Node* n); + QuantOpParams insertCalculateQParams( + script::Module& module, + Graph* g, + Node* n); + void checkQScheme(Graph* g, c10::QScheme qscheme) { if (qscheme_for_graph_.count(g)) { // FIXME[T110786721]: This check was broken before nevery failing. @@ -747,7 +776,12 @@ class InsertQuantDeQuantHelper { void collectObserverNodesAndValueToQuantize(Module& module, Value*); void cleanup(Module& module, Graph* g); + void removeObserverNodes(Graph* g); void quantizeTensors(Module& module, Graph* g, Value* self); + void insertCalculateQParamsAndQuantizationOps( + Module& module, + Graph* g, + Value* self); // Function that extracts and runs the weight observer in a separate // subgraph. @@ -837,6 +871,27 @@ void InsertQuantDeQuantHelper::collectObserverNodesAndValueToQuantize( observer_nodes_for_graph_[g].push_back(observer); } +void InsertQuantDeQuantHelper::removeObserverNodes(Module& module) { + for (auto& method : module.get_methods()) { + removeObserverNodes(method.graph().get()); + } + for (Module m : module.children()) { + removeObserverNodes(m); + } +} + +void InsertQuantDeQuantHelper::removeObserverNodes(Graph* g) { + if (nodes_to_destroy_.count(g)) { + for (auto& n : nodes_to_destroy_.at(g)) { + n->removeAllInputs(); + } + for (auto& n : nodes_to_destroy_.at(g)) { + n->destroy(); + } + nodes_to_destroy_.at(g).clear(); + } +} + void InsertQuantDeQuantHelper::cleanup(Module& module) { for (auto& method : module.get_methods()) { cleanup(module, method.graph().get()); @@ -848,15 +903,7 @@ void InsertQuantDeQuantHelper::cleanup(Module& module) { void InsertQuantDeQuantHelper::cleanup(Module& module, Graph* g) { GRAPH_DUMP("Before Remove Observers:", g); - if (nodes_to_destroy_.count(g)) { - for (auto& n : nodes_to_destroy_.at(g)) { - n->removeAllInputs(); - } - for (auto& n : nodes_to_destroy_.at(g)) { - n->destroy(); - } - nodes_to_destroy_.at(g).clear(); - } + removeObserverNodes(g); // 1. If we have seen this graph before, this means the observer // attributes has been removed from the type(see step 2) but the slot @@ -1471,6 +1518,200 @@ void InsertQuantDeQuantHelper::propagateQuantizationOps(Module& module) { RemoveRedundantDequantize(graph); } +// Insert quant and dequant nodes into the graph for both static and dynamic +// quant. +template <> +Node* insertQuantDequantNodes( + Value* self, + Node* observer, + QuantOpParams& qparams, + const std::string& quantize_func) { + (void)self; + Graph* g = observer->owningGraph(); + Value* observer_out = observer->output(); + Value* original_val = observer->input(1); + std::vector inputs; + // + 1 for tensor to be quantized + inputs.reserve(qparams.qparams.size() + 1); + inputs.push_back({observer_out}); + for (const auto& qparam_values : qparams.qparams) { + inputs.push_back(qparam_values); + } + Node* quant = insertQuant( + g, + inputs, + at::Symbol::aten(quantize_func), + original_val->debugName() + ".quant"); + // Have to make sure that quant node appears after the values it depends on. + for (Value* v : inputs) { + quant->moveAfter(v->node()); + } + Node* dequant = insertDeQuant(g, quant->output(), original_val); + dequant->moveAfter(quant); + return dequant; +} + +void checkCalculateQParamsResultTypes(const Node* out) { + TORCH_CHECK( + out->outputs().size() == 2, + "cacluate_qparams should produce output of size 2 (scale, zero_point)."); + Value* scale = out->output(0); + Value* zp = out->output(1); + TORCH_CHECK( + scale->type()->expect(), + "Scale value should be of Tensor type."); + TORCH_CHECK( + zp->type()->expect(), "Scale value should be of float type."); +} + +QuantOpParams InsertQuantDeQuantHelper::insertCalculateQParams( + script::Module& module, + Graph* g, + Node* n) { + // TODO: refactor findObserverName to take Node* as input + Value* self = g->inputs()[0]; + Value* v = n->output(); + TORCH_INTERNAL_ASSERT( + v->type()->isSubtypeOf(*TensorType::get()), + "Expected output of observer node to be Tensor"); + auto observer_name = findObserverName(v); + TORCH_INTERNAL_ASSERT( + observer_name, + "getQSchemeAndParamMap expects the corresponding observer for ", + v->debugName(), + " exists."); + std::vector qparams_graph_values; + QuantOpParams quant_op_params; + + TORCH_CHECK( + !isPlaceholderObserver(n->input(0)), + "Placeholder observers are not supported in ondevice PTQ."); + auto observer_module = module.attr(observer_name.value()).toModule(); + Value* observer_module_value = g->insertGetAttr(self, observer_name.value()); + auto scalar_type = observer_module.attr("dtype"); + TORCH_CHECK( + scalar_type.toScalarType() != at::ScalarType::Undefined, + "dtype of observer can't be undefined"); + // Not sure if we need to support this for on device PTQ. + if (scalar_type == at::ScalarType::Half) { + return quant_op_params; + } + auto calculate_qparams = observer_module.get_method("calculate_qparams"); + auto calculate_qparams_schema = calculate_qparams.function().getSchema(); + MatchedSchema matched_schema = matchSchema( + calculate_qparams_schema, + v->node()->sourceRange(), + *g, + {observer_module_value}, + {}); + Node* call = g->insertMethodCall("calculate_qparams", matched_schema)->node(); + Node* scale_zp_node = g->insertNode(g->createTupleUnpack(call->output(0))); + checkCalculateQParamsResultTypes(scale_zp_node); + auto qscheme = observer_module.attr("qscheme").toQScheme(); + quant_op_params.qscheme = qscheme; + quant_op_params.qparams.push_back(scale_zp_node->output(0)); // scale Value* + quant_op_params.qparams.push_back( + scale_zp_node->output(1)); // zero_point Value* + if (isPerChannel(qscheme)) { + Value* ch_axis_value = g->insertGetAttr(observer_module_value, "ch_axis"); + quant_op_params.qparams.push_back(ch_axis_value); + } + Value* scalar_type_value = g->insertGetAttr(observer_module_value, "dtype"); + quant_op_params.qparams.push_back(scalar_type_value); + return quant_op_params; +} + +void InsertQuantDeQuantHelper::insertCalculateQParamsAndQuantizationOps( + Module& module, + Graph* graph, + Value* self) { + if (!observer_nodes_for_graph_.count(graph)) { + return; + } + for (auto* n : observer_nodes_for_graph_.at(graph)) { + Graph* g = n->owningGraph(); + // Observer output + Value* observer_out = n->output(); + // Inserting before insert point + WithInsertPoint insert_qparams_calc(observer_out->node()->next()); + auto quant_op_params = insertCalculateQParams(module, g, n); + insertQuantizationOps( + module, + self, + n, + isPerChannel(quant_op_params.qscheme), + quant_op_params, + quant_type_); + } +} + +void InsertQuantDeQuantHelper::runForOnDevicePTQ( + Module& module, + const std::string& method_name) { + // In all likelihood this really wont do anything because we expect that + // the input method for quantization's prepare step will be inlined. Thus + // only call methods we will see will belong to observer's forward calls. + for (auto& invoked_methods : getInvokedMethods(module, method_name)) { + auto& invoked_module = std::get<0>(invoked_methods); + const auto& invoked_method_name = std::get<1>(invoked_methods); + runForOnDevicePTQ(invoked_module, invoked_method_name); + } + + Method method = module.get_method(method_name); + auto graph = method.graph(); + // Unliked the run method we dont need to extract new qparam values for the + // the same graph used in different call site. + // Reason is that for on device PTQ we dont: + // 1. Run calculate_qparams + // 2. Get the scale and zero point + // 3. get axis and dtype + // 4. register values from 2 and 3 as attributes on the parent module. + // Instead we insert call to calculate_qparams (1) via insertCalculateQParams + // in the graph itself. Then instead of 2 and 3, we get the output Value* + // and for 3, we insert GetAttr for axis and dtype and use those Value* + // with insterQuantizationOps + + // prim::Param nodes do not belong to the graph. Hence the Insert + // point is the beginning of graph node. This also safe guards against + // observing a potentially mutated value due to some in-place operation + std::vector input_values; + for (const auto idx : c10::irange(1, method.num_inputs())) { + auto& v = graph->inputs()[idx]; + if (v->type()->isSubtypeOf(*TensorType::get())) { + input_values.push_back(v); + } + } + + std::stack blocks_to_visit; + blocks_to_visit.push(graph->block()); + while (!blocks_to_visit.empty()) { + Block* b = blocks_to_visit.top(); + blocks_to_visit.pop(); + for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) { + Node* n = *it++; + for (Value* v : n->outputs()) { + if (!v->type()->isSubtypeOf(*TensorType::get())) { + continue; + } + collectObserverNodesAndValueToQuantize(module, v); + } + + for (Block* subblock : n->blocks()) { + blocks_to_visit.push(subblock); + } + } + } + + for (Value* v : input_values) { + collectObserverNodesAndValueToQuantize(module, v); + } + + GRAPH_DUMP("Before insertCalculateQparamsAndQuantizationOps:", graph); + Value* self = graph->inputs()[0]; + insertCalculateQParamsAndQuantizationOps(module, graph.get(), self); + GRAPH_DUMP("After insertCalculateQparamsAndQuantizationOps:", graph); +} + } // namespace void ReplicateQuant(std::shared_ptr& graph) { @@ -1573,5 +1814,57 @@ Module InsertQuantDeQuant( return module; } +/* + * + * Assumption: method_name method has observer placed + * Objective: modify that method to insert calls to: + * 1. calculate_qparams + * 2. GetAttr for axis and dtype values + * 3. Use Values from above two to insert calls to quant + dequant + * Thus after this step you have a graph of, e.g., observe_forward, + * that has observer nodes, calculate_qparams run on those observer nodes, + * output of which is used by quant-dequant nodes. output of dequant is used + * by the actual op. + * Later on we will replace dequant + op (e.g. linear) with + * 1. prepacked_op context + * 2. unpack + * 3. dequantize + * 4. linear + * + * Of the above pattern 2, 3, and 4 can be replaced by linear_run op + */ +// Module InsertQuantDeQuantForOnDevicePTQ( +Module InsertQuantDeQuantOnDevicePTQ( + Module& input_module, + const std::string& method_name, + bool inplace, + bool debug, + QuantType quant_type) { + Module module = input_module.clone(inplace); + const std::string kObserveString = "observe_"; + const auto matched_pos = method_name.find(kObserveString); + const auto end_pos = matched_pos + kObserveString.length(); + const std::string orig_method_name = method_name.substr(end_pos); + TORCH_CHECK( + matched_pos == 0, + "Quant dequant nodes can only be added to observe_", + orig_method_name, + ". Please make sure to run prepare step for on-device PTQ."); + + std::string quantize_method_name = "quantize_" + orig_method_name; + cloneMethod(module, method_name, quantize_method_name); + InsertQuantDeQuantHelper h(quant_type, debug); + h.runForOnDevicePTQ(module, quantize_method_name); + h.removeObserverNodes(module); + // Dont need: + // ReplicateChooseQParamsQuantDequant: This is propagating dynamic quant's + // quant dequant RemoveRedundantQuantizationOps: THis is removing activation + // observers for dynamic quant when the op related to it is not dynamically + // quantizable. Doesnt really make sense. In our case we wont have those + // anyway since for dynamic quant activations wont be observed We can still + // use this function because the above two methods should really be a noop + h.propagateQuantizationOps(module); + return module; +} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.h b/torch/csrc/jit/passes/quantization/insert_quant_dequant.h index 2b3f5d711cbe..de2b31fdba7c 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.h +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.h @@ -35,5 +35,12 @@ TORCH_API Module InsertQuantDeQuant( bool debug, QuantType quant_type = QuantType::STATIC); +TORCH_API Module InsertQuantDeQuantOnDevicePTQ( + Module& module, + const std::string& method_name, + bool inplace, + bool debug, + QuantType quant_type = QuantType::STATIC); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 9a55dcd63064..1a62b834b44b 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -443,6 +443,22 @@ void initJITBindings(PyObject* module) { py::arg("inplace"), py::arg("debug"), py::arg("quant_type_int") = 1) + .def( + "_jit_pass_insert_quant_dequant_for_ondevice_ptq", + [](Module& module, + const std::string& method_name, + bool inplace, + bool debug, + int quant_type_int) { + auto quant_type = static_cast(quant_type_int); + return InsertQuantDeQuantOnDevicePTQ( + module, method_name, inplace, debug, quant_type); + }, + py::arg("module"), + py::arg("method_name"), + py::arg("inplace"), + py::arg("debug"), + py::arg("quant_type_int") = 1) .def( "_jit_pass_insert_prepack_unpack", [](std::shared_ptr& g) { return InsertPrepackUnpack(g); }) diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index 77933f8e62dd..48ba1abdd90a 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -25,7 +25,7 @@ _all__ = [ 'quantize', 'quantize_dynamic', 'quantize_qat', 'prepare', 'convert', 'prepare_qat', # Top level API for graph mode quantization on TorchScript - 'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit' + 'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit', '_convert_ondevice_dynamic_jit', # Top level API for graph mode quantization on GraphModule(torch.fx) # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',