mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 07:24:58 +08:00
This reverts commit 6b7efac3c9ea5c9fbfb18069abd254ad7d9a103e. Reverted https://github.com/pytorch/pytorch/pull/90205 on behalf of https://github.com/seemethere due to Reverting since this caused failures in internal systems, see https://fb.workplace.com/groups/802176577445480/posts/894284641568006 for discussion
753 lines
32 KiB
Python
753 lines
32 KiB
Python
from typing import Any, Dict, Optional, Set, Tuple, Union
|
|
import warnings
|
|
|
|
import torch
|
|
from torch.fx import GraphModule
|
|
from .fx.tracer import QuantizationTracer
|
|
from .fx import fuse # noqa: F401
|
|
from .fx import prepare # noqa: F401
|
|
from .fx.convert import convert
|
|
from .backend_config import ( # noqa: F401
|
|
BackendConfig,
|
|
get_tensorrt_backend_config,
|
|
)
|
|
from .fx.graph_module import ObservedGraphModule
|
|
from .fx.custom_config import (
|
|
ConvertCustomConfig,
|
|
FuseCustomConfig,
|
|
PrepareCustomConfig,
|
|
)
|
|
from .fx.utils import get_custom_module_class_keys # noqa: F401
|
|
from .fx.utils import get_skipped_module_name_and_classes
|
|
from .qconfig_mapping import QConfigMapping
|
|
|
|
def _check_is_graph_module(model: torch.nn.Module) -> None:
|
|
if not isinstance(model, GraphModule):
|
|
raise ValueError(
|
|
"input model must be a GraphModule, "
|
|
+ "Got type:"
|
|
+ str(type(model))
|
|
+ " Please make "
|
|
+ "sure to follow the tutorials."
|
|
)
|
|
|
|
|
|
def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
|
|
r""" Swap FloatFunctional with FXFloatFunctional
|
|
"""
|
|
modules_to_swap = []
|
|
for name, module in model.named_children():
|
|
if isinstance(module, torch.ao.nn.quantized.FloatFunctional):
|
|
modules_to_swap.append(name)
|
|
else:
|
|
_swap_ff_with_fxff(module)
|
|
|
|
for name in modules_to_swap:
|
|
del model._modules[name]
|
|
model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional()
|
|
|
|
|
|
def _fuse_fx(
|
|
graph_module: GraphModule,
|
|
is_qat: bool,
|
|
fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
) -> GraphModule:
|
|
r""" Internal helper function to fuse modules in preparation for quantization
|
|
|
|
Args:
|
|
graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
|
|
"""
|
|
_check_is_graph_module(graph_module)
|
|
return fuse(
|
|
graph_module, is_qat, fuse_custom_config, backend_config) # type: ignore[operator]
|
|
|
|
|
|
class Scope(object):
|
|
""" Scope object that records the module path and the module type
|
|
of a module. Scope is used to track the information of the module
|
|
that contains a Node in a Graph of GraphModule. For example::
|
|
|
|
class Sub(torch.nn.Module):
|
|
def forward(self, x):
|
|
# This will be a call_method Node in GraphModule,
|
|
# scope for this would be (module_path="sub", module_type=Sub)
|
|
return x.transpose(1, 2)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
self.sub = Sub()
|
|
|
|
def forward(self, x):
|
|
# This will be a call_method Node as well,
|
|
# scope for this would be (module_path="", None)
|
|
x = x.transpose(1, 2)
|
|
x = self.sub(x)
|
|
return x
|
|
|
|
"""
|
|
|
|
def __init__(self, module_path: str, module_type: Any):
|
|
super().__init__()
|
|
self.module_path = module_path
|
|
self.module_type = module_type
|
|
|
|
|
|
class ScopeContextManager(object):
|
|
""" A context manager to track the Scope of Node during symbolic tracing.
|
|
When entering a forward function of a Module, we'll update the scope information of
|
|
the current module, and when we exit, we'll restore the previous scope information.
|
|
"""
|
|
|
|
def __init__(
|
|
self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
|
|
):
|
|
super().__init__()
|
|
self.prev_module_type = scope.module_type
|
|
self.prev_module_path = scope.module_path
|
|
self.scope = scope
|
|
self.scope.module_path = current_module_path
|
|
self.scope.module_type = type(current_module)
|
|
|
|
def __enter__(self):
|
|
return
|
|
|
|
def __exit__(self, *args):
|
|
self.scope.module_path = self.prev_module_path
|
|
self.scope.module_type = self.prev_module_type
|
|
return
|
|
|
|
|
|
def _prepare_fx(
|
|
model: torch.nn.Module,
|
|
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
|
|
is_qat: bool,
|
|
example_inputs: Tuple[Any, ...],
|
|
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
|
|
_equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
is_standalone_module: bool = False,
|
|
) -> ObservedGraphModule:
|
|
r""" Internal helper function for prepare_fx
|
|
Args:
|
|
`model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`:
|
|
see docs for :func:`~torch.ao.quantization.prepare_fx`
|
|
`is_standalone_module`: a boolean flag indicates whether we are
|
|
quantizing a standalone module or not, a standalone module
|
|
is a submodule of the parent module that is not inlined in the
|
|
forward graph of the parent module,
|
|
the way we quantize standalone module is described in:
|
|
:func:`~torch.ao.quantization._prepare_standalone_module_fx`
|
|
"""
|
|
if prepare_custom_config is None:
|
|
prepare_custom_config = PrepareCustomConfig()
|
|
if _equalization_config is None:
|
|
_equalization_config = QConfigMapping()
|
|
|
|
if isinstance(prepare_custom_config, Dict):
|
|
warnings.warn(
|
|
"Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
|
|
"in a future version. Please pass in a PrepareCustomConfig instead.")
|
|
prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
|
|
|
|
# swap FloatFunctional with FXFloatFunctional
|
|
_swap_ff_with_fxff(model)
|
|
|
|
skipped_module_names, skipped_module_classes = \
|
|
get_skipped_module_name_and_classes(prepare_custom_config, is_standalone_module)
|
|
preserved_attributes = prepare_custom_config.preserved_attributes
|
|
# symbolically trace the model
|
|
tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) # type: ignore[arg-type]
|
|
graph_module = GraphModule(model, tracer.trace(model))
|
|
for attr_name in preserved_attributes:
|
|
setattr(graph_module, attr_name, getattr(model, attr_name))
|
|
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(prepare_custom_config.preserved_attributes)
|
|
graph_module = _fuse_fx(
|
|
graph_module,
|
|
is_qat,
|
|
fuse_custom_config,
|
|
backend_config)
|
|
prepared = prepare(
|
|
graph_module,
|
|
qconfig_mapping,
|
|
is_qat,
|
|
tracer.node_name_to_scope,
|
|
example_inputs=example_inputs,
|
|
prepare_custom_config=prepare_custom_config,
|
|
_equalization_config=_equalization_config,
|
|
backend_config=backend_config,
|
|
is_standalone_module=is_standalone_module,
|
|
) # type: ignore[operator]
|
|
|
|
for attr_name in preserved_attributes:
|
|
setattr(prepared, attr_name, getattr(model, attr_name))
|
|
return prepared
|
|
|
|
|
|
def _prepare_standalone_module_fx(
|
|
model: torch.nn.Module,
|
|
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
|
|
is_qat: bool,
|
|
example_inputs: Tuple[Any, ...],
|
|
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
) -> GraphModule:
|
|
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
|
|
parent module.
|
|
standalone_module means it a submodule that is not inlined in parent module,
|
|
and will be quantized separately as one unit.
|
|
|
|
How the standalone module is observed is specified by `input_quantized_idxs` and
|
|
`output_quantized_idxs` in the prepare_custom_config for the standalone module
|
|
|
|
Returns:
|
|
|
|
* model(GraphModule): prepared standalone module. It has these attributes:
|
|
|
|
* `_standalone_module_input_quantized_idxs(List[Int])`: a list of
|
|
indexes for the graph input that is expected to be quantized,
|
|
same as input_quantized_idxs configuration provided
|
|
for the standalone module
|
|
* `_standalone_module_output_quantized_idxs(List[Int])`: a list of
|
|
indexs for the graph output that is quantized
|
|
same as input_quantized_idxs configuration provided
|
|
for the standalone module
|
|
|
|
"""
|
|
return _prepare_fx(
|
|
model,
|
|
qconfig_mapping,
|
|
is_qat,
|
|
example_inputs,
|
|
prepare_custom_config,
|
|
backend_config=backend_config,
|
|
is_standalone_module=True,
|
|
)
|
|
|
|
|
|
def fuse_fx(
|
|
model: torch.nn.Module,
|
|
fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
) -> GraphModule:
|
|
r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
|
|
Fusion rules are defined in torch.quantization.fx.fusion_pattern.py
|
|
|
|
Args:
|
|
|
|
* `model` (torch.nn.Module): a torch.nn.Module model
|
|
* `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx.
|
|
See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details
|
|
Example::
|
|
|
|
from torch.ao.quantization import fuse_fx
|
|
m = Model().eval()
|
|
m = fuse_fx(m)
|
|
|
|
"""
|
|
if fuse_custom_config is None:
|
|
fuse_custom_config = FuseCustomConfig()
|
|
|
|
if isinstance(fuse_custom_config, Dict):
|
|
warnings.warn(
|
|
"Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
|
|
"in a future version. Please pass in a FuseCustomConfig instead.")
|
|
fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
|
|
|
|
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
|
|
graph_module = torch.fx.symbolic_trace(model)
|
|
preserved_attributes: Set[str] = set()
|
|
if fuse_custom_config:
|
|
preserved_attributes = set(fuse_custom_config.preserved_attributes)
|
|
for attr_name in preserved_attributes:
|
|
setattr(graph_module, attr_name, getattr(model, attr_name))
|
|
return _fuse_fx(graph_module, False, fuse_custom_config, backend_config)
|
|
|
|
|
|
def prepare_fx(
|
|
model: torch.nn.Module,
|
|
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
|
|
example_inputs: Tuple[Any, ...],
|
|
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
|
|
_equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
) -> ObservedGraphModule:
|
|
r""" Prepare a model for post training static quantization
|
|
|
|
Args:
|
|
* `model` (torch.nn.Module): torch.nn.Module model
|
|
|
|
* `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is
|
|
quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping`
|
|
for more details
|
|
|
|
* `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model,
|
|
Tuple of positional args (keyword args can be passed as positional args as well)
|
|
|
|
* `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool.
|
|
See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details
|
|
|
|
* `_equalization_config`: config for specifying how to perform equalization on the model
|
|
|
|
* `backend_config` (BackendConfig): config that specifies how operators are quantized
|
|
in a backend, this includes how the operators are observed,
|
|
supported fusion patterns, how quantize/dequantize ops are
|
|
inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
|
|
|
|
Return:
|
|
A GraphModule with observer (configured by qconfig_mapping), ready for calibration
|
|
|
|
Example::
|
|
|
|
import torch
|
|
from torch.ao.quantization import get_default_qconfig_mapping
|
|
from torch.ao.quantization import prepare_fx
|
|
|
|
class Submodule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
self.sub = Submodule()
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.sub(x) + x
|
|
return x
|
|
|
|
# initialize a floating point model
|
|
float_model = M().eval()
|
|
|
|
# define calibration function
|
|
def calibrate(model, data_loader):
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for image, target in data_loader:
|
|
model(image)
|
|
|
|
# qconfig is the configuration for how we insert observers for a particular
|
|
# operator
|
|
# qconfig = get_default_qconfig("fbgemm")
|
|
# Example of customizing qconfig:
|
|
# qconfig = torch.ao.quantization.QConfig(
|
|
# activation=MinMaxObserver.with_args(dtype=torch.qint8),
|
|
# weight=MinMaxObserver.with_args(dtype=torch.qint8))
|
|
# `activation` and `weight` are constructors of observer module
|
|
|
|
# qconfig_mapping is a collection of quantization configurations, user can
|
|
# set the qconfig for each operator (torch op calls, functional calls, module calls)
|
|
# in the model through qconfig_mapping
|
|
# the following call will get the qconfig_mapping that works best for models
|
|
# that target "fbgemm" backend
|
|
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
|
|
|
|
# We can customize qconfig_mapping in different ways.
|
|
# e.g. set the global qconfig, which means we will use the same qconfig for
|
|
# all operators in the model, this can be overwritten by other settings
|
|
# qconfig_mapping = QConfigMapping().set_global(qconfig)
|
|
# e.g. quantize the linear submodule with a specific qconfig
|
|
# qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig)
|
|
# e.g. quantize all nn.Linear modules with a specific qconfig
|
|
# qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
|
|
# for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping`
|
|
# argument
|
|
|
|
# example_inputs is a tuple of inputs, that is used to infer the type of the
|
|
# outputs in the model
|
|
# currently it's not used, but please make sure model(*example_inputs) runs
|
|
example_inputs = (torch.randn(1, 3, 224, 224),)
|
|
|
|
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
|
|
# e.g. backend_config = get_default_backend_config("fbgemm")
|
|
# `prepare_fx` inserts observers in the model based on qconfig_mapping and
|
|
# backend_config. If the configuration for an operator in qconfig_mapping
|
|
# is supported in the backend_config (meaning it's supported by the target
|
|
# hardware), we'll insert observer modules according to the qconfig_mapping
|
|
# otherwise the configuration in qconfig_mapping will be ignored
|
|
#
|
|
# Example:
|
|
# in qconfig_mapping, user sets linear module to be quantized with quint8 for
|
|
# activation and qint8 for weight:
|
|
# qconfig = torch.ao.quantization.QConfig(
|
|
# observer=MinMaxObserver.with_args(dtype=torch.quint8),
|
|
# weight=MinMaxObserver.with-args(dtype=torch.qint8))
|
|
# Note: current qconfig api does not support setting output observer, but
|
|
# we may extend this to support these more fine grained control in the
|
|
# future
|
|
#
|
|
# qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
|
|
# in backend config, linear module also supports in this configuration:
|
|
# weighted_int8_dtype_config = DTypeConfig(
|
|
# input_dtype=torch.quint8,
|
|
# output_dtype=torch.quint8,
|
|
# weight_dtype=torch.qint8,
|
|
# bias_type=torch.float)
|
|
|
|
# linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \
|
|
# .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
|
|
# .add_dtype_config(weighted_int8_dtype_config) \
|
|
# ...
|
|
|
|
# backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config)
|
|
# `prepare_fx` will check that the setting requested by suer in qconfig_mapping
|
|
# is supported by the backend_config and insert observers and fake quant modules
|
|
# in the model
|
|
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)
|
|
# Run calibration
|
|
calibrate(prepared_model, sample_inference_data)
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
|
|
return _prepare_fx(
|
|
model,
|
|
qconfig_mapping,
|
|
False, # is_qat
|
|
example_inputs,
|
|
prepare_custom_config,
|
|
_equalization_config,
|
|
backend_config,
|
|
)
|
|
|
|
|
|
def prepare_qat_fx(
|
|
model: torch.nn.Module,
|
|
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
|
|
example_inputs: Tuple[Any, ...],
|
|
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
) -> ObservedGraphModule:
|
|
r""" Prepare a model for quantization aware training
|
|
|
|
Args:
|
|
* `model` (torch.nn.Module): torch.nn.Module model
|
|
* `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx`
|
|
* `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx`
|
|
* `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx`
|
|
* `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx`
|
|
|
|
Return:
|
|
A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for
|
|
quantization aware training
|
|
|
|
Example::
|
|
|
|
import torch
|
|
from torch.ao.quantization import get_default_qat_qconfig_mapping
|
|
from torch.ao.quantization import prepare_fx
|
|
|
|
class Submodule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
self.sub = Submodule()
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.sub(x) + x
|
|
return x
|
|
|
|
# initialize a floating point model
|
|
float_model = M().train()
|
|
# (optional, but preferred) load the weights from pretrained model
|
|
# float_model.load_weights(...)
|
|
|
|
# define the training loop for quantization aware training
|
|
def train_loop(model, train_data):
|
|
model.train()
|
|
for image, target in data_loader:
|
|
...
|
|
|
|
# qconfig is the configuration for how we insert observers for a particular
|
|
# operator
|
|
# qconfig = get_default_qconfig("fbgemm")
|
|
# Example of customizing qconfig:
|
|
# qconfig = torch.ao.quantization.QConfig(
|
|
# activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)),
|
|
# weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)))
|
|
# `activation` and `weight` are constructors of observer module
|
|
|
|
# qconfig_mapping is a collection of quantization configurations, user can
|
|
# set the qconfig for each operator (torch op calls, functional calls, module calls)
|
|
# in the model through qconfig_mapping
|
|
# the following call will get the qconfig_mapping that works best for models
|
|
# that target "fbgemm" backend
|
|
qconfig_mapping = get_default_qat_qconfig("fbgemm")
|
|
|
|
# We can customize qconfig_mapping in different ways, please take a look at
|
|
# the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways
|
|
# to configure this
|
|
|
|
# example_inputs is a tuple of inputs, that is used to infer the type of the
|
|
# outputs in the model
|
|
# currently it's not used, but please make sure model(*example_inputs) runs
|
|
example_inputs = (torch.randn(1, 3, 224, 224),)
|
|
|
|
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
|
|
# e.g. backend_config = get_default_backend_config("fbgemm")
|
|
# `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and
|
|
# backend_config, if the configuration for an operator in qconfig_mapping
|
|
# is supported in the backend_config (meaning it's supported by the target
|
|
# hardware), we'll insert fake_quantize modules according to the qconfig_mapping
|
|
# otherwise the configuration in qconfig_mapping will be ignored
|
|
# see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of
|
|
# how qconfig_mapping interacts with backend_config
|
|
prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs)
|
|
# Run training
|
|
train_loop(prepared_model, train_loop)
|
|
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
|
|
return _prepare_fx(
|
|
model,
|
|
qconfig_mapping,
|
|
True, # is_qat
|
|
example_inputs,
|
|
prepare_custom_config,
|
|
backend_config=backend_config,
|
|
)
|
|
|
|
|
|
def _convert_fx(
|
|
graph_module: GraphModule,
|
|
is_reference: bool,
|
|
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
|
|
is_standalone_module: bool = False,
|
|
_remove_qconfig: bool = True,
|
|
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
is_decomposed: bool = False,
|
|
) -> torch.nn.Module:
|
|
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
|
|
"""
|
|
if convert_custom_config is None:
|
|
convert_custom_config = ConvertCustomConfig()
|
|
|
|
if isinstance(convert_custom_config, Dict):
|
|
warnings.warn(
|
|
"Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
|
|
"in a future version. Please pass in a ConvertCustomConfig instead.")
|
|
convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
|
|
|
|
_check_is_graph_module(graph_module)
|
|
|
|
quantized = convert(
|
|
graph_module,
|
|
is_reference,
|
|
convert_custom_config,
|
|
is_standalone_module,
|
|
_remove_qconfig_flag=_remove_qconfig,
|
|
qconfig_mapping=qconfig_mapping,
|
|
backend_config=backend_config,
|
|
is_decomposed=is_decomposed,
|
|
)
|
|
|
|
preserved_attributes = convert_custom_config.preserved_attributes
|
|
for attr_name in preserved_attributes:
|
|
setattr(quantized, attr_name, getattr(graph_module, attr_name))
|
|
return quantized
|
|
|
|
|
|
def convert_fx(
|
|
graph_module: GraphModule,
|
|
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
|
|
_remove_qconfig: bool = True,
|
|
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
) -> torch.nn.Module:
|
|
r""" Convert a calibrated or trained model to a quantized model
|
|
|
|
Args:
|
|
* `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule)
|
|
|
|
* `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
|
|
See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details
|
|
|
|
* `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
|
|
|
|
* `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
|
|
|
|
The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`,
|
|
with the same values or `None`. Additional keys can be specified with values set to `None`.
|
|
|
|
For each entry whose value is set to None, we skip quantizing that entry in the model::
|
|
|
|
qconfig_mapping = QConfigMapping
|
|
.set_global(qconfig_from_prepare)
|
|
.set_object_type(torch.nn.functional.add, None) # skip quantizing torch.nn.functional.add
|
|
.set_object_type(torch.nn.functional.linear, qconfig_from_prepare)
|
|
.set_module_name("foo.bar", None) # skip quantizing module "foo.bar"
|
|
|
|
* `backend_config` (BackendConfig): A configuration for the backend which describes how
|
|
operators should be quantized in the backend, this includes quantization
|
|
mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.),
|
|
observer placement for each operators and fused operators.
|
|
See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
|
|
|
|
Return:
|
|
A quantized model (torch.nn.Module)
|
|
|
|
Example::
|
|
|
|
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
|
|
# convert_fx converts a calibrated/trained model to a quantized model for the
|
|
# target hardware, this includes converting the model first to a reference
|
|
# quantized model, and then lower the reference quantized model to a backend
|
|
# Currently, the supported backends are fbgemm (onednn), qnnpack (xnnpack) and
|
|
# they share the same set of quantized operators, so we are using the same
|
|
# lowering procedure
|
|
#
|
|
# backend_config defines the corresponding reference quantized module for
|
|
# the weighted modules in the model, e.g. nn.Linear
|
|
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
|
|
# e.g. backend_config = get_default_backend_config("fbgemm")
|
|
quantized_model = convert_fx(prepared_model)
|
|
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
|
|
return _convert_fx(
|
|
graph_module,
|
|
is_reference=False,
|
|
convert_custom_config=convert_custom_config,
|
|
_remove_qconfig=_remove_qconfig,
|
|
qconfig_mapping=qconfig_mapping,
|
|
backend_config=backend_config,
|
|
)
|
|
|
|
|
|
def convert_to_reference_fx(
|
|
graph_module: GraphModule,
|
|
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
|
|
_remove_qconfig: bool = True,
|
|
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
) -> torch.nn.Module:
|
|
r""" Convert a calibrated or trained model to a reference quantized model,
|
|
see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
|
|
reference quantzied model is a standard representation of a quantized model provided
|
|
by FX Graph Mode Quantization, it can be further lowered to run on the target
|
|
hardware, like accelerators
|
|
|
|
Args:
|
|
* `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
|
|
|
|
* `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
|
|
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
|
|
|
* `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
|
|
|
|
* `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
|
|
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
|
|
|
* `backend_config` (BackendConfig): A configuration for the backend which describes how
|
|
operators should be quantized in the backend. See
|
|
:func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
|
|
|
Return:
|
|
A reference quantized model (GraphModule)
|
|
|
|
Example::
|
|
|
|
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
|
|
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
|
|
# e.g. backend_config = get_default_backend_config("fbgemm")
|
|
reference_quantized_model = convert_to_reference_fx(prepared_model)
|
|
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx")
|
|
return _convert_fx(
|
|
graph_module,
|
|
is_reference=True,
|
|
convert_custom_config=convert_custom_config,
|
|
_remove_qconfig=_remove_qconfig,
|
|
qconfig_mapping=qconfig_mapping,
|
|
backend_config=backend_config,
|
|
)
|
|
|
|
def _convert_to_reference_decomposed_fx(
|
|
graph_module: GraphModule,
|
|
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
|
|
_remove_qconfig: bool = True,
|
|
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
) -> torch.nn.Module:
|
|
r""" Convert a calibrated or trained model to a reference quantized model, with
|
|
decomposed representation for quantized Tensor
|
|
see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
|
|
reference quantzied model is a standard representation of a quantized model provided
|
|
by FX Graph Mode Quantization, it can be further lowered to run on the target
|
|
hardware, like accelerators
|
|
|
|
Note: this is not public API
|
|
|
|
Args:
|
|
* `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
|
|
|
|
* `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
|
|
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
|
|
|
* `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
|
|
|
|
* `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
|
|
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
|
|
|
* `backend_config` (BackendConfig): A configuration for the backend which describes how
|
|
operators should be quantized in the backend. See
|
|
:func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
|
|
|
Return:
|
|
A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor
|
|
|
|
Example::
|
|
|
|
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
|
|
# TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
|
|
# e.g. backend_config = get_default_backend_config("fbgemm")
|
|
reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
|
|
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api.quantize_fx._convert_to_reference_decomposed_fx")
|
|
return _convert_fx(
|
|
graph_module,
|
|
is_reference=True,
|
|
convert_custom_config=convert_custom_config,
|
|
_remove_qconfig=_remove_qconfig,
|
|
qconfig_mapping=qconfig_mapping,
|
|
backend_config=backend_config,
|
|
is_decomposed=True,
|
|
)
|
|
|
|
|
|
def _convert_standalone_module_fx(
|
|
graph_module: GraphModule,
|
|
is_reference: bool = False,
|
|
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
|
|
) -> torch.nn.Module:
|
|
r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
|
|
and convert it to a quantized model
|
|
|
|
Returns a quantized standalone module, whether input/output is quantized is
|
|
specified by prepare_custom_config, with
|
|
input_quantized_idxs, output_quantized_idxs, please
|
|
see docs for prepare_fx for details
|
|
"""
|
|
return _convert_fx(
|
|
graph_module,
|
|
is_reference,
|
|
convert_custom_config,
|
|
is_standalone_module=True,
|
|
)
|