[Inductor] [Quant] Enable lowering of quant per tensor and refactor quant pattern (#124041)

**Summary**
Per the discussion in https://github.com/pytorch/pytorch/pull/123444, the `decomposed quant/dequant` patterns changed after https://github.com/pytorch/pytorch/pull/123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in https://github.com/pytorch/pytorch/pull/123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124041
Approved by: https://github.com/peterbell10, https://github.com/jgong5
This commit is contained in:
leslie-fang-intel
2024-05-06 19:37:22 -07:00
committed by PyTorch MergeBot
parent 1b1b18a7a4
commit 33e6791645
7 changed files with 434 additions and 387 deletions

View File

@ -1671,7 +1671,7 @@ static at::Tensor _quantized_convolution_onednn(
}
tensor src_scales_t = tensor(ideep::scale_t(1, act_scale));
tensor wei_scales_t = tensor(weights_scales);
tensor dst_scales_t = tensor(ideep::scale_t(1, 1.0/inv_output_scale));
tensor dst_scales_t = tensor(ideep::scale_t(1, inv_output_scale));
tensor src_zp_t = tensor(ideep::zero_point_t(1, act_zero_point));
tensor dst_zp_t = tensor(ideep::zero_point_t(1, output_zero_point));
if (act_scale != 1.0f) {
@ -1707,7 +1707,7 @@ static at::Tensor _quantized_convolution_onednn(
ideep::convolution_forward::prepare(
params, src, packed_weight, expected_bias, dst_dims, dst,
stride.vec(), dilation.vec(), padding.vec(), padding.vec(), groups,
src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
src_scales, weights_scales, ideep::scale_t(1, 1.0f / inv_output_scale),
src_zero_points, dst_zero_points,
op_attr, dnnl::algorithm::convolution_direct,
dnnl::prop_kind::forward_inference,

View File

@ -931,7 +931,6 @@ static at::Tensor linear_int8_with_onednn_weight(
c10::string_view& unary_post_op_algorithm) {
using ideep::tensor;
const int64_t dim = input.dim();
output_scale = 1.0f / output_scale;
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
"qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char).");
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,

View File

@ -555,15 +555,15 @@ class TestPatternMatcher(TestPatternMatcherBase):
def matcher_check_fn():
# 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1
# int8_mixed_fp32: [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# int8_mixed_bf16: [convert_element_type_1, sub, mul_1, optional(convert_element_type_4),
# int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution]
# int8_mixed_bf16: [dequant_node, optional(convert_element_type_4),
# dequantize_per_channel, optional(convert_element_type_3), clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"],
16 if int8_mixed_bf16 else 12,
12 if int8_mixed_bf16 else 8,
)
self._test_common(
@ -683,14 +683,13 @@ class TestPatternMatcher(TestPatternMatcherBase):
r"""
This testcase will quantize Conv2d->Hardtanh pattern.
Match.nodes:
[qconv2d_pointwise_default_1, convert_element_type_5, clamp_min_1, clamp_max_1, mul_2, round_2, add_1, clamp_min_2,
clamp_max_1, mul_2, round_2, add_1, clamp_min_2, clamp_max_2, convert_element_type_8
[qconv2d_pointwise_default, convert_element_type_13, clamp_min_3, clamp_max_3]
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
"""
self._qconv2d_unary_cpu_test_helper(
unary_op=torch.nn.Hardtanh(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=14,
qconv2d_unary_matcher_nodes=11,
)
@skipIfNoDynamoSupport
@ -710,14 +709,14 @@ class TestPatternMatcher(TestPatternMatcherBase):
r"""
This testcase will quantize Conv2d->Hardswish pattern.
Match.nodes:
[qconv2d_pointwise_default_1, convert_element_type_5, add_1, clamp_min_1,
clamp_max_1, mul_2, div, mul_3, round_2, add_2, clamp_min_2, clamp_max_2, convert_element_type_8]
[qconv2d_pointwise_default, convert_element_type_13, add_3, clamp_min_3, clamp_max_3, mul_5, div_1]
[qconv2d_pointwise_default, convert_element_type, add, clamp_min,
clamp_max, mul, div, convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
"""
self._qconv2d_unary_cpu_test_helper(
unary_op=torch.nn.Hardswish(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=20,
qconv2d_unary_matcher_nodes=17,
)
@skipIfNoDynamoSupport
@ -737,14 +736,14 @@ class TestPatternMatcher(TestPatternMatcherBase):
r"""
This testcase will quantize Conv2d->SiLU pattern.
Match.nodes:
[qconv2d_pointwise_default_1, convert_element_type_5, sigmoid, mul_2,
mul_3, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_8]
[qconv2d_pointwise_default, convert_element_type_13, sigmoid_1, mul_5]
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
"""
self._qconv2d_unary_cpu_test_helper(
unary_op=torch.nn.SiLU(),
int8_mixed_bf16=True,
qconv2d_unary_matcher_nodes=14,
qconv2d_unary_matcher_nodes=11,
)
def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False):
@ -1028,17 +1027,17 @@ class TestPatternMatcher(TestPatternMatcherBase):
def matcher_check_fn():
# 1. Dequant-conv pattern matched in quantization weight prepack * 1
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 6
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 4
)
# 2. QConv2D Unary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# [qconv2d_pointwise_default, quantize_per_tensor]
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 7)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_nodes"], 2)
self._test_common(
mod,
@ -1107,7 +1106,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
r"""
This testcase will quantize Conv2d->ReLU6 pattern with qat flow.
"""
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6())
@skipIfNoDynamoSupport
@ -1117,7 +1115,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
r"""
This testcase will quantize Conv2d->Hardtanh pattern with qat flow.
"""
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh())
@skipIfNoDynamoSupport
@ -1127,7 +1124,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
r"""
This testcase will quantize Conv2d->SiLU pattern with qat flow.
"""
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.SiLU())
@skipIfNoDynamoSupport
@ -1137,7 +1133,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
r"""
This testcase will quantize Conv2d->Hardswish pattern with qat flow.
"""
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardswish())
@skipIfNoDynamoSupport
@ -1176,18 +1171,17 @@ class TestPatternMatcher(TestPatternMatcherBase):
def matcher_check_fn():
# 1. Dequant-conv pattern matched in quantization weight prepack * 2
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8
)
# 2. Qconv2d Binary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, mul_6, round_4, add_4,
# clamp_min_3, clamp_max_3, convert_element_type_6]
# [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, quantize_per_tensor]
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 11)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 4)
self._test_common(
mod,
@ -1236,18 +1230,17 @@ class TestPatternMatcher(TestPatternMatcherBase):
def matcher_check_fn():
# 1. Dequant-conv pattern matched in quantization weight prepack * 2
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8
)
# 2. Qconv2d Binary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default_1, convert_element_type_5, sub_2, mul_5, add_3, relu, mul_6, round_4, add_4,
# clamp_min_3, clamp_max_3, convert_element_type_6]
# [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, relu, quantize_per_tensor]
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_count"], 1)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 12)
self.assertEqual(counters["inductor"]["qconv2d_binary_matcher_nodes"], 5)
self._test_common(
mod,
@ -1294,16 +1287,16 @@ class TestPatternMatcher(TestPatternMatcherBase):
def matcher_check_fn():
# 1. Dequant pattern matcher for dequant promotion * 1
# [convert_element_type_3, sub_1, mul_3]
# [dequantize_per_tensor]
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 3)
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 1)
# 2. Dequant-conv pattern matched in quantization weight prepack * 3
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 3
)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 18
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12
)
# 3. Qconv2d Binary fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default_1, add_3]
@ -1445,7 +1438,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
)
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
17 if bias else 16,
13 if bias else 12,
)
self._qlinear_cpu_test_helper(
@ -1473,7 +1466,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
)
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
21 if bias else 20,
17 if bias else 16,
)
self._qlinear_cpu_test_helper(
@ -1722,12 +1715,16 @@ class TestPatternMatcher(TestPatternMatcherBase):
x1 = torch.randn((2, 4))
x2 = torch.randn((2, 5))
def matcher_check_fn():
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
)
self._test_common(
mod,
(x1, x2),
2,
8,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -1763,22 +1760,19 @@ class TestPatternMatcher(TestPatternMatcherBase):
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1
)
# Totally 6 pattern_matcher_count, 31 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 * 3, matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. Dequant-conv pattern matched in quantization weight prepack * 1
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. qconv2d_relu fusion in post-grad fusion pass * 1
# [qconv2d_pointwise_default, relu, mul_2, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 4. qmaxpool2d * 1
# [convert_element_type_3, sub_1, mul_3, max_pool2d_with_indices, getitem, mul_4, round_3, add_2,
# clamp_min_2, clamp_max_2, convert_element_type_4]
def matcher_check_fn():
self.assertEqual(counters["inductor"]["qmaxpool2d_matcher_count"], 1)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 1)
self._test_common(
mod,
(v,),
6,
31,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
@skipIfNoDynamoSupport
@ -1852,22 +1846,19 @@ class TestPatternMatcher(TestPatternMatcherBase):
mod = M().eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
# Totally 10 pattern_matcher_count, 49 pattern_matcher_nodes
# 1. Pair of to_int8 and to_fp32 * 5, matched in pointless_convert pass at
# torch/_inductor/fx_passes/joint_graph.py: [convert_element_type, convert_element_type_1]
# 2. Dequant-conv pattern matched in quantization weight prepack * 2
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
# 3. qconv2d fusion in post-grad fusion pass * 2
# [qconv2d_pointwise_default, mul_2, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
# 4. qcat * 1
# [convert_element_type_3, sub_1, mul_3, convert_element_type_7, sub_3, mul_7, cat, mul_8, round_5,
# add_4, clamp_min_4, clamp_max_4, convert_element_type_8]
def matcher_check_fn():
self.assertEqual(counters["inductor"]["qcat_matcher_count"], 1)
self.assertEqual(
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2
)
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 2)
self._test_common(
mod,
(v,),
10,
49,
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)
# https://github.com/pytorch/pytorch/issues/99841.

