[Pytorch][quantization][ondevice] Add a wrapper API for server side prep (#83742)

for ondevice quantization

Summary:
THis diff just wraps existing API for ondevice quantization

Test Plan:
test/quantization/jit/test_ondevice_quantization.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D38868647](https://our.internmc.facebook.com/intern/diff/D38868647)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83742
Approved by: https://github.com/jerryzh168
This commit is contained in:
Kimish Patel
2022-08-27 16:06:16 -07:00
committed by PyTorch MergeBot
parent 5c7e801c50
commit eebdcb5a2e
3 changed files with 94 additions and 4 deletions

View File

@ -12,7 +12,7 @@ from torch.ao.quantization.quantize_jit import (
prepare_dynamic_jit,
convert_dynamic_jit,
_prepare_ondevice_dynamic_jit,
_convert_ondevice_dynamic_jit,
_quantize_ondevice_dynamic_jit,
)
from torch.testing._internal.common_utils import TestCase
@ -22,6 +22,8 @@ from torch.testing._internal.common_quantization import (
LinearAddModel,
)
from torch.jit.mobile import _load_for_lite_interpreter
from torch.testing import FileCheck
import io
@ -69,8 +71,7 @@ class OnDevicePTQUtils(object):
def ptq_dynamic_quantize(model, qconfig_dict):
inputs = model.get_example_inputs()
m = get_script_module(model, False, inputs)
m = _prepare_ondevice_dynamic_jit(m, qconfig_dict)
m = _convert_ondevice_dynamic_jit(m, 'forward', True, False)
m = _quantize_ondevice_dynamic_jit(m, qconfig_dict, 'forward', True)
return m
@staticmethod
@ -420,6 +421,17 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
m.run_method("reset_observers_forward")
m.run_method("observe_forward", *inputs)
m.run_method("quantize_forward", *inputs)
output = m.run_method("quantized_forward", *inputs)
self.assertTrue(torch.allclose(ref_output, output))
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
@ -444,6 +456,17 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
m.run_method("reset_observers_forward")
m.run_method("observe_forward", *inputs)
m.run_method("quantize_forward", *inputs)
output = m.run_method("quantized_forward", *inputs)
self.assertTrue(torch.allclose(ref_output, output))
def test_quantize_forward(self):
model = LinearAddModel()

View File

@ -145,6 +145,12 @@ def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None)
def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False):
return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC)
def _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=False):
model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace)
model = _convert_ondevice_dynamic_jit(model, method_name, inplace)
return model
def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC):
# Always do inplace convert because the Tensor is already
# copied in prepare_jit when inplace is False
@ -255,3 +261,63 @@ def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False):
"""
torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit")
return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC)
def _quantize_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False):
r"""Prepares the input float TorchScript model with
*on-device* post training dynamic quantization.
Currently only qint8 quantization of torch.nn.Linear is supported.
Args:
`model`: input float TorchScript model
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
qconfig for that module as value, please see detailed
`method_name`: Name of the method within the model, to be prepared for quantization
descriptions in :func:`~torch.ao.quantization.quantize_jit`
`inplace`: carry out model transformations in-place, the original module is
mutated
Return:
TorchScript model that is ready for on device quantization.
This means that the returned
model has:
- Method is inlined.
- Model has observer modules inserted in the model.
- Model has packed params inserted in the model. However they are empty as in they dont
contain valid quantized weights.
- observe_<method_name> is added that observe the values to be quantized.
- reset_observers_<method_name> to reset observers.
- quantize_<method_name> is added to the model.
- This method extract scale, zero points.
- Quantizes observed weights.
- Creates packed params from it and update the attribute of the model with the new values
for the packed params.
- Reset the original fp32 weights with empty tensor using SetAttr.
- quantized_<method_name> is added to the model.
- This method uses quantized weights and quantized linear ops instead of fp32 op.
- This method should be used for inference post PTQ.
- Note that all method's signatures should be the same as method_name.
Later on device:
- Run reset_observers_<method_name>
- Run observe_<method_name>
- Run quantize_<method_name>
- Now model can be saved and loaded later.
- Run model with quantized_<method_name>
Example:
```python
import torch
from torch.ao.quantization import per_channel_dynamic_qconfig
from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
qconfig = get_default_qconfig('fbgemm')
quant_ready_model = _quantize_ondevice_dynamic_jit(
ts_model,
{'': qconfig},
'forward',
True)
```
"""
return _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=inplace)

View File

@ -25,7 +25,8 @@ _all__ = [
'quantize', 'quantize_dynamic', 'quantize_qat',
'prepare', 'convert', 'prepare_qat',
# Top level API for graph mode quantization on TorchScript
'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit', '_convert_ondevice_dynamic_jit',
'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit',
'_convert_ondevice_dynamic_jit', '_quantize_ondevice_dynamic_jit',
# Top level API for graph mode quantization on GraphModule(torch.fx)
# 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
# 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',