[Easy] enable PYFMT for torch/quantization/eager (#150761)

All modifications are done through tools, the detailed commands are as follows:

```bash
lintrunner -a --take "PYFMT" --all-files
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150761
Approved by: https://github.com/jerryzh168
This commit is contained in:
FFFrog
2025-04-07 17:02:16 +08:00
committed by PyTorch MergeBot
parent 91b090c912
commit 8895c290f4
8 changed files with 1473 additions and 809 deletions

View File

@ -1165,14 +1165,6 @@ exclude_patterns = [
'test/quantization/core/test_utils.py',
'test/quantization/core/test_workflow_module.py',
'test/quantization/core/test_workflow_ops.py',
'test/quantization/eager/__init__.py',
'test/quantization/eager/test_bias_correction_eager.py',
'test/quantization/eager/test_equalize_eager.py',
'test/quantization/eager/test_fuse_eager.py',
'test/quantization/eager/test_model_numerics.py',
'test/quantization/eager/test_numeric_suite_eager.py',
'test/quantization/eager/test_quantize_eager_ptq.py',
'test/quantization/eager/test_quantize_eager_qat.py',
'test/quantization/fx/__init__.py',
'test/quantization/fx/test_equalize_fx.py',
'test/quantization/fx/test_model_report_fx.py',

View File

@ -1,24 +1,23 @@
# Owner(s): ["oncall: quantization"]
import copy
import torch
import torch.nn as nn
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.ao.quantization import default_qconfig
from torch.ao.quantization import QuantWrapper
import torch.ao.ns._numeric_suite as ns
import torch.nn as nn
from torch.ao.quantization import default_qconfig, QuantWrapper
from torch.ao.quantization._correct_bias import (
_supported_modules,
_supported_modules_quantized,
bias_correction,
get_module,
get_param,
parent_child_names
parent_child_names,
)
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skipIfNoFBGEMM,
)
import copy
class TestBiasCorrectionEager(QuantizationTestCase):
@ -28,9 +27,9 @@ class TestBiasCorrectionEager(QuantizationTestCase):
return 20 * torch.log10(Ps / Pn)
def correct_artificial_bias_quantize(self, float_model, img_data):
''' Adding artificial bias and testing if bias persists after bias
correction. This test case changes the bias of a quantized submodule
'''
"""Adding artificial bias and testing if bias persists after bias
correction. This test case changes the bias of a quantized submodule
"""
artificial_model = copy.deepcopy(float_model)
artificial_model.qconfig = default_qconfig
torch.ao.quantization.prepare(artificial_model, inplace=True)
@ -41,12 +40,17 @@ class TestBiasCorrectionEager(QuantizationTestCase):
# manually changing bias
for name, submodule in artificial_model.named_modules():
if type(submodule) in _supported_modules:
x = get_param(submodule, 'bias')
weight = get_param(submodule, 'weight')
x = get_param(submodule, "bias")
weight = get_param(submodule, "weight")
if x is not None:
submodule.set_weight_bias(weight, x.data * 3)
bias_correction(float_model, artificial_model, img_data, target_modules=_supported_modules_quantized)
bias_correction(
float_model,
artificial_model,
img_data,
target_modules=_supported_modules_quantized,
)
# Trims off the shadow module,
for name, submodule in artificial_model.named_modules():
@ -58,11 +62,13 @@ class TestBiasCorrectionEager(QuantizationTestCase):
for name, artificial_submodule in artificial_model.named_modules():
if type(artificial_submodule) in _supported_modules_quantized:
submodule = get_module(float_model, name)
float_bias = get_param(submodule, 'bias')
artificial_bias = get_param(artificial_submodule, 'bias')
float_bias = get_param(submodule, "bias")
artificial_bias = get_param(artificial_submodule, "bias")
self.assertTrue(self.compute_sqnr(float_bias, artificial_bias) > 30,
"Correcting quantized bias produced too much noise, sqnr score too low")
self.assertTrue(
self.compute_sqnr(float_bias, artificial_bias) > 30,
"Correcting quantized bias produced too much noise, sqnr score too low",
)
@skipIfNoFBGEMM
def test_linear_chain(self):
@ -78,9 +84,15 @@ class TestBiasCorrectionEager(QuantizationTestCase):
x = self.linear2(x)
x = self.linear3(x)
return x
float_model = QuantWrapper(LinearChain())
img_data = [(torch.rand(10, 3, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long))
for _ in range(50)]
img_data = [
(
torch.rand(10, 3, dtype=torch.float),
torch.randint(0, 1, (2,), dtype=torch.long),
)
for _ in range(50)
]
self.correct_artificial_bias_quantize(float_model, img_data)
@skipIfNoFBGEMM
@ -97,7 +109,13 @@ class TestBiasCorrectionEager(QuantizationTestCase):
x = self.conv2d2(x)
x = self.conv2d3(x)
return x
float_model = QuantWrapper(ConvChain())
img_data = [(torch.rand(10, 3, 125, 125, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long))
for _ in range(50)]
img_data = [
(
torch.rand(10, 3, 125, 125, dtype=torch.float),
torch.randint(0, 1, (2,), dtype=torch.long),
)
for _ in range(50)
]
self.correct_artificial_bias_quantize(float_model, img_data)

View File

@ -1,20 +1,19 @@
# Owner(s): ["oncall: quantization"]
import torch
import torch.nn as nn
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.ao.quantization.fuse_modules import fuse_modules
import torch.ao.quantization._equalize as _equalize
import copy
import torch
import torch.ao.quantization._equalize as _equalize
import torch.nn as nn
from torch.ao.quantization.fuse_modules import fuse_modules
from torch.testing._internal.common_quantization import QuantizationTestCase
class TestEqualizeEager(QuantizationTestCase):
def checkChannelsEqualized(self, tensor1, tensor2, output_axis, input_axis):
''' Checks the channel ranges of tensor1, tensor2 are the same,
"""Checks the channel ranges of tensor1, tensor2 are the same,
which is an indication that equalization has been applied correctly
'''
"""
output_channel_tensor1 = _equalize.channel_range(tensor1, output_axis)
input_channel_tensor2 = _equalize.channel_range(tensor2, input_axis)
@ -23,18 +22,17 @@ class TestEqualizeEager(QuantizationTestCase):
self.assertEqual(output_channel_tensor1, input_channel_tensor2)
def getModule(self, model, name):
''' Given the name is a submodule to a model, return the submodule
'''
"""Given the name is a submodule to a model, return the submodule"""
curr = model
name = name.split('.')
name = name.split(".")
for subname in name:
curr = curr._modules[subname]
return curr
def test_cross_layer_equalization(self):
''' applies _equalize.cross_layer_equalization on two modules and checks
"""applies _equalize.cross_layer_equalization on two modules and checks
to make sure channels ranges are equivalent
'''
"""
module1 = nn.Conv2d(3, 4, 2)
module2 = nn.Linear(4, 4)
@ -45,13 +43,18 @@ class TestEqualizeEager(QuantizationTestCase):
mod_tensor1, mod_tensor2 = module1.weight, module2.weight
self.checkChannelsEqualized(mod_tensor1, mod_tensor2, module1_output_channel_axis, module2_input_channel_axis)
self.checkChannelsEqualized(
mod_tensor1,
mod_tensor2,
module1_output_channel_axis,
module2_input_channel_axis,
)
def test_converged(self):
''' Sanity checks on _equalize.converged working
"""Sanity checks on _equalize.converged working
identical modules should return true
modules with high difference in weights should return false
'''
"""
module1 = nn.Linear(3, 3)
module2 = nn.Linear(3, 3)
@ -59,18 +62,19 @@ class TestEqualizeEager(QuantizationTestCase):
module2.weight = nn.parameter.Parameter(torch.zeros(module1.weight.size()))
# input is a dictionary
dictionary_1 = {'linear1': module1}
dictionary_2 = {'linear1': module2}
dictionary_1 = {"linear1": module1}
dictionary_2 = {"linear1": module2}
self.assertTrue(_equalize.converged(dictionary_1, dictionary_1, 1e-6))
self.assertFalse(_equalize.converged(dictionary_1, dictionary_2, 1e-6))
def test_equalize(self):
''' First checks to see if _equalize.equalize can handle multiple
"""First checks to see if _equalize.equalize can handle multiple
pair modules as input
then checks correctness of the function by ensuring the equalized
and unequalized versions of the model yield the same output
given the same input
'''
"""
class ChainModule(nn.Module):
def __init__(self) -> None:
super().__init__()
@ -83,13 +87,16 @@ class TestEqualizeEager(QuantizationTestCase):
x = self.linear2(x)
x = self.linear3(x)
return x
chain1 = ChainModule()
chain2 = copy.deepcopy(chain1)
_equalize.equalize(chain1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6)
linear1 = self.getModule(chain1, 'linear1')
linear2 = self.getModule(chain1, 'linear2')
linear3 = self.getModule(chain1, 'linear3')
_equalize.equalize(
chain1, [["linear1", "linear2"], ["linear2", "linear3"]], 1e-6
)
linear1 = self.getModule(chain1, "linear1")
linear2 = self.getModule(chain1, "linear2")
linear3 = self.getModule(chain1, "linear3")
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
@ -98,7 +105,7 @@ class TestEqualizeEager(QuantizationTestCase):
self.assertEqual(chain1(input), chain2(input))
def test_equalize_fused_convrelu(self):
''' Checks to see if eager mode equalization supports fused
"""Checks to see if eager mode equalization supports fused
ConvReLU2d models
A model with 3 ConvReLU2d is constructed. Next, the conv2d and relu
@ -106,7 +113,8 @@ class TestEqualizeEager(QuantizationTestCase):
equalization applied. Finally, we ensure that the channels have been
equalized and that the equalized and unequalized versions of the model
yield the same output given the same input
'''
"""
class M(nn.Module):
def __init__(self) -> None:
super().__init__()
@ -128,13 +136,15 @@ class TestEqualizeEager(QuantizationTestCase):
model = M()
fused_model1 = fuse_modules(model, [['conv1', 'relu1'], ['conv2', 'relu2'], ['conv3', 'relu3']])
fused_model1 = fuse_modules(
model, [["conv1", "relu1"], ["conv2", "relu2"], ["conv3", "relu3"]]
)
fused_model2 = copy.deepcopy(fused_model1)
_equalize.equalize(fused_model1, [['conv1', 'conv2'], ['conv2', 'conv3']], 1e-6)
conv1 = self.getModule(fused_model1, 'conv1')[0]
conv2 = self.getModule(fused_model1, 'conv2')[0]
conv3 = self.getModule(fused_model1, 'conv3')[0]
_equalize.equalize(fused_model1, [["conv1", "conv2"], ["conv2", "conv3"]], 1e-6)
conv1 = self.getModule(fused_model1, "conv1")[0]
conv2 = self.getModule(fused_model1, "conv2")[0]
conv3 = self.getModule(fused_model1, "conv3")[0]
self.checkChannelsEqualized(conv1.weight, conv2.weight, 0, 1)
self.checkChannelsEqualized(conv2.weight, conv3.weight, 0, 1)
@ -144,7 +154,7 @@ class TestEqualizeEager(QuantizationTestCase):
self.assertEqual(fused_model1(input), model(input))
def test_equalize_fused_linearrelu(self):
''' Checks to see if eager mode equalization supports fused
"""Checks to see if eager mode equalization supports fused
LinearReLU models
A model with 3 LinearReLU is constructed. Next, the linear and relu
@ -152,7 +162,8 @@ class TestEqualizeEager(QuantizationTestCase):
equalization applied. Finally, we ensure that the channels have been
equalized and that the equalized and unequalized versions of the model
yield the same output given the same input
'''
"""
class M(nn.Module):
def __init__(self) -> None:
super().__init__()
@ -174,13 +185,17 @@ class TestEqualizeEager(QuantizationTestCase):
model = M()
fused_model1 = fuse_modules(model, [['linear1', 'relu1'], ['linear2', 'relu2'], ['linear3', 'relu3']])
fused_model1 = fuse_modules(
model, [["linear1", "relu1"], ["linear2", "relu2"], ["linear3", "relu3"]]
)
fused_model2 = copy.deepcopy(fused_model1)
_equalize.equalize(fused_model1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6)
linear1 = self.getModule(fused_model1, 'linear1')[0]
linear2 = self.getModule(fused_model1, 'linear2')[0]
linear3 = self.getModule(fused_model1, 'linear3')[0]
_equalize.equalize(
fused_model1, [["linear1", "linear2"], ["linear2", "linear3"]], 1e-6
)
linear1 = self.getModule(fused_model1, "linear1")[0]
linear2 = self.getModule(fused_model1, "linear2")[0]
linear3 = self.getModule(fused_model1, "linear3")[0]
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)

View File

@ -3,37 +3,35 @@
import copy
import torch
import torch.nn as nn
import torch.ao.nn.quantized as nnq
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.quantized as nnq
import torch.nn as nn
from torch.ao.quantization import (
quantize,
prepare,
convert,
prepare_qat,
quantize_qat,
default_qat_qconfig,
default_qconfig,
fuse_modules,
fuse_modules_qat,
prepare,
prepare_qat,
QConfig,
default_qconfig,
default_qat_qconfig,
quantize,
quantize_qat,
)
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
ModelForFusion,
ModelWithSequentialFusion,
ModelForLinearBNFusion,
ModelForFusionWithBias,
ModelForConvTransposeBNFusion,
ModelForFusion,
ModelForFusionWithBias,
ModelForLinearBNFusion,
ModelWithSequentialFusion,
QuantizationTestCase,
SingleLayerLinearModel,
skipIfNoFBGEMM,
test_only_eval_fn,
test_only_train_fn,
skipIfNoFBGEMM,
)
from torch.testing._internal.common_quantized import (
override_quantized_engine,
supported_qengines,
@ -45,23 +43,38 @@ class TestFuseEager(QuantizationTestCase):
def test_fuse_module_train(self):
model = ModelForFusion(default_qat_qconfig).train()
# Test step by step fusion
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
msg="Fused Conv + BN + Relu first layer")
self.assertEqual(type(model.bn1), torch.nn.Identity,
msg="Fused Conv + BN + Relu (skipped BN)")
self.assertEqual(type(model.relu1), torch.nn.Identity,
msg="Fused Conv + BN + Relu (skipped Relu)")
model = fuse_modules_qat(model, ["conv1", "bn1", "relu1"])
model = fuse_modules_qat(model, ["sub1.conv", "sub1.bn"])
self.assertEqual(
type(model.conv1),
nni.ConvBnReLU2d,
msg="Fused Conv + BN + Relu first layer",
)
self.assertEqual(
type(model.bn1),
torch.nn.Identity,
msg="Fused Conv + BN + Relu (skipped BN)",
)
self.assertEqual(
type(model.relu1),
torch.nn.Identity,
msg="Fused Conv + BN + Relu (skipped Relu)",
)
self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
msg="Fused submodule Conv + BN")
self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
msg="Fused submodule Conv + BN (skipped BN)")
self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
msg="Non-fused submodule ReLU")
self.assertEqual(
type(model.sub1.conv), nni.ConvBn2d, msg="Fused submodule Conv + BN"
)
self.assertEqual(
type(model.sub1.bn),
torch.nn.Identity,
msg="Fused submodule Conv + BN (skipped BN)",
)
self.assertEqual(
type(model.sub2.conv), torch.nn.Conv2d, msg="Non-fused submodule Conv"
)
self.assertEqual(
type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU"
)
model = prepare_qat(model)
self.checkObservers(model)
@ -89,69 +102,121 @@ class TestFuseEager(QuantizationTestCase):
test_only_eval_fn(model, self.img_data_1d)
self.checkNoQconfig(model)
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
with self.assertRaisesRegex(
RuntimeError,
"Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'",
):
checkQuantized(model)
model = ModelForFusion(default_qat_qconfig).train()
model = fuse_modules_qat(
model,
[['conv1', 'bn1', 'relu1'],
['sub1.conv', 'sub1.bn']])
model, [["conv1", "bn1", "relu1"], ["sub1.conv", "sub1.bn"]]
)
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
with self.assertRaisesRegex(
RuntimeError,
"Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'",
):
checkQuantized(model)
def test_fuse_module_eval(self):
model = ModelForFusion(default_qconfig)
model.eval()
model = fuse_modules(
model,
[['conv3', 'bn3', 'relu4'],
['conv1', 'bn1', 'relu1'],
['conv2', 'relu2'],
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn']])
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
msg="Fused Conv + BN + Relu first layer (BN is folded)")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
msg="Fused Conv + BN + Relu second layer (Relu only)")
self.assertEqual(type(model.bn1), nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped BN)")
self.assertEqual(type(model.relu1), nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.conv2), nni.ConvReLU3d,
msg="Fused Conv + BN + Relu first layer (BN is folded)")
self.assertEqual(type(model.bn2), nni.BNReLU3d,
msg="Fused BN + Relu first layer (Relu is folded))")
self.assertEqual(type(model.relu3), nn.Identity,
msg="Fused BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.conv2[0]), nn.Conv3d,
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
self.assertEqual(type(model.conv2[1]), nn.ReLU,
msg="Fused Conv + BN + Relu second layer (Relu only)")
self.assertEqual(type(model.relu2), nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
[
["conv3", "bn3", "relu4"],
["conv1", "bn1", "relu1"],
["conv2", "relu2"],
["bn2", "relu3"],
["sub1.conv", "sub1.bn"],
],
)
self.assertEqual(
type(model.conv1),
nni.ConvReLU2d,
msg="Fused Conv + BN + Relu first layer (BN is folded)",
)
self.assertEqual(
type(model.conv1[0]),
nn.Conv2d,
msg="Fused Conv + BN + Relu (Conv + folded BN only)",
)
self.assertEqual(
type(model.conv1[1]),
nn.ReLU,
msg="Fused Conv + BN + Relu second layer (Relu only)",
)
self.assertEqual(
type(model.bn1),
nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped BN)",
)
self.assertEqual(
type(model.relu1),
nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped Relu)",
)
self.assertEqual(
type(model.conv2),
nni.ConvReLU3d,
msg="Fused Conv + BN + Relu first layer (BN is folded)",
)
self.assertEqual(
type(model.bn2),
nni.BNReLU3d,
msg="Fused BN + Relu first layer (Relu is folded))",
)
self.assertEqual(
type(model.relu3),
nn.Identity,
msg="Fused BN + Relu second layer (Skipped Relu)",
)
self.assertEqual(
type(model.conv2[0]),
nn.Conv3d,
msg="Fused Conv + BN + Relu (Conv + folded BN only)",
)
self.assertEqual(
type(model.conv2[1]),
nn.ReLU,
msg="Fused Conv + BN + Relu second layer (Relu only)",
)
self.assertEqual(
type(model.relu2),
nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped Relu)",
)
self.assertEqual(type(model.conv3), nni.ConvReLU1d,
msg="Fused Conv + Relu for Conv1d (folded BN)")
self.assertEqual(type(model.conv3[0]), nn.Conv1d,
msg="Fused Conv + Relu for Conv1d ")
self.assertEqual(type(model.conv3[1]), nn.ReLU,
msg="Fused Conv + Relu for Conv1d")
self.assertEqual(type(model.bn3), nn.Identity,
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)")
self.assertEqual(
type(model.conv3),
nni.ConvReLU1d,
msg="Fused Conv + Relu for Conv1d (folded BN)",
)
self.assertEqual(
type(model.conv3[0]), nn.Conv1d, msg="Fused Conv + Relu for Conv1d "
)
self.assertEqual(
type(model.conv3[1]), nn.ReLU, msg="Fused Conv + Relu for Conv1d"
)
self.assertEqual(
type(model.bn3),
nn.Identity,
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)",
)
self.assertEqual(type(model.sub1.conv), nn.Conv2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.sub1.bn), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.sub2.conv), nn.Conv2d,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
msg="Non-fused submodule ReLU")
self.assertEqual(
type(model.sub1.conv), nn.Conv2d, msg="Fused submodule Conv + folded BN"
)
self.assertEqual(
type(model.sub1.bn), nn.Identity, msg="Fused submodule (skipped BN)"
)
self.assertEqual(
type(model.sub2.conv), nn.Conv2d, msg="Non-fused submodule Conv"
)
self.assertEqual(
type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU"
)
model = prepare(model)
self.checkObservers(model)
@ -176,11 +241,14 @@ class TestFuseEager(QuantizationTestCase):
model = ModelForFusion(default_qconfig).eval()
model = fuse_modules(
model,
[['conv1', 'bn1', 'relu1'],
['conv2', 'relu2'],
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn'],
['conv3', 'bn3', 'relu4']])
[
["conv1", "bn1", "relu1"],
["conv2", "relu2"],
["bn2", "relu3"],
["sub1.conv", "sub1.bn"],
["conv3", "bn3", "relu4"],
],
)
model = quantize(model, test_only_eval_fn, [self.img_data_1d])
checkQuantized(model)
@ -190,27 +258,46 @@ class TestFuseEager(QuantizationTestCase):
model = ModelWithSequentialFusion().train()
model.to(torch.float)
fuse_modules_qat(
model, [['conv1', 'relu1'] ,
['features.0.0', 'features.0.1', 'features.0.2'],
['features.1.0', 'features.1.1', 'features.1.2'],
['features.2.0', 'features.2.1', 'features.2.2'],
['classifier.0', 'classifier.1']],
inplace=True)
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
msg="Fused Conv + Relu: nni.ConvReLU2d")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
msg="Fused Conv + Relu: Conv2d")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
msg="Fused Conv + Relu: Relu")
self.assertEqual(type(model.relu1), nn.Identity,
msg="Fused Conv + Relu: Identity")
model,
[
["conv1", "relu1"],
["features.0.0", "features.0.1", "features.0.2"],
["features.1.0", "features.1.1", "features.1.2"],
["features.2.0", "features.2.1", "features.2.2"],
["classifier.0", "classifier.1"],
],
inplace=True,
)
self.assertEqual(
type(model.conv1),
nni.ConvReLU2d,
msg="Fused Conv + Relu: nni.ConvReLU2d",
)
self.assertEqual(
type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d"
)
self.assertEqual(
type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu"
)
self.assertEqual(
type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity"
)
for i in range(3):
self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
msg="Non-fused submodule Conv")
self.assertEqual(
type(model.features[i][0]),
nni.ConvBnReLU2d,
msg="Fused submodule Conv + folded BN",
)
self.assertEqual(
type(model.features[i][1]),
nn.Identity,
msg="Fused submodule (skipped BN)",
)
self.assertEqual(
type(model.features[i][2]),
nn.Identity,
msg="Non-fused submodule Conv",
)
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
@ -218,17 +305,26 @@ class TestFuseEager(QuantizationTestCase):
self.checkObservers(model)
model(self.img_data_2d[0][0])
def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
self.assertEqual(type(model.relu1), nn.Identity)
for i in range(3):
self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
msg="Non-fused submodule Conv")
self.assertEqual(
type(model.features[i][0]),
nniqat.ConvBnReLU2d,
msg="Fused submodule Conv + folded BN",
)
self.assertEqual(
type(model.features[i][1]),
nn.Identity,
msg="Fused submodule (skipped BN)",
)
self.assertEqual(
type(model.features[i][2]),
nn.Identity,
msg="Non-fused submodule Conv",
)
self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
@ -245,27 +341,45 @@ class TestFuseEager(QuantizationTestCase):
model.to(torch.float)
fuse_modules(
model,
[['conv1', 'relu1'],
['features.0.0', 'features.0.1', 'features.0.2'],
['features.1.0', 'features.1.1', 'features.1.2'],
['features.2.0', 'features.2.1', 'features.2.2'],
['classifier.0', 'classifier.1']],
inplace=True)
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
msg="Fused Conv + Relu: nni.ConvReLU2d")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
msg="Fused Conv + Relu: Conv2d")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
msg="Fused Conv + Relu: Relu")
self.assertEqual(type(model.relu1), nn.Identity,
msg="Fused Conv + Relu: Identity")
[
["conv1", "relu1"],
["features.0.0", "features.0.1", "features.0.2"],
["features.1.0", "features.1.1", "features.1.2"],
["features.2.0", "features.2.1", "features.2.2"],
["classifier.0", "classifier.1"],
],
inplace=True,
)
self.assertEqual(
type(model.conv1),
nni.ConvReLU2d,
msg="Fused Conv + Relu: nni.ConvReLU2d",
)
self.assertEqual(
type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d"
)
self.assertEqual(
type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu"
)
self.assertEqual(
type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity"
)
for i in range(3):
self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
msg="Non-fused submodule Conv")
self.assertEqual(
type(model.features[i][0]),
nni.ConvReLU2d,
msg="Fused submodule Conv + folded BN",
)
self.assertEqual(
type(model.features[i][1]),
nn.Identity,
msg="Fused submodule (skipped BN)",
)
self.assertEqual(
type(model.features[i][2]),
nn.Identity,
msg="Non-fused submodule Conv",
)
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
@ -297,12 +411,12 @@ class TestFuseEager(QuantizationTestCase):
out_ref = model_ref(self.img_data_2d[0][0])
# fused model
model_orig.qconfig = QConfig(activation=torch.nn.Identity,
weight=torch.nn.Identity)
model_orig.qconfig = QConfig(
activation=torch.nn.Identity, weight=torch.nn.Identity
)
model = fuse_modules_qat(
model_orig,
[["conv1", "bn1", "relu1"],
["conv2", "bn2"]])
model_orig, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]]
)
prep_model = prepare_qat(model, inplace=False)
# output with fusion but no observers.
out_fused = prep_model(self.img_data_2d[0][0])
@ -332,7 +446,6 @@ class TestFuseEager(QuantizationTestCase):
checkQAT(model)
def test_fusion_linear_bn_eval(self):
model = ModelForLinearBNFusion().train()
inp1 = torch.randn(8, 20)
@ -357,7 +470,9 @@ class TestFuseEager(QuantizationTestCase):
model.eval()
golden = model(inp2)
model = fuse_modules(model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]])
model = fuse_modules(
model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]]
)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.bn2), nn.Identity)
self.assertEqual(type(model.bn3), nn.Identity)
@ -384,50 +499,68 @@ class TestFuseEager(QuantizationTestCase):
model = ModelForFusion(default_qat_qconfig).train()
counter = {
'pre_forwards': 0,
'forwards': 0,
"pre_forwards": 0,
"forwards": 0,
}
fused = False
def fw_pre_hook(fused_module_class, h_module, input):
if fused:
self.assertEqual(type(h_module), fused_module_class,
"After fusion owner of the first module's forward pre hook is not a fused module")
counter['pre_forwards'] += 1
self.assertEqual(
type(h_module),
fused_module_class,
"After fusion owner of the first module's forward pre hook is not a fused module",
)
counter["pre_forwards"] += 1
def fw_hook(fused_module_class, h_module, input, output):
if fused:
self.assertEqual(type(h_module), fused_module_class,
"After fusion owner of the last module's forward hook is not a fused module")
counter['forwards'] += 1
self.assertEqual(
type(h_module),
fused_module_class,
"After fusion owner of the last module's forward hook is not a fused module",
)
counter["forwards"] += 1
# Registering two pre and two post forward hooks, thus expecting counter increment by two each inference
model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args))
model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args))
model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args))
model.conv1.register_forward_pre_hook(
lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args)
)
model.sub1.conv.register_forward_pre_hook(
lambda *args: fw_pre_hook(nni.ConvBn2d, *args)
)
model.relu1.register_forward_hook(
lambda *args: fw_hook(nni.ConvBnReLU2d, *args)
)
model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args))
test_only_eval_fn(model, self.img_data_1d)
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
self.assertEqual(counter["pre_forwards"], 2 * len(self.img_data_1d))
self.assertEqual(counter["forwards"], 2 * len(self.img_data_1d))
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
model = fuse_modules_qat(model, ["conv1", "bn1", "relu1"])
model = fuse_modules_qat(model, ["sub1.conv", "sub1.bn"])
fused = True
before_fusion_pre_count = counter['pre_forwards']
before_fusion_post_count = counter['forwards']
before_fusion_pre_count = counter["pre_forwards"]
before_fusion_post_count = counter["forwards"]
test_only_eval_fn(model, self.img_data_1d)
self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d))
self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d))
self.assertEqual(
counter["pre_forwards"] - before_fusion_pre_count, 2 * len(self.img_data_1d)
)
self.assertEqual(
counter["forwards"] - before_fusion_post_count, 2 * len(self.img_data_1d)
)
def test_fuse_modules_with_nested_hooks(self):
r"""Test case that checks whether a nested module with sub-sub modules registered with hooks
can be safely fused. Safeguard for issues similar to https://github.com/pytorch/pytorch/issues/105063
in the future.
"""
def myhook(*x):
return ""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ModelWithSequentialFusion().eval()
@ -435,28 +568,32 @@ class TestFuseEager(QuantizationTestCase):
for sub_model in model.modules():
if isinstance(sub_model, nn.Sequential):
for layer in sub_model:
if hasattr(layer, 'register_forward_hook'):
if hasattr(layer, "register_forward_hook"):
layer.register_forward_hook(myhook)
fuse_modules(model, [['features.0.0', 'features.0.1', 'features.0.2']], inplace=True)
fuse_modules(
model,
[["features.0.0", "features.0.1", "features.0.2"]],
inplace=True,
)
self.assertEqual(
type(model.features[0][0]),
nni.ConvReLU2d,
msg="Fused submodule Conv + folded BN"
msg="Fused submodule Conv + folded BN",
)
self.assertEqual(
type(model.features[0][1]),
nn.Identity,
msg="Fused submodule (skipped BN)"
msg="Fused submodule (skipped BN)",
)
self.assertEqual(
type(model.features[0][2]),
nn.Identity,
msg="Non-fused submodule Conv"
msg="Non-fused submodule Conv",
)
if __name__ == '__main__':
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"

View File

@ -1,17 +1,17 @@
# Owner(s): ["oncall: quantization"]
import torch
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
ModelMultipleOps,
ModelMultipleOpsNoAvgPool,
QuantizationTestCase,
)
from torch.testing._internal.common_quantized import (
override_quantized_engine,
supported_qengines,
)
class TestModelNumericsEager(QuantizationTestCase):
def test_float_quant_compare_per_tensor(self):
for qengine in supported_qengines:
@ -25,16 +25,24 @@ class TestModelNumericsEager(QuantizationTestCase):
qModel = torch.ao.quantization.QuantWrapper(my_model)
qModel.eval()
qModel.qconfig = torch.ao.quantization.default_qconfig
torch.ao.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.ao.quantization.fuse_modules(
qModel.module, [["conv1", "bn1", "relu1"]], inplace=True
)
torch.ao.quantization.prepare(qModel, inplace=True)
qModel(calib_data)
torch.ao.quantization.convert(qModel, inplace=True)
out_q = qModel(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q))
SQNRdB = 20 * torch.log10(
torch.norm(out_ref) / torch.norm(out_ref - out_q)
)
# Quantized model output should be close to floating point model output numerically
# Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired
# output
self.assertGreater(SQNRdB, 30, msg='Quantized model numerics diverge from float, expect SQNR > 30 dB')
self.assertGreater(
SQNRdB,
30,
msg="Quantized model numerics diverge from float, expect SQNR > 30 dB",
)
def test_float_quant_compare_per_channel(self):
# Test for per-channel Quant
@ -47,7 +55,9 @@ class TestModelNumericsEager(QuantizationTestCase):
q_model = torch.ao.quantization.QuantWrapper(my_model)
q_model.eval()
q_model.qconfig = torch.ao.quantization.default_per_channel_qconfig
torch.ao.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.ao.quantization.fuse_modules(
q_model.module, [["conv1", "bn1", "relu1"]], inplace=True
)
torch.ao.quantization.prepare(q_model)
q_model(calib_data)
torch.ao.quantization.convert(q_model)
@ -55,7 +65,11 @@ class TestModelNumericsEager(QuantizationTestCase):
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q))
# Quantized model output should be close to floating point model output numerically
# Setting target SQNR to be 35 dB
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')
self.assertGreater(
SQNRdB,
35,
msg="Quantized model numerics diverge from float, expect SQNR > 35 dB",
)
def test_fake_quant_true_quant_compare(self):
for qengine in supported_qengines:
@ -69,7 +83,9 @@ class TestModelNumericsEager(QuantizationTestCase):
fq_model = torch.ao.quantization.QuantWrapper(my_model)
fq_model.train()
fq_model.qconfig = torch.ao.quantization.default_qat_qconfig
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.ao.quantization.fuse_modules_qat(
fq_model.module, [["conv1", "bn1", "relu1"]], inplace=True
)
torch.ao.quantization.prepare_qat(fq_model)
fq_model.eval()
fq_model.apply(torch.ao.quantization.disable_fake_quant)
@ -78,14 +94,26 @@ class TestModelNumericsEager(QuantizationTestCase):
fq_model.apply(torch.ao.quantization.enable_fake_quant)
fq_model.apply(torch.ao.quantization.disable_observer)
out_fq = fq_model(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
SQNRdB = 20 * torch.log10(
torch.norm(out_ref) / torch.norm(out_ref - out_fq)
)
# Quantized model output should be close to floating point model output numerically
# Setting target SQNR to be 35 dB
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')
self.assertGreater(
SQNRdB,
35,
msg="Quantized model numerics diverge from float, expect SQNR > 35 dB",
)
torch.ao.quantization.convert(fq_model)
out_q = fq_model(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10))
self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB')
SQNRdB = 20 * torch.log10(
torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10)
)
self.assertGreater(
SQNRdB,
60,
msg="Fake quant and true quant numerics diverge, expect SQNR > 60 dB",
)
# Test to compare weight only quantized model numerics and
# activation only quantized model numerics with float
@ -95,8 +123,10 @@ class TestModelNumericsEager(QuantizationTestCase):
torch.manual_seed(67)
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
qconfigset = {torch.ao.quantization.default_weight_only_qconfig,
torch.ao.quantization.default_activation_only_qconfig}
qconfigset = {
torch.ao.quantization.default_weight_only_qconfig,
torch.ao.quantization.default_activation_only_qconfig,
}
SQNRTarget = [35, 45]
for idx, qconfig in enumerate(qconfigset):
my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
@ -105,7 +135,9 @@ class TestModelNumericsEager(QuantizationTestCase):
fq_model = torch.ao.quantization.QuantWrapper(my_model)
fq_model.train()
fq_model.qconfig = qconfig
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.ao.quantization.fuse_modules_qat(
fq_model.module, [["conv1", "bn1", "relu1"]], inplace=True
)
torch.ao.quantization.prepare_qat(fq_model)
fq_model.eval()
fq_model.apply(torch.ao.quantization.disable_fake_quant)
@ -114,11 +146,19 @@ class TestModelNumericsEager(QuantizationTestCase):
fq_model.apply(torch.ao.quantization.enable_fake_quant)
fq_model.apply(torch.ao.quantization.disable_observer)
out_fq = fq_model(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float')
SQNRdB = 20 * torch.log10(
torch.norm(out_ref) / torch.norm(out_ref - out_fq)
)
self.assertGreater(
SQNRdB,
SQNRTarget[idx],
msg="Quantized model numerics diverge from float",
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"
"instead."
)

View File

@ -2,43 +2,45 @@
# ruff: noqa: F841
import unittest
import torch
import torch.nn as nn
import torch.ao.nn.quantized as nnq
from torch.ao.quantization import (
DeQuantStub,
QuantStub,
convert,
default_qconfig,
prepare,
quantize,
quantize_dynamic,
)
import torch.nn as nn
from torch.ao.ns._numeric_suite import (
OutputLogger,
Shadow,
ShadowLogger,
compare_model_outputs,
compare_model_stub,
compare_weights,
prepare_model_outputs,
get_matching_activations,
OutputLogger,
prepare_model_outputs,
Shadow,
ShadowLogger,
)
from torch.ao.quantization import (
convert,
default_qconfig,
DeQuantStub,
prepare,
quantize,
quantize_dynamic,
QuantStub,
)
from torch.testing._internal.common_quantization import (
AnnotatedConvBnReLUModel,
AnnotatedConvModel,
AnnotatedConvTransposeModel,
AnnotatedSingleLayerLinearModel,
LSTMwithHiddenDynamicModel,
AnnotatedTwoLayerLinearModel,
LSTMwithHiddenDynamicModel,
QuantizationTestCase,
SingleLayerLinearDynamicModel,
test_only_eval_fn,
skip_if_no_torchvision,
test_only_eval_fn,
)
from torch.testing._internal.common_quantized import override_qengines
from torch.testing._internal.common_utils import IS_ARM64
class SubModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -200,10 +202,18 @@ class TestNumericSuiteEager(QuantizationTestCase):
for i, val in enumerate(v["quantized"]):
self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
model_list = [AnnotatedConvModel(qengine),
AnnotatedConvTransposeModel("qnnpack"), # ConvT cannot use per channel weights
AnnotatedConvBnReLUModel(qengine)]
module_swap_list = [nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d, nn.ConvTranspose2d]
model_list = [
AnnotatedConvModel(qengine),
AnnotatedConvTransposeModel(
"qnnpack"
), # ConvT cannot use per channel weights
AnnotatedConvBnReLUModel(qengine),
]
module_swap_list = [
nn.Conv2d,
nn.intrinsic.modules.fused.ConvReLU2d,
nn.ConvTranspose2d,
]
for model in model_list:
model.eval()
if hasattr(model, "fuse_model"):
@ -279,7 +289,6 @@ class TestNumericSuiteEager(QuantizationTestCase):
self.assertTrue(isinstance(q_model.mod1, Shadow))
self.assertFalse(isinstance(q_model.conv, Shadow))
@override_qengines
def test_compare_model_stub_functional_static(self):
r"""Compare the output of static quantized functional layer and its float shadow module"""
@ -486,7 +495,9 @@ class TestNumericSuiteEager(QuantizationTestCase):
for i, val in enumerate(v["quantized"]):
self.assertTrue(len(v["float"][i]) == len(v["quantized"][i]))
if i == 0:
self.assertTrue(v["float"][i][0].shape == v["quantized"][i][0].shape)
self.assertTrue(
v["float"][i][0].shape == v["quantized"][i][0].shape
)
else:
self.assertTrue(
v["float"][i][0].shape == v["quantized"][i][0].shape
@ -540,12 +551,23 @@ class TestNumericSuiteEager(QuantizationTestCase):
@skip_if_no_torchvision
def _test_vision_model(self, float_model):
float_model.to('cpu')
float_model.to("cpu")
float_model.eval()
float_model.fuse_model()
float_model.qconfig = torch.ao.quantization.default_qconfig
img_data = [(torch.rand(2, 3, 224, 224, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)]
qmodel = quantize(float_model, torch.ao.quantization.default_eval_fn, [img_data], inplace=False)
img_data = [
(
torch.rand(2, 3, 224, 224, dtype=torch.float),
torch.randint(0, 1, (2,), dtype=torch.long),
)
for _ in range(2)
]
qmodel = quantize(
float_model,
torch.ao.quantization.default_eval_fn,
[img_data],
inplace=False,
)
wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
@ -560,9 +582,11 @@ class TestNumericSuiteEager(QuantizationTestCase):
# 'quantized', containing the activations of floating point and quantized model at matching locations.
act_compare_dict = compare_model_outputs(float_model, qmodel, data)
for key in act_compare_dict:
compute_error(act_compare_dict[key]['float'][0], act_compare_dict[key]['quantized'][0].dequantize())
compute_error(
act_compare_dict[key]["float"][0],
act_compare_dict[key]["quantized"][0].dequantize(),
)
prepare_model_outputs(float_model, qmodel)
@ -579,10 +603,12 @@ class TestNumericSuiteEager(QuantizationTestCase):
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
def test_mobilenet_v2(self):
from torchvision.models.quantization import mobilenet_v2
self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False))
@skip_if_no_torchvision
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
def test_mobilenet_v3(self):
from torchvision.models.quantization import mobilenet_v3_large
self._test_vision_model(mobilenet_v3_large(pretrained=True, quantize=False))

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,8 @@
import copy
import math
from hypothesis import given, strategies as st
import torch
import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.qat as nnqat
@ -12,8 +14,6 @@ import torch.ao.nn.quantized.dynamic as nnqd
import torch.backends.mkldnn
import torch.nn as nn
import torch.testing._internal.hypothesis_utils as hu
from hypothesis import given, strategies as st
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.ao.quantization import (
convert,
@ -50,42 +50,63 @@ from torch.testing._internal.common_quantization import (
test_only_train_fn,
TwoLayerLinearModel,
)
from torch.testing._internal.common_quantized import (
override_qengines,
override_quantized_engine,
supported_qengines,
)
from torch.testing._internal.common_utils import skipIfNoXNNPACK
hu.assert_deadline_disabled()
from functools import reduce
class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
"""
Conv-BN fusion implemented with explicit folding. Useful
to verify numerical equivalency with non-folded version.
"""
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups,
bias,
padding_mode,
# BatchNormNd args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
stride, padding, dilation, transposed,
output_padding, groups, False, padding_mode)
assert qconfig, 'qconfig must be provided for QAT module'
def __init__(
self,
# ConvNd args
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
bias,
padding_mode,
# BatchNormNd args
# num_features: out_channels
eps=1e-05,
momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None,
):
nn.modules.conv._ConvNd.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
False,
padding_mode,
)
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig
self.eps = eps
self.momentum = momentum
@ -103,7 +124,7 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
if bias:
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)
self.reset_bn_parameters()
def reset_running_stats(self):
@ -123,7 +144,7 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
def reset_parameters(self):
super().reset_parameters()
# A hack to avoid resetting on undefined parameters
if hasattr(self, 'gamma'):
if hasattr(self, "gamma"):
self.reset_bn_parameters()
def update_bn_stats(self):
@ -161,33 +182,50 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
if self.bias is not None:
zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
else:
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype)
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias)
zero_bias = torch.zeros(
self.out_channels, device=scaled_weight.device, dtype=input.dtype
)
conv = self._conv_forward(
input, self.weight_fake_quant(scaled_weight), zero_bias
)
if self.training and not self.freeze_bn:
# recovering original conv to get original batch_mean and batch_var
if self.bias is not None:
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])
conv_orig = conv / scale_factor.reshape(
[1, -1, 1, 1]
) + self.bias.reshape([1, -1, 1, 1])
else:
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
n = float(conv_orig.numel() / conv_orig.size()[1])
unbiased_batch_var = batch_var * (n / (n - 1))
batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps)
batch_rstd = torch.ones_like(
batch_var, memory_format=torch.contiguous_format
) / torch.sqrt(batch_var + self.eps)
conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + \
(self.beta - self.gamma * batch_rstd * batch_mean).reshape([1, -1, 1, 1])
self.running_mean = exponential_average_factor * batch_mean.detach() + \
(1 - exponential_average_factor) * self.running_mean
self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \
(1 - exponential_average_factor) * self.running_var
conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + (
self.beta - self.gamma * batch_rstd * batch_mean
).reshape([1, -1, 1, 1])
self.running_mean = (
exponential_average_factor * batch_mean.detach()
+ (1 - exponential_average_factor) * self.running_mean
)
self.running_var = (
exponential_average_factor * unbiased_batch_var.detach()
+ (1 - exponential_average_factor) * self.running_var
)
else:
if self.bias is None:
conv = conv + (self.beta - self.gamma * self.running_mean /
running_std).reshape([1, -1, 1, 1])
conv = conv + (
self.beta - self.gamma * self.running_mean / running_std
).reshape([1, -1, 1, 1])
else:
conv = conv + (self.gamma * (self.bias - self.running_mean) / running_std + self.beta).reshape([1, -1, 1, 1])
conv = conv + (
self.gamma * (self.bias - self.running_mean) / running_std
+ self.beta
).reshape([1, -1, 1, 1])
return conv
def extra_repr(self):
@ -200,23 +238,37 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
@classmethod
def from_float(cls, mod, qconfig=None):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
assert type(mod) == cls._FLOAT_MODULE, (
"qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
if not qconfig:
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
assert hasattr(
mod, "qconfig"
), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
qconfig = mod.qconfig
conv, bn = mod[0], mod[1]
qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
conv.stride, conv.padding, conv.dilation,
conv.groups, conv.bias is not None,
conv.padding_mode,
bn.eps, bn.momentum,
False,
qconfig)
qat_convbn = cls(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
conv.dilation,
conv.groups,
conv.bias is not None,
conv.padding_mode,
bn.eps,
bn.momentum,
False,
qconfig,
)
qat_convbn.weight = conv.weight
qat_convbn.bias = conv.bias
qat_convbn.gamma = bn.weight
@ -226,41 +278,69 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
qat_convbn.num_batches_tracked = bn.num_batches_tracked
return qat_convbn
class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d):
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm2d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
def __init__(
self,
# ConvNd args
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=None,
padding_mode="zeros",
# BatchNorm2d args
# num_features: out_channels
eps=1e-05,
momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None,
):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
_ReferenceConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, False, _pair(0), groups, bias, padding_mode,
eps, momentum, freeze_bn, qconfig)
_ReferenceConvBnNd.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
False,
_pair(0),
groups,
bias,
padding_mode,
eps,
momentum,
freeze_bn,
qconfig,
)
class TestQuantizeEagerQAT(QuantizationTestCase):
def setUp(self):
super().setUp()
self.embed_linear_data_train = [[torch.randint(0, 10, (12, 12), dtype=torch.long),
torch.randn((12, 1), dtype=torch.float)]
for _ in range(2)]
self.embed_linear_data_train = [
[
torch.randint(0, 10, (12, 12), dtype=torch.long),
torch.randn((12, 1), dtype=torch.float),
]
for _ in range(2)
]
self.embed_data = [[torch.randint(0, 10, (12, 1))]]
def test_manual(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
@ -279,8 +359,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
checkQuantized(model)
model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn,
[self.train_data])
model = quantize_qat(
ManualLinearQATModel(qengine), test_only_train_fn, [self.train_data]
)
checkQuantized(model)
def test_dropout(self):
@ -301,8 +382,11 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
checkQuantized(model)
model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn,
[self.train_data])
model = quantize_qat(
ManualDropoutQATModel(qengine),
test_only_train_fn,
[self.train_data],
)
checkQuantized(model)
def test_eval_only_fake_quant(self):
@ -342,7 +426,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
checkQuantized(model)
model = ManualConvLinearQATModel()
model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
model = quantize_qat(
model, test_only_train_fn, [self.img_data_2d_train]
)
checkQuantized(model)
@skipIfNoXNNPACK
@ -351,7 +437,7 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
Supported only with qengine=qnnpack, which uses symmetric
kernels from xnnpack library."""
for qengine in supported_qengines:
if qengine != 'qnnpack':
if qengine != "qnnpack":
continue
with override_quantized_engine(qengine):
model = ManualConvLinearSymmQATModel()
@ -373,17 +459,20 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
checkQuantized(model)
model = ManualConvLinearSymmQATModel()
model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
model = quantize_qat(
model, test_only_train_fn, [self.img_data_2d_train]
)
checkQuantized(model)
def test_dynamic_qat_linear(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
# Dynamic QAT without memoryless observers should fail
with self.assertRaisesRegex(ValueError,
"Dynamic QAT requires a memoryless observer." +
"This means a MovingAverage observer with averaging constant equal to 1"
):
with self.assertRaisesRegex(
ValueError,
"Dynamic QAT requires a memoryless observer."
+ "This means a MovingAverage observer with averaging constant equal to 1",
):
model = ManualLinearDynamicQATModel(default_qat_qconfig)
model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
@ -409,14 +498,23 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
test_only_train_fn(model, self.embed_linear_data_train)
# make sure activation_post_process is inserted after Linear.
self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize)
self.assertEqual(
type(model.linear.activation_post_process),
FusedMovingAvgObsFakeQuantize,
)
# make sure that Embedding has a noop for activation.
self.assertEqual(type(model.emb.activation_post_process), NoopObserver)
# make sure that FakeQuant zero_points are correct dtype
self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)
self.assertEqual(
model.emb.weight_fake_quant.zero_point.dtype, torch.float32
)
self.assertEqual(
model.linear.weight_fake_quant.zero_point.dtype, torch.int32
)
model = convert(model, mapping=get_embedding_static_quant_module_mappings())
model = convert(
model, mapping=get_embedding_static_quant_module_mappings()
)
def checkQuantized(model):
# make sure Embedding is now a QuantizedEmbedding
@ -430,7 +528,6 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
checkQuantized(model)
def test_embedding_bag_linear(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
@ -442,9 +539,15 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
# make sure not activation_post_process is inserted for EmbeddingBag
self.assertFalse(hasattr(model, "activation_post_process"))
# make sure that FakeQuant zero_points are correct dtype
self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)
model = convert(model, mapping=get_embedding_static_quant_module_mappings())
self.assertEqual(
model.emb.weight_fake_quant.zero_point.dtype, torch.float32
)
self.assertEqual(
model.linear.weight_fake_quant.zero_point.dtype, torch.int32
)
model = convert(
model, mapping=get_embedding_static_quant_module_mappings()
)
def checkQuantized(model):
# Make sure EmbeddingBag is now a quantized EmbeddingBag.
@ -505,7 +608,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
torch.ao.quantization.prepare(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)
self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys()))
self.assertEqual(
set(model.state_dict().keys()), set(quant_state_dict.keys())
)
model.eval()
model.load_state_dict(quant_state_dict)
out = model(x)
@ -513,20 +618,19 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
@override_qengines
def test_forward_hooks_preserved(self):
r"""Test QAT on preserving pre forward and post forward hooks of original model
"""
r"""Test QAT on preserving pre forward and post forward hooks of original model"""
qengine = torch.backends.quantized.engine
model = QuantStubModel()
counter = {
'pre_forwards': 0,
'forwards': 0,
"pre_forwards": 0,
"forwards": 0,
}
def fw_pre_hook(h_module, input):
counter['pre_forwards'] += 1
counter["pre_forwards"] += 1
def fw_hook(h_module, input, output):
counter['forwards'] += 1
counter["forwards"] += 1
model.fc.register_forward_pre_hook(fw_pre_hook)
model.fc.register_forward_hook(fw_hook)
@ -537,15 +641,24 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
def checkHooksIsPresent(model, before_convert=True):
forward_hooks = 1
if before_convert:
self.assertEqual(len(model.quant._forward_hooks.values()), 1,
"Quantization observer hook has disappeared")
self.assertEqual(
len(model.quant._forward_hooks.values()),
1,
"Quantization observer hook has disappeared",
)
forward_hooks = 2
self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1,
"Extra pre forward hooks have appeared on a layer")
self.assertEqual(len(model.fc._forward_hooks.values()), forward_hooks,
"Extra post forward hooks have appeared on a layer")
self.assertEqual(
len(model.fc._forward_pre_hooks.values()),
1,
"Extra pre forward hooks have appeared on a layer",
)
self.assertEqual(
len(model.fc._forward_hooks.values()),
forward_hooks,
"Extra post forward hooks have appeared on a layer",
)
checkHooksIsPresent(model, True)
x = torch.rand(2, 5, dtype=torch.float)
@ -600,32 +713,40 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine)
# Test constructor parameters checks here.
with self.assertRaisesRegex(AssertionError,
"qconfig must be provided for QAT module"):
with self.assertRaisesRegex(
AssertionError, "qconfig must be provided for QAT module"
):
nnqat.EmbeddingBag(10, 5, qconfig=None)
with self.assertRaisesRegex(AssertionError,
"Embedding Bag weights requires a qscheme of " +
"torch.per_channel_affine_float_qparams"):
with self.assertRaisesRegex(
AssertionError,
"Embedding Bag weights requires a qscheme of "
+ "torch.per_channel_affine_float_qparams",
):
nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig)
# Test from_float checks here.
embed = nn.Embedding(10, 5)
with self.assertRaisesRegex(AssertionError,
"qat.EmbeddingBag.from_float only works for EmbeddingBag"):
with self.assertRaisesRegex(
AssertionError, "qat.EmbeddingBag.from_float only works for EmbeddingBag"
):
nnqat.EmbeddingBag.from_float(embed)
embed_bag = nn.EmbeddingBag(10, 5)
with self.assertRaisesRegex(AssertionError,
"Input float module must have qconfig defined"):
with self.assertRaisesRegex(
AssertionError, "Input float module must have qconfig defined"
):
nnqat.EmbeddingBag.from_float(embed_bag)
embed_bag.qconfig = None
with self.assertRaisesRegex(AssertionError,
"Input float module must have a valid qconfig"):
with self.assertRaisesRegex(
AssertionError, "Input float module must have a valid qconfig"
):
nnqat.EmbeddingBag.from_float(embed_bag)
embed_bag.qconfig = default_qat_qconfig
with self.assertRaisesRegex(AssertionError,
"Embedding Bag weights requires a qscheme of " +
"torch.per_channel_affine_float_qparams"):
with self.assertRaisesRegex(
AssertionError,
"Embedding Bag weights requires a qscheme of "
+ "torch.per_channel_affine_float_qparams",
):
nnqat.EmbeddingBag.from_float(embed_bag)
def test_embedding_qat_qconfig_equal(self):
@ -636,8 +757,10 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
model = ManualEmbeddingBagLinear().train()
model = prepare_qat(model)
self.assertTrue(qconfig_equals(model.emb.qconfig,
default_embedding_qat_qconfig))
self.assertTrue(
qconfig_equals(model.emb.qconfig, default_embedding_qat_qconfig)
)
class TestQuantizeEagerQATNumerics(QuantizationTestCase):
def _test_activation_convert_numerics_impl(self, Act, data):
@ -683,24 +806,26 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
m = M().train()
m.qconfig = default_qat_qconfig
m = prepare_qat(m)
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize)
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
self.assertEqual(
type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize
)
data = torch.randn(1, 3, 2, 4)
before_convert = m(data)
m = convert(m)
after_convert = m(data)
self.assertEqual(before_convert, after_convert)
# make sure activation post process is removed
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
# verify fake quant module is removd
self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process'))
self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
# verify that hooks are removed
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
# make sure no fake quantize module is inserted for eval mode
def checkNoFQModule(m):
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
@ -734,50 +859,52 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
# make sure ReLU module is not changed
self.assertTrue(type(m.relu), nn.ReLU)
@given(batch_size=st.integers(2, 4),
input_channels_per_group=st.sampled_from([2, 3, 4]),
height=st.integers(5, 10),
width=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 3),
kernel_w=st.integers(1, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(['zeros', 'circular']),
use_relu=st.booleans(),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans(),
zero_gamma=st.booleans(),
has_bias=st.booleans(),
use_slow_fusion=st.booleans())
@given(
batch_size=st.integers(2, 4),
input_channels_per_group=st.sampled_from([2, 3, 4]),
height=st.integers(5, 10),
width=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 3),
kernel_w=st.integers(1, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(["zeros", "circular"]),
use_relu=st.booleans(),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans(),
zero_gamma=st.booleans(),
has_bias=st.booleans(),
use_slow_fusion=st.booleans(),
)
def test_conv_bn_relu(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
use_relu,
eps,
momentum,
freeze_bn,
zero_gamma,
has_bias,
use_slow_fusion,
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
use_relu,
eps,
momentum,
freeze_bn,
zero_gamma,
has_bias,
use_slow_fusion,
):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
@ -792,7 +919,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
(dilation_h, dilation_w),
groups,
has_bias,
padding_mode
padding_mode,
).to(dtype=torch.double)
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
relu_op = ReLU()
@ -811,7 +938,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
eps,
momentum,
freeze_bn=True,
qconfig=default_qat_qconfig
qconfig=default_qat_qconfig,
).to(dtype=torch.double)
qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion
@ -826,7 +953,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats)
# align inputs and internal parameters
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
input = torch.randn(
batch_size,
input_channels,
height,
width,
dtype=torch.double,
requires_grad=True,
)
conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
if has_bias:
conv_op.bias = torch.nn.Parameter(qat_op.bias.detach())
@ -840,17 +974,20 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
if not use_relu:
def relu_op(x): # noqa: F811
return x
if freeze_bn:
def ref_op(x):
x = conv_op(x)
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * (
bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)
).reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
x = relu_op(x)
return x
else:
ref_op = compose([conv_op, bn_op, relu_op])
@ -882,51 +1019,64 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
precision = 1e-10
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
self.assertEqual(
weight_grad_ref, weight_grad_actual, atol=precision, rtol=0
)
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
self.assertEqual(
num_batches_tracked_ref,
num_batches_tracked_actual,
atol=precision,
rtol=0,
)
self.assertEqual(
running_mean_ref, running_mean_actual, atol=precision, rtol=0
)
self.assertEqual(
running_var_ref, running_var_actual, atol=precision, rtol=0
)
@given(batch_size=st.integers(2, 4),
input_channels_per_group=st.sampled_from([2, 3, 4]),
height=st.integers(5, 10),
width=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 3),
kernel_w=st.integers(1, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(['zeros', 'circular']),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans(),
bias=st.booleans())
@given(
batch_size=st.integers(2, 4),
input_channels_per_group=st.sampled_from([2, 3, 4]),
height=st.integers(5, 10),
width=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 3),
kernel_w=st.integers(1, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(["zeros", "circular"]),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans(),
bias=st.booleans(),
)
def test_conv_bn_folded_vs_unfolded(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
eps,
momentum,
freeze_bn,
bias,
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
eps,
momentum,
freeze_bn,
bias,
):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
@ -945,7 +1095,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
eps,
momentum,
freeze_bn=freeze_bn,
qconfig=default_qat_qconfig
qconfig=default_qat_qconfig,
).to(dtype=torch.double)
qat_ref_op = _ReferenceConvBn2d(
@ -961,7 +1111,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
eps,
momentum,
freeze_bn=freeze_bn,
qconfig=default_qat_qconfig
qconfig=default_qat_qconfig,
).to(dtype=torch.double)
qat_op.apply(torch.ao.quantization.disable_fake_quant)
@ -981,7 +1131,6 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr)
for i in range(5):
# make sure that calling model.train() does not override the
# bn freeze setting
qat_op.train()
@ -990,7 +1139,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
qat_op_optim.zero_grad()
qat_ref_op_optim.zero_grad()
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
input = torch.randn(
batch_size,
input_channels,
height,
width,
dtype=torch.double,
requires_grad=True,
)
input_clone = input.detach().clone().requires_grad_()
if i > 2:
@ -1030,12 +1186,23 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
precision = 1e-5
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
self.assertEqual(
weight_grad_ref, weight_grad_actual, atol=precision, rtol=0
)
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
self.assertEqual(
num_batches_tracked_ref,
num_batches_tracked_actual,
atol=precision,
rtol=0,
)
self.assertEqual(
running_mean_ref, running_mean_actual, atol=precision, rtol=0
)
self.assertEqual(
running_var_ref, running_var_actual, atol=precision, rtol=0
)
qat_op_optim.step()
qat_ref_op_optim.step()
@ -1048,7 +1215,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
nn.BatchNorm1d(4),
)
m_ref_copy = copy.deepcopy(m_ref)
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [["0", "1"]])
qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
m_ref_copy[0].qconfig = qconfig
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
@ -1071,7 +1238,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
nn.BatchNorm1d(4),
)
m_ref_copy = copy.deepcopy(m_ref)
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [["0", "1"]])
qconfig = default_symmetric_qnnpack_qat_qconfig
m_ref_copy[0].qconfig = qconfig
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
@ -1093,14 +1260,13 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
)
data = torch.randn(4, 4)
m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']])
m = torch.ao.quantization.fuse_modules_qat(m, [["1", "2"]])
mp = prepare_qat(m)
mp(data)
mq = convert(mp)
self.assertTrue(type(mq[1]) == nnq.Linear)
self.assertTrue(type(mq[2]) == nn.Identity)
@skipIfNoXNNPACK
@override_qengines
def test_linear_precomputed_fake_quant(self):
@ -1124,10 +1290,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
m_ref.activation_post_process = activation
m_ref.qconfig = qconfig
m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True)
self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale)
self.assertTrue(
m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"
"instead."
)