View File

@ -4309,7 +4309,7 @@ class TestQuantizedLinear(TestCase):
if post_op in ("none", "relu", "gelu"):
qy_cpu = qlinear_op(
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
b, used_y_scale, used_y_zp, output_dtype,
post_op, unary_post_op_args, post_op_algo
)
if post_op == "relu":
@ -4330,7 +4330,7 @@ class TestQuantizedLinear(TestCase):
accum = accum.bfloat16()
qy_cpu = qlinear_op(
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
b, used_y_scale, used_y_zp, output_dtype,
accum, x2_scale, x2_zp, "sum", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
@ -4348,7 +4348,7 @@ class TestQuantizedLinear(TestCase):
binary_alpha = 1.0 # we only support alpha=1.0 now
qy_cpu = qlinear_op(
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, 1.0 / used_y_scale, used_y_zp, output_dtype,
b, used_y_scale, used_y_zp, output_dtype,
x2, 1.0, 0, "add", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
@ -6796,7 +6796,7 @@ class TestQuantizedConv(TestCase):
pads,
dilations,
groups,
1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant
Y_scale,
Y_zero_point,
qconv_output_dtype,
post_op.binary_attr,
@ -6818,7 +6818,7 @@ class TestQuantizedConv(TestCase):
pads,
dilations,
groups,
1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant
Y_scale,
Y_zero_point,
qconv_output_dtype,
post_op.unary_attr,

View File

@ -478,70 +478,6 @@ def linear_dynamic_fp16_unpacked_weight(
)
# The difference between quantize_per_tensor.default and quantize_per_tensor.tensor is
# scale and zero_point is scalar or scalar tensor
@register_decomposition(quantized_decomposed.quantize_per_tensor.default)
def quantize_per_tensor_default_decomp_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)
inv_scale = 1.0 / scale
return torch.clamp(
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
).to(dtype)
# The difference between dequantize_per_tensor.default and dequantize_per_tensor.tensor is
# scale and zero_point is scalar or scalar tensor
@register_decomposition(quantized_decomposed.dequantize_per_tensor.default)
def dequantize_per_tensor_default_decomp_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return (input.to(torch.float32) - zero_point) * scale
@register_decomposition(quantized_decomposed.quantize_per_tensor.tensor)
def quantize_per_tensor_tensor_decomp_impl(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)
inv_scale = 1.0 / scale
return torch.clamp(
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
).to(dtype)
@register_decomposition(quantized_decomposed.dequantize_per_tensor.tensor)
def dequantize_per_tensor_tensor_decomp_impl(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return (input.to(torch.float32) - zero_point.to(torch.int32)) * scale.to(
torch.float32
)
@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
def q_embedding_bag_byte_unpack_decomp(packed):
def bitcast_u8_to_f32(u8):

View File

@ -49,10 +49,25 @@ quantization.
"""
def _get_pattern_output_dtype(match: Match):
"""
Get the pattern's output dtype from node's meta
Assume only 1 output node in this matched pattern.
"""
pattern_output_nodes = match.output_nodes()
assert len(pattern_output_nodes) == 1
output_node = pattern_output_nodes[0]
assert isinstance(output_node, torch.fx.Node)
output_dtype = output_node.meta["val"].dtype
if output_dtype is torch.uint8:
output_dtype = None
return output_dtype
def _may_generate_pattern_with_dtype_convert(
pattern, dtype=Arg(), dtype_convert=True, users=1
pattern, dtype=Arg(), with_dtype_convert=True, users=1
):
if dtype_convert:
if with_dtype_convert:
return CallFunction(
prims.convert_element_type.default,
pattern,
@ -94,30 +109,25 @@ def _generate_linear_t_pattern(
def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16):
# only insert to_dtype if is_bf16 is True
computation_call = _may_generate_pattern_with_dtype_convert(
call_fn, dtype=KeywordArg("to_float"), dtype_convert=is_bf16, users=users
call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users
)
return unary_fusion(computation_call)
"""
dequantize activation:
x = x.to(fp32)
x = x - zero_point
x = x * scale
"""
dequantize_per_tensor_activation_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.sub.Tensor,
CallFunction(
prims.convert_element_type.default,
KeywordArg("x"),
KeywordArg("x_dq_dtype"),
),
def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False):
dequantize_per_tensor_activation_pattern = CallFunction(
quantized_decomposed.dequantize_per_tensor.tensor
if is_tensor_overload
else quantized_decomposed.dequantize_per_tensor.default,
KeywordArg("x"),
KeywordArg("x_scale"),
KeywordArg("x_zp"),
),
KeywordArg("x_scale"),
)
KeywordArg("x_quant_min"),
KeywordArg("x_quant_max"),
KeywordArg("x_dq_dtype"),
)
return dequantize_per_tensor_activation_pattern
dequantize_per_channel_weight_pattern = CallFunction(
quantized_decomposed.dequantize_per_channel.default,
@ -200,17 +210,13 @@ def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1):
dequantize_accum_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.sub.Tensor,
CallFunction(
prims.convert_element_type.default,
KeywordArg("accum"),
KeywordArg("accum_dq_dtype"),
),
KeywordArg("accum_zp"),
),
quantized_decomposed.dequantize_per_tensor.default,
KeywordArg("accum"),
KeywordArg("accum_scale"),
KeywordArg("accum_zp"),
Arg(),
Arg(),
KeywordArg("accum_dq_dtype"),
)
@ -241,43 +247,18 @@ def generate_pattern_with_unary(computation_call, unary_post_op):
return computation_call
def generate_pattern_with_output_quant(
computation_call, has_to_fp32_before_quant=False
):
"""
quantize output:
output = round(output * o_inv_scale)
output = output + zero_point
output = clamp_min(output, 0)
output = clamp_max(output, 127)
output = output.to(uint8)
"""
def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False):
quantized_op_output_pattern_pt2e = CallFunction(
prims.convert_element_type.default,
CallFunction(
aten.clamp_max.default,
CallFunction(
aten.clamp_min.default,
CallFunction(
aten.add.Tensor,
CallFunction(
aten.round.default,
CallFunction(
aten.mul.Tensor,
_may_generate_pattern_with_dtype_convert(
computation_call,
KeywordArg("autocast_output_quant_dtype"),
has_to_fp32_before_quant,
),
KeywordArg("o_inv_scale"),
),
),
KeywordArg("o_zp"),
),
KeywordArg("o_qmin"),
),
KeywordArg("o_qmax"),
quantized_decomposed.quantize_per_tensor.default,
_may_generate_pattern_with_dtype_convert(
computation_call,
Arg(),
with_dtype_convert,
),
KeywordArg("o_inv_scale"),
KeywordArg("o_zp"),
KeywordArg("o_qmin"),
KeywordArg("o_qmax"),
KeywordArg("o_dtype"),
)
return quantized_op_output_pattern_pt2e
@ -293,8 +274,9 @@ def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_val
return actual_value == expected_value
def _is_valid_quantized_conv2d_optimization_pattern(output_dtype):
def _is_valid_quantized_conv2d_optimization_pattern():
def fn(match):
output_dtype = _get_pattern_output_dtype(match)
if output_dtype is not None:
# Only keep matched pattern with same output_dtype
qconv_node_after_weight_prepack = filter_nodes(
@ -312,13 +294,11 @@ def _register_quantized_conv_lowering(
pattern,
pass_number,
computation_op,
output_dtype,
unary_attr,
original_pattern_output_dtype=torch.float32,
):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_quantized_conv2d_optimization_pattern(output_dtype),
extra_check=_is_valid_quantized_conv2d_optimization_pattern(),
pass_number=pass_number,
)
def qconv(match: Match, *args, **kwargs):
@ -342,13 +322,11 @@ def _register_quantized_conv_lowering(
kwargs["dilation"],
kwargs["groups"],
)
output_dtype = _get_pattern_output_dtype(match)
assert output_dtype in [None, torch.float32, torch.bfloat16]
# Output QParams
o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0
o_zero_point = kwargs["o_zp"] if output_dtype is None else 0
assert (
kwargs["output_dtype"] is original_pattern_output_dtype
) # Expected int8-in fp32-out qconv in weight prepack phase
assert (
kwargs["attr"] == "none"
) # Expected no post op fused in weight prepack phase
@ -383,8 +361,9 @@ def _register_quantized_conv_lowering(
return qconv
def _is_valid_quantized_linear_optimization_pattern(output_dtype):
def _is_valid_quantized_linear_optimization_pattern():
def fn(match):
output_dtype = _get_pattern_output_dtype(match)
if output_dtype is not None:
# Only keep matched pattern with same output_dtype
qlinear_node_after_weight_prepack = filter_nodes(
@ -402,16 +381,15 @@ def _register_quantized_linear_lowering(
pattern,
pass_number,
computation_op,
output_dtype,
unary_attr,
original_pattern_output_dtype=torch.float32,
):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_quantized_linear_optimization_pattern(output_dtype),
extra_check=_is_valid_quantized_linear_optimization_pattern(),
pass_number=pass_number,
)
def qlinear(match: Match, *args, **kwargs):
output_dtype = _get_pattern_output_dtype(match)
# Activation QParams
x, x_scale, x_zp = (
kwargs["x"],
@ -431,9 +409,6 @@ def _register_quantized_linear_lowering(
# Output QParams
o_inv_scale = kwargs["o_inv_scale"] if output_dtype is None else 1.0
o_zero_point = kwargs["o_zp"] if output_dtype is None else 0
assert (
kwargs["output_dtype"] is original_pattern_output_dtype
) # Expected int8-in fp32/bf16-out qlinear in weight prepack phase
assert (
kwargs["postop_name"] == "none"
) # Expected no post op fused in weight prepack phase
@ -460,7 +435,7 @@ def _register_quantized_linear_lowering(
return qlinear
def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype):
def _is_valid_quantized_conv_binary_optimization_pattern():
# Check if it's a valid Conv Binary Pattern:
# * qconv2d_pointwise should only has one users
# * Extra input of binary node comes from dequant pattern
@ -470,6 +445,7 @@ def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype):
# ancestor nodes of the compute node, except for the binary node
# connected to the compute node.
def fn(match):
output_dtype = _get_pattern_output_dtype(match)
compute_node = filter_nodes(match.nodes, torch.ops.onednn.qconv2d_pointwise)[0]
# qconv2d_pointwise should only have one user
if len(compute_node.users) != 1:
@ -485,7 +461,8 @@ def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype):
assert extra_input_of_binary_node is not None
# Extra input of binary node comes from dequant pattern
if (not isinstance(extra_input_of_binary_node, torch.fx.Node)) or (
extra_input_of_binary_node.target != aten.mul.Tensor
extra_input_of_binary_node.target
!= quantized_decomposed.dequantize_per_tensor.default
):
return False
@ -536,15 +513,15 @@ def _register_quantized_conv_binary_lowering(
pattern,
pass_number,
computation_op,
output_dtype,
binary_unary_attr,
):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_quantized_conv_binary_optimization_pattern(output_dtype),
extra_check=_is_valid_quantized_conv_binary_optimization_pattern(),
pass_number=pass_number,
)
def qconv_binary(match: Match, *args, **kwargs):
output_dtype = _get_pattern_output_dtype(match)
x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
accum = (
kwargs["accum"] if output_dtype is None else kwargs["accum_after_dequant"]
@ -629,13 +606,11 @@ def _register_quantization_unary_fusion():
conv_unary_replace_patterns = {
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
get_dequantize_qconv_pt2e_pattern(1),
has_to_fp32_before_quant=is_bf16,
),
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
),
has_to_fp32_before_quant=is_bf16,
),
UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant(
_unary_fusion_pattern(
@ -644,7 +619,7 @@ def _register_quantization_unary_fusion():
1,
is_bf16,
),
has_to_fp32_before_quant=False,
with_dtype_convert=is_bf16,
),
UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant(
_unary_fusion_pattern(
@ -653,7 +628,7 @@ def _register_quantization_unary_fusion():
2,
is_bf16,
),
has_to_fp32_before_quant=False,
with_dtype_convert=is_bf16,
),
UnaryAttr("swish", [], ""): generate_pattern_with_output_quant(
_unary_fusion_pattern(
@ -662,7 +637,7 @@ def _register_quantization_unary_fusion():
2,
is_bf16,
),
has_to_fp32_before_quant=False,
with_dtype_convert=is_bf16,
),
}
@ -672,9 +647,7 @@ def _register_quantization_unary_fusion():
patterns,
1, # pass_number
torch.ops.onednn.qconv2d_pointwise, # computation_op
None, # output_dtype, None is the default value for int8 output
unary_attr, # unary_attr
original_pattern_output_dtype=original_pattern_output_dtype,
)
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
@ -682,22 +655,34 @@ def _register_quantization_unary_fusion():
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
),
UnaryAttr("hardtanh", [], ""): _unary_fusion_pattern(
_hardtanh_fusion,
get_dequantize_qconv_pt2e_pattern(1),
1,
UnaryAttr("hardtanh", [], ""): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_hardtanh_fusion,
get_dequantize_qconv_pt2e_pattern(1),
1,
is_bf16,
),
Arg(),
is_bf16,
),
UnaryAttr("hardswish", [], ""): _unary_fusion_pattern(
_hardswish_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
UnaryAttr("hardswish", [], ""): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_hardswish_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
Arg(),
is_bf16,
),
UnaryAttr("swish", [], ""): _unary_fusion_pattern(
_silu_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
UnaryAttr("swish", [], ""): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_silu_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
Arg(),
is_bf16,
),
}
@ -708,9 +693,7 @@ def _register_quantization_unary_fusion():
patterns,
2, # pass_number
torch.ops.onednn.qconv2d_pointwise, # computation_op
original_pattern_output_dtype, # output_dtype
unary_attr, # unary_attr
original_pattern_output_dtype=original_pattern_output_dtype,
)
# QLinear
@ -720,11 +703,9 @@ def _register_quantization_unary_fusion():
linear_unary_replace_patterns = {
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
qlinear_pattern,
has_to_fp32_before_quant=is_bf16,
),
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
has_to_fp32_before_quant=is_bf16,
),
UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant(
_unary_fusion_pattern(
@ -735,18 +716,18 @@ def _register_quantization_unary_fusion():
2,
is_bf16,
),
has_to_fp32_before_quant=False,
with_dtype_convert=is_bf16,
),
UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_gelu_fusion_tanh,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 2
x_scale_zp_are_tensors, 1 if is_bf16 else 4
),
4,
is_bf16,
),
has_to_fp32_before_quant=False,
with_dtype_convert=is_bf16,
),
}
@ -755,9 +736,7 @@ def _register_quantization_unary_fusion():
patterns,
1, # pass_number
torch.ops.onednn.qlinear_pointwise, # computation_op
None, # output_dtype
unary_attr, # unary_attr
original_pattern_output_dtype=original_pattern_output_dtype,
)
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
@ -765,20 +744,28 @@ def _register_quantization_unary_fusion():
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
qlinear_pattern, aten.relu.default
),
UnaryAttr("gelu", [], "none"): _unary_fusion_pattern(
_gelu_fusion_erf,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 2
UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_gelu_fusion_erf,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 2
),
2,
is_bf16,
),
2,
Arg(),
is_bf16,
),
UnaryAttr("gelu", [], "tanh"): _unary_fusion_pattern(
_gelu_fusion_tanh,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 4
UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_gelu_fusion_tanh,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 4
),
4,
is_bf16,
),
4,
Arg(),
is_bf16,
),
}
@ -788,9 +775,7 @@ def _register_quantization_unary_fusion():
patterns,
2, # pass_number
torch.ops.onednn.qlinear_pointwise, # computation_op
original_pattern_output_dtype, # output_dtype
unary_attr, # unary_attr
original_pattern_output_dtype=original_pattern_output_dtype,
)
@ -822,7 +807,6 @@ def _register_quantization_binary_fusion():
dequantize_accum_pattern,
int8_mixed_bf16_with_inplace_add,
),
has_to_fp32_before_quant=int8_mixed_bf16_with_inplace_add,
),
BinaryUnaryAttr(
"sum", 1.0, "relu", [], ""
@ -836,7 +820,6 @@ def _register_quantization_binary_fusion():
),
aten.relu.default,
),
has_to_fp32_before_quant=int8_mixed_bf16_with_inplace_add,
),
}
@ -845,7 +828,6 @@ def _register_quantization_binary_fusion():
patterns,
0, # pass_number
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
None, # output_dtype
binary_unary_attr, # binary_unary_attr
)
@ -871,11 +853,6 @@ def _register_quantization_binary_fusion():
patterns,
0, # pass_number
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
# Note that for int8-mixed-bf16 and non-inplace add, because we have
# q-dq inserted at extra input of add, so the non-inplace add has bf16 and fp32 inputs,
# the output dtype will be float32.
# For inplace add, there is a extra to_bf16 node at add output, so the fusion pattern has bfloat16 output.
torch.bfloat16,
binary_unary_attr, # binary_unary_attr
)
else:
@ -883,7 +860,6 @@ def _register_quantization_binary_fusion():
patterns,
1, # pass_number
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
torch.float32,
binary_unary_attr, # binary_unary_attr
)
@ -905,8 +881,6 @@ def _register_quantization_binary_fusion():
patterns,
1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
# Same output dtype setting as conv-add-relu pattern
torch.bfloat16 if int8_mixed_bf16_with_inplace_add else torch.float32,
binary_unary_attr, # binary_unary_attr
)
@ -962,6 +936,8 @@ def _register_quantized_maxpool2d_lowering(
ceil_mode,
)
computation_args, _ = require_channels_last(computation_op, *computation_args)
counters["inductor"]["qmaxpool2d_matcher_count"] += 1
counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes)
return L[computation_op](*computation_args)
return qmaxpool2d
@ -996,7 +972,7 @@ def _register_quantization_maxpool2d():
for max_pool2d_args in max_pool2d_args_list:
dequantize_maxpool2d_pattern = CallFunction(
aten.max_pool2d_with_indices.default,
dequantize_per_tensor_activation_pattern,
get_dequantize_per_tensor_activation_pattern(),
KeywordArg("kernel_size"),
*max_pool2d_args,
)
@ -1033,26 +1009,23 @@ def _is_input_output_same_scale_zp(check_node):
def fn(match):
# Ensure all the inputs and output has same scale and zero point
# Step 1: Check inputs/output zero point
sub_nodes = filter_nodes(match.nodes, aten.sub.Tensor)
zero_points = [node.args[1] for node in sub_nodes]
add_nodes = filter_nodes(match.nodes, aten.add.Tensor)
assert len(add_nodes) == 1, "expect only 1 add node at output quant pattern"
zero_points.append(add_nodes[0].args[1])
# Get dequant nodes at input
dequant_nodes = filter_nodes(
match.nodes, quantized_decomposed.dequantize_per_tensor.default
)
zero_points = [node.args[2] for node in dequant_nodes]
# Get quant nodes at output
quant_nodes = filter_nodes(
match.nodes, quantized_decomposed.quantize_per_tensor.default
)
assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern"
zero_points.append(quant_nodes[0].args[2])
if not all(zero_point == zero_points[0] for zero_point in zero_points):
return False
# Step 2: Check inputs/output scale
mul_nodes = filter_nodes(match.nodes, aten.mul.Tensor)
# We need to find mul node at output since the scale value is reciprocal to input scale.
# Mul node at output should connect to cat node directly.
scales = [
(
mul_node.args[1]
if mul_node.args[0].target is check_node # type: ignore[union-attr]
else 1.0 / mul_node.args[1] # type: ignore[operator]
)
for mul_node in mul_nodes
]
scales = [node.args[1] for node in dequant_nodes]
scales.append(quant_nodes[0].args[1])
if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type]
return False
@ -1072,22 +1045,20 @@ def _register_quantized_cat_lowering(
def qcat(match: Match, inputs, dim, **kwargs):
# inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...]
uint8_inputs = [input[0] for input in inputs]
counters["inductor"]["qcat_matcher_count"] += 1
counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes)
return L[computation_op](uint8_inputs, dim)
return qcat
_raw_dequantize_per_tensor_activation_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.sub.Tensor,
CallFunction(
prims.convert_element_type.default,
Arg(),
Arg(),
),
Arg(),
),
quantized_decomposed.dequantize_per_tensor.default,
Arg(),
Arg(),
Arg(),
Arg(),
Arg(),
Arg(),
)
@ -1125,7 +1096,7 @@ def _register_quantized_reshape_lowering(
def _register_quantization_reshape():
dequantize_reshape_pattern = CallFunction(
torch.ops.aten.reshape.default,
dequantize_per_tensor_activation_pattern,
get_dequantize_per_tensor_activation_pattern(),
KeywordArg("shape"),
)
_register_quantized_reshape_lowering(
@ -1274,35 +1245,33 @@ def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
assert dtype in [torch.float32, torch.bfloat16]
dequant_pattern_end_node = match.output_node()
if dequant_pattern_end_node.target not in [
aten.mul.Tensor,
quantized_decomposed.dequantize_per_tensor.default,
prims.convert_element_type.default,
aten.reshape.default,
]:
return False
if dequant_pattern_end_node.target is aten.reshape.default:
mul_node = (
dequant_pattern_end_node.args[0] # pattern: linear <- reshape <- mul
dequant_node = (
dequant_pattern_end_node.args[
0
] # pattern: linear <- reshape <- dequant
if dtype == torch.float32
else dequant_pattern_end_node.args[0].args[
0
] # pattern: linear <- reshape <- to_bf16 <- mul
] # pattern: linear <- reshape <- to_bf16 <- dequant
)
else:
mul_node = (
dequant_pattern_end_node # pattern: linear <- mul
dequant_node = (
dequant_pattern_end_node # pattern: linear <- dequant
if dtype == torch.float32
else dequant_pattern_end_node.args[
0
] # pattern: linear <- to_bf16 <- mul
] # pattern: linear <- to_bf16 <- dequant
)
sub_node = mul_node.args[0]
to_fp32_node = sub_node.args[0]
if (
mul_node.target is aten.mul.Tensor
and sub_node.target is aten.sub.Tensor
and to_fp32_node.target is prims.convert_element_type.default
dequant_node.target is quantized_decomposed.dequantize_per_tensor.default
and len(list(dequant_pattern_end_node.users)) > 1
):
# If dequant pattern has more than 1 users, then do dequant promoted
@ -1363,10 +1332,10 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
# Find the start node and end node of a dequant pattern
# * End node should be the match.output_node()
# * Start node should be the node of dtype convert to float32
# * Start node should be the node of dequantize_per_tensor
dequant_pattern_end_node = match.output_node()
assert dequant_pattern_end_node.target in [
aten.mul.Tensor,
quantized_decomposed.dequantize_per_tensor.default,
prims.convert_element_type.default,
aten.reshape.default,
]
@ -1374,15 +1343,10 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
# For a dequant pattern, we should expect see the node list as:
# * OPT(aten.reshape.default)
# * OPT(prims.convert_element_type.default) (to_bf16)
# * aten.mul
# * aten.sub
# * prims.convert_element_type.default (to_fp32)
# * dequantize_per_tensor
def _find_first_node_in_dequant_pattern(_node):
if (
_node.target is prims.convert_element_type.default
and _node.args[1] == torch.float32
):
# For a dequant pattern, we expect the start node is a to_fp32 node
if _node.target is quantized_decomposed.dequantize_per_tensor.default:
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
return _node
else:
assert (
@ -1394,6 +1358,11 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
dequant_pattern_end_node
)
assert (
dequant_pattern_start_node.target
is quantized_decomposed.dequantize_per_tensor.default
)
# Clone the dequant pattern for each user node
graph = match.graph
user_node_list = list(dequant_pattern_end_node.users)
@ -1429,22 +1398,14 @@ def _is_valid_dequant_conv2d_pattern(dtype):
return False
assert dtype in [torch.float32, torch.bfloat16]
if dtype == torch.float32:
mul_node = conv_node.args[0]
dequant_node = conv_node.args[0]
else:
convert_to_bf16 = conv_node.args[0]
mul_node = convert_to_bf16.args[0]
sub_node = mul_node.args[0]
to_fp32_node = sub_node.args[0]
dequant_node = convert_to_bf16.args[0]
assert to_fp32_node.target is prims.convert_element_type.default
assert sub_node.target is aten.sub.Tensor
assert mul_node.target is aten.mul.Tensor
if (
len(list(to_fp32_node.users)) != 1
or len(list(sub_node.users)) != 1
or len(list(mul_node.users)) != 1
):
if len(list(dequant_node.users)) != 1:
# Ensure the dequant pattern only has 1 user
# since we will delete the dequant pattern here
return False
@ -1477,12 +1438,10 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
conv_node = match.output_node()
assert conv_node.target is aten.convolution.default
if dtype == torch.float32:
mul_node = conv_node.args[0]
dequant_node = conv_node.args[0]
else:
convert_to_bf16 = conv_node.args[0]
mul_node = convert_to_bf16.args[0] # type: ignore[union-attr]
sub_node = mul_node.args[0] # type: ignore[union-attr]
to_fp32_node = sub_node.args[0] # type: ignore[union-attr]
dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr]
has_clone_to_channel_last_node_in_pattern = (
conv_node.args[1].target is aten.clone.default # type: ignore[union-attr]
)
@ -1585,10 +1544,7 @@ def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float3
# Erase the dequant pattern
if dtype == torch.bfloat16:
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined]
# Erase the dequant pattern
graph.erase_node(mul_node)
graph.erase_node(sub_node)
graph.erase_node(to_fp32_node)
graph.erase_node(dequant_node)
# Erase the dequant per channel pattern
if clone_node is not None:
graph.erase_node(clone_node)
@ -1608,7 +1564,7 @@ def _generate_dequant_convolution_node_pattern(
dequant_convolution_node_pattern = CallFunction(
aten.convolution.default,
_may_generate_pattern_with_dtype_convert(
dequantize_per_tensor_activation_pattern,
get_dequantize_per_tensor_activation_pattern(),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
@ -1668,7 +1624,7 @@ def _get_linear_node(match, input_dim_exceeds_two, input_contiguous):
return linear_node, output_reshape_node
def _get_linear_dq_mul_node(
def _get_linear_dq_node(
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
):
act_reshape_node = None
@ -1679,30 +1635,30 @@ def _get_linear_dq_mul_node(
act_reshape_node = linear_node.args[input_index]
assert act_reshape_node.target is aten.reshape.default
if dtype == torch.float32:
# pattern: linear -> reshape -> mul
mul_node = act_reshape_node.args[0]
# pattern: linear -> reshape -> dequant
dequant_node = act_reshape_node.args[0]
else:
# pattern: linear -> reshape -> to_bf16 -> mul
# pattern: linear -> reshape -> to_bf16 -> dequant
activation_to_bf16_node = act_reshape_node.args[0]
mul_node = activation_to_bf16_node.args[0]
dequant_node = activation_to_bf16_node.args[0]
else:
# bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous
act_expand_node = linear_node.args[input_index]
assert act_expand_node.target is aten.expand.default
if dtype == torch.float32:
mul_node = act_expand_node.args[0]
dequant_node = act_expand_node.args[0]
else:
activation_to_bf16_node = act_expand_node.args[0]
mul_node = activation_to_bf16_node.args[0]
dequant_node = activation_to_bf16_node.args[0]
else:
if dtype == torch.float32:
# pattern: linear -> mul
mul_node = linear_node.args[input_index]
# pattern: linear -> dequant
dequant_node = linear_node.args[input_index]
else:
# pattern: linear -> to_bf16 -> mul
# pattern: linear -> to_bf16 -> dequant
activation_to_bf16_node = linear_node.args[input_index]
mul_node = activation_to_bf16_node.args[0]
return mul_node, act_reshape_node, activation_to_bf16_node, act_expand_node
dequant_node = activation_to_bf16_node.args[0]
return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node
def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous):
@ -1715,27 +1671,21 @@ def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contigu
input_index = 1 if linear_node.target is aten.addmm.default else 0
assert dtype in [torch.float32, torch.bfloat16]
(
mul_node,
dequant_node,
_,
_,
_,
) = _get_linear_dq_mul_node(
) = _get_linear_dq_node(
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
)
sub_node = mul_node.args[0]
to_fp32_node = sub_node.args[0]
assert dequant_node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
]
assert to_fp32_node.target is prims.convert_element_type.default
assert sub_node.target is aten.sub.Tensor
assert mul_node.target is aten.mul.Tensor
if (
len(list(to_fp32_node.users)) != 1
or len(list(sub_node.users)) != 1
or len(list(mul_node.users)) != 1
):
if len(list(dequant_node.users)) != 1:
# Ensure the dequant pattern only has 1 user
# since we will delete the dequant pattern here
return False
@ -1820,17 +1770,14 @@ def _register_qlinear_weight_prepack_pass(
weight_index = input_index + 1
(
mul_node,
dequant_node,
act_reshape_node,
activation_to_bf16_node,
act_expand_node,
) = _get_linear_dq_mul_node(
) = _get_linear_dq_node(
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
)
sub_node = mul_node.args[0]
to_fp32_node = sub_node.args[0]
if input_dim_exceeds_two and not input_contiguous:
wgt_expand_node = linear_node.args[weight_index]
assert wgt_expand_node.target is aten.expand.default
@ -1938,9 +1885,7 @@ def _register_qlinear_weight_prepack_pass(
if dtype == torch.bfloat16:
graph.erase_node(activation_to_bf16_node)
# Erase the dequant pattern
graph.erase_node(mul_node)
graph.erase_node(sub_node)
graph.erase_node(to_fp32_node)
graph.erase_node(dequant_node)
# Erase the dequant per channel pattern
graph.erase_node(t_node)
if dtype == torch.bfloat16:
@ -1954,7 +1899,10 @@ def _register_qlinear_weight_prepack_pass(
def _generate_dequant_linear_node_pattern(
_dequant_per_channel_pattern, dtype=torch.float32, input_dim_exceeds_two=False
_dequant_per_channel_pattern,
dtype=torch.float32,
input_dim_exceeds_two=False,
is_tensor_overload=False,
):
assert dtype in [torch.float32, torch.bfloat16]
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
@ -1964,7 +1912,7 @@ def _generate_dequant_linear_node_pattern(
KeywordArg("b"),
_may_generate_pattern_with_reshape(
_may_generate_pattern_with_dtype_convert(
dequantize_per_tensor_activation_pattern,
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
@ -1981,7 +1929,7 @@ def _generate_dequant_linear_node_pattern(
aten.mm.default,
_may_generate_pattern_with_reshape(
_may_generate_pattern_with_dtype_convert(
dequantize_per_tensor_activation_pattern,
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
@ -2000,6 +1948,7 @@ def _generate_dequant_bmm_node_pattern(
_dequant_per_channel_pattern,
dtype=torch.float32,
with_bias=False,
is_tensor_overload=False,
):
# When activation of linear dim exceed 2 and not contiguous
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
@ -2010,7 +1959,7 @@ def _generate_dequant_bmm_node_pattern(
CallFunction(
aten.expand.default,
_may_generate_pattern_with_dtype_convert(
dequantize_per_tensor_activation_pattern,
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
@ -2041,16 +1990,21 @@ def _generate_qlinear_weight_prepack_patterns(
input_dim_exceeds_two=False,
input_contiguous=True,
with_bias=False,
is_tensor_overload=False,
):
if input_dim_exceeds_two and not input_contiguous:
return _generate_dequant_bmm_node_pattern(
dequantize_per_channel_weight_pattern,
dtype,
with_bias,
is_tensor_overload,
)
else:
return _generate_dequant_linear_node_pattern(
dequantize_per_channel_weight_pattern, dtype, input_dim_exceeds_two
dequantize_per_channel_weight_pattern,
dtype,
input_dim_exceeds_two,
is_tensor_overload,
)
@ -2082,7 +2036,7 @@ def _register_dequant_promotion():
_register_dequant_promotion_pass(
_may_generate_pattern_with_reshape(
_may_generate_pattern_with_dtype_convert(
dequantize_per_tensor_activation_pattern,
get_dequantize_per_tensor_activation_pattern(),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
@ -2140,13 +2094,15 @@ def _register_qlinear_weight_prepack():
# | OPT(add) |
linear_weight_prepack_cases = itertools.product(
[torch.float32, torch.bfloat16], [True, False]
[torch.float32, torch.bfloat16], [True, False], [True, False]
)
# Step 1: register patterns from mm and addmm
for dtype, input_dim_exceeds_two in linear_weight_prepack_cases:
for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases:
weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns(
dtype, input_dim_exceeds_two
dtype,
input_dim_exceeds_two,
is_tensor_overload=is_tensor_overload,
)
for weight_prepack_pattern in weight_prepack_patterns:
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
@ -2163,14 +2119,15 @@ def _register_qlinear_weight_prepack():
# https://github.com/pytorch/pytorch/blob/
# 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968
# in this case, we can convert it back to qlinear
for dtype, with_bias in itertools.product(
[torch.float32, torch.bfloat16], [True, False]
for dtype, with_bias, is_tensor_overload in itertools.product(
[torch.float32, torch.bfloat16], [True, False], [True, False]
):
bmm_pattern = _generate_qlinear_weight_prepack_patterns(
dtype=dtype,
input_dim_exceeds_two=True,
input_contiguous=False,
with_bias=with_bias,
is_tensor_overload=is_tensor_overload,
)
_register_qlinear_weight_prepack_pass(
bmm_pattern,

View File

@ -1144,6 +1144,170 @@ def quantized_decomposed_dequantize_per_channel(
)
@register_lowering(
quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None
)
def quantized_decomposed_quantize_per_tensor_default(
input: TensorBox,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> TensorBox:
if input.get_dtype() == torch.bfloat16:
input = to_dtype(input, torch.float32)
assert (
input.get_dtype() == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
input_loader = input.make_loader()
def inner_fn(idx, scale, zero_point):
input = input_loader(idx)
inv_scale, zero_point = _create_constants(
1.0 / scale, zero_point, dtype=torch.float32
)
val = ops.round(input * inv_scale) + zero_point
qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
clamped = ops.minimum(ops.maximum(val, qmin), qmax)
return ops.to_dtype(clamped, dtype)
return Pointwise.create(
device=input.get_device(),
dtype=dtype,
inner_fn=functools.partial(
inner_fn, scale=float(scale), zero_point=int(zero_point)
),
ranges=input.get_size(),
)
@register_lowering(
quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None
)
def quantized_decomposed_dequantize_per_tensor_default(
input: TensorBox,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> TensorBox:
assert (
input.get_dtype() == dtype
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
input_loader = input.make_loader()
def inner_fn(idx, scale, zero_point):
input = input_loader(idx)
scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32)
val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
return val
return Pointwise.create(
device=input.get_device(),
dtype=torch.float32,
inner_fn=functools.partial(
inner_fn, scale=float(scale), zero_point=int(zero_point)
),
ranges=input.get_size(),
)
@register_lowering(
quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None
)
def quantized_decomposed_quantize_per_tensor_tensor(
input: TensorBox,
scale: TensorBox,
zero_point: TensorBox,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> TensorBox:
if input.get_dtype() == torch.bfloat16:
input = to_dtype(input, torch.float32)
assert (
input.get_dtype() == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
assert len(scale.get_size()) == 0 or (
len(scale.get_size()) == 1 and scale.get_size()[0] == 1
), "expect scale as scalar tensor"
assert len(zero_point.get_size()) == 0 or (
len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
), "expect zero_point as scalar tensor"
input_loader = input.make_loader()
scale_loader = scale.make_loader()
zero_point_loader = zero_point.make_loader()
def inner_fn(idx):
input = input_loader(idx)
_scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
_zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
if scale.dtype != torch.float32:
_scale = ops.to_dtype(_scale, torch.float32)
if zero_point.dtype != torch.float32:
_zero_point = ops.to_dtype(_zero_point, torch.float32)
val = ops.round(input * ops.reciprocal(_scale)) + _zero_point
qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
clamped = ops.minimum(ops.maximum(val, qmin), qmax)
return ops.to_dtype(clamped, dtype)
return Pointwise.create(
device=input.get_device(),
dtype=dtype,
inner_fn=inner_fn,
ranges=input.get_size(),
)
@register_lowering(
quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None
)
def quantized_decomposed_dequantize_per_tensor_tensor(
input: TensorBox,
scale: TensorBox,
zero_point: TensorBox,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> TensorBox:
assert len(scale.get_size()) == 0 or (
len(scale.get_size()) == 1 and scale.get_size()[0] == 1
), "expect scale as scalar tensor"
assert len(zero_point.get_size()) == 0 or (
len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
), "expect zero_point as scalar tensor"
assert (
input.get_dtype() == dtype
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
input_loader = input.make_loader()
scale_loader = scale.make_loader()
zero_point_loader = zero_point.make_loader()
def inner_fn(idx):
input = input_loader(idx)
_scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
_zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
if scale.dtype != torch.float32:
_scale = ops.to_dtype(_scale, torch.float32)
if zero_point.dtype != torch.float32:
_zero_point = ops.to_dtype(_zero_point, torch.float32)
val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale
return val
return Pointwise.create(
device=input.get_device(),
dtype=torch.float32,
inner_fn=inner_fn,
ranges=input.get_size(),
)
@register_lowering(aten.cat)
def cat(inputs, dim=0):
cpu_device = inputs[0].get_device().type == "cpu"