Files
pytorch/test/quantization/test_numeric_suite_fx.py
Vasiliy Kuznetsov 444b318a90 ns for fx: add linear-relu mod weight extraction (#55080)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55080

Adds support for extracting weights of linear-relu module pattern.

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs
```

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D27474701

fbshipit-source-id: 69ceaadc28d7fdcebd16d519367274d348b0dd29
2021-04-14 09:02:51 -07:00

769 lines
28 KiB
Python

import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import default_dynamic_qconfig
import torch.nn.quantized as nnq
toq = torch.ops.quantized
from torch.quantization.quantize_fx import (
convert_fx,
prepare_fx,
prepare_qat_fx,
)
from torch.testing._internal.common_quantization import (
ConvBnModel,
ConvBnReLUModel,
ConvModel,
QuantizationTestCase,
skipIfNoFBGEMM,
SingleLayerLinearDynamicModel,
SingleLayerLinearModel,
LSTMwithHiddenDynamicModel,
SparseNNModel,
skip_if_no_torchvision,
)
from torch.testing._internal.common_quantization import NodeSpec as ns
from torch.testing._internal.common_quantized import override_qengines
from torch.quantization.ns.graph_matcher import (
get_matching_subgraph_pairs,
GraphMatchingException,
)
from torch.quantization._numeric_suite_fx import (
extract_weights,
_extract_weights_impl,
add_loggers,
OutputLogger,
add_shadow_loggers,
extract_logger_info,
extract_shadow_logger_info,
)
# Note: these models are not for use outside of this file. While it's good
# to reuse code, we also need to be able to iterate on tests
# quickly when debugging. If a test model has a large number of callsites
# across various different files, speed of debugging on individual test cases
# decreases.
class LinearReluFunctional(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.empty(4, 4))
self.b1 = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
def forward(self, x):
x = F.linear(x, self.w1, self.b1)
x = F.relu(x)
return x
class TestFXGraphMatcher(QuantizationTestCase):
@override_qengines
def test_simple_mod(self):
m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.nn.Conv2d_0':
((nn.Conv2d, nn.Conv2d), (nnq.Conv2d, nnq.Conv2d)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_simple_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.empty(1, 4))
self.b = nn.Parameter(torch.zeros(1))
torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))
def forward(self, x):
return F.linear(x, self.w, self.b)
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.nn.functional.linear_0':
((F.linear, F.linear), (toq.linear, toq.linear))
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_simple_fusion(self):
m = LinearReluFunctional().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.nn.functional.linear_0':
((F.linear, F.relu), (toq.linear_relu, toq.linear_relu)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_simple_mod_multi(self):
m = nn.Sequential(
nn.Sequential(
nn.Conv2d(1, 1, 1),
),
nn.Conv2d(1, 1, 1),
).eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
@override_qengines
def test_simple_tensor_ops(self):
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
z = x + y
return z
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
@override_qengines
def test_matching_failure_node_count(self):
# verify that matching graphs with matching node types but
# different counts of matchable nodes fails
m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
m2 = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval()
mp1 = prepare_fx(m1, {'': torch.quantization.default_qconfig})
mp2 = prepare_fx(m2, {'': torch.quantization.default_qconfig})
with self.assertRaises(GraphMatchingException) as ex:
results = get_matching_subgraph_pairs(mp1, mp2)
@override_qengines
def test_matching_failure_node_type(self):
# verify that matching graphs with non-matching node types fails
m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
m2 = nn.Sequential(nn.Linear(1, 1)).eval()
mp1 = prepare_fx(m1, {'': torch.quantization.default_qconfig})
mp2 = prepare_fx(m2, {'': torch.quantization.default_qconfig})
with self.assertRaises(GraphMatchingException) as ex:
results = get_matching_subgraph_pairs(mp1, mp2)
@override_qengines
def test_nodes_before_cat(self):
# verify that nodes before cat get matched
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x0):
x1 = torch.add(x0, 1.0)
y1 = torch.add(x0, 1.0)
x2 = torch.cat([x1, y1])
return x2
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.cat_0': ((torch.cat, torch.cat), (toq.cat, toq.cat)),
'base_op_torch.add_0': ((torch.add, torch.add), (toq.add, toq.add)),
'base_op_torch.add_1': ((torch.add, torch.add), (toq.add, toq.add)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_dict_return_type(self):
# verify that we can traverse up nodes which return dictionaries
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x0):
x1 = torch.add(x0, 1.0)
y1 = torch.add(x0, 1.0)
z1 = torch.add(x0, 1.0)
a1 = {'x1': x1, 'y1': (y1,), 'z1': [{'key': (z1,)}]}
return a1
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.add_0': ((torch.add, torch.add), (toq.add, toq.add)),
'base_op_torch.add_1': ((torch.add, torch.add), (toq.add, toq.add)),
'base_op_torch.add_2': ((torch.add, torch.add), (toq.add, toq.add)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_nodes_with_equal_types_do_not_get_matched(self):
# verifies that by default, nodes with equivalent types do not get matched.
# This is important for user defined types, for which we do not know
# the weight extraction functions or input type. In the future, this can
# be made configurable.
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = torch.mul(x, x)
x = torch.sigmoid(x)
x = F.relu(x)
return x
m = M().eval()
# prevent conv2 from getting quantized, so we can test
# modules with equal types
qconfig_dict = {
'': torch.quantization.default_qconfig,
'module_name': [('conv2', None)],
}
mp = prepare_fx(m, qconfig_dict)
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
# Conv2 should not be matched because we disabled quantization for it,
# so its type is the same in mp and mq. sigmoid and relu should not be
# matched because they use the same function in mp and mq.
expected_types = {
'base_op_torch.nn.Conv2d_0':
((nn.Conv2d, nn.Conv2d), (nnq.Conv2d, nnq.Conv2d)),
'base_op_torch.mul_0': ((torch.mul, torch.mul), (toq.mul, toq.mul)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
class TestFXGraphMatcherModels(QuantizationTestCase):
@override_qengines
@skip_if_no_torchvision
def test_mobilenet_v2(self):
# verify that mobilenetv2 graph is able to be matched
import torchvision
m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).eval().float()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
@override_qengines
@skip_if_no_torchvision
def test_mobilenet_v2_qat(self):
# verify that mobilenetv2 graph is able to be matched
import torchvision
m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float()
mp = prepare_qat_fx(m, {'': torch.quantization.get_default_qat_qconfig('fbgemm')})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
class FXNumericSuiteQuantizationTestCase(QuantizationTestCase):
def _test_extract_weights(self, m, results_len=0, qconfig_dict=None):
if qconfig_dict is None:
qconfig_dict = {'': torch.quantization.default_qconfig}
mp = prepare_fx(m, qconfig_dict)
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# test both the public API as well as the internal GraphModule API
for extract_weights_fun in (extract_weights, _extract_weights_impl):
results = extract_weights_fun('fp32_prepared', mp, 'int8', mq)
self.assertTrue(
len(results) == results_len,
f"expected len {results_len}, got len {len(results)}")
self.assert_ns_compare_dict_valid(results)
return results
def _test_match_activations(
self, m, data, prepared_expected_node_occurrence=None, results_len=0,
should_log_inputs=False,
qconfig_dict=None,
skip_scripting=False,
):
if qconfig_dict is None:
qconfig_dict = {'': torch.quantization.default_qconfig}
mp = prepare_fx(m, qconfig_dict)
mp(*data)
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
mp_ns, mq_ns = add_loggers(
'fp32_prepared', mp, 'int8', mq, OutputLogger,
should_log_inputs=should_log_inputs)
if prepared_expected_node_occurrence:
self.checkGraphModuleNodes(
mp_ns, expected_node_occurrence=prepared_expected_node_occurrence)
self.checkGraphModuleNodes(
mq_ns, expected_node_occurrence=prepared_expected_node_occurrence)
if not skip_scripting:
mp_ns = torch.jit.script(mp_ns)
mq_ns = torch.jit.script(mq_ns)
# calibrate
mp_ns(*data)
mq_ns(*data)
# check activation result correctness
act_compare_dict = extract_logger_info(mp_ns, mq_ns, OutputLogger)
self.assertTrue(
len(act_compare_dict) == results_len,
f"expected len {results_len}, got len {len(act_compare_dict)}")
self.assert_ns_compare_dict_valid(act_compare_dict)
return act_compare_dict
def _test_match_shadow_activations(
self, m, data, prepared_expected_node_occurrence=None, results_len=0,
should_log_inputs=False, qconfig_dict=None, skip_scripting=False,
):
if qconfig_dict is None:
qconfig_dict = {'': torch.quantization.default_qconfig}
mp = prepare_fx(m, qconfig_dict)
mp(*data)
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
mp_shadows_mq = add_shadow_loggers(
'fp32_prepared', mp, 'int8', mq, OutputLogger,
should_log_inputs=should_log_inputs)
if prepared_expected_node_occurrence:
self.checkGraphModuleNodes(
mp_shadows_mq, expected_node_occurrence=prepared_expected_node_occurrence)
if not skip_scripting:
mp_shadows_mq = torch.jit.script(mp_shadows_mq)
# calibrate
mp_shadows_mq(*data)
# check activation result correctness
act_compare_dict = extract_shadow_logger_info(
mp_shadows_mq, OutputLogger)
self.assertTrue(
len(act_compare_dict) == results_len,
f"expected len {results_len}, got len {len(act_compare_dict)}")
self.assert_ns_compare_dict_valid(act_compare_dict)
return act_compare_dict
class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
@skipIfNoFBGEMM
def test_extract_weights_mod(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
# conv1d
self.conv1d_0 = nn.Conv1d(1, 1, 1)
# conv1d - relu
self.conv1d_1 = nn.Conv1d(1, 1, 1)
self.relu_0 = nn.ReLU()
# conv2d
self.conv2d_0 = nn.Conv2d(1, 1, 1)
# conv2d - relu
self.conv2d_1 = nn.Conv2d(1, 1, 1)
self.relu_1 = nn.ReLU()
# conv3d
self.conv3d_0 = nn.Conv3d(1, 1, 1)
# conv3d - relu
self.conv3d_1 = nn.Conv3d(1, 1, 1)
self.relu_2 = nn.ReLU()
# linear
self.linear_0 = nn.Linear(1, 1)
# linear - relu
self.linear_1 = nn.Linear(1, 1)
self.relu_3 = nn.ReLU()
def forward(self, x):
x = self.conv1d_0(x)
x = self.conv1d_1(x)
x = self.relu_0(x)
x = x.reshape(1, 1, 1, 1)
x = self.conv2d_0(x)
x = self.conv2d_1(x)
x = self.relu_1(x)
x = x.reshape(1, 1, 1, 1, 1)
x = self.conv3d_0(x)
x = self.conv3d_1(x)
x = self.relu_2(x)
x = x.reshape(1, 1)
x = self.linear_0(x)
x = self.linear_1(x)
x = self.relu_3(x)
return x
m = M().eval()
self._test_extract_weights(m, results_len=8)
@skipIfNoFBGEMM
def test_extract_weights_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.empty(4, 4))
self.b = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))
def forward(self, x):
x = F.linear(x, self.w, self.b)
x = F.relu(x)
x = F.linear(x, self.w, self.b)
return x
m = M().eval()
self._test_extract_weights(m, results_len=2)
@skipIfNoFBGEMM
def test_match_activations_mod(self):
m = nn.Sequential(
torch.quantization.QuantStub(),
nn.Conv2d(1, 1, 1),
nn.Conv2d(1, 1, 1),
).eval()
expected_occurrence = {
ns.call_module(OutputLogger): 2,
}
self._test_match_activations(
m, (torch.randn(2, 1, 2, 2),),
prepared_expected_node_occurrence=expected_occurrence,
results_len=2)
@override_qengines
def test_match_activations_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.empty(4, 4))
self.b1 = nn.Parameter(torch.zeros(4))
self.w2 = nn.Parameter(torch.empty(4, 4))
self.b2 = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
def forward(self, x):
x = F.linear(x, self.w1, self.b1)
x = F.linear(x, self.w2, self.b2)
x = F.relu(x)
return x
m = M().eval()
expected_occurrence = {
ns.call_module(OutputLogger): 2,
}
self._test_match_activations(
m, (torch.randn(4, 4),),
prepared_expected_node_occurrence=expected_occurrence,
results_len=2)
@skipIfNoFBGEMM
def test_add_shadow_loggers_mod(self):
m = nn.Sequential(
nn.Conv2d(1, 1, 1),
nn.Conv2d(1, 1, 1),
).eval()
self._test_match_shadow_activations(
m, (torch.randn(1, 1, 4, 4),), results_len=2)
@skipIfNoFBGEMM
def test_add_shadow_loggers_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.empty(4, 4))
self.b1 = nn.Parameter(torch.zeros(4))
self.w2 = nn.Parameter(torch.empty(4, 4))
self.b2 = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
def forward(self, x):
x = F.linear(x, self.w1, self.b1)
x = F.linear(x, self.w2, self.b2)
x = F.relu(x)
return x
m = M().eval()
self._test_match_shadow_activations(
m, (torch.randn(4, 4),), results_len=2)
@skipIfNoFBGEMM
def test_add_shadow_loggers_multiple_dtype_casts(self):
"""
Verifies that for nodes where the first input arg is a list,
such as `cat`, we insert an individual dtype cast for each
arg of the list.
"""
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.cat([x, x, x], dim=0)
return x
m = M().eval()
expected_occurrence = {
# 3 dequantize function calls from the 3 dtype casts for [x, x, x]
ns.call_function(torch.dequantize): 3,
# 1 dequantize method call for module output
ns.call_method("dequantize"): 1,
}
self._test_match_shadow_activations(
m, (torch.randn(4, 4),),
prepared_expected_node_occurrence=expected_occurrence,
results_len=1)
@skipIfNoFBGEMM
def test_logging_inputs(self):
"""
Verifies that logging inputs works correctly
"""
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv(x)
x = torch.cat([x, x], dim=0)
return x
m = M().eval()
self._test_match_shadow_activations(
m, (torch.randn(1, 1, 4, 4),),
results_len=2,
should_log_inputs=True)
@skipIfNoFBGEMM
def test_linear_fp16_weights(self):
qconfig_dict = {'': torch.quantization.float16_static_qconfig}
m = LinearReluFunctional().eval()
self._test_extract_weights(m, results_len=1, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_linear_fp16_activations(self):
for should_log_inputs in (True, False):
qconfig_dict = {'': torch.quantization.float16_static_qconfig}
m = LinearReluFunctional().eval()
num_loggers = 2 if should_log_inputs else 1
expected_occurrence = {
ns.call_module(OutputLogger): num_loggers,
}
res = self._test_match_activations(
m, (torch.randn(4, 4),),
prepared_expected_node_occurrence=expected_occurrence,
results_len=1,
qconfig_dict=qconfig_dict,
should_log_inputs=should_log_inputs)
@skipIfNoFBGEMM
def test_linear_fp16_shadow_activations(self):
for should_log_inputs in (True, False):
qconfig_dict = {'': torch.quantization.float16_static_qconfig}
m = LinearReluFunctional().eval()
num_loggers = 4 if should_log_inputs else 2
expected_occurrence = {
ns.call_module(OutputLogger): num_loggers,
}
res2 = self._test_match_shadow_activations(
m, (torch.randn(4, 4),),
prepared_expected_node_occurrence=expected_occurrence,
results_len=1,
qconfig_dict=qconfig_dict,
should_log_inputs=should_log_inputs)
class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase):
"""
Tests numeric suite core APIs on non-toy models.
"""
@skipIfNoFBGEMM
def test_compare_weights_conv(self):
test_cases = (
(ConvModel(),),
(ConvBnModel(),),
(ConvBnReLUModel(),),
)
for m, in test_cases:
m.eval()
self._test_extract_weights(m, results_len=1)
@skipIfNoFBGEMM
def test_compare_weights_linear(self):
test_cases = (
(SingleLayerLinearModel(), None),
(
SingleLayerLinearDynamicModel(),
{"object_type": [(nn.Linear, default_dynamic_qconfig)]},
),
)
for m, qconfig_dict in test_cases:
m.eval()
res = self._test_extract_weights(
m, results_len=1, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_compare_weights_lstm_dynamic(self):
qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
m = LSTMwithHiddenDynamicModel().eval()
res = self._test_extract_weights(
m, results_len=1, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_compare_activations_conv(self):
test_cases = (
(ConvModel(),),
(ConvBnModel(),),
(ConvBnReLUModel(),),
)
for m, in test_cases:
m.eval()
res = self._test_match_activations(
m, (torch.randn(1, 3, 4, 4),), results_len=1)
@skipIfNoFBGEMM
def test_compare_activations_linear(self):
test_cases = (
(SingleLayerLinearModel(), None),
(
SingleLayerLinearDynamicModel(),
{"object_type": [(nn.Linear, default_dynamic_qconfig)]},
),
)
for m, qconfig_dict in test_cases:
m.eval()
res = self._test_match_activations(
m, (torch.randn(5, 5),), results_len=1, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_compare_activations_lstm_dynamic(self):
qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
m = LSTMwithHiddenDynamicModel().eval()
lstm_input = torch.rand((1, 1, 2))
lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2))
# TODO(future PR): enable scripting (quant prepared LSTM not scriptable)
res = self._test_match_activations(
m, (lstm_input, lstm_hidden), results_len=1, qconfig_dict=qconfig_dict,
skip_scripting=True)
@skipIfNoFBGEMM
def test_compare_shadow_activations_conv(self):
test_cases = (
(ConvModel(),),
(ConvBnModel(),),
(ConvBnReLUModel(),),
)
for m, in test_cases:
m.eval()
res = self._test_match_shadow_activations(
m, (torch.randn(1, 3, 4, 4),), results_len=1)
@skipIfNoFBGEMM
def test_compare_shadow_activations_linear(self):
test_cases = (
(SingleLayerLinearModel(), None),
(
SingleLayerLinearDynamicModel(),
{"object_type": [(nn.Linear, default_dynamic_qconfig)]},
),
)
for m, qconfig_dict in test_cases:
m.eval()
res = self._test_match_shadow_activations(
m, (torch.randn(5, 5),), results_len=1, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_compare_shadow_activations_lstm_dynamic(self):
qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
m = LSTMwithHiddenDynamicModel().eval()
lstm_input = torch.rand((1, 1, 2))
lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2))
# TODO(future PR): enable scripting (quant prepared LSTM not scriptable)
res = self._test_match_shadow_activations(
m, (lstm_input, lstm_hidden), results_len=1, qconfig_dict=qconfig_dict,
skip_scripting=True)
@skipIfNoFBGEMM
def test_sparsenn_compare_activations(self):
for should_log_inputs in (True, False):
sparse_nn = SparseNNModel().eval()
idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
offsets = torch.LongTensor([0, 4])
x = torch.randn(2, 4)
self._test_match_activations(
sparse_nn, (idx, offsets, x),
results_len=4,
should_log_inputs=should_log_inputs)
@skipIfNoFBGEMM
def test_sparsenn_shadow(self):
for should_log_inputs in (True, False):
sparse_nn = SparseNNModel().eval()
idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
offsets = torch.LongTensor([0, 4])
x = torch.randn(2, 4)
self._test_match_activations(
sparse_nn, (idx, offsets, x),
results_len=4,
should_log_inputs=should_log_inputs)