mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
`E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162 Approved by: https://github.com/Skylion007
1963 lines
82 KiB
Python
1963 lines
82 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
# ruff: noqa: F841
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.ao.quantization.quantize_fx as quantize_fx
|
|
import torch.nn.functional as F
|
|
from torch.ao.quantization import QConfig, QConfigMapping
|
|
from torch.ao.quantization.fx._model_report.detector import (
|
|
DynamicStaticDetector,
|
|
InputWeightEqualizationDetector,
|
|
PerChannelDetector,
|
|
OutlierDetector,
|
|
)
|
|
from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver
|
|
from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer
|
|
from torch.ao.quantization.fx._model_report.model_report import ModelReport
|
|
from torch.ao.quantization.observer import (
|
|
HistogramObserver,
|
|
default_per_channel_weight_observer,
|
|
default_observer
|
|
)
|
|
from torch.ao.nn.intrinsic.modules.fused import ConvReLU2d, LinearReLU
|
|
from torch.testing._internal.common_quantization import (
|
|
ConvModel,
|
|
QuantizationTestCase,
|
|
SingleLayerLinearModel,
|
|
TwoLayerLinearModel,
|
|
skipIfNoFBGEMM,
|
|
skipIfNoQNNPACK,
|
|
override_quantized_engine,
|
|
)
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
|
|
|
|
"""
|
|
Partition of input domain:
|
|
|
|
Model contains: conv or linear, both conv and linear
|
|
Model contains: ConvTransposeNd (not supported for per_channel)
|
|
|
|
Model is: post training quantization model, quantization aware training model
|
|
Model is: composed with nn.Sequential, composed in class structure
|
|
|
|
QConfig utilizes per_channel weight observer, backend uses non per_channel weight observer
|
|
QConfig_dict uses only one default qconfig, Qconfig dict uses > 1 unique qconfigs
|
|
|
|
Partition on output domain:
|
|
|
|
There are possible changes / suggestions, there are no changes / suggestions
|
|
"""
|
|
|
|
# Default output for string if no optimizations are possible
|
|
DEFAULT_NO_OPTIMS_ANSWER_STRING = (
|
|
"Further Optimizations for backend {}: \nNo further per_channel optimizations possible."
|
|
)
|
|
|
|
# Example Sequential Model with multiple Conv and Linear with nesting involved
|
|
NESTED_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
|
|
torch.nn.Conv2d(3, 3, 2, 1),
|
|
torch.nn.Sequential(torch.nn.Linear(9, 27), torch.nn.ReLU()),
|
|
torch.nn.Linear(27, 27),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Conv2d(3, 3, 2, 1),
|
|
)
|
|
|
|
# Example Sequential Model with Conv sub-class example
|
|
LAZY_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
|
|
torch.nn.LazyConv2d(3, 3, 2, 1),
|
|
torch.nn.Sequential(torch.nn.Linear(5, 27), torch.nn.ReLU()),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(27, 27),
|
|
torch.nn.ReLU(),
|
|
torch.nn.LazyConv2d(3, 3, 2, 1),
|
|
)
|
|
|
|
# Example Sequential Model with Fusion directly built into model
|
|
FUSION_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
|
|
ConvReLU2d(torch.nn.Conv2d(3, 3, 2, 1), torch.nn.ReLU()),
|
|
torch.nn.Sequential(LinearReLU(torch.nn.Linear(9, 27), torch.nn.ReLU())),
|
|
LinearReLU(torch.nn.Linear(27, 27), torch.nn.ReLU()),
|
|
torch.nn.Conv2d(3, 3, 2, 1),
|
|
)
|
|
|
|
# Test class
|
|
# example model to use for tests
|
|
class ThreeOps(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = nn.Linear(3, 3)
|
|
self.bn = nn.BatchNorm2d(3)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.bn(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
def get_example_inputs(self):
|
|
return (torch.randn(1, 3, 3, 3),)
|
|
|
|
class TwoThreeOps(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.block1 = ThreeOps()
|
|
self.block2 = ThreeOps()
|
|
|
|
def forward(self, x):
|
|
x = self.block1(x)
|
|
y = self.block2(x)
|
|
z = x + y
|
|
z = F.relu(z)
|
|
return z
|
|
|
|
def get_example_inputs(self):
|
|
return (torch.randn(1, 3, 3, 3),)
|
|
|
|
class TestFxModelReportDetector(QuantizationTestCase):
|
|
|
|
"""Prepares and calibrate the model"""
|
|
|
|
def _prepare_model_and_run_input(self, model, q_config_mapping, input):
|
|
model_prep = torch.ao.quantization.quantize_fx.prepare_fx(model, q_config_mapping, input) # prep model
|
|
model_prep(input).sum() # calibrate the model
|
|
return model_prep
|
|
|
|
"""Case includes:
|
|
one conv or linear
|
|
post training quantization
|
|
composed as module
|
|
qconfig uses per_channel weight observer
|
|
Only 1 qconfig in qconfig dict
|
|
Output has no changes / suggestions
|
|
"""
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_simple_conv(self):
|
|
|
|
with override_quantized_engine('fbgemm'):
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
|
|
q_config_mapping = QConfigMapping()
|
|
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
|
|
|
input = torch.randn(1, 3, 10, 10)
|
|
prepared_model = self._prepare_model_and_run_input(ConvModel(), q_config_mapping, input)
|
|
|
|
# run the detector
|
|
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
|
|
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
|
|
|
|
# no optims possible and there should be nothing in per_channel_status
|
|
self.assertEqual(
|
|
optims_str,
|
|
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
|
|
)
|
|
|
|
# there should only be one conv there in this model
|
|
self.assertEqual(per_channel_info["conv"]["backend"], torch.backends.quantized.engine)
|
|
self.assertEqual(len(per_channel_info), 1)
|
|
self.assertEqual(next(iter(per_channel_info)), "conv")
|
|
self.assertEqual(
|
|
per_channel_info["conv"]["per_channel_quantization_supported"],
|
|
True,
|
|
)
|
|
self.assertEqual(per_channel_info["conv"]["per_channel_quantization_used"], True)
|
|
|
|
"""Case includes:
|
|
Multiple conv or linear
|
|
post training quantization
|
|
composed as module
|
|
qconfig doesn't use per_channel weight observer
|
|
Only 1 qconfig in qconfig dict
|
|
Output has possible changes / suggestions
|
|
"""
|
|
|
|
@skipIfNoQNNPACK
|
|
def test_multi_linear_model_without_per_channel(self):
|
|
|
|
with override_quantized_engine('qnnpack'):
|
|
torch.backends.quantized.engine = "qnnpack"
|
|
|
|
q_config_mapping = QConfigMapping()
|
|
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
|
|
|
prepared_model = self._prepare_model_and_run_input(
|
|
TwoLayerLinearModel(),
|
|
q_config_mapping,
|
|
TwoLayerLinearModel().get_example_inputs()[0],
|
|
)
|
|
|
|
# run the detector
|
|
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
|
|
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
|
|
|
|
# there should be optims possible
|
|
self.assertNotEqual(
|
|
optims_str,
|
|
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
|
|
)
|
|
# pick a random key to look at
|
|
rand_key: str = next(iter(per_channel_info.keys()))
|
|
self.assertEqual(per_channel_info[rand_key]["backend"], torch.backends.quantized.engine)
|
|
self.assertEqual(len(per_channel_info), 2)
|
|
|
|
# for each linear layer, should be supported but not used
|
|
for linear_key in per_channel_info.keys():
|
|
module_entry = per_channel_info[linear_key]
|
|
|
|
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
|
|
self.assertEqual(module_entry["per_channel_quantization_used"], False)
|
|
|
|
"""Case includes:
|
|
Multiple conv or linear
|
|
post training quantization
|
|
composed as Module
|
|
qconfig doesn't use per_channel weight observer
|
|
More than 1 qconfig in qconfig dict
|
|
Output has possible changes / suggestions
|
|
"""
|
|
|
|
@skipIfNoQNNPACK
|
|
def test_multiple_q_config_options(self):
|
|
|
|
with override_quantized_engine('qnnpack'):
|
|
torch.backends.quantized.engine = "qnnpack"
|
|
|
|
# qconfig with support for per_channel quantization
|
|
per_channel_qconfig = QConfig(
|
|
activation=HistogramObserver.with_args(reduce_range=True),
|
|
weight=default_per_channel_weight_observer,
|
|
)
|
|
|
|
# we need to design the model
|
|
class ConvLinearModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, 2, 1)
|
|
self.fc1 = torch.nn.Linear(9, 27)
|
|
self.relu = torch.nn.ReLU()
|
|
self.fc2 = torch.nn.Linear(27, 27)
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 2, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
x = self.fc2(x)
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
q_config_mapping = QConfigMapping()
|
|
q_config_mapping.set_global(
|
|
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
|
|
).set_object_type(torch.nn.Conv2d, per_channel_qconfig)
|
|
|
|
prepared_model = self._prepare_model_and_run_input(
|
|
ConvLinearModel(),
|
|
q_config_mapping,
|
|
torch.randn(1, 3, 10, 10),
|
|
)
|
|
|
|
# run the detector
|
|
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
|
|
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
|
|
|
|
# the only suggestions should be to linear layers
|
|
|
|
# there should be optims possible
|
|
self.assertNotEqual(
|
|
optims_str,
|
|
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
|
|
)
|
|
|
|
# to ensure it got into the nested layer
|
|
self.assertEqual(len(per_channel_info), 4)
|
|
|
|
# for each layer, should be supported but not used
|
|
for key in per_channel_info.keys():
|
|
module_entry = per_channel_info[key]
|
|
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
|
|
|
|
# if linear False, if conv2d true cuz it uses different config
|
|
if "fc" in key:
|
|
self.assertEqual(module_entry["per_channel_quantization_used"], False)
|
|
elif "conv" in key:
|
|
self.assertEqual(module_entry["per_channel_quantization_used"], True)
|
|
else:
|
|
raise ValueError("Should only contain conv and linear layers as key values")
|
|
|
|
"""Case includes:
|
|
Multiple conv or linear
|
|
post training quantization
|
|
composed as sequential
|
|
qconfig doesn't use per_channel weight observer
|
|
Only 1 qconfig in qconfig dict
|
|
Output has possible changes / suggestions
|
|
"""
|
|
|
|
@skipIfNoQNNPACK
|
|
def test_sequential_model_format(self):
|
|
|
|
with override_quantized_engine('qnnpack'):
|
|
torch.backends.quantized.engine = "qnnpack"
|
|
|
|
q_config_mapping = QConfigMapping()
|
|
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
|
|
|
prepared_model = self._prepare_model_and_run_input(
|
|
NESTED_CONV_LINEAR_EXAMPLE,
|
|
q_config_mapping,
|
|
torch.randn(1, 3, 10, 10),
|
|
)
|
|
|
|
# run the detector
|
|
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
|
|
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
|
|
|
|
# there should be optims possible
|
|
self.assertNotEqual(
|
|
optims_str,
|
|
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
|
|
)
|
|
|
|
# to ensure it got into the nested layer
|
|
self.assertEqual(len(per_channel_info), 4)
|
|
|
|
# for each layer, should be supported but not used
|
|
for key in per_channel_info.keys():
|
|
module_entry = per_channel_info[key]
|
|
|
|
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
|
|
self.assertEqual(module_entry["per_channel_quantization_used"], False)
|
|
|
|
"""Case includes:
|
|
Multiple conv or linear
|
|
post training quantization
|
|
composed as sequential
|
|
qconfig doesn't use per_channel weight observer
|
|
Only 1 qconfig in qconfig dict
|
|
Output has possible changes / suggestions
|
|
"""
|
|
|
|
@skipIfNoQNNPACK
|
|
def test_conv_sub_class_considered(self):
|
|
|
|
with override_quantized_engine('qnnpack'):
|
|
torch.backends.quantized.engine = "qnnpack"
|
|
|
|
q_config_mapping = QConfigMapping()
|
|
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
|
|
|
prepared_model = self._prepare_model_and_run_input(
|
|
LAZY_CONV_LINEAR_EXAMPLE,
|
|
q_config_mapping,
|
|
torch.randn(1, 3, 10, 10),
|
|
)
|
|
|
|
# run the detector
|
|
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
|
|
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
|
|
|
|
# there should be optims possible
|
|
self.assertNotEqual(
|
|
optims_str,
|
|
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
|
|
)
|
|
|
|
# to ensure it got into the nested layer and it considered the lazyConv2d
|
|
self.assertEqual(len(per_channel_info), 4)
|
|
|
|
# for each layer, should be supported but not used
|
|
for key in per_channel_info.keys():
|
|
module_entry = per_channel_info[key]
|
|
|
|
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
|
|
self.assertEqual(module_entry["per_channel_quantization_used"], False)
|
|
|
|
"""Case includes:
|
|
Multiple conv or linear
|
|
post training quantization
|
|
composed as sequential
|
|
qconfig uses per_channel weight observer
|
|
Only 1 qconfig in qconfig dict
|
|
Output has no possible changes / suggestions
|
|
"""
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_fusion_layer_in_sequential(self):
|
|
|
|
with override_quantized_engine('fbgemm'):
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
|
|
q_config_mapping = QConfigMapping()
|
|
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
|
|
|
prepared_model = self._prepare_model_and_run_input(
|
|
FUSION_CONV_LINEAR_EXAMPLE,
|
|
q_config_mapping,
|
|
torch.randn(1, 3, 10, 10),
|
|
)
|
|
|
|
# run the detector
|
|
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
|
|
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
|
|
|
|
# no optims possible and there should be nothing in per_channel_status
|
|
self.assertEqual(
|
|
optims_str,
|
|
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
|
|
)
|
|
|
|
# to ensure it got into the nested layer and it considered all the nested fusion components
|
|
self.assertEqual(len(per_channel_info), 4)
|
|
|
|
# for each layer, should be supported but not used
|
|
for key in per_channel_info.keys():
|
|
module_entry = per_channel_info[key]
|
|
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
|
|
self.assertEqual(module_entry["per_channel_quantization_used"], True)
|
|
|
|
"""Case includes:
|
|
Multiple conv or linear
|
|
quantitative aware training
|
|
composed as model
|
|
qconfig does not use per_channel weight observer
|
|
Only 1 qconfig in qconfig dict
|
|
Output has possible changes / suggestions
|
|
"""
|
|
|
|
@skipIfNoQNNPACK
|
|
def test_qat_aware_model_example(self):
|
|
|
|
# first we want a QAT model
|
|
class QATConvLinearReluModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
# QuantStub converts tensors from floating point to quantized
|
|
self.quant = torch.ao.quantization.QuantStub()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
self.bn = torch.nn.BatchNorm2d(1)
|
|
self.relu = torch.nn.ReLU()
|
|
# DeQuantStub converts tensors from quantized to floating point
|
|
self.dequant = torch.ao.quantization.DeQuantStub()
|
|
|
|
def forward(self, x):
|
|
x = self.quant(x)
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
x = self.relu(x)
|
|
x = self.dequant(x)
|
|
return x
|
|
|
|
with override_quantized_engine('qnnpack'):
|
|
# create a model instance
|
|
model_fp32 = QATConvLinearReluModel()
|
|
|
|
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig("qnnpack")
|
|
|
|
# model must be in eval mode for fusion
|
|
model_fp32.eval()
|
|
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [["conv", "bn", "relu"]])
|
|
|
|
# model must be set to train mode for QAT logic to work
|
|
model_fp32_fused.train()
|
|
|
|
# prepare the model for QAT, different than for post training quantization
|
|
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused)
|
|
|
|
# run the detector
|
|
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
|
|
optims_str, per_channel_info = per_channel_detector.generate_detector_report(model_fp32_prepared)
|
|
|
|
# there should be optims possible
|
|
self.assertNotEqual(
|
|
optims_str,
|
|
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
|
|
)
|
|
|
|
# make sure it was able to find the single conv in the fused model
|
|
self.assertEqual(len(per_channel_info), 1)
|
|
|
|
# for the one conv, it should still give advice to use different qconfig
|
|
for key in per_channel_info.keys():
|
|
module_entry = per_channel_info[key]
|
|
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
|
|
self.assertEqual(module_entry["per_channel_quantization_used"], False)
|
|
|
|
|
|
"""
|
|
Partition on Domain / Things to Test
|
|
|
|
- All zero tensor
|
|
- Multiple tensor dimensions
|
|
- All of the outward facing functions
|
|
- Epoch min max are correctly updating
|
|
- Batch range is correctly averaging as expected
|
|
- Reset for each epoch is correctly resetting the values
|
|
|
|
Partition on Output
|
|
- the calcuation of the ratio is occurring correctly
|
|
|
|
"""
|
|
|
|
|
|
class TestFxModelReportObserver(QuantizationTestCase):
|
|
class NestedModifiedSingleLayerLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.obs1 = ModelReportObserver()
|
|
self.mod1 = SingleLayerLinearModel()
|
|
self.obs2 = ModelReportObserver()
|
|
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.obs1(x)
|
|
x = self.mod1(x)
|
|
x = self.obs2(x)
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
def run_model_and_common_checks(self, model, ex_input, num_epochs, batch_size):
|
|
# split up data into batches
|
|
split_up_data = torch.split(ex_input, batch_size)
|
|
for epoch in range(num_epochs):
|
|
# reset all model report obs
|
|
model.apply(
|
|
lambda module: module.reset_batch_and_epoch_values()
|
|
if isinstance(module, ModelReportObserver)
|
|
else None
|
|
)
|
|
|
|
# quick check that a reset occurred
|
|
self.assertEqual(
|
|
model.obs1.average_batch_activation_range,
|
|
torch.tensor(float(0)),
|
|
)
|
|
self.assertEqual(model.obs1.epoch_activation_min, torch.tensor(float("inf")))
|
|
self.assertEqual(model.obs1.epoch_activation_max, torch.tensor(float("-inf")))
|
|
|
|
# loop through the batches and run through
|
|
for index, batch in enumerate(split_up_data):
|
|
|
|
num_tracked_so_far = model.obs1.num_batches_tracked
|
|
self.assertEqual(num_tracked_so_far, index)
|
|
|
|
# get general info about the batch and the model to use later
|
|
batch_min, batch_max = torch.aminmax(batch)
|
|
current_average_range = model.obs1.average_batch_activation_range
|
|
current_epoch_min = model.obs1.epoch_activation_min
|
|
current_epoch_max = model.obs1.epoch_activation_max
|
|
|
|
# run input through
|
|
model(ex_input)
|
|
|
|
# check that average batch activation range updated correctly
|
|
correct_updated_value = (current_average_range * num_tracked_so_far + (batch_max - batch_min)) / (
|
|
num_tracked_so_far + 1
|
|
)
|
|
self.assertEqual(
|
|
model.obs1.average_batch_activation_range,
|
|
correct_updated_value,
|
|
)
|
|
|
|
if current_epoch_max - current_epoch_min > 0:
|
|
self.assertEqual(
|
|
model.obs1.get_batch_to_epoch_ratio(),
|
|
correct_updated_value / (current_epoch_max - current_epoch_min),
|
|
)
|
|
|
|
"""Case includes:
|
|
all zero tensor
|
|
dim size = 2
|
|
run for 1 epoch
|
|
run for 10 batch
|
|
tests input data observer
|
|
"""
|
|
|
|
def test_zero_tensor_errors(self):
|
|
# initialize the model
|
|
model = self.NestedModifiedSingleLayerLinear()
|
|
|
|
# generate the desired input
|
|
ex_input = torch.zeros((10, 1, 5))
|
|
|
|
# run it through the model and do general tests
|
|
self.run_model_and_common_checks(model, ex_input, 1, 1)
|
|
|
|
# make sure final values are all 0
|
|
self.assertEqual(model.obs1.epoch_activation_min, 0)
|
|
self.assertEqual(model.obs1.epoch_activation_max, 0)
|
|
self.assertEqual(model.obs1.average_batch_activation_range, 0)
|
|
|
|
# we should get an error if we try to calculate the ratio
|
|
with self.assertRaises(ValueError):
|
|
ratio_val = model.obs1.get_batch_to_epoch_ratio()
|
|
|
|
"""Case includes:
|
|
non-zero tensor
|
|
dim size = 2
|
|
run for 1 epoch
|
|
run for 1 batch
|
|
tests input data observer
|
|
"""
|
|
|
|
def test_single_batch_of_ones(self):
|
|
# initialize the model
|
|
model = self.NestedModifiedSingleLayerLinear()
|
|
|
|
# generate the desired input
|
|
ex_input = torch.ones((1, 1, 5))
|
|
|
|
# run it through the model and do general tests
|
|
self.run_model_and_common_checks(model, ex_input, 1, 1)
|
|
|
|
# make sure final values are all 0 except for range
|
|
self.assertEqual(model.obs1.epoch_activation_min, 1)
|
|
self.assertEqual(model.obs1.epoch_activation_max, 1)
|
|
self.assertEqual(model.obs1.average_batch_activation_range, 0)
|
|
|
|
# we should get an error if we try to calculate the ratio
|
|
with self.assertRaises(ValueError):
|
|
ratio_val = model.obs1.get_batch_to_epoch_ratio()
|
|
|
|
"""Case includes:
|
|
non-zero tensor
|
|
dim size = 2
|
|
run for 10 epoch
|
|
run for 15 batch
|
|
tests non input data observer
|
|
"""
|
|
|
|
def test_observer_after_relu(self):
|
|
|
|
# model specific to this test
|
|
class NestedModifiedObserverAfterRelu(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.obs1 = ModelReportObserver()
|
|
self.mod1 = SingleLayerLinearModel()
|
|
self.obs2 = ModelReportObserver()
|
|
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.obs1(x)
|
|
x = self.mod1(x)
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
x = self.obs2(x)
|
|
return x
|
|
|
|
# initialize the model
|
|
model = NestedModifiedObserverAfterRelu()
|
|
|
|
# generate the desired input
|
|
ex_input = torch.randn((15, 1, 5))
|
|
|
|
# run it through the model and do general tests
|
|
self.run_model_and_common_checks(model, ex_input, 10, 15)
|
|
|
|
"""Case includes:
|
|
non-zero tensor
|
|
dim size = 2
|
|
run for multiple epoch
|
|
run for multiple batch
|
|
tests input data observer
|
|
"""
|
|
|
|
def test_random_epochs_and_batches(self):
|
|
|
|
# set up a basic model
|
|
class TinyNestModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.obs1 = ModelReportObserver()
|
|
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
|
self.relu = torch.nn.ReLU()
|
|
self.obs2 = ModelReportObserver()
|
|
|
|
def forward(self, x):
|
|
x = self.obs1(x)
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
x = self.obs2(x)
|
|
return x
|
|
|
|
class LargerIncludeNestModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.obs1 = ModelReportObserver()
|
|
self.nested = TinyNestModule()
|
|
self.fc1 = SingleLayerLinearModel()
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.obs1(x)
|
|
x = self.nested(x)
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
class ModifiedThreeOps(torch.nn.Module):
|
|
def __init__(self, batch_norm_dim):
|
|
super().__init__()
|
|
self.obs1 = ModelReportObserver()
|
|
self.linear = torch.nn.Linear(7, 3, 2)
|
|
self.obs2 = ModelReportObserver()
|
|
|
|
if batch_norm_dim == 2:
|
|
self.bn = torch.nn.BatchNorm2d(2)
|
|
elif batch_norm_dim == 3:
|
|
self.bn = torch.nn.BatchNorm3d(4)
|
|
else:
|
|
raise ValueError("Dim should only be 2 or 3")
|
|
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.obs1(x)
|
|
x = self.linear(x)
|
|
x = self.obs2(x)
|
|
x = self.bn(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
class HighDimensionNet(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.obs1 = ModelReportObserver()
|
|
self.fc1 = torch.nn.Linear(3, 7)
|
|
self.block1 = ModifiedThreeOps(3)
|
|
self.fc2 = torch.nn.Linear(3, 7)
|
|
self.block2 = ModifiedThreeOps(3)
|
|
self.fc3 = torch.nn.Linear(3, 7)
|
|
|
|
def forward(self, x):
|
|
x = self.obs1(x)
|
|
x = self.fc1(x)
|
|
x = self.block1(x)
|
|
x = self.fc2(x)
|
|
y = self.block2(x)
|
|
y = self.fc3(y)
|
|
z = x + y
|
|
z = F.relu(z)
|
|
return z
|
|
|
|
# the purpose of this test is to give the observers a variety of data examples
|
|
# initialize the model
|
|
models = [
|
|
self.NestedModifiedSingleLayerLinear(),
|
|
LargerIncludeNestModel(),
|
|
ModifiedThreeOps(2),
|
|
HighDimensionNet(),
|
|
]
|
|
|
|
# get some number of epochs and batches
|
|
num_epochs = 10
|
|
num_batches = 15
|
|
|
|
input_shapes = [(1, 5), (1, 5), (2, 3, 7), (4, 1, 8, 3)]
|
|
|
|
# generate the desired inputs
|
|
inputs = []
|
|
for shape in input_shapes:
|
|
ex_input = torch.randn((num_batches, *shape))
|
|
inputs.append(ex_input)
|
|
|
|
# run it through the model and do general tests
|
|
for index, model in enumerate(models):
|
|
self.run_model_and_common_checks(model, inputs[index], num_epochs, num_batches)
|
|
|
|
|
|
"""
|
|
Partition on domain / things to test
|
|
|
|
There is only a single test case for now.
|
|
|
|
This will be more thoroughly tested with the implementation of the full end to end tool coming soon.
|
|
"""
|
|
|
|
|
|
class TestFxModelReportDetectDynamicStatic(QuantizationTestCase):
|
|
@skipIfNoFBGEMM
|
|
def test_nested_detection_case(self):
|
|
class SingleLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
class TwoBlockNet(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.block1 = SingleLinear()
|
|
self.block2 = SingleLinear()
|
|
|
|
def forward(self, x):
|
|
x = self.block1(x)
|
|
y = self.block2(x)
|
|
z = x + y
|
|
z = F.relu(z)
|
|
return z
|
|
|
|
|
|
with override_quantized_engine('fbgemm'):
|
|
# create model, example input, and qconfig mapping
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
model = TwoBlockNet()
|
|
example_input = torch.randint(-10, 0, (1, 3, 3, 3))
|
|
example_input = example_input.to(torch.float)
|
|
q_config_mapping = QConfigMapping()
|
|
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig("fbgemm"))
|
|
|
|
# prep model and select observer
|
|
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
|
|
obs_ctr = ModelReportObserver
|
|
|
|
# find layer to attach to and store
|
|
linear_fqn = "block2.linear" # fqn of target linear
|
|
|
|
target_linear = None
|
|
for node in model_prep.graph.nodes:
|
|
if node.target == linear_fqn:
|
|
target_linear = node
|
|
break
|
|
|
|
# insert into both module and graph pre and post
|
|
|
|
# set up to insert before target_linear (pre_observer)
|
|
with model_prep.graph.inserting_before(target_linear):
|
|
obs_to_insert = obs_ctr()
|
|
pre_obs_fqn = linear_fqn + ".model_report_pre_observer"
|
|
model_prep.add_submodule(pre_obs_fqn, obs_to_insert)
|
|
model_prep.graph.create_node(op="call_module", target=pre_obs_fqn, args=target_linear.args)
|
|
|
|
# set up and insert after the target_linear (post_observer)
|
|
with model_prep.graph.inserting_after(target_linear):
|
|
obs_to_insert = obs_ctr()
|
|
post_obs_fqn = linear_fqn + ".model_report_post_observer"
|
|
model_prep.add_submodule(post_obs_fqn, obs_to_insert)
|
|
model_prep.graph.create_node(op="call_module", target=post_obs_fqn, args=(target_linear,))
|
|
|
|
# need to recompile module after submodule added and pass input through
|
|
model_prep.recompile()
|
|
|
|
num_iterations = 10
|
|
for i in range(num_iterations):
|
|
if i % 2 == 0:
|
|
example_input = torch.randint(-10, 0, (1, 3, 3, 3)).to(torch.float)
|
|
else:
|
|
example_input = torch.randint(0, 10, (1, 3, 3, 3)).to(torch.float)
|
|
model_prep(example_input)
|
|
|
|
# run it through the dynamic vs static detector
|
|
dynamic_vs_static_detector = DynamicStaticDetector()
|
|
dynam_vs_stat_str, dynam_vs_stat_dict = dynamic_vs_static_detector.generate_detector_report(model_prep)
|
|
|
|
# one of the stats should be stationary, and the other non-stationary
|
|
# as a result, dynamic should be recommended
|
|
data_dist_info = [
|
|
dynam_vs_stat_dict[linear_fqn][DynamicStaticDetector.PRE_OBS_DATA_DIST_KEY],
|
|
dynam_vs_stat_dict[linear_fqn][DynamicStaticDetector.POST_OBS_DATA_DIST_KEY],
|
|
]
|
|
|
|
self.assertTrue("stationary" in data_dist_info)
|
|
self.assertTrue("non-stationary" in data_dist_info)
|
|
self.assertTrue(dynam_vs_stat_dict[linear_fqn]["dynamic_recommended"])
|
|
|
|
class TestFxModelReportClass(QuantizationTestCase):
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_constructor(self):
|
|
"""
|
|
Tests the constructor of the ModelReport class.
|
|
Specifically looks at:
|
|
- The desired reports
|
|
- Ensures that the observers of interest are properly initialized
|
|
"""
|
|
|
|
with override_quantized_engine('fbgemm'):
|
|
# set the backend for this test
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
backend = torch.backends.quantized.engine
|
|
|
|
# create a model
|
|
model = ThreeOps()
|
|
q_config_mapping = QConfigMapping()
|
|
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
|
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, model.get_example_inputs()[0])
|
|
|
|
# make an example set of detectors
|
|
test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)}
|
|
# initialize with an empty detector
|
|
model_report = ModelReport(model_prep, test_detector_set)
|
|
|
|
# make sure internal valid reports matches
|
|
detector_name_set = {detector.get_detector_name() for detector in test_detector_set}
|
|
self.assertEqual(model_report.get_desired_reports_names(), detector_name_set)
|
|
|
|
# now attempt with no valid reports, should raise error
|
|
with self.assertRaises(ValueError):
|
|
model_report = ModelReport(model, set())
|
|
|
|
# number of expected obs of interest entries
|
|
num_expected_entries = len(test_detector_set)
|
|
self.assertEqual(len(model_report.get_observers_of_interest()), num_expected_entries)
|
|
|
|
for value in model_report.get_observers_of_interest().values():
|
|
self.assertEqual(len(value), 0)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_prepare_model_callibration(self):
|
|
"""
|
|
Tests model_report.prepare_detailed_calibration that prepares the model for callibration
|
|
Specifically looks at:
|
|
- Whether observers are properly inserted into regular nn.Module
|
|
- Whether the target and the arguments of the observers are proper
|
|
- Whether the internal representation of observers of interest is updated
|
|
"""
|
|
|
|
with override_quantized_engine('fbgemm'):
|
|
# create model report object
|
|
|
|
# create model
|
|
model = TwoThreeOps()
|
|
# make an example set of detectors
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
backend = torch.backends.quantized.engine
|
|
test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)}
|
|
# initialize with an empty detector
|
|
|
|
# prepare the model
|
|
example_input = model.get_example_inputs()[0]
|
|
current_backend = torch.backends.quantized.engine
|
|
q_config_mapping = QConfigMapping()
|
|
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
|
|
|
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
|
|
|
|
model_report = ModelReport(model_prep, test_detector_set)
|
|
|
|
# prepare the model for callibration
|
|
prepared_for_callibrate_model = model_report.prepare_detailed_calibration()
|
|
|
|
# see whether observers properly in regular nn.Module
|
|
# there should be 4 observers present in this case
|
|
modules_observer_cnt = 0
|
|
for fqn, module in prepared_for_callibrate_model.named_modules():
|
|
if isinstance(module, ModelReportObserver):
|
|
modules_observer_cnt += 1
|
|
|
|
self.assertEqual(modules_observer_cnt, 4)
|
|
|
|
model_report_str_check = "model_report"
|
|
# also make sure arguments for observers in the graph are proper
|
|
for node in prepared_for_callibrate_model.graph.nodes:
|
|
# not all node targets are strings, so check
|
|
if isinstance(node.target, str) and model_report_str_check in node.target:
|
|
# if pre-observer has same args as the linear (next node)
|
|
if "pre_observer" in node.target:
|
|
self.assertEqual(node.args, node.next.args)
|
|
# if post-observer, args are the target linear (previous node)
|
|
if "post_observer" in node.target:
|
|
self.assertEqual(node.args, (node.prev,))
|
|
|
|
# ensure model_report observers of interest updated
|
|
# there should be two entries
|
|
self.assertEqual(len(model_report.get_observers_of_interest()), 2)
|
|
for detector in test_detector_set:
|
|
self.assertTrue(detector.get_detector_name() in model_report.get_observers_of_interest().keys())
|
|
|
|
# get number of entries for this detector
|
|
detector_obs_of_interest_fqns = model_report.get_observers_of_interest()[detector.get_detector_name()]
|
|
|
|
# assert that the per channel detector has 0 and the dynamic static has 4
|
|
if isinstance(detector, PerChannelDetector):
|
|
self.assertEqual(len(detector_obs_of_interest_fqns), 0)
|
|
elif isinstance(detector, DynamicStaticDetector):
|
|
self.assertEqual(len(detector_obs_of_interest_fqns), 4)
|
|
|
|
# ensure that we can prepare for callibration only once
|
|
with self.assertRaises(ValueError):
|
|
prepared_for_callibrate_model = model_report.prepare_detailed_calibration()
|
|
|
|
|
|
def get_module_and_graph_cnts(self, callibrated_fx_module):
|
|
r"""
|
|
Calculates number of ModelReportObserver modules in the model as well as the graph structure.
|
|
Returns a tuple of two elements:
|
|
int: The number of ModelReportObservers found in the model
|
|
int: The number of model_report nodes found in the graph
|
|
"""
|
|
# get the number of observers stored as modules
|
|
modules_observer_cnt = 0
|
|
for fqn, module in callibrated_fx_module.named_modules():
|
|
if isinstance(module, ModelReportObserver):
|
|
modules_observer_cnt += 1
|
|
|
|
# get number of observers in the graph
|
|
model_report_str_check = "model_report"
|
|
graph_observer_cnt = 0
|
|
# also make sure arguments for observers in the graph are proper
|
|
for node in callibrated_fx_module.graph.nodes:
|
|
# not all node targets are strings, so check
|
|
if isinstance(node.target, str) and model_report_str_check in node.target:
|
|
# increment if we found a graph observer
|
|
graph_observer_cnt += 1
|
|
|
|
return (modules_observer_cnt, graph_observer_cnt)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_generate_report(self):
|
|
"""
|
|
Tests model_report.generate_model_report to ensure report generation
|
|
Specifically looks at:
|
|
- Whether correct number of reports are being generated
|
|
- Whether observers are being properly removed if specified
|
|
- Whether correct blocking from generating report twice if obs removed
|
|
"""
|
|
|
|
with override_quantized_engine('fbgemm'):
|
|
# set the backend for this test
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
|
|
# check whether the correct number of reports are being generated
|
|
filled_detector_set = {DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)}
|
|
single_detector_set = {DynamicStaticDetector()}
|
|
|
|
# create our models
|
|
model_full = TwoThreeOps()
|
|
model_single = TwoThreeOps()
|
|
|
|
# prepare and callibrate two different instances of same model
|
|
# prepare the model
|
|
example_input = model_full.get_example_inputs()[0]
|
|
current_backend = torch.backends.quantized.engine
|
|
q_config_mapping = QConfigMapping()
|
|
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
|
|
|
|
model_prep_full = quantize_fx.prepare_fx(model_full, q_config_mapping, example_input)
|
|
model_prep_single = quantize_fx.prepare_fx(model_single, q_config_mapping, example_input)
|
|
|
|
# initialize one with filled detector
|
|
model_report_full = ModelReport(model_prep_full, filled_detector_set)
|
|
# initialize another with a single detector set
|
|
model_report_single = ModelReport(model_prep_single, single_detector_set)
|
|
|
|
# prepare the models for callibration
|
|
prepared_for_callibrate_model_full = model_report_full.prepare_detailed_calibration()
|
|
prepared_for_callibrate_model_single = model_report_single.prepare_detailed_calibration()
|
|
|
|
# now callibrate the two models
|
|
num_iterations = 10
|
|
for i in range(num_iterations):
|
|
example_input = torch.tensor(torch.randint(100, (1, 3, 3, 3)), dtype=torch.float)
|
|
prepared_for_callibrate_model_full(example_input)
|
|
prepared_for_callibrate_model_single(example_input)
|
|
|
|
# now generate the reports
|
|
model_full_report = model_report_full.generate_model_report(True)
|
|
model_single_report = model_report_single.generate_model_report(False)
|
|
|
|
# check that sizes are appropriate
|
|
self.assertEqual(len(model_full_report), len(filled_detector_set))
|
|
self.assertEqual(len(model_single_report), len(single_detector_set))
|
|
|
|
# make sure observers are being properly removed for full report since we put flag in
|
|
modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_full)
|
|
self.assertEqual(modules_observer_cnt, 0) # assert no more observer modules
|
|
self.assertEqual(graph_observer_cnt, 0) # assert no more observer nodes in graph
|
|
|
|
# make sure observers aren't being removed for single report since not specified
|
|
modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_single)
|
|
self.assertNotEqual(modules_observer_cnt, 0)
|
|
self.assertNotEqual(graph_observer_cnt, 0)
|
|
|
|
# make sure error when try to rerun report generation for full report but not single report
|
|
with self.assertRaises(Exception):
|
|
model_full_report = model_report_full.generate_model_report(
|
|
prepared_for_callibrate_model_full, False
|
|
)
|
|
|
|
# make sure we don't run into error for single report
|
|
model_single_report = model_report_single.generate_model_report(False)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_generate_visualizer(self):
|
|
"""
|
|
Tests that the ModelReport class can properly create the ModelReportVisualizer instance
|
|
Checks that:
|
|
- Correct number of modules are represented
|
|
- Modules are sorted
|
|
- Correct number of features for each module
|
|
"""
|
|
with override_quantized_engine('fbgemm'):
|
|
# set the backend for this test
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
# test with multiple detectors
|
|
detector_set = set()
|
|
detector_set.add(OutlierDetector(reference_percentile=0.95))
|
|
detector_set.add(InputWeightEqualizationDetector(0.5))
|
|
|
|
model = TwoThreeOps()
|
|
|
|
# get tst model and callibrate
|
|
prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
|
|
model, detector_set, model.get_example_inputs()[0]
|
|
)
|
|
|
|
# now we actually callibrate the model
|
|
example_input = model.get_example_inputs()[0]
|
|
example_input = example_input.to(torch.float)
|
|
|
|
prepared_for_callibrate_model(example_input)
|
|
|
|
# try to visualize without generating report, should throw error
|
|
with self.assertRaises(Exception):
|
|
mod_rep_visualizaiton = mod_report.generate_visualizer()
|
|
|
|
# now get the report by running it through ModelReport instance
|
|
generated_report = mod_report.generate_model_report(remove_inserted_observers=False)
|
|
|
|
# now we get the visualizer should not error
|
|
mod_rep_visualizer: ModelReportVisualizer = mod_report.generate_visualizer()
|
|
|
|
# since we tested with outlier detector, which looks at every base level module
|
|
# should be six entries in the ordered dict
|
|
mod_fqns_to_features = mod_rep_visualizer.generated_reports
|
|
|
|
self.assertEqual(len(mod_fqns_to_features), 6)
|
|
|
|
# outlier detector has 9 feature per module
|
|
# input-weight has 12 features per module
|
|
# there are 1 common data point, so should be 12 + 9 - 1 = 20 unique features per common modules
|
|
# all linears will be common
|
|
for module_fqn in mod_fqns_to_features:
|
|
if ".linear" in module_fqn:
|
|
linear_info = mod_fqns_to_features[module_fqn]
|
|
self.assertEqual(len(linear_info), 20)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_qconfig_mapping_generation(self):
|
|
"""
|
|
Tests for generation of qconfigs by ModelReport API
|
|
- Tests that qconfigmapping is generated
|
|
- Tests that mappings include information for for relavent modules
|
|
"""
|
|
with override_quantized_engine('fbgemm'):
|
|
# set the backend for this test
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
# test with multiple detectors
|
|
detector_set = set()
|
|
detector_set.add(PerChannelDetector())
|
|
detector_set.add(DynamicStaticDetector())
|
|
|
|
model = TwoThreeOps()
|
|
|
|
# get tst model and callibrate
|
|
prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
|
|
model, detector_set, model.get_example_inputs()[0]
|
|
)
|
|
|
|
# now we actually callibrate the models
|
|
example_input = model.get_example_inputs()[0]
|
|
example_input = example_input.to(torch.float)
|
|
|
|
prepared_for_callibrate_model(example_input)
|
|
|
|
|
|
# get the mapping without error
|
|
qconfig_mapping = mod_report.generate_qconfig_mapping()
|
|
|
|
# now get the report by running it through ModelReport instance
|
|
generated_report = mod_report.generate_model_report(remove_inserted_observers=False)
|
|
|
|
# get the visualizer so we can get access to reformatted reports by module fqn
|
|
mod_reports_by_fqn = mod_report.generate_visualizer().generated_reports
|
|
|
|
# compare the entries of the mapping to those of the report
|
|
# we should have the same number of entries
|
|
self.assertEqual(len(qconfig_mapping.module_name_qconfigs), len(mod_reports_by_fqn))
|
|
|
|
# for the non_empty one, we should have 2 because we have only applicable linears
|
|
# so should have suggestions for each module named
|
|
self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2)
|
|
|
|
# only two linears, make sure per channel min max for weight since fbgemm
|
|
# also static distribution since a simple single callibration
|
|
for key in qconfig_mapping.module_name_qconfigs:
|
|
config = qconfig_mapping.module_name_qconfigs[key]
|
|
self.assertEqual(config.weight, default_per_channel_weight_observer)
|
|
self.assertEqual(config.activation, default_observer)
|
|
|
|
# make sure these can actually be used to prepare the model
|
|
prepared = quantize_fx.prepare_fx(TwoThreeOps(), qconfig_mapping, example_input)
|
|
|
|
# now convert the model to ensure no errors in conversion
|
|
converted = quantize_fx.convert_fx(prepared)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_equalization_mapping_generation(self):
|
|
"""
|
|
Tests for generation of qconfigs by ModelReport API
|
|
- Tests that equalization config generated when input-weight equalization detector used
|
|
- Tests that mappings include information for for relavent modules
|
|
"""
|
|
with override_quantized_engine('fbgemm'):
|
|
# set the backend for this test
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
# test with multiple detectors
|
|
detector_set = set()
|
|
detector_set.add(InputWeightEqualizationDetector(0.6))
|
|
|
|
model = TwoThreeOps()
|
|
|
|
# get tst model and callibrate
|
|
prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
|
|
model, detector_set, model.get_example_inputs()[0]
|
|
)
|
|
|
|
# now we actually callibrate the models
|
|
example_input = model.get_example_inputs()[0]
|
|
example_input = example_input.to(torch.float)
|
|
|
|
prepared_for_callibrate_model(example_input)
|
|
|
|
|
|
# get the mapping without error
|
|
qconfig_mapping = mod_report.generate_qconfig_mapping()
|
|
equalization_mapping = mod_report.generate_equalization_mapping()
|
|
|
|
# tests a lot more simple for the equalization mapping
|
|
|
|
# shouldn't have any equalization suggestions for this case
|
|
self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2)
|
|
|
|
|
|
# make sure these can actually be used to prepare the model
|
|
prepared = quantize_fx.prepare_fx(
|
|
TwoThreeOps(),
|
|
qconfig_mapping,
|
|
example_input,
|
|
_equalization_config=equalization_mapping
|
|
)
|
|
|
|
# now convert the model to ensure no errors in conversion
|
|
converted = quantize_fx.convert_fx(prepared)
|
|
|
|
class TestFxDetectInputWeightEqualization(QuantizationTestCase):
|
|
|
|
class SimpleConv(torch.nn.Module):
|
|
def __init__(self, con_dims):
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
self.conv = torch.nn.Conv2d(con_dims[0], con_dims[1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
class TwoBlockComplexNet(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.block1 = TestFxDetectInputWeightEqualization.SimpleConv((3, 32))
|
|
self.block2 = TestFxDetectInputWeightEqualization.SimpleConv((3, 3))
|
|
self.conv = torch.nn.Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)
|
|
self.linear = torch.nn.Linear(768, 10)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.block1(x)
|
|
x = self.conv(x)
|
|
y = self.block2(x)
|
|
y = y.repeat(1, 1, 2, 2)
|
|
z = x + y
|
|
z = z.flatten(start_dim=1)
|
|
z = self.linear(z)
|
|
z = self.relu(z)
|
|
return z
|
|
|
|
def get_fusion_modules(self):
|
|
return [['conv', 'relu']]
|
|
|
|
def get_example_inputs(self):
|
|
return (torch.randn((1, 3, 28, 28)),)
|
|
|
|
class ReluOnly(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
def get_example_inputs(self):
|
|
return (torch.arange(27).reshape((1, 3, 3, 3)),)
|
|
|
|
def _get_prepped_for_calibration_model(self, model, detector_set, fused=False):
|
|
r"""Returns a model that has been prepared for callibration and corresponding model_report"""
|
|
|
|
# pass in necessary inputs to helper
|
|
example_input = model.get_example_inputs()[0]
|
|
return _get_prepped_for_calibration_model_helper(model, detector_set, example_input, fused)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_input_weight_equalization_determine_points(self):
|
|
# use fbgemm and create our model instance
|
|
# then create model report instance with detector
|
|
with override_quantized_engine('fbgemm'):
|
|
|
|
detector_set = {InputWeightEqualizationDetector(0.5)}
|
|
|
|
# get tst model and callibrate
|
|
non_fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set)
|
|
fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set, fused=True)
|
|
|
|
# reporter should still give same counts even for fused model
|
|
for prepared_for_callibrate_model, mod_report in [non_fused, fused]:
|
|
|
|
# supported modules to check
|
|
mods_to_check = {nn.Linear, nn.Conv2d}
|
|
|
|
# get the set of all nodes in the graph their fqns
|
|
node_fqns = {node.target for node in prepared_for_callibrate_model.graph.nodes}
|
|
|
|
# there should be 4 node fqns that have the observer inserted
|
|
correct_number_of_obs_inserted = 4
|
|
number_of_obs_found = 0
|
|
obs_name_to_find = InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME
|
|
|
|
for node in prepared_for_callibrate_model.graph.nodes:
|
|
# if the obs name is inside the target, we found an observer
|
|
if obs_name_to_find in str(node.target):
|
|
number_of_obs_found += 1
|
|
|
|
self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted)
|
|
|
|
# assert that each of the desired modules have the observers inserted
|
|
for fqn, module in prepared_for_callibrate_model.named_modules():
|
|
# check if module is a supported module
|
|
is_in_include_list = sum(isinstance(module, x) for x in mods_to_check) > 0
|
|
|
|
if is_in_include_list:
|
|
# make sure it has the observer attribute
|
|
self.assertTrue(hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
|
|
else:
|
|
# if it's not a supported type, it shouldn't have observer attached
|
|
self.assertTrue(not hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_input_weight_equalization_report_gen(self):
|
|
# use fbgemm and create our model instance
|
|
# then create model report instance with detector
|
|
with override_quantized_engine('fbgemm'):
|
|
|
|
test_input_weight_detector = InputWeightEqualizationDetector(0.4)
|
|
detector_set = {test_input_weight_detector}
|
|
model = self.TwoBlockComplexNet()
|
|
# prepare the model for callibration
|
|
prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(
|
|
model, detector_set
|
|
)
|
|
|
|
# now we actually callibrate the model
|
|
example_input = model.get_example_inputs()[0]
|
|
example_input = example_input.to(torch.float)
|
|
|
|
prepared_for_callibrate_model(example_input)
|
|
|
|
# now get the report by running it through ModelReport instance
|
|
generated_report = model_report.generate_model_report(True)
|
|
|
|
# check that sizes are appropriate only 1 detector
|
|
self.assertEqual(len(generated_report), 1)
|
|
|
|
# get the specific report for input weight equalization
|
|
input_weight_str, input_weight_dict = generated_report[test_input_weight_detector.get_detector_name()]
|
|
|
|
# we should have 5 layers looked at since 4 conv / linear layers
|
|
self.assertEqual(len(input_weight_dict), 4)
|
|
|
|
# we can validate that the max and min values of the detector were recorded properly for the first one
|
|
# this is because no data has been processed yet, so it should be values from original input
|
|
|
|
example_input = example_input.reshape((3, 28, 28)) # reshape input
|
|
for module_fqn in input_weight_dict:
|
|
# look for the first linear
|
|
if "block1.linear" in module_fqn:
|
|
block_1_lin_recs = input_weight_dict[module_fqn]
|
|
# get input range info and the channel axis
|
|
ch_axis = block_1_lin_recs[InputWeightEqualizationDetector.CHANNEL_KEY]
|
|
|
|
# ensure that the min and max values extracted match properly
|
|
example_min, example_max = torch.aminmax(example_input, dim=ch_axis)
|
|
dimension_min = torch.amin(example_min, dim=ch_axis)
|
|
dimension_max = torch.amax(example_max, dim=ch_axis)
|
|
|
|
# make sure per channel min and max are as expected
|
|
min_per_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
|
|
min_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MIN_KEY
|
|
|
|
max_per_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
|
|
max_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MAX_KEY
|
|
|
|
per_channel_min = block_1_lin_recs[min_per_key]
|
|
per_channel_max = block_1_lin_recs[max_per_key]
|
|
self.assertEqual(per_channel_min, dimension_min)
|
|
self.assertEqual(per_channel_max, dimension_max)
|
|
|
|
# make sure per channel min and max are as expected
|
|
min_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
|
|
min_key += InputWeightEqualizationDetector.GLOBAL_MIN_KEY
|
|
|
|
max_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
|
|
max_key += InputWeightEqualizationDetector.GLOBAL_MAX_KEY
|
|
|
|
# make sure the global min and max were correctly recorded and presented
|
|
global_min = block_1_lin_recs[min_key]
|
|
global_max = block_1_lin_recs[max_key]
|
|
self.assertEqual(global_min, min(dimension_min))
|
|
self.assertEqual(global_max, max(dimension_max))
|
|
|
|
input_ratio = torch.sqrt((per_channel_max - per_channel_min) / (global_max - global_min))
|
|
# ensure comparision stat passed back is sqrt of range ratios
|
|
# need to get the weight ratios first
|
|
|
|
# make sure per channel min and max are as expected
|
|
min_per_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
|
|
min_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MIN_KEY
|
|
|
|
max_per_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
|
|
max_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MAX_KEY
|
|
|
|
# get weight per channel and global info
|
|
per_channel_min = block_1_lin_recs[min_per_key]
|
|
per_channel_max = block_1_lin_recs[max_per_key]
|
|
|
|
# make sure per channel min and max are as expected
|
|
min_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
|
|
min_key += InputWeightEqualizationDetector.GLOBAL_MIN_KEY
|
|
|
|
max_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
|
|
max_key += InputWeightEqualizationDetector.GLOBAL_MAX_KEY
|
|
|
|
global_min = block_1_lin_recs[min_key]
|
|
global_max = block_1_lin_recs[max_key]
|
|
|
|
weight_ratio = torch.sqrt((per_channel_max - per_channel_min) / (global_max - global_min))
|
|
|
|
# also get comp stat for this specific layer
|
|
comp_stat = block_1_lin_recs[InputWeightEqualizationDetector.COMP_METRIC_KEY]
|
|
|
|
weight_to_input_ratio = weight_ratio / input_ratio
|
|
|
|
self.assertEqual(comp_stat, weight_to_input_ratio)
|
|
# only looking at the first example so can break
|
|
break
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_input_weight_equalization_report_gen_empty(self):
|
|
# tests report gen on a model that doesn't have any layers
|
|
# use fbgemm and create our model instance
|
|
# then create model report instance with detector
|
|
with override_quantized_engine('fbgemm'):
|
|
test_input_weight_detector = InputWeightEqualizationDetector(0.4)
|
|
detector_set = {test_input_weight_detector}
|
|
model = self.ReluOnly()
|
|
# prepare the model for callibration
|
|
prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(model, detector_set)
|
|
|
|
# now we actually callibrate the model
|
|
example_input = model.get_example_inputs()[0]
|
|
example_input = example_input.to(torch.float)
|
|
|
|
prepared_for_callibrate_model(example_input)
|
|
|
|
# now get the report by running it through ModelReport instance
|
|
generated_report = model_report.generate_model_report(True)
|
|
|
|
# check that sizes are appropriate only 1 detector
|
|
self.assertEqual(len(generated_report), 1)
|
|
|
|
# get the specific report for input weight equalization
|
|
input_weight_str, input_weight_dict = generated_report[test_input_weight_detector.get_detector_name()]
|
|
|
|
# we should have 0 layers since there is only a Relu
|
|
self.assertEqual(len(input_weight_dict), 0)
|
|
|
|
# make sure that the string only has two lines, as should be if no suggestions
|
|
self.assertEqual(input_weight_str.count("\n"), 2)
|
|
|
|
|
|
class TestFxDetectOutliers(QuantizationTestCase):
|
|
|
|
class LargeBatchModel(torch.nn.Module):
|
|
def __init__(self, param_size):
|
|
super().__init__()
|
|
self.param_size = param_size
|
|
self.linear = torch.nn.Linear(param_size, param_size)
|
|
self.relu_1 = torch.nn.ReLU()
|
|
self.conv = torch.nn.Conv2d(param_size, param_size, 1)
|
|
self.relu_2 = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.relu_1(x)
|
|
x = self.conv(x)
|
|
x = self.relu_2(x)
|
|
return x
|
|
|
|
def get_example_inputs(self):
|
|
param_size = self.param_size
|
|
return (torch.randn((1, param_size, param_size, param_size)),)
|
|
|
|
def get_outlier_inputs(self):
|
|
param_size = self.param_size
|
|
random_vals = torch.randn((1, param_size, param_size, param_size))
|
|
# change one in some of them to be a massive value
|
|
random_vals[:, 0:param_size:2, 0, 3] = torch.tensor([3.28e8])
|
|
return (random_vals,)
|
|
|
|
|
|
def _get_prepped_for_calibration_model(self, model, detector_set, use_outlier_data=False):
|
|
r"""Returns a model that has been prepared for callibration and corresponding model_report"""
|
|
# call the general helper function to callibrate
|
|
example_input = model.get_example_inputs()[0]
|
|
|
|
# if we specifically want to test data with outliers replace input
|
|
if use_outlier_data:
|
|
example_input = model.get_outlier_inputs()[0]
|
|
|
|
return _get_prepped_for_calibration_model_helper(model, detector_set, example_input)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_outlier_detection_determine_points(self):
|
|
# use fbgemm and create our model instance
|
|
# then create model report instance with detector
|
|
# similar to test for InputWeightEqualization but key differences that made refactoring not viable
|
|
# not explicitly testing fusion because fx workflow automatically
|
|
with override_quantized_engine('fbgemm'):
|
|
|
|
detector_set = {OutlierDetector(reference_percentile=0.95)}
|
|
|
|
# get tst model and callibrate
|
|
prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
|
|
self.LargeBatchModel(param_size=128), detector_set
|
|
)
|
|
|
|
# supported modules to check
|
|
mods_to_check = {nn.Linear, nn.Conv2d, nn.ReLU}
|
|
|
|
# there should be 4 node fqns that have the observer inserted
|
|
correct_number_of_obs_inserted = 4
|
|
number_of_obs_found = 0
|
|
obs_name_to_find = InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME
|
|
|
|
number_of_obs_found = sum(
|
|
1 if obs_name_to_find in str(node.target) else 0 for node in prepared_for_callibrate_model.graph.nodes
|
|
)
|
|
self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted)
|
|
|
|
# assert that each of the desired modules have the observers inserted
|
|
for fqn, module in prepared_for_callibrate_model.named_modules():
|
|
# check if module is a supported module
|
|
is_in_include_list = isinstance(module, tuple(mods_to_check))
|
|
|
|
if is_in_include_list:
|
|
# make sure it has the observer attribute
|
|
self.assertTrue(hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
|
|
else:
|
|
# if it's not a supported type, it shouldn't have observer attached
|
|
self.assertTrue(not hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_no_outlier_report_gen(self):
|
|
# use fbgemm and create our model instance
|
|
# then create model report instance with detector
|
|
with override_quantized_engine('fbgemm'):
|
|
|
|
# test with multiple detectors
|
|
outlier_detector = OutlierDetector(reference_percentile=0.95)
|
|
dynamic_static_detector = DynamicStaticDetector(tolerance=0.5)
|
|
|
|
param_size: int = 4
|
|
detector_set = {outlier_detector, dynamic_static_detector}
|
|
model = self.LargeBatchModel(param_size=param_size)
|
|
|
|
# get tst model and callibrate
|
|
prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
|
|
model, detector_set
|
|
)
|
|
|
|
# now we actually callibrate the model
|
|
example_input = model.get_example_inputs()[0]
|
|
example_input = example_input.to(torch.float)
|
|
|
|
prepared_for_callibrate_model(example_input)
|
|
|
|
# now get the report by running it through ModelReport instance
|
|
generated_report = mod_report.generate_model_report(True)
|
|
|
|
# check that sizes are appropriate only 2 detectors
|
|
self.assertEqual(len(generated_report), 2)
|
|
|
|
# get the specific report for input weight equalization
|
|
outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()]
|
|
|
|
# we should have 5 layers looked at since 4 conv + linear + relu
|
|
self.assertEqual(len(outlier_dict), 4)
|
|
|
|
# assert the following are true for all the modules
|
|
for module_fqn in outlier_dict:
|
|
# get the info for the specific module
|
|
module_dict = outlier_dict[module_fqn]
|
|
|
|
# there really should not be any outliers since we used a normal distribution to perform this calculation
|
|
outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
|
|
self.assertEqual(sum(outlier_info), 0)
|
|
|
|
# ensure that the number of ratios and batches counted is the same as the number of params
|
|
self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size)
|
|
self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size)
|
|
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_all_outlier_report_gen(self):
|
|
# make the percentile 0 and the ratio 1, and then see that everything is outlier according to it
|
|
# use fbgemm and create our model instance
|
|
# then create model report instance with detector
|
|
with override_quantized_engine('fbgemm'):
|
|
# create detector of interest
|
|
outlier_detector = OutlierDetector(ratio_threshold=1, reference_percentile=0)
|
|
|
|
param_size: int = 16
|
|
detector_set = {outlier_detector}
|
|
model = self.LargeBatchModel(param_size=param_size)
|
|
|
|
# get tst model and callibrate
|
|
prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
|
|
model, detector_set
|
|
)
|
|
|
|
# now we actually callibrate the model
|
|
example_input = model.get_example_inputs()[0]
|
|
example_input = example_input.to(torch.float)
|
|
|
|
prepared_for_callibrate_model(example_input)
|
|
|
|
# now get the report by running it through ModelReport instance
|
|
generated_report = mod_report.generate_model_report(True)
|
|
|
|
# check that sizes are appropriate only 1 detector
|
|
self.assertEqual(len(generated_report), 1)
|
|
|
|
# get the specific report for input weight equalization
|
|
outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()]
|
|
|
|
# we should have 5 layers looked at since 4 conv + linear + relu
|
|
self.assertEqual(len(outlier_dict), 4)
|
|
|
|
# assert the following are true for all the modules
|
|
for module_fqn in outlier_dict:
|
|
# get the info for the specific module
|
|
module_dict = outlier_dict[module_fqn]
|
|
|
|
# everything should be an outlier because we said that the max should be equal to the min for all of them
|
|
# however we will just test and say most should be in case we have several 0 channel values
|
|
outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
|
|
assert sum(outlier_info) >= len(outlier_info) / 2
|
|
|
|
# ensure that the number of ratios and batches counted is the same as the number of params
|
|
self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size)
|
|
self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_multiple_run_consistent_spike_outlier_report_gen(self):
|
|
# specifically make a row really high consistently in the number of batches that you are testing and try that
|
|
# generate report after just 1 run, and after many runs (30) and make sure above minimum threshold is there
|
|
with override_quantized_engine('fbgemm'):
|
|
|
|
# detector of interest
|
|
outlier_detector = OutlierDetector(reference_percentile=0.95)
|
|
|
|
param_size: int = 8
|
|
detector_set = {outlier_detector}
|
|
model = self.LargeBatchModel(param_size=param_size)
|
|
|
|
# get tst model and callibrate
|
|
prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
|
|
model, detector_set, use_outlier_data=True
|
|
)
|
|
|
|
# now we actually callibrate the model
|
|
example_input = model.get_outlier_inputs()[0]
|
|
example_input = example_input.to(torch.float)
|
|
|
|
# now callibrate minimum 30 times to make it above minimum threshold
|
|
for i in range(30):
|
|
example_input = model.get_outlier_inputs()[0]
|
|
example_input = example_input.to(torch.float)
|
|
|
|
# make 2 of the batches to have zero channel
|
|
if i % 14 == 0:
|
|
# make one channel constant
|
|
example_input[0][1] = torch.zeros_like(example_input[0][1])
|
|
|
|
prepared_for_callibrate_model(example_input)
|
|
|
|
# now get the report by running it through ModelReport instance
|
|
generated_report = mod_report.generate_model_report(True)
|
|
|
|
# check that sizes are appropriate only 1 detector
|
|
self.assertEqual(len(generated_report), 1)
|
|
|
|
# get the specific report for input weight equalization
|
|
outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()]
|
|
|
|
# we should have 5 layers looked at since 4 conv + linear + relu
|
|
self.assertEqual(len(outlier_dict), 4)
|
|
|
|
# assert the following are true for all the modules
|
|
for module_fqn in outlier_dict:
|
|
# get the info for the specific module
|
|
module_dict = outlier_dict[module_fqn]
|
|
|
|
# because we ran 30 times, we should have at least a couple be significant
|
|
# could be less because some channels could possibly be all 0
|
|
sufficient_batches_info = module_dict[OutlierDetector.IS_SUFFICIENT_BATCHES_KEY]
|
|
assert sum(sufficient_batches_info) >= len(sufficient_batches_info) / 2
|
|
|
|
# half of them should be outliers, because we set a really high value every 2 channels
|
|
outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
|
|
self.assertEqual(sum(outlier_info), len(outlier_info) / 2)
|
|
|
|
# ensure that the number of ratios and batches counted is the same as the number of params
|
|
self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size)
|
|
self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size)
|
|
|
|
# for the first one ensure the per channel max values are what we set
|
|
if module_fqn == "linear.0":
|
|
|
|
# check that the non-zero channel count, at least 2 should be there
|
|
# for the first module
|
|
counts_info = module_dict[OutlierDetector.CONSTANT_COUNTS_KEY]
|
|
assert sum(counts_info) >= 2
|
|
|
|
# half of the recorded max values should be what we set
|
|
matched_max = sum(val == 3.28e8 for val in module_dict[OutlierDetector.MAX_VALS_KEY])
|
|
self.assertEqual(matched_max, param_size / 2)
|
|
|
|
|
|
class TestFxModelReportVisualizer(QuantizationTestCase):
|
|
|
|
def _callibrate_and_generate_visualizer(self, model, prepared_for_callibrate_model, mod_report):
|
|
r"""
|
|
Callibrates the passed in model, generates report, and returns the visualizer
|
|
"""
|
|
# now we actually callibrate the model
|
|
example_input = model.get_example_inputs()[0]
|
|
example_input = example_input.to(torch.float)
|
|
|
|
prepared_for_callibrate_model(example_input)
|
|
|
|
# now get the report by running it through ModelReport instance
|
|
generated_report = mod_report.generate_model_report(remove_inserted_observers=False)
|
|
|
|
# now we get the visualizer should not error
|
|
mod_rep_visualizer: ModelReportVisualizer = mod_report.generate_visualizer()
|
|
|
|
return mod_rep_visualizer
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_get_modules_and_features(self):
|
|
"""
|
|
Tests the get_all_unique_module_fqns and get_all_unique_feature_names methods of
|
|
ModelReportVisualizer
|
|
|
|
Checks whether returned sets are of proper size and filtered properly
|
|
"""
|
|
with override_quantized_engine('fbgemm'):
|
|
# set the backend for this test
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
# test with multiple detectors
|
|
detector_set = set()
|
|
detector_set.add(OutlierDetector(reference_percentile=0.95))
|
|
detector_set.add(InputWeightEqualizationDetector(0.5))
|
|
|
|
model = TwoThreeOps()
|
|
|
|
# get tst model and callibrate
|
|
prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
|
|
model, detector_set, model.get_example_inputs()[0]
|
|
)
|
|
|
|
mod_rep_visualizer: ModelReportVisualizer = self._callibrate_and_generate_visualizer(
|
|
model, prepared_for_callibrate_model, mod_report
|
|
)
|
|
|
|
# ensure the module fqns match the ones given by the get_all_unique_feature_names method
|
|
actual_model_fqns = set(mod_rep_visualizer.generated_reports.keys())
|
|
returned_model_fqns = mod_rep_visualizer.get_all_unique_module_fqns()
|
|
self.assertEqual(returned_model_fqns, actual_model_fqns)
|
|
|
|
# now ensure that features are all properly returned
|
|
# all the linears have all the features for two detectors
|
|
# can use those as check that method is working reliably
|
|
b_1_linear_features = mod_rep_visualizer.generated_reports["block1.linear"]
|
|
|
|
# first test all features
|
|
returned_all_feats = mod_rep_visualizer.get_all_unique_feature_names(False)
|
|
self.assertEqual(returned_all_feats, set(b_1_linear_features.keys()))
|
|
|
|
# now test plottable features
|
|
plottable_set = set()
|
|
|
|
for feature_name in b_1_linear_features:
|
|
if type(b_1_linear_features[feature_name]) is torch.Tensor:
|
|
plottable_set.add(feature_name)
|
|
|
|
returned_plottable_feats = mod_rep_visualizer.get_all_unique_feature_names()
|
|
self.assertEqual(returned_plottable_feats, plottable_set)
|
|
|
|
def _prep_visualizer_helper(self):
|
|
r"""
|
|
Returns a mod rep visualizer that we test in various ways
|
|
"""
|
|
# set backend for test
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
|
|
# test with multiple detectors
|
|
detector_set = set()
|
|
detector_set.add(OutlierDetector(reference_percentile=0.95))
|
|
detector_set.add(InputWeightEqualizationDetector(0.5))
|
|
|
|
model = TwoThreeOps()
|
|
|
|
# get tst model and callibrate
|
|
prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
|
|
model, detector_set, model.get_example_inputs()[0]
|
|
)
|
|
|
|
mod_rep_visualizer: ModelReportVisualizer = self._callibrate_and_generate_visualizer(
|
|
model, prepared_for_callibrate_model, mod_report
|
|
)
|
|
|
|
return mod_rep_visualizer
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_generate_tables_match_with_report(self):
|
|
"""
|
|
Tests the generate_table_view()
|
|
ModelReportVisualizer
|
|
|
|
Checks whether the generated dict has proper information
|
|
Visual check that the tables look correct performed during testing
|
|
"""
|
|
with override_quantized_engine('fbgemm'):
|
|
|
|
# get the visualizer
|
|
mod_rep_visualizer = self._prep_visualizer_helper()
|
|
|
|
table_dict = mod_rep_visualizer.generate_filtered_tables()
|
|
|
|
# test primarily the dict since it has same info as str
|
|
tensor_headers, tensor_table = table_dict[ModelReportVisualizer.TABLE_TENSOR_KEY]
|
|
channel_headers, channel_table = table_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
|
|
|
|
# these two together should be the same as the generated report info in terms of keys
|
|
tensor_info_modules = {row[1] for row in tensor_table}
|
|
channel_info_modules = {row[1] for row in channel_table}
|
|
combined_modules: set = tensor_info_modules.union(channel_info_modules)
|
|
|
|
generated_report_keys: set = set(mod_rep_visualizer.generated_reports.keys())
|
|
self.assertEqual(combined_modules, generated_report_keys)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_generate_tables_no_match(self):
|
|
"""
|
|
Tests the generate_table_view()
|
|
ModelReportVisualizer
|
|
|
|
Checks whether the generated dict has proper information
|
|
Visual check that the tables look correct performed during testing
|
|
"""
|
|
with override_quantized_engine('fbgemm'):
|
|
# get the visualizer
|
|
mod_rep_visualizer = self._prep_visualizer_helper()
|
|
|
|
# try a random filter and make sure that there are no rows for either table
|
|
empty_tables_dict = mod_rep_visualizer.generate_filtered_tables(module_fqn_filter="random not there module")
|
|
|
|
# test primarily the dict since it has same info as str
|
|
tensor_headers, tensor_table = empty_tables_dict[ModelReportVisualizer.TABLE_TENSOR_KEY]
|
|
channel_headers, channel_table = empty_tables_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
|
|
|
|
tensor_info_modules = {row[1] for row in tensor_table}
|
|
channel_info_modules = {row[1] for row in channel_table}
|
|
combined_modules: set = tensor_info_modules.union(channel_info_modules)
|
|
self.assertEqual(len(combined_modules), 0) # should be no matching modules
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_generate_tables_single_feat_match(self):
|
|
"""
|
|
Tests the generate_table_view()
|
|
ModelReportVisualizer
|
|
|
|
Checks whether the generated dict has proper information
|
|
Visual check that the tables look correct performed during testing
|
|
"""
|
|
with override_quantized_engine('fbgemm'):
|
|
# get the visualizer
|
|
mod_rep_visualizer = self._prep_visualizer_helper()
|
|
|
|
# try a matching filter for feature and make sure only those features show up
|
|
# if we filter to a very specific feature name, should only have 1 additional column in each table row
|
|
single_feat_dict = mod_rep_visualizer.generate_filtered_tables(feature_filter=OutlierDetector.MAX_VALS_KEY)
|
|
|
|
# test primarily the dict since it has same info as str
|
|
tensor_headers, tensor_table = single_feat_dict[ModelReportVisualizer.TABLE_TENSOR_KEY]
|
|
channel_headers, channel_table = single_feat_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
|
|
|
|
# get the number of features in each of these
|
|
tensor_info_features = len(tensor_headers)
|
|
channel_info_features = len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
|
|
|
|
# make sure that there are no tensor features, and that there is one channel level feature
|
|
self.assertEqual(tensor_info_features, 0)
|
|
self.assertEqual(channel_info_features, 1)
|
|
|
|
def _get_prepped_for_calibration_model_helper(model, detector_set, example_input, fused: bool = False):
|
|
r"""Returns a model that has been prepared for callibration and corresponding model_report"""
|
|
# set the backend for this test
|
|
torch.backends.quantized.engine = "fbgemm"
|
|
|
|
# create model instance and prepare it
|
|
example_input = example_input.to(torch.float)
|
|
q_config_mapping = torch.ao.quantization.get_default_qconfig_mapping()
|
|
|
|
# if they passed in fusion parameter, make sure to test that
|
|
if fused:
|
|
model = torch.ao.quantization.fuse_modules(model, model.get_fusion_modules())
|
|
|
|
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
|
|
|
|
model_report = ModelReport(model_prep, detector_set)
|
|
|
|
# prepare the model for callibration
|
|
prepared_for_callibrate_model = model_report.prepare_detailed_calibration()
|
|
|
|
return (prepared_for_callibrate_model, model_report)
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_quantization.py")
|