mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][CPU] Fuse SmoothQuant int8 linear pattern (#139595)
**About the PR** In the implementation of SmoothQuant in Torchao, quantized linear is computed by `_int_mm(a, b)` + `mul(b_scale)` + `mul(a_scale)` (+ optional `add` for bias) with `reshape` and `convert_dtype` in between. This PR adds a pass to fuse the corresponding patterns: - (no bias) `reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape` - (with bias) `pattern_no_bias -> add -> reshape -> reshape` The patterns are replaced by `onednn.qlinear_pointwise` and `onednn.qlinear_prepack`, the latter of which is evaluated and frozen during the freezing process of Inductor. The final graph contains `onednn.qlinear_pointwise` only with packed weight constants. Note that `onednn.qlinear_pointwise` does not support per-channel quantization of activation, which is a limitation of oneDNN library, so in that case we set activation scale to 1 and bias to none and apply scales and add bias after `onednn.qlinear_pointwise`. **Validation results** Accuracy/perplexity is not changed with or without this fusion pass. Latency is improved by >10% with the fusion pass. Test method: - Model: EleutherAI/gpt-j-6b - Hardware: Intel(R) Xeon(R) Platinum 8490H, running on 1 socket, 60 cores - Using Intel OMP and Tcmalloc - Running [the example script of SmoothQuant in Torchao](https://github.com/pytorch/ao/blob/main/torchao/prototype/smoothquant/example.py) with `TORCHINDUCTOR_FREEZING=1 numactl -N1 python example.py -m EleutherAI/gpt-j-6b --device=cpu --quant-mode=dynamic --compile` **Test plan** ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_smooth_quant_with_int_mm ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/139595 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
d031d1bf4c
commit
22e89ea2aa
@ -932,8 +932,8 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||
c10::string_view& unary_post_op_algorithm) {
|
||||
using ideep::tensor;
|
||||
const int64_t dim = input.dim();
|
||||
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
|
||||
"qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char).");
|
||||
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char,
|
||||
"qlinear with mkldnn tensor: data type of input should be uint8 or int8 (unsigned char or char).");
|
||||
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
|
||||
"qlinear with mkldnn tensor: data type of weight should be int8 (char).");
|
||||
TORCH_CHECK(
|
||||
@ -1022,7 +1022,8 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||
empty_tensor;
|
||||
|
||||
// Create onednn primitive
|
||||
auto src_desc = tensor::desc(src_dims, ideep::data_type::u8, ideep::format_tag::any);
|
||||
auto src_dtype = input.scalar_type() == c10::kByte ? ideep::data_type::u8 : ideep::data_type::s8;
|
||||
auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any);
|
||||
auto weights_desc = packed_weight.get_desc();
|
||||
auto dst_dtype = dst.get_data_type();
|
||||
auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any);
|
||||
@ -1119,12 +1120,14 @@ namespace at::native {
|
||||
torch::List<std::optional<at::Scalar>> post_op_args,
|
||||
c10::string_view post_op_algorithm) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
|
||||
"onednn int8 linear: act scale/zp size should be 1");
|
||||
// act_zero_point.numel() == 0 for symmetric quantization
|
||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
|
||||
"onednn int8 linear: act scale/zp size should be 1/<=1");
|
||||
static std::optional<at::Tensor> other = std::nullopt;
|
||||
static const c10::string_view binary_post_op = "none";
|
||||
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
|
||||
return linear_int8_with_onednn_weight(
|
||||
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
|
||||
act, act_scale.item().toDouble(), act_zp,
|
||||
onednn_weight, weight_scales, weight_zero_points,
|
||||
bias, output_scale, output_zero_point, output_dtype,
|
||||
other, /*other scale*/1.0, /*other zp*/0,
|
||||
@ -1155,10 +1158,12 @@ namespace at::native {
|
||||
torch::List<std::optional<at::Scalar>> unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
|
||||
"onednn int8 linear: act scale/zp size should be 1");
|
||||
// act_zero_point.numel() == 0 for symmetric quantization
|
||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
|
||||
"onednn int8 linear: act scale/zp size should be 1/<=1");
|
||||
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
|
||||
return linear_int8_with_onednn_weight(
|
||||
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
|
||||
act, act_scale.item().toDouble(), act_zp,
|
||||
onednn_weight, weight_scales, weight_zero_points,
|
||||
bias, output_scale, output_zero_point, output_dtype,
|
||||
other, other_scale, other_zero_point,
|
||||
|
@ -2824,6 +2824,92 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||
rtol=0.07,
|
||||
)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
def test_smooth_quant_with_int_mm(self):
|
||||
r"""
|
||||
This testcase check if we can match the SmoothQuant int8 linear pattern from Torchao.
|
||||
The pattern is:
|
||||
(no bias) reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape
|
||||
or
|
||||
(with bias) pattern_no_bias -> add -> reshape -> reshape
|
||||
"""
|
||||
M = 16
|
||||
in_feature = 64
|
||||
out_feature = 128
|
||||
q_min, q_max = -32, 31
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(
|
||||
self, dtype: torch.dtype, has_bias: bool, per_channel_quant: bool
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.has_bias = has_bias
|
||||
self.b = torch.randint(
|
||||
q_min, q_max, [in_feature, out_feature], dtype=torch.int8
|
||||
)
|
||||
self.per_channel_quant = per_channel_quant
|
||||
self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01
|
||||
self.b_scale = self.b_scale.to(dtype)
|
||||
self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None
|
||||
|
||||
def forward(self, a, a_scale_per_tensor, a_scale_per_channel):
|
||||
out_shape = a.shape[:-1] + (self.b.size(-1),)
|
||||
a_reshaped = a.reshape(-1, a.size(-1))
|
||||
c = torch._int_mm(a_reshaped, self.b)
|
||||
c = c.to(self.dtype)
|
||||
c_shape = c.shape
|
||||
a_scale = (
|
||||
a_scale_per_channel
|
||||
if self.per_channel_quant
|
||||
else a_scale_per_tensor
|
||||
)
|
||||
a_scale = a_scale.expand(c.shape)
|
||||
c = c * a_scale
|
||||
c = c * self.b_scale
|
||||
if self.has_bias:
|
||||
c = c.reshape([1, *list(c_shape)])
|
||||
c = c + self.bias
|
||||
c = c.reshape(c_shape)
|
||||
c = c.reshape(out_shape)
|
||||
return c
|
||||
|
||||
has_bias_list = [True, False]
|
||||
dype_list = (
|
||||
[torch.float, torch.bfloat16]
|
||||
if torch.ops.mkldnn._is_mkldnn_bf16_supported()
|
||||
else [torch.float]
|
||||
)
|
||||
per_channel_list = [True, False]
|
||||
for has_bias, dtype, per_channel_quant in itertools.product(
|
||||
has_bias_list, dype_list, per_channel_list
|
||||
):
|
||||
mod = Mod(dtype, has_bias, per_channel_quant).eval()
|
||||
a = torch.randint(q_min, q_max, [1, M, in_feature], dtype=torch.int8)
|
||||
a_scale_per_tensor = torch.rand([1], dtype=dtype) * 0.01 + 0.01
|
||||
a_scale_per_channel = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01
|
||||
a_scale_per_tensor, a_scale_per_channel = (
|
||||
a_scale_per_tensor.to(dtype),
|
||||
a_scale_per_channel.to(dtype),
|
||||
)
|
||||
|
||||
def matcher_check_fn():
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
|
||||
10 if has_bias else 7,
|
||||
)
|
||||
|
||||
self._test_common(
|
||||
mod,
|
||||
(a, a_scale_per_tensor, a_scale_per_channel),
|
||||
matcher_check_fn=matcher_check_fn,
|
||||
check_autocast=dtype,
|
||||
)
|
||||
|
||||
|
||||
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
|
||||
class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||
|
@ -2529,6 +2529,187 @@ def _register_qlinear_weight_prepack():
|
||||
)
|
||||
|
||||
|
||||
def _register_smooth_quant_int_mm_pattern():
|
||||
"""
|
||||
The pattern is:
|
||||
(no bias) reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape
|
||||
or
|
||||
(with bias) pattern_no_bias -> add -> reshape -> reshape
|
||||
"""
|
||||
pattern_no_bias = CallFunction(
|
||||
aten.reshape.default,
|
||||
CallFunction(
|
||||
aten.mul.Tensor,
|
||||
CallFunction(
|
||||
aten.mul.Tensor,
|
||||
CallFunction(
|
||||
prims.convert_element_type.default,
|
||||
CallFunction(
|
||||
aten._int_mm.default,
|
||||
CallFunction(
|
||||
aten.reshape.default,
|
||||
KeywordArg("a"),
|
||||
KeywordArg("in_shape"),
|
||||
),
|
||||
KeywordArg("b"),
|
||||
),
|
||||
KeywordArg("dtype"),
|
||||
),
|
||||
CallFunction(
|
||||
aten.expand.default,
|
||||
KeywordArg("x_scale"),
|
||||
Arg(),
|
||||
),
|
||||
),
|
||||
KeywordArg("w_scale"),
|
||||
),
|
||||
KeywordArg("out_shape_no_bias"),
|
||||
)
|
||||
pattern_with_bias = CallFunction(
|
||||
aten.reshape.default,
|
||||
CallFunction(
|
||||
aten.reshape.default,
|
||||
CallFunction(
|
||||
aten.add.Tensor,
|
||||
pattern_no_bias,
|
||||
KeywordArg("bias"),
|
||||
),
|
||||
Arg(),
|
||||
),
|
||||
KeywordArg("out_shape_with_bias"),
|
||||
)
|
||||
|
||||
def _validate_pattern(match: Match):
|
||||
return len(match.nodes) in [7, 10]
|
||||
|
||||
for pattern in [pattern_with_bias, pattern_no_bias]:
|
||||
|
||||
@register_freezing_graph_pattern(
|
||||
pattern,
|
||||
extra_check=_validate_pattern,
|
||||
pass_number=0,
|
||||
)
|
||||
def _int_mm_weight_prepack(match: Match, *args, **kwargs):
|
||||
bias = kwargs.get("bias", None)
|
||||
if bias is not None:
|
||||
if len(bias.meta.get("tensor_meta").shape) != 1:
|
||||
# we expect bias is a vector
|
||||
return
|
||||
x = kwargs["a"]
|
||||
weight = kwargs["b"]
|
||||
dtype = kwargs["dtype"]
|
||||
x_scale = kwargs["x_scale"]
|
||||
w_scale = kwargs["w_scale"]
|
||||
x_shape = x.meta.get("tensor_meta").shape
|
||||
if has_free_symbols(x_shape):
|
||||
# For dynamic shape case, we can't get activation shape ahead of runtime.
|
||||
x_shape = None
|
||||
|
||||
out_node = match.output_node()
|
||||
with match.graph.inserting_before(out_node):
|
||||
transpose_node = match.graph.call_function(
|
||||
aten.permute.default, args=(weight, [1, 0])
|
||||
)
|
||||
contig_node = match.graph.call_function(
|
||||
aten.contiguous.default, args=(transpose_node,)
|
||||
)
|
||||
packed_weight_inputs = (
|
||||
contig_node,
|
||||
x_shape,
|
||||
)
|
||||
packed_weight_op = torch.ops.onednn.qlinear_prepack
|
||||
prepack_weight_node = match.graph.call_function(
|
||||
packed_weight_op, args=packed_weight_inputs
|
||||
)
|
||||
|
||||
dummy_zp = match.graph.call_function(aten.empty, args=([0],))
|
||||
w_scale = match.graph.call_function(
|
||||
prims.convert_element_type.default, args=(w_scale, torch.float32)
|
||||
)
|
||||
|
||||
x_scale_shape = x_scale.meta.get("tensor_meta").shape
|
||||
x_scale_is_scalar = False
|
||||
if not has_free_symbols(x_scale_shape):
|
||||
prod = 1
|
||||
for d in x_scale_shape:
|
||||
prod *= d
|
||||
x_scale_is_scalar = prod == 1
|
||||
|
||||
new_args: Tuple[Any, ...]
|
||||
if x_scale_is_scalar:
|
||||
# in this case, we can call onednn.qlinear directly
|
||||
new_args = (
|
||||
x,
|
||||
x_scale,
|
||||
dummy_zp, # x_zp
|
||||
prepack_weight_node,
|
||||
w_scale,
|
||||
dummy_zp, # w_zp
|
||||
bias,
|
||||
1.0, # output_scale
|
||||
0, # output_zero_point
|
||||
dtype, # output_dtype
|
||||
"none", # post op name
|
||||
[], # post op args
|
||||
"", # post op algorithm
|
||||
)
|
||||
new_linear_node = match.graph.call_function(
|
||||
torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
|
||||
)
|
||||
out_node.replace_all_uses_with(new_linear_node)
|
||||
new_linear_node.meta.update(out_node.meta)
|
||||
else:
|
||||
# onednn.qlinear does not support per-channel quantization of x
|
||||
# so in this case, we have to apply x scale and add bias ourselves after qlinear
|
||||
x_reshaped = match.graph.call_function(
|
||||
aten.reshape.default, args=(x, kwargs["in_shape"])
|
||||
)
|
||||
new_args = (
|
||||
x_reshaped,
|
||||
1.0, # x_scale
|
||||
0, # x_zp
|
||||
prepack_weight_node,
|
||||
w_scale,
|
||||
dummy_zp, # w_zp
|
||||
None, # bias
|
||||
1.0, # output_scale
|
||||
0, # output_zero_point
|
||||
dtype, # output_dtype
|
||||
"none", # post op name
|
||||
[], # post op args
|
||||
"", # post op algorithm
|
||||
)
|
||||
new_linear_node = match.graph.call_function(
|
||||
torch.ops.onednn.qlinear_pointwise, args=new_args
|
||||
)
|
||||
# apply x scale
|
||||
new_out_node = match.graph.call_function(
|
||||
aten.mul.Tensor, args=(new_linear_node, x_scale)
|
||||
)
|
||||
# Add bias and reshape
|
||||
if bias is not None:
|
||||
new_out_node = match.graph.call_function(
|
||||
aten.add.Tensor, args=(new_out_node, bias)
|
||||
)
|
||||
new_out_node = match.graph.call_function(
|
||||
aten.reshape.default,
|
||||
args=(new_out_node, kwargs["out_shape_with_bias"]),
|
||||
)
|
||||
else:
|
||||
new_out_node = match.graph.call_function(
|
||||
aten.reshape.default,
|
||||
args=(new_out_node, kwargs["out_shape_no_bias"]),
|
||||
)
|
||||
out_node.replace_all_uses_with(new_out_node)
|
||||
new_out_node.meta.update(out_node.meta)
|
||||
for node in reversed(match.nodes):
|
||||
match.graph.erase_node(node)
|
||||
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
|
||||
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
|
||||
match.nodes
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _register_quantization_weight_pack_pass():
|
||||
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
|
||||
@ -2540,6 +2721,9 @@ def _register_quantization_weight_pack_pass():
|
||||
# Step 3: QLinear weight prepack
|
||||
_register_qlinear_weight_prepack()
|
||||
|
||||
# Step 4: weight prepack for SmoothQuant from Torchao
|
||||
_register_smooth_quant_int_mm_pattern()
|
||||
|
||||
|
||||
def quant_lift_up(graph_module: torch.fx.GraphModule):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user