mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] [Quant] Enable lowering of quant per tensor and refactor quant pattern (#124041)
**Summary** Per the discussion in https://github.com/pytorch/pytorch/pull/123444, the `decomposed quant/dequant` patterns changed after https://github.com/pytorch/pytorch/pull/123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can: - Avoid the pattern matcher failure introduced in https://github.com/pytorch/pytorch/pull/123445 - Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed. **Changes in this PR** - Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase. - Corresponding changes in the quantization pattern matcher to ensure no bc-breaking. **TestPlan** ``` python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/124041 Approved by: https://github.com/peterbell10, https://github.com/jgong5
This commit is contained in:
committed by
PyTorch MergeBot
parent
1b1b18a7a4
commit
33e6791645
@ -1671,7 +1671,7 @@ static at::Tensor _quantized_convolution_onednn(
|
||||
}
|
||||
tensor src_scales_t = tensor(ideep::scale_t(1, act_scale));
|
||||
tensor wei_scales_t = tensor(weights_scales);
|
||||
tensor dst_scales_t = tensor(ideep::scale_t(1, 1.0/inv_output_scale));
|
||||
tensor dst_scales_t = tensor(ideep::scale_t(1, inv_output_scale));
|
||||
tensor src_zp_t = tensor(ideep::zero_point_t(1, act_zero_point));
|
||||
tensor dst_zp_t = tensor(ideep::zero_point_t(1, output_zero_point));
|
||||
if (act_scale != 1.0f) {
|
||||
@ -1707,7 +1707,7 @@ static at::Tensor _quantized_convolution_onednn(
|
||||
ideep::convolution_forward::prepare(
|
||||
params, src, packed_weight, expected_bias, dst_dims, dst,
|
||||
stride.vec(), dilation.vec(), padding.vec(), padding.vec(), groups,
|
||||
src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
|
||||
src_scales, weights_scales, ideep::scale_t(1, 1.0f / inv_output_scale),
|
||||
src_zero_points, dst_zero_points,
|
||||
op_attr, dnnl::algorithm::convolution_direct,
|
||||
dnnl::prop_kind::forward_inference,
|
||||
|
@ -931,7 +931,6 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||
c10::string_view& unary_post_op_algorithm) {
|
||||
using ideep::tensor;
|
||||
const int64_t dim = input.dim();
|
||||
output_scale = 1.0f / output_scale;
|
||||
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
|
||||
"qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char).");
|
||||
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
|
||||
|
@ -555,15 +555,15 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
|
||||
def matcher_check_fn():
|
||||
# 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1
|
||||
# int8_mixed_fp32: [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
|
||||
# int8_mixed_bf16: [convert_element_type_1, sub, mul_1, optional(convert_element_type_4),
|
||||
# int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution]
|
||||
# int8_mixed_bf16: [dequant_node, optional(convert_element_type_4),
|
||||
# dequantize_per_channel, optional(convert_element_type_3), clone, convolution]
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"],
|
||||
16 if int8_mixed_bf16 else 12,
|
||||
12 if int8_mixed_bf16 else 8,
|
||||
)
|
||||
|
||||
self._test_common(
|
||||
@ -683,14 +683,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
r"""
|
||||
This testcase will quantize Conv2d->Hardtanh pattern.
|
||||
Match.nodes:
|
||||
[qconv2d_pointwise_default_1, convert_element_type_5, clamp_min_1, clamp_max_1, mul_2, round_2, add_1, clamp_min_2,
|
||||
clamp_max_1, mul_2, round_2, add_1, clamp_min_2, clamp_max_2, convert_element_type_8
|
||||
[qconv2d_pointwise_default, convert_element_type_13, clamp_min_3, clamp_max_3]
|
||||
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
|
||||
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
|
||||
"""
|
||||
self._qconv2d_unary_cpu_test_helper(
|
||||
unary_op=torch.nn.Hardtanh(),
|
||||
int8_mixed_bf16=True,
|
||||
qconv2d_unary_matcher_nodes=14,
|
||||
qconv2d_unary_matcher_nodes=11,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -710,14 +709,14 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
r"""
|
||||
This testcase will quantize Conv2d->Hardswish pattern.
|
||||
Match.nodes:
|
||||
[qconv2d_pointwise_default_1, convert_element_type_5, add_1, clamp_min_1,
|
||||
clamp_max_1, mul_2, div, mul_3, round_2, add_2, clamp_min_2, clamp_max_2, convert_element_type_8]
|
||||
[qconv2d_pointwise_default, convert_element_type_13, add_3, clamp_min_3, clamp_max_3, mul_5, div_1]
|
||||
[qconv2d_pointwise_default, convert_element_type, add, clamp_min,
|
||||
clamp_max, mul, div, convert_element_type, quantize_per_tensor]
|
||||
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
|
||||
"""
|
||||
self._qconv2d_unary_cpu_test_helper(
|
||||
unary_op=torch.nn.Hardswish(),
|
||||
int8_mixed_bf16=True,
|
||||
qconv2d_unary_matcher_nodes=20,
|
||||
qconv2d_unary_matcher_nodes=17,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -737,14 +736,14 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
r"""
|
||||
This testcase will quantize Conv2d->SiLU pattern.
|
||||
Match.nodes:
|
||||
[qconv2d_pointwise_default_1, convert_element_type_5, sigmoid, mul_2,
|
||||
mul_3, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_8]
|
||||
[qconv2d_pointwise_default, convert_element_type_13, sigmoid_1, mul_5]
|
||||
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
|
||||
convert_element_type, quantize_per_tensor]
|
||||
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
|
||||
"""
|
||||
self._qconv2d_unary_cpu_test_helper(
|
||||
unary_op=torch.nn.SiLU(),
|
||||
int8_mixed_bf16=True,
|
||||
qconv2d_unary_matcher_nodes=14,
|
||||
qconv2d_unary_matcher_nodes=11,
|
||||
)
|
||||
|
||||
def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False):
|
||||
@ -1028,17 +1027,17 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
|
||||
def matcher_check_fn():
|
||||
# 1. Dequant-conv pattern matched in quantization weight prepack * 1
|
||||
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
|
||||
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 6
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 4
|
||||
)
|
||||
# 2. QConv2D Unary fusion in post-grad fusion pass * 1
|
||||
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
|
||||
# [qconv2d_pointwise_default, quantize_per_tensor]
|
||||
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1)
|
||||
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 7)
|
||||
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 2)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
@ -1107,7 +1106,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
r"""
|
||||
This testcase will quantize Conv2d->ReLU6 pattern with qat flow.
|
||||
"""
|
||||
|
||||
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6())
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1117,7 +1115,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
r"""
|
||||
This testcase will quantize Conv2d->Hardtanh pattern with qat flow.
|
||||
"""
|
||||
|
||||
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh())
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1127,7 +1124,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
r"""
|
||||
This testcase will quantize Conv2d->SiLU pattern with qat flow.
|
||||
"""
|
||||
|
||||
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.SiLU())
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1137,7 +1133,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
r"""
|
||||
This testcase will quantize Conv2d->Hardswish pattern with qat flow.
|
||||
"""
|
||||
|
||||
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardswish())
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1176,18 +1171,17 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
|
||||
def matcher_check_fn():
|
||||
# 1. Dequant-conv pattern matched in quantization weight prepack * 2
|
||||
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
|
||||
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8
|
||||
)
|
||||
# 2. Qconv2d Binary fusion in post-grad fusion pass * 1
|
||||
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, mul_6, round_4, add_4,
|
||||
# clamp_min_3, clamp_max_3, convert_element_type_6]
|
||||
# [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, quantize_per_tensor]
|
||||
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
|
||||
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 11)
|
||||
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 4)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
@ -1236,18 +1230,17 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
|
||||
def matcher_check_fn():
|
||||
# 1. Dequant-conv pattern matched in quantization weight prepack * 2
|
||||
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
|
||||
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8
|
||||
)
|
||||
# 2. Qconv2d Binary fusion in post-grad fusion pass * 1
|
||||
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, relu, mul_6, round_4, add_4,
|
||||
# clamp_min_3, clamp_max_3, convert_element_type_6]
|
||||
# [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, relu, quantize_per_tensor]
|
||||
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
|
||||
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 12)
|
||||
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 5)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
@ -1294,16 +1287,16 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
|
||||
def matcher_check_fn():
|
||||
# 1. Dequant pattern matcher for dequant promotion * 1
|
||||
# [convert_element_type_3, sub_1, mul_3]
|
||||
# [dequantize_per_tensor]
|
||||
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
|
||||
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 3)
|
||||
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 1)
|
||||
# 2. Dequant-conv pattern matched in quantization weight prepack * 3
|
||||
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
|
||||
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 3
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 18
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
|
||||
)
|
||||
# 3. Qconv2d Binary fusion in post-grad fusion pass * 1
|
||||
# [qconv2d_pointwise_default_1, add_3]
|
||||
@ -1445,7 +1438,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
|
||||
17 if bias else 16,
|
||||
13 if bias else 12,
|
||||
)
|
||||
|
||||
self._qlinear_cpu_test_helper(
|
||||
@ -1473,7 +1466,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
|
||||
21 if bias else 20,
|
||||
17 if bias else 16,
|
||||
)
|
||||
|
||||
self._qlinear_cpu_test_helper(
|
||||
@ -1722,12 +1715,16 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
x1 = torch.randn((2, 4))
|
||||
x2 = torch.randn((2, 5))
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
|
||||
)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
(x1, x2),
|
||||
2,
|
||||
8,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1763,22 +1760,19 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
|
||||
1
|
||||
)
|
||||
# Totally 6 pattern_matcher_count, 31 pattern_matcher_nodes
|
||||
# 1. Pair of to_int8 and to_fp32 * 3, matched in pointless_convert pass at
|
||||
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
|
||||
# 2. Dequant-conv pattern matched in quantization weight prepack * 1
|
||||
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
|
||||
# 3. qconv2d_relu fusion in post-grad fusion pass * 1
|
||||
# [qconv2d_pointwise_default, relu, mul_2, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
|
||||
# 4. qmaxpool2d * 1
|
||||
# [convert_element_type_3, sub_1, mul_3, max_pool2d_with_indices, getitem, mul_4, round_3, add_2,
|
||||
# clamp_min_2, clamp_max_2, convert_element_type_4]
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(counters["inductor"]["qmaxpool2d_matcher_count"], 1)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
6,
|
||||
31,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@ -1852,22 +1846,19 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
|
||||
mod = M().eval()
|
||||
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
|
||||
# Totally 10 pattern_matcher_count, 49 pattern_matcher_nodes
|
||||
# 1. Pair of to_int8 and to_fp32 * 5, matched in pointless_convert pass at
|
||||
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
|
||||
# 2. Dequant-conv pattern matched in quantization weight prepack * 2
|
||||
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
|
||||
# 3. qconv2d fusion in post-grad fusion pass * 2
|
||||
# [qconv2d_pointwise_default, mul_2, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
|
||||
# 4. qcat * 1
|
||||
# [convert_element_type_3, sub_1, mul_3, convert_element_type_7, sub_3, mul_7, cat, mul_8, round_5,
|
||||
# add_4, clamp_min_4, clamp_max_4, convert_element_type_8]
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(counters["inductor"]["qcat_matcher_count"], 1)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
(v,),
|
||||
10,
|
||||
49,
|
||||
check_quantization=True,
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/99841.
|
||||
|
@ -4309,7 +4309,7 @@ class TestQuantizedLinear(TestCase):
|
||||
if post_op in ("none", "relu", "gelu"):
|
||||
qy_cpu = qlinear_op(
|
||||
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
|
||||
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
|
||||
b, used_y_scale, used_y_zp, output_dtype,
|
||||
post_op, unary_post_op_args, post_op_algo
|
||||
)
|
||||
if post_op == "relu":
|
||||
@ -4330,7 +4330,7 @@ class TestQuantizedLinear(TestCase):
|
||||
accum = accum.bfloat16()
|
||||
qy_cpu = qlinear_op(
|
||||
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
|
||||
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
|
||||
b, used_y_scale, used_y_zp, output_dtype,
|
||||
accum, x2_scale, x2_zp, "sum", binary_alpha,
|
||||
unary_post_op, unary_post_op_args, post_op_algo
|
||||
)
|
||||
@ -4348,7 +4348,7 @@ class TestQuantizedLinear(TestCase):
|
||||
binary_alpha = 1.0 # we only support alpha=1.0 now
|
||||
qy_cpu = qlinear_op(
|
||||
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
|
||||
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
|
||||
b, used_y_scale, used_y_zp, output_dtype,
|
||||
x2, 1.0, 0, "add", binary_alpha,
|
||||
unary_post_op, unary_post_op_args, post_op_algo
|
||||
)
|
||||
@ -6796,7 +6796,7 @@ class TestQuantizedConv(TestCase):
|
||||
pads,
|
||||
dilations,
|
||||
groups,
|
||||
1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant
|
||||
Y_scale,
|
||||
Y_zero_point,
|
||||
qconv_output_dtype,
|
||||
post_op.binary_attr,
|
||||
@ -6818,7 +6818,7 @@ class TestQuantizedConv(TestCase):
|
||||
pads,
|
||||
dilations,
|
||||
groups,
|
||||
1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant
|
||||
Y_scale,
|
||||
Y_zero_point,
|
||||
qconv_output_dtype,
|
||||
post_op.unary_attr,
|
||||
|
@ -478,70 +478,6 @@ def linear_dynamic_fp16_unpacked_weight(
|
||||
)
|
||||
|
||||
|
||||
# The difference between quantize_per_tensor.default and quantize_per_tensor.tensor is
|
||||
# scale and zero_point is scalar or scalar tensor
|
||||
@register_decomposition(quantized_decomposed.quantize_per_tensor.default)
|
||||
def quantize_per_tensor_default_decomp_impl(
|
||||
input: torch.Tensor,
|
||||
scale: float,
|
||||
zero_point: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
if input.dtype == torch.bfloat16:
|
||||
input = input.to(torch.float32)
|
||||
inv_scale = 1.0 / scale
|
||||
return torch.clamp(
|
||||
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
|
||||
).to(dtype)
|
||||
|
||||
|
||||
# The difference between dequantize_per_tensor.default and dequantize_per_tensor.tensor is
|
||||
# scale and zero_point is scalar or scalar tensor
|
||||
@register_decomposition(quantized_decomposed.dequantize_per_tensor.default)
|
||||
def dequantize_per_tensor_default_decomp_impl(
|
||||
input: torch.Tensor,
|
||||
scale: float,
|
||||
zero_point: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return (input.to(torch.float32) - zero_point) * scale
|
||||
|
||||
|
||||
@register_decomposition(quantized_decomposed.quantize_per_tensor.tensor)
|
||||
def quantize_per_tensor_tensor_decomp_impl(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero_point: torch.Tensor,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
if input.dtype == torch.bfloat16:
|
||||
input = input.to(torch.float32)
|
||||
inv_scale = 1.0 / scale
|
||||
return torch.clamp(
|
||||
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
|
||||
).to(dtype)
|
||||
|
||||
|
||||
@register_decomposition(quantized_decomposed.dequantize_per_tensor.tensor)
|
||||
def dequantize_per_tensor_tensor_decomp_impl(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero_point: torch.Tensor,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return (input.to(torch.float32) - zero_point.to(torch.int32)) * scale.to(
|
||||
torch.float32
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
|
||||
def q_embedding_bag_byte_unpack_decomp(packed):
|
||||
def bitcast_u8_to_f32(u8):
|
||||
|
@ -49,10 +49,25 @@ quantization.
|
||||
"""
|
||||
|
||||
|
||||
def _get_pattern_output_dtype(match: Match):
|
||||
"""
|
||||
Get the pattern's output dtype from node's meta
|
||||
Assume only 1 output node in this matched pattern.
|
||||
"""
|
||||
pattern_output_nodes = match.output_nodes()
|
||||
assert len(pattern_output_nodes) == 1
|
||||
output_node = pattern_output_nodes[0]
|
||||
assert isinstance(output_node, torch.fx.Node)
|
||||
output_dtype = output_node.meta["val"].dtype
|
||||
if output_dtype is torch.uint8:
|
||||
output_dtype = None
|
||||
return output_dtype
|
||||
|
||||
|
||||
def _may_generate_pattern_with_dtype_convert(
|
||||
pattern, dtype=Arg(), dtype_convert=True, users=1
|
||||
pattern, dtype=Arg(), with_dtype_convert=True, users=1
|
||||
):
|
||||
if dtype_convert:
|
||||
if with_dtype_convert:
|
||||
return CallFunction(
|
||||
prims.convert_element_type.default,
|
||||
pattern,
|
||||
@ -94,30 +109,25 @@ def _generate_linear_t_pattern(
|
||||
def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16):
|
||||
# only insert to_dtype if is_bf16 is True
|
||||
computation_call = _may_generate_pattern_with_dtype_convert(
|
||||
call_fn, dtype=KeywordArg("to_float"), dtype_convert=is_bf16, users=users
|
||||
call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users
|
||||
)
|
||||
return unary_fusion(computation_call)
|
||||
|
||||
|
||||
"""
|
||||
dequantize activation:
|
||||
x = x.to(fp32)
|
||||
x = x - zero_point
|
||||
x = x * scale
|
||||
"""
|
||||
dequantize_per_tensor_activation_pattern = CallFunction(
|
||||
aten.mul.Tensor,
|
||||
CallFunction(
|
||||
aten.sub.Tensor,
|
||||
CallFunction(
|
||||
prims.convert_element_type.default,
|
||||
KeywordArg("x"),
|
||||
KeywordArg("x_dq_dtype"),
|
||||
),
|
||||
def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False):
|
||||
dequantize_per_tensor_activation_pattern = CallFunction(
|
||||
quantized_decomposed.dequantize_per_tensor.tensor
|
||||
if is_tensor_overload
|
||||
else quantized_decomposed.dequantize_per_tensor.default,
|
||||
KeywordArg("x"),
|
||||
KeywordArg("x_scale"),
|
||||
KeywordArg("x_zp"),
|
||||
),
|
||||
KeywordArg("x_scale"),
|
||||
)
|
||||
KeywordArg("x_quant_min"),
|
||||
KeywordArg("x_quant_max"),
|
||||
KeywordArg("x_dq_dtype"),
|
||||
)
|
||||
return dequantize_per_tensor_activation_pattern
|
||||
|
||||
|
||||
dequantize_per_channel_weight_pattern = CallFunction(
|
||||
quantized_decomposed.dequantize_per_channel.default,
|
||||
@ -200,17 +210,13 @@ def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1):
|
||||
|
||||
|
||||
dequantize_accum_pattern = CallFunction(
|
||||
aten.mul.Tensor,
|
||||
CallFunction(
|
||||
aten.sub.Tensor,
|
||||
CallFunction(
|
||||
prims.convert_element_type.default,
|
||||
KeywordArg("accum"),
|
||||
KeywordArg("accum_dq_dtype"),
|
||||
),
|
||||
KeywordArg("accum_zp"),
|
||||
),
|
||||
quantized_decomposed.dequantize_per_tensor.default,
|
||||
KeywordArg("accum"),
|
||||
KeywordArg("accum_scale"),
|
||||
KeywordArg("accum_zp"),
|
||||
Arg(),
|
||||
Arg(),
|
||||
KeywordArg("accum_dq_dtype"),
|
||||
)
|
||||
|
||||
|
||||
@ -241,43 +247,18 @@ def generate_pattern_with_unary(computation_call, unary_post_op):
|
||||
return computation_call
|
||||
|
||||
|
||||
def generate_pattern_with_output_quant(
|
||||
computation_call, has_to_fp32_before_quant=False
|
||||
):
|
||||
"""
|
||||
quantize output:
|
||||
output = round(output * o_inv_scale)
|
||||
output = output + zero_point
|
||||
output = clamp_min(output, 0)
|
||||
output = clamp_max(output, 127)
|
||||
output = output.to(uint8)
|
||||
"""
|
||||
def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False):
|
||||
quantized_op_output_pattern_pt2e = CallFunction(
|
||||
prims.convert_element_type.default,
|
||||
CallFunction(
|
||||
aten.clamp_max.default,
|
||||
CallFunction(
|
||||
aten.clamp_min.default,
|
||||
CallFunction(
|
||||
aten.add.Tensor,
|
||||
CallFunction(
|
||||
aten.round.default,
|
||||
CallFunction(
|
||||
aten.mul.Tensor,
|
||||
_may_generate_pattern_with_dtype_convert(
|
||||
computation_call,
|
||||
KeywordArg("autocast_output_quant_dtype"),
|
||||
has_to_fp32_before_quant,
|
||||
),
|
||||
KeywordArg("o_inv_scale"),
|
||||
),
|
||||
),
|
||||
KeywordArg("o_zp"),
|
||||
),
|
||||
KeywordArg("o_qmin"),
|
||||
),
|
||||
KeywordArg("o_qmax"),
|
||||
quantized_decomposed.quantize_per_tensor.default,
|
||||
_may_generate_pattern_with_dtype_convert(
|
||||
computation_call,
|
||||
Arg(),
|
||||
with_dtype_convert,
|
||||
),
|
||||
KeywordArg("o_inv_scale"),
|
||||
KeywordArg("o_zp"),
|
||||
KeywordArg("o_qmin"),
|
||||
KeywordArg("o_qmax"),
|
||||
KeywordArg("o_dtype"),
|
||||
)
|
||||
return quantized_op_output_pattern_pt2e
|
||||
@ -293,8 +274,9 @@ def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_val
|
||||
return actual_value == expected_value
|
||||
|
||||
|
||||
def _is_valid_quantized_conv2d_optimization_pattern(output_dtype):
|
||||
def _is_valid_quantized_conv2d_optimization_pattern():
|
||||
def fn(match):
|
||||
output_dtype = _get_pattern_output_dtype(match)
|
||||
if output_dtype is not None:
|
||||
# Only keep matched pattern with same output_dtype
|
||||
qconv_node_after_weight_prepack = filter_nodes(
|
||||
@ -312,13 +294,11 @@ def _register_quantized_conv_lowering(
|
||||
pattern,
|
||||
pass_number,
|
||||
computation_op,
|
||||
output_dtype,
|
||||
unary_attr,
|
||||
original_pattern_output_dtype=torch.float32,
|
||||
):
|
||||
@register_lowering_pattern(
|
||||
pattern,
|
||||
extra_check=_is_valid_quantized_conv2d_optimization_pattern(output_dtype),
|
||||
extra_check=_is_valid_quantized_conv2d_optimization_pattern(),
|
||||
pass_number=pass_number,
|
||||
)
|
||||
def qconv(match: Match, *args, **kwargs):
|
||||
@ -342,13 +322,11 @@ def _register_quantized_conv_lowering(
|
||||
kwargs["dilation"],
|
||||
kwargs["groups"],
|
||||
)
|
||||
output_dtype = _get_pattern_output_dtype(match)
|
||||
assert output_dtype in [None, torch.float32, torch.bfloat16]
|
||||
# Output QParams
|
||||
o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0
|
||||
o_zero_point = kwargs["o_zp"] if output_dtype is None else 0
|
||||
assert (
|
||||
kwargs["output_dtype"] is original_pattern_output_dtype
|
||||
) # Expected int8-in fp32-out qconv in weight prepack phase
|
||||
assert (
|
||||
kwargs["attr"] == "none"
|
||||
) # Expected no post op fused in weight prepack phase
|
||||
@ -383,8 +361,9 @@ def _register_quantized_conv_lowering(
|
||||
return qconv
|
||||
|
||||
|
||||
def _is_valid_quantized_linear_optimization_pattern(output_dtype):
|
||||
def _is_valid_quantized_linear_optimization_pattern():
|
||||
def fn(match):
|
||||
output_dtype = _get_pattern_output_dtype(match)
|
||||
if output_dtype is not None:
|
||||
# Only keep matched pattern with same output_dtype
|
||||
qlinear_node_after_weight_prepack = filter_nodes(
|
||||
@ -402,16 +381,15 @@ def _register_quantized_linear_lowering(
|
||||
pattern,
|
||||
pass_number,
|
||||
computation_op,
|
||||
output_dtype,
|
||||
unary_attr,
|
||||
original_pattern_output_dtype=torch.float32,
|
||||
):
|
||||
@register_lowering_pattern(
|
||||
pattern,
|
||||
extra_check=_is_valid_quantized_linear_optimization_pattern(output_dtype),
|
||||
extra_check=_is_valid_quantized_linear_optimization_pattern(),
|
||||
pass_number=pass_number,
|
||||
)
|
||||
def qlinear(match: Match, *args, **kwargs):
|
||||
output_dtype = _get_pattern_output_dtype(match)
|
||||
# Activation QParams
|
||||
x, x_scale, x_zp = (
|
||||
kwargs["x"],
|
||||
@ -431,9 +409,6 @@ def _register_quantized_linear_lowering(
|
||||
# Output QParams
|
||||
o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0
|
||||
o_zero_point = kwargs["o_zp"] if output_dtype is None else 0
|
||||
assert (
|
||||
kwargs["output_dtype"] is original_pattern_output_dtype
|
||||
) # Expected int8-in fp32/bf16-out qlinear in weight prepack phase
|
||||
assert (
|
||||
kwargs["postop_name"] == "none"
|
||||
) # Expected no post op fused in weight prepack phase
|
||||
@ -460,7 +435,7 @@ def _register_quantized_linear_lowering(
|
||||
return qlinear
|
||||
|
||||
|
||||
def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype):
|
||||
def _is_valid_quantized_conv_binary_optimization_pattern():
|
||||
# Check if it's a valid Conv Binary Pattern:
|
||||
# * qconv2d_pointwise should only has one users
|
||||
# * Extra input of binary node comes from dequant pattern
|
||||
@ -470,6 +445,7 @@ def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype):
|
||||
# ancestor nodes of the compute node, except for the binary node
|
||||
# connected to the compute node.
|
||||
def fn(match):
|
||||
output_dtype = _get_pattern_output_dtype(match)
|
||||
compute_node = filter_nodes(match.nodes, torch.ops.onednn.qconv2d_pointwise)[0]
|
||||
# qconv2d_pointwise should only have one user
|
||||
if len(compute_node.users) != 1:
|
||||
@ -485,7 +461,8 @@ def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype):
|
||||
assert extra_input_of_binary_node is not None
|
||||
# Extra input of binary node comes from dequant pattern
|
||||
if (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or (
|
||||
extra_input_of_binary_node.target != aten.mul.Tensor
|
||||
extra_input_of_binary_node.target
|
||||
!= quantized_decomposed.dequantize_per_tensor.default
|
||||
):
|
||||
return False
|
||||
|
||||
@ -536,15 +513,15 @@ def _register_quantized_conv_binary_lowering(
|
||||
pattern,
|
||||
pass_number,
|
||||
computation_op,
|
||||
output_dtype,
|
||||
binary_unary_attr,
|
||||
):
|
||||
@register_lowering_pattern(
|
||||
pattern,
|
||||
extra_check=_is_valid_quantized_conv_binary_optimization_pattern(output_dtype),
|
||||
extra_check=_is_valid_quantized_conv_binary_optimization_pattern(),
|
||||
pass_number=pass_number,
|
||||
)
|
||||
def qconv_binary(match: Match, *args, **kwargs):
|
||||
output_dtype = _get_pattern_output_dtype(match)
|
||||
x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
|
||||
accum = (
|
||||
kwargs["accum"] if output_dtype is None else kwargs["accum_after_dequant"]
|
||||
@ -629,13 +606,11 @@ def _register_quantization_unary_fusion():
|
||||
conv_unary_replace_patterns = {
|
||||
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
|
||||
get_dequantize_qconv_pt2e_pattern(1),
|
||||
has_to_fp32_before_quant=is_bf16,
|
||||
),
|
||||
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
|
||||
generate_pattern_with_unary(
|
||||
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
|
||||
),
|
||||
has_to_fp32_before_quant=is_bf16,
|
||||
),
|
||||
UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant(
|
||||
_unary_fusion_pattern(
|
||||
@ -644,7 +619,7 @@ def _register_quantization_unary_fusion():
|
||||
1,
|
||||
is_bf16,
|
||||
),
|
||||
has_to_fp32_before_quant=False,
|
||||
with_dtype_convert=is_bf16,
|
||||
),
|
||||
UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant(
|
||||
_unary_fusion_pattern(
|
||||
@ -653,7 +628,7 @@ def _register_quantization_unary_fusion():
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
has_to_fp32_before_quant=False,
|
||||
with_dtype_convert=is_bf16,
|
||||
),
|
||||
UnaryAttr("swish", [], ""): generate_pattern_with_output_quant(
|
||||
_unary_fusion_pattern(
|
||||
@ -662,7 +637,7 @@ def _register_quantization_unary_fusion():
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
has_to_fp32_before_quant=False,
|
||||
with_dtype_convert=is_bf16,
|
||||
),
|
||||
}
|
||||
|
||||
@ -672,9 +647,7 @@ def _register_quantization_unary_fusion():
|
||||
patterns,
|
||||
1, # pass_number
|
||||
torch.ops.onednn.qconv2d_pointwise, # computation_op
|
||||
None, # output_dtype, None is the default value for int8 output
|
||||
unary_attr, # unary_attr
|
||||
original_pattern_output_dtype=original_pattern_output_dtype,
|
||||
)
|
||||
|
||||
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
|
||||
@ -682,22 +655,34 @@ def _register_quantization_unary_fusion():
|
||||
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
|
||||
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
|
||||
),
|
||||
UnaryAttr("hardtanh", [], ""): _unary_fusion_pattern(
|
||||
_hardtanh_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1),
|
||||
1,
|
||||
UnaryAttr("hardtanh", [], ""): _may_generate_pattern_with_dtype_convert(
|
||||
_unary_fusion_pattern(
|
||||
_hardtanh_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1),
|
||||
1,
|
||||
is_bf16,
|
||||
),
|
||||
Arg(),
|
||||
is_bf16,
|
||||
),
|
||||
UnaryAttr("hardswish", [], ""): _unary_fusion_pattern(
|
||||
_hardswish_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
UnaryAttr("hardswish", [], ""): _may_generate_pattern_with_dtype_convert(
|
||||
_unary_fusion_pattern(
|
||||
_hardswish_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
Arg(),
|
||||
is_bf16,
|
||||
),
|
||||
UnaryAttr("swish", [], ""): _unary_fusion_pattern(
|
||||
_silu_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
UnaryAttr("swish", [], ""): _may_generate_pattern_with_dtype_convert(
|
||||
_unary_fusion_pattern(
|
||||
_silu_fusion,
|
||||
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
Arg(),
|
||||
is_bf16,
|
||||
),
|
||||
}
|
||||
@ -708,9 +693,7 @@ def _register_quantization_unary_fusion():
|
||||
patterns,
|
||||
2, # pass_number
|
||||
torch.ops.onednn.qconv2d_pointwise, # computation_op
|
||||
original_pattern_output_dtype, # output_dtype
|
||||
unary_attr, # unary_attr
|
||||
original_pattern_output_dtype=original_pattern_output_dtype,
|
||||
)
|
||||
|
||||
# QLinear
|
||||
@ -720,11 +703,9 @@ def _register_quantization_unary_fusion():
|
||||
linear_unary_replace_patterns = {
|
||||
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
|
||||
qlinear_pattern,
|
||||
has_to_fp32_before_quant=is_bf16,
|
||||
),
|
||||
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
|
||||
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
|
||||
has_to_fp32_before_quant=is_bf16,
|
||||
),
|
||||
UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant(
|
||||
_unary_fusion_pattern(
|
||||
@ -735,18 +716,18 @@ def _register_quantization_unary_fusion():
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
has_to_fp32_before_quant=False,
|
||||
with_dtype_convert=is_bf16,
|
||||
),
|
||||
UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant(
|
||||
_unary_fusion_pattern(
|
||||
_gelu_fusion_tanh,
|
||||
get_qlinear_pt2e_pattern(
|
||||
x_scale_zp_are_tensors, 1 if is_bf16 else 2
|
||||
x_scale_zp_are_tensors, 1 if is_bf16 else 4
|
||||
),
|
||||
4,
|
||||
is_bf16,
|
||||
),
|
||||
has_to_fp32_before_quant=False,
|
||||
with_dtype_convert=is_bf16,
|
||||
),
|
||||
}
|
||||
|
||||
@ -755,9 +736,7 @@ def _register_quantization_unary_fusion():
|
||||
patterns,
|
||||
1, # pass_number
|
||||
torch.ops.onednn.qlinear_pointwise, # computation_op
|
||||
None, # output_dtype
|
||||
unary_attr, # unary_attr
|
||||
original_pattern_output_dtype=original_pattern_output_dtype,
|
||||
)
|
||||
|
||||
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
|
||||
@ -765,20 +744,28 @@ def _register_quantization_unary_fusion():
|
||||
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
|
||||
qlinear_pattern, aten.relu.default
|
||||
),
|
||||
UnaryAttr("gelu", [], "none"): _unary_fusion_pattern(
|
||||
_gelu_fusion_erf,
|
||||
get_qlinear_pt2e_pattern(
|
||||
x_scale_zp_are_tensors, 1 if is_bf16 else 2
|
||||
UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert(
|
||||
_unary_fusion_pattern(
|
||||
_gelu_fusion_erf,
|
||||
get_qlinear_pt2e_pattern(
|
||||
x_scale_zp_are_tensors, 1 if is_bf16 else 2
|
||||
),
|
||||
2,
|
||||
is_bf16,
|
||||
),
|
||||
2,
|
||||
Arg(),
|
||||
is_bf16,
|
||||
),
|
||||
UnaryAttr("gelu", [], "tanh"): _unary_fusion_pattern(
|
||||
_gelu_fusion_tanh,
|
||||
get_qlinear_pt2e_pattern(
|
||||
x_scale_zp_are_tensors, 1 if is_bf16 else 4
|
||||
UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert(
|
||||
_unary_fusion_pattern(
|
||||
_gelu_fusion_tanh,
|
||||
get_qlinear_pt2e_pattern(
|
||||
x_scale_zp_are_tensors, 1 if is_bf16 else 4
|
||||
),
|
||||
4,
|
||||
is_bf16,
|
||||
),
|
||||
4,
|
||||
Arg(),
|
||||
is_bf16,
|
||||
),
|
||||
}
|
||||
@ -788,9 +775,7 @@ def _register_quantization_unary_fusion():
|
||||
patterns,
|
||||
2, # pass_number
|
||||
torch.ops.onednn.qlinear_pointwise, # computation_op
|
||||
original_pattern_output_dtype, # output_dtype
|
||||
unary_attr, # unary_attr
|
||||
original_pattern_output_dtype=original_pattern_output_dtype,
|
||||
)
|
||||
|
||||
|
||||
@ -822,7 +807,6 @@ def _register_quantization_binary_fusion():
|
||||
dequantize_accum_pattern,
|
||||
int8_mixed_bf16_with_inplace_add,
|
||||
),
|
||||
has_to_fp32_before_quant=int8_mixed_bf16_with_inplace_add,
|
||||
),
|
||||
BinaryUnaryAttr(
|
||||
"sum", 1.0, "relu", [], ""
|
||||
@ -836,7 +820,6 @@ def _register_quantization_binary_fusion():
|
||||
),
|
||||
aten.relu.default,
|
||||
),
|
||||
has_to_fp32_before_quant=int8_mixed_bf16_with_inplace_add,
|
||||
),
|
||||
}
|
||||
|
||||
@ -845,7 +828,6 @@ def _register_quantization_binary_fusion():
|
||||
patterns,
|
||||
0, # pass_number
|
||||
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
||||
None, # output_dtype
|
||||
binary_unary_attr, # binary_unary_attr
|
||||
)
|
||||
|
||||
@ -871,11 +853,6 @@ def _register_quantization_binary_fusion():
|
||||
patterns,
|
||||
0, # pass_number
|
||||
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
||||
# Note that for int8-mixed-bf16 and non-inplace add, because we have
|
||||
# q-dq inserted at extra input of add, so the non-inplace add has bf16 and fp32 inputs,
|
||||
# the output dtype will be float32.
|
||||
# For inplace add, there is a extra to_bf16 node at add output, so the fusion pattern has bfloat16 output.
|
||||
torch.bfloat16,
|
||||
binary_unary_attr, # binary_unary_attr
|
||||
)
|
||||
else:
|
||||
@ -883,7 +860,6 @@ def _register_quantization_binary_fusion():
|
||||
patterns,
|
||||
1, # pass_number
|
||||
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
||||
torch.float32,
|
||||
binary_unary_attr, # binary_unary_attr
|
||||
)
|
||||
|
||||
@ -905,8 +881,6 @@ def _register_quantization_binary_fusion():
|
||||
patterns,
|
||||
1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number
|
||||
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
|
||||
# Same output dtype setting as conv-add-relu pattern
|
||||
torch.bfloat16 if int8_mixed_bf16_with_inplace_add else torch.float32,
|
||||
binary_unary_attr, # binary_unary_attr
|
||||
)
|
||||
|
||||
@ -962,6 +936,8 @@ def _register_quantized_maxpool2d_lowering(
|
||||
ceil_mode,
|
||||
)
|
||||
computation_args, _ = require_channels_last(computation_op, *computation_args)
|
||||
counters["inductor"]["qmaxpool2d_matcher_count"] += 1
|
||||
counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes)
|
||||
return L[computation_op](*computation_args)
|
||||
|
||||
return qmaxpool2d
|
||||
@ -996,7 +972,7 @@ def _register_quantization_maxpool2d():
|
||||
for max_pool2d_args in max_pool2d_args_list:
|
||||
dequantize_maxpool2d_pattern = CallFunction(
|
||||
aten.max_pool2d_with_indices.default,
|
||||
dequantize_per_tensor_activation_pattern,
|
||||
get_dequantize_per_tensor_activation_pattern(),
|
||||
KeywordArg("kernel_size"),
|
||||
*max_pool2d_args,
|
||||
)
|
||||
@ -1033,26 +1009,23 @@ def _is_input_output_same_scale_zp(check_node):
|
||||
def fn(match):
|
||||
# Ensure all the inputs and output has same scale and zero point
|
||||
# Step 1: Check inputs/output zero point
|
||||
sub_nodes = filter_nodes(match.nodes, aten.sub.Tensor)
|
||||
zero_points = [node.args[1] for node in sub_nodes]
|
||||
add_nodes = filter_nodes(match.nodes, aten.add.Tensor)
|
||||
assert len(add_nodes) == 1, "expect only 1 add node at output quant pattern"
|
||||
zero_points.append(add_nodes[0].args[1])
|
||||
# Get dequant nodes at input
|
||||
dequant_nodes = filter_nodes(
|
||||
match.nodes, quantized_decomposed.dequantize_per_tensor.default
|
||||
)
|
||||
zero_points = [node.args[2] for node in dequant_nodes]
|
||||
# Get quant nodes at output
|
||||
quant_nodes = filter_nodes(
|
||||
match.nodes, quantized_decomposed.quantize_per_tensor.default
|
||||
)
|
||||
assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern"
|
||||
zero_points.append(quant_nodes[0].args[2])
|
||||
if not all(zero_point == zero_points[0] for zero_point in zero_points):
|
||||
return False
|
||||
|
||||
# Step 2: Check inputs/output scale
|
||||
mul_nodes = filter_nodes(match.nodes, aten.mul.Tensor)
|
||||
# We need to find mul node at output since the scale value is reciprocal to input scale.
|
||||
# Mul node at output should connect to cat node directly.
|
||||
scales = [
|
||||
(
|
||||
mul_node.args[1]
|
||||
if mul_node.args[0].target is check_node # type: ignore[union-attr]
|
||||
else 1.0 / mul_node.args[1] # type: ignore[operator]
|
||||
)
|
||||
for mul_node in mul_nodes
|
||||
]
|
||||
scales = [node.args[1] for node in dequant_nodes]
|
||||
scales.append(quant_nodes[0].args[1])
|
||||
if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type]
|
||||
return False
|
||||
|
||||
@ -1072,22 +1045,20 @@ def _register_quantized_cat_lowering(
|
||||
def qcat(match: Match, inputs, dim, **kwargs):
|
||||
# inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...]
|
||||
uint8_inputs = [input[0] for input in inputs]
|
||||
counters["inductor"]["qcat_matcher_count"] += 1
|
||||
counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes)
|
||||
return L[computation_op](uint8_inputs, dim)
|
||||
|
||||
return qcat
|
||||
|
||||
|
||||
_raw_dequantize_per_tensor_activation_pattern = CallFunction(
|
||||
aten.mul.Tensor,
|
||||
CallFunction(
|
||||
aten.sub.Tensor,
|
||||
CallFunction(
|
||||
prims.convert_element_type.default,
|
||||
Arg(),
|
||||
Arg(),
|
||||
),
|
||||
Arg(),
|
||||
),
|
||||
quantized_decomposed.dequantize_per_tensor.default,
|
||||
Arg(),
|
||||
Arg(),
|
||||
Arg(),
|
||||
Arg(),
|
||||
Arg(),
|
||||
Arg(),
|
||||
)
|
||||
|
||||
@ -1125,7 +1096,7 @@ def _register_quantized_reshape_lowering(
|
||||
def _register_quantization_reshape():
|
||||
dequantize_reshape_pattern = CallFunction(
|
||||
torch.ops.aten.reshape.default,
|
||||
dequantize_per_tensor_activation_pattern,
|
||||
get_dequantize_per_tensor_activation_pattern(),
|
||||
KeywordArg("shape"),
|
||||
)
|
||||
_register_quantized_reshape_lowering(
|
||||
@ -1274,35 +1245,33 @@ def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
|
||||
assert dtype in [torch.float32, torch.bfloat16]
|
||||
dequant_pattern_end_node = match.output_node()
|
||||
if dequant_pattern_end_node.target not in [
|
||||
aten.mul.Tensor,
|
||||
quantized_decomposed.dequantize_per_tensor.default,
|
||||
prims.convert_element_type.default,
|
||||
aten.reshape.default,
|
||||
]:
|
||||
return False
|
||||
|
||||
if dequant_pattern_end_node.target is aten.reshape.default:
|
||||
mul_node = (
|
||||
dequant_pattern_end_node.args[0] # pattern: linear <- reshape <- mul
|
||||
dequant_node = (
|
||||
dequant_pattern_end_node.args[
|
||||
0
|
||||
] # pattern: linear <- reshape <- dequant
|
||||
if dtype == torch.float32
|
||||
else dequant_pattern_end_node.args[0].args[
|
||||
0
|
||||
] # pattern: linear <- reshape <- to_bf16 <- mul
|
||||
] # pattern: linear <- reshape <- to_bf16 <- dequant
|
||||
)
|
||||
else:
|
||||
mul_node = (
|
||||
dequant_pattern_end_node # pattern: linear <- mul
|
||||
dequant_node = (
|
||||
dequant_pattern_end_node # pattern: linear <- dequant
|
||||
if dtype == torch.float32
|
||||
else dequant_pattern_end_node.args[
|
||||
0
|
||||
] # pattern: linear <- to_bf16 <- mul
|
||||
] # pattern: linear <- to_bf16 <- dequant
|
||||
)
|
||||
|
||||
sub_node = mul_node.args[0]
|
||||
to_fp32_node = sub_node.args[0]
|
||||
if (
|
||||
mul_node.target is aten.mul.Tensor
|
||||
and sub_node.target is aten.sub.Tensor
|
||||
and to_fp32_node.target is prims.convert_element_type.default
|
||||
dequant_node.target is quantized_decomposed.dequantize_per_tensor.default
|
||||
and len(list(dequant_pattern_end_node.users)) > 1
|
||||
):
|
||||
# If dequant pattern has more than 1 users, then do dequant promoted
|
||||
@ -1363,10 +1332,10 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
|
||||
|
||||
# Find the start node and end node of a dequant pattern
|
||||
# * End node should be the match.output_node()
|
||||
# * Start node should be the node of dtype convert to float32
|
||||
# * Start node should be the node of dequantize_per_tensor
|
||||
dequant_pattern_end_node = match.output_node()
|
||||
assert dequant_pattern_end_node.target in [
|
||||
aten.mul.Tensor,
|
||||
quantized_decomposed.dequantize_per_tensor.default,
|
||||
prims.convert_element_type.default,
|
||||
aten.reshape.default,
|
||||
]
|
||||
@ -1374,15 +1343,10 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
|
||||
# For a dequant pattern, we should expect see the node list as:
|
||||
# * OPT(aten.reshape.default)
|
||||
# * OPT(prims.convert_element_type.default) (to_bf16)
|
||||
# * aten.mul
|
||||
# * aten.sub
|
||||
# * prims.convert_element_type.default (to_fp32)
|
||||
# * dequantize_per_tensor
|
||||
def _find_first_node_in_dequant_pattern(_node):
|
||||
if (
|
||||
_node.target is prims.convert_element_type.default
|
||||
and _node.args[1] == torch.float32
|
||||
):
|
||||
# For a dequant pattern, we expect the start node is a to_fp32 node
|
||||
if _node.target is quantized_decomposed.dequantize_per_tensor.default:
|
||||
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
|
||||
return _node
|
||||
else:
|
||||
assert (
|
||||
@ -1394,6 +1358,11 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
|
||||
dequant_pattern_end_node
|
||||
)
|
||||
|
||||
assert (
|
||||
dequant_pattern_start_node.target
|
||||
is quantized_decomposed.dequantize_per_tensor.default
|
||||
)
|
||||
|
||||
# Clone the dequant pattern for each user node
|
||||
graph = match.graph
|
||||
user_node_list = list(dequant_pattern_end_node.users)
|
||||
@ -1429,22 +1398,14 @@ def _is_valid_dequant_conv2d_pattern(dtype):
|
||||
return False
|
||||
|
||||
assert dtype in [torch.float32, torch.bfloat16]
|
||||
|
||||
if dtype == torch.float32:
|
||||
mul_node = conv_node.args[0]
|
||||
dequant_node = conv_node.args[0]
|
||||
else:
|
||||
convert_to_bf16 = conv_node.args[0]
|
||||
mul_node = convert_to_bf16.args[0]
|
||||
sub_node = mul_node.args[0]
|
||||
to_fp32_node = sub_node.args[0]
|
||||
dequant_node = convert_to_bf16.args[0]
|
||||
|
||||
assert to_fp32_node.target is prims.convert_element_type.default
|
||||
assert sub_node.target is aten.sub.Tensor
|
||||
assert mul_node.target is aten.mul.Tensor
|
||||
if (
|
||||
len(list(to_fp32_node.users)) != 1
|
||||
or len(list(sub_node.users)) != 1
|
||||
or len(list(mul_node.users)) != 1
|
||||
):
|
||||
if len(list(dequant_node.users)) != 1:
|
||||
# Ensure the dequant pattern only has 1 user
|
||||
# since we will delete the dequant pattern here
|
||||
return False
|
||||
@ -1477,12 +1438,10 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
|
||||
conv_node = match.output_node()
|
||||
assert conv_node.target is aten.convolution.default
|
||||
if dtype == torch.float32:
|
||||
mul_node = conv_node.args[0]
|
||||
dequant_node = conv_node.args[0]
|
||||
else:
|
||||
convert_to_bf16 = conv_node.args[0]
|
||||
mul_node = convert_to_bf16.args[0] # type: ignore[union-attr]
|
||||
sub_node = mul_node.args[0] # type: ignore[union-attr]
|
||||
to_fp32_node = sub_node.args[0] # type: ignore[union-attr]
|
||||
dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr]
|
||||
has_clone_to_channel_last_node_in_pattern = (
|
||||
conv_node.args[1].target is aten.clone.default # type: ignore[union-attr]
|
||||
)
|
||||
@ -1585,10 +1544,7 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
|
||||
# Erase the dequant pattern
|
||||
if dtype == torch.bfloat16:
|
||||
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined]
|
||||
# Erase the dequant pattern
|
||||
graph.erase_node(mul_node)
|
||||
graph.erase_node(sub_node)
|
||||
graph.erase_node(to_fp32_node)
|
||||
graph.erase_node(dequant_node)
|
||||
# Erase the dequant per channel pattern
|
||||
if clone_node is not None:
|
||||
graph.erase_node(clone_node)
|
||||
@ -1608,7 +1564,7 @@ def _generate_dequant_convolution_node_pattern(
|
||||
dequant_convolution_node_pattern = CallFunction(
|
||||
aten.convolution.default,
|
||||
_may_generate_pattern_with_dtype_convert(
|
||||
dequantize_per_tensor_activation_pattern,
|
||||
get_dequantize_per_tensor_activation_pattern(),
|
||||
KeywordArg("autocast_act_dtype"),
|
||||
dtype == torch.bfloat16,
|
||||
),
|
||||
@ -1668,7 +1624,7 @@ def _get_linear_node(match, input_dim_exceeds_two, input_contiguous):
|
||||
return linear_node, output_reshape_node
|
||||
|
||||
|
||||
def _get_linear_dq_mul_node(
|
||||
def _get_linear_dq_node(
|
||||
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
||||
):
|
||||
act_reshape_node = None
|
||||
@ -1679,30 +1635,30 @@ def _get_linear_dq_mul_node(
|
||||
act_reshape_node = linear_node.args[input_index]
|
||||
assert act_reshape_node.target is aten.reshape.default
|
||||
if dtype == torch.float32:
|
||||
# pattern: linear -> reshape -> mul
|
||||
mul_node = act_reshape_node.args[0]
|
||||
# pattern: linear -> reshape -> dequant
|
||||
dequant_node = act_reshape_node.args[0]
|
||||
else:
|
||||
# pattern: linear -> reshape -> to_bf16 -> mul
|
||||
# pattern: linear -> reshape -> to_bf16 -> dequant
|
||||
activation_to_bf16_node = act_reshape_node.args[0]
|
||||
mul_node = activation_to_bf16_node.args[0]
|
||||
dequant_node = activation_to_bf16_node.args[0]
|
||||
else:
|
||||
# bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous
|
||||
act_expand_node = linear_node.args[input_index]
|
||||
assert act_expand_node.target is aten.expand.default
|
||||
if dtype == torch.float32:
|
||||
mul_node = act_expand_node.args[0]
|
||||
dequant_node = act_expand_node.args[0]
|
||||
else:
|
||||
activation_to_bf16_node = act_expand_node.args[0]
|
||||
mul_node = activation_to_bf16_node.args[0]
|
||||
dequant_node = activation_to_bf16_node.args[0]
|
||||
else:
|
||||
if dtype == torch.float32:
|
||||
# pattern: linear -> mul
|
||||
mul_node = linear_node.args[input_index]
|
||||
# pattern: linear -> dequant
|
||||
dequant_node = linear_node.args[input_index]
|
||||
else:
|
||||
# pattern: linear -> to_bf16 -> mul
|
||||
# pattern: linear -> to_bf16 -> dequant
|
||||
activation_to_bf16_node = linear_node.args[input_index]
|
||||
mul_node = activation_to_bf16_node.args[0]
|
||||
return mul_node, act_reshape_node, activation_to_bf16_node, act_expand_node
|
||||
dequant_node = activation_to_bf16_node.args[0]
|
||||
return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node
|
||||
|
||||
|
||||
def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous):
|
||||
@ -1715,27 +1671,21 @@ def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contigu
|
||||
|
||||
input_index = 1 if linear_node.target is aten.addmm.default else 0
|
||||
assert dtype in [torch.float32, torch.bfloat16]
|
||||
|
||||
(
|
||||
mul_node,
|
||||
dequant_node,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = _get_linear_dq_mul_node(
|
||||
) = _get_linear_dq_node(
|
||||
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
||||
)
|
||||
|
||||
sub_node = mul_node.args[0]
|
||||
to_fp32_node = sub_node.args[0]
|
||||
assert dequant_node.target in [
|
||||
quantized_decomposed.dequantize_per_tensor.default,
|
||||
quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
]
|
||||
|
||||
assert to_fp32_node.target is prims.convert_element_type.default
|
||||
assert sub_node.target is aten.sub.Tensor
|
||||
assert mul_node.target is aten.mul.Tensor
|
||||
if (
|
||||
len(list(to_fp32_node.users)) != 1
|
||||
or len(list(sub_node.users)) != 1
|
||||
or len(list(mul_node.users)) != 1
|
||||
):
|
||||
if len(list(dequant_node.users)) != 1:
|
||||
# Ensure the dequant pattern only has 1 user
|
||||
# since we will delete the dequant pattern here
|
||||
return False
|
||||
@ -1820,17 +1770,14 @@ def _register_qlinear_weight_prepack_pass(
|
||||
weight_index = input_index + 1
|
||||
|
||||
(
|
||||
mul_node,
|
||||
dequant_node,
|
||||
act_reshape_node,
|
||||
activation_to_bf16_node,
|
||||
act_expand_node,
|
||||
) = _get_linear_dq_mul_node(
|
||||
) = _get_linear_dq_node(
|
||||
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
|
||||
)
|
||||
|
||||
sub_node = mul_node.args[0]
|
||||
to_fp32_node = sub_node.args[0]
|
||||
|
||||
if input_dim_exceeds_two and not input_contiguous:
|
||||
wgt_expand_node = linear_node.args[weight_index]
|
||||
assert wgt_expand_node.target is aten.expand.default
|
||||
@ -1938,9 +1885,7 @@ def _register_qlinear_weight_prepack_pass(
|
||||
if dtype == torch.bfloat16:
|
||||
graph.erase_node(activation_to_bf16_node)
|
||||
# Erase the dequant pattern
|
||||
graph.erase_node(mul_node)
|
||||
graph.erase_node(sub_node)
|
||||
graph.erase_node(to_fp32_node)
|
||||
graph.erase_node(dequant_node)
|
||||
# Erase the dequant per channel pattern
|
||||
graph.erase_node(t_node)
|
||||
if dtype == torch.bfloat16:
|
||||
@ -1954,7 +1899,10 @@ def _register_qlinear_weight_prepack_pass(
|
||||
|
||||
|
||||
def _generate_dequant_linear_node_pattern(
|
||||
_dequant_per_channel_pattern, dtype=torch.float32, input_dim_exceeds_two=False
|
||||
_dequant_per_channel_pattern,
|
||||
dtype=torch.float32,
|
||||
input_dim_exceeds_two=False,
|
||||
is_tensor_overload=False,
|
||||
):
|
||||
assert dtype in [torch.float32, torch.bfloat16]
|
||||
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
|
||||
@ -1964,7 +1912,7 @@ def _generate_dequant_linear_node_pattern(
|
||||
KeywordArg("b"),
|
||||
_may_generate_pattern_with_reshape(
|
||||
_may_generate_pattern_with_dtype_convert(
|
||||
dequantize_per_tensor_activation_pattern,
|
||||
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
|
||||
KeywordArg("autocast_act_dtype"),
|
||||
dtype == torch.bfloat16,
|
||||
),
|
||||
@ -1981,7 +1929,7 @@ def _generate_dequant_linear_node_pattern(
|
||||
aten.mm.default,
|
||||
_may_generate_pattern_with_reshape(
|
||||
_may_generate_pattern_with_dtype_convert(
|
||||
dequantize_per_tensor_activation_pattern,
|
||||
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
|
||||
KeywordArg("autocast_act_dtype"),
|
||||
dtype == torch.bfloat16,
|
||||
),
|
||||
@ -2000,6 +1948,7 @@ def _generate_dequant_bmm_node_pattern(
|
||||
_dequant_per_channel_pattern,
|
||||
dtype=torch.float32,
|
||||
with_bias=False,
|
||||
is_tensor_overload=False,
|
||||
):
|
||||
# When activation of linear dim exceed 2 and not contiguous
|
||||
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
|
||||
@ -2010,7 +1959,7 @@ def _generate_dequant_bmm_node_pattern(
|
||||
CallFunction(
|
||||
aten.expand.default,
|
||||
_may_generate_pattern_with_dtype_convert(
|
||||
dequantize_per_tensor_activation_pattern,
|
||||
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
|
||||
KeywordArg("autocast_act_dtype"),
|
||||
dtype == torch.bfloat16,
|
||||
),
|
||||
@ -2041,16 +1990,21 @@ def _generate_qlinear_weight_prepack_patterns(
|
||||
input_dim_exceeds_two=False,
|
||||
input_contiguous=True,
|
||||
with_bias=False,
|
||||
is_tensor_overload=False,
|
||||
):
|
||||
if input_dim_exceeds_two and not input_contiguous:
|
||||
return _generate_dequant_bmm_node_pattern(
|
||||
dequantize_per_channel_weight_pattern,
|
||||
dtype,
|
||||
with_bias,
|
||||
is_tensor_overload,
|
||||
)
|
||||
else:
|
||||
return _generate_dequant_linear_node_pattern(
|
||||
dequantize_per_channel_weight_pattern, dtype, input_dim_exceeds_two
|
||||
dequantize_per_channel_weight_pattern,
|
||||
dtype,
|
||||
input_dim_exceeds_two,
|
||||
is_tensor_overload,
|
||||
)
|
||||
|
||||
|
||||
@ -2082,7 +2036,7 @@ def _register_dequant_promotion():
|
||||
_register_dequant_promotion_pass(
|
||||
_may_generate_pattern_with_reshape(
|
||||
_may_generate_pattern_with_dtype_convert(
|
||||
dequantize_per_tensor_activation_pattern,
|
||||
get_dequantize_per_tensor_activation_pattern(),
|
||||
KeywordArg("autocast_act_dtype"),
|
||||
dtype == torch.bfloat16,
|
||||
),
|
||||
@ -2140,13 +2094,15 @@ def _register_qlinear_weight_prepack():
|
||||
# | OPT(add) |
|
||||
|
||||
linear_weight_prepack_cases = itertools.product(
|
||||
[torch.float32, torch.bfloat16], [True, False]
|
||||
[torch.float32, torch.bfloat16], [True, False], [True, False]
|
||||
)
|
||||
|
||||
# Step 1: register patterns from mm and addmm
|
||||
for dtype, input_dim_exceeds_two in linear_weight_prepack_cases:
|
||||
for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases:
|
||||
weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns(
|
||||
dtype, input_dim_exceeds_two
|
||||
dtype,
|
||||
input_dim_exceeds_two,
|
||||
is_tensor_overload=is_tensor_overload,
|
||||
)
|
||||
for weight_prepack_pattern in weight_prepack_patterns:
|
||||
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
|
||||
@ -2163,14 +2119,15 @@ def _register_qlinear_weight_prepack():
|
||||
# https://github.com/pytorch/pytorch/blob/
|
||||
# 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968
|
||||
# in this case, we can convert it back to qlinear
|
||||
for dtype, with_bias in itertools.product(
|
||||
[torch.float32, torch.bfloat16], [True, False]
|
||||
for dtype, with_bias, is_tensor_overload in itertools.product(
|
||||
[torch.float32, torch.bfloat16], [True, False], [True, False]
|
||||
):
|
||||
bmm_pattern = _generate_qlinear_weight_prepack_patterns(
|
||||
dtype=dtype,
|
||||
input_dim_exceeds_two=True,
|
||||
input_contiguous=False,
|
||||
with_bias=with_bias,
|
||||
is_tensor_overload=is_tensor_overload,
|
||||
)
|
||||
_register_qlinear_weight_prepack_pass(
|
||||
bmm_pattern,
|
||||
|
@ -1144,6 +1144,170 @@ def quantized_decomposed_dequantize_per_channel(
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(
|
||||
quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None
|
||||
)
|
||||
def quantized_decomposed_quantize_per_tensor_default(
|
||||
input: TensorBox,
|
||||
scale: float,
|
||||
zero_point: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> TensorBox:
|
||||
if input.get_dtype() == torch.bfloat16:
|
||||
input = to_dtype(input, torch.float32)
|
||||
assert (
|
||||
input.get_dtype() == torch.float32
|
||||
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
||||
|
||||
input_loader = input.make_loader()
|
||||
|
||||
def inner_fn(idx, scale, zero_point):
|
||||
input = input_loader(idx)
|
||||
inv_scale, zero_point = _create_constants(
|
||||
1.0 / scale, zero_point, dtype=torch.float32
|
||||
)
|
||||
val = ops.round(input * inv_scale) + zero_point
|
||||
qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
|
||||
clamped = ops.minimum(ops.maximum(val, qmin), qmax)
|
||||
return ops.to_dtype(clamped, dtype)
|
||||
|
||||
return Pointwise.create(
|
||||
device=input.get_device(),
|
||||
dtype=dtype,
|
||||
inner_fn=functools.partial(
|
||||
inner_fn, scale=float(scale), zero_point=int(zero_point)
|
||||
),
|
||||
ranges=input.get_size(),
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(
|
||||
quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None
|
||||
)
|
||||
def quantized_decomposed_dequantize_per_tensor_default(
|
||||
input: TensorBox,
|
||||
scale: float,
|
||||
zero_point: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> TensorBox:
|
||||
assert (
|
||||
input.get_dtype() == dtype
|
||||
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
||||
|
||||
input_loader = input.make_loader()
|
||||
|
||||
def inner_fn(idx, scale, zero_point):
|
||||
input = input_loader(idx)
|
||||
scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32)
|
||||
val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
|
||||
return val
|
||||
|
||||
return Pointwise.create(
|
||||
device=input.get_device(),
|
||||
dtype=torch.float32,
|
||||
inner_fn=functools.partial(
|
||||
inner_fn, scale=float(scale), zero_point=int(zero_point)
|
||||
),
|
||||
ranges=input.get_size(),
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(
|
||||
quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None
|
||||
)
|
||||
def quantized_decomposed_quantize_per_tensor_tensor(
|
||||
input: TensorBox,
|
||||
scale: TensorBox,
|
||||
zero_point: TensorBox,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> TensorBox:
|
||||
if input.get_dtype() == torch.bfloat16:
|
||||
input = to_dtype(input, torch.float32)
|
||||
assert (
|
||||
input.get_dtype() == torch.float32
|
||||
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
||||
assert len(scale.get_size()) == 0 or (
|
||||
len(scale.get_size()) == 1 and scale.get_size()[0] == 1
|
||||
), "expect scale as scalar tensor"
|
||||
assert len(zero_point.get_size()) == 0 or (
|
||||
len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
|
||||
), "expect zero_point as scalar tensor"
|
||||
|
||||
input_loader = input.make_loader()
|
||||
scale_loader = scale.make_loader()
|
||||
zero_point_loader = zero_point.make_loader()
|
||||
|
||||
def inner_fn(idx):
|
||||
input = input_loader(idx)
|
||||
_scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
|
||||
_zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
|
||||
if scale.dtype != torch.float32:
|
||||
_scale = ops.to_dtype(_scale, torch.float32)
|
||||
if zero_point.dtype != torch.float32:
|
||||
_zero_point = ops.to_dtype(_zero_point, torch.float32)
|
||||
val = ops.round(input * ops.reciprocal(_scale)) + _zero_point
|
||||
qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
|
||||
clamped = ops.minimum(ops.maximum(val, qmin), qmax)
|
||||
return ops.to_dtype(clamped, dtype)
|
||||
|
||||
return Pointwise.create(
|
||||
device=input.get_device(),
|
||||
dtype=dtype,
|
||||
inner_fn=inner_fn,
|
||||
ranges=input.get_size(),
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(
|
||||
quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None
|
||||
)
|
||||
def quantized_decomposed_dequantize_per_tensor_tensor(
|
||||
input: TensorBox,
|
||||
scale: TensorBox,
|
||||
zero_point: TensorBox,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> TensorBox:
|
||||
assert len(scale.get_size()) == 0 or (
|
||||
len(scale.get_size()) == 1 and scale.get_size()[0] == 1
|
||||
), "expect scale as scalar tensor"
|
||||
assert len(zero_point.get_size()) == 0 or (
|
||||
len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
|
||||
), "expect zero_point as scalar tensor"
|
||||
assert (
|
||||
input.get_dtype() == dtype
|
||||
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
||||
|
||||
input_loader = input.make_loader()
|
||||
scale_loader = scale.make_loader()
|
||||
zero_point_loader = zero_point.make_loader()
|
||||
|
||||
def inner_fn(idx):
|
||||
input = input_loader(idx)
|
||||
_scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
|
||||
_zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
|
||||
if scale.dtype != torch.float32:
|
||||
_scale = ops.to_dtype(_scale, torch.float32)
|
||||
if zero_point.dtype != torch.float32:
|
||||
_zero_point = ops.to_dtype(_zero_point, torch.float32)
|
||||
val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale
|
||||
return val
|
||||
|
||||
return Pointwise.create(
|
||||
device=input.get_device(),
|
||||
dtype=torch.float32,
|
||||
inner_fn=inner_fn,
|
||||
ranges=input.get_size(),
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten.cat)
|
||||
def cat(inputs, dim=0):
|
||||
cpu_device = inputs[0].get_device().type == "cpu"
|
||||
|
Reference in New Issue
Block a user