mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144556 Approved by: https://github.com/ezyang
1304 lines
47 KiB
Python
1304 lines
47 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
import copy
|
|
import math
|
|
|
|
from hypothesis import given, strategies as st
|
|
|
|
import torch
|
|
import torch.ao.nn.intrinsic.qat as nniqat
|
|
import torch.ao.nn.qat as nnqat
|
|
import torch.ao.nn.qat.dynamic as nnqatd
|
|
import torch.ao.nn.quantized as nnq
|
|
import torch.ao.nn.quantized.dynamic as nnqd
|
|
import torch.backends.mkldnn
|
|
import torch.nn as nn
|
|
import torch.testing._internal.hypothesis_utils as hu
|
|
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
|
|
from torch.ao.quantization import (
|
|
convert,
|
|
default_embedding_qat_qconfig,
|
|
default_qat_qconfig,
|
|
default_qconfig,
|
|
default_symmetric_qnnpack_qat_qconfig,
|
|
DeQuantStub,
|
|
FixedQParamsFakeQuantize,
|
|
FusedMovingAvgObsFakeQuantize,
|
|
get_default_qat_qconfig,
|
|
get_embedding_qat_module_mappings,
|
|
get_embedding_static_quant_module_mappings,
|
|
NoopObserver,
|
|
prepare,
|
|
prepare_qat,
|
|
quantize_qat,
|
|
QuantStub,
|
|
)
|
|
from torch.ao.quantization.qconfig import qconfig_equals
|
|
from torch.nn import BatchNorm2d, Conv2d, init, ReLU
|
|
from torch.nn.modules.utils import _pair
|
|
from torch.testing._internal.common_quantization import (
|
|
DeFusedEmbeddingBagLinear,
|
|
ManualConvLinearQATModel,
|
|
ManualConvLinearSymmQATModel,
|
|
ManualDropoutQATModel,
|
|
ManualEmbeddingBagLinear,
|
|
ManualLinearDynamicQATModel,
|
|
ManualLinearQATModel,
|
|
QuantizationTestCase,
|
|
QuantStubModel,
|
|
test_only_eval_fn,
|
|
test_only_train_fn,
|
|
TwoLayerLinearModel,
|
|
)
|
|
from torch.testing._internal.common_quantized import (
|
|
override_qengines,
|
|
override_quantized_engine,
|
|
supported_qengines,
|
|
)
|
|
from torch.testing._internal.common_utils import skipIfNoXNNPACK
|
|
|
|
|
|
hu.assert_deadline_disabled()
|
|
from functools import reduce
|
|
|
|
|
|
class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
|
"""
|
|
Conv-BN fusion implemented with explicit folding. Useful
|
|
to verify numerical equivalency with non-folded version.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
# ConvNd args
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
bias,
|
|
padding_mode,
|
|
# BatchNormNd args
|
|
# num_features: out_channels
|
|
eps=1e-05,
|
|
momentum=0.1,
|
|
# affine: True
|
|
# track_running_stats: True
|
|
# Args for this module
|
|
freeze_bn=False,
|
|
qconfig=None,
|
|
):
|
|
nn.modules.conv._ConvNd.__init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
False,
|
|
padding_mode,
|
|
)
|
|
assert qconfig, "qconfig must be provided for QAT module"
|
|
self.qconfig = qconfig
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.freeze_bn = freeze_bn if self.training else True
|
|
self.num_features = out_channels
|
|
self.gamma = nn.Parameter(torch.empty(out_channels))
|
|
self.beta = nn.Parameter(torch.empty(out_channels))
|
|
self.affine = True
|
|
self.track_running_stats = True
|
|
self.running_mean = nn.Buffer(torch.zeros(out_channels))
|
|
self.running_var = nn.Buffer(torch.ones(out_channels))
|
|
self.num_batches_tracked = nn.Buffer(torch.tensor(0, dtype=torch.long))
|
|
self.activation_post_process = self.qconfig.activation()
|
|
self.weight_fake_quant = self.qconfig.weight()
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.empty(out_channels))
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
self.reset_bn_parameters()
|
|
|
|
def reset_running_stats(self):
|
|
self.running_mean.zero_()
|
|
self.running_var.fill_(1)
|
|
self.num_batches_tracked.zero_()
|
|
|
|
def reset_bn_parameters(self):
|
|
self.reset_running_stats()
|
|
init.uniform_(self.gamma)
|
|
init.zeros_(self.beta)
|
|
if self.bias is not None:
|
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
|
bound = 1 / math.sqrt(fan_in)
|
|
init.uniform_(self.bias, -bound, bound)
|
|
|
|
def reset_parameters(self):
|
|
super().reset_parameters()
|
|
# A hack to avoid resetting on undefined parameters
|
|
if hasattr(self, "gamma"):
|
|
self.reset_bn_parameters()
|
|
|
|
def update_bn_stats(self):
|
|
self.freeze_bn = False
|
|
return self
|
|
|
|
def freeze_bn_stats(self):
|
|
self.freeze_bn = True
|
|
return self
|
|
|
|
def _forward(self, input):
|
|
# exponential_average_factor is self.momentum set to
|
|
# (when it is available) only so that if gets updated
|
|
# in ONNX graph when this node is exported to ONNX.
|
|
if self.momentum is None:
|
|
exponential_average_factor = 0.0
|
|
else:
|
|
exponential_average_factor = self.momentum
|
|
|
|
if self.training and not self.freeze_bn and self.track_running_stats:
|
|
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
|
if self.num_batches_tracked is not None:
|
|
self.num_batches_tracked += 1
|
|
if self.momentum is None: # use cumulative moving average
|
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
|
else: # use exponential moving average
|
|
exponential_average_factor = self.momentum
|
|
|
|
# we use running statistics from the previous batch, so this is an
|
|
# approximation of the approach mentioned in the whitepaper, but we only
|
|
# need to do one convolution in this case instead of two
|
|
running_std = torch.sqrt(self.running_var + self.eps)
|
|
scale_factor = self.gamma / running_std
|
|
scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1])
|
|
if self.bias is not None:
|
|
zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
|
|
else:
|
|
zero_bias = torch.zeros(
|
|
self.out_channels, device=scaled_weight.device, dtype=input.dtype
|
|
)
|
|
conv = self._conv_forward(
|
|
input, self.weight_fake_quant(scaled_weight), zero_bias
|
|
)
|
|
|
|
if self.training and not self.freeze_bn:
|
|
# recovering original conv to get original batch_mean and batch_var
|
|
if self.bias is not None:
|
|
conv_orig = conv / scale_factor.reshape(
|
|
[1, -1, 1, 1]
|
|
) + self.bias.reshape([1, -1, 1, 1])
|
|
else:
|
|
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
|
|
batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
|
|
batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
|
|
n = float(conv_orig.numel() / conv_orig.size()[1])
|
|
unbiased_batch_var = batch_var * (n / (n - 1))
|
|
batch_rstd = torch.ones_like(
|
|
batch_var, memory_format=torch.contiguous_format
|
|
) / torch.sqrt(batch_var + self.eps)
|
|
|
|
conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + (
|
|
self.beta - self.gamma * batch_rstd * batch_mean
|
|
).reshape([1, -1, 1, 1])
|
|
self.running_mean = (
|
|
exponential_average_factor * batch_mean.detach()
|
|
+ (1 - exponential_average_factor) * self.running_mean
|
|
)
|
|
self.running_var = (
|
|
exponential_average_factor * unbiased_batch_var.detach()
|
|
+ (1 - exponential_average_factor) * self.running_var
|
|
)
|
|
else:
|
|
if self.bias is None:
|
|
conv = conv + (
|
|
self.beta - self.gamma * self.running_mean / running_std
|
|
).reshape([1, -1, 1, 1])
|
|
else:
|
|
conv = conv + (
|
|
self.gamma * (self.bias - self.running_mean) / running_std
|
|
+ self.beta
|
|
).reshape([1, -1, 1, 1])
|
|
return conv
|
|
|
|
def extra_repr(self):
|
|
# TODO(jerryzh): extend
|
|
return super().extra_repr()
|
|
|
|
def forward(self, input):
|
|
return self.activation_post_process(self._forward(input))
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, qconfig=None):
|
|
r"""Create a qat module from a float module or qparams_dict
|
|
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
|
or directly from user
|
|
"""
|
|
assert type(mod) == cls._FLOAT_MODULE, (
|
|
"qat."
|
|
+ cls.__name__
|
|
+ ".from_float only works for "
|
|
+ cls._FLOAT_MODULE.__name__
|
|
)
|
|
if not qconfig:
|
|
assert hasattr(mod, "qconfig"), (
|
|
"Input float module must have qconfig defined"
|
|
)
|
|
assert mod.qconfig, "Input float module must have a valid qconfig"
|
|
qconfig = mod.qconfig
|
|
conv, bn = mod[0], mod[1]
|
|
qat_convbn = cls(
|
|
conv.in_channels,
|
|
conv.out_channels,
|
|
conv.kernel_size,
|
|
conv.stride,
|
|
conv.padding,
|
|
conv.dilation,
|
|
conv.groups,
|
|
conv.bias is not None,
|
|
conv.padding_mode,
|
|
bn.eps,
|
|
bn.momentum,
|
|
False,
|
|
qconfig,
|
|
)
|
|
qat_convbn.weight = conv.weight
|
|
qat_convbn.bias = conv.bias
|
|
qat_convbn.gamma = bn.weight
|
|
qat_convbn.beta = bn.bias
|
|
qat_convbn.running_mean = bn.running_mean
|
|
qat_convbn.running_var = bn.running_var
|
|
qat_convbn.num_batches_tracked = bn.num_batches_tracked
|
|
return qat_convbn
|
|
|
|
|
|
class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d):
|
|
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d
|
|
|
|
def __init__(
|
|
self,
|
|
# ConvNd args
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=None,
|
|
padding_mode="zeros",
|
|
# BatchNorm2d args
|
|
# num_features: out_channels
|
|
eps=1e-05,
|
|
momentum=0.1,
|
|
# affine: True
|
|
# track_running_stats: True
|
|
# Args for this module
|
|
freeze_bn=False,
|
|
qconfig=None,
|
|
):
|
|
kernel_size = _pair(kernel_size)
|
|
stride = _pair(stride)
|
|
padding = _pair(padding)
|
|
dilation = _pair(dilation)
|
|
_ReferenceConvBnNd.__init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
False,
|
|
_pair(0),
|
|
groups,
|
|
bias,
|
|
padding_mode,
|
|
eps,
|
|
momentum,
|
|
freeze_bn,
|
|
qconfig,
|
|
)
|
|
|
|
|
|
class TestQuantizeEagerQAT(QuantizationTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
|
|
self.embed_linear_data_train = [
|
|
[
|
|
torch.randint(0, 10, (12, 12), dtype=torch.long),
|
|
torch.randn((12, 1), dtype=torch.float),
|
|
]
|
|
for _ in range(2)
|
|
]
|
|
self.embed_data = [[torch.randint(0, 10, (12, 1))]]
|
|
|
|
def test_manual(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = ManualLinearQATModel(qengine)
|
|
model = prepare_qat(model)
|
|
self.checkObservers(model)
|
|
test_only_train_fn(model, self.train_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.fc1), nnq.Linear)
|
|
self.assertEqual(type(model.fc2), nnq.Linear)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data)
|
|
self.checkNoQconfig(model)
|
|
|
|
checkQuantized(model)
|
|
|
|
model = quantize_qat(
|
|
ManualLinearQATModel(qengine), test_only_train_fn, [self.train_data]
|
|
)
|
|
checkQuantized(model)
|
|
|
|
def test_dropout(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = ManualDropoutQATModel(qengine)
|
|
model = prepare_qat(model)
|
|
self.checkObservers(model)
|
|
test_only_train_fn(model, self.train_data)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.fc1), nnq.Linear)
|
|
self.assertEqual(type(model.dropout), nnq.Dropout)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data)
|
|
self.checkNoQconfig(model)
|
|
|
|
checkQuantized(model)
|
|
|
|
model = quantize_qat(
|
|
ManualDropoutQATModel(qengine),
|
|
test_only_train_fn,
|
|
[self.train_data],
|
|
)
|
|
checkQuantized(model)
|
|
|
|
def test_eval_only_fake_quant(self):
|
|
r"""Using FakeQuant in evaluation only mode,
|
|
this is useful for estimating accuracy loss when we quantize the
|
|
network
|
|
"""
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = ManualLinearQATModel(qengine)
|
|
|
|
model = prepare_qat(model)
|
|
self.checkObservers(model)
|
|
|
|
model.eval()
|
|
test_only_eval_fn(model, self.calib_data)
|
|
|
|
def test_conv_linear(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = ManualConvLinearQATModel()
|
|
|
|
model = prepare_qat(model)
|
|
self.checkObservers(model)
|
|
|
|
test_only_train_fn(model, self.img_data_2d_train)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.conv), nnq.Conv2d)
|
|
self.assertEqual(type(model.fc1), nnq.Linear)
|
|
self.assertEqual(type(model.fc2), nnq.Linear)
|
|
test_only_eval_fn(model, self.img_data_2d)
|
|
self.checkScriptable(model, self.img_data_2d)
|
|
self.checkNoQconfig(model)
|
|
|
|
checkQuantized(model)
|
|
|
|
model = ManualConvLinearQATModel()
|
|
model = quantize_qat(
|
|
model, test_only_train_fn, [self.img_data_2d_train]
|
|
)
|
|
checkQuantized(model)
|
|
|
|
@skipIfNoXNNPACK
|
|
def test_conv_linear_symm(self):
|
|
r"""Same as test_conv_linear but with Symmetric quantization.
|
|
Supported only with qengine=qnnpack, which uses symmetric
|
|
kernels from xnnpack library."""
|
|
for qengine in supported_qengines:
|
|
if qengine != "qnnpack":
|
|
continue
|
|
with override_quantized_engine(qengine):
|
|
model = ManualConvLinearSymmQATModel()
|
|
|
|
model = prepare_qat(model)
|
|
self.checkObservers(model)
|
|
|
|
test_only_train_fn(model, self.img_data_2d_train)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.conv), nnq.Conv2d)
|
|
self.assertEqual(type(model.fc1), nnq.Linear)
|
|
self.assertEqual(type(model.fc2), nnq.Linear)
|
|
test_only_eval_fn(model, self.img_data_2d)
|
|
self.checkScriptable(model, self.img_data_2d)
|
|
self.checkNoQconfig(model)
|
|
|
|
checkQuantized(model)
|
|
|
|
model = ManualConvLinearSymmQATModel()
|
|
model = quantize_qat(
|
|
model, test_only_train_fn, [self.img_data_2d_train]
|
|
)
|
|
checkQuantized(model)
|
|
|
|
def test_dynamic_qat_linear(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
# Dynamic QAT without memoryless observers should fail
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Dynamic QAT requires a memoryless observer."
|
|
+ "This means a MovingAverage observer with averaging constant equal to 1",
|
|
):
|
|
model = ManualLinearDynamicQATModel(default_qat_qconfig)
|
|
model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
|
|
|
|
model = ManualLinearDynamicQATModel()
|
|
model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
|
|
self.assertEqual(type(model.fc1), nnqatd.Linear)
|
|
self.assertEqual(type(model.fc2), nnqatd.Linear)
|
|
self.checkObservers(model)
|
|
test_only_train_fn(model, self.train_data)
|
|
model = convert(model, mapping={nnqatd.Linear: nnqd.Linear})
|
|
self.assertEqual(type(model.fc1), nnqd.Linear)
|
|
self.assertEqual(type(model.fc2), nnqd.Linear)
|
|
test_only_eval_fn(model, self.calib_data)
|
|
self.checkScriptable(model, self.calib_data)
|
|
self.checkNoQconfig(model)
|
|
|
|
def test_defused_embedding_bag_linear(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = DeFusedEmbeddingBagLinear().train()
|
|
model = prepare_qat(model, mapping=get_embedding_qat_module_mappings())
|
|
self.checkObservers(model)
|
|
|
|
test_only_train_fn(model, self.embed_linear_data_train)
|
|
# make sure activation_post_process is inserted after Linear.
|
|
self.assertEqual(
|
|
type(model.linear.activation_post_process),
|
|
FusedMovingAvgObsFakeQuantize,
|
|
)
|
|
# make sure that Embedding has a noop for activation.
|
|
self.assertEqual(type(model.emb.activation_post_process), NoopObserver)
|
|
# make sure that FakeQuant zero_points are correct dtype
|
|
self.assertEqual(
|
|
model.emb.weight_fake_quant.zero_point.dtype, torch.float32
|
|
)
|
|
self.assertEqual(
|
|
model.linear.weight_fake_quant.zero_point.dtype, torch.int32
|
|
)
|
|
|
|
model = convert(
|
|
model, mapping=get_embedding_static_quant_module_mappings()
|
|
)
|
|
|
|
def checkQuantized(model):
|
|
# make sure Embedding is now a QuantizedEmbedding
|
|
self.assertEqual(type(model.emb), nn.quantized.Embedding)
|
|
# make sure Linear is now a QuantizedLinear
|
|
self.assertEqual(type(model.linear), nn.quantized.Linear)
|
|
|
|
test_only_eval_fn(model, self.embed_data)
|
|
self.checkScriptable(model, self.embed_data)
|
|
self.checkNoQconfig(model)
|
|
|
|
checkQuantized(model)
|
|
|
|
def test_embedding_bag_linear(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = ManualEmbeddingBagLinear().train()
|
|
model = prepare_qat(model, mapping=get_embedding_qat_module_mappings())
|
|
self.checkObservers(model)
|
|
|
|
test_only_train_fn(model, self.embed_linear_data_train)
|
|
# make sure not activation_post_process is inserted for EmbeddingBag
|
|
self.assertFalse(hasattr(model, "activation_post_process"))
|
|
# make sure that FakeQuant zero_points are correct dtype
|
|
self.assertEqual(
|
|
model.emb.weight_fake_quant.zero_point.dtype, torch.float32
|
|
)
|
|
self.assertEqual(
|
|
model.linear.weight_fake_quant.zero_point.dtype, torch.int32
|
|
)
|
|
model = convert(
|
|
model, mapping=get_embedding_static_quant_module_mappings()
|
|
)
|
|
|
|
def checkQuantized(model):
|
|
# Make sure EmbeddingBag is now a quantized EmbeddingBag.
|
|
self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag)
|
|
# Also test that Linear has been quantized.
|
|
self.assertTrue(type(model.linear), nnq.Linear)
|
|
|
|
test_only_eval_fn(model, self.embed_data)
|
|
self.checkScriptable(model, self.embed_data)
|
|
self.checkNoQconfig(model)
|
|
|
|
checkQuantized(model)
|
|
|
|
model = ManualEmbeddingBagLinear()
|
|
|
|
def test_train_save_load_eval(self):
|
|
r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict
|
|
During eval, we first call prepare_qat and conver on the model and then load the state_dict
|
|
and compare results against original model
|
|
"""
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = TwoLayerLinearModel()
|
|
model = torch.ao.quantization.QuantWrapper(model)
|
|
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
|
model = prepare_qat(model)
|
|
|
|
fq_state_dict = model.state_dict()
|
|
|
|
test_only_train_fn(model, self.train_data)
|
|
model = convert(model)
|
|
|
|
quant_state_dict = model.state_dict()
|
|
|
|
x = torch.rand(2, 5, dtype=torch.float)
|
|
ref = model(x)
|
|
|
|
# Create model again for eval. Check result using quantized state_dict
|
|
model = TwoLayerLinearModel()
|
|
model = torch.ao.quantization.QuantWrapper(model)
|
|
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
|
torch.ao.quantization.prepare_qat(model, inplace=True)
|
|
new_state_dict = model.state_dict()
|
|
|
|
# Check to make sure the model after prepare_qat has the same state_dict as original.
|
|
self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys()))
|
|
|
|
torch.ao.quantization.convert(model, inplace=True)
|
|
model.eval()
|
|
model.load_state_dict(quant_state_dict)
|
|
out = model(x)
|
|
self.assertEqual(ref, out)
|
|
|
|
# Check model created using prepare has same state dict as quantized state_dict
|
|
model = TwoLayerLinearModel()
|
|
model.eval()
|
|
model = torch.ao.quantization.QuantWrapper(model)
|
|
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
|
torch.ao.quantization.prepare(model, inplace=True)
|
|
torch.ao.quantization.convert(model, inplace=True)
|
|
self.assertEqual(
|
|
set(model.state_dict().keys()), set(quant_state_dict.keys())
|
|
)
|
|
model.eval()
|
|
model.load_state_dict(quant_state_dict)
|
|
out = model(x)
|
|
self.assertEqual(ref, out)
|
|
|
|
@override_qengines
|
|
def test_forward_hooks_preserved(self):
|
|
r"""Test QAT on preserving pre forward and post forward hooks of original model"""
|
|
qengine = torch.backends.quantized.engine
|
|
model = QuantStubModel()
|
|
counter = {
|
|
"pre_forwards": 0,
|
|
"forwards": 0,
|
|
}
|
|
|
|
def fw_pre_hook(h_module, input):
|
|
counter["pre_forwards"] += 1
|
|
|
|
def fw_hook(h_module, input, output):
|
|
counter["forwards"] += 1
|
|
|
|
model.fc.register_forward_pre_hook(fw_pre_hook)
|
|
model.fc.register_forward_hook(fw_hook)
|
|
|
|
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
|
model = prepare_qat(model)
|
|
|
|
def checkHooksIsPresent(model, before_convert=True):
|
|
forward_hooks = 1
|
|
if before_convert:
|
|
self.assertEqual(
|
|
len(model.quant._forward_hooks.values()),
|
|
1,
|
|
"Quantization observer hook has disappeared",
|
|
)
|
|
forward_hooks = 2
|
|
self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
|
|
self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
|
|
self.assertEqual(
|
|
len(model.fc._forward_pre_hooks.values()),
|
|
1,
|
|
"Extra pre forward hooks have appeared on a layer",
|
|
)
|
|
self.assertEqual(
|
|
len(model.fc._forward_hooks.values()),
|
|
forward_hooks,
|
|
"Extra post forward hooks have appeared on a layer",
|
|
)
|
|
|
|
checkHooksIsPresent(model, True)
|
|
x = torch.rand(2, 5, dtype=torch.float)
|
|
model(x)
|
|
torch.ao.quantization.convert(model, inplace=True)
|
|
checkHooksIsPresent(model, False)
|
|
|
|
def test_add_scalar_uses_input_qparams(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.quant = torch.ao.quantization.QuantStub()
|
|
self.ff = torch.ao.nn.quantized.FloatFunctional()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.ff.add_scalar(x, 1.0)
|
|
return x
|
|
|
|
m = M()
|
|
m.qconfig = torch.ao.quantization.default_qconfig
|
|
mp = torch.ao.quantization.prepare_qat(m)
|
|
mp(torch.randn(4, 4))
|
|
mq = torch.ao.quantization.convert(mp)
|
|
res = mq(torch.randn(4, 4))
|
|
eps = 1e-5
|
|
self.assertTrue(torch.abs(mq.quant.scale - res.q_scale()) < eps)
|
|
|
|
def test_mul_scalar_uses_input_qparams(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.quant = torch.ao.quantization.QuantStub()
|
|
self.ff = torch.ao.nn.quantized.FloatFunctional()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.ff.mul_scalar(x, 2.0)
|
|
return x
|
|
|
|
m = M()
|
|
m.qconfig = torch.ao.quantization.default_qconfig
|
|
mp = torch.ao.quantization.prepare_qat(m)
|
|
mp(torch.randn(4, 4))
|
|
mq = torch.ao.quantization.convert(mp)
|
|
res = mq(torch.randn(4, 4))
|
|
eps = 1e-5
|
|
self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps)
|
|
|
|
@override_qengines
|
|
def test_qat_embedding_bag_errors(self):
|
|
default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine)
|
|
|
|
# Test constructor parameters checks here.
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "qconfig must be provided for QAT module"
|
|
):
|
|
nnqat.EmbeddingBag(10, 5, qconfig=None)
|
|
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
"Embedding Bag weights requires a qscheme of "
|
|
+ "torch.per_channel_affine_float_qparams",
|
|
):
|
|
nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig)
|
|
|
|
# Test from_float checks here.
|
|
embed = nn.Embedding(10, 5)
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "qat.EmbeddingBag.from_float only works for EmbeddingBag"
|
|
):
|
|
nnqat.EmbeddingBag.from_float(embed)
|
|
embed_bag = nn.EmbeddingBag(10, 5)
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "Input float module must have qconfig defined"
|
|
):
|
|
nnqat.EmbeddingBag.from_float(embed_bag)
|
|
embed_bag.qconfig = None
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "Input float module must have a valid qconfig"
|
|
):
|
|
nnqat.EmbeddingBag.from_float(embed_bag)
|
|
embed_bag.qconfig = default_qat_qconfig
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
"Embedding Bag weights requires a qscheme of "
|
|
+ "torch.per_channel_affine_float_qparams",
|
|
):
|
|
nnqat.EmbeddingBag.from_float(embed_bag)
|
|
|
|
def test_embedding_qat_qconfig_equal(self):
|
|
# Embedding QAT uses a NoopObserver class for activation,
|
|
# and a FakeQuant for weight, make sure that qconfig comparison
|
|
# functions properly for a mix of partial function and class in
|
|
# qconfig.
|
|
model = ManualEmbeddingBagLinear().train()
|
|
model = prepare_qat(model)
|
|
|
|
self.assertTrue(
|
|
qconfig_equals(model.emb.qconfig, default_embedding_qat_qconfig)
|
|
)
|
|
|
|
|
|
class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
|
def _test_activation_convert_numerics_impl(self, Act, data):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.act = Act()
|
|
self.quant = QuantStub()
|
|
self.dequant = DeQuantStub()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.act(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
m = M().train()
|
|
m.qconfig = default_qat_qconfig
|
|
m = prepare_qat(m)
|
|
before_convert = m(data)
|
|
m = convert(m)
|
|
after_convert = m(data)
|
|
self.assertEqual(before_convert, after_convert)
|
|
|
|
def test_fixed_qparam_ops(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
self.hardsigmoid = torch.nn.Hardsigmoid()
|
|
self.tanh = torch.nn.Tanh()
|
|
self.quant = QuantStub()
|
|
self.dequant = DeQuantStub()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.sigmoid(x)
|
|
x = self.hardsigmoid(x)
|
|
x = self.tanh(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
m = M().train()
|
|
m.qconfig = default_qat_qconfig
|
|
m = prepare_qat(m)
|
|
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
|
|
self.assertEqual(
|
|
type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize
|
|
)
|
|
data = torch.randn(1, 3, 2, 4)
|
|
before_convert = m(data)
|
|
m = convert(m)
|
|
after_convert = m(data)
|
|
self.assertEqual(before_convert, after_convert)
|
|
# make sure activation post process is removed
|
|
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
|
|
# verify fake quant module is removd
|
|
self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
|
|
# verify that hooks are removed
|
|
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
|
|
|
|
# make sure no fake quantize module is inserted for eval mode
|
|
|
|
def checkNoFQModule(m):
|
|
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
|
|
self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
|
|
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
|
|
|
|
m = M().eval()
|
|
m.qconfig = default_qconfig
|
|
m = prepare(m)
|
|
checkNoFQModule(m)
|
|
m = convert(m)
|
|
checkNoFQModule(m)
|
|
|
|
def test_leaky_relu(self):
|
|
data = torch.randn(1, 3, 2, 4)
|
|
self._test_activation_convert_numerics_impl(nn.LeakyReLU, data)
|
|
|
|
def test_relu(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
m = M().train()
|
|
m.qconfig = default_qconfig
|
|
m = prepare_qat(m)
|
|
# make sure no activation_post_process is inserted for relu
|
|
self.assertFalse(hasattr(m, "activation_post_process"))
|
|
m = convert(m)
|
|
# make sure ReLU module is not changed
|
|
self.assertTrue(type(m.relu), nn.ReLU)
|
|
|
|
@given(
|
|
batch_size=st.integers(2, 4),
|
|
input_channels_per_group=st.sampled_from([2, 3, 4]),
|
|
height=st.integers(5, 10),
|
|
width=st.integers(5, 10),
|
|
output_channels_per_group=st.sampled_from([2, 3]),
|
|
groups=st.integers(1, 3),
|
|
kernel_h=st.integers(1, 3),
|
|
kernel_w=st.integers(1, 3),
|
|
stride_h=st.integers(1, 2),
|
|
stride_w=st.integers(1, 2),
|
|
pad_h=st.integers(0, 2),
|
|
pad_w=st.integers(0, 2),
|
|
dilation=st.integers(1, 1),
|
|
padding_mode=st.sampled_from(["zeros", "circular"]),
|
|
use_relu=st.booleans(),
|
|
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
|
|
momentum=st.sampled_from([0.1, 0.2, 0.3]),
|
|
freeze_bn=st.booleans(),
|
|
zero_gamma=st.booleans(),
|
|
has_bias=st.booleans(),
|
|
use_slow_fusion=st.booleans(),
|
|
)
|
|
def test_conv_bn_relu(
|
|
self,
|
|
batch_size,
|
|
input_channels_per_group,
|
|
height,
|
|
width,
|
|
output_channels_per_group,
|
|
groups,
|
|
kernel_h,
|
|
kernel_w,
|
|
stride_h,
|
|
stride_w,
|
|
pad_h,
|
|
pad_w,
|
|
dilation,
|
|
padding_mode,
|
|
use_relu,
|
|
eps,
|
|
momentum,
|
|
freeze_bn,
|
|
zero_gamma,
|
|
has_bias,
|
|
use_slow_fusion,
|
|
):
|
|
input_channels = input_channels_per_group * groups
|
|
output_channels = output_channels_per_group * groups
|
|
dilation_h = dilation_w = dilation
|
|
|
|
conv_op = Conv2d(
|
|
input_channels,
|
|
output_channels,
|
|
(kernel_h, kernel_w),
|
|
(stride_h, stride_w),
|
|
(pad_h, pad_w),
|
|
(dilation_h, dilation_w),
|
|
groups,
|
|
has_bias,
|
|
padding_mode,
|
|
).to(dtype=torch.double)
|
|
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
|
|
relu_op = ReLU()
|
|
|
|
cls = ConvBnReLU2d if use_relu else ConvBn2d
|
|
qat_op = cls(
|
|
input_channels,
|
|
output_channels,
|
|
(kernel_h, kernel_w),
|
|
(stride_h, stride_w),
|
|
(pad_h, pad_w),
|
|
(dilation_h, dilation_w),
|
|
groups,
|
|
has_bias,
|
|
padding_mode,
|
|
eps,
|
|
momentum,
|
|
freeze_bn=True,
|
|
qconfig=default_qat_qconfig,
|
|
).to(dtype=torch.double)
|
|
qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion
|
|
|
|
# the approximate fusion will not work if bn.weight has 0
|
|
if zero_gamma and use_slow_fusion:
|
|
torch.nn.init.zeros_(qat_op.bn.weight)
|
|
|
|
qat_op.apply(torch.ao.quantization.disable_fake_quant)
|
|
if freeze_bn:
|
|
qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
|
|
else:
|
|
qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats)
|
|
|
|
# align inputs and internal parameters
|
|
input = torch.randn(
|
|
batch_size,
|
|
input_channels,
|
|
height,
|
|
width,
|
|
dtype=torch.double,
|
|
requires_grad=True,
|
|
)
|
|
conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
|
|
if has_bias:
|
|
conv_op.bias = torch.nn.Parameter(qat_op.bias.detach())
|
|
bn_op.running_mean = qat_op.bn.running_mean.clone()
|
|
bn_op.running_var = qat_op.bn.running_var.clone()
|
|
bn_op.weight = torch.nn.Parameter(qat_op.bn.weight.detach())
|
|
bn_op.bias = torch.nn.Parameter(qat_op.bn.bias.detach())
|
|
|
|
def compose(functions):
|
|
# functions are reversed for natural reading order
|
|
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
|
|
|
|
if not use_relu:
|
|
|
|
def relu_op(x): # noqa: F811
|
|
return x
|
|
|
|
if freeze_bn:
|
|
|
|
def ref_op(x):
|
|
x = conv_op(x)
|
|
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * (
|
|
bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)
|
|
).reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
|
|
x = relu_op(x)
|
|
return x
|
|
|
|
else:
|
|
ref_op = compose([conv_op, bn_op, relu_op])
|
|
|
|
input_clone = input.detach().clone().requires_grad_()
|
|
for _ in range(2):
|
|
result_ref = ref_op(input)
|
|
result_actual = qat_op(input_clone)
|
|
self.assertEqual(result_ref, result_actual)
|
|
|
|
# backward
|
|
dout = torch.randn(result_ref.size(), dtype=torch.double)
|
|
loss = (result_ref - dout).sum()
|
|
loss.backward()
|
|
input_grad_ref = input.grad.cpu()
|
|
weight_grad_ref = conv_op.weight.grad.cpu()
|
|
gamma_grad_ref = bn_op.weight.grad.cpu()
|
|
beta_grad_ref = bn_op.bias.grad.cpu()
|
|
running_mean_ref = bn_op.running_mean
|
|
running_var_ref = bn_op.running_var
|
|
num_batches_tracked_ref = bn_op.num_batches_tracked
|
|
loss = (result_actual - dout).sum()
|
|
loss.backward()
|
|
input_grad_actual = input_clone.grad.cpu()
|
|
weight_grad_actual = qat_op.weight.grad.cpu()
|
|
gamma_grad_actual = qat_op.bn.weight.grad.cpu()
|
|
beta_grad_actual = qat_op.bn.bias.grad.cpu()
|
|
running_mean_actual = qat_op.bn.running_mean
|
|
running_var_actual = qat_op.bn.running_var
|
|
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
|
|
precision = 1e-10
|
|
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
|
|
self.assertEqual(
|
|
weight_grad_ref, weight_grad_actual, atol=precision, rtol=0
|
|
)
|
|
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
|
|
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
|
|
self.assertEqual(
|
|
num_batches_tracked_ref,
|
|
num_batches_tracked_actual,
|
|
atol=precision,
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
running_mean_ref, running_mean_actual, atol=precision, rtol=0
|
|
)
|
|
self.assertEqual(
|
|
running_var_ref, running_var_actual, atol=precision, rtol=0
|
|
)
|
|
|
|
@given(
|
|
batch_size=st.integers(2, 4),
|
|
input_channels_per_group=st.sampled_from([2, 3, 4]),
|
|
height=st.integers(5, 10),
|
|
width=st.integers(5, 10),
|
|
output_channels_per_group=st.sampled_from([2, 3]),
|
|
groups=st.integers(1, 3),
|
|
kernel_h=st.integers(1, 3),
|
|
kernel_w=st.integers(1, 3),
|
|
stride_h=st.integers(1, 2),
|
|
stride_w=st.integers(1, 2),
|
|
pad_h=st.integers(0, 2),
|
|
pad_w=st.integers(0, 2),
|
|
dilation=st.integers(1, 1),
|
|
padding_mode=st.sampled_from(["zeros", "circular"]),
|
|
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
|
|
momentum=st.sampled_from([0.1, 0.2, 0.3]),
|
|
freeze_bn=st.booleans(),
|
|
bias=st.booleans(),
|
|
)
|
|
def test_conv_bn_folded_vs_unfolded(
|
|
self,
|
|
batch_size,
|
|
input_channels_per_group,
|
|
height,
|
|
width,
|
|
output_channels_per_group,
|
|
groups,
|
|
kernel_h,
|
|
kernel_w,
|
|
stride_h,
|
|
stride_w,
|
|
pad_h,
|
|
pad_w,
|
|
dilation,
|
|
padding_mode,
|
|
eps,
|
|
momentum,
|
|
freeze_bn,
|
|
bias,
|
|
):
|
|
input_channels = input_channels_per_group * groups
|
|
output_channels = output_channels_per_group * groups
|
|
dilation_h = dilation_w = dilation
|
|
|
|
qat_op = ConvBn2d(
|
|
input_channels,
|
|
output_channels,
|
|
(kernel_h, kernel_w),
|
|
(stride_h, stride_w),
|
|
(pad_h, pad_w),
|
|
(dilation_h, dilation_w),
|
|
groups,
|
|
bias, # bias
|
|
padding_mode,
|
|
eps,
|
|
momentum,
|
|
freeze_bn=freeze_bn,
|
|
qconfig=default_qat_qconfig,
|
|
).to(dtype=torch.double)
|
|
|
|
qat_ref_op = _ReferenceConvBn2d(
|
|
input_channels,
|
|
output_channels,
|
|
(kernel_h, kernel_w),
|
|
(stride_h, stride_w),
|
|
(pad_h, pad_w),
|
|
(dilation_h, dilation_w),
|
|
groups,
|
|
bias, # bias
|
|
padding_mode,
|
|
eps,
|
|
momentum,
|
|
freeze_bn=freeze_bn,
|
|
qconfig=default_qat_qconfig,
|
|
).to(dtype=torch.double)
|
|
|
|
qat_op.apply(torch.ao.quantization.disable_fake_quant)
|
|
qat_ref_op.apply(torch.ao.quantization.disable_fake_quant)
|
|
|
|
# align inputs and internal parameters
|
|
qat_ref_op.weight = torch.nn.Parameter(qat_op.weight.detach().clone())
|
|
qat_ref_op.running_mean = qat_op.bn.running_mean.clone()
|
|
qat_ref_op.running_var = qat_op.bn.running_var.clone()
|
|
qat_ref_op.gamma = torch.nn.Parameter(qat_op.bn.weight.detach().clone())
|
|
qat_ref_op.beta = torch.nn.Parameter(qat_op.bn.bias.detach().clone())
|
|
if qat_op.bias is not None:
|
|
qat_ref_op.bias = torch.nn.Parameter(qat_op.bias.detach().clone())
|
|
|
|
lr = 0.01
|
|
qat_op_optim = torch.optim.SGD(qat_op.parameters(), lr=lr)
|
|
qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr)
|
|
|
|
for i in range(5):
|
|
# make sure that calling model.train() does not override the
|
|
# bn freeze setting
|
|
qat_op.train()
|
|
qat_ref_op.train()
|
|
|
|
qat_op_optim.zero_grad()
|
|
qat_ref_op_optim.zero_grad()
|
|
|
|
input = torch.randn(
|
|
batch_size,
|
|
input_channels,
|
|
height,
|
|
width,
|
|
dtype=torch.double,
|
|
requires_grad=True,
|
|
)
|
|
input_clone = input.detach().clone().requires_grad_()
|
|
|
|
if i > 2:
|
|
qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
|
|
qat_ref_op.freeze_bn_stats()
|
|
|
|
if i > 3:
|
|
qat_op.apply(torch.ao.quantization.disable_observer)
|
|
qat_ref_op.apply(torch.ao.quantization.disable_observer)
|
|
|
|
result_ref = qat_ref_op(input)
|
|
result_actual = qat_op(input_clone)
|
|
self.assertEqual(result_ref, result_actual)
|
|
|
|
# backward
|
|
dout = torch.randn(result_ref.size(), dtype=torch.double) + 10.0
|
|
|
|
loss = (result_ref - dout).sum()
|
|
loss.backward()
|
|
input_grad_ref = input.grad.cpu()
|
|
weight_grad_ref = qat_ref_op.weight.grad.cpu()
|
|
gamma_grad_ref = qat_ref_op.gamma.grad.cpu()
|
|
beta_grad_ref = qat_ref_op.beta.grad.cpu()
|
|
running_mean_ref = qat_ref_op.running_mean
|
|
running_var_ref = qat_ref_op.running_var
|
|
num_batches_tracked_ref = qat_ref_op.num_batches_tracked
|
|
|
|
loss = (result_actual - dout).sum()
|
|
loss.backward()
|
|
input_grad_actual = input_clone.grad.cpu()
|
|
weight_grad_actual = qat_op.weight.grad.cpu()
|
|
gamma_grad_actual = qat_op.bn.weight.grad.cpu()
|
|
beta_grad_actual = qat_op.bn.bias.grad.cpu()
|
|
running_mean_actual = qat_op.bn.running_mean
|
|
running_var_actual = qat_op.bn.running_var
|
|
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
|
|
|
|
precision = 1e-5
|
|
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
|
|
self.assertEqual(
|
|
weight_grad_ref, weight_grad_actual, atol=precision, rtol=0
|
|
)
|
|
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
|
|
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
|
|
self.assertEqual(
|
|
num_batches_tracked_ref,
|
|
num_batches_tracked_actual,
|
|
atol=precision,
|
|
rtol=0,
|
|
)
|
|
self.assertEqual(
|
|
running_mean_ref, running_mean_actual, atol=precision, rtol=0
|
|
)
|
|
self.assertEqual(
|
|
running_var_ref, running_var_actual, atol=precision, rtol=0
|
|
)
|
|
|
|
qat_op_optim.step()
|
|
qat_ref_op_optim.step()
|
|
|
|
@override_qengines
|
|
def test_linear_bn_numerics(self):
|
|
qengine = torch.backends.quantized.engine
|
|
m_ref = nn.Sequential(
|
|
nn.Linear(4, 4),
|
|
nn.BatchNorm1d(4),
|
|
)
|
|
m_ref_copy = copy.deepcopy(m_ref)
|
|
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [["0", "1"]])
|
|
qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
|
m_ref_copy[0].qconfig = qconfig
|
|
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
|
|
|
|
# without fake_quants, fused QAT module should match fp32 module
|
|
m.apply(torch.ao.quantization.disable_fake_quant)
|
|
data = torch.randn(4, 4)
|
|
r1 = m_ref(data)
|
|
r2 = m(data)
|
|
self.assertTrue(torch.allclose(r1, r2))
|
|
|
|
@skipIfNoXNNPACK
|
|
@override_qengines
|
|
def test_linear_bn_symm_numerics(self):
|
|
qengine = torch.backends.quantized.engine
|
|
if qengine != "qnnpack":
|
|
return # Only qnnpack support symmetric quantization
|
|
m_ref = nn.Sequential(
|
|
nn.Linear(4, 4),
|
|
nn.BatchNorm1d(4),
|
|
)
|
|
m_ref_copy = copy.deepcopy(m_ref)
|
|
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [["0", "1"]])
|
|
qconfig = default_symmetric_qnnpack_qat_qconfig
|
|
m_ref_copy[0].qconfig = qconfig
|
|
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
|
|
|
|
# without fake_quants, fused QAT module should match fp32 module
|
|
m.apply(torch.ao.quantization.disable_fake_quant)
|
|
data = torch.randn(4, 4)
|
|
r1 = m_ref(data)
|
|
r2 = m(data)
|
|
self.assertTrue(torch.allclose(r1, r2))
|
|
|
|
@override_qengines
|
|
def test_linear_bn_workflow(self):
|
|
qengine = torch.backends.quantized.engine
|
|
m = nn.Sequential(
|
|
QuantStub(),
|
|
nn.Linear(4, 4),
|
|
nn.BatchNorm1d(4),
|
|
)
|
|
data = torch.randn(4, 4)
|
|
m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
|
m = torch.ao.quantization.fuse_modules_qat(m, [["1", "2"]])
|
|
mp = prepare_qat(m)
|
|
mp(data)
|
|
mq = convert(mp)
|
|
self.assertTrue(type(mq[1]) == nnq.Linear)
|
|
self.assertTrue(type(mq[2]) == nn.Identity)
|
|
|
|
@skipIfNoXNNPACK
|
|
@override_qengines
|
|
def test_linear_precomputed_fake_quant(self):
|
|
qengine = torch.backends.quantized.engine
|
|
if qengine != "qnnpack":
|
|
return # Only qnnpack support symmetric quantization
|
|
m_ref = nn.Linear(4, 4)
|
|
|
|
m_ref_copy = copy.deepcopy(m_ref)
|
|
qconfig = default_qconfig
|
|
m_ref_copy.qconfig = qconfig
|
|
weight_post_process = copy.deepcopy(qconfig.weight())
|
|
activation = copy.deepcopy(qconfig.activation())
|
|
activation(torch.randn(4, 4))
|
|
m_ref_copy.activation_post_process = activation
|
|
m_ref_copy = nnq.Linear.from_float(m_ref_copy)
|
|
weight_post_process = qconfig.weight()
|
|
weight_post_process.min_val = torch.tensor(-1)
|
|
weight_post_process.max_val = torch.tensor(1)
|
|
m_ref.weight_post_process = weight_post_process
|
|
m_ref.activation_post_process = activation
|
|
m_ref.qconfig = qconfig
|
|
m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True)
|
|
self.assertTrue(
|
|
m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_quantization.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|