mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Easy] enable PYFMT for torch/quantization/eager (#150761)
All modifications are done through tools, the detailed commands are as follows: ```bash lintrunner -a --take "PYFMT" --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150761 Approved by: https://github.com/jerryzh168
This commit is contained in:
@ -1165,14 +1165,6 @@ exclude_patterns = [
|
||||
'test/quantization/core/test_utils.py',
|
||||
'test/quantization/core/test_workflow_module.py',
|
||||
'test/quantization/core/test_workflow_ops.py',
|
||||
'test/quantization/eager/__init__.py',
|
||||
'test/quantization/eager/test_bias_correction_eager.py',
|
||||
'test/quantization/eager/test_equalize_eager.py',
|
||||
'test/quantization/eager/test_fuse_eager.py',
|
||||
'test/quantization/eager/test_model_numerics.py',
|
||||
'test/quantization/eager/test_numeric_suite_eager.py',
|
||||
'test/quantization/eager/test_quantize_eager_ptq.py',
|
||||
'test/quantization/eager/test_quantize_eager_qat.py',
|
||||
'test/quantization/fx/__init__.py',
|
||||
'test/quantization/fx/test_equalize_fx.py',
|
||||
'test/quantization/fx/test_model_report_fx.py',
|
||||
|
@ -1,24 +1,23 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||
|
||||
from torch.ao.quantization import default_qconfig
|
||||
from torch.ao.quantization import QuantWrapper
|
||||
import torch.ao.ns._numeric_suite as ns
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.ao.quantization import default_qconfig, QuantWrapper
|
||||
from torch.ao.quantization._correct_bias import (
|
||||
_supported_modules,
|
||||
_supported_modules_quantized,
|
||||
bias_correction,
|
||||
get_module,
|
||||
get_param,
|
||||
parent_child_names
|
||||
parent_child_names,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import (
|
||||
QuantizationTestCase,
|
||||
skipIfNoFBGEMM,
|
||||
)
|
||||
|
||||
import copy
|
||||
|
||||
|
||||
class TestBiasCorrectionEager(QuantizationTestCase):
|
||||
@ -28,9 +27,9 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
||||
return 20 * torch.log10(Ps / Pn)
|
||||
|
||||
def correct_artificial_bias_quantize(self, float_model, img_data):
|
||||
''' Adding artificial bias and testing if bias persists after bias
|
||||
correction. This test case changes the bias of a quantized submodule
|
||||
'''
|
||||
"""Adding artificial bias and testing if bias persists after bias
|
||||
correction. This test case changes the bias of a quantized submodule
|
||||
"""
|
||||
artificial_model = copy.deepcopy(float_model)
|
||||
artificial_model.qconfig = default_qconfig
|
||||
torch.ao.quantization.prepare(artificial_model, inplace=True)
|
||||
@ -41,12 +40,17 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
||||
# manually changing bias
|
||||
for name, submodule in artificial_model.named_modules():
|
||||
if type(submodule) in _supported_modules:
|
||||
x = get_param(submodule, 'bias')
|
||||
weight = get_param(submodule, 'weight')
|
||||
x = get_param(submodule, "bias")
|
||||
weight = get_param(submodule, "weight")
|
||||
if x is not None:
|
||||
submodule.set_weight_bias(weight, x.data * 3)
|
||||
|
||||
bias_correction(float_model, artificial_model, img_data, target_modules=_supported_modules_quantized)
|
||||
bias_correction(
|
||||
float_model,
|
||||
artificial_model,
|
||||
img_data,
|
||||
target_modules=_supported_modules_quantized,
|
||||
)
|
||||
|
||||
# Trims off the shadow module,
|
||||
for name, submodule in artificial_model.named_modules():
|
||||
@ -58,11 +62,13 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
||||
for name, artificial_submodule in artificial_model.named_modules():
|
||||
if type(artificial_submodule) in _supported_modules_quantized:
|
||||
submodule = get_module(float_model, name)
|
||||
float_bias = get_param(submodule, 'bias')
|
||||
artificial_bias = get_param(artificial_submodule, 'bias')
|
||||
float_bias = get_param(submodule, "bias")
|
||||
artificial_bias = get_param(artificial_submodule, "bias")
|
||||
|
||||
self.assertTrue(self.compute_sqnr(float_bias, artificial_bias) > 30,
|
||||
"Correcting quantized bias produced too much noise, sqnr score too low")
|
||||
self.assertTrue(
|
||||
self.compute_sqnr(float_bias, artificial_bias) > 30,
|
||||
"Correcting quantized bias produced too much noise, sqnr score too low",
|
||||
)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_linear_chain(self):
|
||||
@ -78,9 +84,15 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
||||
x = self.linear2(x)
|
||||
x = self.linear3(x)
|
||||
return x
|
||||
|
||||
float_model = QuantWrapper(LinearChain())
|
||||
img_data = [(torch.rand(10, 3, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long))
|
||||
for _ in range(50)]
|
||||
img_data = [
|
||||
(
|
||||
torch.rand(10, 3, dtype=torch.float),
|
||||
torch.randint(0, 1, (2,), dtype=torch.long),
|
||||
)
|
||||
for _ in range(50)
|
||||
]
|
||||
self.correct_artificial_bias_quantize(float_model, img_data)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
@ -97,7 +109,13 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
||||
x = self.conv2d2(x)
|
||||
x = self.conv2d3(x)
|
||||
return x
|
||||
|
||||
float_model = QuantWrapper(ConvChain())
|
||||
img_data = [(torch.rand(10, 3, 125, 125, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long))
|
||||
for _ in range(50)]
|
||||
img_data = [
|
||||
(
|
||||
torch.rand(10, 3, 125, 125, dtype=torch.float),
|
||||
torch.randint(0, 1, (2,), dtype=torch.long),
|
||||
)
|
||||
for _ in range(50)
|
||||
]
|
||||
self.correct_artificial_bias_quantize(float_model, img_data)
|
||||
|
@ -1,20 +1,19 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
from torch.ao.quantization.fuse_modules import fuse_modules
|
||||
|
||||
import torch.ao.quantization._equalize as _equalize
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.ao.quantization._equalize as _equalize
|
||||
import torch.nn as nn
|
||||
from torch.ao.quantization.fuse_modules import fuse_modules
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
|
||||
|
||||
class TestEqualizeEager(QuantizationTestCase):
|
||||
def checkChannelsEqualized(self, tensor1, tensor2, output_axis, input_axis):
|
||||
''' Checks the channel ranges of tensor1, tensor2 are the same,
|
||||
"""Checks the channel ranges of tensor1, tensor2 are the same,
|
||||
which is an indication that equalization has been applied correctly
|
||||
'''
|
||||
"""
|
||||
output_channel_tensor1 = _equalize.channel_range(tensor1, output_axis)
|
||||
input_channel_tensor2 = _equalize.channel_range(tensor2, input_axis)
|
||||
|
||||
@ -23,18 +22,17 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||
self.assertEqual(output_channel_tensor1, input_channel_tensor2)
|
||||
|
||||
def getModule(self, model, name):
|
||||
''' Given the name is a submodule to a model, return the submodule
|
||||
'''
|
||||
"""Given the name is a submodule to a model, return the submodule"""
|
||||
curr = model
|
||||
name = name.split('.')
|
||||
name = name.split(".")
|
||||
for subname in name:
|
||||
curr = curr._modules[subname]
|
||||
return curr
|
||||
|
||||
def test_cross_layer_equalization(self):
|
||||
''' applies _equalize.cross_layer_equalization on two modules and checks
|
||||
"""applies _equalize.cross_layer_equalization on two modules and checks
|
||||
to make sure channels ranges are equivalent
|
||||
'''
|
||||
"""
|
||||
module1 = nn.Conv2d(3, 4, 2)
|
||||
module2 = nn.Linear(4, 4)
|
||||
|
||||
@ -45,13 +43,18 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||
|
||||
mod_tensor1, mod_tensor2 = module1.weight, module2.weight
|
||||
|
||||
self.checkChannelsEqualized(mod_tensor1, mod_tensor2, module1_output_channel_axis, module2_input_channel_axis)
|
||||
self.checkChannelsEqualized(
|
||||
mod_tensor1,
|
||||
mod_tensor2,
|
||||
module1_output_channel_axis,
|
||||
module2_input_channel_axis,
|
||||
)
|
||||
|
||||
def test_converged(self):
|
||||
''' Sanity checks on _equalize.converged working
|
||||
"""Sanity checks on _equalize.converged working
|
||||
identical modules should return true
|
||||
modules with high difference in weights should return false
|
||||
'''
|
||||
"""
|
||||
module1 = nn.Linear(3, 3)
|
||||
module2 = nn.Linear(3, 3)
|
||||
|
||||
@ -59,18 +62,19 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||
module2.weight = nn.parameter.Parameter(torch.zeros(module1.weight.size()))
|
||||
|
||||
# input is a dictionary
|
||||
dictionary_1 = {'linear1': module1}
|
||||
dictionary_2 = {'linear1': module2}
|
||||
dictionary_1 = {"linear1": module1}
|
||||
dictionary_2 = {"linear1": module2}
|
||||
self.assertTrue(_equalize.converged(dictionary_1, dictionary_1, 1e-6))
|
||||
self.assertFalse(_equalize.converged(dictionary_1, dictionary_2, 1e-6))
|
||||
|
||||
def test_equalize(self):
|
||||
''' First checks to see if _equalize.equalize can handle multiple
|
||||
"""First checks to see if _equalize.equalize can handle multiple
|
||||
pair modules as input
|
||||
then checks correctness of the function by ensuring the equalized
|
||||
and unequalized versions of the model yield the same output
|
||||
given the same input
|
||||
'''
|
||||
"""
|
||||
|
||||
class ChainModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -83,13 +87,16 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||
x = self.linear2(x)
|
||||
x = self.linear3(x)
|
||||
return x
|
||||
|
||||
chain1 = ChainModule()
|
||||
chain2 = copy.deepcopy(chain1)
|
||||
|
||||
_equalize.equalize(chain1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6)
|
||||
linear1 = self.getModule(chain1, 'linear1')
|
||||
linear2 = self.getModule(chain1, 'linear2')
|
||||
linear3 = self.getModule(chain1, 'linear3')
|
||||
_equalize.equalize(
|
||||
chain1, [["linear1", "linear2"], ["linear2", "linear3"]], 1e-6
|
||||
)
|
||||
linear1 = self.getModule(chain1, "linear1")
|
||||
linear2 = self.getModule(chain1, "linear2")
|
||||
linear3 = self.getModule(chain1, "linear3")
|
||||
|
||||
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
|
||||
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
|
||||
@ -98,7 +105,7 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||
self.assertEqual(chain1(input), chain2(input))
|
||||
|
||||
def test_equalize_fused_convrelu(self):
|
||||
''' Checks to see if eager mode equalization supports fused
|
||||
"""Checks to see if eager mode equalization supports fused
|
||||
ConvReLU2d models
|
||||
|
||||
A model with 3 ConvReLU2d is constructed. Next, the conv2d and relu
|
||||
@ -106,7 +113,8 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||
equalization applied. Finally, we ensure that the channels have been
|
||||
equalized and that the equalized and unequalized versions of the model
|
||||
yield the same output given the same input
|
||||
'''
|
||||
"""
|
||||
|
||||
class M(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -128,13 +136,15 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||
|
||||
model = M()
|
||||
|
||||
fused_model1 = fuse_modules(model, [['conv1', 'relu1'], ['conv2', 'relu2'], ['conv3', 'relu3']])
|
||||
fused_model1 = fuse_modules(
|
||||
model, [["conv1", "relu1"], ["conv2", "relu2"], ["conv3", "relu3"]]
|
||||
)
|
||||
fused_model2 = copy.deepcopy(fused_model1)
|
||||
|
||||
_equalize.equalize(fused_model1, [['conv1', 'conv2'], ['conv2', 'conv3']], 1e-6)
|
||||
conv1 = self.getModule(fused_model1, 'conv1')[0]
|
||||
conv2 = self.getModule(fused_model1, 'conv2')[0]
|
||||
conv3 = self.getModule(fused_model1, 'conv3')[0]
|
||||
_equalize.equalize(fused_model1, [["conv1", "conv2"], ["conv2", "conv3"]], 1e-6)
|
||||
conv1 = self.getModule(fused_model1, "conv1")[0]
|
||||
conv2 = self.getModule(fused_model1, "conv2")[0]
|
||||
conv3 = self.getModule(fused_model1, "conv3")[0]
|
||||
|
||||
self.checkChannelsEqualized(conv1.weight, conv2.weight, 0, 1)
|
||||
self.checkChannelsEqualized(conv2.weight, conv3.weight, 0, 1)
|
||||
@ -144,7 +154,7 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||
self.assertEqual(fused_model1(input), model(input))
|
||||
|
||||
def test_equalize_fused_linearrelu(self):
|
||||
''' Checks to see if eager mode equalization supports fused
|
||||
"""Checks to see if eager mode equalization supports fused
|
||||
LinearReLU models
|
||||
|
||||
A model with 3 LinearReLU is constructed. Next, the linear and relu
|
||||
@ -152,7 +162,8 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||
equalization applied. Finally, we ensure that the channels have been
|
||||
equalized and that the equalized and unequalized versions of the model
|
||||
yield the same output given the same input
|
||||
'''
|
||||
"""
|
||||
|
||||
class M(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -174,13 +185,17 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||
|
||||
model = M()
|
||||
|
||||
fused_model1 = fuse_modules(model, [['linear1', 'relu1'], ['linear2', 'relu2'], ['linear3', 'relu3']])
|
||||
fused_model1 = fuse_modules(
|
||||
model, [["linear1", "relu1"], ["linear2", "relu2"], ["linear3", "relu3"]]
|
||||
)
|
||||
fused_model2 = copy.deepcopy(fused_model1)
|
||||
|
||||
_equalize.equalize(fused_model1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6)
|
||||
linear1 = self.getModule(fused_model1, 'linear1')[0]
|
||||
linear2 = self.getModule(fused_model1, 'linear2')[0]
|
||||
linear3 = self.getModule(fused_model1, 'linear3')[0]
|
||||
_equalize.equalize(
|
||||
fused_model1, [["linear1", "linear2"], ["linear2", "linear3"]], 1e-6
|
||||
)
|
||||
linear1 = self.getModule(fused_model1, "linear1")[0]
|
||||
linear2 = self.getModule(fused_model1, "linear2")[0]
|
||||
linear3 = self.getModule(fused_model1, "linear3")[0]
|
||||
|
||||
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
|
||||
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
|
||||
|
@ -3,37 +3,35 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.ao.nn.quantized as nnq
|
||||
import torch.ao.nn.intrinsic as nni
|
||||
import torch.ao.nn.intrinsic.quantized as nniq
|
||||
import torch.ao.nn.intrinsic.qat as nniqat
|
||||
import torch.ao.nn.intrinsic.quantized as nniq
|
||||
import torch.ao.nn.quantized as nnq
|
||||
import torch.nn as nn
|
||||
from torch.ao.quantization import (
|
||||
quantize,
|
||||
prepare,
|
||||
convert,
|
||||
prepare_qat,
|
||||
quantize_qat,
|
||||
default_qat_qconfig,
|
||||
default_qconfig,
|
||||
fuse_modules,
|
||||
fuse_modules_qat,
|
||||
prepare,
|
||||
prepare_qat,
|
||||
QConfig,
|
||||
default_qconfig,
|
||||
default_qat_qconfig,
|
||||
quantize,
|
||||
quantize_qat,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_quantization import (
|
||||
QuantizationTestCase,
|
||||
ModelForFusion,
|
||||
ModelWithSequentialFusion,
|
||||
ModelForLinearBNFusion,
|
||||
ModelForFusionWithBias,
|
||||
ModelForConvTransposeBNFusion,
|
||||
ModelForFusion,
|
||||
ModelForFusionWithBias,
|
||||
ModelForLinearBNFusion,
|
||||
ModelWithSequentialFusion,
|
||||
QuantizationTestCase,
|
||||
SingleLayerLinearModel,
|
||||
skipIfNoFBGEMM,
|
||||
test_only_eval_fn,
|
||||
test_only_train_fn,
|
||||
skipIfNoFBGEMM,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_quantized import (
|
||||
override_quantized_engine,
|
||||
supported_qengines,
|
||||
@ -45,23 +43,38 @@ class TestFuseEager(QuantizationTestCase):
|
||||
def test_fuse_module_train(self):
|
||||
model = ModelForFusion(default_qat_qconfig).train()
|
||||
# Test step by step fusion
|
||||
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
||||
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
||||
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
|
||||
msg="Fused Conv + BN + Relu first layer")
|
||||
self.assertEqual(type(model.bn1), torch.nn.Identity,
|
||||
msg="Fused Conv + BN + Relu (skipped BN)")
|
||||
self.assertEqual(type(model.relu1), torch.nn.Identity,
|
||||
msg="Fused Conv + BN + Relu (skipped Relu)")
|
||||
model = fuse_modules_qat(model, ["conv1", "bn1", "relu1"])
|
||||
model = fuse_modules_qat(model, ["sub1.conv", "sub1.bn"])
|
||||
self.assertEqual(
|
||||
type(model.conv1),
|
||||
nni.ConvBnReLU2d,
|
||||
msg="Fused Conv + BN + Relu first layer",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.bn1),
|
||||
torch.nn.Identity,
|
||||
msg="Fused Conv + BN + Relu (skipped BN)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.relu1),
|
||||
torch.nn.Identity,
|
||||
msg="Fused Conv + BN + Relu (skipped Relu)",
|
||||
)
|
||||
|
||||
self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
|
||||
msg="Fused submodule Conv + BN")
|
||||
self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
|
||||
msg="Fused submodule Conv + BN (skipped BN)")
|
||||
self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
|
||||
msg="Non-fused submodule Conv")
|
||||
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
|
||||
msg="Non-fused submodule ReLU")
|
||||
self.assertEqual(
|
||||
type(model.sub1.conv), nni.ConvBn2d, msg="Fused submodule Conv + BN"
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.sub1.bn),
|
||||
torch.nn.Identity,
|
||||
msg="Fused submodule Conv + BN (skipped BN)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.sub2.conv), torch.nn.Conv2d, msg="Non-fused submodule Conv"
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU"
|
||||
)
|
||||
model = prepare_qat(model)
|
||||
self.checkObservers(model)
|
||||
|
||||
@ -89,69 +102,121 @@ class TestFuseEager(QuantizationTestCase):
|
||||
test_only_eval_fn(model, self.img_data_1d)
|
||||
self.checkNoQconfig(model)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'",
|
||||
):
|
||||
checkQuantized(model)
|
||||
|
||||
model = ModelForFusion(default_qat_qconfig).train()
|
||||
model = fuse_modules_qat(
|
||||
model,
|
||||
[['conv1', 'bn1', 'relu1'],
|
||||
['sub1.conv', 'sub1.bn']])
|
||||
model, [["conv1", "bn1", "relu1"], ["sub1.conv", "sub1.bn"]]
|
||||
)
|
||||
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'",
|
||||
):
|
||||
checkQuantized(model)
|
||||
|
||||
|
||||
def test_fuse_module_eval(self):
|
||||
model = ModelForFusion(default_qconfig)
|
||||
model.eval()
|
||||
model = fuse_modules(
|
||||
model,
|
||||
[['conv3', 'bn3', 'relu4'],
|
||||
['conv1', 'bn1', 'relu1'],
|
||||
['conv2', 'relu2'],
|
||||
['bn2', 'relu3'],
|
||||
['sub1.conv', 'sub1.bn']])
|
||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
||||
msg="Fused Conv + BN + Relu first layer (BN is folded)")
|
||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
||||
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
|
||||
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
||||
msg="Fused Conv + BN + Relu second layer (Relu only)")
|
||||
self.assertEqual(type(model.bn1), nn.Identity,
|
||||
msg="Fused Conv + BN + Relu second layer (Skipped BN)")
|
||||
self.assertEqual(type(model.relu1), nn.Identity,
|
||||
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
|
||||
self.assertEqual(type(model.conv2), nni.ConvReLU3d,
|
||||
msg="Fused Conv + BN + Relu first layer (BN is folded)")
|
||||
self.assertEqual(type(model.bn2), nni.BNReLU3d,
|
||||
msg="Fused BN + Relu first layer (Relu is folded))")
|
||||
self.assertEqual(type(model.relu3), nn.Identity,
|
||||
msg="Fused BN + Relu second layer (Skipped Relu)")
|
||||
self.assertEqual(type(model.conv2[0]), nn.Conv3d,
|
||||
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
|
||||
self.assertEqual(type(model.conv2[1]), nn.ReLU,
|
||||
msg="Fused Conv + BN + Relu second layer (Relu only)")
|
||||
self.assertEqual(type(model.relu2), nn.Identity,
|
||||
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
|
||||
[
|
||||
["conv3", "bn3", "relu4"],
|
||||
["conv1", "bn1", "relu1"],
|
||||
["conv2", "relu2"],
|
||||
["bn2", "relu3"],
|
||||
["sub1.conv", "sub1.bn"],
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv1),
|
||||
nni.ConvReLU2d,
|
||||
msg="Fused Conv + BN + Relu first layer (BN is folded)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv1[0]),
|
||||
nn.Conv2d,
|
||||
msg="Fused Conv + BN + Relu (Conv + folded BN only)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv1[1]),
|
||||
nn.ReLU,
|
||||
msg="Fused Conv + BN + Relu second layer (Relu only)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.bn1),
|
||||
nn.Identity,
|
||||
msg="Fused Conv + BN + Relu second layer (Skipped BN)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.relu1),
|
||||
nn.Identity,
|
||||
msg="Fused Conv + BN + Relu second layer (Skipped Relu)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv2),
|
||||
nni.ConvReLU3d,
|
||||
msg="Fused Conv + BN + Relu first layer (BN is folded)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.bn2),
|
||||
nni.BNReLU3d,
|
||||
msg="Fused BN + Relu first layer (Relu is folded))",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.relu3),
|
||||
nn.Identity,
|
||||
msg="Fused BN + Relu second layer (Skipped Relu)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv2[0]),
|
||||
nn.Conv3d,
|
||||
msg="Fused Conv + BN + Relu (Conv + folded BN only)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv2[1]),
|
||||
nn.ReLU,
|
||||
msg="Fused Conv + BN + Relu second layer (Relu only)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.relu2),
|
||||
nn.Identity,
|
||||
msg="Fused Conv + BN + Relu second layer (Skipped Relu)",
|
||||
)
|
||||
|
||||
self.assertEqual(type(model.conv3), nni.ConvReLU1d,
|
||||
msg="Fused Conv + Relu for Conv1d (folded BN)")
|
||||
self.assertEqual(type(model.conv3[0]), nn.Conv1d,
|
||||
msg="Fused Conv + Relu for Conv1d ")
|
||||
self.assertEqual(type(model.conv3[1]), nn.ReLU,
|
||||
msg="Fused Conv + Relu for Conv1d")
|
||||
self.assertEqual(type(model.bn3), nn.Identity,
|
||||
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)")
|
||||
self.assertEqual(
|
||||
type(model.conv3),
|
||||
nni.ConvReLU1d,
|
||||
msg="Fused Conv + Relu for Conv1d (folded BN)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv3[0]), nn.Conv1d, msg="Fused Conv + Relu for Conv1d "
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv3[1]), nn.ReLU, msg="Fused Conv + Relu for Conv1d"
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.bn3),
|
||||
nn.Identity,
|
||||
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)",
|
||||
)
|
||||
|
||||
self.assertEqual(type(model.sub1.conv), nn.Conv2d,
|
||||
msg="Fused submodule Conv + folded BN")
|
||||
self.assertEqual(type(model.sub1.bn), nn.Identity,
|
||||
msg="Fused submodule (skipped BN)")
|
||||
self.assertEqual(type(model.sub2.conv), nn.Conv2d,
|
||||
msg="Non-fused submodule Conv")
|
||||
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
|
||||
msg="Non-fused submodule ReLU")
|
||||
self.assertEqual(
|
||||
type(model.sub1.conv), nn.Conv2d, msg="Fused submodule Conv + folded BN"
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.sub1.bn), nn.Identity, msg="Fused submodule (skipped BN)"
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.sub2.conv), nn.Conv2d, msg="Non-fused submodule Conv"
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU"
|
||||
)
|
||||
|
||||
model = prepare(model)
|
||||
self.checkObservers(model)
|
||||
@ -176,11 +241,14 @@ class TestFuseEager(QuantizationTestCase):
|
||||
model = ModelForFusion(default_qconfig).eval()
|
||||
model = fuse_modules(
|
||||
model,
|
||||
[['conv1', 'bn1', 'relu1'],
|
||||
['conv2', 'relu2'],
|
||||
['bn2', 'relu3'],
|
||||
['sub1.conv', 'sub1.bn'],
|
||||
['conv3', 'bn3', 'relu4']])
|
||||
[
|
||||
["conv1", "bn1", "relu1"],
|
||||
["conv2", "relu2"],
|
||||
["bn2", "relu3"],
|
||||
["sub1.conv", "sub1.bn"],
|
||||
["conv3", "bn3", "relu4"],
|
||||
],
|
||||
)
|
||||
model = quantize(model, test_only_eval_fn, [self.img_data_1d])
|
||||
checkQuantized(model)
|
||||
|
||||
@ -190,27 +258,46 @@ class TestFuseEager(QuantizationTestCase):
|
||||
model = ModelWithSequentialFusion().train()
|
||||
model.to(torch.float)
|
||||
fuse_modules_qat(
|
||||
model, [['conv1', 'relu1'] ,
|
||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
||||
['classifier.0', 'classifier.1']],
|
||||
inplace=True)
|
||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
||||
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
||||
msg="Fused Conv + Relu: Conv2d")
|
||||
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
||||
msg="Fused Conv + Relu: Relu")
|
||||
self.assertEqual(type(model.relu1), nn.Identity,
|
||||
msg="Fused Conv + Relu: Identity")
|
||||
model,
|
||||
[
|
||||
["conv1", "relu1"],
|
||||
["features.0.0", "features.0.1", "features.0.2"],
|
||||
["features.1.0", "features.1.1", "features.1.2"],
|
||||
["features.2.0", "features.2.1", "features.2.2"],
|
||||
["classifier.0", "classifier.1"],
|
||||
],
|
||||
inplace=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv1),
|
||||
nni.ConvReLU2d,
|
||||
msg="Fused Conv + Relu: nni.ConvReLU2d",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d"
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu"
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity"
|
||||
)
|
||||
for i in range(3):
|
||||
self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d,
|
||||
msg="Fused submodule Conv + folded BN")
|
||||
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
||||
msg="Fused submodule (skipped BN)")
|
||||
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
||||
msg="Non-fused submodule Conv")
|
||||
self.assertEqual(
|
||||
type(model.features[i][0]),
|
||||
nni.ConvBnReLU2d,
|
||||
msg="Fused submodule Conv + folded BN",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.features[i][1]),
|
||||
nn.Identity,
|
||||
msg="Fused submodule (skipped BN)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.features[i][2]),
|
||||
nn.Identity,
|
||||
msg="Non-fused submodule Conv",
|
||||
)
|
||||
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
||||
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
||||
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
||||
@ -218,17 +305,26 @@ class TestFuseEager(QuantizationTestCase):
|
||||
self.checkObservers(model)
|
||||
model(self.img_data_2d[0][0])
|
||||
|
||||
|
||||
def checkQAT(model):
|
||||
self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
|
||||
self.assertEqual(type(model.relu1), nn.Identity)
|
||||
|
||||
for i in range(3):
|
||||
self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d,
|
||||
msg="Fused submodule Conv + folded BN")
|
||||
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
||||
msg="Fused submodule (skipped BN)")
|
||||
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
||||
msg="Non-fused submodule Conv")
|
||||
self.assertEqual(
|
||||
type(model.features[i][0]),
|
||||
nniqat.ConvBnReLU2d,
|
||||
msg="Fused submodule Conv + folded BN",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.features[i][1]),
|
||||
nn.Identity,
|
||||
msg="Fused submodule (skipped BN)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.features[i][2]),
|
||||
nn.Identity,
|
||||
msg="Non-fused submodule Conv",
|
||||
)
|
||||
self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
|
||||
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
||||
|
||||
@ -245,27 +341,45 @@ class TestFuseEager(QuantizationTestCase):
|
||||
model.to(torch.float)
|
||||
fuse_modules(
|
||||
model,
|
||||
[['conv1', 'relu1'],
|
||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
||||
['classifier.0', 'classifier.1']],
|
||||
inplace=True)
|
||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
||||
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
||||
msg="Fused Conv + Relu: Conv2d")
|
||||
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
||||
msg="Fused Conv + Relu: Relu")
|
||||
self.assertEqual(type(model.relu1), nn.Identity,
|
||||
msg="Fused Conv + Relu: Identity")
|
||||
[
|
||||
["conv1", "relu1"],
|
||||
["features.0.0", "features.0.1", "features.0.2"],
|
||||
["features.1.0", "features.1.1", "features.1.2"],
|
||||
["features.2.0", "features.2.1", "features.2.2"],
|
||||
["classifier.0", "classifier.1"],
|
||||
],
|
||||
inplace=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv1),
|
||||
nni.ConvReLU2d,
|
||||
msg="Fused Conv + Relu: nni.ConvReLU2d",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d"
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu"
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity"
|
||||
)
|
||||
for i in range(3):
|
||||
self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d,
|
||||
msg="Fused submodule Conv + folded BN")
|
||||
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
||||
msg="Fused submodule (skipped BN)")
|
||||
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
||||
msg="Non-fused submodule Conv")
|
||||
self.assertEqual(
|
||||
type(model.features[i][0]),
|
||||
nni.ConvReLU2d,
|
||||
msg="Fused submodule Conv + folded BN",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.features[i][1]),
|
||||
nn.Identity,
|
||||
msg="Fused submodule (skipped BN)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.features[i][2]),
|
||||
nn.Identity,
|
||||
msg="Non-fused submodule Conv",
|
||||
)
|
||||
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
||||
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
||||
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
||||
@ -297,12 +411,12 @@ class TestFuseEager(QuantizationTestCase):
|
||||
out_ref = model_ref(self.img_data_2d[0][0])
|
||||
|
||||
# fused model
|
||||
model_orig.qconfig = QConfig(activation=torch.nn.Identity,
|
||||
weight=torch.nn.Identity)
|
||||
model_orig.qconfig = QConfig(
|
||||
activation=torch.nn.Identity, weight=torch.nn.Identity
|
||||
)
|
||||
model = fuse_modules_qat(
|
||||
model_orig,
|
||||
[["conv1", "bn1", "relu1"],
|
||||
["conv2", "bn2"]])
|
||||
model_orig, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]]
|
||||
)
|
||||
prep_model = prepare_qat(model, inplace=False)
|
||||
# output with fusion but no observers.
|
||||
out_fused = prep_model(self.img_data_2d[0][0])
|
||||
@ -332,7 +446,6 @@ class TestFuseEager(QuantizationTestCase):
|
||||
|
||||
checkQAT(model)
|
||||
|
||||
|
||||
def test_fusion_linear_bn_eval(self):
|
||||
model = ModelForLinearBNFusion().train()
|
||||
inp1 = torch.randn(8, 20)
|
||||
@ -357,7 +470,9 @@ class TestFuseEager(QuantizationTestCase):
|
||||
model.eval()
|
||||
golden = model(inp2)
|
||||
|
||||
model = fuse_modules(model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]])
|
||||
model = fuse_modules(
|
||||
model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]]
|
||||
)
|
||||
self.assertEqual(type(model.bn1), nn.Identity)
|
||||
self.assertEqual(type(model.bn2), nn.Identity)
|
||||
self.assertEqual(type(model.bn3), nn.Identity)
|
||||
@ -384,50 +499,68 @@ class TestFuseEager(QuantizationTestCase):
|
||||
model = ModelForFusion(default_qat_qconfig).train()
|
||||
|
||||
counter = {
|
||||
'pre_forwards': 0,
|
||||
'forwards': 0,
|
||||
"pre_forwards": 0,
|
||||
"forwards": 0,
|
||||
}
|
||||
fused = False
|
||||
|
||||
def fw_pre_hook(fused_module_class, h_module, input):
|
||||
if fused:
|
||||
self.assertEqual(type(h_module), fused_module_class,
|
||||
"After fusion owner of the first module's forward pre hook is not a fused module")
|
||||
counter['pre_forwards'] += 1
|
||||
self.assertEqual(
|
||||
type(h_module),
|
||||
fused_module_class,
|
||||
"After fusion owner of the first module's forward pre hook is not a fused module",
|
||||
)
|
||||
counter["pre_forwards"] += 1
|
||||
|
||||
def fw_hook(fused_module_class, h_module, input, output):
|
||||
if fused:
|
||||
self.assertEqual(type(h_module), fused_module_class,
|
||||
"After fusion owner of the last module's forward hook is not a fused module")
|
||||
counter['forwards'] += 1
|
||||
self.assertEqual(
|
||||
type(h_module),
|
||||
fused_module_class,
|
||||
"After fusion owner of the last module's forward hook is not a fused module",
|
||||
)
|
||||
counter["forwards"] += 1
|
||||
|
||||
# Registering two pre and two post forward hooks, thus expecting counter increment by two each inference
|
||||
model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args))
|
||||
model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args))
|
||||
model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args))
|
||||
model.conv1.register_forward_pre_hook(
|
||||
lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args)
|
||||
)
|
||||
model.sub1.conv.register_forward_pre_hook(
|
||||
lambda *args: fw_pre_hook(nni.ConvBn2d, *args)
|
||||
)
|
||||
model.relu1.register_forward_hook(
|
||||
lambda *args: fw_hook(nni.ConvBnReLU2d, *args)
|
||||
)
|
||||
model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args))
|
||||
|
||||
test_only_eval_fn(model, self.img_data_1d)
|
||||
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
|
||||
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
|
||||
self.assertEqual(counter["pre_forwards"], 2 * len(self.img_data_1d))
|
||||
self.assertEqual(counter["forwards"], 2 * len(self.img_data_1d))
|
||||
|
||||
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
||||
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
||||
model = fuse_modules_qat(model, ["conv1", "bn1", "relu1"])
|
||||
model = fuse_modules_qat(model, ["sub1.conv", "sub1.bn"])
|
||||
|
||||
fused = True
|
||||
before_fusion_pre_count = counter['pre_forwards']
|
||||
before_fusion_post_count = counter['forwards']
|
||||
before_fusion_pre_count = counter["pre_forwards"]
|
||||
before_fusion_post_count = counter["forwards"]
|
||||
test_only_eval_fn(model, self.img_data_1d)
|
||||
self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d))
|
||||
self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d))
|
||||
self.assertEqual(
|
||||
counter["pre_forwards"] - before_fusion_pre_count, 2 * len(self.img_data_1d)
|
||||
)
|
||||
self.assertEqual(
|
||||
counter["forwards"] - before_fusion_post_count, 2 * len(self.img_data_1d)
|
||||
)
|
||||
|
||||
def test_fuse_modules_with_nested_hooks(self):
|
||||
r"""Test case that checks whether a nested module with sub-sub modules registered with hooks
|
||||
can be safely fused. Safeguard for issues similar to https://github.com/pytorch/pytorch/issues/105063
|
||||
in the future.
|
||||
"""
|
||||
|
||||
def myhook(*x):
|
||||
return ""
|
||||
|
||||
for qengine in supported_qengines:
|
||||
with override_quantized_engine(qengine):
|
||||
model = ModelWithSequentialFusion().eval()
|
||||
@ -435,28 +568,32 @@ class TestFuseEager(QuantizationTestCase):
|
||||
for sub_model in model.modules():
|
||||
if isinstance(sub_model, nn.Sequential):
|
||||
for layer in sub_model:
|
||||
if hasattr(layer, 'register_forward_hook'):
|
||||
if hasattr(layer, "register_forward_hook"):
|
||||
layer.register_forward_hook(myhook)
|
||||
|
||||
fuse_modules(model, [['features.0.0', 'features.0.1', 'features.0.2']], inplace=True)
|
||||
fuse_modules(
|
||||
model,
|
||||
[["features.0.0", "features.0.1", "features.0.2"]],
|
||||
inplace=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.features[0][0]),
|
||||
nni.ConvReLU2d,
|
||||
msg="Fused submodule Conv + folded BN"
|
||||
msg="Fused submodule Conv + folded BN",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.features[0][1]),
|
||||
nn.Identity,
|
||||
msg="Fused submodule (skipped BN)"
|
||||
msg="Fused submodule (skipped BN)",
|
||||
)
|
||||
self.assertEqual(
|
||||
type(model.features[0][2]),
|
||||
nn.Identity,
|
||||
msg="Non-fused submodule Conv"
|
||||
msg="Non-fused submodule Conv",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_quantization.py TESTNAME\n\n"
|
||||
|
@ -1,17 +1,17 @@
|
||||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_quantization import (
|
||||
QuantizationTestCase,
|
||||
ModelMultipleOps,
|
||||
ModelMultipleOpsNoAvgPool,
|
||||
QuantizationTestCase,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import (
|
||||
override_quantized_engine,
|
||||
supported_qengines,
|
||||
)
|
||||
|
||||
|
||||
class TestModelNumericsEager(QuantizationTestCase):
|
||||
def test_float_quant_compare_per_tensor(self):
|
||||
for qengine in supported_qengines:
|
||||
@ -25,16 +25,24 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||
qModel = torch.ao.quantization.QuantWrapper(my_model)
|
||||
qModel.eval()
|
||||
qModel.qconfig = torch.ao.quantization.default_qconfig
|
||||
torch.ao.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.fuse_modules(
|
||||
qModel.module, [["conv1", "bn1", "relu1"]], inplace=True
|
||||
)
|
||||
torch.ao.quantization.prepare(qModel, inplace=True)
|
||||
qModel(calib_data)
|
||||
torch.ao.quantization.convert(qModel, inplace=True)
|
||||
out_q = qModel(eval_data)
|
||||
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q))
|
||||
SQNRdB = 20 * torch.log10(
|
||||
torch.norm(out_ref) / torch.norm(out_ref - out_q)
|
||||
)
|
||||
# Quantized model output should be close to floating point model output numerically
|
||||
# Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired
|
||||
# output
|
||||
self.assertGreater(SQNRdB, 30, msg='Quantized model numerics diverge from float, expect SQNR > 30 dB')
|
||||
self.assertGreater(
|
||||
SQNRdB,
|
||||
30,
|
||||
msg="Quantized model numerics diverge from float, expect SQNR > 30 dB",
|
||||
)
|
||||
|
||||
def test_float_quant_compare_per_channel(self):
|
||||
# Test for per-channel Quant
|
||||
@ -47,7 +55,9 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||
q_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||
q_model.eval()
|
||||
q_model.qconfig = torch.ao.quantization.default_per_channel_qconfig
|
||||
torch.ao.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.fuse_modules(
|
||||
q_model.module, [["conv1", "bn1", "relu1"]], inplace=True
|
||||
)
|
||||
torch.ao.quantization.prepare(q_model)
|
||||
q_model(calib_data)
|
||||
torch.ao.quantization.convert(q_model)
|
||||
@ -55,7 +65,11 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q))
|
||||
# Quantized model output should be close to floating point model output numerically
|
||||
# Setting target SQNR to be 35 dB
|
||||
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')
|
||||
self.assertGreater(
|
||||
SQNRdB,
|
||||
35,
|
||||
msg="Quantized model numerics diverge from float, expect SQNR > 35 dB",
|
||||
)
|
||||
|
||||
def test_fake_quant_true_quant_compare(self):
|
||||
for qengine in supported_qengines:
|
||||
@ -69,7 +83,9 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||
fq_model.train()
|
||||
fq_model.qconfig = torch.ao.quantization.default_qat_qconfig
|
||||
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.fuse_modules_qat(
|
||||
fq_model.module, [["conv1", "bn1", "relu1"]], inplace=True
|
||||
)
|
||||
torch.ao.quantization.prepare_qat(fq_model)
|
||||
fq_model.eval()
|
||||
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
||||
@ -78,14 +94,26 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||
fq_model.apply(torch.ao.quantization.enable_fake_quant)
|
||||
fq_model.apply(torch.ao.quantization.disable_observer)
|
||||
out_fq = fq_model(eval_data)
|
||||
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
|
||||
SQNRdB = 20 * torch.log10(
|
||||
torch.norm(out_ref) / torch.norm(out_ref - out_fq)
|
||||
)
|
||||
# Quantized model output should be close to floating point model output numerically
|
||||
# Setting target SQNR to be 35 dB
|
||||
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')
|
||||
self.assertGreater(
|
||||
SQNRdB,
|
||||
35,
|
||||
msg="Quantized model numerics diverge from float, expect SQNR > 35 dB",
|
||||
)
|
||||
torch.ao.quantization.convert(fq_model)
|
||||
out_q = fq_model(eval_data)
|
||||
SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10))
|
||||
self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB')
|
||||
SQNRdB = 20 * torch.log10(
|
||||
torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10)
|
||||
)
|
||||
self.assertGreater(
|
||||
SQNRdB,
|
||||
60,
|
||||
msg="Fake quant and true quant numerics diverge, expect SQNR > 60 dB",
|
||||
)
|
||||
|
||||
# Test to compare weight only quantized model numerics and
|
||||
# activation only quantized model numerics with float
|
||||
@ -95,8 +123,10 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||
torch.manual_seed(67)
|
||||
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
|
||||
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
|
||||
qconfigset = {torch.ao.quantization.default_weight_only_qconfig,
|
||||
torch.ao.quantization.default_activation_only_qconfig}
|
||||
qconfigset = {
|
||||
torch.ao.quantization.default_weight_only_qconfig,
|
||||
torch.ao.quantization.default_activation_only_qconfig,
|
||||
}
|
||||
SQNRTarget = [35, 45]
|
||||
for idx, qconfig in enumerate(qconfigset):
|
||||
my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
|
||||
@ -105,7 +135,9 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||
fq_model.train()
|
||||
fq_model.qconfig = qconfig
|
||||
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
||||
torch.ao.quantization.fuse_modules_qat(
|
||||
fq_model.module, [["conv1", "bn1", "relu1"]], inplace=True
|
||||
)
|
||||
torch.ao.quantization.prepare_qat(fq_model)
|
||||
fq_model.eval()
|
||||
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
||||
@ -114,11 +146,19 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||
fq_model.apply(torch.ao.quantization.enable_fake_quant)
|
||||
fq_model.apply(torch.ao.quantization.disable_observer)
|
||||
out_fq = fq_model(eval_data)
|
||||
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
|
||||
self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float')
|
||||
SQNRdB = 20 * torch.log10(
|
||||
torch.norm(out_ref) / torch.norm(out_ref - out_fq)
|
||||
)
|
||||
self.assertGreater(
|
||||
SQNRdB,
|
||||
SQNRTarget[idx],
|
||||
msg="Quantized model numerics diverge from float",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_quantization.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_quantization.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
@ -2,43 +2,45 @@
|
||||
# ruff: noqa: F841
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.ao.nn.quantized as nnq
|
||||
from torch.ao.quantization import (
|
||||
DeQuantStub,
|
||||
QuantStub,
|
||||
convert,
|
||||
default_qconfig,
|
||||
prepare,
|
||||
quantize,
|
||||
quantize_dynamic,
|
||||
)
|
||||
import torch.nn as nn
|
||||
from torch.ao.ns._numeric_suite import (
|
||||
OutputLogger,
|
||||
Shadow,
|
||||
ShadowLogger,
|
||||
compare_model_outputs,
|
||||
compare_model_stub,
|
||||
compare_weights,
|
||||
prepare_model_outputs,
|
||||
get_matching_activations,
|
||||
OutputLogger,
|
||||
prepare_model_outputs,
|
||||
Shadow,
|
||||
ShadowLogger,
|
||||
)
|
||||
from torch.ao.quantization import (
|
||||
convert,
|
||||
default_qconfig,
|
||||
DeQuantStub,
|
||||
prepare,
|
||||
quantize,
|
||||
quantize_dynamic,
|
||||
QuantStub,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import (
|
||||
AnnotatedConvBnReLUModel,
|
||||
AnnotatedConvModel,
|
||||
AnnotatedConvTransposeModel,
|
||||
AnnotatedSingleLayerLinearModel,
|
||||
LSTMwithHiddenDynamicModel,
|
||||
AnnotatedTwoLayerLinearModel,
|
||||
LSTMwithHiddenDynamicModel,
|
||||
QuantizationTestCase,
|
||||
SingleLayerLinearDynamicModel,
|
||||
test_only_eval_fn,
|
||||
skip_if_no_torchvision,
|
||||
test_only_eval_fn,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import override_qengines
|
||||
from torch.testing._internal.common_utils import IS_ARM64
|
||||
|
||||
|
||||
class SubModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -200,10 +202,18 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||
for i, val in enumerate(v["quantized"]):
|
||||
self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
|
||||
|
||||
model_list = [AnnotatedConvModel(qengine),
|
||||
AnnotatedConvTransposeModel("qnnpack"), # ConvT cannot use per channel weights
|
||||
AnnotatedConvBnReLUModel(qengine)]
|
||||
module_swap_list = [nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d, nn.ConvTranspose2d]
|
||||
model_list = [
|
||||
AnnotatedConvModel(qengine),
|
||||
AnnotatedConvTransposeModel(
|
||||
"qnnpack"
|
||||
), # ConvT cannot use per channel weights
|
||||
AnnotatedConvBnReLUModel(qengine),
|
||||
]
|
||||
module_swap_list = [
|
||||
nn.Conv2d,
|
||||
nn.intrinsic.modules.fused.ConvReLU2d,
|
||||
nn.ConvTranspose2d,
|
||||
]
|
||||
for model in model_list:
|
||||
model.eval()
|
||||
if hasattr(model, "fuse_model"):
|
||||
@ -279,7 +289,6 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||
self.assertTrue(isinstance(q_model.mod1, Shadow))
|
||||
self.assertFalse(isinstance(q_model.conv, Shadow))
|
||||
|
||||
|
||||
@override_qengines
|
||||
def test_compare_model_stub_functional_static(self):
|
||||
r"""Compare the output of static quantized functional layer and its float shadow module"""
|
||||
@ -486,7 +495,9 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||
for i, val in enumerate(v["quantized"]):
|
||||
self.assertTrue(len(v["float"][i]) == len(v["quantized"][i]))
|
||||
if i == 0:
|
||||
self.assertTrue(v["float"][i][0].shape == v["quantized"][i][0].shape)
|
||||
self.assertTrue(
|
||||
v["float"][i][0].shape == v["quantized"][i][0].shape
|
||||
)
|
||||
else:
|
||||
self.assertTrue(
|
||||
v["float"][i][0].shape == v["quantized"][i][0].shape
|
||||
@ -540,12 +551,23 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||
|
||||
@skip_if_no_torchvision
|
||||
def _test_vision_model(self, float_model):
|
||||
float_model.to('cpu')
|
||||
float_model.to("cpu")
|
||||
float_model.eval()
|
||||
float_model.fuse_model()
|
||||
float_model.qconfig = torch.ao.quantization.default_qconfig
|
||||
img_data = [(torch.rand(2, 3, 224, 224, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)]
|
||||
qmodel = quantize(float_model, torch.ao.quantization.default_eval_fn, [img_data], inplace=False)
|
||||
img_data = [
|
||||
(
|
||||
torch.rand(2, 3, 224, 224, dtype=torch.float),
|
||||
torch.randint(0, 1, (2,), dtype=torch.long),
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
qmodel = quantize(
|
||||
float_model,
|
||||
torch.ao.quantization.default_eval_fn,
|
||||
[img_data],
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
|
||||
|
||||
@ -560,9 +582,11 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||
# 'quantized', containing the activations of floating point and quantized model at matching locations.
|
||||
act_compare_dict = compare_model_outputs(float_model, qmodel, data)
|
||||
|
||||
|
||||
for key in act_compare_dict:
|
||||
compute_error(act_compare_dict[key]['float'][0], act_compare_dict[key]['quantized'][0].dequantize())
|
||||
compute_error(
|
||||
act_compare_dict[key]["float"][0],
|
||||
act_compare_dict[key]["quantized"][0].dequantize(),
|
||||
)
|
||||
|
||||
prepare_model_outputs(float_model, qmodel)
|
||||
|
||||
@ -579,10 +603,12 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
|
||||
def test_mobilenet_v2(self):
|
||||
from torchvision.models.quantization import mobilenet_v2
|
||||
|
||||
self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False))
|
||||
|
||||
@skip_if_no_torchvision
|
||||
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
|
||||
def test_mobilenet_v3(self):
|
||||
from torchvision.models.quantization import mobilenet_v3_large
|
||||
|
||||
self._test_vision_model(mobilenet_v3_large(pretrained=True, quantize=False))
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -3,6 +3,8 @@
|
||||
import copy
|
||||
import math
|
||||
|
||||
from hypothesis import given, strategies as st
|
||||
|
||||
import torch
|
||||
import torch.ao.nn.intrinsic.qat as nniqat
|
||||
import torch.ao.nn.qat as nnqat
|
||||
@ -12,8 +14,6 @@ import torch.ao.nn.quantized.dynamic as nnqd
|
||||
import torch.backends.mkldnn
|
||||
import torch.nn as nn
|
||||
import torch.testing._internal.hypothesis_utils as hu
|
||||
|
||||
from hypothesis import given, strategies as st
|
||||
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
|
||||
from torch.ao.quantization import (
|
||||
convert,
|
||||
@ -50,42 +50,63 @@ from torch.testing._internal.common_quantization import (
|
||||
test_only_train_fn,
|
||||
TwoLayerLinearModel,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_quantized import (
|
||||
override_qengines,
|
||||
override_quantized_engine,
|
||||
supported_qengines,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import skipIfNoXNNPACK
|
||||
|
||||
|
||||
hu.assert_deadline_disabled()
|
||||
from functools import reduce
|
||||
|
||||
|
||||
class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||
"""
|
||||
Conv-BN fusion implemented with explicit folding. Useful
|
||||
to verify numerical equivalency with non-folded version.
|
||||
"""
|
||||
def __init__(self,
|
||||
# ConvNd args
|
||||
in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, transposed, output_padding,
|
||||
groups,
|
||||
bias,
|
||||
padding_mode,
|
||||
# BatchNormNd args
|
||||
# num_features: out_channels
|
||||
eps=1e-05, momentum=0.1,
|
||||
# affine: True
|
||||
# track_running_stats: True
|
||||
# Args for this module
|
||||
freeze_bn=False,
|
||||
qconfig=None):
|
||||
nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
|
||||
stride, padding, dilation, transposed,
|
||||
output_padding, groups, False, padding_mode)
|
||||
assert qconfig, 'qconfig must be provided for QAT module'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# ConvNd args
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
bias,
|
||||
padding_mode,
|
||||
# BatchNormNd args
|
||||
# num_features: out_channels
|
||||
eps=1e-05,
|
||||
momentum=0.1,
|
||||
# affine: True
|
||||
# track_running_stats: True
|
||||
# Args for this module
|
||||
freeze_bn=False,
|
||||
qconfig=None,
|
||||
):
|
||||
nn.modules.conv._ConvNd.__init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
False,
|
||||
padding_mode,
|
||||
)
|
||||
assert qconfig, "qconfig must be provided for QAT module"
|
||||
self.qconfig = qconfig
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
@ -103,7 +124,7 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.register_parameter("bias", None)
|
||||
self.reset_bn_parameters()
|
||||
|
||||
def reset_running_stats(self):
|
||||
@ -123,7 +144,7 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||
def reset_parameters(self):
|
||||
super().reset_parameters()
|
||||
# A hack to avoid resetting on undefined parameters
|
||||
if hasattr(self, 'gamma'):
|
||||
if hasattr(self, "gamma"):
|
||||
self.reset_bn_parameters()
|
||||
|
||||
def update_bn_stats(self):
|
||||
@ -161,33 +182,50 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||
if self.bias is not None:
|
||||
zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
|
||||
else:
|
||||
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype)
|
||||
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias)
|
||||
zero_bias = torch.zeros(
|
||||
self.out_channels, device=scaled_weight.device, dtype=input.dtype
|
||||
)
|
||||
conv = self._conv_forward(
|
||||
input, self.weight_fake_quant(scaled_weight), zero_bias
|
||||
)
|
||||
|
||||
if self.training and not self.freeze_bn:
|
||||
# recovering original conv to get original batch_mean and batch_var
|
||||
if self.bias is not None:
|
||||
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])
|
||||
conv_orig = conv / scale_factor.reshape(
|
||||
[1, -1, 1, 1]
|
||||
) + self.bias.reshape([1, -1, 1, 1])
|
||||
else:
|
||||
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
|
||||
batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
|
||||
batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
|
||||
n = float(conv_orig.numel() / conv_orig.size()[1])
|
||||
unbiased_batch_var = batch_var * (n / (n - 1))
|
||||
batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps)
|
||||
batch_rstd = torch.ones_like(
|
||||
batch_var, memory_format=torch.contiguous_format
|
||||
) / torch.sqrt(batch_var + self.eps)
|
||||
|
||||
conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + \
|
||||
(self.beta - self.gamma * batch_rstd * batch_mean).reshape([1, -1, 1, 1])
|
||||
self.running_mean = exponential_average_factor * batch_mean.detach() + \
|
||||
(1 - exponential_average_factor) * self.running_mean
|
||||
self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \
|
||||
(1 - exponential_average_factor) * self.running_var
|
||||
conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + (
|
||||
self.beta - self.gamma * batch_rstd * batch_mean
|
||||
).reshape([1, -1, 1, 1])
|
||||
self.running_mean = (
|
||||
exponential_average_factor * batch_mean.detach()
|
||||
+ (1 - exponential_average_factor) * self.running_mean
|
||||
)
|
||||
self.running_var = (
|
||||
exponential_average_factor * unbiased_batch_var.detach()
|
||||
+ (1 - exponential_average_factor) * self.running_var
|
||||
)
|
||||
else:
|
||||
if self.bias is None:
|
||||
conv = conv + (self.beta - self.gamma * self.running_mean /
|
||||
running_std).reshape([1, -1, 1, 1])
|
||||
conv = conv + (
|
||||
self.beta - self.gamma * self.running_mean / running_std
|
||||
).reshape([1, -1, 1, 1])
|
||||
else:
|
||||
conv = conv + (self.gamma * (self.bias - self.running_mean) / running_std + self.beta).reshape([1, -1, 1, 1])
|
||||
conv = conv + (
|
||||
self.gamma * (self.bias - self.running_mean) / running_std
|
||||
+ self.beta
|
||||
).reshape([1, -1, 1, 1])
|
||||
return conv
|
||||
|
||||
def extra_repr(self):
|
||||
@ -200,23 +238,37 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||
@classmethod
|
||||
def from_float(cls, mod, qconfig=None):
|
||||
r"""Create a qat module from a float module or qparams_dict
|
||||
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
||||
or directly from user
|
||||
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
||||
or directly from user
|
||||
"""
|
||||
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
|
||||
cls._FLOAT_MODULE.__name__
|
||||
assert type(mod) == cls._FLOAT_MODULE, (
|
||||
"qat."
|
||||
+ cls.__name__
|
||||
+ ".from_float only works for "
|
||||
+ cls._FLOAT_MODULE.__name__
|
||||
)
|
||||
if not qconfig:
|
||||
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
||||
assert mod.qconfig, 'Input float module must have a valid qconfig'
|
||||
assert hasattr(
|
||||
mod, "qconfig"
|
||||
), "Input float module must have qconfig defined"
|
||||
assert mod.qconfig, "Input float module must have a valid qconfig"
|
||||
qconfig = mod.qconfig
|
||||
conv, bn = mod[0], mod[1]
|
||||
qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
|
||||
conv.stride, conv.padding, conv.dilation,
|
||||
conv.groups, conv.bias is not None,
|
||||
conv.padding_mode,
|
||||
bn.eps, bn.momentum,
|
||||
False,
|
||||
qconfig)
|
||||
qat_convbn = cls(
|
||||
conv.in_channels,
|
||||
conv.out_channels,
|
||||
conv.kernel_size,
|
||||
conv.stride,
|
||||
conv.padding,
|
||||
conv.dilation,
|
||||
conv.groups,
|
||||
conv.bias is not None,
|
||||
conv.padding_mode,
|
||||
bn.eps,
|
||||
bn.momentum,
|
||||
False,
|
||||
qconfig,
|
||||
)
|
||||
qat_convbn.weight = conv.weight
|
||||
qat_convbn.bias = conv.bias
|
||||
qat_convbn.gamma = bn.weight
|
||||
@ -226,41 +278,69 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||
qat_convbn.num_batches_tracked = bn.num_batches_tracked
|
||||
return qat_convbn
|
||||
|
||||
|
||||
class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d):
|
||||
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d
|
||||
|
||||
def __init__(self,
|
||||
# ConvNd args
|
||||
in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1,
|
||||
bias=None,
|
||||
padding_mode='zeros',
|
||||
# BatchNorm2d args
|
||||
# num_features: out_channels
|
||||
eps=1e-05, momentum=0.1,
|
||||
# affine: True
|
||||
# track_running_stats: True
|
||||
# Args for this module
|
||||
freeze_bn=False,
|
||||
qconfig=None):
|
||||
def __init__(
|
||||
self,
|
||||
# ConvNd args
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=None,
|
||||
padding_mode="zeros",
|
||||
# BatchNorm2d args
|
||||
# num_features: out_channels
|
||||
eps=1e-05,
|
||||
momentum=0.1,
|
||||
# affine: True
|
||||
# track_running_stats: True
|
||||
# Args for this module
|
||||
freeze_bn=False,
|
||||
qconfig=None,
|
||||
):
|
||||
kernel_size = _pair(kernel_size)
|
||||
stride = _pair(stride)
|
||||
padding = _pair(padding)
|
||||
dilation = _pair(dilation)
|
||||
_ReferenceConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, False, _pair(0), groups, bias, padding_mode,
|
||||
eps, momentum, freeze_bn, qconfig)
|
||||
_ReferenceConvBnNd.__init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
False,
|
||||
_pair(0),
|
||||
groups,
|
||||
bias,
|
||||
padding_mode,
|
||||
eps,
|
||||
momentum,
|
||||
freeze_bn,
|
||||
qconfig,
|
||||
)
|
||||
|
||||
|
||||
class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
self.embed_linear_data_train = [[torch.randint(0, 10, (12, 12), dtype=torch.long),
|
||||
torch.randn((12, 1), dtype=torch.float)]
|
||||
for _ in range(2)]
|
||||
self.embed_linear_data_train = [
|
||||
[
|
||||
torch.randint(0, 10, (12, 12), dtype=torch.long),
|
||||
torch.randn((12, 1), dtype=torch.float),
|
||||
]
|
||||
for _ in range(2)
|
||||
]
|
||||
self.embed_data = [[torch.randint(0, 10, (12, 1))]]
|
||||
|
||||
|
||||
def test_manual(self):
|
||||
for qengine in supported_qengines:
|
||||
with override_quantized_engine(qengine):
|
||||
@ -279,8 +359,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
|
||||
checkQuantized(model)
|
||||
|
||||
model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn,
|
||||
[self.train_data])
|
||||
model = quantize_qat(
|
||||
ManualLinearQATModel(qengine), test_only_train_fn, [self.train_data]
|
||||
)
|
||||
checkQuantized(model)
|
||||
|
||||
def test_dropout(self):
|
||||
@ -301,8 +382,11 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
|
||||
checkQuantized(model)
|
||||
|
||||
model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn,
|
||||
[self.train_data])
|
||||
model = quantize_qat(
|
||||
ManualDropoutQATModel(qengine),
|
||||
test_only_train_fn,
|
||||
[self.train_data],
|
||||
)
|
||||
checkQuantized(model)
|
||||
|
||||
def test_eval_only_fake_quant(self):
|
||||
@ -342,7 +426,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
checkQuantized(model)
|
||||
|
||||
model = ManualConvLinearQATModel()
|
||||
model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
|
||||
model = quantize_qat(
|
||||
model, test_only_train_fn, [self.img_data_2d_train]
|
||||
)
|
||||
checkQuantized(model)
|
||||
|
||||
@skipIfNoXNNPACK
|
||||
@ -351,7 +437,7 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
Supported only with qengine=qnnpack, which uses symmetric
|
||||
kernels from xnnpack library."""
|
||||
for qengine in supported_qengines:
|
||||
if qengine != 'qnnpack':
|
||||
if qengine != "qnnpack":
|
||||
continue
|
||||
with override_quantized_engine(qengine):
|
||||
model = ManualConvLinearSymmQATModel()
|
||||
@ -373,17 +459,20 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
checkQuantized(model)
|
||||
|
||||
model = ManualConvLinearSymmQATModel()
|
||||
model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
|
||||
model = quantize_qat(
|
||||
model, test_only_train_fn, [self.img_data_2d_train]
|
||||
)
|
||||
checkQuantized(model)
|
||||
|
||||
def test_dynamic_qat_linear(self):
|
||||
for qengine in supported_qengines:
|
||||
with override_quantized_engine(qengine):
|
||||
# Dynamic QAT without memoryless observers should fail
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Dynamic QAT requires a memoryless observer." +
|
||||
"This means a MovingAverage observer with averaging constant equal to 1"
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Dynamic QAT requires a memoryless observer."
|
||||
+ "This means a MovingAverage observer with averaging constant equal to 1",
|
||||
):
|
||||
model = ManualLinearDynamicQATModel(default_qat_qconfig)
|
||||
model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
|
||||
|
||||
@ -409,14 +498,23 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
|
||||
test_only_train_fn(model, self.embed_linear_data_train)
|
||||
# make sure activation_post_process is inserted after Linear.
|
||||
self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize)
|
||||
self.assertEqual(
|
||||
type(model.linear.activation_post_process),
|
||||
FusedMovingAvgObsFakeQuantize,
|
||||
)
|
||||
# make sure that Embedding has a noop for activation.
|
||||
self.assertEqual(type(model.emb.activation_post_process), NoopObserver)
|
||||
# make sure that FakeQuant zero_points are correct dtype
|
||||
self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
|
||||
self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)
|
||||
self.assertEqual(
|
||||
model.emb.weight_fake_quant.zero_point.dtype, torch.float32
|
||||
)
|
||||
self.assertEqual(
|
||||
model.linear.weight_fake_quant.zero_point.dtype, torch.int32
|
||||
)
|
||||
|
||||
model = convert(model, mapping=get_embedding_static_quant_module_mappings())
|
||||
model = convert(
|
||||
model, mapping=get_embedding_static_quant_module_mappings()
|
||||
)
|
||||
|
||||
def checkQuantized(model):
|
||||
# make sure Embedding is now a QuantizedEmbedding
|
||||
@ -430,7 +528,6 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
|
||||
checkQuantized(model)
|
||||
|
||||
|
||||
def test_embedding_bag_linear(self):
|
||||
for qengine in supported_qengines:
|
||||
with override_quantized_engine(qengine):
|
||||
@ -442,9 +539,15 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
# make sure not activation_post_process is inserted for EmbeddingBag
|
||||
self.assertFalse(hasattr(model, "activation_post_process"))
|
||||
# make sure that FakeQuant zero_points are correct dtype
|
||||
self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
|
||||
self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)
|
||||
model = convert(model, mapping=get_embedding_static_quant_module_mappings())
|
||||
self.assertEqual(
|
||||
model.emb.weight_fake_quant.zero_point.dtype, torch.float32
|
||||
)
|
||||
self.assertEqual(
|
||||
model.linear.weight_fake_quant.zero_point.dtype, torch.int32
|
||||
)
|
||||
model = convert(
|
||||
model, mapping=get_embedding_static_quant_module_mappings()
|
||||
)
|
||||
|
||||
def checkQuantized(model):
|
||||
# Make sure EmbeddingBag is now a quantized EmbeddingBag.
|
||||
@ -505,7 +608,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
||||
torch.ao.quantization.prepare(model, inplace=True)
|
||||
torch.ao.quantization.convert(model, inplace=True)
|
||||
self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys()))
|
||||
self.assertEqual(
|
||||
set(model.state_dict().keys()), set(quant_state_dict.keys())
|
||||
)
|
||||
model.eval()
|
||||
model.load_state_dict(quant_state_dict)
|
||||
out = model(x)
|
||||
@ -513,20 +618,19 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
|
||||
@override_qengines
|
||||
def test_forward_hooks_preserved(self):
|
||||
r"""Test QAT on preserving pre forward and post forward hooks of original model
|
||||
"""
|
||||
r"""Test QAT on preserving pre forward and post forward hooks of original model"""
|
||||
qengine = torch.backends.quantized.engine
|
||||
model = QuantStubModel()
|
||||
counter = {
|
||||
'pre_forwards': 0,
|
||||
'forwards': 0,
|
||||
"pre_forwards": 0,
|
||||
"forwards": 0,
|
||||
}
|
||||
|
||||
def fw_pre_hook(h_module, input):
|
||||
counter['pre_forwards'] += 1
|
||||
counter["pre_forwards"] += 1
|
||||
|
||||
def fw_hook(h_module, input, output):
|
||||
counter['forwards'] += 1
|
||||
counter["forwards"] += 1
|
||||
|
||||
model.fc.register_forward_pre_hook(fw_pre_hook)
|
||||
model.fc.register_forward_hook(fw_hook)
|
||||
@ -537,15 +641,24 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
def checkHooksIsPresent(model, before_convert=True):
|
||||
forward_hooks = 1
|
||||
if before_convert:
|
||||
self.assertEqual(len(model.quant._forward_hooks.values()), 1,
|
||||
"Quantization observer hook has disappeared")
|
||||
self.assertEqual(
|
||||
len(model.quant._forward_hooks.values()),
|
||||
1,
|
||||
"Quantization observer hook has disappeared",
|
||||
)
|
||||
forward_hooks = 2
|
||||
self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
|
||||
self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
|
||||
self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1,
|
||||
"Extra pre forward hooks have appeared on a layer")
|
||||
self.assertEqual(len(model.fc._forward_hooks.values()), forward_hooks,
|
||||
"Extra post forward hooks have appeared on a layer")
|
||||
self.assertEqual(
|
||||
len(model.fc._forward_pre_hooks.values()),
|
||||
1,
|
||||
"Extra pre forward hooks have appeared on a layer",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(model.fc._forward_hooks.values()),
|
||||
forward_hooks,
|
||||
"Extra post forward hooks have appeared on a layer",
|
||||
)
|
||||
|
||||
checkHooksIsPresent(model, True)
|
||||
x = torch.rand(2, 5, dtype=torch.float)
|
||||
@ -600,32 +713,40 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine)
|
||||
|
||||
# Test constructor parameters checks here.
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
"qconfig must be provided for QAT module"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, "qconfig must be provided for QAT module"
|
||||
):
|
||||
nnqat.EmbeddingBag(10, 5, qconfig=None)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
"Embedding Bag weights requires a qscheme of " +
|
||||
"torch.per_channel_affine_float_qparams"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
"Embedding Bag weights requires a qscheme of "
|
||||
+ "torch.per_channel_affine_float_qparams",
|
||||
):
|
||||
nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig)
|
||||
|
||||
# Test from_float checks here.
|
||||
embed = nn.Embedding(10, 5)
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
"qat.EmbeddingBag.from_float only works for EmbeddingBag"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, "qat.EmbeddingBag.from_float only works for EmbeddingBag"
|
||||
):
|
||||
nnqat.EmbeddingBag.from_float(embed)
|
||||
embed_bag = nn.EmbeddingBag(10, 5)
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
"Input float module must have qconfig defined"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, "Input float module must have qconfig defined"
|
||||
):
|
||||
nnqat.EmbeddingBag.from_float(embed_bag)
|
||||
embed_bag.qconfig = None
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
"Input float module must have a valid qconfig"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, "Input float module must have a valid qconfig"
|
||||
):
|
||||
nnqat.EmbeddingBag.from_float(embed_bag)
|
||||
embed_bag.qconfig = default_qat_qconfig
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
"Embedding Bag weights requires a qscheme of " +
|
||||
"torch.per_channel_affine_float_qparams"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
"Embedding Bag weights requires a qscheme of "
|
||||
+ "torch.per_channel_affine_float_qparams",
|
||||
):
|
||||
nnqat.EmbeddingBag.from_float(embed_bag)
|
||||
|
||||
def test_embedding_qat_qconfig_equal(self):
|
||||
@ -636,8 +757,10 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||
model = ManualEmbeddingBagLinear().train()
|
||||
model = prepare_qat(model)
|
||||
|
||||
self.assertTrue(qconfig_equals(model.emb.qconfig,
|
||||
default_embedding_qat_qconfig))
|
||||
self.assertTrue(
|
||||
qconfig_equals(model.emb.qconfig, default_embedding_qat_qconfig)
|
||||
)
|
||||
|
||||
|
||||
class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
def _test_activation_convert_numerics_impl(self, Act, data):
|
||||
@ -683,24 +806,26 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
m = M().train()
|
||||
m.qconfig = default_qat_qconfig
|
||||
m = prepare_qat(m)
|
||||
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
|
||||
self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize)
|
||||
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
|
||||
self.assertEqual(
|
||||
type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize
|
||||
)
|
||||
data = torch.randn(1, 3, 2, 4)
|
||||
before_convert = m(data)
|
||||
m = convert(m)
|
||||
after_convert = m(data)
|
||||
self.assertEqual(before_convert, after_convert)
|
||||
# make sure activation post process is removed
|
||||
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
|
||||
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
|
||||
# verify fake quant module is removd
|
||||
self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process'))
|
||||
self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
|
||||
# verify that hooks are removed
|
||||
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
|
||||
|
||||
# make sure no fake quantize module is inserted for eval mode
|
||||
|
||||
def checkNoFQModule(m):
|
||||
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
|
||||
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
|
||||
self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
|
||||
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
|
||||
|
||||
@ -734,50 +859,52 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
# make sure ReLU module is not changed
|
||||
self.assertTrue(type(m.relu), nn.ReLU)
|
||||
|
||||
@given(batch_size=st.integers(2, 4),
|
||||
input_channels_per_group=st.sampled_from([2, 3, 4]),
|
||||
height=st.integers(5, 10),
|
||||
width=st.integers(5, 10),
|
||||
output_channels_per_group=st.sampled_from([2, 3]),
|
||||
groups=st.integers(1, 3),
|
||||
kernel_h=st.integers(1, 3),
|
||||
kernel_w=st.integers(1, 3),
|
||||
stride_h=st.integers(1, 2),
|
||||
stride_w=st.integers(1, 2),
|
||||
pad_h=st.integers(0, 2),
|
||||
pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 1),
|
||||
padding_mode=st.sampled_from(['zeros', 'circular']),
|
||||
use_relu=st.booleans(),
|
||||
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
|
||||
momentum=st.sampled_from([0.1, 0.2, 0.3]),
|
||||
freeze_bn=st.booleans(),
|
||||
zero_gamma=st.booleans(),
|
||||
has_bias=st.booleans(),
|
||||
use_slow_fusion=st.booleans())
|
||||
@given(
|
||||
batch_size=st.integers(2, 4),
|
||||
input_channels_per_group=st.sampled_from([2, 3, 4]),
|
||||
height=st.integers(5, 10),
|
||||
width=st.integers(5, 10),
|
||||
output_channels_per_group=st.sampled_from([2, 3]),
|
||||
groups=st.integers(1, 3),
|
||||
kernel_h=st.integers(1, 3),
|
||||
kernel_w=st.integers(1, 3),
|
||||
stride_h=st.integers(1, 2),
|
||||
stride_w=st.integers(1, 2),
|
||||
pad_h=st.integers(0, 2),
|
||||
pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 1),
|
||||
padding_mode=st.sampled_from(["zeros", "circular"]),
|
||||
use_relu=st.booleans(),
|
||||
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
|
||||
momentum=st.sampled_from([0.1, 0.2, 0.3]),
|
||||
freeze_bn=st.booleans(),
|
||||
zero_gamma=st.booleans(),
|
||||
has_bias=st.booleans(),
|
||||
use_slow_fusion=st.booleans(),
|
||||
)
|
||||
def test_conv_bn_relu(
|
||||
self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
height,
|
||||
width,
|
||||
output_channels_per_group,
|
||||
groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
pad_h,
|
||||
pad_w,
|
||||
dilation,
|
||||
padding_mode,
|
||||
use_relu,
|
||||
eps,
|
||||
momentum,
|
||||
freeze_bn,
|
||||
zero_gamma,
|
||||
has_bias,
|
||||
use_slow_fusion,
|
||||
self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
height,
|
||||
width,
|
||||
output_channels_per_group,
|
||||
groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
pad_h,
|
||||
pad_w,
|
||||
dilation,
|
||||
padding_mode,
|
||||
use_relu,
|
||||
eps,
|
||||
momentum,
|
||||
freeze_bn,
|
||||
zero_gamma,
|
||||
has_bias,
|
||||
use_slow_fusion,
|
||||
):
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
@ -792,7 +919,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
(dilation_h, dilation_w),
|
||||
groups,
|
||||
has_bias,
|
||||
padding_mode
|
||||
padding_mode,
|
||||
).to(dtype=torch.double)
|
||||
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
|
||||
relu_op = ReLU()
|
||||
@ -811,7 +938,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
eps,
|
||||
momentum,
|
||||
freeze_bn=True,
|
||||
qconfig=default_qat_qconfig
|
||||
qconfig=default_qat_qconfig,
|
||||
).to(dtype=torch.double)
|
||||
qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion
|
||||
|
||||
@ -826,7 +953,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats)
|
||||
|
||||
# align inputs and internal parameters
|
||||
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
|
||||
input = torch.randn(
|
||||
batch_size,
|
||||
input_channels,
|
||||
height,
|
||||
width,
|
||||
dtype=torch.double,
|
||||
requires_grad=True,
|
||||
)
|
||||
conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
|
||||
if has_bias:
|
||||
conv_op.bias = torch.nn.Parameter(qat_op.bias.detach())
|
||||
@ -840,17 +974,20 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
|
||||
|
||||
if not use_relu:
|
||||
|
||||
def relu_op(x): # noqa: F811
|
||||
return x
|
||||
|
||||
if freeze_bn:
|
||||
|
||||
def ref_op(x):
|
||||
x = conv_op(x)
|
||||
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
|
||||
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
|
||||
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
|
||||
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * (
|
||||
bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)
|
||||
).reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
|
||||
x = relu_op(x)
|
||||
return x
|
||||
|
||||
else:
|
||||
ref_op = compose([conv_op, bn_op, relu_op])
|
||||
|
||||
@ -882,51 +1019,64 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
|
||||
precision = 1e-10
|
||||
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(
|
||||
weight_grad_ref, weight_grad_actual, atol=precision, rtol=0
|
||||
)
|
||||
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(
|
||||
num_batches_tracked_ref,
|
||||
num_batches_tracked_actual,
|
||||
atol=precision,
|
||||
rtol=0,
|
||||
)
|
||||
self.assertEqual(
|
||||
running_mean_ref, running_mean_actual, atol=precision, rtol=0
|
||||
)
|
||||
self.assertEqual(
|
||||
running_var_ref, running_var_actual, atol=precision, rtol=0
|
||||
)
|
||||
|
||||
@given(batch_size=st.integers(2, 4),
|
||||
input_channels_per_group=st.sampled_from([2, 3, 4]),
|
||||
height=st.integers(5, 10),
|
||||
width=st.integers(5, 10),
|
||||
output_channels_per_group=st.sampled_from([2, 3]),
|
||||
groups=st.integers(1, 3),
|
||||
kernel_h=st.integers(1, 3),
|
||||
kernel_w=st.integers(1, 3),
|
||||
stride_h=st.integers(1, 2),
|
||||
stride_w=st.integers(1, 2),
|
||||
pad_h=st.integers(0, 2),
|
||||
pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 1),
|
||||
padding_mode=st.sampled_from(['zeros', 'circular']),
|
||||
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
|
||||
momentum=st.sampled_from([0.1, 0.2, 0.3]),
|
||||
freeze_bn=st.booleans(),
|
||||
bias=st.booleans())
|
||||
@given(
|
||||
batch_size=st.integers(2, 4),
|
||||
input_channels_per_group=st.sampled_from([2, 3, 4]),
|
||||
height=st.integers(5, 10),
|
||||
width=st.integers(5, 10),
|
||||
output_channels_per_group=st.sampled_from([2, 3]),
|
||||
groups=st.integers(1, 3),
|
||||
kernel_h=st.integers(1, 3),
|
||||
kernel_w=st.integers(1, 3),
|
||||
stride_h=st.integers(1, 2),
|
||||
stride_w=st.integers(1, 2),
|
||||
pad_h=st.integers(0, 2),
|
||||
pad_w=st.integers(0, 2),
|
||||
dilation=st.integers(1, 1),
|
||||
padding_mode=st.sampled_from(["zeros", "circular"]),
|
||||
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
|
||||
momentum=st.sampled_from([0.1, 0.2, 0.3]),
|
||||
freeze_bn=st.booleans(),
|
||||
bias=st.booleans(),
|
||||
)
|
||||
def test_conv_bn_folded_vs_unfolded(
|
||||
self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
height,
|
||||
width,
|
||||
output_channels_per_group,
|
||||
groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
pad_h,
|
||||
pad_w,
|
||||
dilation,
|
||||
padding_mode,
|
||||
eps,
|
||||
momentum,
|
||||
freeze_bn,
|
||||
bias,
|
||||
self,
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
height,
|
||||
width,
|
||||
output_channels_per_group,
|
||||
groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
stride_h,
|
||||
stride_w,
|
||||
pad_h,
|
||||
pad_w,
|
||||
dilation,
|
||||
padding_mode,
|
||||
eps,
|
||||
momentum,
|
||||
freeze_bn,
|
||||
bias,
|
||||
):
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
@ -945,7 +1095,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
eps,
|
||||
momentum,
|
||||
freeze_bn=freeze_bn,
|
||||
qconfig=default_qat_qconfig
|
||||
qconfig=default_qat_qconfig,
|
||||
).to(dtype=torch.double)
|
||||
|
||||
qat_ref_op = _ReferenceConvBn2d(
|
||||
@ -961,7 +1111,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
eps,
|
||||
momentum,
|
||||
freeze_bn=freeze_bn,
|
||||
qconfig=default_qat_qconfig
|
||||
qconfig=default_qat_qconfig,
|
||||
).to(dtype=torch.double)
|
||||
|
||||
qat_op.apply(torch.ao.quantization.disable_fake_quant)
|
||||
@ -981,7 +1131,6 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr)
|
||||
|
||||
for i in range(5):
|
||||
|
||||
# make sure that calling model.train() does not override the
|
||||
# bn freeze setting
|
||||
qat_op.train()
|
||||
@ -990,7 +1139,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
qat_op_optim.zero_grad()
|
||||
qat_ref_op_optim.zero_grad()
|
||||
|
||||
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
|
||||
input = torch.randn(
|
||||
batch_size,
|
||||
input_channels,
|
||||
height,
|
||||
width,
|
||||
dtype=torch.double,
|
||||
requires_grad=True,
|
||||
)
|
||||
input_clone = input.detach().clone().requires_grad_()
|
||||
|
||||
if i > 2:
|
||||
@ -1030,12 +1186,23 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
|
||||
precision = 1e-5
|
||||
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(
|
||||
weight_grad_ref, weight_grad_actual, atol=precision, rtol=0
|
||||
)
|
||||
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
|
||||
self.assertEqual(
|
||||
num_batches_tracked_ref,
|
||||
num_batches_tracked_actual,
|
||||
atol=precision,
|
||||
rtol=0,
|
||||
)
|
||||
self.assertEqual(
|
||||
running_mean_ref, running_mean_actual, atol=precision, rtol=0
|
||||
)
|
||||
self.assertEqual(
|
||||
running_var_ref, running_var_actual, atol=precision, rtol=0
|
||||
)
|
||||
|
||||
qat_op_optim.step()
|
||||
qat_ref_op_optim.step()
|
||||
@ -1048,7 +1215,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
nn.BatchNorm1d(4),
|
||||
)
|
||||
m_ref_copy = copy.deepcopy(m_ref)
|
||||
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
|
||||
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [["0", "1"]])
|
||||
qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
||||
m_ref_copy[0].qconfig = qconfig
|
||||
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
|
||||
@ -1071,7 +1238,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
nn.BatchNorm1d(4),
|
||||
)
|
||||
m_ref_copy = copy.deepcopy(m_ref)
|
||||
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
|
||||
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [["0", "1"]])
|
||||
qconfig = default_symmetric_qnnpack_qat_qconfig
|
||||
m_ref_copy[0].qconfig = qconfig
|
||||
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
|
||||
@ -1093,14 +1260,13 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
)
|
||||
data = torch.randn(4, 4)
|
||||
m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
||||
m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']])
|
||||
m = torch.ao.quantization.fuse_modules_qat(m, [["1", "2"]])
|
||||
mp = prepare_qat(m)
|
||||
mp(data)
|
||||
mq = convert(mp)
|
||||
self.assertTrue(type(mq[1]) == nnq.Linear)
|
||||
self.assertTrue(type(mq[2]) == nn.Identity)
|
||||
|
||||
|
||||
@skipIfNoXNNPACK
|
||||
@override_qengines
|
||||
def test_linear_precomputed_fake_quant(self):
|
||||
@ -1124,10 +1290,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
m_ref.activation_post_process = activation
|
||||
m_ref.qconfig = qconfig
|
||||
m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True)
|
||||
self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale)
|
||||
self.assertTrue(
|
||||
m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_quantization.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_quantization.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
Reference in New Issue
Block a user