[Quant][Inductor] Enable lowering of dynamic qlinear for X86Inductor (#120605)

**description**
Enable lowering of dynamic qlinear for X86Inductor. The pattern is `choose_qparams -> getitem -> q -> dq -> linear`. We only fuse `dq -> linear` and get `choose_qparams -> getitem -> q -> onednn.qlinear_pointwise`. So, we treat it as dynamic quantization of activation + static quantized linear.
The previous implementation of `onednn.qlinear_pointwise` is for the case where `x_scale` and `x_zp` are scalars. Since `choose_qparams` returns tensors, we added a variation `onednn.qlinear_pointwise.tensor` to support the case.
This feature is targeting PyTorch 2.3 release.

**Test plan**
```
python inductor/test_mkldnn_pattern_matcher.py -k test_dynamic_qlinear_cpu
python inductor/test_mkldnn_pattern_matcher.py -k test_dynamic_qlinear_qat_cpu
python inductor/test_cpu_cpp_wrapper.py -k test_dynamic_qlinear
```

**Performance before and after lowering `choose_qparam` to Inductor**
Before
- latency for shape (32, 32) = 0.151 ms
  latency for shape (128, 128) = 0.153 ms
  latency for shape (1024, 1024) = 0.247 ms

After
- latency for shape (32, 32) = 0.049 ms
- latency for shape (128, 128) = 0.052 ms
- latency for shape (1024, 1024) = 0.133 ms

Test method: A module with a single Linear layer, dynamic-quantize, lower to X86Inductor
Test env & config: Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz, single instance, single core, using Intel OpenMP and Tcmalloc

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120605
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
Xia, Weiwen
2024-03-02 05:11:13 +00:00
committed by PyTorch MergeBot
parent af5376c444
commit 83d848e1c7
9 changed files with 219 additions and 75 deletions

View File

@ -1158,6 +1158,33 @@ class QLinearOnednn final {
bias, output_scale, output_zero_point, output_dtype,
post_op_name, post_op_args, post_op_algorithm
);
#endif
TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
}
static Tensor run_pointwise_tensor(
Tensor act, // int8 CPU tensor, not QTensor
Tensor act_scale,
Tensor act_zero_point,
Tensor onednn_weight, // int8 tensor from MkldnnCPU
Tensor weight_scales,
Tensor weight_zero_points,
c10::optional<Tensor> bias,
double output_scale,
int64_t output_zero_point,
c10::optional<c10::ScalarType> output_dtype,
std::string post_op_name,
torch::List<c10::optional<at::Scalar>> post_op_args,
std::string post_op_algorithm) {
#if AT_MKLDNN_ENABLED()
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
"onednn int8 linear: act scale/zp size should be 1");
return linear_int8_with_onednn_weight(
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
onednn_weight, weight_scales, weight_zero_points,
bias, output_scale, output_zero_point, output_dtype,
post_op_name, post_op_args, post_op_algorithm
);
#endif
TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)");
}
@ -1185,6 +1212,8 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) {
TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"),
TORCH_FN(QLinearOnednn::run_pointwise));
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"),
TORCH_FN(QLinearOnednn::run_pointwise_tensor));
}
} // namespace

View File

@ -266,4 +266,5 @@ TORCH_LIBRARY(onednn, m) {
// Linear with unary postop
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));
}

View File

@ -257,6 +257,18 @@ if RUN_CPU:
test_mkldnn_pattern_matcher.TestPatternMatcher(),
condition=torch.backends.mkldnn.is_available(),
),
BaseTest(
"test_dynamic_qlinear",
"cpu",
test_mkldnn_pattern_matcher.TestPatternMatcher(),
condition=torch.backends.mkldnn.is_available(),
),
BaseTest(
"test_dynamic_qlinear_qat",
"cpu",
test_mkldnn_pattern_matcher.TestPatternMatcher(),
condition=torch.backends.mkldnn.is_available(),
),
BaseTest("test_randint"),
BaseTest("test_randn_with_dtype_and_device"),
BaseTest("test_reduction1"), # Reduction

View File

