mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-25 16:14:55 +08:00 
			
		
		
		
	[On Device Quantization][pytorch]Make insert_quant_dequant support ondevice ptq (#83570)
Summary: This diff adds a way to: - clone previously observed method - Add calls to observer's calculate_qparams methods - Extract the scale and zero point - Use them to insert quant dequant nodes Now for forward method we have - observe_forward - quantize_forward observe_forward is used post training to observer statistics. In the case of dynamic PTQ this requires just running that method once to update weight observer statistics. quantize_forward method will be used to use the observer statistics to calculate quantization parameters and apply that to quant dequant op. Subsequent diffs will replace dequant + op with their quantized op counter parts and replace quantize ops with relevant packed params class where possible Test Plan: To be written Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D38771419](https://our.internmc.facebook.com/intern/diff/D38771419) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83570 Approved by: https://github.com/jerryzh168
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							6a5d9f1be0
						
					
				
				
					commit
					446afb5f9f
				
			| @ -10,6 +10,7 @@ from torch.ao.quantization import ( | ||||
|  | ||||
| from torch.ao.quantization.quantize_jit import ( | ||||
|     _prepare_ondevice_dynamic_jit, | ||||
|     _convert_ondevice_dynamic_jit, | ||||
| ) | ||||
|  | ||||
| from torch.testing._internal.common_utils import TestCase | ||||
| @ -55,8 +56,18 @@ class OnDevicePTQUtils(object): | ||||
|     observer_module_name = ['MinMaxObserver', 'PerChannelMinMaxObserver'] | ||||
|  | ||||
|     @staticmethod | ||||
|     def insert_observers(m, qconfig_dict): | ||||
|     def insert_observers(model, qconfig_dict): | ||||
|         inputs = model.get_example_inputs() | ||||
|         scripted_model = get_script_module(model, False, inputs) | ||||
|         scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict) | ||||
|         return scripted_model | ||||
|  | ||||
|     @staticmethod | ||||
|     def insert_observers_quant_dequant(model, qconfig_dict): | ||||
|         inputs = model.get_example_inputs() | ||||
|         m = get_script_module(model, False, inputs) | ||||
|         m = _prepare_ondevice_dynamic_jit(m, qconfig_dict) | ||||
|         m = _convert_ondevice_dynamic_jit(m, 'forward', True, False) | ||||
|         return m | ||||
|  | ||||
|     @staticmethod | ||||
| @ -75,23 +86,25 @@ class OnDevicePTQUtils(object): | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
|     @staticmethod | ||||
|     def is_calculate_qparam(node): | ||||
|         if node.kind() == "prim::CallMethod": | ||||
|             if node.s('name') == "calculate_qparams": | ||||
|                 return True | ||||
|         return False | ||||
|  | ||||
|  | ||||
| class TestOnDeviceDynamicPTQInsertObservers(TestCase): | ||||
|     def _insert_observers(self, model, qconfig_dict): | ||||
|         inputs = model.get_example_inputs() | ||||
|         scripted_model = get_script_module(model, False, inputs) | ||||
|         scripted_model = OnDevicePTQUtils.insert_observers(scripted_model, qconfig_dict) | ||||
|         return scripted_model | ||||
|  | ||||
|     def _check_num_and_type_of_observers(self, model, num_observers): | ||||
|         qconfig_dict = {"": default_dynamic_qconfig} | ||||
|         scripted_model = self._insert_observers(model, qconfig_dict) | ||||
|         scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) | ||||
|         observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model) | ||||
|         self.assertTrue(len(observer_modules) == num_observers) | ||||
|         for observer in observer_modules: | ||||
|             self.assertTrue(observer.original_name == 'MinMaxObserver') | ||||
|  | ||||
|         qconfig_dict = {"": per_channel_dynamic_qconfig} | ||||
|         scripted_model = self._insert_observers(model, qconfig_dict) | ||||
|         scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) | ||||
|         observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model) | ||||
|         self.assertTrue(len(observer_modules) == num_observers) | ||||
|         for observer in observer_modules: | ||||
| @ -103,7 +116,7 @@ class TestOnDeviceDynamicPTQInsertObservers(TestCase): | ||||
|         orig_scripted_model = get_script_module(model, False, inputs) | ||||
|         torch._C._jit_pass_inline(orig_scripted_model.graph) | ||||
|         orig_forward_graph = orig_scripted_model.graph.str() | ||||
|         scripted_model = self._insert_observers(model, qconfig_dict) | ||||
|         scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) | ||||
|         quant_forward_graph = scripted_model.graph.str() | ||||
|         # exact graph matching is difficult so just resorting to # of lines | ||||
|         # instead of implementing graph matching | ||||
| @ -135,10 +148,80 @@ class TestOnDeviceDynamicPTQInsertObservers(TestCase): | ||||
|         model = MyConvLinearModule() | ||||
|         qconfig_dict = {"": default_dynamic_qconfig} | ||||
|         inputs = model.get_example_inputs() | ||||
|         scripted_model = self._insert_observers(model, qconfig_dict) | ||||
|         scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) | ||||
|         observe_forward_graph = scripted_model.observe_forward.graph | ||||
|         num_weight_only_observers = 0 | ||||
|         for node in observe_forward_graph.nodes(): | ||||
|             if (self._observer_is_weight_only(node)): | ||||
|                 num_weight_only_observers += 1 | ||||
|         self.assertEqual(num_weight_only_observers, 3) | ||||
|  | ||||
|  | ||||
| class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase): | ||||
|     def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0): | ||||
|         quantize_forward_graph = model.quantize_forward.graph | ||||
|         quantize_per_tensor = quantize_per_channel = dequantize = 0 | ||||
|         for n in quantize_forward_graph.nodes(): | ||||
|             if "aten::quantize_per_tensor" in n.kind(): | ||||
|                 quantize_per_tensor += 1 | ||||
|             if "aten::quantize_per_channel" in n.kind(): | ||||
|                 quantize_per_channel += 1 | ||||
|             if "aten::dequantize" in n.kind(): | ||||
|                 dequantize += 1 | ||||
|         self.assertEqual(quantize_per_tensor + quantize_per_channel, dequantize) | ||||
|         self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes) | ||||
|  | ||||
|     def _validate_calculate_qparams(self, model, num_nodes): | ||||
|         quantize_forward_graph = model.quantize_forward.graph | ||||
|         num_calculate_qparams = 0 | ||||
|         for n in quantize_forward_graph.nodes(): | ||||
|             if OnDevicePTQUtils.is_calculate_qparam(n): | ||||
|                 num_calculate_qparams += 1 | ||||
|         self.assertEqual(num_calculate_qparams, num_nodes) | ||||
|  | ||||
|     def _validate_no_observer_forward(self, model): | ||||
|         quantize_forward_graph = model.quantize_forward.graph | ||||
|         for n in quantize_forward_graph.nodes(): | ||||
|             if (n.kind() == "prim::CallMethod") and n.s("name") == "forward": | ||||
|                 if (OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0))): | ||||
|                     return False | ||||
|         return True | ||||
|  | ||||
|     def _check_quant_dequant_and_calc_qparams(self, model, num_nodes): | ||||
|         qconfig_dict = {"": default_dynamic_qconfig} | ||||
|         m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict) | ||||
|         self._validate_quant_dequant_nodes(m, num_nodes) | ||||
|         self._validate_calculate_qparams(m, num_nodes) | ||||
|         self._validate_no_observer_forward(m) | ||||
|  | ||||
|         qconfig_dict = {"": per_channel_dynamic_qconfig} | ||||
|         m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict) | ||||
|         self._validate_quant_dequant_nodes(m, num_nodes, num_nodes) | ||||
|         self._validate_calculate_qparams(m, num_nodes) | ||||
|         self._validate_no_observer_forward(m) | ||||
|  | ||||
|     def _check_quantize_forward_runs(self, model): | ||||
|         inputs = model.get_example_inputs() | ||||
|         qconfig_dict = {"": default_dynamic_qconfig} | ||||
|         m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict) | ||||
|         m.observe_forward(*inputs) | ||||
|         m.quantize_forward(*inputs) | ||||
|  | ||||
|         qconfig_dict = {"": per_channel_dynamic_qconfig} | ||||
|         m = OnDevicePTQUtils.insert_observers_quant_dequant(model, qconfig_dict) | ||||
|         # First must run observe forward to record the stats to produce | ||||
|         # correct scales and zero points | ||||
|         m.observe_forward(*inputs) | ||||
|         m.quantize_forward(*inputs) | ||||
|  | ||||
|     def test_num_quant_dequant_nodes(self): | ||||
|         model = LinearAddModel() | ||||
|         self._check_quant_dequant_and_calc_qparams(model, 2) | ||||
|         model = MyConvLinearModule() | ||||
|         self._check_quant_dequant_and_calc_qparams(model, 3) | ||||
|  | ||||
|     def test_quantize_forward_runs(self): | ||||
|         model = LinearAddModel() | ||||
|         self._check_quantize_forward_runs(model) | ||||
|         model = MyConvLinearModule() | ||||
|         self._check_quantize_forward_runs(model) | ||||
|  | ||||
| @ -261,6 +261,11 @@ def _jit_pass_insert_quant_dequant(module: 'torch.jit.ScriptModule', | ||||
|                                    inplace: _bool, | ||||
|                                    debug: _bool, | ||||
|                                    quant_type: _int): ... | ||||
| def _jit_pass_insert_quant_dequant_for_ondevice_ptq(module: 'torch.jit.ScriptModule', | ||||
|                                    method_name: str, | ||||
|                                    inplace: _bool, | ||||
|                                    debug: _bool, | ||||
|                                    quant_type: _int): ... | ||||
| def _jit_pass_quant_finalize(module: 'torch.jit.ScriptModule', | ||||
|                              quant_type: _int, | ||||
|                              preserved_attrs: Sequence[str]): ... | ||||
|  | ||||
| @ -116,6 +116,21 @@ def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC, | ||||
|     torch._C._jit_pass_dce(model.graph) | ||||
|     return model | ||||
|  | ||||
|  | ||||
| def _convert_ondevice_jit(model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC): | ||||
|     _check_is_script_module(model) | ||||
|     assert quant_type == QuantType.DYNAMIC, "This API, while should work for static quant, is only tested for dynamic quant." | ||||
|     assert not method_name.startswith("observe_"), "Pass in valid method to be quantized, e.g. forward" | ||||
|     observe_method_name = "observe_" + method_name | ||||
|     model_c = model._c | ||||
|     model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq( | ||||
|         model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC) | ||||
|     if inplace: | ||||
|         model._reconstruct(model_c) | ||||
|     else: | ||||
|         model = wrap_cpp_module(model_c) | ||||
|     return model | ||||
|  | ||||
| def convert_jit(model, inplace=False, debug=False, preserved_attrs=None): | ||||
|     torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit") | ||||
|     return _convert_jit(model, inplace, debug, quant_type=QuantType.STATIC, preserved_attrs=preserved_attrs) | ||||
| @ -124,6 +139,10 @@ def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None) | ||||
|     torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit") | ||||
|     return _convert_jit(model, inplace, debug, quant_type=QuantType.DYNAMIC, preserved_attrs=preserved_attrs) | ||||
|  | ||||
|  | ||||
| def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False): | ||||
|     return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC) | ||||
|  | ||||
| def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC): | ||||
|     # Always do inplace convert because the Tensor is already | ||||
|     # copied in prepare_jit when inplace is False | ||||
|  | ||||
| @ -2,6 +2,7 @@ | ||||
|  | ||||
| #include <c10/core/QScheme.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <torch/csrc/jit/frontend/schema_matching.h> | ||||
| #include <torch/csrc/jit/ir/subgraph_matcher.h> | ||||
| #include <torch/csrc/jit/jit_log.h> | ||||
| #include <torch/csrc/jit/passes/constant_propagation.h> | ||||
| @ -24,6 +25,16 @@ using DynamicQuantOps = std::tuple<Node*, Node*, Node*>; | ||||
|  | ||||
| std::string kScalarType = "_scalar_type"; | ||||
|  | ||||
| struct QuantOpParams { | ||||
|   c10::QScheme qscheme{c10::kPerTensorAffine}; | ||||
|   std::vector<Value*> qparams; | ||||
|   // This is only so that insertQuantizationOps can be templatized | ||||
|   // and subsequntly significant portion of that code can be reused. | ||||
|   std::string back() const { | ||||
|     return "AttributeDoesNotExist"; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| c10::QScheme toAffine(c10::QScheme qscheme) { | ||||
|   switch (qscheme) { | ||||
|     case c10::kPerTensorAffine: | ||||
| @ -279,12 +290,20 @@ bool isEmbeddingBagOp( | ||||
|       embedding_bag_name.value().find("embedding_bag_") != std::string::npos; | ||||
| } | ||||
|  | ||||
| // Insert quant and dequant nodes into the graph for both static and dynamic | ||||
| // quant. | ||||
| template <typename T> | ||||
| Node* insertQuantDequantNodes( | ||||
|     Value* self, | ||||
|     Node* observer, | ||||
|     const std::vector<std::string>& qparam_names, | ||||
|     T& qparams, | ||||
|     const std::string& quantize_func); | ||||
|  | ||||
| // Insert quant and dequant nodes into the graph for both static and dynamic | ||||
| // quant. | ||||
| template <> | ||||
| Node* insertQuantDequantNodes<std::vector<std::string>>( | ||||
|     Value* self, | ||||
|     Node* observer, | ||||
|     std::vector<std::string>& qparam_names, | ||||
|     const std::string& quantize_func) { | ||||
|   Graph* g = observer->owningGraph(); | ||||
|   Value* observer_out = observer->output(); | ||||
| @ -416,12 +435,13 @@ Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) { | ||||
|   return qembedding_bag; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| void insertQuantizationOps( | ||||
|     Module& module, | ||||
|     Value* self, | ||||
|     Node* observer, | ||||
|     bool is_per_channel, | ||||
|     const std::vector<std::string>& qparam_names, | ||||
|     T& qparams, | ||||
|     QuantType quant_type = QuantType::STATIC) { | ||||
|   Graph* g = observer->owningGraph(); | ||||
|   // Observer output | ||||
| @ -467,7 +487,7 @@ void insertQuantizationOps( | ||||
|           observer_compute_dtype == at::ScalarType::QUInt8 || | ||||
|           observer_compute_dtype == at::ScalarType::QInt8) { | ||||
|         // For activation tensors we insert choose_qparams, quant, dequant ops. | ||||
|         Value* dtype = g->insertGetAttr(self, qparam_names.back()); | ||||
|         Value* dtype = g->insertGetAttr(self, qparams.back()); | ||||
|         std::tie(choose_qparams, quant, dequant) = | ||||
|             insertChooseQParamQuantDequant( | ||||
|                 g, observer_out, dtype, at::Symbol::aten(quantize_func)); | ||||
| @ -479,12 +499,10 @@ void insertQuantizationOps( | ||||
|       } | ||||
|     } else { | ||||
|       // For weight tensors we insert quant-dequant ops. | ||||
|       dequant = | ||||
|           insertQuantDequantNodes(self, observer, qparam_names, quantize_func); | ||||
|       dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func); | ||||
|     } | ||||
|   } else { // Static quant | ||||
|     dequant = | ||||
|         insertQuantDequantNodes(self, observer, qparam_names, quantize_func); | ||||
|     dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func); | ||||
|   } | ||||
|   observer_out->replaceAllUsesWith(original_val); | ||||
|  | ||||
| @ -700,10 +718,16 @@ class InsertQuantDeQuantHelper { | ||||
|  | ||||
|   void run(Module& module, const std::string& method_name); | ||||
|  | ||||
|   void runForOnDevicePTQ(Module& module, const std::string& method_name); | ||||
|  | ||||
|   // Cleanup observer nodes from graph and observer modules | ||||
|   // from module object and ClassType | ||||
|   void cleanup(Module& module); | ||||
|  | ||||
|   // Cleanup observer nodes only but not modules | ||||
|   // This is for ondevice PTQ | ||||
|   void removeObserverNodes(Module& m); | ||||
|  | ||||
|   // In order to propagate quantization ops through the ops that doesn't | ||||
|   // require observation, we'll first inline the graph, and call the | ||||
|   // PropgateQuantizationOps pass | ||||
| @ -725,6 +749,11 @@ class InsertQuantDeQuantHelper { | ||||
|   std::tuple<c10::QScheme, QParamVector> getQSchemeAndQParamVector( | ||||
|       script::Module& module, | ||||
|       Node* n); | ||||
|   QuantOpParams insertCalculateQParams( | ||||
|       script::Module& module, | ||||
|       Graph* g, | ||||
|       Node* n); | ||||
|  | ||||
|   void checkQScheme(Graph* g, c10::QScheme qscheme) { | ||||
|     if (qscheme_for_graph_.count(g)) { | ||||
|       // FIXME[T110786721]: This check was broken before nevery failing. | ||||
| @ -747,7 +776,12 @@ class InsertQuantDeQuantHelper { | ||||
|  | ||||
|   void collectObserverNodesAndValueToQuantize(Module& module, Value*); | ||||
|   void cleanup(Module& module, Graph* g); | ||||
|   void removeObserverNodes(Graph* g); | ||||
|   void quantizeTensors(Module& module, Graph* g, Value* self); | ||||
|   void insertCalculateQParamsAndQuantizationOps( | ||||
|       Module& module, | ||||
|       Graph* g, | ||||
|       Value* self); | ||||
|  | ||||
|   // Function that extracts and runs the weight observer in a separate | ||||
|   // subgraph. | ||||
| @ -837,6 +871,27 @@ void InsertQuantDeQuantHelper::collectObserverNodesAndValueToQuantize( | ||||
|   observer_nodes_for_graph_[g].push_back(observer); | ||||
| } | ||||
|  | ||||
| void InsertQuantDeQuantHelper::removeObserverNodes(Module& module) { | ||||
|   for (auto& method : module.get_methods()) { | ||||
|     removeObserverNodes(method.graph().get()); | ||||
|   } | ||||
|   for (Module m : module.children()) { | ||||
|     removeObserverNodes(m); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void InsertQuantDeQuantHelper::removeObserverNodes(Graph* g) { | ||||
|   if (nodes_to_destroy_.count(g)) { | ||||
|     for (auto& n : nodes_to_destroy_.at(g)) { | ||||
|       n->removeAllInputs(); | ||||
|     } | ||||
|     for (auto& n : nodes_to_destroy_.at(g)) { | ||||
|       n->destroy(); | ||||
|     } | ||||
|     nodes_to_destroy_.at(g).clear(); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void InsertQuantDeQuantHelper::cleanup(Module& module) { | ||||
|   for (auto& method : module.get_methods()) { | ||||
|     cleanup(module, method.graph().get()); | ||||
| @ -848,15 +903,7 @@ void InsertQuantDeQuantHelper::cleanup(Module& module) { | ||||
|  | ||||
| void InsertQuantDeQuantHelper::cleanup(Module& module, Graph* g) { | ||||
|   GRAPH_DUMP("Before Remove Observers:", g); | ||||
|   if (nodes_to_destroy_.count(g)) { | ||||
|     for (auto& n : nodes_to_destroy_.at(g)) { | ||||
|       n->removeAllInputs(); | ||||
|     } | ||||
|     for (auto& n : nodes_to_destroy_.at(g)) { | ||||
|       n->destroy(); | ||||
|     } | ||||
|     nodes_to_destroy_.at(g).clear(); | ||||
|   } | ||||
|   removeObserverNodes(g); | ||||
|  | ||||
|   // 1. If we have seen this graph before, this means the observer | ||||
|   // attributes has been removed from the type(see step 2) but the slot | ||||
| @ -1471,6 +1518,200 @@ void InsertQuantDeQuantHelper::propagateQuantizationOps(Module& module) { | ||||
|   RemoveRedundantDequantize(graph); | ||||
| } | ||||
|  | ||||
| // Insert quant and dequant nodes into the graph for both static and dynamic | ||||
| // quant. | ||||
| template <> | ||||
| Node* insertQuantDequantNodes<QuantOpParams>( | ||||
|     Value* self, | ||||
|     Node* observer, | ||||
|     QuantOpParams& qparams, | ||||
|     const std::string& quantize_func) { | ||||
|   (void)self; | ||||
|   Graph* g = observer->owningGraph(); | ||||
|   Value* observer_out = observer->output(); | ||||
|   Value* original_val = observer->input(1); | ||||
|   std::vector<Value*> inputs; | ||||
|   // + 1 for tensor to be quantized | ||||
|   inputs.reserve(qparams.qparams.size() + 1); | ||||
|   inputs.push_back({observer_out}); | ||||
|   for (const auto& qparam_values : qparams.qparams) { | ||||
|     inputs.push_back(qparam_values); | ||||
|   } | ||||
|   Node* quant = insertQuant( | ||||
|       g, | ||||
|       inputs, | ||||
|       at::Symbol::aten(quantize_func), | ||||
|       original_val->debugName() + ".quant"); | ||||
|   // Have to make sure that quant node appears after the values it depends on. | ||||
|   for (Value* v : inputs) { | ||||
|     quant->moveAfter(v->node()); | ||||
|   } | ||||
|   Node* dequant = insertDeQuant(g, quant->output(), original_val); | ||||
|   dequant->moveAfter(quant); | ||||
|   return dequant; | ||||
| } | ||||
|  | ||||
| void checkCalculateQParamsResultTypes(const Node* out) { | ||||
|   TORCH_CHECK( | ||||
|       out->outputs().size() == 2, | ||||
|       "cacluate_qparams should produce output of size 2 (scale, zero_point)."); | ||||
|   Value* scale = out->output(0); | ||||
|   Value* zp = out->output(1); | ||||
|   TORCH_CHECK( | ||||
|       scale->type()->expect<TensorType>(), | ||||
|       "Scale value should be of Tensor type."); | ||||
|   TORCH_CHECK( | ||||
|       zp->type()->expect<TensorType>(), "Scale value should be of float type."); | ||||
| } | ||||
|  | ||||
| QuantOpParams InsertQuantDeQuantHelper::insertCalculateQParams( | ||||
|     script::Module& module, | ||||
|     Graph* g, | ||||
|     Node* n) { | ||||
|   // TODO: refactor findObserverName to take Node* as input | ||||
|   Value* self = g->inputs()[0]; | ||||
|   Value* v = n->output(); | ||||
|   TORCH_INTERNAL_ASSERT( | ||||
|       v->type()->isSubtypeOf(*TensorType::get()), | ||||
|       "Expected output of observer node to be Tensor"); | ||||
|   auto observer_name = findObserverName(v); | ||||
|   TORCH_INTERNAL_ASSERT( | ||||
|       observer_name, | ||||
|       "getQSchemeAndParamMap expects the corresponding observer for ", | ||||
|       v->debugName(), | ||||
|       " exists."); | ||||
|   std::vector<Value*> qparams_graph_values; | ||||
|   QuantOpParams quant_op_params; | ||||
|  | ||||
|   TORCH_CHECK( | ||||
|       !isPlaceholderObserver(n->input(0)), | ||||
|       "Placeholder observers are not supported in ondevice PTQ."); | ||||
|   auto observer_module = module.attr(observer_name.value()).toModule(); | ||||
|   Value* observer_module_value = g->insertGetAttr(self, observer_name.value()); | ||||
|   auto scalar_type = observer_module.attr("dtype"); | ||||
|   TORCH_CHECK( | ||||
|       scalar_type.toScalarType() != at::ScalarType::Undefined, | ||||
|       "dtype of observer can't be undefined"); | ||||
|   // Not sure if we need to support this for on device PTQ. | ||||
|   if (scalar_type == at::ScalarType::Half) { | ||||
|     return quant_op_params; | ||||
|   } | ||||
|   auto calculate_qparams = observer_module.get_method("calculate_qparams"); | ||||
|   auto calculate_qparams_schema = calculate_qparams.function().getSchema(); | ||||
|   MatchedSchema matched_schema = matchSchema( | ||||
|       calculate_qparams_schema, | ||||
|       v->node()->sourceRange(), | ||||
|       *g, | ||||
|       {observer_module_value}, | ||||
|       {}); | ||||
|   Node* call = g->insertMethodCall("calculate_qparams", matched_schema)->node(); | ||||
|   Node* scale_zp_node = g->insertNode(g->createTupleUnpack(call->output(0))); | ||||
|   checkCalculateQParamsResultTypes(scale_zp_node); | ||||
|   auto qscheme = observer_module.attr("qscheme").toQScheme(); | ||||
|   quant_op_params.qscheme = qscheme; | ||||
|   quant_op_params.qparams.push_back(scale_zp_node->output(0)); // scale Value* | ||||
|   quant_op_params.qparams.push_back( | ||||
|       scale_zp_node->output(1)); // zero_point Value* | ||||
|   if (isPerChannel(qscheme)) { | ||||
|     Value* ch_axis_value = g->insertGetAttr(observer_module_value, "ch_axis"); | ||||
|     quant_op_params.qparams.push_back(ch_axis_value); | ||||
|   } | ||||
|   Value* scalar_type_value = g->insertGetAttr(observer_module_value, "dtype"); | ||||
|   quant_op_params.qparams.push_back(scalar_type_value); | ||||
|   return quant_op_params; | ||||
| } | ||||
|  | ||||
| void InsertQuantDeQuantHelper::insertCalculateQParamsAndQuantizationOps( | ||||
|     Module& module, | ||||
|     Graph* graph, | ||||
|     Value* self) { | ||||
|   if (!observer_nodes_for_graph_.count(graph)) { | ||||
|     return; | ||||
|   } | ||||
|   for (auto* n : observer_nodes_for_graph_.at(graph)) { | ||||
|     Graph* g = n->owningGraph(); | ||||
|     // Observer output | ||||
|     Value* observer_out = n->output(); | ||||
|     // Inserting before insert point | ||||
|     WithInsertPoint insert_qparams_calc(observer_out->node()->next()); | ||||
|     auto quant_op_params = insertCalculateQParams(module, g, n); | ||||
|     insertQuantizationOps( | ||||
|         module, | ||||
|         self, | ||||
|         n, | ||||
|         isPerChannel(quant_op_params.qscheme), | ||||
|         quant_op_params, | ||||
|         quant_type_); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void InsertQuantDeQuantHelper::runForOnDevicePTQ( | ||||
|     Module& module, | ||||
|     const std::string& method_name) { | ||||
|   // In all likelihood this really wont do anything because we expect that | ||||
|   // the input method for quantization's prepare step will be inlined. Thus | ||||
|   // only call methods we will see will belong to observer's forward calls. | ||||
|   for (auto& invoked_methods : getInvokedMethods(module, method_name)) { | ||||
|     auto& invoked_module = std::get<0>(invoked_methods); | ||||
|     const auto& invoked_method_name = std::get<1>(invoked_methods); | ||||
|     runForOnDevicePTQ(invoked_module, invoked_method_name); | ||||
|   } | ||||
|  | ||||
|   Method method = module.get_method(method_name); | ||||
|   auto graph = method.graph(); | ||||
|   // Unliked the run method we dont need to extract new qparam values for the | ||||
|   // the same graph used in different call site. | ||||
|   // Reason is that for on device PTQ we dont: | ||||
|   // 1. Run calculate_qparams | ||||
|   // 2. Get the scale and zero point | ||||
|   // 3. get axis and dtype | ||||
|   // 4. register values from 2 and 3 as attributes on the parent module. | ||||
|   // Instead we insert call to calculate_qparams (1) via insertCalculateQParams | ||||
|   // in the graph itself. Then instead of 2 and 3, we get the output Value* | ||||
|   // and for 3, we insert GetAttr for axis and dtype and use those Value* | ||||
|   // with insterQuantizationOps | ||||
|  | ||||
|   // prim::Param nodes do not belong to the graph. Hence the Insert | ||||
|   // point is the beginning of graph node. This also safe guards against | ||||
|   // observing a potentially mutated value due to some in-place operation | ||||
|   std::vector<Value*> input_values; | ||||
|   for (const auto idx : c10::irange(1, method.num_inputs())) { | ||||
|     auto& v = graph->inputs()[idx]; | ||||
|     if (v->type()->isSubtypeOf(*TensorType::get())) { | ||||
|       input_values.push_back(v); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   std::stack<Block*> blocks_to_visit; | ||||
|   blocks_to_visit.push(graph->block()); | ||||
|   while (!blocks_to_visit.empty()) { | ||||
|     Block* b = blocks_to_visit.top(); | ||||
|     blocks_to_visit.pop(); | ||||
|     for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) { | ||||
|       Node* n = *it++; | ||||
|       for (Value* v : n->outputs()) { | ||||
|         if (!v->type()->isSubtypeOf(*TensorType::get())) { | ||||
|           continue; | ||||
|         } | ||||
|         collectObserverNodesAndValueToQuantize(module, v); | ||||
|       } | ||||
|  | ||||
|       for (Block* subblock : n->blocks()) { | ||||
|         blocks_to_visit.push(subblock); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   for (Value* v : input_values) { | ||||
|     collectObserverNodesAndValueToQuantize(module, v); | ||||
|   } | ||||
|  | ||||
|   GRAPH_DUMP("Before insertCalculateQparamsAndQuantizationOps:", graph); | ||||
|   Value* self = graph->inputs()[0]; | ||||
|   insertCalculateQParamsAndQuantizationOps(module, graph.get(), self); | ||||
|   GRAPH_DUMP("After insertCalculateQparamsAndQuantizationOps:", graph); | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| void ReplicateQuant(std::shared_ptr<Graph>& graph) { | ||||
| @ -1573,5 +1814,57 @@ Module InsertQuantDeQuant( | ||||
|   return module; | ||||
| } | ||||
|  | ||||
| /* | ||||
|  * | ||||
|  * Assumption: method_name method has observer placed | ||||
|  * Objective: modify that method to insert calls to: | ||||
|  * 1. calculate_qparams | ||||
|  * 2. GetAttr for axis and dtype values | ||||
|  * 3. Use Values from above two to insert calls to quant + dequant | ||||
|  * Thus after this step you have a graph of, e.g., observe_forward, | ||||
|  * that has observer nodes, calculate_qparams run on those observer nodes, | ||||
|  * output of which is used by quant-dequant nodes. output of dequant is used | ||||
|  * by the actual op. | ||||
|  * Later on we will replace dequant + op (e.g. linear) with | ||||
|  * 1. prepacked_op context | ||||
|  * 2. unpack | ||||
|  * 3. dequantize | ||||
|  * 4. linear | ||||
|  * | ||||
|  * Of the above pattern 2, 3, and 4 can be replaced by linear_run op | ||||
|  */ | ||||
| // Module InsertQuantDeQuantForOnDevicePTQ( | ||||
| Module InsertQuantDeQuantOnDevicePTQ( | ||||
|     Module& input_module, | ||||
|     const std::string& method_name, | ||||
|     bool inplace, | ||||
|     bool debug, | ||||
|     QuantType quant_type) { | ||||
|   Module module = input_module.clone(inplace); | ||||
|   const std::string kObserveString = "observe_"; | ||||
|   const auto matched_pos = method_name.find(kObserveString); | ||||
|   const auto end_pos = matched_pos + kObserveString.length(); | ||||
|   const std::string orig_method_name = method_name.substr(end_pos); | ||||
|   TORCH_CHECK( | ||||
|       matched_pos == 0, | ||||
|       "Quant dequant nodes can only be added to observe_", | ||||
|       orig_method_name, | ||||
|       ". Please make sure to run prepare step for on-device PTQ."); | ||||
|  | ||||
|   std::string quantize_method_name = "quantize_" + orig_method_name; | ||||
|   cloneMethod(module, method_name, quantize_method_name); | ||||
|   InsertQuantDeQuantHelper h(quant_type, debug); | ||||
|   h.runForOnDevicePTQ(module, quantize_method_name); | ||||
|   h.removeObserverNodes(module); | ||||
|   // Dont need: | ||||
|   // ReplicateChooseQParamsQuantDequant: This is propagating dynamic quant's | ||||
|   // quant dequant RemoveRedundantQuantizationOps: THis is removing activation | ||||
|   // observers for dynamic quant when the op related to it is not dynamically | ||||
|   // quantizable. Doesnt really make sense. In our case we wont have those | ||||
|   // anyway since for dynamic quant activations wont be observed We can still | ||||
|   // use this function because the above two methods should really be a noop | ||||
|   h.propagateQuantizationOps(module); | ||||
|   return module; | ||||
| } | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
|  | ||||
| @ -35,5 +35,12 @@ TORCH_API Module InsertQuantDeQuant( | ||||
|     bool debug, | ||||
|     QuantType quant_type = QuantType::STATIC); | ||||
|  | ||||
| TORCH_API Module InsertQuantDeQuantOnDevicePTQ( | ||||
|     Module& module, | ||||
|     const std::string& method_name, | ||||
|     bool inplace, | ||||
|     bool debug, | ||||
|     QuantType quant_type = QuantType::STATIC); | ||||
|  | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
|  | ||||
| @ -443,6 +443,22 @@ void initJITBindings(PyObject* module) { | ||||
|           py::arg("inplace"), | ||||
|           py::arg("debug"), | ||||
|           py::arg("quant_type_int") = 1) | ||||
|       .def( | ||||
|           "_jit_pass_insert_quant_dequant_for_ondevice_ptq", | ||||
|           [](Module& module, | ||||
|              const std::string& method_name, | ||||
|              bool inplace, | ||||
|              bool debug, | ||||
|              int quant_type_int) { | ||||
|             auto quant_type = static_cast<QuantType>(quant_type_int); | ||||
|             return InsertQuantDeQuantOnDevicePTQ( | ||||
|                 module, method_name, inplace, debug, quant_type); | ||||
|           }, | ||||
|           py::arg("module"), | ||||
|           py::arg("method_name"), | ||||
|           py::arg("inplace"), | ||||
|           py::arg("debug"), | ||||
|           py::arg("quant_type_int") = 1) | ||||
|       .def( | ||||
|           "_jit_pass_insert_prepack_unpack", | ||||
|           [](std::shared_ptr<Graph>& g) { return InsertPrepackUnpack(g); }) | ||||
|  | ||||
| @ -25,7 +25,7 @@ _all__ = [ | ||||
|     'quantize', 'quantize_dynamic', 'quantize_qat', | ||||
|     'prepare', 'convert', 'prepare_qat', | ||||
|     # Top level API for graph mode quantization on TorchScript | ||||
|     'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit' | ||||
|     'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit', '_convert_ondevice_dynamic_jit', | ||||
|     # Top level API for graph mode quantization on GraphModule(torch.fx) | ||||
|     # 'fuse_fx', 'quantize_fx',  # TODO: add quantize_dynamic_fx | ||||
|     # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', | ||||
|  | ||||
		Reference in New Issue
	
	Block a user