mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
for ondevice quantization Summary: THis diff just wraps existing API for ondevice quantization Test Plan: test/quantization/jit/test_ondevice_quantization.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D38868647](https://our.internmc.facebook.com/intern/diff/D38868647) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83742 Approved by: https://github.com/jerryzh168
495 lines
20 KiB
Python
495 lines
20 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Owner(s): ["oncall: quantization"]
|
|
|
|
import torch
|
|
|
|
from torch.ao.quantization import (
|
|
default_dynamic_qconfig,
|
|
per_channel_dynamic_qconfig,
|
|
)
|
|
|
|
from torch.ao.quantization.quantize_jit import (
|
|
prepare_dynamic_jit,
|
|
convert_dynamic_jit,
|
|
_prepare_ondevice_dynamic_jit,
|
|
_quantize_ondevice_dynamic_jit,
|
|
)
|
|
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
from torch.testing._internal.common_quantization import (
|
|
get_script_module,
|
|
LinearAddModel,
|
|
)
|
|
|
|
from torch.jit.mobile import _load_for_lite_interpreter
|
|
|
|
from torch.testing import FileCheck
|
|
|
|
import io
|
|
|
|
class myMod(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super(myMod, self).__init__()
|
|
self.fc1 = torch.nn.Linear(5, 5).float()
|
|
self.fc1.weight = weight
|
|
self.fc2 = torch.nn.Linear(5, 5).float()
|
|
|
|
def forward(self, x):
|
|
return self.fc2(self.fc1(x))
|
|
|
|
|
|
class MyConvLinearModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyConvLinearModule, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 5, 3)
|
|
weight = torch.nn.Parameter(torch.ones(5, 5))
|
|
self.weight1 = torch.nn.Parameter(torch.ones(5, 5))
|
|
self.mymod = myMod(weight)
|
|
|
|
def forward(self, x):
|
|
conv_output = self.conv(x)
|
|
y = self.mymod(conv_output)
|
|
z = torch.nn.functional.linear(y, self.weight1)
|
|
return z
|
|
|
|
def get_example_inputs(self):
|
|
return (torch.rand(1, 3, 12, 7),)
|
|
|
|
|
|
class OnDevicePTQUtils(object):
|
|
observer_module_name = ['MinMaxObserver', 'PerChannelMinMaxObserver']
|
|
|
|
@staticmethod
|
|
def insert_observers(model, qconfig_dict):
|
|
inputs = model.get_example_inputs()
|
|
scripted_model = get_script_module(model, False, inputs)
|
|
scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict)
|
|
return scripted_model
|
|
|
|
@staticmethod
|
|
def ptq_dynamic_quantize(model, qconfig_dict):
|
|
inputs = model.get_example_inputs()
|
|
m = get_script_module(model, False, inputs)
|
|
m = _quantize_ondevice_dynamic_jit(m, qconfig_dict, 'forward', True)
|
|
return m
|
|
|
|
@staticmethod
|
|
def find_observer_modules(m):
|
|
observer_modules = []
|
|
for child_module in m.children():
|
|
if child_module.original_name in OnDevicePTQUtils.observer_module_name:
|
|
observer_modules.append(child_module)
|
|
return observer_modules
|
|
|
|
@staticmethod
|
|
def is_value_type_observer(value):
|
|
type_name = value.type()
|
|
for observer_type in OnDevicePTQUtils.observer_module_name:
|
|
if observer_type in type_name.str():
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def is_calculate_qparam(node):
|
|
if node.kind() == "prim::CallMethod":
|
|
if node.s('name') == "calculate_qparams":
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def get_linear_packed_param_fp_weight(node):
|
|
weight = node.inputsAt(0).node()
|
|
if weight.kind() != "aten::quantize_per_tensor" and weight.kind() != "aten::quantize_per_channel":
|
|
raise ValueError("Quantized weight must be produced.")
|
|
fp_weight = weight.inputsAt(0).node()
|
|
assert fp_weight.kind() == "prim::GetAttr", "Weight must be an attribute of the module."
|
|
fp_weight_name = fp_weight.s('name')
|
|
return fp_weight_name
|
|
|
|
@staticmethod
|
|
def is_per_channel_quantized_packed_param(node):
|
|
assert node.kind() == 'quantized::linear_prepack', "Node must corresponds to linear_prepack."
|
|
weight = node.inputsAt(0).node()
|
|
assert weight.kind() != "aten::quantize_per_tensor" or weight.kind() != "aten::quantize_per_channel"
|
|
return weight.kind() != "aten::quantize_per_tensor"
|
|
|
|
|
|
class TestOnDeviceDynamicPTQInsertObservers(TestCase):
|
|
def _check_num_and_type_of_observers(self, model, num_observers):
|
|
qconfig_dict = {"": default_dynamic_qconfig}
|
|
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
|
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
|
|
self.assertTrue(len(observer_modules) == num_observers)
|
|
for observer in observer_modules:
|
|
self.assertTrue(observer.original_name == 'MinMaxObserver')
|
|
|
|
qconfig_dict = {"": per_channel_dynamic_qconfig}
|
|
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
|
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
|
|
self.assertTrue(len(observer_modules) == num_observers)
|
|
for observer in observer_modules:
|
|
self.assertTrue(observer.original_name == 'PerChannelMinMaxObserver')
|
|
|
|
def _check_observer_method(self, model, num_observers):
|
|
qconfig_dict = {"": default_dynamic_qconfig}
|
|
inputs = model.get_example_inputs()
|
|
orig_scripted_model = get_script_module(model, False, inputs)
|
|
torch._C._jit_pass_inline(orig_scripted_model.graph)
|
|
orig_forward_graph = orig_scripted_model.graph.str()
|
|
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
|
quant_forward_graph = scripted_model.graph.str()
|
|
# exact graph matching is difficult so just resorting to # of lines
|
|
# instead of implementing graph matching
|
|
self.assertEqual(len(orig_forward_graph.splitlines()), len(quant_forward_graph.splitlines()))
|
|
observe_method = scripted_model.observe_forward.graph
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"](%_observer",
|
|
num_observers, exactly=True).run(observe_method)
|
|
reset_observers_method = scripted_model.reset_observers_forward.graph
|
|
FileCheck().check_count(
|
|
"prim::CallMethod[name=\"reset_min_max_vals\"](%_observer", num_observers, exactly=True).run(reset_observers_method)
|
|
|
|
def _observer_is_weight_only(self, node):
|
|
if (node.kind() == "prim::CallMethod") and node.s("name") == "forward":
|
|
if (OnDevicePTQUtils.is_value_type_observer(node.inputsAt(0))):
|
|
return (node.inputsAt(1).node().kind() == "prim::GetAttr")
|
|
return False
|
|
|
|
def test_num_observers(self):
|
|
model = LinearAddModel()
|
|
self._check_num_and_type_of_observers(model, 2)
|
|
model = MyConvLinearModule()
|
|
self._check_num_and_type_of_observers(model, 3)
|
|
|
|
def test_observe_method(self):
|
|
model = MyConvLinearModule()
|
|
self._check_observer_method(model, 3)
|
|
|
|
def test_weight_only_observers(self):
|
|
model = MyConvLinearModule()
|
|
qconfig_dict = {"": default_dynamic_qconfig}
|
|
inputs = model.get_example_inputs()
|
|
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
|
observe_forward_graph = scripted_model.observe_forward.graph
|
|
num_weight_only_observers = 0
|
|
for node in observe_forward_graph.nodes():
|
|
if (self._observer_is_weight_only(node)):
|
|
num_weight_only_observers += 1
|
|
self.assertEqual(num_weight_only_observers, 3)
|
|
|
|
|
|
class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase):
|
|
def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
quantize_per_tensor = quantize_per_channel = 0
|
|
for n in quantize_forward_graph.nodes():
|
|
if "aten::quantize_per_tensor" in n.kind():
|
|
quantize_per_tensor += 1
|
|
if "aten::quantize_per_channel" in n.kind():
|
|
quantize_per_channel += 1
|
|
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
|
|
|
|
def _validate_calculate_qparams(self, model, num_nodes):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
num_calculate_qparams = 0
|
|
for n in quantize_forward_graph.nodes():
|
|
if OnDevicePTQUtils.is_calculate_qparam(n):
|
|
num_calculate_qparams += 1
|
|
self.assertEqual(num_calculate_qparams, num_nodes)
|
|
|
|
def _validate_no_observer_forward(self, model):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
for n in quantize_forward_graph.nodes():
|
|
if (n.kind() == "prim::CallMethod") and n.s("name") == "forward":
|
|
if (OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0))):
|
|
return False
|
|
return True
|
|
|
|
def _check_quant_dequant_and_calc_qparams(self, model, num_nodes):
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_quant_dequant_nodes(m, num_nodes)
|
|
self._validate_calculate_qparams(m, num_nodes)
|
|
self._validate_no_observer_forward(m)
|
|
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_quant_dequant_nodes(m, num_nodes, num_nodes)
|
|
self._validate_calculate_qparams(m, num_nodes)
|
|
self._validate_no_observer_forward(m)
|
|
|
|
def _check_quantize_forward_runs(self, model):
|
|
inputs = model.get_example_inputs()
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
# First must run observe forward to record the stats to produce
|
|
# correct scales and zero points
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
|
|
def test_num_quant_dequant_nodes(self):
|
|
model = LinearAddModel()
|
|
self._check_quant_dequant_and_calc_qparams(model, 2)
|
|
model = MyConvLinearModule()
|
|
self._check_quant_dequant_and_calc_qparams(model, 3)
|
|
|
|
def test_quantize_forward_runs(self):
|
|
model = LinearAddModel()
|
|
self._check_quantize_forward_runs(model)
|
|
model = MyConvLinearModule()
|
|
self._check_quantize_forward_runs(model)
|
|
|
|
|
|
class TestOnDeviceDynamicPTQFinalize(TestCase):
|
|
def _validate_packed_params(self, model, num_nodes, per_channel=0):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
quantize_per_tensor = quantize_per_channel = 0
|
|
linear_prepack = 0
|
|
linear_prepack_uses = 0
|
|
for n in quantize_forward_graph.nodes():
|
|
if n.kind() == 'prim::SetAttr':
|
|
maybe_packed_param_value = n.inputsAt(1)
|
|
maybe_packed_param = maybe_packed_param_value.node()
|
|
if maybe_packed_param.kind() == 'quantized::linear_prepack':
|
|
linear_prepack += 1
|
|
linear_prepack_uses += len(maybe_packed_param_value.uses())
|
|
if OnDevicePTQUtils.is_per_channel_quantized_packed_param(maybe_packed_param):
|
|
quantize_per_channel += 1
|
|
else:
|
|
quantize_per_tensor += 1
|
|
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
|
|
self.assertEqual(quantize_per_channel, per_channel)
|
|
self.assertEqual(linear_prepack, num_nodes)
|
|
self.assertEqual(linear_prepack_uses, num_nodes)
|
|
|
|
|
|
def _validate_no_linear_unpack(self, model):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
for n in quantize_forward_graph.nodes():
|
|
if n.kind() == 'quantized::linear_unpack':
|
|
return False
|
|
return True
|
|
|
|
|
|
def _validate_setattr_fp_weights(self, model, num_nodes):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
fp_weights_setattr = 0
|
|
fp_weight_names = []
|
|
for n in quantize_forward_graph.nodes():
|
|
if n.kind() == 'prim::SetAttr':
|
|
maybe_packed_param = n.inputsAt(1).node()
|
|
if maybe_packed_param.kind() == 'quantized::linear_prepack':
|
|
weight_name = OnDevicePTQUtils.get_linear_packed_param_fp_weight(maybe_packed_param)
|
|
fp_weight_names.append(weight_name)
|
|
|
|
for n in quantize_forward_graph.nodes():
|
|
# This is basically detecting
|
|
# %x = prim::Constant
|
|
# = prim::SetAttr(<weight_name>)(module_value, x)
|
|
# Thus making sure that the original fp weights are
|
|
# reset
|
|
if n.kind() == 'prim::SetAttr':
|
|
weight_name = n.s('name')
|
|
if weight_name in fp_weight_names:
|
|
maybe_constant = n.inputsAt(1).node()
|
|
if maybe_constant.kind() == 'prim::Constant':
|
|
fp_weights_setattr += 1
|
|
self.assertEqual(fp_weights_setattr, num_nodes)
|
|
|
|
|
|
def _validate_quantized_forward(self, model, num_nodes):
|
|
quantized_forward_graph = model.quantized_forward.graph
|
|
quantize_per_tensor = quantize_per_channel = 0
|
|
quantized_linear_dynamic = 0
|
|
linear_packed_params = 0
|
|
num_setattr = 0
|
|
for n in quantized_forward_graph.nodes():
|
|
if "aten::quantize_per_tensor" in n.kind():
|
|
quantize_per_tensor += 1
|
|
if "aten::quantize_per_channel" in n.kind():
|
|
quantize_per_channel += 1
|
|
if "quantized::linear_dynamic" in n.kind():
|
|
quantized_linear_dynamic += 1
|
|
if n.kind() == 'prim::GetAttr':
|
|
output = n.outputsAt(0)
|
|
output_type = output.type()
|
|
if "LinearPackedParamsBase" in output_type.str():
|
|
linear_packed_params += 1
|
|
if n.kind() == 'prim::SetAttr':
|
|
num_setattr += 1
|
|
self.assertEqual(quantize_per_tensor, 0)
|
|
self.assertEqual(quantize_per_channel, 0)
|
|
self.assertEqual(quantized_linear_dynamic, num_nodes)
|
|
self.assertEqual(linear_packed_params, num_nodes)
|
|
# self.assertEqual(num_setattr, 0)
|
|
|
|
|
|
def _check_quantize_forward(self, model, num_nodes):
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_packed_params(m, num_nodes)
|
|
self._validate_no_linear_unpack(m)
|
|
self._validate_setattr_fp_weights(m, num_nodes)
|
|
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_packed_params(m, num_nodes, num_nodes)
|
|
self._validate_no_linear_unpack(m)
|
|
self._validate_setattr_fp_weights(m, num_nodes)
|
|
|
|
|
|
def _check_quantized_forward(self, model, num_nodes):
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_quantized_forward(m, num_nodes)
|
|
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_quantized_forward(m, num_nodes)
|
|
|
|
|
|
def _check_against_ref_dynamic_ptq(self, model):
|
|
model.eval()
|
|
inputs = model.get_example_inputs()
|
|
ref_m = torch.jit.script(model)
|
|
torch._C._jit_pass_inline(ref_m.graph)
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
|
|
ref_m = convert_dynamic_jit(ref_m)
|
|
ref_output = ref_m(*inputs)
|
|
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
output = m.quantized_forward(*inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
thrown = False
|
|
try:
|
|
m(*inputs)
|
|
except Exception as e:
|
|
thrown = True
|
|
self.assertTrue(thrown)
|
|
|
|
# test with per channel quant
|
|
ref_m = torch.jit.script(model)
|
|
torch._C._jit_pass_inline(ref_m.graph)
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
|
|
ref_m = convert_dynamic_jit(ref_m)
|
|
ref_output = ref_m(*inputs)
|
|
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
output = m.quantized_forward(*inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
thrown = False
|
|
try:
|
|
m(*inputs)
|
|
except Exception as e:
|
|
thrown = True
|
|
self.assertTrue(thrown)
|
|
|
|
|
|
def _check_serialization_deserialization(self, model):
|
|
model.eval()
|
|
inputs = model.get_example_inputs()
|
|
ref_m = torch.jit.script(model)
|
|
torch._C._jit_pass_inline(ref_m.graph)
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
|
|
ref_m = convert_dynamic_jit(ref_m)
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(ref_m, buffer)
|
|
buffer.seek(0)
|
|
ref_m = torch.jit.load(buffer)
|
|
ref_output = ref_m(*inputs)
|
|
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(m, buffer)
|
|
buffer.seek(0)
|
|
m = torch.jit.load(buffer)
|
|
m.reset_observers_forward()
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
output = m.quantized_forward(*inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
|
|
# check for lite interpreter
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
m = _load_for_lite_interpreter(buffer) # Error here
|
|
m.run_method("reset_observers_forward")
|
|
m.run_method("observe_forward", *inputs)
|
|
m.run_method("quantize_forward", *inputs)
|
|
output = m.run_method("quantized_forward", *inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
|
|
model.eval()
|
|
inputs = model.get_example_inputs()
|
|
ref_m = torch.jit.script(model)
|
|
torch._C._jit_pass_inline(ref_m.graph)
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
|
|
ref_m = convert_dynamic_jit(ref_m)
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(ref_m, buffer)
|
|
buffer.seek(0)
|
|
ref_m = torch.jit.load(buffer)
|
|
ref_output = ref_m(*inputs)
|
|
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(m, buffer)
|
|
buffer.seek(0)
|
|
m = torch.jit.load(buffer)
|
|
m.reset_observers_forward()
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
output = m.quantized_forward(*inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
|
|
# check for lite interpreter
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
m = _load_for_lite_interpreter(buffer) # Error here
|
|
m.run_method("reset_observers_forward")
|
|
m.run_method("observe_forward", *inputs)
|
|
m.run_method("quantize_forward", *inputs)
|
|
output = m.run_method("quantized_forward", *inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
|
|
|
|
def test_quantize_forward(self):
|
|
model = LinearAddModel()
|
|
self._check_quantize_forward(model, 2)
|
|
model = MyConvLinearModule()
|
|
self._check_quantize_forward(model, 3)
|
|
|
|
|
|
def test_quantized_forward(self):
|
|
model = LinearAddModel()
|
|
self._check_quantized_forward(model, 2)
|
|
model = MyConvLinearModule()
|
|
self._check_quantized_forward(model, 3)
|
|
|
|
|
|
def test_against_offdevice_dynamic_ptq(self):
|
|
model = LinearAddModel()
|
|
self._check_against_ref_dynamic_ptq(model)
|
|
model = MyConvLinearModule()
|
|
self._check_against_ref_dynamic_ptq(model)
|
|
|
|
|
|
def test_serialization_deserialization(self):
|
|
model = MyConvLinearModule()
|
|
self._check_serialization_deserialization(model)
|