[Inductor][CPU] Fuse SmoothQuant int8 linear pattern (#139595)

**About the PR**
In the implementation of SmoothQuant in Torchao, quantized linear is computed by `_int_mm(a, b)` + `mul(b_scale)` + `mul(a_scale)` (+ optional `add` for bias) with `reshape` and `convert_dtype` in between.
This PR adds a pass to fuse the corresponding patterns:
- (no bias) `reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape`
- (with bias) `pattern_no_bias -> add -> reshape -> reshape`

The patterns are replaced by `onednn.qlinear_pointwise` and `onednn.qlinear_prepack`, the latter of which is evaluated and frozen during the freezing process of Inductor. The final graph contains `onednn.qlinear_pointwise` only with packed weight constants.

Note that `onednn.qlinear_pointwise` does not support per-channel quantization of activation, which is a limitation of oneDNN library, so in that case we set activation scale to 1 and bias to none and apply scales and add bias after `onednn.qlinear_pointwise`.

**Validation results**
Accuracy/perplexity is not changed with or without this fusion pass.
Latency is improved by >10% with the fusion pass.
Test method:
- Model: EleutherAI/gpt-j-6b
- Hardware: Intel(R) Xeon(R) Platinum 8490H, running on 1 socket, 60 cores
- Using Intel OMP and Tcmalloc
- Running [the example script of SmoothQuant in Torchao](https://github.com/pytorch/ao/blob/main/torchao/prototype/smoothquant/example.py) with `TORCHINDUCTOR_FREEZING=1 numactl -N1 python example.py -m EleutherAI/gpt-j-6b --device=cpu --quant-mode=dynamic --compile`

**Test plan**
```
python test/inductor/test_mkldnn_pattern_matcher.py -k test_smooth_quant_with_int_mm
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139595
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
Xia, Weiwen
2024-11-04 22:06:17 -08:00
committed by PyTorch MergeBot
parent d031d1bf4c
commit 22e89ea2aa
3 changed files with 284 additions and 9 deletions

View File

@ -932,8 +932,8 @@ static at::Tensor linear_int8_with_onednn_weight(
c10::string_view& unary_post_op_algorithm) {
using ideep::tensor;
const int64_t dim = input.dim();
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
"qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char).");
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char,
"qlinear with mkldnn tensor: data type of input should be uint8 or int8 (unsigned char or char).");
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
"qlinear with mkldnn tensor: data type of weight should be int8 (char).");
TORCH_CHECK(
@ -1022,7 +1022,8 @@ static at::Tensor linear_int8_with_onednn_weight(
empty_tensor;
// Create onednn primitive
auto src_desc = tensor::desc(src_dims, ideep::data_type::u8, ideep::format_tag::any);
auto src_dtype = input.scalar_type() == c10::kByte ? ideep::data_type::u8 : ideep::data_type::s8;
auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any);
auto weights_desc = packed_weight.get_desc();
auto dst_dtype = dst.get_data_type();
auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any);
@ -1119,12 +1120,14 @@ namespace at::native {
torch::List<std::optional<at::Scalar>> post_op_args,
c10::string_view post_op_algorithm) {
#if AT_MKLDNN_ENABLED()
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
"onednn int8 linear: act scale/zp size should be 1");
// act_zero_point.numel() == 0 for symmetric quantization
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
"onednn int8 linear: act scale/zp size should be 1/<=1");
static std::optional<at::Tensor> other = std::nullopt;
static const c10::string_view binary_post_op = "none";
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
return linear_int8_with_onednn_weight(
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
act, act_scale.item().toDouble(), act_zp,
onednn_weight, weight_scales, weight_zero_points,
bias, output_scale, output_zero_point, output_dtype,
other, /*other scale*/1.0, /*other zp*/0,
@ -1155,10 +1158,12 @@ namespace at::native {
torch::List<std::optional<at::Scalar>> unary_post_op_args,
c10::string_view unary_post_op_algorithm) {
#if AT_MKLDNN_ENABLED()
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
"onednn int8 linear: act scale/zp size should be 1");
// act_zero_point.numel() == 0 for symmetric quantization
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
"onednn int8 linear: act scale/zp size should be 1/<=1");
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
return linear_int8_with_onednn_weight(
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
act, act_scale.item().toDouble(), act_zp,
onednn_weight, weight_scales, weight_zero_points,
bias, output_scale, output_zero_point, output_dtype,
other, other_scale, other_zero_point,

View File

@ -2824,6 +2824,92 @@ class TestPatternMatcher(TestPatternMatcherBase):
rtol=0.07,
)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_smooth_quant_with_int_mm(self):
r"""
This testcase check if we can match the SmoothQuant int8 linear pattern from Torchao.
The pattern is:
(no bias) reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape
or
(with bias) pattern_no_bias -> add -> reshape -> reshape
"""
M = 16
in_feature = 64
out_feature = 128
q_min, q_max = -32, 31
class Mod(torch.nn.Module):
def __init__(
self, dtype: torch.dtype, has_bias: bool, per_channel_quant: bool
):
super().__init__()
self.dtype = dtype
self.has_bias = has_bias
self.b = torch.randint(
q_min, q_max, [in_feature, out_feature], dtype=torch.int8
)
self.per_channel_quant = per_channel_quant
self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01
self.b_scale = self.b_scale.to(dtype)
self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None
def forward(self, a, a_scale_per_tensor, a_scale_per_channel):
out_shape = a.shape[:-1] + (self.b.size(-1),)
a_reshaped = a.reshape(-1, a.size(-1))
c = torch._int_mm(a_reshaped, self.b)
c = c.to(self.dtype)
c_shape = c.shape
a_scale = (
a_scale_per_channel
if self.per_channel_quant
else a_scale_per_tensor
)
a_scale = a_scale.expand(c.shape)
c = c * a_scale
c = c * self.b_scale
if self.has_bias:
c = c.reshape([1, *list(c_shape)])
c = c + self.bias
c = c.reshape(c_shape)
c = c.reshape(out_shape)
return c
has_bias_list = [True, False]
dype_list = (
[torch.float, torch.bfloat16]
if torch.ops.mkldnn._is_mkldnn_bf16_supported()
else [torch.float]
)
per_channel_list = [True, False]
for has_bias, dtype, per_channel_quant in itertools.product(
has_bias_list, dype_list, per_channel_list
):
mod = Mod(dtype, has_bias, per_channel_quant).eval()
a = torch.randint(q_min, q_max, [1, M, in_feature], dtype=torch.int8)
a_scale_per_tensor = torch.rand([1], dtype=dtype) * 0.01 + 0.01
a_scale_per_channel = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01
a_scale_per_tensor, a_scale_per_channel = (
a_scale_per_tensor.to(dtype),
a_scale_per_channel.to(dtype),
)
def matcher_check_fn():
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
)
self.assertEqual(
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
10 if has_bias else 7,
)
self._test_common(
mod,
(a, a_scale_per_tensor, a_scale_per_channel),
matcher_check_fn=matcher_check_fn,
check_autocast=dtype,
)
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class TestDynamicPatternMatcher(TestPatternMatcherBase):

View File

@ -2529,6 +2529,187 @@ def _register_qlinear_weight_prepack():
)
def _register_smooth_quant_int_mm_pattern():
"""
The pattern is:
(no bias) reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape
or
(with bias) pattern_no_bias -> add -> reshape -> reshape
"""
pattern_no_bias = CallFunction(
aten.reshape.default,
CallFunction(
aten.mul.Tensor,
CallFunction(
aten.mul.Tensor,
CallFunction(
prims.convert_element_type.default,
CallFunction(
aten._int_mm.default,
CallFunction(
aten.reshape.default,
KeywordArg("a"),
KeywordArg("in_shape"),
),
KeywordArg("b"),
),
KeywordArg("dtype"),
),
CallFunction(
aten.expand.default,
KeywordArg("x_scale"),
Arg(),
),
),
KeywordArg("w_scale"),
),
KeywordArg("out_shape_no_bias"),
)
pattern_with_bias = CallFunction(
aten.reshape.default,
CallFunction(
aten.reshape.default,
CallFunction(
aten.add.Tensor,
pattern_no_bias,
KeywordArg("bias"),
),
Arg(),
),
KeywordArg("out_shape_with_bias"),
)
def _validate_pattern(match: Match):
return len(match.nodes) in [7, 10]
for pattern in [pattern_with_bias, pattern_no_bias]:
@register_freezing_graph_pattern(
pattern,
extra_check=_validate_pattern,
pass_number=0,
)
def _int_mm_weight_prepack(match: Match, *args, **kwargs):
bias = kwargs.get("bias", None)
if bias is not None:
if len(bias.meta.get("tensor_meta").shape) != 1:
# we expect bias is a vector
return
x = kwargs["a"]
weight = kwargs["b"]
dtype = kwargs["dtype"]
x_scale = kwargs["x_scale"]
w_scale = kwargs["w_scale"]
x_shape = x.meta.get("tensor_meta").shape
if has_free_symbols(x_shape):
# For dynamic shape case, we can't get activation shape ahead of runtime.
x_shape = None
out_node = match.output_node()
with match.graph.inserting_before(out_node):
transpose_node = match.graph.call_function(
aten.permute.default, args=(weight, [1, 0])
)
contig_node = match.graph.call_function(
aten.contiguous.default, args=(transpose_node,)
)
packed_weight_inputs = (
contig_node,
x_shape,
)
packed_weight_op = torch.ops.onednn.qlinear_prepack
prepack_weight_node = match.graph.call_function(
packed_weight_op, args=packed_weight_inputs
)
dummy_zp = match.graph.call_function(aten.empty, args=([0],))
w_scale = match.graph.call_function(
prims.convert_element_type.default, args=(w_scale, torch.float32)
)
x_scale_shape = x_scale.meta.get("tensor_meta").shape
x_scale_is_scalar = False
if not has_free_symbols(x_scale_shape):
prod = 1
for d in x_scale_shape:
prod *= d
x_scale_is_scalar = prod == 1
new_args: Tuple[Any, ...]
if x_scale_is_scalar:
# in this case, we can call onednn.qlinear directly
new_args = (
x,
x_scale,
dummy_zp, # x_zp
prepack_weight_node,
w_scale,
dummy_zp, # w_zp
bias,
1.0, # output_scale
0, # output_zero_point
dtype, # output_dtype
"none", # post op name
[], # post op args
"", # post op algorithm
)
new_linear_node = match.graph.call_function(
torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
)
out_node.replace_all_uses_with(new_linear_node)
new_linear_node.meta.update(out_node.meta)
else:
# onednn.qlinear does not support per-channel quantization of x
# so in this case, we have to apply x scale and add bias ourselves after qlinear
x_reshaped = match.graph.call_function(
aten.reshape.default, args=(x, kwargs["in_shape"])
)
new_args = (
x_reshaped,
1.0, # x_scale
0, # x_zp
prepack_weight_node,
w_scale,
dummy_zp, # w_zp
None, # bias
1.0, # output_scale
0, # output_zero_point
dtype, # output_dtype
"none", # post op name
[], # post op args
"", # post op algorithm
)
new_linear_node = match.graph.call_function(
torch.ops.onednn.qlinear_pointwise, args=new_args
)
# apply x scale
new_out_node = match.graph.call_function(
aten.mul.Tensor, args=(new_linear_node, x_scale)
)
# Add bias and reshape
if bias is not None:
new_out_node = match.graph.call_function(
aten.add.Tensor, args=(new_out_node, bias)
)
new_out_node = match.graph.call_function(
aten.reshape.default,
args=(new_out_node, kwargs["out_shape_with_bias"]),
)
else:
new_out_node = match.graph.call_function(
aten.reshape.default,
args=(new_out_node, kwargs["out_shape_no_bias"]),
)
out_node.replace_all_uses_with(new_out_node)
new_out_node.meta.update(out_node.meta)
for node in reversed(match.nodes):
match.graph.erase_node(node)
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
match.nodes
)
@functools.lru_cache(None)
def _register_quantization_weight_pack_pass():
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
@ -2540,6 +2721,9 @@ def _register_quantization_weight_pack_pass():
# Step 3: QLinear weight prepack
_register_qlinear_weight_prepack()
# Step 4: weight prepack for SmoothQuant from Torchao
_register_smooth_quant_int_mm_pattern()
def quant_lift_up(graph_module: torch.fx.GraphModule):
"""