mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40328 Test Plan: Imported from OSS Differential Revision: D22149708 fbshipit-source-id: 63a1cd229d9e4668fba0ef3977e894cb8984318b
194 lines
7.9 KiB
Python
194 lines
7.9 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
||
|
||
import enum
|
||
import torch
|
||
from .qconfig import QConfig
|
||
from torch.jit._recursive import wrap_cpp_module
|
||
|
||
# Quantization type (dynamic quantization, static quantization).
|
||
# Should match the c++ enum in quantization_type.h
|
||
class QuantType(enum.IntEnum):
|
||
DYNAMIC = 0
|
||
STATIC = 1
|
||
|
||
def _check_is_script_module(model):
|
||
if not isinstance(model, torch.jit.ScriptModule):
|
||
raise ValueError('input must be a script module, got: ' + str(type(model)))
|
||
|
||
def _check_forward_method(model):
|
||
if not model._c._has_method('forward'):
|
||
raise ValueError('input script module does not have forward method')
|
||
|
||
def script_qconfig(qconfig):
|
||
r"""Instantiate the activation and weight observer modules and script
|
||
them, these observer module instances will be deepcopied during
|
||
prepare_jit step.
|
||
"""
|
||
return QConfig(
|
||
activation=torch.jit.script(qconfig.activation())._c,
|
||
weight=torch.jit.script(qconfig.weight())._c)
|
||
|
||
def script_qconfig_dict(qconfig_dict):
|
||
r"""Helper function used by `prepare_jit`.
|
||
Apply `script_qconfig` for all entries in `qconfig_dict` that is
|
||
not None.
|
||
"""
|
||
return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()}
|
||
|
||
def fuse_conv_bn_jit(model):
|
||
r""" Fuse conv - bn module
|
||
Works for eval model only.
|
||
|
||
Args:
|
||
model: TorchScript model from scripting or tracing
|
||
"""
|
||
return torch.jit._recursive.wrap_cpp_module(torch._C._jit_pass_fold_convbn(model._c))
|
||
|
||
def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC):
|
||
assert not inplace, "The inplace support is still in development"
|
||
_check_is_script_module(model)
|
||
_check_forward_method(model)
|
||
if not all(isinstance(x, str) for x in qconfig_dict.keys()):
|
||
raise ValueError('qconfig_dict should only contain names(str) as keys.')
|
||
scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
|
||
model = fuse_conv_bn_jit(model)
|
||
return wrap_cpp_module(torch._C._jit_pass_insert_observers(model._c,
|
||
'forward',
|
||
scripted_qconfig_dict,
|
||
inplace,
|
||
quant_type))
|
||
|
||
def prepare_jit(model, qconfig_dict, inplace=False):
|
||
return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.STATIC)
|
||
|
||
def prepare_dynamic_jit(model, qconfig_dict, inplace=False):
|
||
return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC)
|
||
|
||
def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC):
|
||
assert not inplace, "The inplace support is still in development"
|
||
_check_is_script_module(model)
|
||
model.eval()
|
||
model = wrap_cpp_module(torch._C._jit_pass_insert_quant_dequant(model._c, 'forward', inplace, debug, quant_type))
|
||
if not debug:
|
||
# Moving model parameters to CPU since quantized operators
|
||
# are only supported on CPU right now
|
||
model.cpu()
|
||
model = wrap_cpp_module(torch._C._jit_pass_quant_finalize(model._c, quant_type))
|
||
return model
|
||
|
||
def convert_jit(model, inplace=False, debug=False):
|
||
return _convert_jit(model, inplace, debug, quant_type=QuantType.STATIC)
|
||
|
||
def convert_dynamic_jit(model, inplace=False, debug=False):
|
||
return _convert_jit(model, inplace, debug, quant_type=QuantType.DYNAMIC)
|
||
|
||
def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC):
|
||
assert not inplace, "We don't support inplace right now"
|
||
# Always do inplace convert because the Tensor is already
|
||
# copied in prepare_jit when inplace is False
|
||
if quant_type == QuantType.DYNAMIC:
|
||
model = prepare_dynamic_jit(model, qconfig_dict, inplace)
|
||
# TODO: change inplace to True
|
||
model = convert_dynamic_jit(model, False, debug)
|
||
else:
|
||
assert run_fn, "Must provide calibration function for post training static quantization"
|
||
assert run_args, "Must provide calibration dataset for post training static quantization"
|
||
model = prepare_jit(model, qconfig_dict, inplace)
|
||
run_fn(model, *run_args)
|
||
# TODO: change inplace to True
|
||
model = convert_jit(model, False, debug)
|
||
|
||
return model
|
||
|
||
def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False):
|
||
r"""Quantize the input float TorchScript model with
|
||
post training static quantization.
|
||
|
||
First it will prepare the model for calibration, then it calls
|
||
`run_fn` which will run the calibration step, after that we will
|
||
convert the model to a quantized model.
|
||
|
||
Args:
|
||
`model`: input float TorchScript model
|
||
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
|
||
qconfig for that module as value, empty key means the qconfig will be applied
|
||
to whole model unless it’s overwritten by more specific configurations, the
|
||
qconfig for each module is either found in the dictionary or fallback to
|
||
the qconfig of parent module.
|
||
|
||
Right now qconfig_dict is the only way to configure how the model is quantized,
|
||
and it is done in the granularity of module, that is, we only support one type
|
||
of qconfig for each torch.nn.Module, and the qconfig for sub module will
|
||
override the qconfig for parent module, empty string means global configuration.
|
||
`run_fn`: a calibration function for calibrating the prepared model
|
||
`run_args`: positional arguments for `run_fn`
|
||
`inplace`: carry out model transformations in-place, the original module is
|
||
mutated
|
||
`debug`: flag for producing a debug friendly model (preserve weight attribute)
|
||
|
||
Return:
|
||
Quantized TorchSciprt model.
|
||
|
||
Example:
|
||
```python
|
||
import torch
|
||
from torch.quantization import get_default_qconfig
|
||
from torch.quantization import quantize_jit
|
||
|
||
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
|
||
qconfig = get_default_qconfig('fbgemm')
|
||
def calibrate(model, data_loader):
|
||
model.eval()
|
||
with torch.no_grad():
|
||
for image, target in data_loader:
|
||
model(image)
|
||
|
||
quantized_model = quantize_jit(
|
||
ts_model,
|
||
{'': qconfig},
|
||
calibrate,
|
||
[data_loader_test])
|
||
```
|
||
"""
|
||
return _quantize_jit(model, qconfig_dict, run_fn, run_args, inplace, debug, quant_type=QuantType.STATIC)
|
||
|
||
def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False):
|
||
r"""Quantize the input float TorchScript model with
|
||
post training dynamic quantization.
|
||
Currently only qint8 quantization of torch.nn.Linear is supported.
|
||
|
||
Args:
|
||
`model`: input float TorchScript model
|
||
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
|
||
qconfig for that module as value, please see detailed
|
||
descriptions in :func:`~torch.quantization.quantize_jit`
|
||
`inplace`: carry out model transformations in-place, the original module is
|
||
mutated
|
||
`debug`: flag for producing a debug friendly model (preserve weight attribute)
|
||
|
||
Return:
|
||
Quantized TorchSciprt model.
|
||
|
||
Example:
|
||
```python
|
||
import torch
|
||
from torch.quantization import per_channel_dynamic_qconfig
|
||
from torch.quantization import quantize_dynmiac_jit
|
||
|
||
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
|
||
qconfig = get_default_qconfig('fbgemm')
|
||
def calibrate(model, data_loader):
|
||
model.eval()
|
||
with torch.no_grad():
|
||
for image, target in data_loader:
|
||
model(image)
|
||
|
||
quantized_model = quantize_dynamic_jit(
|
||
ts_model,
|
||
{'': qconfig},
|
||
calibrate,
|
||
[data_loader_test])
|
||
```
|
||
"""
|
||
return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC)
|