mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Until we add quant_{min, max} args to `torch.quantize_per_{channel, tensor}`, this patch will make sure we will honor observer's restrictions on quantized values. Test Plan: Added new tests, run with - `buck run caffe2/test:quantization -- quantization.core.test_utils` Differential Revision: D38624119 Pull Request resolved: https://github.com/pytorch/pytorch/pull/83438 Approved by: https://github.com/andrewor14
194 lines
7.1 KiB
Python
194 lines
7.1 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import TestCase
|
|
from torch.ao.quantization.utils import get_fqn_to_example_inputs
|
|
from torch.nn.quantized.modules.utils import _quantize_weight
|
|
from torch.ao.quantization import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver
|
|
|
|
|
|
class TestUtils(TestCase):
|
|
def _test_get_fqn_to_example_inputs(self, M, example_inputs, expected_fqn_to_dim):
|
|
m = M().eval()
|
|
fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs)
|
|
for fqn, expected_dims in expected_fqn_to_dim.items():
|
|
assert fqn in expected_fqn_to_dim
|
|
example_inputs = fqn_to_example_inputs[fqn]
|
|
for example_input, expected_dim in zip(example_inputs, expected_dims):
|
|
assert example_input.dim() == expected_dim
|
|
|
|
def test_get_fqn_to_example_inputs_simple(self):
|
|
class Sub(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(5, 5)
|
|
self.linear2 = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(5, 5)
|
|
self.linear2 = torch.nn.Linear(5, 5)
|
|
self.sub = Sub()
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.linear2(x)
|
|
x = self.sub(x)
|
|
return x
|
|
|
|
expected_fqn_to_dim = {
|
|
"": (2,),
|
|
"linear1": (2,),
|
|
"linear2": (2,),
|
|
"sub": (2,),
|
|
"sub.linear1": (2,),
|
|
"sub.linear2": (2,)
|
|
}
|
|
example_inputs = (torch.rand(1, 5),)
|
|
self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim)
|
|
|
|
def test_get_fqn_to_example_inputs_default_kwargs(self):
|
|
""" Test that we can get example inputs for functions with default keyword arguments
|
|
"""
|
|
class Sub(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(5, 5)
|
|
self.linear2 = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x, key1=torch.rand(1), key2=torch.rand(1)):
|
|
x = self.linear1(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(5, 5)
|
|
self.linear2 = torch.nn.Linear(5, 5)
|
|
self.sub = Sub()
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.linear2(x)
|
|
# only override `key2`, `key1` will use default
|
|
x = self.sub(x, key2=torch.rand(1, 2))
|
|
return x
|
|
|
|
expected_fqn_to_dim = {
|
|
"": (2,),
|
|
"linear1": (2,),
|
|
"linear2": (2,),
|
|
# second arg is `key1`, which is using default argument
|
|
# third arg is `key2`, override by callsite
|
|
"sub": (2, 1, 2),
|
|
"sub.linear1": (2,),
|
|
"sub.linear2": (2,)
|
|
}
|
|
example_inputs = (torch.rand(1, 5),)
|
|
self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim)
|
|
|
|
def test_get_fqn_to_example_inputs_complex_args(self):
|
|
""" Test that we can record complex example inputs such as lists and dicts
|
|
"""
|
|
class Sub(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(5, 5)
|
|
self.linear2 = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x, list_arg, dict_arg):
|
|
x = self.linear1(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(5, 5)
|
|
self.linear2 = torch.nn.Linear(5, 5)
|
|
self.sub = Sub()
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.linear2(x)
|
|
x = self.sub(x, [x], {"3": x})
|
|
return x
|
|
|
|
example_inputs = (torch.rand(1, 5),)
|
|
m = M().eval()
|
|
fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs)
|
|
assert "sub" in fqn_to_example_inputs
|
|
assert isinstance(fqn_to_example_inputs["sub"][1], list)
|
|
assert isinstance(fqn_to_example_inputs["sub"][2], dict) and \
|
|
"3" in fqn_to_example_inputs["sub"][2]
|
|
|
|
def test_quantize_weight_clamping_per_tensor(self):
|
|
""" Test quant_{min, max} from per tensor observer is honored by `_quantize_weight` method
|
|
"""
|
|
fp_min, fp_max = -1000.0, 1000.0
|
|
q8_min, q8_max = -10, 10
|
|
|
|
float_tensor = torch.tensor([fp_min, fp_max])
|
|
|
|
observer = MovingAverageMinMaxObserver(
|
|
averaging_constant=1.0,
|
|
dtype=torch.qint8,
|
|
quant_min=q8_min,
|
|
quant_max=q8_max,
|
|
qscheme=torch.per_tensor_symmetric,
|
|
)
|
|
|
|
observer(float_tensor)
|
|
assert observer.min_val == fp_min
|
|
assert observer.max_val == fp_max
|
|
|
|
quantized_tensor = _quantize_weight(float_tensor, observer)
|
|
assert quantized_tensor.int_repr().max().item() == q8_max
|
|
assert quantized_tensor.int_repr().min().item() == q8_min
|
|
|
|
# Actual weight values can be outside than observer [min_val, max_val] for the moving average observer
|
|
float_tensor *= 1.2
|
|
|
|
quantized_tensor = _quantize_weight(float_tensor, observer)
|
|
assert quantized_tensor.int_repr().max().item() == q8_max
|
|
assert quantized_tensor.int_repr().min().item() == q8_min
|
|
|
|
def test_quantize_weight_clamping_per_channel(self):
|
|
""" Test quant_{min, max} from per channel observer is honored by `_quantize_weight` method
|
|
"""
|
|
fp_min, fp_max = -1000.0, 1000.0
|
|
q8_min, q8_max = -10, 10
|
|
|
|
float_tensor = torch.tensor([[fp_min, fp_max]])
|
|
|
|
observer = MovingAveragePerChannelMinMaxObserver(
|
|
averaging_constant=1.0,
|
|
dtype=torch.qint8,
|
|
quant_min=q8_min,
|
|
quant_max=q8_max,
|
|
qscheme=torch.per_channel_symmetric,
|
|
ch_axis=0,
|
|
)
|
|
|
|
observer(float_tensor)
|
|
assert observer.min_val == fp_min
|
|
assert observer.max_val == fp_max
|
|
|
|
quantized_tensor = _quantize_weight(float_tensor, observer)
|
|
assert quantized_tensor.int_repr().max().item() == q8_max
|
|
assert quantized_tensor.int_repr().min().item() == q8_min
|
|
|
|
# Actual weight values can be outside than observer [min_val, max_val] for the moving average observer
|
|
float_tensor *= 1.2
|
|
|
|
quantized_tensor = _quantize_weight(float_tensor, observer)
|
|
assert quantized_tensor.int_repr().max().item() == q8_max
|
|
assert quantized_tensor.int_repr().min().item() == q8_min
|