[ONNX] Support restricted quantized range for activation.

PyTorch restricts activations to be in the range (0, 127).
In ONNX, the supported ranges are (0, 255) and (-128, 127),
respectfully, uint8 and int8. This PR extends support for range
(0, 127), by adding additional clipping when detected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76055

Approved by: https://github.com/garymm
This commit is contained in:
BowenBao
2022-04-22 15:08:22 -07:00
committed by PyTorch MergeBot
parent cada2cd3ae
commit 8d31706b9e
3 changed files with 49 additions and 10 deletions

View File

@ -8731,6 +8731,32 @@ class _TestONNXRuntime:
x = torch.randn(6, 4, 3, 3)
self.run_test(FakeQuantizePerChannelModel(), (x))
@skipIfUnsupportedMinOpsetVersion(13)
@disableScriptTest() # RuntimeError: Can't redefine method: forward on class: __torch__.torch.nn.modules.linear.Linear
def test_fake_quantize_activation(self):
from torch import quantization
m = torch.nn.Linear(1, 1)
m.qconfig = quantization.QConfig(
activation=quantization.default_fake_quant,
weight=quantization.default_per_channel_weight_fake_quant)
quantization.prepare_qat(m.train(), inplace=True)
m.apply(quantization.enable_observer)
m.apply(quantization.enable_fake_quant)
for module in m.modules():
if isinstance(module, quantization.FakeQuantize):
module.calculate_qparams()
m.apply(quantization.disable_observer)
m.eval()
# Fake quantize activation is a special case, as it restricts quantized range to be (0, 127),
# while standard 8bit quantization range is (-128, 127) or (0, 255).
# Set fixed weight, bias and inputs to test if ONNX handles the overflow correctly.
m.weight = torch.nn.Parameter(torch.tensor([[1.], [1.], [1.]]))
m.bias = torch.nn.Parameter(torch.tensor([0.]))
x = torch.tensor([[150.], [127.], [-5.]])
self.run_test(m, x)
def test_batchnorm_training(self):
class MyModule(torch.nn.Module):
def __init__(self):

View File

@ -300,6 +300,12 @@ def embedding_bag(g,
@parse_args("v", "v", "v", "i", "i")
def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
# NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
if (quant_min, quant_max) == (0, 127):
sym_help._onnx_opset_unsupported_detailed(
"fake_quantize_per_tensor_affine", 10, 13,
"Quantize range (0, 127) not supported, requires opset 13 Clip")
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
raise RuntimeError(
"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "

View File

@ -7,7 +7,7 @@ import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_opset9 import (overload_by_arg_count, _maybe_cast_reduce_op_input,
nonzero, expand, zeros, ones, size, linear, conv2d,
relu)
relu, unused)
from torch.onnx.symbolic_opset11 import unsqueeze
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
@ -132,25 +132,29 @@ def where(g, condition, self=None, other=None, _outputs=None):
@parse_args("v", "v", "v", "i", "i", "i")
def fake_quantize_per_channel_affine(g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127):
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
raise RuntimeError(
"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
"For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
"Got ({}, {})".format(quant_min, quant_max))
# ONNX defines zero_point to be int8 or uint8
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
else:
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
return g.op(
"DequantizeLinear",
g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis),
scale, zero_point, axis_i=axis)
quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis)
if (quant_min, quant_max) == (0, 127):
quantized = g.op("Clip", quantized, unused(g), g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)))
return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis)
@parse_args("v", "v", "v", "i", "i")
def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
raise RuntimeError(
"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
"For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
"Got ({}, {})".format(quant_min, quant_max))
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
@ -158,7 +162,10 @@ def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.INT8)
if scale.type().scalarType() != "Float":
scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
return g.op("DequantizeLinear", g.op("QuantizeLinear", inputs, scale, zero_point), scale, zero_point)
quantized = g.op("QuantizeLinear", inputs, scale, zero_point)
if (quant_min, quant_max) == (0, 127):
quantized = g.op("Clip", quantized, unused(g), g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)))
return g.op("DequantizeLinear", quantized, scale, zero_point)
def _reduce_op_symbolic(onnx_op_name):
def symbolic(g, self, dim=None, keepdim=None):