[quant][graphmode] Add observers for dynamic quant (#35121)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35121

For dynamic quantization we insert observers at the input to mimic the quatization of activations that happens in the operator
Observer for weight is inserted similar to static quant

Test Plan:
python test/test_quantize_script.py

Sample output for single layer FC

.graph(%self : __torch__.___torch_mangle_4.M,
      %x.2 : Tensor):
  %_observer_1 : __torch__.torch.quantization.observer.MinMaxObserver = prim::GetAttr[name="_observer_1"](%self)
  %x.1 : Tensor = prim::CallMethod[name="forward"](%_observer_1, %x.2)
  %2 : __torch__.torch.nn.modules.linear.___torch_mangle_5.Linear = prim::GetAttr[name="fc"](%self)
  %3 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # test/test_quantize_script.py:19:23
  return (%3)

graph(%self : __torch__.torch.nn.modules.linear.___torch_mangle_5.Linear,
      %input.1 : Tensor):
 %2 : Function = prim::Constant[name="linear"]()
 %3 : Tensor = prim::GetAttr[name="weight"](%self)
 %_observer_0 : __torch__.torch.quantization.observer.MinMaxObserver = prim::GetAttr[name="_observer_0"](%self)
 %7 : Tensor = prim::CallMethod[name="forward"](%_observer_0, %3)
 %4 : Tensor = prim::GetAttr[name="bias"](%self)
 %5 : Tensor = prim::CallFunction(%2, %input.1, %7, %4) # /home/supriyar/miniconda3/envs/pytorch_py3/lib/python3.7/site-packages/torch/nn/modules/linear.py:87:15
 return (%5)

Imported from OSS

Differential Revision: D20599144

fbshipit-source-id: 9a8fa0e8655b9908826b981dce8a11d86efce5df
This commit is contained in:
Supriya Rao
2020-03-24 10:46:38 -07:00
committed by Facebook GitHub Bot
parent a045343402
commit 55019d357e
5 changed files with 110 additions and 9 deletions

View File

@ -172,16 +172,18 @@ void initJITBindings(PyObject* module) {
[](Module& module,
const std::string& method_name,
const py::dict& qconfig_dict,
bool inplace) {
bool inplace,
bool is_dynamic) {
auto dict = py::cast<std::unordered_map<
std::string,
std::tuple<Module, Module>>>(qconfig_dict);
return InsertObservers(module, method_name, dict, inplace);
return InsertObservers(module, method_name, dict, inplace, is_dynamic);
},
py::arg("module"),
py::arg("method_name"),
py::arg("qconfig_dict"),
py::arg("inplace") = false)
py::arg("inplace") = false,
py::arg("is_dynamic") = false)
.def(
"_jit_pass_insert_quant_dequant",
[](Module& module,