mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
onnx export of per channel fake quantize functions (#42835)
Summary: Fixes https://github.com/pytorch/pytorch/issues/39502 This PR adds support for exporting **fake_quantize_per_channel_affine** to a pair of QuantizeLinear and DequantizeLinear. Per tensor support was added by PR https://github.com/pytorch/pytorch/pull/39738. `axis` attribute of QuantizeLinear and DequantizeLinear, which is required for per channel support, is added in opset13 added by https://github.com/onnx/onnx/pull/2772. [update 1/20/2021]: opset13 is being supported on master, the added function is now properly tested. Code also rebased to new master. The function is also tested offline with the following code ```python import torch from torch import quantization from torchvision import models qat_resnet18 = models.resnet18(pretrained=True).eval().cuda() qat_resnet18.qconfig = quantization.QConfig( activation=quantization.default_fake_quant, weight=quantization.default_per_channel_weight_fake_quant) quantization.prepare_qat(qat_resnet18, inplace=True) qat_resnet18.apply(quantization.enable_observer) qat_resnet18.apply(quantization.enable_fake_quant) dummy_input = torch.randn(16, 3, 224, 224).cuda() _ = qat_resnet18(dummy_input) for module in qat_resnet18.modules(): if isinstance(module, quantization.FakeQuantize): module.calculate_qparams() qat_resnet18.apply(quantization.disable_observer) qat_resnet18.cuda() input_names = [ "actual_input_1" ] output_names = [ "output1" ] torch.onnx.export(qat_resnet18, dummy_input, "quant_model.onnx", verbose=True, opset_version=13) ``` It can generate the desired graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/42835 Reviewed By: houseroad Differential Revision: D26293823 Pulled By: SplitInfinity fbshipit-source-id: 300498a2e24b7731b12fa2fbdea4e73dde80e7ea
This commit is contained in:
committed by
Facebook GitHub Bot
parent
159c48b19b
commit
7363da7c57
@ -182,7 +182,7 @@ class TestModels(TestCase):
|
||||
self.exportTest(toC(FakeQuantNet()), toC(x))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_qat_resnet(self):
|
||||
def test_qat_resnet_pertensor(self):
|
||||
# Quantize ResNet50 model
|
||||
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
||||
qat_resnet50 = resnet50()
|
||||
@ -202,6 +202,27 @@ class TestModels(TestCase):
|
||||
|
||||
self.exportTest(toC(qat_resnet50), toC(x))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_qat_resnet_per_channel(self):
|
||||
# Quantize ResNet50 model
|
||||
x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
|
||||
qat_resnet50 = resnet50()
|
||||
|
||||
qat_resnet50.qconfig = quantization.QConfig(
|
||||
activation=quantization.default_fake_quant,
|
||||
weight=quantization.default_per_channel_weight_fake_quant)
|
||||
quantization.prepare_qat(qat_resnet50, inplace=True)
|
||||
qat_resnet50.apply(torch.quantization.enable_observer)
|
||||
qat_resnet50.apply(torch.quantization.enable_fake_quant)
|
||||
|
||||
_ = qat_resnet50(x)
|
||||
for module in qat_resnet50.modules():
|
||||
if isinstance(module, quantization.FakeQuantize):
|
||||
module.calculate_qparams()
|
||||
qat_resnet50.apply(torch.quantization.disable_observer)
|
||||
|
||||
self.exportTest(toC(qat_resnet50), toC(x))
|
||||
|
||||
@disableScriptTest() # None type in outputs
|
||||
def test_googlenet(self):
|
||||
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
|
||||
|
||||
@ -5998,6 +5998,20 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
x = torch.randn(6, 4, 3, 3)
|
||||
self.run_test(FakeQuantizePerTensorModel(), (x))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_fake_quantize_per_channel(self):
|
||||
class FakeQuantizePerChannelModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
amax = torch.ones(4)
|
||||
scale = amax / 127.
|
||||
zero_point = torch.zeros_like(amax, dtype=torch.long)
|
||||
# Quantize twice to test differnet branches
|
||||
y = torch.fake_quantize_per_channel_affine(input, scale, zero_point, 1, 0, 255)
|
||||
return torch.fake_quantize_per_channel_affine(y, scale, zero_point, 1, -128, 127)
|
||||
|
||||
x = torch.randn(6, 4, 3, 3)
|
||||
self.run_test(FakeQuantizePerChannelModel(), (x))
|
||||
|
||||
def test_batchnorm_training(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
@ -121,6 +121,21 @@ def where(g, condition, self=None, other=None, _outputs=None):
|
||||
return sym_help._unbind_helper(g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs)
|
||||
return g.op("Where", condition, self, other)
|
||||
|
||||
@parse_args('v', 'v', 'v', 'i', 'i', 'i')
|
||||
def fake_quantize_per_channel_affine(g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127):
|
||||
if quant_min not in [0, -128] or quant_max not in [127, 255]:
|
||||
raise RuntimeError(
|
||||
"ONNX defines [0, 255] for quint8 and [-128, 127] for qint8, got [{}, {}]".format(quant_min, quant_max))
|
||||
|
||||
# ONNX defines zero_point to be int8 or uint8
|
||||
if quant_min == 0:
|
||||
zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx['Byte'])
|
||||
else:
|
||||
zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx['Char'])
|
||||
return g.op(
|
||||
"DequantizeLinear",
|
||||
g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis),
|
||||
scale, zero_point, axis_i=axis)
|
||||
|
||||
def _reduce_op_symbolic(onnx_op_name):
|
||||
def symbolic(g, self, dim=None, keepdim=None):
|
||||
|
||||
Reference in New Issue
Block a user