[On Device Quantization][pytorch]Make insert_quant_dequant support ondevice ptq (#83570)

Summary:
This diff adds a way to:
- clone previously observed method
- Add calls to observer's calculate_qparams methods
- Extract the scale and zero point
- Use them to insert quant dequant nodes

Now for forward method we have
- observe_forward
- quantize_forward

observe_forward is used post training to observer statistics. In the
case of dynamic PTQ this requires just running that method once to
update weight observer statistics.

quantize_forward method will be used to use the observer
statistics to calculate quantization parameters and apply that to quant
dequant op.

Subsequent diffs will replace dequant + op with their quantized op
counter parts and replace quantize ops with relevant packed params class
where possible

Test Plan:
To be written

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D38771419](https://our.internmc.facebook.com/intern/diff/D38771419)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83570
Approved by: https://github.com/jerryzh168
This commit is contained in:
Kimish Patel
2022-08-27 16:06:14 -07:00
committed by PyTorch MergeBot
parent 6a5d9f1be0
commit 446afb5f9f
7 changed files with 453 additions and 30 deletions

View File

@ -443,6 +443,22 @@ void initJITBindings(PyObject* module) {
py::arg("inplace"),
py::arg("debug"),
py::arg("quant_type_int") = 1)
.def(
"_jit_pass_insert_quant_dequant_for_ondevice_ptq",
[](Module& module,
const std::string& method_name,
bool inplace,
bool debug,
int quant_type_int) {
auto quant_type = static_cast<QuantType>(quant_type_int);
return InsertQuantDeQuantOnDevicePTQ(
module, method_name, inplace, debug, quant_type);
},
py::arg("module"),
py::arg("method_name"),
py::arg("inplace"),
py::arg("debug"),
py::arg("quant_type_int") = 1)
.def(
"_jit_pass_insert_prepack_unpack",
[](std::shared_ptr<Graph>& g) { return InsertPrepackUnpack(g); })