@ -89,7 +89,9 @@ class TestPatternMatcherBase(TestCase):
return tuple(clone(x) for x in inputs)
def _generate_qdq_quantized_model(self, mod, inputs, is_qat=False):
def _generate_qdq_quantized_model(
self, mod, inputs, is_qat=False, is_dynamic=False
):
maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
with maybe_no_grad:
export_model = capture_pre_autograd_graph(
@ -98,7 +100,9 @@ class TestPatternMatcherBase(TestCase):
)
quantizer = X86InductorQuantizer()
quantizer.set_global(
xiq.get_default_x86_inductor_quantization_config(is_qat=is_qat)
xiq.get_default_x86_inductor_quantization_config(
is_qat=is_qat, is_dynamic=is_dynamic
)
)
prepare_model = (
prepare_qat_pt2e(export_model, quantizer)
@ -123,6 +127,7 @@ class TestPatternMatcherBase(TestCase):
is_qat=False,
matcher_check_fn=None,
dtype=None,
is_dynamic=False,
):
counters.clear()
torch._dynamo.reset()
@ -146,7 +151,9 @@ class TestPatternMatcherBase(TestCase):
maybe_autocast = contextlib.nullcontext()
if check_quantization:
convert_model = self._generate_qdq_quantized_model(mod, inputs, is_qat)
convert_model = self._generate_qdq_quantized_model(
mod, inputs, is_qat, is_dynamic
)
with torch.no_grad(), maybe_autocast:
_ = torch.compile(convert_model)(*inputs)
if matcher_count is not None:
@ -1195,6 +1202,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
do_permute=False,
matcher_check_fn=None,
bias=True,
is_dynamic=False,
is_qat=False,
):
class M(torch.nn.Module):
def __init__(self, use_bias, do_permute=False):
@ -1223,6 +1232,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
matcher_check_fn=matcher_check_fn
if matcher_check_fn is not None
else _default_matcher_check_fn,
is_qat=is_qat,
is_dynamic=is_dynamic,
)
@skipIfNoDynamoSupport
@ -1235,6 +1246,30 @@ class TestPatternMatcher(TestPatternMatcherBase):
for bias in [True, False]:
self._qlinear_cpu_test_helper((torch.randn((2, 4)),), bias=bias)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_dynamic_qlinear_cpu(self):
r"""
This testcase will quantize a single Linear Moduel.
"""
for bias in [True, False]:
self._qlinear_cpu_test_helper(
(torch.randn((2, 4)),), bias=bias, is_dynamic=True
)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_dynamic_qlinear_qat_cpu(self):
r"""
This testcase will quantize a single Linear Moduel.
"""
for bias in [True, False]:
self._qlinear_cpu_test_helper(
(torch.randn((2, 4)),), bias=bias, is_dynamic=True, is_qat=True
)
@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN

View File

