Compare commits

...

2 Commits

Author SHA1 Message Date
a7103e27f7 Disable max-autotune for UT 2025-01-12 23:47:24 -08:00
dfe4cfc8af Copy SmoothQuant UT from test_mkldnn_pattern_matcher.py to test_cpu_select_algorithm.py
TestSelectAlgorithmCPU.test_smooth_quant_with_int_mm_has_bias_True_bfloat16_per_channel_quant_True_dynamic_False_cpu_bfloat16 fails but passes in test_mkldnn_pattern_matcher.py
2025-01-12 23:41:01 -08:00

View File

@ -209,6 +209,75 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@inductor_config.patch({"freezing": True})
#@patches
@torch.no_grad
@parametrize("has_bias", [True, False])
@parametrize("dtype", [torch.bfloat16])
@parametrize("per_channel_quant", [True, False])
@parametrize("dynamic", [False])
def test_smooth_quant_with_int_mm(
self, has_bias, dtype, per_channel_quant, dynamic
):
r"""
This testcase check if we can match the SmoothQuant int8 linear pattern from Torchao.
The pattern is:
(no bias) reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape
or
(with bias) pattern_no_bias -> add -> reshape -> reshape
"""
if dtype == torch.bfloat16 and not torch.ops.mkldnn._is_mkldnn_bf16_supported():
return
M = 16
in_feature = 32
out_feature = 64
q_min, q_max = -32, 31
class Mod(torch.nn.Module):
def __init__(
self, dtype: torch.dtype, has_bias: bool, per_channel_quant: bool
):
super().__init__()
self.dtype = dtype
self.has_bias = has_bias
self.b = torch.randint(
q_min, q_max, [in_feature, out_feature], dtype=torch.int8
)
self.per_channel_quant = per_channel_quant
a_scale_per_tensor = torch.rand([1], dtype=dtype) * 0.01 + 0.01
a_scale_per_channel = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01
self.a_scale = (
a_scale_per_channel
if self.per_channel_quant
else a_scale_per_tensor
)
self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01
self.b_scale = self.b_scale.to(dtype)
self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None
def forward(self, a):
out_shape = a.shape[:-1] + (self.b.size(-1),)
a_reshaped = a.reshape(-1, a.size(-1))
c = torch._int_mm(a_reshaped, self.b)
c = c.to(self.dtype)
c_shape = c.shape
a_scale = self.a_scale.expand(c.shape)
c = c * a_scale
c = c * self.b_scale
if self.has_bias:
c = c.reshape([1, *list(c_shape)])
c = c + self.bias
c = c.reshape(c_shape)
c = c.reshape(out_shape)
return c
mod = Mod(dtype, has_bias, per_channel_quant).eval()
a = torch.randint(q_min, q_max, [1, M, in_feature], dtype=torch.int8)
self.common(
mod,
(a,),
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad