Files
pytorch/test/quantization/fx/test_model_report_fx.py
Aaron Gokaslan 597b558c51 [BE]: Update flake8 and plugins and fix bugs (#97795)
Update flake8 and flake8-plugins in lintrunner to a modern version. Enables more checks and makes flake8 checks significantly faster. Added a few additional rule ignores that will need to be fixed in the future.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97795
Approved by: https://github.com/alexsio27444, https://github.com/janeyx99, https://github.com/ezyang
2023-03-28 23:51:55 +00:00

1959 lines
82 KiB
Python

# -*- coding: utf-8 -*-
# Owner(s): ["oncall: quantization"]
import torch
import torch.nn as nn
import torch.ao.quantization.quantize_fx as quantize_fx
import torch.nn.functional as F
from torch.ao.quantization import QConfig, QConfigMapping
from torch.ao.quantization.fx._model_report.detector import (
DynamicStaticDetector,
InputWeightEqualizationDetector,
PerChannelDetector,
OutlierDetector,
)
from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver
from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer
from torch.ao.quantization.fx._model_report.model_report import ModelReport
from torch.ao.quantization.observer import (
HistogramObserver,
default_per_channel_weight_observer,
default_observer
)
from torch.ao.nn.intrinsic.modules.fused import ConvReLU2d, LinearReLU
from torch.testing._internal.common_quantization import (
ConvModel,
QuantizationTestCase,
SingleLayerLinearModel,
TwoLayerLinearModel,
skipIfNoFBGEMM,
skipIfNoQNNPACK,
override_quantized_engine,
)
"""
Partition of input domain:
Model contains: conv or linear, both conv and linear
Model contains: ConvTransposeNd (not supported for per_channel)
Model is: post training quantization model, quantization aware training model
Model is: composed with nn.Sequential, composed in class structure
QConfig utilizes per_channel weight observer, backend uses non per_channel weight observer
QConfig_dict uses only one default qconfig, Qconfig dict uses > 1 unique qconfigs
Partition on output domain:
There are possible changes / suggestions, there are no changes / suggestions
"""
# Default output for string if no optimizations are possible
DEFAULT_NO_OPTIMS_ANSWER_STRING = (
"Further Optimizations for backend {}: \nNo further per_channel optimizations possible."
)
# Example Sequential Model with multiple Conv and Linear with nesting involved
NESTED_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
torch.nn.Conv2d(3, 3, 2, 1),
torch.nn.Sequential(torch.nn.Linear(9, 27), torch.nn.ReLU()),
torch.nn.Linear(27, 27),
torch.nn.ReLU(),
torch.nn.Conv2d(3, 3, 2, 1),
)
# Example Sequential Model with Conv sub-class example
LAZY_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
torch.nn.LazyConv2d(3, 3, 2, 1),
torch.nn.Sequential(torch.nn.Linear(5, 27), torch.nn.ReLU()),
torch.nn.ReLU(),
torch.nn.Linear(27, 27),
torch.nn.ReLU(),
torch.nn.LazyConv2d(3, 3, 2, 1),
)
# Example Sequential Model with Fusion directly built into model
FUSION_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
ConvReLU2d(torch.nn.Conv2d(3, 3, 2, 1), torch.nn.ReLU()),
torch.nn.Sequential(LinearReLU(torch.nn.Linear(9, 27), torch.nn.ReLU())),
LinearReLU(torch.nn.Linear(27, 27), torch.nn.ReLU()),
torch.nn.Conv2d(3, 3, 2, 1),
)
# Test class
# example model to use for tests
class ThreeOps(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.bn = nn.BatchNorm2d(3)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear(x)
x = self.bn(x)
x = self.relu(x)
return x
def get_example_inputs(self):
return (torch.randn(1, 3, 3, 3),)
class TwoThreeOps(nn.Module):
def __init__(self):
super().__init__()
self.block1 = ThreeOps()
self.block2 = ThreeOps()
def forward(self, x):
x = self.block1(x)
y = self.block2(x)
z = x + y
z = F.relu(z)
return z
def get_example_inputs(self):
return (torch.randn(1, 3, 3, 3),)
class TestFxModelReportDetector(QuantizationTestCase):
"""Prepares and callibrate the model"""
def _prepare_model_and_run_input(self, model, q_config_mapping, input):
model_prep = torch.ao.quantization.quantize_fx.prepare_fx(model, q_config_mapping, input) # prep model
model_prep(input).sum() # callibrate the model
return model_prep
"""Case includes:
one conv or linear
post training quantiztion
composed as module
qconfig uses per_channel weight observer
Only 1 qconfig in qconfig dict
Output has no changes / suggestions
"""
@skipIfNoFBGEMM
def test_simple_conv(self):
with override_quantized_engine('fbgemm'):
torch.backends.quantized.engine = "fbgemm"
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
input = torch.randn(1, 3, 10, 10)
prepared_model = self._prepare_model_and_run_input(ConvModel(), q_config_mapping, input)
# run the detector
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
# no optims possible and there should be nothing in per_channel_status
self.assertEqual(
optims_str,
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
)
# there shoud only be one conv there in this model
self.assertEqual(per_channel_info["conv"]["backend"], torch.backends.quantized.engine)
self.assertEqual(len(per_channel_info), 1)
self.assertEqual(list(per_channel_info)[0], "conv")
self.assertEqual(
per_channel_info["conv"]["per_channel_quantization_supported"],
True,
)
self.assertEqual(per_channel_info["conv"]["per_channel_quantization_used"], True)
"""Case includes:
Multiple conv or linear
post training quantization
composed as module
qconfig doesn't use per_channel weight observer
Only 1 qconfig in qconfig dict
Output has possible changes / suggestions
"""
@skipIfNoQNNPACK
def test_multi_linear_model_without_per_channel(self):
with override_quantized_engine('qnnpack'):
torch.backends.quantized.engine = "qnnpack"
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
prepared_model = self._prepare_model_and_run_input(
TwoLayerLinearModel(),
q_config_mapping,
TwoLayerLinearModel().get_example_inputs()[0],
)
# run the detector
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
# there should be optims possible
self.assertNotEqual(
optims_str,
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
)
# pick a random key to look at
rand_key: str = list(per_channel_info.keys())[0]
self.assertEqual(per_channel_info[rand_key]["backend"], torch.backends.quantized.engine)
self.assertEqual(len(per_channel_info), 2)
# for each linear layer, should be supported but not used
for linear_key in per_channel_info.keys():
module_entry = per_channel_info[linear_key]
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
self.assertEqual(module_entry["per_channel_quantization_used"], False)
"""Case includes:
Multiple conv or linear
post training quantization
composed as Module
qconfig doesn't use per_channel weight observer
More than 1 qconfig in qconfig dict
Output has possible changes / suggestions
"""
@skipIfNoQNNPACK
def test_multiple_q_config_options(self):
with override_quantized_engine('qnnpack'):
torch.backends.quantized.engine = "qnnpack"
# qconfig with support for per_channel quantization
per_channel_qconfig = QConfig(
activation=HistogramObserver.with_args(reduce_range=True),
weight=default_per_channel_weight_observer,
)
# we need to design the model
class ConvLinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 2, 1)
self.fc1 = torch.nn.Linear(9, 27)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(27, 27)
self.conv2 = torch.nn.Conv2d(3, 3, 2, 1)
def forward(self, x):
x = self.conv1(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.conv2(x)
return x
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(
torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
).set_object_type(torch.nn.Conv2d, per_channel_qconfig)
prepared_model = self._prepare_model_and_run_input(
ConvLinearModel(),
q_config_mapping,
torch.randn(1, 3, 10, 10),
)
# run the detector
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
# the only suggestions should be to linear layers
# there should be optims possible
self.assertNotEqual(
optims_str,
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
)
# to ensure it got into the nested layer
self.assertEqual(len(per_channel_info), 4)
# for each layer, should be supported but not used
for key in per_channel_info.keys():
module_entry = per_channel_info[key]
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
# if linear False, if conv2d true cuz it uses different config
if "fc" in key:
self.assertEqual(module_entry["per_channel_quantization_used"], False)
elif "conv" in key:
self.assertEqual(module_entry["per_channel_quantization_used"], True)
else:
raise ValueError("Should only contain conv and linear layers as key values")
"""Case includes:
Multiple conv or linear
post training quantization
composed as sequential
qconfig doesn't use per_channel weight observer
Only 1 qconfig in qconfig dict
Output has possible changes / suggestions
"""
@skipIfNoQNNPACK
def test_sequential_model_format(self):
with override_quantized_engine('qnnpack'):
torch.backends.quantized.engine = "qnnpack"
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
prepared_model = self._prepare_model_and_run_input(
NESTED_CONV_LINEAR_EXAMPLE,
q_config_mapping,
torch.randn(1, 3, 10, 10),
)
# run the detector
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
# there should be optims possible
self.assertNotEqual(
optims_str,
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
)
# to ensure it got into the nested layer
self.assertEqual(len(per_channel_info), 4)
# for each layer, should be supported but not used
for key in per_channel_info.keys():
module_entry = per_channel_info[key]
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
self.assertEqual(module_entry["per_channel_quantization_used"], False)
"""Case includes:
Multiple conv or linear
post training quantization
composed as sequential
qconfig doesn't use per_channel weight observer
Only 1 qconfig in qconfig dict
Output has possible changes / suggestions
"""
@skipIfNoQNNPACK
def test_conv_sub_class_considered(self):
with override_quantized_engine('qnnpack'):
torch.backends.quantized.engine = "qnnpack"
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
prepared_model = self._prepare_model_and_run_input(
LAZY_CONV_LINEAR_EXAMPLE,
q_config_mapping,
torch.randn(1, 3, 10, 10),
)
# run the detector
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
# there should be optims possible
self.assertNotEqual(
optims_str,
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
)
# to ensure it got into the nested layer and it considered the lazyConv2d
self.assertEqual(len(per_channel_info), 4)
# for each layer, should be supported but not used
for key in per_channel_info.keys():
module_entry = per_channel_info[key]
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
self.assertEqual(module_entry["per_channel_quantization_used"], False)
"""Case includes:
Multiple conv or linear
post training quantization
composed as sequential
qconfig uses per_channel weight observer
Only 1 qconfig in qconfig dict
Output has no possible changes / suggestions
"""
@skipIfNoFBGEMM
def test_fusion_layer_in_sequential(self):
with override_quantized_engine('fbgemm'):
torch.backends.quantized.engine = "fbgemm"
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
prepared_model = self._prepare_model_and_run_input(
FUSION_CONV_LINEAR_EXAMPLE,
q_config_mapping,
torch.randn(1, 3, 10, 10),
)
# run the detector
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
# no optims possible and there should be nothing in per_channel_status
self.assertEqual(
optims_str,
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
)
# to ensure it got into the nested layer and it considered all the nested fusion components
self.assertEqual(len(per_channel_info), 4)
# for each layer, should be supported but not used
for key in per_channel_info.keys():
module_entry = per_channel_info[key]
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
self.assertEqual(module_entry["per_channel_quantization_used"], True)
"""Case includes:
Multiple conv or linear
quantitative aware training
composed as model
qconfig does not use per_channel weight observer
Only 1 qconfig in qconfig dict
Output has possible changes / suggestions
"""
@skipIfNoQNNPACK
def test_qat_aware_model_example(self):
# first we want a QAT model
class QATConvLinearReluModel(torch.nn.Module):
def __init__(self):
super().__init__()
# QuantStub converts tensors from floating point to quantized
self.quant = torch.ao.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.bn = torch.nn.BatchNorm2d(1)
self.relu = torch.nn.ReLU()
# DeQuantStub converts tensors from quantized to floating point
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.dequant(x)
return x
with override_quantized_engine('qnnpack'):
# create a model instance
model_fp32 = QATConvLinearReluModel()
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig("qnnpack")
# model must be in eval mode for fusion
model_fp32.eval()
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [["conv", "bn", "relu"]])
# model must be set to train mode for QAT logic to work
model_fp32_fused.train()
# prepare the model for QAT, different than for post training quantization
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused)
# run the detector
per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
optims_str, per_channel_info = per_channel_detector.generate_detector_report(model_fp32_prepared)
# there should be optims possible
self.assertNotEqual(
optims_str,
DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
)
# make sure it was able to find the single conv in the fused model
self.assertEqual(len(per_channel_info), 1)
# for the one conv, it should still give advice to use different qconfig
for key in per_channel_info.keys():
module_entry = per_channel_info[key]
self.assertEqual(module_entry["per_channel_quantization_supported"], True)
self.assertEqual(module_entry["per_channel_quantization_used"], False)
"""
Partition on Domain / Things to Test
- All zero tensor
- Multiple tensor dimensions
- All of the outward facing functions
- Epoch min max are correctly updating
- Batch range is correctly averaging as expected
- Reset for each epoch is correctly resetting the values
Partition on Output
- the calcuation of the ratio is occurring correctly
"""
class TestFxModelReportObserver(QuantizationTestCase):
class NestedModifiedSingleLayerLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.obs1 = ModelReportObserver()
self.mod1 = SingleLayerLinearModel()
self.obs2 = ModelReportObserver()
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.obs1(x)
x = self.mod1(x)
x = self.obs2(x)
x = self.fc1(x)
x = self.relu(x)
return x
def run_model_and_common_checks(self, model, ex_input, num_epochs, batch_size):
# split up data into batches
split_up_data = torch.split(ex_input, batch_size)
for epoch in range(num_epochs):
# reset all model report obs
model.apply(
lambda module: module.reset_batch_and_epoch_values()
if isinstance(module, ModelReportObserver)
else None
)
# quick check that a reset occurred
self.assertEqual(
model.obs1.average_batch_activation_range,
torch.tensor(float(0)),
)
self.assertEqual(model.obs1.epoch_activation_min, torch.tensor(float("inf")))
self.assertEqual(model.obs1.epoch_activation_max, torch.tensor(float("-inf")))
# loop through the batches and run through
for index, batch in enumerate(split_up_data):
num_tracked_so_far = model.obs1.num_batches_tracked
self.assertEqual(num_tracked_so_far, index)
# get general info about the batch and the model to use later
batch_min, batch_max = torch.aminmax(batch)
current_average_range = model.obs1.average_batch_activation_range
current_epoch_min = model.obs1.epoch_activation_min
current_epoch_max = model.obs1.epoch_activation_max
# run input through
model(ex_input)
# check that average batch activation range updated correctly
correct_updated_value = (current_average_range * num_tracked_so_far + (batch_max - batch_min)) / (
num_tracked_so_far + 1
)
self.assertEqual(
model.obs1.average_batch_activation_range,
correct_updated_value,
)
if current_epoch_max - current_epoch_min > 0:
self.assertEqual(
model.obs1.get_batch_to_epoch_ratio(),
correct_updated_value / (current_epoch_max - current_epoch_min),
)
"""Case includes:
all zero tensor
dim size = 2
run for 1 epoch
run for 10 batch
tests input data observer
"""
def test_zero_tensor_errors(self):
# initialize the model
model = self.NestedModifiedSingleLayerLinear()
# generate the desired input
ex_input = torch.zeros((10, 1, 5))
# run it through the model and do general tests
self.run_model_and_common_checks(model, ex_input, 1, 1)
# make sure final values are all 0
self.assertEqual(model.obs1.epoch_activation_min, 0)
self.assertEqual(model.obs1.epoch_activation_max, 0)
self.assertEqual(model.obs1.average_batch_activation_range, 0)
# we should get an error if we try to calculate the ratio
with self.assertRaises(ValueError):
ratio_val = model.obs1.get_batch_to_epoch_ratio()
"""Case includes:
non-zero tensor
dim size = 2
run for 1 epoch
run for 1 batch
tests input data observer
"""
def test_single_batch_of_ones(self):
# initialize the model
model = self.NestedModifiedSingleLayerLinear()
# generate the desired input
ex_input = torch.ones((1, 1, 5))
# run it through the model and do general tests
self.run_model_and_common_checks(model, ex_input, 1, 1)
# make sure final values are all 0 except for range
self.assertEqual(model.obs1.epoch_activation_min, 1)
self.assertEqual(model.obs1.epoch_activation_max, 1)
self.assertEqual(model.obs1.average_batch_activation_range, 0)
# we should get an error if we try to calculate the ratio
with self.assertRaises(ValueError):
ratio_val = model.obs1.get_batch_to_epoch_ratio()
"""Case includes:
non-zero tensor
dim size = 2
run for 10 epoch
run for 15 batch
tests non input data observer
"""
def test_observer_after_relu(self):
# model specific to this test
class NestedModifiedObserverAfterRelu(torch.nn.Module):
def __init__(self):
super().__init__()
self.obs1 = ModelReportObserver()
self.mod1 = SingleLayerLinearModel()
self.obs2 = ModelReportObserver()
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.obs1(x)
x = self.mod1(x)
x = self.fc1(x)
x = self.relu(x)
x = self.obs2(x)
return x
# initialize the model
model = NestedModifiedObserverAfterRelu()
# generate the desired input
ex_input = torch.randn((15, 1, 5))
# run it through the model and do general tests
self.run_model_and_common_checks(model, ex_input, 10, 15)
"""Case includes:
non-zero tensor
dim size = 2
run for multiple epoch
run for multiple batch
tests input data observer
"""
def test_random_epochs_and_batches(self):
# set up a basic model
class TinyNestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.obs1 = ModelReportObserver()
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
self.relu = torch.nn.ReLU()
self.obs2 = ModelReportObserver()
def forward(self, x):
x = self.obs1(x)
x = self.fc1(x)
x = self.relu(x)
x = self.obs2(x)
return x
class LargerIncludeNestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.obs1 = ModelReportObserver()
self.nested = TinyNestModule()
self.fc1 = SingleLayerLinearModel()
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.obs1(x)
x = self.nested(x)
x = self.fc1(x)
x = self.relu(x)
return x
class ModifiedThreeOps(torch.nn.Module):
def __init__(self, batch_norm_dim):
super().__init__()
self.obs1 = ModelReportObserver()
self.linear = torch.nn.Linear(7, 3, 2)
self.obs2 = ModelReportObserver()
if batch_norm_dim == 2:
self.bn = torch.nn.BatchNorm2d(2)
elif batch_norm_dim == 3:
self.bn = torch.nn.BatchNorm3d(4)
else:
raise ValueError("Dim should only be 2 or 3")
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.obs1(x)
x = self.linear(x)
x = self.obs2(x)
x = self.bn(x)
x = self.relu(x)
return x
class HighDimensionNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.obs1 = ModelReportObserver()
self.fc1 = torch.nn.Linear(3, 7)
self.block1 = ModifiedThreeOps(3)
self.fc2 = torch.nn.Linear(3, 7)
self.block2 = ModifiedThreeOps(3)
self.fc3 = torch.nn.Linear(3, 7)
def forward(self, x):
x = self.obs1(x)
x = self.fc1(x)
x = self.block1(x)
x = self.fc2(x)
y = self.block2(x)
y = self.fc3(y)
z = x + y
z = F.relu(z)
return z
# the purpose of this test is to give the observers a variety of data examples
# initialize the model
models = [
self.NestedModifiedSingleLayerLinear(),
LargerIncludeNestModel(),
ModifiedThreeOps(2),
HighDimensionNet(),
]
# get some number of epochs and batches
num_epochs = 10
num_batches = 15
input_shapes = [(1, 5), (1, 5), (2, 3, 7), (4, 1, 8, 3)]
# generate the desired inputs
inputs = []
for shape in input_shapes:
ex_input = torch.randn((num_batches, *shape))
inputs.append(ex_input)
# run it through the model and do general tests
for index, model in enumerate(models):
self.run_model_and_common_checks(model, inputs[index], num_epochs, num_batches)
"""
Partition on domain / things to test
There is only a single test case for now.
This will be more thoroughly tested with the implementation of the full end to end tool coming soon.
"""
class TestFxModelReportDetectDynamicStatic(QuantizationTestCase):
@skipIfNoFBGEMM
def test_nested_detection_case(self):
class SingleLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.linear(x)
return x
class TwoBlockNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.block1 = SingleLinear()
self.block2 = SingleLinear()
def forward(self, x):
x = self.block1(x)
y = self.block2(x)
z = x + y
z = F.relu(z)
return z
with override_quantized_engine('fbgemm'):
# create model, example input, and qconfig mapping
torch.backends.quantized.engine = "fbgemm"
model = TwoBlockNet()
example_input = torch.randint(-10, 0, (1, 3, 3, 3))
example_input = example_input.to(torch.float)
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig("fbgemm"))
# prep model and select observer
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
obs_ctr = ModelReportObserver
# find layer to attach to and store
linear_fqn = "block2.linear" # fqn of target linear
target_linear = None
for node in model_prep.graph.nodes:
if node.target == linear_fqn:
target_linear = node
break
# insert into both module and graph pre and post
# set up to insert before target_linear (pre_observer)
with model_prep.graph.inserting_before(target_linear):
obs_to_insert = obs_ctr()
pre_obs_fqn = linear_fqn + ".model_report_pre_observer"
model_prep.add_submodule(pre_obs_fqn, obs_to_insert)
model_prep.graph.create_node(op="call_module", target=pre_obs_fqn, args=target_linear.args)
# set up and insert after the target_linear (post_observer)
with model_prep.graph.inserting_after(target_linear):
obs_to_insert = obs_ctr()
post_obs_fqn = linear_fqn + ".model_report_post_observer"
model_prep.add_submodule(post_obs_fqn, obs_to_insert)
model_prep.graph.create_node(op="call_module", target=post_obs_fqn, args=(target_linear,))
# need to recompile module after submodule added and pass input through
model_prep.recompile()
num_iterations = 10
for i in range(num_iterations):
if i % 2 == 0:
example_input = torch.randint(-10, 0, (1, 3, 3, 3)).to(torch.float)
else:
example_input = torch.randint(0, 10, (1, 3, 3, 3)).to(torch.float)
model_prep(example_input)
# run it through the dynamic vs static detector
dynamic_vs_static_detector = DynamicStaticDetector()
dynam_vs_stat_str, dynam_vs_stat_dict = dynamic_vs_static_detector.generate_detector_report(model_prep)
# one of the stats should be stationary, and the other non-stationary
# as a result, dynamic should be recommended
data_dist_info = [
dynam_vs_stat_dict[linear_fqn][DynamicStaticDetector.PRE_OBS_DATA_DIST_KEY],
dynam_vs_stat_dict[linear_fqn][DynamicStaticDetector.POST_OBS_DATA_DIST_KEY],
]
self.assertTrue("stationary" in data_dist_info)
self.assertTrue("non-stationary" in data_dist_info)
self.assertTrue(dynam_vs_stat_dict[linear_fqn]["dynamic_recommended"])
class TestFxModelReportClass(QuantizationTestCase):
@skipIfNoFBGEMM
def test_constructor(self):
"""
Tests the constructor of the ModelReport class.
Specifically looks at:
- The desired reports
- Ensures that the observers of interest are properly initialized
"""
with override_quantized_engine('fbgemm'):
# set the backend for this test
torch.backends.quantized.engine = "fbgemm"
backend = torch.backends.quantized.engine
# create a model
model = ThreeOps()
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, model.get_example_inputs()[0])
# make an example set of detectors
test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)}
# initialize with an empty detector
model_report = ModelReport(model_prep, test_detector_set)
# make sure internal valid reports matches
detector_name_set = {detector.get_detector_name() for detector in test_detector_set}
self.assertEqual(model_report.get_desired_reports_names(), detector_name_set)
# now attempt with no valid reports, should raise error
with self.assertRaises(ValueError):
model_report = ModelReport(model, set())
# number of expected obs of interest entries
num_expected_entries = len(test_detector_set)
self.assertEqual(len(model_report.get_observers_of_interest()), num_expected_entries)
for value in model_report.get_observers_of_interest().values():
self.assertEqual(len(value), 0)
@skipIfNoFBGEMM
def test_prepare_model_callibration(self):
"""
Tests model_report.prepare_detailed_calibration that prepares the model for callibration
Specifically looks at:
- Whether observers are properly inserted into regular nn.Module
- Whether the target and the arguments of the observers are proper
- Whether the internal representation of observers of interest is updated
"""
with override_quantized_engine('fbgemm'):
# create model report object
# create model
model = TwoThreeOps()
# make an example set of detectors
torch.backends.quantized.engine = "fbgemm"
backend = torch.backends.quantized.engine
test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)}
# initialize with an empty detector
# prepare the model
example_input = model.get_example_inputs()[0]
current_backend = torch.backends.quantized.engine
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
model_report = ModelReport(model_prep, test_detector_set)
# prepare the model for callibration
prepared_for_callibrate_model = model_report.prepare_detailed_calibration()
# see whether observers properly in regular nn.Module
# there should be 4 observers present in this case
modules_observer_cnt = 0
for fqn, module in prepared_for_callibrate_model.named_modules():
if isinstance(module, ModelReportObserver):
modules_observer_cnt += 1
self.assertEqual(modules_observer_cnt, 4)
model_report_str_check = "model_report"
# also make sure arguments for observers in the graph are proper
for node in prepared_for_callibrate_model.graph.nodes:
# not all node targets are strings, so check
if isinstance(node.target, str) and model_report_str_check in node.target:
# if pre-observer has same args as the linear (next node)
if "pre_observer" in node.target:
self.assertEqual(node.args, node.next.args)
# if post-observer, args are the target linear (previous node)
if "post_observer" in node.target:
self.assertEqual(node.args, (node.prev,))
# ensure model_report observers of interest updated
# there should be two entries
self.assertEqual(len(model_report.get_observers_of_interest()), 2)
for detector in test_detector_set:
self.assertTrue(detector.get_detector_name() in model_report.get_observers_of_interest().keys())
# get number of entries for this detector
detector_obs_of_interest_fqns = model_report.get_observers_of_interest()[detector.get_detector_name()]
# assert that the per channel detector has 0 and the dynamic static has 4
if isinstance(detector, PerChannelDetector):
self.assertEqual(len(detector_obs_of_interest_fqns), 0)
elif isinstance(detector, DynamicStaticDetector):
self.assertEqual(len(detector_obs_of_interest_fqns), 4)
# ensure that we can prepare for callibration only once
with self.assertRaises(ValueError):
prepared_for_callibrate_model = model_report.prepare_detailed_calibration()
def get_module_and_graph_cnts(self, callibrated_fx_module):
r"""
Calculates number of ModelReportObserver modules in the model as well as the graph structure.
Returns a tuple of two elements:
int: The number of ModelReportObservers found in the model
int: The number of model_report nodes found in the graph
"""
# get the number of observers stored as modules
modules_observer_cnt = 0
for fqn, module in callibrated_fx_module.named_modules():
if isinstance(module, ModelReportObserver):
modules_observer_cnt += 1
# get number of observers in the graph
model_report_str_check = "model_report"
graph_observer_cnt = 0
# also make sure arguments for observers in the graph are proper
for node in callibrated_fx_module.graph.nodes:
# not all node targets are strings, so check
if isinstance(node.target, str) and model_report_str_check in node.target:
# increment if we found a graph observer
graph_observer_cnt += 1
return (modules_observer_cnt, graph_observer_cnt)
@skipIfNoFBGEMM
def test_generate_report(self):
"""
Tests model_report.generate_model_report to ensure report generation
Specifically looks at:
- Whether correct number of reports are being generated
- Whether observers are being properly removed if specified
- Whether correct blocking from generating report twice if obs removed
"""
with override_quantized_engine('fbgemm'):
# set the backend for this test
torch.backends.quantized.engine = "fbgemm"
# check whether the correct number of reports are being generated
filled_detector_set = {DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)}
single_detector_set = {DynamicStaticDetector()}
# create our models
model_full = TwoThreeOps()
model_single = TwoThreeOps()
# prepare and callibrate two different instances of same model
# prepare the model
example_input = model_full.get_example_inputs()[0]
current_backend = torch.backends.quantized.engine
q_config_mapping = QConfigMapping()
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
model_prep_full = quantize_fx.prepare_fx(model_full, q_config_mapping, example_input)
model_prep_single = quantize_fx.prepare_fx(model_single, q_config_mapping, example_input)
# initialize one with filled detector
model_report_full = ModelReport(model_prep_full, filled_detector_set)
# initialize another with a single detector set
model_report_single = ModelReport(model_prep_single, single_detector_set)
# prepare the models for callibration
prepared_for_callibrate_model_full = model_report_full.prepare_detailed_calibration()
prepared_for_callibrate_model_single = model_report_single.prepare_detailed_calibration()
# now callibrate the two models
num_iterations = 10
for i in range(num_iterations):
example_input = torch.tensor(torch.randint(100, (1, 3, 3, 3)), dtype=torch.float)
prepared_for_callibrate_model_full(example_input)
prepared_for_callibrate_model_single(example_input)
# now generate the reports
model_full_report = model_report_full.generate_model_report(True)
model_single_report = model_report_single.generate_model_report(False)
# check that sizes are appropriate
self.assertEqual(len(model_full_report), len(filled_detector_set))
self.assertEqual(len(model_single_report), len(single_detector_set))
# make sure observers are being properly removed for full report since we put flag in
modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_full)
self.assertEqual(modules_observer_cnt, 0) # assert no more observer modules
self.assertEqual(graph_observer_cnt, 0) # assert no more observer nodes in graph
# make sure observers aren't being removed for single report since not specified
modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_single)
self.assertNotEqual(modules_observer_cnt, 0)
self.assertNotEqual(graph_observer_cnt, 0)
# make sure error when try to rerun report generation for full report but not single report
with self.assertRaises(Exception):
model_full_report = model_report_full.generate_model_report(
prepared_for_callibrate_model_full, False
)
# make sure we don't run into error for single report
model_single_report = model_report_single.generate_model_report(False)
@skipIfNoFBGEMM
def test_generate_visualizer(self):
"""
Tests that the ModelReport class can properly create the ModelReportVisualizer instance
Checks that:
- Correct number of modules are represented
- Modules are sorted
- Correct number of features for each module
"""
with override_quantized_engine('fbgemm'):
# set the backend for this test
torch.backends.quantized.engine = "fbgemm"
# test with multiple detectors
detector_set = set()
detector_set.add(OutlierDetector(reference_percentile=0.95))
detector_set.add(InputWeightEqualizationDetector(0.5))
model = TwoThreeOps()
# get tst model and callibrate
prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
model, detector_set, model.get_example_inputs()[0]
)
# now we actually callibrate the model
example_input = model.get_example_inputs()[0]
example_input = example_input.to(torch.float)
prepared_for_callibrate_model(example_input)
# try to visualize without generating report, should throw error
with self.assertRaises(Exception):
mod_rep_visualizaiton = mod_report.generate_visualizer()
# now get the report by running it through ModelReport instance
generated_report = mod_report.generate_model_report(remove_inserted_observers=False)
# now we get the visualizer should not error
mod_rep_visualizer: ModelReportVisualizer = mod_report.generate_visualizer()
# since we tested with outlier detector, which looks at every base level module
# should be six entries in the ordered dict
mod_fqns_to_features = mod_rep_visualizer.generated_reports
self.assertEqual(len(mod_fqns_to_features), 6)
# outlier detector has 9 feature per module
# input-weight has 12 features per module
# there are 1 common data point, so should be 12 + 9 - 1 = 20 unique features per common modules
# all linears will be common
for module_fqn in mod_fqns_to_features:
if ".linear" in module_fqn:
linear_info = mod_fqns_to_features[module_fqn]
self.assertEqual(len(linear_info), 20)
@skipIfNoFBGEMM
def test_qconfig_mapping_generation(self):
"""
Tests for generation of qconfigs by ModelReport API
- Tests that qconfigmapping is generated
- Tests that mappings include information for for relavent modules
"""
with override_quantized_engine('fbgemm'):
# set the backend for this test
torch.backends.quantized.engine = "fbgemm"
# test with multiple detectors
detector_set = set()
detector_set.add(PerChannelDetector())
detector_set.add(DynamicStaticDetector())
model = TwoThreeOps()
# get tst model and callibrate
prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
model, detector_set, model.get_example_inputs()[0]
)
# now we actually callibrate the models
example_input = model.get_example_inputs()[0]
example_input = example_input.to(torch.float)
prepared_for_callibrate_model(example_input)
# get the mapping without error
qconfig_mapping = mod_report.generate_qconfig_mapping()
# now get the report by running it through ModelReport instance
generated_report = mod_report.generate_model_report(remove_inserted_observers=False)
# get the visualizer so we can get access to reformatted reports by module fqn
mod_reports_by_fqn = mod_report.generate_visualizer().generated_reports
# compare the entries of the mapping to those of the report
# we should have the same number of entries
self.assertEqual(len(qconfig_mapping.module_name_qconfigs), len(mod_reports_by_fqn))
# for the non_empty one, we should have 2 because we have only applicable linears
# so should have suggestions for each module named
self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2)
# only two linears, make sure per channel min max for weight since fbgemm
# also static distribution since a simple single callibration
for key in qconfig_mapping.module_name_qconfigs:
config = qconfig_mapping.module_name_qconfigs[key]
self.assertEqual(config.weight, default_per_channel_weight_observer)
self.assertEqual(config.activation, default_observer)
# make sure these can actually be used to prepare the model
prepared = quantize_fx.prepare_fx(TwoThreeOps(), qconfig_mapping, example_input)
# now convert the model to ensure no errors in conversion
converted = quantize_fx.convert_fx(prepared)
@skipIfNoFBGEMM
def test_equalization_mapping_generation(self):
"""
Tests for generation of qconfigs by ModelReport API
- Tests that equalization config generated when input-weight equalization detector used
- Tests that mappings include information for for relavent modules
"""
with override_quantized_engine('fbgemm'):
# set the backend for this test
torch.backends.quantized.engine = "fbgemm"
# test with multiple detectors
detector_set = set()
detector_set.add(InputWeightEqualizationDetector(0.6))
model = TwoThreeOps()
# get tst model and callibrate
prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
model, detector_set, model.get_example_inputs()[0]
)
# now we actually callibrate the models
example_input = model.get_example_inputs()[0]
example_input = example_input.to(torch.float)
prepared_for_callibrate_model(example_input)
# get the mapping without error
qconfig_mapping = mod_report.generate_qconfig_mapping()
equalization_mapping = mod_report.generate_equalization_mapping()
# tests a lot more simple for the equalization mapping
# shouldn't have any equalization suggestions for this case
self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2)
# make sure these can actually be used to prepare the model
prepared = quantize_fx.prepare_fx(
TwoThreeOps(),
qconfig_mapping,
example_input,
_equalization_config=equalization_mapping
)
# now convert the model to ensure no errors in conversion
converted = quantize_fx.convert_fx(prepared)
class TestFxDetectInputWeightEqualization(QuantizationTestCase):
class SimpleConv(torch.nn.Module):
def __init__(self, con_dims):
super().__init__()
self.relu = torch.nn.ReLU()
self.conv = torch.nn.Conv2d(con_dims[0], con_dims[1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
class TwoBlockComplexNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.block1 = TestFxDetectInputWeightEqualization.SimpleConv((3, 32))
self.block2 = TestFxDetectInputWeightEqualization.SimpleConv((3, 3))
self.conv = torch.nn.Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)
self.linear = torch.nn.Linear(768, 10)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.block1(x)
x = self.conv(x)
y = self.block2(x)
y = y.repeat(1, 1, 2, 2)
z = x + y
z = z.flatten(start_dim=1)
z = self.linear(z)
z = self.relu(z)
return z
def get_fusion_modules(self):
return [['conv', 'relu']]
def get_example_inputs(self):
return (torch.randn((1, 3, 28, 28)),)
class ReluOnly(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.relu(x)
return x
def get_example_inputs(self):
return (torch.arange(27).reshape((1, 3, 3, 3)),)
def _get_prepped_for_calibration_model(self, model, detector_set, fused=False):
r"""Returns a model that has been prepared for callibration and corresponding model_report"""
# pass in necessary inputs to helper
example_input = model.get_example_inputs()[0]
return _get_prepped_for_calibration_model_helper(model, detector_set, example_input, fused)
@skipIfNoFBGEMM
def test_input_weight_equalization_determine_points(self):
# use fbgemm and create our model instance
# then create model report instance with detector
with override_quantized_engine('fbgemm'):
detector_set = {InputWeightEqualizationDetector(0.5)}
# get tst model and callibrate
non_fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set)
fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set, fused=True)
# reporter should still give same counts even for fused model
for prepared_for_callibrate_model, mod_report in [non_fused, fused]:
# supported modules to check
mods_to_check = {nn.Linear, nn.Conv2d}
# get the set of all nodes in the graph their fqns
node_fqns = {node.target for node in prepared_for_callibrate_model.graph.nodes}
# there should be 4 node fqns that have the observer inserted
correct_number_of_obs_inserted = 4
number_of_obs_found = 0
obs_name_to_find = InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME
for node in prepared_for_callibrate_model.graph.nodes:
# if the obs name is inside the target, we found an observer
if obs_name_to_find in str(node.target):
number_of_obs_found += 1
self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted)
# assert that each of the desired modules have the observers inserted
for fqn, module in prepared_for_callibrate_model.named_modules():
# check if module is a supported module
is_in_include_list = sum(list(map(lambda x: isinstance(module, x), mods_to_check))) > 0
if is_in_include_list:
# make sure it has the observer attribute
self.assertTrue(hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
else:
# if it's not a supported type, it shouldn't have observer attached
self.assertTrue(not hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
@skipIfNoFBGEMM
def test_input_weight_equalization_report_gen(self):
# use fbgemm and create our model instance
# then create model report instance with detector
with override_quantized_engine('fbgemm'):
test_input_weight_detector = InputWeightEqualizationDetector(0.4)
detector_set = {test_input_weight_detector}
model = self.TwoBlockComplexNet()
# prepare the model for callibration
prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(
model, detector_set
)
# now we actually callibrate the model
example_input = model.get_example_inputs()[0]
example_input = example_input.to(torch.float)
prepared_for_callibrate_model(example_input)
# now get the report by running it through ModelReport instance
generated_report = model_report.generate_model_report(True)
# check that sizes are appropriate only 1 detector
self.assertEqual(len(generated_report), 1)
# get the specific report for input weight equalization
input_weight_str, input_weight_dict = generated_report[test_input_weight_detector.get_detector_name()]
# we should have 5 layers looked at since 4 conv / linear layers
self.assertEqual(len(input_weight_dict), 4)
# we can validate that the max and min values of the detector were recorded properly for the first one
# this is because no data has been processed yet, so it should be values from original input
example_input = example_input.reshape((3, 28, 28)) # reshape input
for module_fqn in input_weight_dict:
# look for the first linear
if "block1.linear" in module_fqn:
block_1_lin_recs = input_weight_dict[module_fqn]
# get input range info and the channel axis
ch_axis = block_1_lin_recs[InputWeightEqualizationDetector.CHANNEL_KEY]
# ensure that the min and max values extracted match properly
example_min, example_max = torch.aminmax(example_input, dim=ch_axis)
dimension_min = torch.amin(example_min, dim=ch_axis)
dimension_max = torch.amax(example_max, dim=ch_axis)
# make sure per channel min and max are as expected
min_per_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
min_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MIN_KEY
max_per_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
max_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MAX_KEY
per_channel_min = block_1_lin_recs[min_per_key]
per_channel_max = block_1_lin_recs[max_per_key]
self.assertEqual(per_channel_min, dimension_min)
self.assertEqual(per_channel_max, dimension_max)
# make sure per channel min and max are as expected
min_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
min_key += InputWeightEqualizationDetector.GLOBAL_MIN_KEY
max_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
max_key += InputWeightEqualizationDetector.GLOBAL_MAX_KEY
# make sure the global min and max were correctly recorded and presented
global_min = block_1_lin_recs[min_key]
global_max = block_1_lin_recs[max_key]
self.assertEqual(global_min, min(dimension_min))
self.assertEqual(global_max, max(dimension_max))
input_ratio = torch.sqrt((per_channel_max - per_channel_min) / (global_max - global_min))
# ensure comparision stat passed back is sqrt of range ratios
# need to get the weight ratios first
# make sure per channel min and max are as expected
min_per_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
min_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MIN_KEY
max_per_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
max_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MAX_KEY
# get weight per channel and global info
per_channel_min = block_1_lin_recs[min_per_key]
per_channel_max = block_1_lin_recs[max_per_key]
# make sure per channel min and max are as expected
min_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
min_key += InputWeightEqualizationDetector.GLOBAL_MIN_KEY
max_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
max_key += InputWeightEqualizationDetector.GLOBAL_MAX_KEY
global_min = block_1_lin_recs[min_key]
global_max = block_1_lin_recs[max_key]
weight_ratio = torch.sqrt((per_channel_max - per_channel_min) / (global_max - global_min))
# also get comp stat for this specific layer
comp_stat = block_1_lin_recs[InputWeightEqualizationDetector.COMP_METRIC_KEY]
weight_to_input_ratio = weight_ratio / input_ratio
self.assertEqual(comp_stat, weight_to_input_ratio)
# only looking at the first example so can break
break
@skipIfNoFBGEMM
def test_input_weight_equalization_report_gen_empty(self):
# tests report gen on a model that doesn't have any layers
# use fbgemm and create our model instance
# then create model report instance with detector
with override_quantized_engine('fbgemm'):
test_input_weight_detector = InputWeightEqualizationDetector(0.4)
detector_set = {test_input_weight_detector}
model = self.ReluOnly()
# prepare the model for callibration
prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(model, detector_set)
# now we actually callibrate the model
example_input = model.get_example_inputs()[0]
example_input = example_input.to(torch.float)
prepared_for_callibrate_model(example_input)
# now get the report by running it through ModelReport instance
generated_report = model_report.generate_model_report(True)
# check that sizes are appropriate only 1 detector
self.assertEqual(len(generated_report), 1)
# get the specific report for input weight equalization
input_weight_str, input_weight_dict = generated_report[test_input_weight_detector.get_detector_name()]
# we should have 0 layers since there is only a Relu
self.assertEqual(len(input_weight_dict), 0)
# make sure that the string only has two lines, as should be if no suggestions
self.assertEqual(input_weight_str.count("\n"), 2)
class TestFxDetectOutliers(QuantizationTestCase):
class LargeBatchModel(torch.nn.Module):
def __init__(self, param_size):
super().__init__()
self.param_size = param_size
self.linear = torch.nn.Linear(param_size, param_size)
self.relu_1 = torch.nn.ReLU()
self.conv = torch.nn.Conv2d(param_size, param_size, 1)
self.relu_2 = torch.nn.ReLU()
def forward(self, x):
x = self.linear(x)
x = self.relu_1(x)
x = self.conv(x)
x = self.relu_2(x)
return x
def get_example_inputs(self):
param_size = self.param_size
return (torch.randn((1, param_size, param_size, param_size)),)
def get_outlier_inputs(self):
param_size = self.param_size
random_vals = torch.randn((1, param_size, param_size, param_size))
# change one in some of them to be a massive value
random_vals[:, 0:param_size:2, 0, 3] = torch.tensor([3.28e8])
return (random_vals,)
def _get_prepped_for_calibration_model(self, model, detector_set, use_outlier_data=False):
r"""Returns a model that has been prepared for callibration and corresponding model_report"""
# call the general helper function to callibrate
example_input = model.get_example_inputs()[0]
# if we specifically want to test data with outliers replace input
if use_outlier_data:
example_input = model.get_outlier_inputs()[0]
return _get_prepped_for_calibration_model_helper(model, detector_set, example_input)
@skipIfNoFBGEMM
def test_outlier_detection_determine_points(self):
# use fbgemm and create our model instance
# then create model report instance with detector
# similar to test for InputWeightEqualization but key differences that made refactoring not viable
# not explicitly testing fusion because fx workflow automatically
with override_quantized_engine('fbgemm'):
detector_set = {OutlierDetector(reference_percentile=0.95)}
# get tst model and callibrate
prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
self.LargeBatchModel(param_size=128), detector_set
)
# supported modules to check
mods_to_check = {nn.Linear, nn.Conv2d, nn.ReLU}
# there should be 4 node fqns that have the observer inserted
correct_number_of_obs_inserted = 4
number_of_obs_found = 0
obs_name_to_find = InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME
number_of_obs_found = sum(
[1 if obs_name_to_find in str(node.target) else 0 for node in prepared_for_callibrate_model.graph.nodes]
)
self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted)
# assert that each of the desired modules have the observers inserted
for fqn, module in prepared_for_callibrate_model.named_modules():
# check if module is a supported module
is_in_include_list = isinstance(module, tuple(mods_to_check))
if is_in_include_list:
# make sure it has the observer attribute
self.assertTrue(hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
else:
# if it's not a supported type, it shouldn't have observer attached
self.assertTrue(not hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
@skipIfNoFBGEMM
def test_no_outlier_report_gen(self):
# use fbgemm and create our model instance
# then create model report instance with detector
with override_quantized_engine('fbgemm'):
# test with multiple detectors
outlier_detector = OutlierDetector(reference_percentile=0.95)
dynamic_static_detector = DynamicStaticDetector(tolerance=0.5)
param_size: int = 4
detector_set = {outlier_detector, dynamic_static_detector}
model = self.LargeBatchModel(param_size=param_size)
# get tst model and callibrate
prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
model, detector_set
)
# now we actually callibrate the model
example_input = model.get_example_inputs()[0]
example_input = example_input.to(torch.float)
prepared_for_callibrate_model(example_input)
# now get the report by running it through ModelReport instance
generated_report = mod_report.generate_model_report(True)
# check that sizes are appropriate only 2 detectors
self.assertEqual(len(generated_report), 2)
# get the specific report for input weight equalization
outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()]
# we should have 5 layers looked at since 4 conv + linear + relu
self.assertEqual(len(outlier_dict), 4)
# assert the following are true for all the modules
for module_fqn in outlier_dict:
# get the info for the specific module
module_dict = outlier_dict[module_fqn]
# there really should not be any outliers since we used a normal distribution to perform this calculation
outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
self.assertEqual(sum(outlier_info), 0)
# ensure that the number of ratios and batches counted is the same as the number of params
self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size)
self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size)
@skipIfNoFBGEMM
def test_all_outlier_report_gen(self):
# make the percentile 0 and the ratio 1, and then see that everything is outlier according to it
# use fbgemm and create our model instance
# then create model report instance with detector
with override_quantized_engine('fbgemm'):
# create detector of interest
outlier_detector = OutlierDetector(ratio_threshold=1, reference_percentile=0)
param_size: int = 16
detector_set = {outlier_detector}
model = self.LargeBatchModel(param_size=param_size)
# get tst model and callibrate
prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
model, detector_set
)
# now we actually callibrate the model
example_input = model.get_example_inputs()[0]
example_input = example_input.to(torch.float)
prepared_for_callibrate_model(example_input)
# now get the report by running it through ModelReport instance
generated_report = mod_report.generate_model_report(True)
# check that sizes are appropriate only 1 detector
self.assertEqual(len(generated_report), 1)
# get the specific report for input weight equalization
outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()]
# we should have 5 layers looked at since 4 conv + linear + relu
self.assertEqual(len(outlier_dict), 4)
# assert the following are true for all the modules
for module_fqn in outlier_dict:
# get the info for the specific module
module_dict = outlier_dict[module_fqn]
# everything should be an outlier because we said that the max should be equal to the min for all of them
# however we will just test and say most should be in case we have several 0 channel values
outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
assert sum(outlier_info) >= len(outlier_info) / 2
# ensure that the number of ratios and batches counted is the same as the number of params
self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size)
self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size)
@skipIfNoFBGEMM
def test_multiple_run_consistent_spike_outlier_report_gen(self):
# specifically make a row really high consistently in the number of batches that you are testing and try that
# generate report after just 1 run, and after many runs (30) and make sure above minimum threshold is there
with override_quantized_engine('fbgemm'):
# detector of interest
outlier_detector = OutlierDetector(reference_percentile=0.95)
param_size: int = 8
detector_set = {outlier_detector}
model = self.LargeBatchModel(param_size=param_size)
# get tst model and callibrate
prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
model, detector_set, use_outlier_data=True
)
# now we actually callibrate the model
example_input = model.get_outlier_inputs()[0]
example_input = example_input.to(torch.float)
# now callibrate minimum 30 times to make it above minimum threshold
for i in range(30):
example_input = model.get_outlier_inputs()[0]
example_input = example_input.to(torch.float)
# make 2 of the batches to have zero channel
if i % 14 == 0:
# make one channel constant
example_input[0][1] = torch.zeros_like(example_input[0][1])
prepared_for_callibrate_model(example_input)
# now get the report by running it through ModelReport instance
generated_report = mod_report.generate_model_report(True)
# check that sizes are appropriate only 1 detector
self.assertEqual(len(generated_report), 1)
# get the specific report for input weight equalization
outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()]
# we should have 5 layers looked at since 4 conv + linear + relu
self.assertEqual(len(outlier_dict), 4)
# assert the following are true for all the modules
for module_fqn in outlier_dict:
# get the info for the specific module
module_dict = outlier_dict[module_fqn]
# because we ran 30 times, we should have at least a couple be significant
# could be less because some channels could possibly be all 0
sufficient_batches_info = module_dict[OutlierDetector.IS_SUFFICIENT_BATCHES_KEY]
assert sum(sufficient_batches_info) >= len(sufficient_batches_info) / 2
# half of them should be outliers, because we set a really high value every 2 channels
outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
self.assertEqual(sum(outlier_info), len(outlier_info) / 2)
# ensure that the number of ratios and batches counted is the same as the number of params
self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size)
self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size)
# for the first one ensure the per channel max values are what we set
if module_fqn == "linear.0":
# check that the non-zero channel count, at least 2 should be there
# for the first module
counts_info = module_dict[OutlierDetector.CONSTANT_COUNTS_KEY]
assert sum(counts_info) >= 2
# half of the recorded max values should be what we set
matched_max = sum([val == 3.28e8 for val in module_dict[OutlierDetector.MAX_VALS_KEY]])
self.assertEqual(matched_max, param_size / 2)
class TestFxModelReportVisualizer(QuantizationTestCase):
def _callibrate_and_generate_visualizer(self, model, prepared_for_callibrate_model, mod_report):
r"""
Callibrates the passed in model, generates report, and returns the visualizer
"""
# now we actually callibrate the model
example_input = model.get_example_inputs()[0]
example_input = example_input.to(torch.float)
prepared_for_callibrate_model(example_input)
# now get the report by running it through ModelReport instance
generated_report = mod_report.generate_model_report(remove_inserted_observers=False)
# now we get the visualizer should not error
mod_rep_visualizer: ModelReportVisualizer = mod_report.generate_visualizer()
return mod_rep_visualizer
@skipIfNoFBGEMM
def test_get_modules_and_features(self):
"""
Tests the get_all_unique_module_fqns and get_all_unique_feature_names methods of
ModelReportVisualizer
Checks whether returned sets are of proper size and filtered properly
"""
with override_quantized_engine('fbgemm'):
# set the backend for this test
torch.backends.quantized.engine = "fbgemm"
# test with multiple detectors
detector_set = set()
detector_set.add(OutlierDetector(reference_percentile=0.95))
detector_set.add(InputWeightEqualizationDetector(0.5))
model = TwoThreeOps()
# get tst model and callibrate
prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
model, detector_set, model.get_example_inputs()[0]
)
mod_rep_visualizer: ModelReportVisualizer = self._callibrate_and_generate_visualizer(
model, prepared_for_callibrate_model, mod_report
)
# ensure the module fqns match the ones given by the get_all_unique_feature_names method
actual_model_fqns = set(mod_rep_visualizer.generated_reports.keys())
returned_model_fqns = mod_rep_visualizer.get_all_unique_module_fqns()
self.assertEqual(returned_model_fqns, actual_model_fqns)
# now ensure that features are all properly returned
# all the linears have all the features for two detectors
# can use those as check that method is working reliably
b_1_linear_features = mod_rep_visualizer.generated_reports["block1.linear"]
# first test all features
returned_all_feats = mod_rep_visualizer.get_all_unique_feature_names(False)
self.assertEqual(returned_all_feats, set(b_1_linear_features.keys()))
# now test plottable features
plottable_set = set()
for feature_name in b_1_linear_features:
if type(b_1_linear_features[feature_name]) == torch.Tensor:
plottable_set.add(feature_name)
returned_plottable_feats = mod_rep_visualizer.get_all_unique_feature_names()
self.assertEqual(returned_plottable_feats, plottable_set)
def _prep_visualizer_helper(self):
r"""
Returns a mod rep visualizer that we test in various ways
"""
# set backend for test
torch.backends.quantized.engine = "fbgemm"
# test with multiple detectors
detector_set = set()
detector_set.add(OutlierDetector(reference_percentile=0.95))
detector_set.add(InputWeightEqualizationDetector(0.5))
model = TwoThreeOps()
# get tst model and callibrate
prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
model, detector_set, model.get_example_inputs()[0]
)
mod_rep_visualizer: ModelReportVisualizer = self._callibrate_and_generate_visualizer(
model, prepared_for_callibrate_model, mod_report
)
return mod_rep_visualizer
@skipIfNoFBGEMM
def test_generate_tables_match_with_report(self):
"""
Tests the generate_table_view()
ModelReportVisualizer
Checks whether the generated dict has proper information
Visual check that the tables look correct performed during testing
"""
with override_quantized_engine('fbgemm'):
# get the visualizer
mod_rep_visualizer = self._prep_visualizer_helper()
table_dict = mod_rep_visualizer.generate_filtered_tables()
# test primarily the dict since it has same info as str
tensor_headers, tensor_table = table_dict[ModelReportVisualizer.TABLE_TENSOR_KEY]
channel_headers, channel_table = table_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
# these two together should be the same as the generated report info in terms of keys
tensor_info_modules = {row[1] for row in tensor_table}
channel_info_modules = {row[1] for row in channel_table}
combined_modules: Set = tensor_info_modules.union(channel_info_modules)
generated_report_keys: Set = set(mod_rep_visualizer.generated_reports.keys())
self.assertEqual(combined_modules, generated_report_keys)
@skipIfNoFBGEMM
def test_generate_tables_no_match(self):
"""
Tests the generate_table_view()
ModelReportVisualizer
Checks whether the generated dict has proper information
Visual check that the tables look correct performed during testing
"""
with override_quantized_engine('fbgemm'):
# get the visualizer
mod_rep_visualizer = self._prep_visualizer_helper()
# try a random filter and make sure that there are no rows for either table
empty_tables_dict = mod_rep_visualizer.generate_filtered_tables(module_fqn_filter="random not there module")
# test primarily the dict since it has same info as str
tensor_headers, tensor_table = empty_tables_dict[ModelReportVisualizer.TABLE_TENSOR_KEY]
channel_headers, channel_table = empty_tables_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
tensor_info_modules = {row[1] for row in tensor_table}
channel_info_modules = {row[1] for row in channel_table}
combined_modules: Set = tensor_info_modules.union(channel_info_modules)
self.assertEqual(len(combined_modules), 0) # should be no matching modules
@skipIfNoFBGEMM
def test_generate_tables_single_feat_match(self):
"""
Tests the generate_table_view()
ModelReportVisualizer
Checks whether the generated dict has proper information
Visual check that the tables look correct performed during testing
"""
with override_quantized_engine('fbgemm'):
# get the visualizer
mod_rep_visualizer = self._prep_visualizer_helper()
# try a matching filter for feature and make sure only those features show up
# if we filter to a very specific feature name, should only have 1 additional column in each table row
single_feat_dict = mod_rep_visualizer.generate_filtered_tables(feature_filter=OutlierDetector.MAX_VALS_KEY)
# test primarily the dict since it has same info as str
tensor_headers, tensor_table = single_feat_dict[ModelReportVisualizer.TABLE_TENSOR_KEY]
channel_headers, channel_table = single_feat_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
# get the number of features in each of these
tensor_info_features = len(tensor_headers)
channel_info_features = len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
# make sure that there are no tensor features, and that there is one channel level feature
self.assertEqual(tensor_info_features, 0)
self.assertEqual(channel_info_features, 1)
def _get_prepped_for_calibration_model_helper(model, detector_set, example_input, fused: bool = False):
r"""Returns a model that has been prepared for callibration and corresponding model_report"""
# set the backend for this test
torch.backends.quantized.engine = "fbgemm"
# create model instance and prepare it
example_input = example_input.to(torch.float)
q_config_mapping = torch.ao.quantization.get_default_qconfig_mapping()
# if they passed in fusion paramter, make sure to test that
if fused:
model = torch.ao.quantization.fuse_modules(model, model.get_fusion_modules())
model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
model_report = ModelReport(model_prep, detector_set)
# prepare the model for callibration
prepared_for_callibrate_model = model_report.prepare_detailed_calibration()
return (prepared_for_callibrate_model, model_report)