@ -521,7 +521,9 @@ def dequantize_per_tensor_tensor_decomp_impl(
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return (input.to(torch.float32) - zero_point) * scale
return (input.to(torch.float32) - zero_point.to(torch.int32)) * scale.to(
torch.float32
)
@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
@ -646,3 +648,15 @@ def masked_scatter(self, mask, source):
source_idx = mask.reshape(-1).cumsum(0) - 1
return inductor_prims.masked_scatter_with_index(self, mask, source_idx, source)
return NotImplemented
@register_decomposition(quantized_decomposed.choose_qparams.tensor)
def choose_qparams_tensor(
input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
):
min_val, max_val = torch.aminmax(input)
scale = (max_val - min_val) / float(quant_max - quant_min)
scale = torch.max(scale, torch.Tensor([eps]))
zero_point = quant_min - torch.round(min_val / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
return scale.to(torch.float64), zero_point.to(torch.int64)

View File

@ -150,22 +150,29 @@ def get_dequantize_qconv_pt2e_pattern(users=1):
)
qlinear_pt2e_pattern = CallFunction(
torch.ops.onednn.qlinear_pointwise.default,
KeywordArg("x"),
KeywordArg("x_scale"),
KeywordArg("x_zp"),
KeywordArg("packed_weight"),
KeywordArg("w_scale"),
KeywordArg("w_zp"),
KeywordArg("b"),
KeywordArg("output_scale"),
KeywordArg("output_zero_point"),
KeywordArg("output_dtype"),
KeywordArg("postop_name"),
KeywordArg("postop_args"),
KeywordArg("postop_algorithm"),
)
def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors):
qlinear_op = (
torch.ops.onednn.qlinear_pointwise.tensor
if x_scale_zp_are_tensors
else torch.ops.onednn.qlinear_pointwise.default
)
return CallFunction(
qlinear_op,
KeywordArg("x"),
KeywordArg("x_scale"),
KeywordArg("x_zp"),
KeywordArg("packed_weight"),
KeywordArg("w_scale"),
KeywordArg("w_zp"),
KeywordArg("b"),
KeywordArg("output_scale"),
KeywordArg("output_zero_point"),
KeywordArg("output_dtype"),
KeywordArg("postop_name"),
KeywordArg("postop_args"),
KeywordArg("postop_algorithm"),
)
dequantize_accum_pattern = CallFunction(
aten.mul.Tensor,
@ -670,44 +677,46 @@ def _register_quantization_unary_fusion():
)
# QLinear
# Priority 1 to match: QLinear Unary pattern with int8 output
linear_unary_replace_patterns = {
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
qlinear_pt2e_pattern,
dtype=original_pattern_output_dtype,
),
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(qlinear_pt2e_pattern, aten.relu.default),
dtype=original_pattern_output_dtype,
),
}
for x_scale_zp_are_tensors in (False, True):
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
# Priority 1 to match: QLinear Unary pattern with int8 output
linear_unary_replace_patterns = {
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
qlinear_pattern,
dtype=original_pattern_output_dtype,
),
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
dtype=original_pattern_output_dtype,
),
}
for unary_attr, patterns in linear_unary_replace_patterns.items():
_register_quantized_linear_lowering(
patterns,
1, # pass_number
torch.ops.onednn.qlinear_pointwise, # computation_op
None, # output_dtype
unary_attr, # unary_attr
original_pattern_output_dtype=original_pattern_output_dtype,
)
for unary_attr, patterns in linear_unary_replace_patterns.items():
_register_quantized_linear_lowering(
patterns,
1, # pass_number
torch.ops.onednn.qlinear_pointwise, # computation_op
None, # output_dtype
unary_attr, # unary_attr
original_pattern_output_dtype=original_pattern_output_dtype,
)
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
linear_unary_replace_float_out_patterns = {
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
qlinear_pt2e_pattern, aten.relu.default
),
}
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
linear_unary_replace_float_out_patterns = {
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
qlinear_pattern, aten.relu.default
),
}
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
_register_quantized_linear_lowering(
patterns,
2, # pass_number
torch.ops.onednn.qlinear_pointwise, # computation_op
original_pattern_output_dtype, # output_dtype
unary_attr, # unary_attr
original_pattern_output_dtype=original_pattern_output_dtype,
)
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
_register_quantized_linear_lowering(
patterns,
2, # pass_number
torch.ops.onednn.qlinear_pointwise, # computation_op
original_pattern_output_dtype, # output_dtype
unary_attr, # unary_attr
original_pattern_output_dtype=original_pattern_output_dtype,
)
def _register_quantization_binary_fusion():
@ -1672,9 +1681,15 @@ def _register_qlinear_weight_prepack_pass(
[], # post op args
"", # post op algorithm
)
new_linear_node = graph.call_function(
torch.ops.onednn.qlinear_pointwise.default, args=new_args
)
Node = torch.fx.node.Node
if isinstance(x_scale, Node) and isinstance(x_zp, Node):
new_linear_node = graph.call_function(
torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
)
else:
new_linear_node = graph.call_function(
torch.ops.onednn.qlinear_pointwise.default, args=new_args
)
if input_dim_exceeds_two:
if input_contiguous:
output_reshape_node.replace_all_uses_with(new_linear_node)

View File

