mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Pytorch][quantization][ondevice] Add a wrapper API for server side prep (#83742)
for ondevice quantization Summary: THis diff just wraps existing API for ondevice quantization Test Plan: test/quantization/jit/test_ondevice_quantization.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D38868647](https://our.internmc.facebook.com/intern/diff/D38868647) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83742 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
5c7e801c50
commit
eebdcb5a2e
@ -12,7 +12,7 @@ from torch.ao.quantization.quantize_jit import (
|
||||
prepare_dynamic_jit,
|
||||
convert_dynamic_jit,
|
||||
_prepare_ondevice_dynamic_jit,
|
||||
_convert_ondevice_dynamic_jit,
|
||||
_quantize_ondevice_dynamic_jit,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
@ -22,6 +22,8 @@ from torch.testing._internal.common_quantization import (
|
||||
LinearAddModel,
|
||||
)
|
||||
|
||||
from torch.jit.mobile import _load_for_lite_interpreter
|
||||
|
||||
from torch.testing import FileCheck
|
||||
|
||||
import io
|
||||
@ -69,8 +71,7 @@ class OnDevicePTQUtils(object):
|
||||
def ptq_dynamic_quantize(model, qconfig_dict):
|
||||
inputs = model.get_example_inputs()
|
||||
m = get_script_module(model, False, inputs)
|
||||
m = _prepare_ondevice_dynamic_jit(m, qconfig_dict)
|
||||
m = _convert_ondevice_dynamic_jit(m, 'forward', True, False)
|
||||
m = _quantize_ondevice_dynamic_jit(m, qconfig_dict, 'forward', True)
|
||||
return m
|
||||
|
||||
@staticmethod
|
||||
@ -420,6 +421,17 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
|
||||
output = m.quantized_forward(*inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
|
||||
# check for lite interpreter
|
||||
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
||||
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
m = _load_for_lite_interpreter(buffer) # Error here
|
||||
m.run_method("reset_observers_forward")
|
||||
m.run_method("observe_forward", *inputs)
|
||||
m.run_method("quantize_forward", *inputs)
|
||||
output = m.run_method("quantized_forward", *inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
|
||||
model.eval()
|
||||
inputs = model.get_example_inputs()
|
||||
ref_m = torch.jit.script(model)
|
||||
@ -444,6 +456,17 @@ class TestOnDeviceDynamicPTQFinalize(TestCase):
|
||||
output = m.quantized_forward(*inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
|
||||
# check for lite interpreter
|
||||
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
||||
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
m = _load_for_lite_interpreter(buffer) # Error here
|
||||
m.run_method("reset_observers_forward")
|
||||
m.run_method("observe_forward", *inputs)
|
||||
m.run_method("quantize_forward", *inputs)
|
||||
output = m.run_method("quantized_forward", *inputs)
|
||||
self.assertTrue(torch.allclose(ref_output, output))
|
||||
|
||||
|
||||
def test_quantize_forward(self):
|
||||
model = LinearAddModel()
|
||||
|
@ -145,6 +145,12 @@ def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None)
|
||||
def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False):
|
||||
return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC)
|
||||
|
||||
|
||||
def _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=False):
|
||||
model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace)
|
||||
model = _convert_ondevice_dynamic_jit(model, method_name, inplace)
|
||||
return model
|
||||
|
||||
def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC):
|
||||
# Always do inplace convert because the Tensor is already
|
||||
# copied in prepare_jit when inplace is False
|
||||
@ -255,3 +261,63 @@ def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False):
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit")
|
||||
return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC)
|
||||
|
||||
|
||||
def _quantize_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False):
|
||||
r"""Prepares the input float TorchScript model with
|
||||
*on-device* post training dynamic quantization.
|
||||
Currently only qint8 quantization of torch.nn.Linear is supported.
|
||||
|
||||
Args:
|
||||
`model`: input float TorchScript model
|
||||
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
|
||||
qconfig for that module as value, please see detailed
|
||||
`method_name`: Name of the method within the model, to be prepared for quantization
|
||||
descriptions in :func:`~torch.ao.quantization.quantize_jit`
|
||||
`inplace`: carry out model transformations in-place, the original module is
|
||||
mutated
|
||||
|
||||
Return:
|
||||
TorchScript model that is ready for on device quantization.
|
||||
This means that the returned
|
||||
model has:
|
||||
- Method is inlined.
|
||||
- Model has observer modules inserted in the model.
|
||||
- Model has packed params inserted in the model. However they are empty as in they dont
|
||||
contain valid quantized weights.
|
||||
- observe_<method_name> is added that observe the values to be quantized.
|
||||
- reset_observers_<method_name> to reset observers.
|
||||
- quantize_<method_name> is added to the model.
|
||||
- This method extract scale, zero points.
|
||||
- Quantizes observed weights.
|
||||
- Creates packed params from it and update the attribute of the model with the new values
|
||||
for the packed params.
|
||||
- Reset the original fp32 weights with empty tensor using SetAttr.
|
||||
- quantized_<method_name> is added to the model.
|
||||
- This method uses quantized weights and quantized linear ops instead of fp32 op.
|
||||
- This method should be used for inference post PTQ.
|
||||
- Note that all method's signatures should be the same as method_name.
|
||||
|
||||
Later on device:
|
||||
- Run reset_observers_<method_name>
|
||||
- Run observe_<method_name>
|
||||
- Run quantize_<method_name>
|
||||
- Now model can be saved and loaded later.
|
||||
- Run model with quantized_<method_name>
|
||||
|
||||
Example:
|
||||
```python
|
||||
import torch
|
||||
from torch.ao.quantization import per_channel_dynamic_qconfig
|
||||
from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit
|
||||
|
||||
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
|
||||
qconfig = get_default_qconfig('fbgemm')
|
||||
quant_ready_model = _quantize_ondevice_dynamic_jit(
|
||||
ts_model,
|
||||
{'': qconfig},
|
||||
'forward',
|
||||
True)
|
||||
```
|
||||
"""
|
||||
return _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=inplace)
|
||||
|
@ -25,7 +25,8 @@ _all__ = [
|
||||
'quantize', 'quantize_dynamic', 'quantize_qat',
|
||||
'prepare', 'convert', 'prepare_qat',
|
||||
# Top level API for graph mode quantization on TorchScript
|
||||
'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit', '_convert_ondevice_dynamic_jit',
|
||||
'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit',
|
||||
'_convert_ondevice_dynamic_jit', '_quantize_ondevice_dynamic_jit',
|
||||
# Top level API for graph mode quantization on GraphModule(torch.fx)
|
||||
# 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
|
||||
# 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
|
||||
|
Reference in New Issue
Block a user