@ -1034,6 +1034,7 @@ class GraphLowering(torch.fx.Interpreter):
torch.ops.onednn.qconv2d_pointwise.default,
torch.ops.onednn.qconv2d_pointwise.binary,
torch.ops.onednn.qlinear_pointwise.default,
torch.ops.onednn.qlinear_pointwise.tensor,
]
if torch._C.has_mkl:
need_fixed_layout += [torch.ops.mkl._mkl_linear.default]

View File

@ -6478,6 +6478,8 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
layout,
inputs,
constant_args=(),
has_bias=True,
x_scale_zp_are_tensors=False,
):
"""
if bias is not None
@ -6489,21 +6491,32 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
- const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp,
fp32_output, unary_attr, unary_scalars, unary_algorithm]
"""
self.has_bias = len(inputs) == 5
self.has_bias = has_bias
self.x_scale_zp_are_tensors = x_scale_zp_are_tensors
super().__init__(
layout,
inputs,
constant_args,
None,
python_kernel_name="torch.ops.onednn.qlinear_pointwise",
python_kernel_name=(
"torch.ops.onednn.qlinear_pointwise.tensor"
if x_scale_zp_are_tensors
else "torch.ops.onednn.qlinear_pointwise.default"
),
cpp_kernel_name="onednn::qlinear_pointwise",
)
self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else ""
self.cpp_kernel_key = "qlinear_pointwise"
self.cpp_op_schema = """
x_scale_type_str, x_zp_type_str = (
("at::Tensor", "at::Tensor")
if x_scale_zp_are_tensors
else ("double", "int64_t")
)
self.cpp_op_schema = f"""
at::Tensor(
at::Tensor act,
double act_scale,
int64_t act_zero_point,
{x_scale_type_str} act_scale,
{x_zp_type_str} act_zero_point,
at::Tensor weight,
at::Tensor weight_scales,
at::Tensor weight_zero_points,
@ -6525,16 +6538,29 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
packed_weight = args[1]
bias = args[2] if self.has_bias else const_args[0]
w_scale, w_zp = args[-2], args[-1]
(
x_scale,
x_zp,
o_inv_scale,
o_zp,
output_dtype,
unary_attr,
unary_scalars,
unary_algorithm,
) = const_args[-8:]
if self.x_scale_zp_are_tensors:
assert len(args) >= 4
x_scale, x_zp = args[-4], args[-3]
(
o_inv_scale,
o_zp,
output_dtype,
unary_attr,
unary_scalars,
unary_algorithm,
) = const_args[-6:]
else:
assert len(const_args) >= 8
(
x_scale,
x_zp,
o_inv_scale,
o_zp,
output_dtype,
unary_attr,
unary_scalars,
unary_algorithm,
) = const_args[-8:]
codegen_args = (
x,
@ -6557,6 +6583,7 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
codegen_args,
self.cpp_op_schema,
self.cpp_kernel_key,
self.cpp_kernel_overload_name,
)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
@ -6585,12 +6612,19 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
bias,
)
if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox):
x_scale.realize()
x_zp.realize()
inputs = inputs + [x_scale, x_zp]
x_scale_zp_are_tensors = True
else:
assert isinstance(x_scale, float) and isinstance(x_zp, int)
constant_args = constant_args + [x_scale, x_zp]
x_scale_zp_are_tensors = False
w_scale.realize()
w_zp.realize()
inputs = inputs + [w_scale, w_zp]
constant_args = constant_args + [
x_scale,
x_zp,
o_inv_scale,
output_zero_point,
output_dtype,
@ -6609,6 +6643,8 @@ class QLinearPointwisePT2E(ExternKernelAlloc):
layout=kernel_layout,
inputs=inputs,
constant_args=constant_args,
has_bias=(bias is not None),
x_scale_zp_are_tensors=x_scale_zp_are_tensors,
)

View File

@ -2249,6 +2249,7 @@ if torch._C._has_mkldnn:
return out
@register_meta(torch.ops.onednn.qlinear_pointwise.default)
@register_meta(torch.ops.onednn.qlinear_pointwise.tensor)
def meta_qlinear_pointwise(
x,
x_scale,