mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Useful to have PR testing for PRs such as https://github.com/pytorch/pytorch/pull/151360 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160215 Approved by: https://github.com/malfet, https://github.com/atalman Co-authored-by: Jeff Daily <jeff.daily@amd.com>
4784 lines
165 KiB
Python
4784 lines
165 KiB
Python
# Owner(s): ["oncall: cpu inductor"]
|
|
import contextlib
|
|
import copy
|
|
import itertools
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
|
from torch._dynamo import config as dynamo_config
|
|
from torch._dynamo.utils import counters
|
|
from torch._inductor import config, metrics
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import (
|
|
is_mkldnn_bf16_supported,
|
|
is_mkldnn_fp16_supported,
|
|
run_and_get_code,
|
|
)
|
|
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
|
|
from torch.nn import functional as F
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off
|
|
from torch.testing._internal.common_quantization import (
|
|
_generate_qdq_quantized_model,
|
|
skipIfNoDynamoSupport,
|
|
skipIfNoONEDNN,
|
|
skipIfNoONEDNNBF16,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
IS_FBCODE,
|
|
IS_LINUX,
|
|
IS_X86,
|
|
MI300_ARCH,
|
|
MI350_ARCH,
|
|
parametrize,
|
|
skipIfNoXPU,
|
|
skipIfRocm,
|
|
skipIfRocmArch,
|
|
skipIfXpu,
|
|
TEST_ACL,
|
|
TEST_MKL,
|
|
xfailIfACL,
|
|
)
|
|
from torch.testing._internal.inductor_utils import (
|
|
_check_has_dynamic_shape,
|
|
clone_preserve_strides_offset,
|
|
HAS_CPU,
|
|
)
|
|
|
|
|
|
# The dict value is match_nodes(computation_op+unary_op)
|
|
|
|
unary_list = {
|
|
torch.nn.ReLU(): 2,
|
|
torch.nn.Sigmoid(): 2,
|
|
torch.nn.Tanh(): 2,
|
|
torch.nn.Hardswish(): 6,
|
|
torch.nn.LeakyReLU(0.1, inplace=False): 4,
|
|
# Use floats for min/max, otherwise they can get converted to symints
|
|
torch.nn.Hardtanh(min_val=-0.5, max_val=4.0, inplace=False): 3,
|
|
torch.nn.Hardtanh(min_val=-0.5, max_val=float("inf"), inplace=False): 3,
|
|
torch.nn.GELU(approximate="none"): 6,
|
|
torch.nn.GELU(approximate="tanh"): 10,
|
|
torch.nn.ReLU6(): 3,
|
|
torch.nn.SiLU(): 3,
|
|
torch.nn.Hardsigmoid(): 5,
|
|
}
|
|
|
|
non_decomposed_unary_list = [
|
|
torch.nn.ReLU,
|
|
torch.nn.Sigmoid,
|
|
torch.nn.Tanh,
|
|
]
|
|
|
|
# The dict value is (match_count, match_nodes, inplace)
|
|
binary_list = {
|
|
lambda x, y: torch.add(x, y): (1, 2, False), # call_function
|
|
lambda x, y: torch.add(y, x): (1, 2, False), # call_function
|
|
lambda x, y: x.add(y): (1, 2, False), # call_method
|
|
lambda x, y: x.add_(y): (1, 2, True), # call_method
|
|
lambda x, y: torch.sub(x, y): (1, 2, False), # call_function
|
|
lambda x, y: x.sub(y): (1, 2, False), # call_method
|
|
lambda x, y: x.sub_(y): (1, 2, True), # call_method
|
|
}
|
|
|
|
quantization_add_fn_list = [
|
|
lambda x, y: torch.add(x, y),
|
|
lambda x, y: x.add(y),
|
|
]
|
|
|
|
quantization_inplace_add_fn_list = [
|
|
lambda x, y: x.add_(y),
|
|
]
|
|
|
|
|
|
def get_default_quantizer(is_qat, is_dynamic):
|
|
quantizer = X86InductorQuantizer()
|
|
quantizer.set_global(
|
|
xiq.get_default_x86_inductor_quantization_config(
|
|
is_qat=is_qat, is_dynamic=is_dynamic
|
|
)
|
|
)
|
|
return quantizer
|
|
|
|
|
|
def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"):
|
|
# this function is to decide how many kernels are generated
|
|
# while testing conv2d/3d/deconv2d
|
|
# the assumption is:
|
|
# (1) There will be a to_dtype kernel for input for lp
|
|
# (2) inductor always use channel_last format, there will
|
|
# be a to_channel_last format for input
|
|
# (3) to_dtype and to_channel_last for input can be fused
|
|
# (4) inductor always get channel last format from mkldnn_conv_pointwise(binary),
|
|
# and force the output to have same stride with eager.
|
|
# So there will be a to_contiguous for output if eager output is contiguouse
|
|
mod = copy.deepcopy(mod)
|
|
mod = mod.to(device=device)
|
|
input = input.clone()
|
|
input = input.to(device)
|
|
|
|
if dtype == torch.float32:
|
|
maybe_autocast = contextlib.nullcontext()
|
|
else:
|
|
maybe_autocast = torch.amp.autocast(device_type=device, dtype=dtype)
|
|
with torch.no_grad(), maybe_autocast:
|
|
output = mod(input)
|
|
input_kernel, output_kernel = 0, 0
|
|
if (
|
|
input.is_contiguous(memory_format=torch.contiguous_format)
|
|
or dtype != torch.float32
|
|
or (TEST_ACL and dim == 4)
|
|
):
|
|
input_kernel = 1
|
|
if output.is_contiguous(memory_format=torch.contiguous_format) or (
|
|
TEST_ACL and dtype == torch.bfloat16
|
|
):
|
|
output_kernel = 1
|
|
|
|
return input_kernel + output_kernel
|
|
|
|
|
|
class TestPatternMatcherBase(TestCase):
|
|
def setUp(self):
|
|
TestCase.setUp(self)
|
|
self.ctx_stack = contextlib.ExitStack()
|
|
self.ctx_stack.enter_context(config.patch({"freezing": True}))
|
|
|
|
def tearDown(self):
|
|
TestCase.tearDown(self)
|
|
self.ctx_stack.close()
|
|
|
|
def _check_unary_is_decomposed(self, unary_fn):
|
|
return not any(
|
|
isinstance(unary_fn, fn)
|
|
for fn in [torch.nn.ReLU, torch.nn.Sigmoid, torch.nn.Tanh]
|
|
)
|
|
|
|
def _clone_inputs(self, inputs):
|
|
def clone(x):
|
|
if not isinstance(x, torch.Tensor):
|
|
return x
|
|
return x.clone()
|
|
|
|
return tuple(clone(x) for x in inputs)
|
|
|
|
def _test_common(
|
|
self,
|
|
mod,
|
|
inputs,
|
|
matcher_check_fn,
|
|
atol=1e-5,
|
|
rtol=1.3e-6,
|
|
check_autocast=torch.float32,
|
|
check_quantization=False,
|
|
is_qat=False,
|
|
dtype=None,
|
|
is_dynamic=False,
|
|
quantizer=None,
|
|
compile_options={}, # noqa: B006
|
|
quantization_with_autocast=False,
|
|
):
|
|
if not hasattr(self, "device"):
|
|
has_xpu = any(
|
|
isinstance(input, torch.Tensor) and input.device.type == "xpu"
|
|
for input in inputs
|
|
)
|
|
device = "xpu" if has_xpu else "cpu"
|
|
else:
|
|
device = self.device
|
|
|
|
mod = mod.to(device=device)
|
|
if device != "cpu":
|
|
inputs = tuple(
|
|
clone_preserve_strides_offset(x, device=device) for x in inputs
|
|
)
|
|
counters.clear()
|
|
torch._dynamo.reset()
|
|
if check_autocast == torch.bfloat16 and is_mkldnn_bf16_supported(device):
|
|
maybe_autocast = torch.amp.autocast(
|
|
device_type=device, dtype=torch.bfloat16
|
|
)
|
|
atol, rtol = 1e-2, 1e-2
|
|
elif check_autocast == torch.float16 and (is_mkldnn_fp16_supported(device)):
|
|
maybe_autocast = torch.amp.autocast(device_type=device, dtype=torch.float16)
|
|
atol, rtol = 1e-2, 1e-2
|
|
else:
|
|
assert check_autocast == torch.float32
|
|
maybe_autocast = contextlib.nullcontext()
|
|
if check_quantization:
|
|
if quantization_with_autocast:
|
|
with maybe_autocast:
|
|
convert_model = _generate_qdq_quantized_model(
|
|
mod, inputs, is_qat, is_dynamic, quantizer
|
|
)
|
|
else:
|
|
convert_model = _generate_qdq_quantized_model(
|
|
mod, inputs, is_qat, is_dynamic, quantizer
|
|
)
|
|
with torch.no_grad(), maybe_autocast:
|
|
_ = torch.compile(convert_model)(*inputs)
|
|
matcher_check_fn()
|
|
else:
|
|
with torch.no_grad(), maybe_autocast:
|
|
clone_inputs = self._clone_inputs(inputs)
|
|
expected = mod(*inputs)
|
|
actual = torch.compile(mod, **compile_options)(*clone_inputs)
|
|
if self.precision != 0:
|
|
torch.testing.assert_close(
|
|
actual, expected, atol=self.precision, rtol=self.precision
|
|
)
|
|
else:
|
|
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
|
|
matcher_check_fn()
|
|
|
|
def _test_code_common(
|
|
self,
|
|
mod,
|
|
inputs,
|
|
include_ops,
|
|
exclude_ops,
|
|
atol=1e-5,
|
|
rtol=1.3e-6,
|
|
check_quantization=False,
|
|
check_dynamic=None,
|
|
num_include_ops=None,
|
|
quantizer=None,
|
|
):
|
|
with torch.no_grad():
|
|
clone_inputs = self._clone_inputs(inputs)
|
|
if check_quantization:
|
|
mod = _generate_qdq_quantized_model(mod, inputs, quantizer=quantizer)
|
|
expected = mod(*inputs)
|
|
actual, (source_code,) = run_and_get_code(
|
|
torch.compile(mod, fullgraph=True, dynamic=check_dynamic),
|
|
*clone_inputs,
|
|
)
|
|
assert_keywords = ["assert_size_stride", "assert_alignment"]
|
|
filtered_lines = [
|
|
line
|
|
for line in source_code.splitlines()
|
|
if not any(assert_key in line for assert_key in assert_keywords)
|
|
]
|
|
source_code = "\n".join(filtered_lines)
|
|
|
|
for op in include_ops:
|
|
self.assertIn(op, source_code)
|
|
if num_include_ops is not None:
|
|
assert len(include_ops) == len(num_include_ops)
|
|
for i in range(len(include_ops)):
|
|
self.assertEqual(
|
|
source_code.count(include_ops[i]), num_include_ops[i]
|
|
)
|
|
for op in exclude_ops:
|
|
self.assertNotIn(op, source_code)
|
|
if check_dynamic is not None:
|
|
_check_has_dynamic_shape(self, source_code)
|
|
if not check_quantization:
|
|
# Skip due to reduce range setting for Quantization on preCI system.
|
|
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
|
|
|
|
|
|
class TestPatternMatcherGeneric(TestPatternMatcherBase):
|
|
def _test_conv_unary_base(self, dim=4):
|
|
assert dim == 4 or dim == 5
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
unary_fn,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
if dim == 4:
|
|
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1)
|
|
else:
|
|
self.conv = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1)
|
|
self.unary_fn = unary_fn
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return self.unary_fn(x)
|
|
|
|
dtypes = [
|
|
torch.float,
|
|
]
|
|
if is_mkldnn_bf16_supported(self.device):
|
|
dtypes.append(torch.bfloat16)
|
|
if is_mkldnn_fp16_supported(self.device):
|
|
dtypes.append(torch.float16)
|
|
cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d
|
|
options = itertools.product(
|
|
unary_list.keys(),
|
|
[torch.contiguous_format, cl_format],
|
|
dtypes,
|
|
)
|
|
|
|
for (
|
|
unary_fn,
|
|
memory_format,
|
|
dtype,
|
|
) in options:
|
|
if (
|
|
dtype != torch.float32
|
|
and torch.backends.mkldnn.matmul.fp32_precision == "tf32"
|
|
):
|
|
continue
|
|
metrics.reset()
|
|
if dim == 4:
|
|
x_shape = (1, 3, 56, 56)
|
|
else:
|
|
x_shape = (1, 3, 20, 56, 56)
|
|
mod = M(unary_fn).to(memory_format=memory_format).eval()
|
|
|
|
v = (
|
|
torch.randn(x_shape, dtype=torch.float32)
|
|
.add(1)
|
|
.to(memory_format=memory_format)
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
match_nodes = unary_list[unary_fn]
|
|
if dtype in (
|
|
torch.float16,
|
|
torch.bfloat16,
|
|
) and self._check_unary_is_decomposed(unary_fn):
|
|
# Has extra dtype conversion nodes for autocast.
|
|
match_nodes += 2
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
|
0 if TEST_ACL else match_nodes,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
|
|
)
|
|
|
|
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
|
generated_kernel_count = cal_conv_generated_kernel_number(
|
|
mod, v, dtype, dim, self.device
|
|
)
|
|
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
@reduced_f32_on_and_off()
|
|
def test_conv2d_unary(self, device):
|
|
self.device = device
|
|
self._test_conv_unary_base(dim=4)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
@reduced_f32_on_and_off()
|
|
def test_conv3d_unary(self, device):
|
|
self.device = device
|
|
self._test_conv_unary_base(dim=5)
|
|
|
|
def _test_conv_transpose_unary_base(self, dim=4):
|
|
assert dim == 4 or dim == 5
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
unary_fn,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
if dim == 4:
|
|
self.conv_transpose = torch.nn.ConvTranspose2d(
|
|
3, 16, 3, stride=2, padding=1
|
|
)
|
|
else:
|
|
self.conv_transpose = torch.nn.ConvTranspose3d(
|
|
3, 16, 3, stride=2, padding=1
|
|
)
|
|
self.unary_fn = unary_fn
|
|
|
|
def forward(self, x):
|
|
x = self.conv_transpose(x)
|
|
return self.unary_fn(x)
|
|
|
|
dtypes = [
|
|
torch.float,
|
|
]
|
|
if is_mkldnn_bf16_supported(self.device):
|
|
dtypes.append(torch.bfloat16)
|
|
if is_mkldnn_fp16_supported(self.device):
|
|
dtypes.append(torch.float16)
|
|
|
|
cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d
|
|
options = itertools.product(
|
|
unary_list,
|
|
[torch.contiguous_format, cl_format],
|
|
dtypes,
|
|
)
|
|
|
|
for unary_fn, memory_format, dtype in options:
|
|
metrics.reset()
|
|
if dim == 4:
|
|
x_shape = (1, 3, 28, 28)
|
|
else:
|
|
x_shape = (1, 3, 17, 28, 28)
|
|
mod = M(unary_fn).eval()
|
|
|
|
v = torch.randn(x_shape, dtype=torch.float32).to(
|
|
memory_format=memory_format
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
match_nodes = unary_list[unary_fn]
|
|
if dtype in (
|
|
torch.float16,
|
|
torch.bfloat16,
|
|
) and self._check_unary_is_decomposed(unary_fn):
|
|
# Has extra dtype conversion nodes for autocast.
|
|
match_nodes += 2
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
|
0 if TEST_ACL else match_nodes,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
|
|
)
|
|
|
|
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
|
generated_kernel_count = cal_conv_generated_kernel_number(
|
|
mod, v, dtype, dim, self.device
|
|
)
|
|
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
@skipIfXpu(
|
|
msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device."
|
|
)
|
|
@reduced_f32_on_and_off()
|
|
def test_conv_transpose2d_unary(self, device):
|
|
self.device = device
|
|
self._test_conv_transpose_unary_base(dim=4)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
@skipIfXpu(
|
|
msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device."
|
|
)
|
|
@reduced_f32_on_and_off()
|
|
def test_conv_transpose3d_unary(self, device):
|
|
self.device = device
|
|
self._test_conv_transpose_unary_base(dim=5)
|
|
|
|
def _test_conv_binary_base(self, dim=4):
|
|
assert dim == 4 or dim == 5
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
binary_fn,
|
|
has_relu,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
if dim == 4:
|
|
self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1)
|
|
self.conv2 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1)
|
|
else:
|
|
self.conv1 = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1)
|
|
self.conv2 = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1)
|
|
self.binary_fn = binary_fn
|
|
self.has_relu = has_relu
|
|
|
|
def forward(self, x):
|
|
x1 = self.conv1(x)
|
|
x2 = self.conv2(x)
|
|
if has_relu:
|
|
return self.binary_fn(x1, x2).relu()
|
|
else:
|
|
return self.binary_fn(x1, x2)
|
|
|
|
dtypes = [
|
|
torch.float,
|
|
]
|
|
if is_mkldnn_bf16_supported(self.device):
|
|
dtypes.append(torch.bfloat16)
|
|
if is_mkldnn_fp16_supported(self.device):
|
|
dtypes.append(torch.float16)
|
|
cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d
|
|
test_memory_format = [torch.contiguous_format, cl_format]
|
|
options = itertools.product(
|
|
binary_list,
|
|
[True, False],
|
|
test_memory_format,
|
|
dtypes,
|
|
)
|
|
|
|
for (
|
|
binary_fn,
|
|
has_relu,
|
|
memory_format,
|
|
dtype,
|
|
) in options:
|
|
if (
|
|
dtype != torch.float32
|
|
and torch.backends.mkldnn.matmul.fp32_precision == "tf32"
|
|
):
|
|
continue
|
|
metrics.reset()
|
|
if dim == 4:
|
|
x_shape = (1, 3, 56, 56)
|
|
else:
|
|
x_shape = (1, 3, 20, 56, 56)
|
|
mod = M(binary_fn, has_relu).eval()
|
|
v = (
|
|
torch.randn(x_shape, dtype=torch.float32, requires_grad=True)
|
|
.add(1)
|
|
.to(memory_format=memory_format)
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
match_nodes = binary_list[binary_fn][1]
|
|
if has_relu:
|
|
match_nodes += 1
|
|
self.assertEqual(
|
|
counters["inductor"][
|
|
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
|
],
|
|
0 if TEST_ACL else match_nodes,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 2
|
|
)
|
|
|
|
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
|
generated_kernel_count = cal_conv_generated_kernel_number(
|
|
mod, v, dtype, dim, self.device
|
|
)
|
|
self.assertEqual(metrics.generated_kernel_count, generated_kernel_count)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
@reduced_f32_on_and_off(0.02)
|
|
def test_conv2d_binary(self, device):
|
|
self.device = device
|
|
self._test_conv_binary_base(dim=4)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
@reduced_f32_on_and_off(0.02)
|
|
def test_conv3d_binary(self, device):
|
|
self.device = device
|
|
self._test_conv_binary_base(dim=5)
|
|
|
|
def _test_conv_binary_broadcast_shapes_base(self, dim=4):
|
|
assert dim == 4 or dim == 5
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
binary_fn,
|
|
has_relu,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
if dim == 4:
|
|
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1)
|
|
else:
|
|
self.conv = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1)
|
|
self.binary_fn = binary_fn
|
|
self.has_relu = has_relu
|
|
|
|
def forward(self, x, x2):
|
|
x1 = self.conv(x)
|
|
if has_relu:
|
|
return self.binary_fn(x1, x2).relu()
|
|
else:
|
|
return self.binary_fn(x1, x2)
|
|
|
|
dtypes = [
|
|
torch.float,
|
|
]
|
|
if is_mkldnn_bf16_supported(self.device):
|
|
dtypes.append(torch.bfloat16)
|
|
if is_mkldnn_fp16_supported(self.device):
|
|
dtypes.append(torch.float16)
|
|
cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d
|
|
test_memory_format = [torch.contiguous_format, cl_format]
|
|
if dim == 4:
|
|
input_shapes = [
|
|
[2, 3, 56, 56],
|
|
]
|
|
other_shapes = [[2, 16, 1, 1], [1, 16, 1, 1], [1, 1, 1, 1]]
|
|
else:
|
|
input_shapes = [
|
|
[2, 3, 20, 56, 56],
|
|
]
|
|
other_shapes = [[2, 16, 1, 1, 1], [1, 16, 1, 1, 1], [1, 1, 1, 1, 1]]
|
|
options = itertools.product(
|
|
binary_list,
|
|
input_shapes,
|
|
other_shapes,
|
|
[True, False],
|
|
test_memory_format,
|
|
dtypes,
|
|
)
|
|
|
|
for (
|
|
binary_fn,
|
|
x_shape,
|
|
other_shape,
|
|
has_relu,
|
|
memory_format,
|
|
dtype,
|
|
) in options:
|
|
metrics.reset()
|
|
mod = M(binary_fn, has_relu).eval()
|
|
x = (
|
|
torch.randn(x_shape, dtype=torch.float32, requires_grad=True)
|
|
.add(1)
|
|
.to(memory_format=memory_format)
|
|
)
|
|
other = (
|
|
torch.randn(other_shape, dtype=torch.float32, requires_grad=True)
|
|
.add(1)
|
|
.to(memory_format=memory_format)
|
|
.to(dtype)
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
match_nodes = binary_list[binary_fn][1]
|
|
if has_relu:
|
|
match_nodes += 1
|
|
self.assertEqual(
|
|
counters["inductor"][
|
|
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
|
],
|
|
0 if TEST_ACL else match_nodes,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_conv_weight_pack_matcher_nodes"], 1
|
|
)
|
|
|
|
self._test_common(mod, (x, other), matcher_check_fn, check_autocast=dtype)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
@reduced_f32_on_and_off()
|
|
def test_conv2d_binary_broadcast_shapes(self, device):
|
|
self.device = device
|
|
self._test_conv_binary_broadcast_shapes_base(dim=4)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
@reduced_f32_on_and_off()
|
|
def test_conv3d_binary_broadcast_shapes(self, device):
|
|
self.device = device
|
|
self._test_conv_binary_broadcast_shapes_base(dim=5)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
@unittest.skipIf(IS_FBCODE, "Failing in fbcode")
|
|
@reduced_f32_on_and_off()
|
|
def test_conv2d_linear_add_broadcast_shapes(self, device):
|
|
self.device = device
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1)
|
|
self.linear = torch.nn.Linear(3, 16)
|
|
|
|
def forward(self, x1, x2):
|
|
return self.conv(x1) + self.linear(x2)[:, :, None, None]
|
|
|
|
metrics.reset()
|
|
mod = M().eval()
|
|
x1 = torch.randn(2, 3, 56, 56)
|
|
x2 = torch.randn(2, 3)
|
|
|
|
def matcher_check_fn():
|
|
match_nodes = 0 if TEST_ACL else 2
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"],
|
|
match_nodes,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_conv_weight_pack_matcher_nodes"], 1
|
|
)
|
|
|
|
self._test_common(mod, (x1, x2), matcher_check_fn)
|
|
|
|
|
|
class TestPatternMatcher(TestPatternMatcherBase):
|
|
@reduced_f32_on_and_off()
|
|
def test_linear_unary(self, device="cpu"):
|
|
self.device = device
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
unary_fn,
|
|
in_features,
|
|
out_features,
|
|
bias,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(
|
|
in_features,
|
|
out_features,
|
|
bias,
|
|
**kwargs,
|
|
)
|
|
self.unary_fn = unary_fn
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return self.unary_fn(x)
|
|
|
|
dtypes = []
|
|
if is_mkldnn_bf16_supported(self.device):
|
|
dtypes.append(torch.bfloat16)
|
|
if is_mkldnn_fp16_supported(self.device):
|
|
dtypes.append(torch.float16)
|
|
if torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"]:
|
|
dtypes.append(torch.float32)
|
|
options = itertools.product(unary_list, [True, False], dtypes)
|
|
for unary_fn, bias, dtype in options:
|
|
if (
|
|
dtype != torch.float32
|
|
and torch.backends.mkldnn.matmul.fp32_precision == "tf32"
|
|
):
|
|
continue
|
|
metrics.reset()
|
|
mod = M(unary_fn, 10, 30, bias=bias).eval()
|
|
# only fuse for linear when the dtype is bf16
|
|
mod = mod
|
|
v = torch.randn(2, 10)
|
|
|
|
def matcher_check_fn():
|
|
match_nodes = unary_list[unary_fn]
|
|
if dtype != torch.float32 and self._check_unary_is_decomposed(unary_fn):
|
|
# Has extra dtype conversion nodes for autocast.
|
|
match_nodes += 2
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
|
0 if TEST_ACL else match_nodes,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1
|
|
)
|
|
|
|
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
|
# only generated 1 kernel for "to_dtype"
|
|
expected_kernel_count = 2 if TEST_ACL else 1
|
|
if dtype == torch.float32:
|
|
# In BF32, input is float32, will not generate kernel for "to_dtype"
|
|
expected_kernel_count -= 1
|
|
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
|
|
|
|
@reduced_f32_on_and_off()
|
|
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
|
def test_linear_fp32(self, device="cpu"):
|
|
self.device = device
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, bias):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 30, bias)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
for bias in [True, False]:
|
|
mod = M(bias=bias).eval()
|
|
v = torch.randn(2, 10)
|
|
|
|
# packing pass.
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1
|
|
)
|
|
|
|
self._test_common(mod, (v,), matcher_check_fn)
|
|
|
|
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
|
def test_linear_input_non_contiguous_3D_wo_bias(self, device="cpu"):
|
|
self.device = device
|
|
|
|
# Activation is 3D, non-contiguous and without Bias
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4096, 1024, bias=False)
|
|
|
|
def forward(self, x):
|
|
x = torch.ops.aten.permute.default(x, [0, 2, 1, 3])
|
|
x = torch.ops.aten.reshape.default(x, [4, 1, 4096])
|
|
return self.linear(x)
|
|
|
|
mod = M().eval()
|
|
v = torch.randn(4, 32, 1, 128)
|
|
|
|
dtypes = [torch.float]
|
|
if is_mkldnn_bf16_supported(self.device):
|
|
dtypes.append(torch.bfloat16)
|
|
if is_mkldnn_fp16_supported(self.device):
|
|
dtypes.append(torch.float16)
|
|
|
|
for dtype in dtypes:
|
|
torch._dynamo.reset()
|
|
autocast_enabled = (
|
|
True if dtype in [torch.bfloat16, torch.float16] else False
|
|
)
|
|
with (
|
|
torch.no_grad(),
|
|
torch.autocast(
|
|
device_type="cpu",
|
|
enabled=autocast_enabled,
|
|
dtype=dtype,
|
|
),
|
|
):
|
|
expected = mod(v)
|
|
actual, (source_code,) = run_and_get_code(
|
|
torch.compile(mod, fullgraph=True),
|
|
v,
|
|
)
|
|
self.assertIn(
|
|
"torch.ops.mkldnn._linear_pointwise.default"
|
|
if autocast_enabled
|
|
else "torch.ops.mkl._mkl_linear.default",
|
|
source_code,
|
|
)
|
|
torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
|
|
|
|
@skipIfXpu(
|
|
msg="Different with CPU, two linears will be concat on XPU for better performance"
|
|
)
|
|
def test_linear_add_bias(self, device="cpu"):
|
|
self.device = device
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, device, dtype, unary_fn, cast_bias):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(10, 64, bias=False)
|
|
self.bias1 = torch.randn(64, device=device)
|
|
self.linear2 = torch.nn.Linear(10, 64, bias=False)
|
|
self.bias2 = torch.randn(64, device=device)
|
|
if cast_bias:
|
|
self.bias1 = self.bias1.to(dtype=dtype, device=device)
|
|
self.bias2 = self.bias2.to(dtype=dtype, device=device)
|
|
self.unary_fn = unary_fn
|
|
|
|
def forward(self, x):
|
|
a = self.linear1(x) + self.bias1
|
|
b = self.linear2(x) + self.bias2
|
|
return self.unary_fn(a), self.unary_fn(b)
|
|
|
|
dtypes = []
|
|
if is_mkldnn_bf16_supported(self.device):
|
|
dtypes.append(torch.bfloat16)
|
|
if is_mkldnn_fp16_supported(self.device):
|
|
dtypes.append(torch.float16)
|
|
options = itertools.product(unary_list, dtypes)
|
|
for unary_fn, dtype in options:
|
|
metrics.reset()
|
|
fold_mod = M(self.device, dtype, unary_fn, cast_bias=True).eval()
|
|
v = torch.randn(2, 10)
|
|
|
|
def folder_matcher_check_fn():
|
|
match_nodes = unary_list[unary_fn]
|
|
if self._check_unary_is_decomposed(unary_fn):
|
|
# Has extra dtype conversion nodes for autocast.
|
|
match_nodes += 2
|
|
# we have 2 linears, so we double the matcher_count/nodes
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_unary_fusion_matcher_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
|
0 if TEST_ACL else match_nodes * 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
|
|
)
|
|
|
|
self._test_common(
|
|
fold_mod,
|
|
(v,),
|
|
folder_matcher_check_fn,
|
|
check_autocast=dtype,
|
|
)
|
|
self.assertEqual(metrics.generated_kernel_count, 3 if TEST_ACL else 1)
|
|
# we won't fold the bias if bias is not same dtype with weight
|
|
# https://github.com/pytorch/pytorch/pull/129138
|
|
metrics.reset()
|
|
mod = M(self.device, dtype, unary_fn, cast_bias=False).eval()
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
|
|
)
|
|
|
|
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
|
# 1 kernel for "to_lowp", 2 kernels for unary ops
|
|
self.assertEqual(metrics.generated_kernel_count, 3)
|
|
|
|
@reduced_f32_on_and_off()
|
|
def test_linear_binary(self, device="cpu"):
|
|
self.device = device
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(
|
|
in_channels, out_channels, bias=bias, **kwargs
|
|
)
|
|
self.binary_fn = binary_fn
|
|
|
|
def forward(self, x, y):
|
|
x = self.linear(x)
|
|
x = self.binary_fn(x, y.clone())
|
|
return x
|
|
|
|
dtypes = []
|
|
if is_mkldnn_bf16_supported(self.device):
|
|
dtypes.append(torch.bfloat16)
|
|
if is_mkldnn_fp16_supported(self.device):
|
|
dtypes.append(torch.float16)
|
|
if torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"]:
|
|
dtypes.append(torch.float32)
|
|
options = itertools.product(
|
|
binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
|
|
)
|
|
out_feature = 30
|
|
|
|
for binary_fn, input_shape, bias, dtype in options:
|
|
metrics.reset()
|
|
if (
|
|
dtype != torch.float32
|
|
and torch.backends.mkldnn.matmul.fp32_precision == "tf32"
|
|
):
|
|
continue
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"][
|
|
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
|
],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
reshape_linear_reshape_match_nodes = 3 if len(input_shape) == 3 else 0
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"],
|
|
reshape_linear_reshape_match_nodes,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 1
|
|
)
|
|
|
|
mod = M(binary_fn, input_shape[-1], out_feature, bias).eval()
|
|
v = torch.randn(input_shape)
|
|
other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype)
|
|
self._test_common(
|
|
mod,
|
|
(
|
|
v,
|
|
other,
|
|
),
|
|
matcher_check_fn,
|
|
check_autocast=dtype,
|
|
)
|
|
# only generated 1 kernel for "to_dtype"
|
|
expected_kernel_count = 2 if TEST_ACL else 1
|
|
if dtype == torch.float32:
|
|
# In BF32, input is float32, will not generate kernel for "to_dtype"
|
|
expected_kernel_count -= 1
|
|
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
|
|
|
|
def test_linear_binary_broadcast_shapes(self, device="cpu"):
|
|
self.device = device
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(
|
|
in_channels, out_channels, bias=bias, **kwargs
|
|
)
|
|
self.binary_fn = binary_fn
|
|
|
|
def forward(self, x, y):
|
|
x = self.linear(x)
|
|
x = self.binary_fn(x, y.clone())
|
|
return x
|
|
|
|
dtypes = []
|
|
if is_mkldnn_bf16_supported(self.device):
|
|
dtypes.append(torch.bfloat16)
|
|
if is_mkldnn_fp16_supported(self.device):
|
|
dtypes.append(torch.float16)
|
|
options = itertools.product(
|
|
binary_list,
|
|
(
|
|
([2, 3, 10], [1, 1, 30]),
|
|
([2, 10], [1, 30]),
|
|
),
|
|
(True, False),
|
|
dtypes,
|
|
)
|
|
out_feature = 30
|
|
|
|
for binary_fn, (input_shape, other_shape), bias, dtype in options:
|
|
metrics.reset()
|
|
mod = M(binary_fn, input_shape[-1], out_feature, bias).eval()
|
|
v = torch.randn(input_shape)
|
|
other = torch.randn(other_shape).to(dtype)
|
|
|
|
def matcher_check_fn():
|
|
reshape_linear_reshape_match_nodes = 3 if len(input_shape) == 3 else 0
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"],
|
|
reshape_linear_reshape_match_nodes,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"][
|
|
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
|
],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_linear_weight_pack_matcher_nodes"], 1
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(
|
|
v,
|
|
other,
|
|
),
|
|
matcher_check_fn,
|
|
check_autocast=dtype,
|
|
)
|
|
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
|
|
|
|
@skipIfXpu(
|
|
msg="Different with CPU, two linears will be concat on XPU for better performance"
|
|
)
|
|
def test_multi_linear_share_same_input(self, device="cpu"):
|
|
self.device = device
|
|
|
|
# llama pattern.
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
self.w1 = torch.nn.Linear(16, 16, bias=False)
|
|
self.w2 = torch.nn.Linear(16, 16, bias=False)
|
|
|
|
def forward(self, x):
|
|
return F.silu(self.w1(x)) * F.relu(self.w2(x))
|
|
|
|
dtypes = []
|
|
if is_mkldnn_bf16_supported(self.device):
|
|
dtypes.append(torch.bfloat16)
|
|
if is_mkldnn_fp16_supported(self.device):
|
|
dtypes.append(torch.float16)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
|
0 if TEST_ACL else 7,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_unary_fusion_matcher_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"], 6
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
|
|
)
|
|
|
|
for dtype in dtypes:
|
|
mod = M().to(dtype).eval()
|
|
v = torch.randn(2, 4, 16).to(dtype)
|
|
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
|
|
|
|
def _qconv2d_test_helper(
|
|
self,
|
|
device="cpu",
|
|
int8_mixed_bf16=False,
|
|
quantization_with_autocast=False,
|
|
):
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
|
|
self.conv2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1)
|
|
self.conv3 = torch.nn.Conv2d(
|
|
128, 128, kernel_size=3, stride=1, groups=4
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.conv3(self.conv2(self.conv(x)))
|
|
|
|
mod = M().eval().to(device=device)
|
|
v = (
|
|
torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False)
|
|
.add(1)
|
|
.to(device=device)
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1
|
|
# int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution]
|
|
# int8_mixed_bf16: [dequant_node, optional(convert_element_type_4),
|
|
# dequantize_per_channel, optional(convert_element_type_3), clone, convolution]
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 3
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_nodes"],
|
|
(16 if quantization_with_autocast else 18) if int8_mixed_bf16 else 12,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
|
quantization_with_autocast=quantization_with_autocast,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
def test_qconv2d_cpu(self):
|
|
r"""
|
|
This testcase will quantize a single Conv2d module.
|
|
"""
|
|
self._qconv2d_test_helper("cpu")
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_xpu(self):
|
|
r"""
|
|
This testcase will quantize a single Conv2d module.
|
|
"""
|
|
self._qconv2d_test_helper("xpu")
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfRocmArch(MI300_ARCH + MI350_ARCH)
|
|
def test_qconv2d_int8_mixed_bf16(self):
|
|
r"""
|
|
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
|
|
"""
|
|
self._qconv2d_test_helper(int8_mixed_bf16=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfRocmArch(MI300_ARCH + MI350_ARCH)
|
|
def test_qconv2d_int8_mixed_bf16_use_autocast(self):
|
|
r"""
|
|
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
|
|
"""
|
|
self._qconv2d_test_helper(int8_mixed_bf16=True, quantization_with_autocast=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_int8_mixed_bf16_xpu(self):
|
|
r"""
|
|
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
|
|
"""
|
|
self._qconv2d_test_helper(device="xpu", int8_mixed_bf16=True)
|
|
|
|
def _qconv2d_unary_test_helper(
|
|
self,
|
|
device="cpu",
|
|
int8_mixed_bf16=False,
|
|
unary_op=torch.nn.ReLU(),
|
|
qconv_unary_matcher_nodes=None,
|
|
):
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
|
|
self.unary_fn = copy.deepcopy(unary_op)
|
|
self.conv2 = torch.nn.Conv2d(
|
|
128, 128, kernel_size=3, stride=1, bias=False
|
|
)
|
|
self.unary_fn2 = copy.deepcopy(unary_op)
|
|
|
|
def forward(self, x):
|
|
tmp = self.unary_fn(self.conv(x))
|
|
return self.unary_fn2(self.conv2(tmp))
|
|
|
|
mod = M().eval().to(device=device)
|
|
v = (
|
|
torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False)
|
|
.add(1)
|
|
.to(device=device)
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 2
|
|
)
|
|
# 2. QConv2D Unary fusion in post-grad fusion pass * 2
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_matcher_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2
|
|
)
|
|
if qconv_unary_matcher_nodes:
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_matcher_nodes"],
|
|
0 if TEST_ACL else qconv_unary_matcher_nodes,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
check_quantization=True,
|
|
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
|
matcher_check_fn=matcher_check_fn,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_relu_cpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->ReLU pattern.
|
|
"""
|
|
self._qconv2d_unary_test_helper(device="cpu")
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_relu_xpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->ReLU pattern.
|
|
"""
|
|
self._qconv2d_unary_test_helper(device="xpu")
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_relu_int8_mixed_bf16_xpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization.
|
|
"""
|
|
self._qconv2d_unary_test_helper(int8_mixed_bf16=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_relu6_cpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->ReLU6 pattern.
|
|
"""
|
|
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_relu6_xpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->ReLU6 pattern.
|
|
"""
|
|
self._qconv2d_unary_test_helper(device="xpu", unary_op=torch.nn.ReLU6())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_hardtanh_cpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->Hardtanh pattern.
|
|
"""
|
|
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_hardtanh_xpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->Hardtanh pattern.
|
|
"""
|
|
self._qconv2d_unary_test_helper(device="xpu", unary_op=torch.nn.Hardtanh())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->Hardtanh pattern.
|
|
Match.nodes:
|
|
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
|
|
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
|
|
"""
|
|
self._qconv2d_unary_test_helper(
|
|
unary_op=torch.nn.Hardtanh(),
|
|
int8_mixed_bf16=True,
|
|
qconv_unary_matcher_nodes=11,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_hardtanh_int8_mixed_bf16_xpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->Hardtanh pattern.
|
|
Match.nodes:
|
|
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
|
|
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
|
|
"""
|
|
self._qconv2d_unary_test_helper(
|
|
device="xpu",
|
|
unary_op=torch.nn.Hardtanh(),
|
|
int8_mixed_bf16=True,
|
|
qconv_unary_matcher_nodes=11,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_hardswish_cpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->Hardswish pattern.
|
|
"""
|
|
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_hardswish_xpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->Hardswish pattern.
|
|
"""
|
|
self._qconv2d_unary_test_helper(device="xpu", unary_op=torch.nn.Hardswish())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->Hardswish pattern.
|
|
Match.nodes:
|
|
[qconv2d_pointwise_default, convert_element_type, add, clamp_min,
|
|
clamp_max, mul, div, convert_element_type, quantize_per_tensor]
|
|
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
|
|
"""
|
|
self._qconv2d_unary_test_helper(
|
|
unary_op=torch.nn.Hardswish(),
|
|
int8_mixed_bf16=True,
|
|
qconv_unary_matcher_nodes=17,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_hardswish_int8_mixed_bf16_xpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->Hardswish pattern.
|
|
Match.nodes:
|
|
[qconv2d_pointwise_default, convert_element_type, add, clamp_min,
|
|
clamp_max, mul, div, convert_element_type, quantize_per_tensor]
|
|
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
|
|
"""
|
|
self._qconv2d_unary_test_helper(
|
|
device="xpu",
|
|
unary_op=torch.nn.Hardswish(),
|
|
int8_mixed_bf16=True,
|
|
qconv_unary_matcher_nodes=17,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_silu_cpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->SiLU pattern.
|
|
"""
|
|
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_silu_xpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->SiLU pattern.
|
|
"""
|
|
self._qconv2d_unary_test_helper(device="xpu", unary_op=torch.nn.SiLU())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->SiLU pattern.
|
|
Match.nodes:
|
|
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
|
|
convert_element_type, quantize_per_tensor]
|
|
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
|
|
"""
|
|
self._qconv2d_unary_test_helper(
|
|
unary_op=torch.nn.SiLU(),
|
|
int8_mixed_bf16=True,
|
|
qconv_unary_matcher_nodes=11,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_silu_int8_mixed_bf16_xpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->SiLU pattern.
|
|
Match.nodes:
|
|
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
|
|
convert_element_type, quantize_per_tensor]
|
|
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
|
|
"""
|
|
self._qconv2d_unary_test_helper(
|
|
device="xpu",
|
|
unary_op=torch.nn.SiLU(),
|
|
int8_mixed_bf16=True,
|
|
qconv_unary_matcher_nodes=11,
|
|
)
|
|
|
|
def _qconv2d_add_test_helper(
|
|
self, device="cpu", use_relu=False, int8_mixed_bf16=False
|
|
):
|
|
r"""
|
|
This testcase will quantize a Conv2d->Add pattern as:
|
|
X
|
|
/ \
|
|
Conv1(X) Conv2(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Optional(relu)
|
|
|
|
|
Y
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
add_fn,
|
|
use_relu,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.add_fn = add_fn
|
|
self.relu = torch.nn.ReLU()
|
|
self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1, bias=False)
|
|
self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1, bias=False)
|
|
self.add_fn2 = add_fn
|
|
self.relu2 = torch.nn.ReLU()
|
|
self.use_relu = use_relu
|
|
|
|
def forward(self, x):
|
|
x1 = self.conv1(x)
|
|
x2 = self.conv2(x)
|
|
tmp = self.add_fn(x1, x2)
|
|
if self.use_relu:
|
|
tmp = self.relu(tmp)
|
|
tmp1 = self.conv3(tmp)
|
|
tmp2 = self.conv4(tmp)
|
|
res = self.add_fn2(tmp1, tmp2)
|
|
if self.use_relu:
|
|
res = self.relu2(res)
|
|
return res
|
|
|
|
for add_fn in quantization_add_fn_list + quantization_inplace_add_fn_list:
|
|
mod = M(add_fn, use_relu).eval().to(device=device)
|
|
v = (
|
|
torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False)
|
|
.add(1)
|
|
.to(device=device)
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant-Conv2D pattern matched in quantization weight prepack * 4
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 4
|
|
)
|
|
# 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 2
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_lower_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
|
)
|
|
|
|
def _qconv2d_add_test_helper2(
|
|
self, device="cpu", use_relu=False, int8_mixed_bf16=False
|
|
):
|
|
r"""
|
|
This testcase will quantize two Conv2d->Add patterns as:
|
|
|
|
Conv(X) extra input
|
|
\ /
|
|
Add
|
|
|
|
|
Optional(relu)
|
|
|
|
|
Y
|
|
|
|
, and
|
|
|
|
extra input Conv(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Optional(relu)
|
|
|
|
|
Y
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
add_fn,
|
|
use_relu,
|
|
swap_inputs,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.add_fn = add_fn
|
|
self.relu = torch.nn.ReLU()
|
|
self.conv2 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1, bias=False)
|
|
self.add_fn2 = add_fn
|
|
self.relu2 = torch.nn.ReLU()
|
|
self.use_relu = use_relu
|
|
self.swap_inputs = swap_inputs
|
|
|
|
def forward(self, x, x2, x3):
|
|
x1 = self.conv1(x)
|
|
if self.swap_inputs:
|
|
tmp = self.add_fn(x2, x1)
|
|
else:
|
|
tmp = self.add_fn(x1, x2)
|
|
if self.use_relu:
|
|
tmp = self.relu(tmp)
|
|
tmp1 = self.conv2(tmp)
|
|
if self.swap_inputs:
|
|
res = self.add_fn2(x3, tmp1)
|
|
else:
|
|
res = self.add_fn2(tmp1, x3)
|
|
if self.use_relu:
|
|
res = self.relu2(res)
|
|
return res
|
|
|
|
for add_fn, swap_inputs in itertools.product(
|
|
quantization_add_fn_list + quantization_inplace_add_fn_list, [False, True]
|
|
):
|
|
mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device)
|
|
x = torch.randn(
|
|
(1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device
|
|
)
|
|
x2 = torch.randn(
|
|
(1, 6, 6, 6), dtype=torch.float32, requires_grad=False, device=device
|
|
)
|
|
x3 = torch.randn(
|
|
(1, 6, 4, 4), dtype=torch.float32, requires_grad=False, device=device
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 2
|
|
)
|
|
# 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 2
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_lower_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(x, x2, x3),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_add_cpu(self):
|
|
self._qconv2d_add_test_helper()
|
|
self._qconv2d_add_test_helper2()
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_add_xpu(self):
|
|
self._qconv2d_add_test_helper(device="xpu")
|
|
self._qconv2d_add_test_helper2(device="xpu")
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_add_int8_mixed_bf16(self):
|
|
self._qconv2d_add_test_helper(int8_mixed_bf16=True)
|
|
self._qconv2d_add_test_helper2(int8_mixed_bf16=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_add_int8_mixed_bf16_xpu(self):
|
|
self._qconv2d_add_test_helper(device="xpu", int8_mixed_bf16=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_add_relu_cpu(self):
|
|
self._qconv2d_add_test_helper(use_relu=True)
|
|
self._qconv2d_add_test_helper2(use_relu=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_add_relu_xpu(self):
|
|
self._qconv2d_add_test_helper(device="xpu", use_relu=True)
|
|
self._qconv2d_add_test_helper2(device="xpu", use_relu=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_add_relu_int8_mixed_bf16(self):
|
|
self._qconv2d_add_test_helper(use_relu=True, int8_mixed_bf16=True)
|
|
self._qconv2d_add_test_helper2(use_relu=True, int8_mixed_bf16=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qconv2d_add_relu_int8_mixed_bf16_xpu(self):
|
|
self._qconv2d_add_test_helper(device="xpu", use_relu=True, int8_mixed_bf16=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_add_broadcast_shapes_cpu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->add pattern using broadcast shape inputs.
|
|
Conv2d->Add fusion will fail for the broadcast shape inputs case.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, use_bias):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(32, 32, kernel_size=3, stride=1)
|
|
|
|
def forward(self, x1, x2):
|
|
return torch.add(self.conv(x1), x2)
|
|
|
|
bias_list = [True, False]
|
|
for bias in bias_list:
|
|
mod = M(bias).eval()
|
|
x1 = torch.randn((2, 32, 9, 9))
|
|
x2 = torch.randn((2, 32, 1, 1))
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant-Conv2D pattern matched in quantization weight prepack * 1
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 1
|
|
)
|
|
# 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 0
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_count"], 0
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(x1, x2),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_with_concat_cpu(self):
|
|
channel_1 = 32
|
|
channel_2 = 16
|
|
channel_3 = 8
|
|
channel_4 = int(channel_2 * 2 + channel_3)
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(
|
|
channel_1, channel_2, 1, stride=1, dilation=1, padding=0
|
|
)
|
|
self.conv2 = torch.nn.Conv2d(
|
|
channel_1, channel_2, 1, stride=1, dilation=1, padding=0
|
|
)
|
|
self.conv3 = torch.nn.Conv2d(
|
|
channel_2, channel_3, 3, stride=1, dilation=1, padding=1
|
|
)
|
|
|
|
self.conv = torch.nn.Conv2d(
|
|
channel_4, channel_2, 1, stride=1, dilation=1, padding=0
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
x1 = self.conv1(x)
|
|
x2 = self.conv2(x)
|
|
x3 = self.conv3(x2)
|
|
res = torch.cat([x1, x2, x3], dim=1)
|
|
res = self.conv(res)
|
|
return res
|
|
|
|
mod = Model().eval()
|
|
v = torch.randn(
|
|
(8, channel_1, 40, 40), dtype=torch.float32, requires_grad=False
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 4
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_matcher_count"],
|
|
0 if TEST_ACL else 3,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 4
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_add_2(self):
|
|
r"""
|
|
This testcase prevents this pattern be matched as a conv_binary fusion by mistake.
|
|
Conv(X) 3
|
|
\ /
|
|
Add
|
|
We see this pattern in Mobilenet v3 large which add is decomposed from torch.nn.Hardswish or torch.nn.Hardsigmoid.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
post_op,
|
|
):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.post_op = post_op
|
|
|
|
def forward(self, x):
|
|
return self.post_op(self.conv(x))
|
|
|
|
for post_op in [
|
|
torch.nn.Hardswish(inplace=True),
|
|
torch.nn.Hardsigmoid(inplace=True),
|
|
]:
|
|
mod = M(post_op).eval()
|
|
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
|
|
1
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
# Shouldn't hit conv binary fusion
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_count"], 0
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv2d_add_3(self):
|
|
r"""
|
|
This testcase will test below model:
|
|
x
|
|
/ \
|
|
conv1 maxpool
|
|
\ / \
|
|
add conv2
|
|
\ /
|
|
cat
|
|
Based on default recipe of x86InductorQuantizer, we will see this pattern after convert:
|
|
qconv1 maxpool
|
|
\ |
|
|
\ q1
|
|
\ / \
|
|
\ dq1 qconv2
|
|
\ /
|
|
add
|
|
|
|
|
q2
|
|
Since q1 has 2 users and qconv2 is not ancestor node of qconv1, we shouldn't fuse:
|
|
int8
|
|
/
|
|
qconv1 dq1
|
|
\ /
|
|
add
|
|
|
|
|
q2
|
|
|
|
|
int8
|
|
Instead we can match and fuse this pattern into qconv_binary:
|
|
qconv1 fp32
|
|
\ /
|
|
add
|
|
|
|
|
fp32
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1)
|
|
self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1)
|
|
self.maxpool = torch.nn.MaxPool2d(
|
|
kernel_size=3, stride=1, padding=0, dilation=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
tmp1 = self.conv1(x)
|
|
tmp2 = self.maxpool(x)
|
|
add = torch.add(tmp1, tmp2)
|
|
tmp3 = self.conv2(tmp2)
|
|
return torch.cat((add, tmp3), dim=1)
|
|
|
|
mod = M().eval()
|
|
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
# The matched qconv binary pattern should have 2 nodes [qconv, add]
|
|
# instead of 11 which has dequant in binary input and output quant
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_nodes"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_lower_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
def test_qat_qconv2d(self):
|
|
r"""
|
|
This testcase will quantize a single Conv2d module with qat flow.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1)
|
|
self.bn = torch.nn.BatchNorm2d(128)
|
|
|
|
def forward(self, x):
|
|
return self.bn(self.conv(x))
|
|
|
|
mod = M().train()
|
|
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant-conv pattern matched in quantization weight prepack * 1
|
|
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 1
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 4
|
|
)
|
|
# 2. QConv2D Unary fusion in post-grad fusion pass * 1
|
|
# [qconv2d_pointwise_default, quantize_per_tensor]
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_matcher_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_matcher_nodes"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 1
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
is_qat=True,
|
|
)
|
|
|
|
def _qat_qconv2d_unary_cpu_test_helper(
|
|
self,
|
|
unary_op=torch.nn.ReLU(),
|
|
):
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1)
|
|
self.unary_fn = copy.deepcopy(unary_op)
|
|
self.bn = torch.nn.BatchNorm2d(3)
|
|
self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1)
|
|
self.unary_fn2 = copy.deepcopy(unary_op)
|
|
self.bn2 = torch.nn.BatchNorm2d(3)
|
|
|
|
def forward(self, x):
|
|
tmp = self.unary_fn(self.bn(self.conv(x)))
|
|
return self.unary_fn2(self.bn2(self.conv2(tmp)))
|
|
|
|
mod = M()
|
|
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant-conv pattern matched in quantization weight prepack * 1
|
|
# [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution]
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 2
|
|
)
|
|
# 2. QConv2D Unary fusion in post-grad fusion pass * 1
|
|
# [qconv2d_pointwise_default, relu, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2]
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_matcher_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
is_qat=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qat_qconv2d_relu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->ReLU pattern with qat flow.
|
|
"""
|
|
|
|
self._qat_qconv2d_unary_cpu_test_helper()
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qat_qconv2d_relu6(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->ReLU6 pattern with qat flow.
|
|
"""
|
|
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.ReLU6())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qat_qconv2d_hardtanh(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->Hardtanh pattern with qat flow.
|
|
"""
|
|
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardtanh())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qat_qconv2d_silu(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->SiLU pattern with qat flow.
|
|
"""
|
|
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.SiLU())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qat_qconv2d_hardswish(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->Hardswish pattern with qat flow.
|
|
"""
|
|
self._qat_qconv2d_unary_cpu_test_helper(unary_op=torch.nn.Hardswish())
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
def test_qat_qconv2d_add(self):
|
|
r"""
|
|
This testcase will quantize a Conv2d->Add pattern as:
|
|
X
|
|
/ \
|
|
Conv1(X) Conv2(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Y
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.bn1 = torch.nn.BatchNorm2d(6)
|
|
self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.bn2 = torch.nn.BatchNorm2d(6)
|
|
|
|
def forward(self, x):
|
|
x1 = self.bn1(self.conv1(x))
|
|
x2 = self.bn2(self.conv2(x))
|
|
return x1 + x2
|
|
|
|
mod = M().train()
|
|
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant-conv pattern matched in quantization weight prepack * 2
|
|
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 2
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 8
|
|
)
|
|
# 2. Qconv2d Binary fusion in post-grad fusion pass * 1
|
|
# [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, quantize_per_tensor]
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_nodes"],
|
|
0 if TEST_ACL else 4,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_lower_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
is_qat=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
def test_qat_qconv2d_add_relu(self):
|
|
r"""
|
|
This testcase will quantize a Conv2d->Add->ReLU pattern as:
|
|
X
|
|
/ \
|
|
Conv1(X) Conv2(X)
|
|
\ /
|
|
Add
|
|
|
|
|
ReLU
|
|
|
|
|
Y
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.bn1 = torch.nn.BatchNorm2d(6)
|
|
self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.bn2 = torch.nn.BatchNorm2d(6)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x1 = self.bn1(self.conv1(x))
|
|
x2 = self.bn2(self.conv2(x))
|
|
return self.relu(x1 + x2)
|
|
|
|
mod = M().train()
|
|
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant-conv pattern matched in quantization weight prepack * 2
|
|
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 2
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 8
|
|
)
|
|
# 2. Qconv2d Binary fusion in post-grad fusion pass * 1
|
|
# [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, relu, quantize_per_tensor]
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_nodes"],
|
|
0 if TEST_ACL else 5,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_lower_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
is_qat=True,
|
|
)
|
|
|
|
def _test_qconv2d_dequant_promotion_helper(self, device="cpu"):
|
|
r"""
|
|
This testcase tests if dequant node before conv2d is promoted correctly:
|
|
X
|
|
|
|
|
Conv1(X)
|
|
/ \
|
|
Conv2(X) Conv3(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Y
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.conv2 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
|
|
self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
|
|
|
|
def forward(self, x):
|
|
temp = self.conv1(x)
|
|
temp = self.conv2(temp) + self.conv3(temp)
|
|
return temp
|
|
|
|
mod = M().eval().to(device=device)
|
|
v = (
|
|
torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False)
|
|
.add(1)
|
|
.to(device=device)
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant pattern matcher for dequant promotion * 1
|
|
# [dequantize_per_tensor]
|
|
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
|
|
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_nodes"], 1)
|
|
# 2. Dequant-conv pattern matched in quantization weight prepack * 3
|
|
# [dequantize_per_tensor, dequantize_per_channel, clone, convolution]
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 3
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 12
|
|
)
|
|
# 3. Qconv2d Binary fusion in post-grad fusion pass * 1
|
|
# [qconv2d_pointwise_default_1, add_3]
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_matcher_nodes"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv2d_binary_lower_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
def test_qconv2d_dequant_promotion_cpu(self):
|
|
self._test_qconv2d_dequant_promotion_helper()
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfRocm
|
|
@skipIfNoXPU
|
|
def test_qconv2d_dequant_promotion_xpu(self):
|
|
self._test_qconv2d_dequant_promotion_helper(device="xpu")
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qconv1d_relu_cpu(self):
|
|
r"""
|
|
This testcase will quantize Conv1d->ReLU pattern.
|
|
"""
|
|
device = "cpu"
|
|
unary_op = torch.nn.ReLU()
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv1d(3, 128, kernel_size=3, stride=1)
|
|
self.unary_fn = copy.deepcopy(unary_op)
|
|
self.conv2 = torch.nn.Conv1d(
|
|
128, 128, kernel_size=3, stride=1, bias=False
|
|
)
|
|
self.unary_fn2 = copy.deepcopy(unary_op)
|
|
|
|
def forward(self, x):
|
|
tmp = self.unary_fn(self.conv(x))
|
|
return self.unary_fn2(self.conv2(tmp))
|
|
|
|
mod = M().eval().to(device=device)
|
|
v = (
|
|
torch.randn((1, 3, 8), dtype=torch.float32, requires_grad=False)
|
|
.add(1)
|
|
.to(device=device)
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 2
|
|
)
|
|
# 2. QConv2D Unary fusion in post-grad fusion pass * 2
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_matcher_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
check_quantization=True,
|
|
matcher_check_fn=matcher_check_fn,
|
|
)
|
|
|
|
def _qlinear_test_helper(
|
|
self,
|
|
inputs,
|
|
device="cpu",
|
|
int8_mixed_bf16=False,
|
|
do_permute=False,
|
|
matcher_check_fn=None,
|
|
bias=True,
|
|
is_dynamic=False,
|
|
is_qat=False,
|
|
quantization_with_autocast=False,
|
|
):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, use_bias, do_permute=False):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 3, use_bias)
|
|
self.linear2 = torch.nn.Linear(3, 4, use_bias)
|
|
self.do_permute = do_permute
|
|
|
|
def forward(self, x):
|
|
if self.do_permute:
|
|
x = torch.reshape(torch.permute(x, (0, 2, 3, 1)), (2, 12, 4))
|
|
return self.linear2(self.linear(x))
|
|
|
|
mod = M(bias, do_permute=do_permute).eval().to(device=device)
|
|
assert isinstance(inputs, tuple)
|
|
|
|
def __convert_tensor_to_device(input, device):
|
|
return input.to(device=device) if isinstance(input, torch.Tensor) else input
|
|
|
|
inputs = tuple(__convert_tensor_to_device(input, device) for input in inputs)
|
|
|
|
def _default_matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
inputs,
|
|
matcher_check_fn=(
|
|
matcher_check_fn
|
|
if matcher_check_fn is not None
|
|
else _default_matcher_check_fn
|
|
),
|
|
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
|
check_quantization=True,
|
|
is_qat=is_qat,
|
|
is_dynamic=is_dynamic,
|
|
quantization_with_autocast=quantization_with_autocast,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_cpu(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_xpu(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 4)).to(device="xpu"),), device="xpu", bias=bias
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_dynamic_qlinear_cpu(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 4)),), bias=bias, is_dynamic=True
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_dynamic_qlinear_qat_cpu(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 4)),), bias=bias, is_dynamic=True, is_qat=True
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_dynamic_qlinear_input_dim_exceeds_2(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 3, 4)),), bias=bias, is_dynamic=True
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_int8_mixed_bf16(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 4)),), int8_mixed_bf16=True, bias=bias
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_int8_mixed_bf16_use_autocast(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 4)),),
|
|
int8_mixed_bf16=True,
|
|
bias=bias,
|
|
quantization_with_autocast=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoXPU
|
|
def test_qlinear_int8_mixed_bf16_xpu(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 4)).to(device="xpu"),),
|
|
device="xpu",
|
|
int8_mixed_bf16=True,
|
|
bias=bias,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_input_dim_exceeds_2(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper((torch.randn((2, 3, 4)),), bias=bias)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_input_dim_exceeds_2_xpu(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 3, 4)).to(device="xpu"),), device="xpu", bias=bias
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 3, 4)),), int8_mixed_bf16=True, bias=bias
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_use_autocast(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 3, 4)),),
|
|
int8_mixed_bf16=True,
|
|
bias=bias,
|
|
quantization_with_autocast=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_xpu(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization.
|
|
"""
|
|
for bias in [True, False]:
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 3, 4)).to(device="xpu"),),
|
|
device="xpu",
|
|
int8_mixed_bf16=True,
|
|
bias=bias,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_input_dim_exceeds_2_and_not_contiguous(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module.
|
|
* Input dim exceeds 2
|
|
* Input not contiguous
|
|
"""
|
|
for bias in [True, False]:
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
|
|
13 if bias else 12,
|
|
)
|
|
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 4, 3, 4)),),
|
|
do_permute=True,
|
|
matcher_check_fn=matcher_check_fn,
|
|
bias=bias,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_and_not_contiguous(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module for int8_bf16.
|
|
* Input dim exceeds 2
|
|
* Input not contiguous
|
|
"""
|
|
for bias in [True, False]:
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
|
|
17 if bias else 16,
|
|
)
|
|
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 4, 3, 4)),),
|
|
int8_mixed_bf16=True,
|
|
do_permute=True,
|
|
matcher_check_fn=matcher_check_fn,
|
|
bias=bias,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_and_not_contiguous_use_autocast(
|
|
self,
|
|
):
|
|
r"""
|
|
This testcase will quantize a single Linear Module for int8_bf16.
|
|
* Input dim exceeds 2
|
|
* Input not contiguous
|
|
"""
|
|
for bias in [True, False]:
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
|
|
16 if bias else 15,
|
|
)
|
|
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 4, 3, 4)),),
|
|
int8_mixed_bf16=True,
|
|
do_permute=True,
|
|
matcher_check_fn=matcher_check_fn,
|
|
bias=bias,
|
|
quantization_with_autocast=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_and_not_contiguous_xpu(self):
|
|
r"""
|
|
This testcase will quantize a single Linear Module for int8_bf16.
|
|
* Input dim exceeds 2
|
|
* Input not contiguous
|
|
"""
|
|
for bias in [True, False]:
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
|
|
17 if bias else 16,
|
|
)
|
|
|
|
self._qlinear_test_helper(
|
|
(torch.randn((2, 4, 3, 4)).to(device="xpu"),),
|
|
device="xpu",
|
|
int8_mixed_bf16=True,
|
|
do_permute=True,
|
|
matcher_check_fn=matcher_check_fn,
|
|
bias=bias,
|
|
)
|
|
|
|
def _qlinear_unary_test_helper(
|
|
self, inputs, unary_op=torch.nn.ReLU(), device="cpu", int8_mixed_bf16=False
|
|
):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, use_bias):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4, use_bias)
|
|
self.unary_fn = copy.deepcopy(unary_op)
|
|
self.linear2 = torch.nn.Linear(4, 4, use_bias)
|
|
self.unary_fn2 = copy.deepcopy(unary_op)
|
|
|
|
def forward(self, x):
|
|
tmp = self.unary_fn(self.linear(x))
|
|
return self.unary_fn2(self.linear2(tmp))
|
|
|
|
bias_list = [True, False]
|
|
for bias in bias_list:
|
|
mod = M(bias).eval().to(device=device)
|
|
|
|
def matcher_check_fn():
|
|
# 1. dequant-linear pattern matched in quantization weight prepack
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 2
|
|
)
|
|
# 2. QLinear Unary fusion in post-grad fusion pass
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_unary_matcher_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_unary_lower_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
inputs,
|
|
matcher_check_fn,
|
|
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
|
check_quantization=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_relu_cpu(self):
|
|
r"""
|
|
This testcase will quantize a Linear->ReLU pattern.
|
|
"""
|
|
self._qlinear_unary_test_helper((torch.randn((2, 4)),))
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_relu_xpu(self):
|
|
r"""
|
|
This testcase will quantize a Linear->ReLU pattern.
|
|
"""
|
|
self._qlinear_unary_test_helper(
|
|
(torch.randn((2, 4)).to(device="xpu"),), device="xpu"
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_relu_int8_mixed_bf16(self):
|
|
r"""
|
|
This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization.
|
|
"""
|
|
self._qlinear_unary_test_helper((torch.randn((2, 4)),), int8_mixed_bf16=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_relu_int8_mixed_bf16_xpu(self):
|
|
r"""
|
|
This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization.
|
|
"""
|
|
self._qlinear_unary_test_helper(
|
|
(torch.randn((2, 4)).to(device="xpu"),), device="xpu", int8_mixed_bf16=True
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_relu_input_dim_exceeds_2(self):
|
|
r"""
|
|
This testcase will quantize a Linear->ReLU pattern.
|
|
"""
|
|
self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),))
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_relu_input_dim_exceeds_2_xpu(self):
|
|
r"""
|
|
This testcase will quantize a Linear->ReLU pattern.
|
|
"""
|
|
self._qlinear_unary_test_helper(
|
|
(torch.randn((2, 3, 4)).to(device="xpu"),), device="xpu"
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_relu_int8_mixed_bf16_input_dim_exceeds_2(self):
|
|
r"""
|
|
This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization.
|
|
"""
|
|
self._qlinear_unary_test_helper((torch.randn((2, 3, 4)),), int8_mixed_bf16=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_relu_int8_mixed_bf16_input_dim_exceeds_2_xpu(self):
|
|
r"""
|
|
This testcase will quantize a Linear->ReLU pattern with int8_mixed_bf16 quantization.
|
|
"""
|
|
self._qlinear_unary_test_helper(
|
|
(torch.randn((2, 3, 4)).to(device="xpu"),),
|
|
device="xpu",
|
|
int8_mixed_bf16=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_gelu_cpu(self):
|
|
r"""
|
|
This testcase will quantize a Linear->GELU pattern.
|
|
"""
|
|
for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]:
|
|
self._qlinear_unary_test_helper((torch.randn((2, 4)),), gelu)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_gelu_xpu(self):
|
|
r"""
|
|
This testcase will quantize a Linear->GELU pattern.
|
|
"""
|
|
for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]:
|
|
self._qlinear_unary_test_helper(
|
|
(torch.randn((2, 4)).to(device="xpu"),), gelu, device="xpu"
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_gelu_int8_mixed_bf16(self):
|
|
r"""
|
|
This testcase will quantize a Linear->GELU pattern with int8_mixed_bf16 quantization.
|
|
"""
|
|
for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]:
|
|
self._qlinear_unary_test_helper(
|
|
(torch.randn((2, 4)),), gelu, int8_mixed_bf16=True
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_gelu_int8_mixed_bf16_xpu(self):
|
|
r"""
|
|
This testcase will quantize a Linear->GELU pattern with int8_mixed_bf16 quantization.
|
|
"""
|
|
for gelu in [torch.nn.GELU("none"), torch.nn.GELU("tanh")]:
|
|
self._qlinear_unary_test_helper(
|
|
(torch.randn((2, 4)).to(device="xpu"),),
|
|
gelu,
|
|
device="xpu",
|
|
int8_mixed_bf16=True,
|
|
)
|
|
|
|
def _qlinear_add_test_helper(
|
|
self,
|
|
device="cpu",
|
|
use_relu=False,
|
|
int8_mixed_bf16=False,
|
|
is_qat=True,
|
|
is_dynamic=True,
|
|
):
|
|
r"""
|
|
This testcase will quantize two consecutive Linear->Add(->relu) patterns as:
|
|
X
|
|
/ \
|
|
linear(X) linear(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Optional(relu)
|
|
/ \
|
|
linear(X) linear(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Optional(relu)
|
|
|
|
|
Y
|
|
"""
|
|
|
|
def fake_quant(x):
|
|
# to produce a float32 result as extra input
|
|
qlib = torch.ops.quantized_decomposed
|
|
if device == "cpu":
|
|
qmin, qmax, dtype = 0, 255, torch.uint8
|
|
else:
|
|
qmin, qmax, dtype = -128, 127, torch.int8
|
|
x = qlib.quantize_per_tensor.default(x, 0.0166785, 42, qmin, qmax, dtype)
|
|
x = qlib.dequantize_per_tensor.default(x, 0.0166785, 42, qmin, qmax, dtype)
|
|
return x
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
add_fn,
|
|
use_relu,
|
|
fake_quant_before_extra_input,
|
|
):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(4, 4)
|
|
self.linear2 = torch.nn.Linear(4, 4)
|
|
self.add_fn = add_fn
|
|
self.relu = torch.nn.ReLU()
|
|
self.linear3 = torch.nn.Linear(4, 4)
|
|
self.linear4 = torch.nn.Linear(4, 4)
|
|
self.add_fn2 = add_fn
|
|
self.relu2 = torch.nn.ReLU()
|
|
self.use_relu = use_relu
|
|
self.fake_quant_before_extra_input = fake_quant_before_extra_input
|
|
|
|
def forward(self, x):
|
|
x1 = self.linear1(x)
|
|
x2 = self.linear2(x)
|
|
if self.fake_quant_before_extra_input:
|
|
x2 = fake_quant(x2)
|
|
tmp = self.add_fn(x1, x2)
|
|
if self.use_relu:
|
|
tmp = self.relu(tmp)
|
|
tmp1 = self.linear3(tmp)
|
|
tmp2 = self.linear4(tmp)
|
|
if self.fake_quant_before_extra_input:
|
|
tmp2 = fake_quant(tmp2)
|
|
res = self.add_fn2(tmp1, tmp2)
|
|
if self.use_relu:
|
|
res = self.relu2(res)
|
|
return res
|
|
|
|
add_fn_list = [
|
|
lambda x, y: x + y,
|
|
lambda x, y: y + x,
|
|
lambda x, y: x.add_(y),
|
|
lambda x, y: y.add_(x),
|
|
]
|
|
fake_quant_x2_list = [False, True] if int8_mixed_bf16 else [False]
|
|
shape_list = [(4, 4), [4, 4, 4]]
|
|
cases = itertools.product(add_fn_list, fake_quant_x2_list, shape_list)
|
|
for add_fn, fq_x2, shape in cases:
|
|
mod = M(add_fn, use_relu, fq_x2).eval().to(device=device)
|
|
v = torch.randn(
|
|
shape, dtype=torch.float32, requires_grad=False, device=device
|
|
).add(1)
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant-linear pattern matched in quantization weight prepack * 4
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4
|
|
)
|
|
# pattern = [dequant_per_tensor, (convert_dtype), dequant_per_channel, (convert_dtype), permute, addmm]
|
|
nodes_per_match = 6 if int8_mixed_bf16 else 4
|
|
if len(shape) == 3:
|
|
# pattern = [dequant_per_tensor, (convert_dtype), (view), \
|
|
# dequant_per_channel, (convert_dtype), (view), permute, addmm]
|
|
nodes_per_match += 2
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
|
|
4 * nodes_per_match,
|
|
)
|
|
# 2. Qlinear Binary Unary fusion in post-grad fusion pass * 2
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_binary_matcher_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
# Two linear-binary patterns are matched
|
|
# matched patter1 = [qlinear, add, (convert dtype), (relu), quantize_per_tensor]
|
|
# matched patter2 = [qlinear, add, (convert dtype), (relu)]
|
|
# If add_fn is x.add_(y), x is bf16 and y is fp32, there is a to_bf16 node after binary
|
|
to_bf16_after_binary = 2 * (add_fn == add_fn_list[2] and fq_x2)
|
|
expected_matcher_nodes = (
|
|
(4 if is_dynamic else 5) + 2 * use_relu + to_bf16_after_binary
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_binary_matcher_nodes"],
|
|
0 if TEST_ACL else expected_matcher_nodes,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_binary_lower_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
|
is_qat=is_qat,
|
|
is_dynamic=is_dynamic,
|
|
)
|
|
|
|
if TEST_ACL:
|
|
continue
|
|
|
|
if torch._inductor.config.cpp_wrapper:
|
|
# For CPP wrapper
|
|
self._test_code_common(
|
|
mod,
|
|
(v,),
|
|
[
|
|
f"aoti_torch_{device}__qlinear_pointwise_tensor",
|
|
f"aoti_torch_{device}__qlinear_pointwise_binary_tensor",
|
|
],
|
|
[],
|
|
check_quantization=True,
|
|
num_include_ops=[2, 2],
|
|
)
|
|
else:
|
|
# For python wrapper
|
|
self._test_code_common(
|
|
mod,
|
|
(v,),
|
|
[
|
|
"torch.ops.onednn.qlinear_pointwise.tensor",
|
|
"torch.ops.onednn.qlinear_pointwise.binary",
|
|
],
|
|
[],
|
|
check_quantization=True,
|
|
num_include_ops=[2, 2],
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@parametrize("use_relu", [True, False])
|
|
@parametrize("is_qat", [True, False])
|
|
@parametrize("is_dynamic", [True, False])
|
|
def test_qlinear_add_cpu(self, use_relu, is_qat, is_dynamic):
|
|
self._qlinear_add_test_helper(
|
|
use_relu=use_relu, is_qat=is_qat, is_dynamic=is_dynamic
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
@config.patch({"fx_graph_cache": False})
|
|
@parametrize("use_relu", [True])
|
|
@parametrize("is_qat", [False])
|
|
@parametrize("is_dynamic", [False])
|
|
def test_qlinear_add_xpu(self, use_relu, is_qat, is_dynamic):
|
|
self._qlinear_add_test_helper(
|
|
device="xpu", use_relu=use_relu, is_qat=is_qat, is_dynamic=is_dynamic
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@parametrize("use_relu", [True, False])
|
|
@parametrize("is_qat", [True, False])
|
|
@parametrize("is_dynamic", [True, False])
|
|
def test_qlinear_add_int8_mixed_bf16(self, use_relu, is_qat, is_dynamic):
|
|
self._qlinear_add_test_helper(
|
|
int8_mixed_bf16=True,
|
|
use_relu=use_relu,
|
|
is_qat=is_qat,
|
|
is_dynamic=is_dynamic,
|
|
)
|
|
|
|
@skipIfNoXPU
|
|
@parametrize("use_relu", [True, False])
|
|
@parametrize("is_qat", [False])
|
|
@parametrize("is_dynamic", [False])
|
|
def test_qlinear_add_int8_mixed_bf16_xpu(self, use_relu, is_qat, is_dynamic):
|
|
self._qlinear_add_test_helper(
|
|
device="xpu",
|
|
int8_mixed_bf16=True,
|
|
use_relu=use_relu,
|
|
is_qat=is_qat,
|
|
is_dynamic=is_dynamic,
|
|
)
|
|
|
|
def _test_qlinear_fp8_inductor_cpu_helper(self, qlinear_op, post_op="none"):
|
|
dtype = torch.float8_e4m3fn
|
|
qlinear_prepack = torch.ops.onednn.qlinear_prepack
|
|
post_op_algo = "none"
|
|
unary_post_op_args = ()
|
|
batch_size = 1
|
|
output_dtype = torch.float8_e4m3fn
|
|
y_scale, y_zp = 0.07, 0
|
|
ic = 4
|
|
oc = 16
|
|
|
|
torch._dynamo.reset()
|
|
used_y_scale = y_scale
|
|
used_y_zp = y_zp
|
|
x = torch.rand(batch_size, ic)
|
|
w = torch.rand(oc, ic)
|
|
qx = x.to(dtype)
|
|
qw = w.to(dtype)
|
|
x_scale = 0.5
|
|
w_scales = torch.randn(oc)
|
|
b = torch.rand(oc)
|
|
|
|
x_zp = 0
|
|
w_zps = torch.zeros_like(w_scales, dtype=torch.int)
|
|
|
|
if post_op == "none":
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.qw_packed = qlinear_prepack(qw, x.shape)
|
|
|
|
def forward(self, qx):
|
|
qy = qlinear_op(
|
|
qx,
|
|
x_scale,
|
|
x_zp,
|
|
self.qw_packed,
|
|
w_scales,
|
|
w_zps,
|
|
b,
|
|
used_y_scale,
|
|
used_y_zp,
|
|
output_dtype,
|
|
post_op,
|
|
unary_post_op_args,
|
|
post_op_algo,
|
|
)
|
|
return qy
|
|
|
|
elif post_op == "add":
|
|
x2 = torch.rand(batch_size, oc)
|
|
binary_alpha = 1.0 # we only support alpha=1.0 now
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.qw_packed = qlinear_prepack(qw, x.shape)
|
|
|
|
def forward(self, qx):
|
|
qy = qlinear_op(
|
|
qx,
|
|
x_scale,
|
|
x_zp,
|
|
self.qw_packed,
|
|
w_scales,
|
|
w_zps,
|
|
x2,
|
|
b,
|
|
used_y_scale,
|
|
used_y_zp,
|
|
output_dtype,
|
|
1.0,
|
|
0,
|
|
"add",
|
|
binary_alpha,
|
|
"none",
|
|
unary_post_op_args,
|
|
post_op_algo,
|
|
)
|
|
return qy
|
|
|
|
with torch.no_grad():
|
|
model = Mod()
|
|
y_refe = model(qx)
|
|
y_test = torch.compile(model)(qx)
|
|
self.assertEqual(y_refe.float(), y_test.float())
|
|
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_fp8_inductor_cpu(self):
|
|
qlinear_op = torch.ops.onednn.qlinear_pointwise.default
|
|
self._test_qlinear_fp8_inductor_cpu_helper(qlinear_op, "none")
|
|
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_add_fp8_inductor_cpu(self):
|
|
qlinear_op = torch.ops.onednn.qlinear_pointwise.binary
|
|
self._test_qlinear_fp8_inductor_cpu_helper(qlinear_op, "add")
|
|
|
|
def _qlinear_dequant_promotion_test_helper(
|
|
self,
|
|
inputs,
|
|
device="cpu",
|
|
int8_mixed_bf16=False,
|
|
is_dynamic=False,
|
|
matcher_check_fn=None,
|
|
):
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(4, 4)
|
|
self.linear2 = torch.nn.Linear(4, 4)
|
|
self.linear3 = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
temp = self.linear1(x)
|
|
temp = self.linear2(temp) + self.linear3(temp)
|
|
return temp
|
|
|
|
mod = M().eval().to(device=device)
|
|
|
|
def default_matcher_check_fn():
|
|
# 1. Dequant pattern matcher for dequant promotion * 1
|
|
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
|
|
# 2. dequant-linear pattern matched in quantization weight prepack * 3
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3
|
|
)
|
|
# 3. QLinear Unary fusion in post-grad fusion pass * 1
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_unary_matcher_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
inputs,
|
|
matcher_check_fn=(
|
|
matcher_check_fn
|
|
if matcher_check_fn is not None
|
|
else default_matcher_check_fn
|
|
),
|
|
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
|
check_quantization=True,
|
|
is_dynamic=is_dynamic,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_dequant_promotion_cpu(self):
|
|
r"""
|
|
This testcase test if dequant node before linear is promoted correctly:
|
|
X
|
|
|
|
|
Linear1(X)
|
|
/ \
|
|
Linear2(X) Linear3(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Y
|
|
"""
|
|
self._qlinear_dequant_promotion_test_helper((torch.randn((2, 4)),))
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_dequant_promotion_xpu(self):
|
|
r"""
|
|
This testcase test if dequant node before linear is promoted correctly:
|
|
X
|
|
|
|
|
Linear1(X)
|
|
/ \
|
|
Linear2(X) Linear3(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Y
|
|
"""
|
|
self._qlinear_dequant_promotion_test_helper(
|
|
(torch.randn((2, 4)).to(device="xpu"),), device="xpu"
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_dequant_promotion_int8_mixed_bf16(self):
|
|
r"""
|
|
Test with int8_mixed_bf16 quantization.
|
|
This testcase test if dequant node before linear is promoted correctly:
|
|
X
|
|
|
|
|
Linear1(X)
|
|
/ \
|
|
Linear2(X) Linear3(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Y
|
|
"""
|
|
self._qlinear_dequant_promotion_test_helper(
|
|
(torch.randn((2, 4)),), int8_mixed_bf16=True
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_dequant_promotion_int8_mixed_bf16_xpu(self):
|
|
r"""
|
|
Test with int8_mixed_bf16 quantization.
|
|
This testcase test if dequant node before linear is promoted correctly:
|
|
X
|
|
|
|
|
Linear1(X)
|
|
/ \
|
|
Linear2(X) Linear3(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Y
|
|
"""
|
|
self._qlinear_dequant_promotion_test_helper(
|
|
(torch.randn((2, 4)).to(device="xpu"),), device="xpu", int8_mixed_bf16=True
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2(self):
|
|
r"""
|
|
This testcase test if dequant node before linear is promoted correctly:
|
|
X
|
|
|
|
|
Linear1(X)
|
|
/ \
|
|
Linear2(X) Linear3(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Y
|
|
"""
|
|
self._qlinear_dequant_promotion_test_helper((torch.randn((2, 3, 4)),))
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_dequant_promotion_input_dim_exceeds_2_xpu(self):
|
|
r"""
|
|
This testcase test if dequant node before linear is promoted correctly:
|
|
X
|
|
|
|
|
Linear1(X)
|
|
/ \
|
|
Linear2(X) Linear3(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Y
|
|
"""
|
|
self._qlinear_dequant_promotion_test_helper(
|
|
(torch.randn((2, 3, 4)).to(device="xpu"),), device="xpu"
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2(self):
|
|
r"""
|
|
Test with int8_mixed_bf16 quantization.
|
|
This testcase test if dequant node before linear is promoted correctly:
|
|
X
|
|
|
|
|
Linear1(X)
|
|
/ \
|
|
Linear2(X) Linear3(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Y
|
|
"""
|
|
self._qlinear_dequant_promotion_test_helper(
|
|
(torch.randn((2, 3, 4)),), int8_mixed_bf16=True
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNNBF16
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_dequant_promotion_int8_mixed_bf16_input_dim_exceeds_2_xpu(self):
|
|
r"""
|
|
Test with int8_mixed_bf16 quantization.
|
|
This testcase test if dequant node before linear is promoted correctly:
|
|
X
|
|
|
|
|
Linear1(X)
|
|
/ \
|
|
Linear2(X) Linear3(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Y
|
|
"""
|
|
self._qlinear_dequant_promotion_test_helper(
|
|
(torch.randn((2, 3, 4)).to(device="xpu"),),
|
|
device="xpu",
|
|
int8_mixed_bf16=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_dequant_promotion_dynamic_cpu(self):
|
|
r"""
|
|
This testcase test if dequant node before linear is promoted correctly:
|
|
X
|
|
|
|
|
Linear1(X)
|
|
/ \
|
|
Linear2(X) Linear3(X)
|
|
\ /
|
|
Add
|
|
|
|
|
Y
|
|
"""
|
|
|
|
def matcher_check_fn():
|
|
# 1. Dequant pattern matcher for dequant promotion * 1
|
|
self.assertEqual(counters["inductor"]["dequant_promotion_matcher_count"], 1)
|
|
# 2. dequant-linear pattern matched in quantization weight prepack * 3
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 3
|
|
)
|
|
|
|
self._qlinear_dequant_promotion_test_helper(
|
|
(torch.randn((2, 4)),),
|
|
matcher_check_fn=matcher_check_fn,
|
|
is_dynamic=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
@config.patch({"fx_graph_cache": False})
|
|
def test_qlinear_mul_xpu(self):
|
|
r"""
|
|
This testcase will quantize a Linear->Mul pattern.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, use_bias):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 5, use_bias)
|
|
|
|
def forward(self, x1, x2):
|
|
return torch.mul(self.linear(x1), x2)
|
|
|
|
bias_list = [True, False]
|
|
for bias in bias_list:
|
|
mod = M(bias).eval().to(device="xpu")
|
|
x1 = torch.randn((2, 4)).to(device="xpu")
|
|
x2 = torch.randn((2, 5)).to(device="xpu")
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(x1, x2),
|
|
check_quantization=True,
|
|
matcher_check_fn=matcher_check_fn,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qlinear_mul_cpu(self):
|
|
r"""
|
|
This testcase will quantize a Linear->Mul pattern.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, use_bias):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 5, use_bias)
|
|
|
|
def forward(self, x1, x2):
|
|
return torch.mul(self.linear(x1), x2)
|
|
|
|
bias_list = [True, False]
|
|
for bias in bias_list:
|
|
mod = M(bias).eval()
|
|
x1 = torch.randn((2, 4))
|
|
x2 = torch.randn((2, 5))
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(x1, x2),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
@skipIfNoXPU
|
|
def test_qlinear_mul(self):
|
|
r"""
|
|
This testcase will quantize a Linear->Mul pattern.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, use_bias):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 5, use_bias)
|
|
|
|
def forward(self, x1, x2):
|
|
return torch.mul(self.linear(x1), x2)
|
|
|
|
bias_list = [True, False]
|
|
for bias in bias_list:
|
|
mod = M(bias).eval().to(device="xpu")
|
|
x1 = torch.randn((2, 4)).to(device="xpu")
|
|
x2 = torch.randn((2, 5)).to(device="xpu")
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(x1, x2),
|
|
check_quantization=True,
|
|
matcher_check_fn=matcher_check_fn,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
def test_qmaxpool2d(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->ReLU->MaxPool2d pattern.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
3, 64, 7, bias=True, stride=2, padding=3, dilation=1
|
|
)
|
|
self.relu = torch.nn.ReLU()
|
|
self.maxpool = torch.nn.MaxPool2d(3, **kwargs)
|
|
|
|
def forward(self, x):
|
|
return self.maxpool(self.relu(self.conv(x)))
|
|
|
|
kwargs_list = [
|
|
{"stride": 2},
|
|
{"stride": 2, "padding": 1},
|
|
{"stride": 2, "padding": 1, "dilation": 1},
|
|
{"stride": 2, "padding": 1, "dilation": 1, "ceil_mode": False},
|
|
]
|
|
for kwargs in kwargs_list:
|
|
mod = M(kwargs).eval()
|
|
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
|
|
1
|
|
)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qmaxpool2d_matcher_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 1
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_matcher_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_lower_count"],
|
|
0 if TEST_ACL else 1,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
def test_qflatten(self):
|
|
r"""
|
|
This testcase will quantize Conv2d->AdaptiveAvgPool2d->flatten->cat pattern.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
3, 64, 7, bias=True, stride=2, padding=3, dilation=1
|
|
)
|
|
self.relu = torch.nn.ReLU()
|
|
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
def forward(self, x):
|
|
return torch.cat(
|
|
[
|
|
torch.flatten(
|
|
self.adaptive_avg_pool2d(self.relu(self.conv(x))), 1
|
|
)
|
|
]
|
|
)
|
|
|
|
mod = M().eval()
|
|
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qreshape_matcher_count"], 0 if TEST_ACL else 1
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
def test_qcat(self):
|
|
r"""
|
|
This testcase will quantize cat based pattern:
|
|
X
|
|
/ \
|
|
Conv1(X) Pow(x)
|
|
\ \
|
|
\ Conv2(X)
|
|
\ /
|
|
Cat
|
|
|
|
|
Y
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
3, 64, 7, bias=True, stride=2, padding=3, dilation=1
|
|
)
|
|
self.conv2 = torch.nn.Conv2d(
|
|
3, 64, 7, bias=True, stride=2, padding=3, dilation=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
temp1 = self.conv(x)
|
|
temp2 = self.conv2(torch.pow(x, 2))
|
|
return torch.cat((temp1, temp2), 1)
|
|
|
|
mod = M().eval()
|
|
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qcat_matcher_count"], 0 if TEST_ACL else 1
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 2
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_matcher_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/99841.
|
|
def test_hardtanh_pattern_fallback(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv_transpose = torch.nn.ConvTranspose2d(
|
|
in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x, min_value, max_value):
|
|
conv_transpose_output = self.conv_transpose(x)
|
|
clamp_min_output = torch.clamp_min(conv_transpose_output, min_value)
|
|
clamp_max_output = torch.clamp_max(clamp_min_output, max_value)
|
|
return clamp_max_output
|
|
|
|
# check works for min_value > max_value.
|
|
min_values = [3, torch.randn(1, 32, 28, 28)]
|
|
max_values = [0, torch.randn(1, 32, 28, 28)]
|
|
v = torch.randn(1, 3, 28, 28)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
|
0 if TEST_ACL else 3,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
|
|
)
|
|
|
|
for min_value, max_value in zip(min_values, max_values):
|
|
mod = Model().eval()
|
|
self._test_common(mod, (v, min_value, max_value), matcher_check_fn)
|
|
|
|
def test_leaky_relu_pattern_fallback(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x, negative_slope):
|
|
conv_out = self.conv(x)
|
|
return torch.where(conv_out > 0, conv_out, conv_out * negative_slope)
|
|
|
|
negative_slopes = [0.1, torch.randn(1, 32, 28, 28)]
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
|
0 if TEST_ACL else 4,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
|
|
)
|
|
|
|
with torch.no_grad():
|
|
v = torch.randn(1, 3, 28, 28)
|
|
for negative_slope in negative_slopes:
|
|
mod = Model().eval()
|
|
self._test_common(mod, (v, negative_slope), matcher_check_fn)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/99838.
|
|
def test_conv2d_add_scalar(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
out_conv = self.conv(x)
|
|
out = torch.add(out_conv, 1.0)
|
|
return out
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(counters["inductor"]["binary_folding"], 1)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"], 1
|
|
)
|
|
|
|
with torch.no_grad():
|
|
mod = Model().eval()
|
|
v = torch.randn(1, 3, 28, 28)
|
|
self._test_common(mod, (v,), matcher_check_fn)
|
|
|
|
@xfailIfACL
|
|
def test_conv2d_binary_inplace_fusion_pass_cpu(
|
|
self, include_ops=None, exclude_ops=None
|
|
):
|
|
class Model_v1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x, other):
|
|
conv_out = self.conv(x)
|
|
return torch.add(conv_out, other.relu())
|
|
|
|
class Model_v2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
self.conv2 = torch.nn.Conv2d(
|
|
in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
self.conv3 = torch.nn.Conv2d(
|
|
in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x, _):
|
|
conv_out1 = self.conv(x)
|
|
pow_out = torch.pow(conv_out1, 2)
|
|
conv_out2 = self.conv2(pow_out)
|
|
conv_out3 = self.conv3(conv_out2)
|
|
res = torch.add(conv_out3, pow_out)
|
|
return res
|
|
|
|
input = torch.randn(1, 3, 28, 28).to(memory_format=torch.channels_last)
|
|
others = [
|
|
torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
|
|
torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
|
|
]
|
|
mod_v1 = Model_v1().to(memory_format=torch.channels_last).eval()
|
|
mod_v2 = Model_v2().to(memory_format=torch.channels_last).eval()
|
|
|
|
if include_ops is None:
|
|
include_ops = ["mkldnn._convolution_pointwise_.binary"]
|
|
if exclude_ops is None:
|
|
exclude_ops = ["mkldnn._convolution_pointwise.binary"]
|
|
|
|
for other, mod in zip(others, [mod_v1, mod_v2]):
|
|
self._test_code_common(mod, (input, other), include_ops, exclude_ops)
|
|
|
|
@xfailIfACL
|
|
def test_conv2d_binary_inplace_fusion_failed_cpu(
|
|
self, include_ops=None, exclude_ops=None
|
|
):
|
|
# Written buffer is graph input, we can't fuse inplace.
|
|
class Model_v1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x, other):
|
|
conv_out = self.conv(x)
|
|
return torch.add(conv_out, other)
|
|
|
|
# Written buffer is an alias tensor, we can't fuse inplace.
|
|
class Model_v2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x, other):
|
|
conv_out = self.conv(x)
|
|
return torch.add(conv_out, other[1:2, :, :, :]), other
|
|
|
|
class Model_v3(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
self.conv2 = torch.nn.Conv2d(
|
|
in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x, _):
|
|
pow_out = torch.pow(self.conv(x), 2)
|
|
other2 = F.relu(pow_out)
|
|
conv_out2 = self.conv2(pow_out)
|
|
res = torch.add(conv_out2, pow_out)
|
|
res = res + other2
|
|
return res
|
|
|
|
# Written buffer is an ReinterpretView, we can't fuse inplace.
|
|
class Model_v4(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 32, 3, padding=1, bias=True)
|
|
self.linear = torch.nn.Linear(32 * 28, 32 * 28)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv(self.relu(x))
|
|
y = self.linear(y)
|
|
y = torch.cat((y, y + 1), 1)
|
|
y = torch.ops.aten.permute.default(y, [0, 2, 1]).reshape(1, 32, 28, 28)
|
|
return x + y
|
|
|
|
class Model_v5(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(32, 32, 3, padding=1, bias=True)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, _, x):
|
|
x1 = self.relu(x)
|
|
return self.conv(x1) + x1
|
|
|
|
input = torch.randn(1, 3, 28, 28).to(memory_format=torch.channels_last)
|
|
others = [
|
|
torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
|
|
torch.randn(2, 32, 28, 28).to(memory_format=torch.channels_last),
|
|
torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
|
|
torch.randn(1, 14, 32 * 28),
|
|
torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
|
|
]
|
|
mod_v1 = Model_v1().to(memory_format=torch.channels_last).eval()
|
|
mod_v2 = Model_v2().to(memory_format=torch.channels_last).eval()
|
|
mod_v3 = Model_v3().to(memory_format=torch.channels_last).eval()
|
|
mod_v4 = Model_v4().to(memory_format=torch.channels_last).eval()
|
|
mod_v5 = Model_v5().to(memory_format=torch.channels_last).eval()
|
|
|
|
if include_ops is None:
|
|
include_ops = ["mkldnn._convolution_pointwise.binary"]
|
|
if exclude_ops is None:
|
|
exclude_ops = ["mkldnn._convolution_pointwise_.binary"]
|
|
|
|
for other, mod in zip(others, [mod_v1, mod_v2, mod_v3, mod_v4, mod_v5]):
|
|
self._test_code_common(mod, (input, other), include_ops, exclude_ops)
|
|
|
|
def test_conv2d_binary_fusion_failed(self):
|
|
# we don't support alpha !=1 case or other has different size with conv's output.
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x, other, alpha):
|
|
conv_out = self.conv(x)
|
|
return torch.add(conv_out, other, alpha=alpha)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/100802.
|
|
# we can't do the fusion when add's inputs are same tensor.
|
|
class Model2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
out = torch.add(out, out)
|
|
return out
|
|
|
|
# https://github.com/pytorch/pytorch/issues/101374.
|
|
# we can't do the fusion when add's inputs are mixed dtype.
|
|
class Model3(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
temp = self.conv(x)
|
|
other = torch.ones(temp.shape, dtype=torch.double)
|
|
out = torch.add(temp, other)
|
|
return out
|
|
|
|
input = torch.randn(1, 3, 28, 28).to(memory_format=torch.channels_last)
|
|
others = [
|
|
torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last),
|
|
torch.randn(32, 28, 28),
|
|
]
|
|
include_ops = ["mkldnn._convolution_pointwise"]
|
|
exclude_ops = [
|
|
"mkldnn._convolution_pointwise.binary",
|
|
"mkldnn._convolution_pointwise_.binary",
|
|
]
|
|
|
|
# case1
|
|
for other, alpha in zip(others, [0.1, 1.0]):
|
|
mod = Model().to(memory_format=torch.channels_last).eval()
|
|
self._test_code_common(mod, (input, other, alpha), include_ops, exclude_ops)
|
|
# case2:
|
|
mod = Model2().to(memory_format=torch.channels_last).eval()
|
|
self._test_code_common(mod, (input,), include_ops, exclude_ops)
|
|
# case3:
|
|
mod = Model3().to(memory_format=torch.channels_last).eval()
|
|
self._test_code_common(mod, (input,), include_ops, exclude_ops)
|
|
|
|
@xfailIfACL
|
|
def test_reproduce_99842_issue(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
|
|
|
def forward(self, input_tensor):
|
|
x = self.conv(input_tensor)
|
|
x = F.relu(x + torch.ones(x.size()))
|
|
return x
|
|
|
|
input = torch.randn(1, 3, 14, 14)
|
|
mod = Model().eval()
|
|
include_ops = ["mkldnn._convolution_pointwise_.binary"]
|
|
self._test_code_common(mod, (input,), include_ops, [])
|
|
|
|
def test_reproduce_113440_issue_1(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
add_fn,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.add_fn = add_fn
|
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
|
|
self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
|
|
self.add_fn2 = add_fn
|
|
self.relu2 = torch.nn.ReLU(inplace=True)
|
|
self.use_relu = True
|
|
|
|
def forward(self, x):
|
|
x1 = self.conv1(x)
|
|
x2 = self.conv2(x)
|
|
tmp = self.add_fn(x1, x2)
|
|
if self.use_relu:
|
|
tmp = self.relu(tmp)
|
|
tmp1 = self.conv3(tmp)
|
|
tmp2 = self.conv4(tmp)
|
|
res = self.add_fn2(tmp1, tmp2)
|
|
if self.use_relu:
|
|
res = self.relu2(res)
|
|
return res
|
|
|
|
with torch.no_grad():
|
|
example_inputs = (
|
|
torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
|
|
1
|
|
),
|
|
)
|
|
example_inputs[0].get_device()
|
|
m = Mod(
|
|
lambda x, y: x.add_(y),
|
|
).eval()
|
|
om = torch.compile(m)
|
|
om(*example_inputs)
|
|
om(*example_inputs)
|
|
|
|
def test_reproduce_113440_issue_2(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
add_fn,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
|
|
self.add_fn = add_fn
|
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
|
|
self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
|
|
self.add_fn2 = add_fn
|
|
self.relu2 = torch.nn.ReLU(inplace=True)
|
|
|
|
self.conv5 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
|
|
self.conv6 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1)
|
|
self.conv7 = torch.nn.Conv2d(6, 6, kernel_size=1, stride=1)
|
|
self.add_fn3 = add_fn
|
|
self.relu3 = torch.nn.ReLU(inplace=True)
|
|
|
|
self.use_relu = True
|
|
|
|
def forward(self, x):
|
|
x1 = self.conv1(x)
|
|
x2 = self.conv2(x)
|
|
tmp = self.add_fn(x1, x2)
|
|
if self.use_relu:
|
|
tmp = self.relu(tmp)
|
|
|
|
tmp1 = self.conv3(tmp)
|
|
res = self.relu2(tmp1)
|
|
|
|
return res
|
|
|
|
with torch.no_grad():
|
|
example_inputs = (
|
|
torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
|
|
1
|
|
),
|
|
)
|
|
m = Mod(
|
|
lambda x, y: x.add_(y),
|
|
).eval()
|
|
om = torch.compile(m)
|
|
om(*example_inputs)
|
|
om(*example_inputs)
|
|
|
|
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
|
@xfailIfACL
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
|
def test_reproduce_121253_issue_addmm_fusion_check(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self, weight, bias, beta, alpha):
|
|
super().__init__()
|
|
self.weight = weight
|
|
self.bias = bias
|
|
self.beta = beta
|
|
self.alpha = alpha
|
|
|
|
def forward(self, x):
|
|
return torch.addmm(
|
|
self.bias, x, self.weight, beta=self.beta, alpha=self.alpha
|
|
)
|
|
|
|
dtypes = [torch.float32]
|
|
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
|
dtypes.append(torch.bfloat16)
|
|
for dtype in dtypes:
|
|
linear_op = (
|
|
"mkl._mkl_linear"
|
|
if dtype == torch.float32
|
|
else "mkldnn._linear_pointwise"
|
|
)
|
|
for beta, alpha in zip([1.0, 0.1, 0.0], [1.0, 0.1, 1.0]):
|
|
weight = torch.nn.Parameter(torch.randn(64, 64, dtype=dtype))
|
|
bias = torch.nn.Parameter(torch.randn(64, dtype=dtype))
|
|
mod = Mod(weight, bias, beta, alpha).to(dtype).eval()
|
|
with torch.no_grad():
|
|
x = torch.randn(1, 64, dtype=dtype)
|
|
include_ops = []
|
|
exclude_ops = []
|
|
if (beta != 1.0 and beta != 0.0) or alpha != 1.0:
|
|
exclude_ops = [linear_op]
|
|
else:
|
|
include_ops = [linear_op]
|
|
self._test_code_common(mod, (x,), include_ops, exclude_ops)
|
|
|
|
@skipIfNoDynamoSupport
|
|
def test_woq_int8(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, is_permute):
|
|
super().__init__()
|
|
self.is_permute = is_permute
|
|
|
|
def forward(self, x, weight, scales):
|
|
if self.is_permute:
|
|
weight = weight.t()
|
|
m = torch.mm(
|
|
x.reshape(-1, x.shape[-1]),
|
|
weight.to(x.dtype),
|
|
)
|
|
y = m * scales.to(m.dtype)
|
|
y = y.reshape(*x.shape[:-1], y.shape[-1])
|
|
return y
|
|
else:
|
|
return (
|
|
torch.nn.functional.linear(x, weight.to(dtype=x.dtype)) * scales
|
|
)
|
|
|
|
x_shape = (1, 1, 256)
|
|
s_shape = 12
|
|
x_strides = [
|
|
(256, 256, 1), # linear dispatching to mm
|
|
(256, 32, 1), # linear dispatching to bmm
|
|
]
|
|
is_permutes = [False, True]
|
|
for x_stride, is_permute in itertools.product(x_strides, is_permutes):
|
|
mod = M(is_permute=is_permute).eval()
|
|
x = torch.randn(x_shape, dtype=torch.bfloat16).as_strided(x_shape, x_stride)
|
|
w_shape = (12, 256)
|
|
w = torch.randint(-128, 127, w_shape, dtype=torch.int8)
|
|
s = torch.randn(s_shape, dtype=torch.bfloat16)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["woq_matcher_count"], 0 if TEST_ACL else 1
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(x, w, s),
|
|
matcher_check_fn,
|
|
check_quantization=False,
|
|
atol=0.001,
|
|
rtol=0.07,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
def test_woq_int4_cpu(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, in_feature, out_feature, group_size):
|
|
super().__init__()
|
|
self.weight = torch.randint(
|
|
0, 255, (out_feature, in_feature // 2), dtype=torch.uint8
|
|
)
|
|
self.group_size = group_size
|
|
self.qScaleAndZeros = torch.rand(
|
|
(in_feature // group_size, out_feature, 2), dtype=torch.bfloat16
|
|
)
|
|
|
|
def forward(self, x):
|
|
if x.ndim > 2:
|
|
x = x.reshape(-1, x.shape[-1])
|
|
y = torch.ops.aten._weight_int4pack_mm_for_cpu.default(
|
|
x, self.weight, self.group_size, self.qScaleAndZeros
|
|
)
|
|
return y.reshape(*x.shape[:-1], y.shape[-1])
|
|
return torch.ops.aten._weight_int4pack_mm_for_cpu.default(
|
|
x, self.weight, self.group_size, self.qScaleAndZeros
|
|
)
|
|
|
|
bs = 4
|
|
seq = 8
|
|
x_dim_list = [2, 3]
|
|
in_feature_list = [256, 512]
|
|
out_feature_list = [256, 512]
|
|
group_size_list = [64, 128]
|
|
cases = itertools.product(
|
|
x_dim_list, in_feature_list, out_feature_list, group_size_list
|
|
)
|
|
for x_dim, in_feature, out_feature, group_size in cases:
|
|
x_shape = (seq, in_feature) if x_dim == 2 else (bs, seq, in_feature)
|
|
x = torch.randn(x_shape, dtype=torch.bfloat16)
|
|
m = M(in_feature, out_feature, group_size).eval()
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["woq_matcher_count"], 0 if TEST_ACL else 1
|
|
)
|
|
|
|
include_ops = [
|
|
"aoti_torch_cpu__weight_int4pack_mm_cpu_tensor"
|
|
if torch._inductor.config.cpp_wrapper
|
|
else "torch.ops.quantized.int4mm_packed_weight_cpu.default"
|
|
]
|
|
self._test_code_common(
|
|
m,
|
|
(x,),
|
|
include_ops,
|
|
["torch.ops.aten._weight_int4pack_mm_for_cpu.default"],
|
|
)
|
|
|
|
def _test_linear_dynamic_fp16_helper(self, use_relu: bool):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, bias: bool, use_relu: bool):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(256, 256, bias=bias)
|
|
self.relu = torch.nn.ReLU()
|
|
self.use_relu = use_relu
|
|
|
|
def forward(self, x):
|
|
if self.use_relu:
|
|
return self.relu(self.linear(x))
|
|
return self.linear(x)
|
|
|
|
quantizer = X86InductorQuantizer().set_global(
|
|
xiq.get_default_x86_inductor_quantization_config()
|
|
)
|
|
quantizer.set_module_type_qconfig(
|
|
torch.nn.Linear, xiq.get_x86_inductor_linear_dynamic_fp16_config()
|
|
)
|
|
bias_list = [True, False]
|
|
input_ndim_list = [2, 3]
|
|
x_contig_list = [True, False]
|
|
cases = itertools.product(bias_list, input_ndim_list, x_contig_list)
|
|
for bias, input_ndim, x_contig in cases:
|
|
x_shape = (4, 256) if input_ndim == 2 else (4, 1, 256)
|
|
x = torch.randn(x_shape)
|
|
if not x_contig:
|
|
x = x[0::2, ...]
|
|
mod = M(bias, use_relu).eval()
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
|
|
)
|
|
# Matched nodes:
|
|
# (1) w to fp16, (2) w to fp32, (3) permute w, (4) mm/addmm/bmm
|
|
# If x.ndim == 3 and x is contiguous, two view nodes are added.
|
|
# If x.ndim == 3 and x is not contiguous, two expand nodes and one add node are added.
|
|
nodes_count = 4
|
|
if input_ndim > 2:
|
|
if x_contig:
|
|
nodes_count += 2
|
|
else:
|
|
nodes_count += 3 if bias else 2
|
|
if use_relu:
|
|
nodes_count += 1
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
|
|
nodes_count,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(x,),
|
|
atol=1e-2,
|
|
rtol=1e-2,
|
|
matcher_check_fn=matcher_check_fn,
|
|
check_quantization=True,
|
|
quantizer=quantizer,
|
|
)
|
|
linear_op_str = (
|
|
"torch.ops.onednn.linear_relu_dynamic_fp16.default"
|
|
if use_relu
|
|
else "torch.ops.onednn.linear_dynamic_fp16.default"
|
|
)
|
|
self._test_code_common(
|
|
mod,
|
|
(x,),
|
|
[linear_op_str],
|
|
["torch.ops.aten.addmm.default", "torch.ops.aten.mm.default"],
|
|
check_quantization=True,
|
|
quantizer=quantizer,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_linear_dynamic_fp16(self):
|
|
self._test_linear_dynamic_fp16_helper(use_relu=False)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_linear_relu_dynamic_fp16(self):
|
|
self._test_linear_dynamic_fp16_helper(use_relu=True)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
# TODO: investigate options of torch.compile in fbcode
|
|
@unittest.skipIf(IS_FBCODE, "Failing in fbcode")
|
|
@parametrize("has_bias", [True, False])
|
|
@parametrize("dtype", [torch.float, torch.bfloat16])
|
|
@parametrize("per_channel_quant", [True, False])
|
|
@parametrize("dynamic", [True, False])
|
|
def test_smooth_quant_with_int_mm(
|
|
self, has_bias, dtype, per_channel_quant, dynamic
|
|
):
|
|
r"""
|
|
This testcase check if we can match the SmoothQuant int8 linear pattern from Torchao.
|
|
The pattern is:
|
|
(no bias) reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape
|
|
or
|
|
(with bias) pattern_no_bias -> add -> reshape -> reshape
|
|
"""
|
|
if dtype == torch.bfloat16 and not torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
|
return
|
|
M = 16
|
|
in_feature = 32
|
|
out_feature = 64
|
|
q_min, q_max = -32, 31
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(
|
|
self, dtype: torch.dtype, has_bias: bool, per_channel_quant: bool
|
|
):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
self.has_bias = has_bias
|
|
self.b = torch.randint(
|
|
q_min, q_max, [in_feature, out_feature], dtype=torch.int8
|
|
)
|
|
self.per_channel_quant = per_channel_quant
|
|
a_scale_per_tensor = torch.rand([1], dtype=dtype) * 0.01 + 0.01
|
|
a_scale_per_channel = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01
|
|
self.a_scale = (
|
|
a_scale_per_channel
|
|
if self.per_channel_quant
|
|
else a_scale_per_tensor
|
|
)
|
|
self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01
|
|
self.b_scale = self.b_scale.to(dtype)
|
|
self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None
|
|
|
|
def forward(self, a):
|
|
out_shape = a.shape[:-1] + (self.b.size(-1),)
|
|
a_reshaped = a.reshape(-1, a.size(-1))
|
|
c = torch._int_mm(a_reshaped, self.b)
|
|
c = c.to(self.dtype)
|
|
c_shape = c.shape
|
|
a_scale = self.a_scale.expand(c.shape)
|
|
c = c * a_scale
|
|
c = c * self.b_scale
|
|
if self.has_bias:
|
|
c = c.reshape([1, *list(c_shape)])
|
|
c = c + self.bias
|
|
c = c.reshape(c_shape)
|
|
c = c.reshape(out_shape)
|
|
return c
|
|
|
|
mod = Mod(dtype, has_bias, per_channel_quant).eval()
|
|
a = torch.randint(q_min, q_max, [1, M, in_feature], dtype=torch.int8)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
|
|
)
|
|
if dynamic:
|
|
nodes_count = 10 if has_bias else 7
|
|
else:
|
|
nodes_count = 7 if has_bias else 6
|
|
if counters["inductor"]["removed_pointless_view_pair"] == 0:
|
|
# Removing pointless view pairs affect how the pattern
|
|
# for this test is matched.
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
|
|
nodes_count,
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(a,),
|
|
matcher_check_fn=matcher_check_fn,
|
|
check_autocast=dtype,
|
|
compile_options={"dynamic": dynamic},
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
# TODO: investigate options of torch.compile in fbcode
|
|
@unittest.skipIf(IS_FBCODE, "Failing in fbcode")
|
|
@parametrize("has_bias", [True, False])
|
|
@parametrize("dtype", [torch.float, torch.bfloat16])
|
|
@parametrize("dynamic", [True, False])
|
|
@parametrize("reshape_a", [True, False])
|
|
@parametrize(
|
|
"M",
|
|
[
|
|
1,
|
|
32,
|
|
],
|
|
)
|
|
@parametrize("inplace_add", [True, False])
|
|
@parametrize("expand_a_scale", [True, False])
|
|
def test_da8w8_sym_act_sym_wgt_with_int_mm(
|
|
self, has_bias, dtype, dynamic, reshape_a, M, inplace_add, expand_a_scale
|
|
):
|
|
r"""
|
|
This testcase check if we can match the int8_dynamic_activation_int8_weight int8 linear pattern from torchao,
|
|
when activation is symmetrically quantized dynamically & weights are symmetrically quantized (statically)
|
|
The pattern is:
|
|
(no bias) _int_mm -> convert_element_type -> ([expand_a] -> mul) -> mul
|
|
or
|
|
(with bias) pattern_no_bias -> add
|
|
Expansion of the scale of activation is optional.
|
|
The pattern depiction doesn't mean that convert_element_type output is fed into expand_a as input,
|
|
but simply that activation scale may be applied after an expand operation on it.
|
|
"""
|
|
if dtype == torch.bfloat16 and not torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
|
return
|
|
in_feature = 32
|
|
out_feature = 64
|
|
q_min, q_max = -32, 31
|
|
# we only test for qlinear_binary in this case
|
|
test_for_pointwise_binary = (
|
|
True
|
|
if M == 1
|
|
and inplace_add
|
|
and not expand_a_scale
|
|
and not dynamic
|
|
and not has_bias
|
|
else False
|
|
)
|
|
if test_for_pointwise_binary and not IS_X86:
|
|
self.skipTest("Some UTs are only supported on x86_64 CPUs")
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self, dtype: torch.dtype, has_bias: bool):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
self.has_bias = has_bias
|
|
self.b = torch.randint(
|
|
q_min, q_max, [in_feature, out_feature], dtype=torch.int8
|
|
)
|
|
self.a_scale = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01
|
|
self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01
|
|
self.b_scale = self.b_scale.to(dtype)
|
|
self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None
|
|
self.additive = torch.rand([M, out_feature], dtype=dtype)
|
|
|
|
def forward(self, a):
|
|
if reshape_a:
|
|
a_reshaped = a.reshape(-1, a.size(-1))
|
|
else:
|
|
a_reshaped = a
|
|
c = torch._int_mm(a_reshaped, self.b)
|
|
c = c.to(self.dtype)
|
|
if expand_a_scale:
|
|
a_scale = self.a_scale.expand(c.shape)
|
|
else:
|
|
a_scale = self.a_scale
|
|
c = c * a_scale
|
|
c = c * self.b_scale
|
|
if self.has_bias:
|
|
c = c + self.bias
|
|
elif inplace_add and test_for_pointwise_binary:
|
|
# When M is 1, dynamic shapes are enabled with torch.compile, has_bias is False,
|
|
# expand_a_scale is False and inplace_add is true,
|
|
# the output's outermost dim's stride can't be determined due to some Inductor bug.
|
|
c.add_(self.additive)
|
|
return c
|
|
|
|
mod = Mod(dtype, has_bias).eval()
|
|
a = torch.randint(q_min, q_max, [M, in_feature], dtype=torch.int8)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(a,),
|
|
matcher_check_fn,
|
|
check_autocast=dtype,
|
|
compile_options={"dynamic": dynamic},
|
|
)
|
|
if test_for_pointwise_binary:
|
|
self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1)
|
|
|
|
|
|
class TestDynamicPatternMatcherGeneric(TestPatternMatcherBase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.ctx_stack.enter_context(
|
|
# When testing kernel counts, unspecializing float causes wobbling of our tests because
|
|
# we end up reusing the same compiled region across tests. Thus we purposely specialize floats
|
|
# here since we primarily care about number of kernels generated in the absence of compile
|
|
# caching.
|
|
dynamo_config.patch(
|
|
{
|
|
"dynamic_shapes": True,
|
|
"assume_static_by_default": False,
|
|
"specialize_float": True,
|
|
}
|
|
)
|
|
)
|
|
|
|
_test_conv_unary_base = TestPatternMatcherGeneric._test_conv_unary_base
|
|
test_conv2d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_unary
|
|
test_conv3d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_unary
|
|
_test_conv_binary_base = TestPatternMatcherGeneric._test_conv_binary_base
|
|
test_conv2d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_binary
|
|
test_conv3d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_binary
|
|
|
|
def test_conv_transpose2d_dynamic_shapes(self, device):
|
|
self.device = device
|
|
|
|
# We don't support conv_transpose2d for now.
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv_transpose2d = torch.nn.ConvTranspose2d(
|
|
3, 16, 3, stride=2, padding=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.conv_transpose2d(x)
|
|
|
|
x_shape = (1, 3, 28, 28)
|
|
mod = M().eval()
|
|
v = torch.randn(x_shape, dtype=torch.float32)
|
|
|
|
def matcher_check_fn():
|
|
return
|
|
|
|
self._test_common(mod, (v,), matcher_check_fn)
|
|
|
|
@skipIfXpu(
|
|
msg="Different with CPU, two linears will be concat on XPU for better performance"
|
|
)
|
|
def test_multi_linear_share_same_input_dynamic(self, device):
|
|
self.device = device
|
|
|
|
# llama pattern.
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
self.w1 = torch.nn.Linear(16, 16, bias=False)
|
|
self.w2 = torch.nn.Linear(16, 16, bias=False)
|
|
|
|
def forward(self, x):
|
|
return F.silu(self.w1(x)) * F.relu(self.w2(x))
|
|
|
|
dtypes = []
|
|
if is_mkldnn_bf16_supported(self.device):
|
|
dtypes.append(torch.bfloat16)
|
|
if is_mkldnn_fp16_supported(self.device):
|
|
dtypes.append(torch.float16)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"],
|
|
0 if TEST_ACL else 7,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_unary_fusion_matcher_count"],
|
|
0 if TEST_ACL else 2,
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"], 6
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_count"], 2
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"], 2
|
|
)
|
|
|
|
for dtype in dtypes:
|
|
mod = M().to(dtype).eval()
|
|
v = torch.randn(2, 4, 16).to(dtype)
|
|
self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2)
|
|
|
|
|
|
class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
|
test_linear_unary_dynamic_shapes = TestPatternMatcher.test_linear_unary
|
|
test_linear_input_non_contiguous_3D_wo_bias_dynamic_shapes = (
|
|
TestPatternMatcher.test_linear_input_non_contiguous_3D_wo_bias
|
|
)
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.ctx_stack.enter_context(
|
|
# When testing kernel counts, unspecializing float causes wobbling of our tests because
|
|
# we end up reusing the same compiled region across tests. Thus we purposely specialize floats
|
|
# here since we primarily care about number of kernels generated in the absence of compile
|
|
# caching.
|
|
dynamo_config.patch(
|
|
{
|
|
"dynamic_shapes": True,
|
|
"assume_static_by_default": False,
|
|
"specialize_float": True,
|
|
}
|
|
)
|
|
)
|
|
|
|
@xfailIfACL
|
|
def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None):
|
|
r"""
|
|
This testcase will quantize a single Conv2d->Maxpool2d->Linear module
|
|
with dynamic batch size input.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(
|
|
3, 16, (2, 2), stride=(1, 1), padding=(1, 1)
|
|
)
|
|
self.relu = torch.nn.ReLU()
|
|
self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
|
self.linear = torch.nn.Linear(16, 16)
|
|
|
|
def forward(self, x):
|
|
temp = self.relu(self.conv(x))
|
|
temp = self.maxpool2d(temp)
|
|
temp = self.avgpool(temp)
|
|
temp = torch.flatten(temp, 1)
|
|
return self.linear(temp)
|
|
|
|
mod = M().eval()
|
|
v = torch.randn((2, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1)
|
|
if include_ops is None:
|
|
include_ops = [
|
|
"torch.ops.onednn.qconv_pointwise",
|
|
"torch.ops.quantized.max_pool2d",
|
|
"torch.ops.onednn.qlinear_pointwise",
|
|
]
|
|
exclude_ops = []
|
|
self._test_code_common(
|
|
mod,
|
|
(v,),
|
|
include_ops,
|
|
exclude_ops,
|
|
check_quantization=True,
|
|
check_dynamic=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_qat_bn_conv2d(self):
|
|
r"""
|
|
This testcase will quantize a single BN Conv2d module with qat flow.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.bn1 = torch.nn.BatchNorm2d(3)
|
|
self.bn2 = torch.nn.BatchNorm2d(3)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(self.bn1(x))
|
|
return self.bn2(x)
|
|
|
|
mod = M().train()
|
|
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qconv_weight_prepack_matcher_count"], 1
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
is_qat=True,
|
|
)
|
|
|
|
@skipIfNoDynamoSupport
|
|
@skipIfNoONEDNN
|
|
def test_q_attention_block(self):
|
|
class SelfAttnLikeModule(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim,
|
|
num_attention_heads=None,
|
|
attention_head_size=None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.input_dim = input_dim
|
|
self.q_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
|
|
self.k_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
|
|
self.v_proj = torch.nn.Linear(input_dim, input_dim, bias=False)
|
|
self.softmax = torch.nn.Softmax(dim=-1)
|
|
self.num_attention_heads = num_attention_heads
|
|
self.attention_head_size = attention_head_size
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
self.dense = torch.nn.Linear(self.all_head_size, self.all_head_size)
|
|
|
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
|
new_x_shape = x.size()[:-1] + (
|
|
self.num_attention_heads,
|
|
self.attention_head_size,
|
|
)
|
|
x = x.view(new_x_shape)
|
|
return x.permute(0, 2, 1, 3)
|
|
|
|
def forward(self, x):
|
|
q = self.q_proj(x)
|
|
k = self.k_proj(x)
|
|
v = self.v_proj(x)
|
|
q = self.transpose_for_scores(q)
|
|
k = self.transpose_for_scores(k)
|
|
v = self.transpose_for_scores(v)
|
|
scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5)
|
|
attention = self.softmax(scores)
|
|
weighted = torch.matmul(attention, v)
|
|
weighted = weighted.permute(0, 2, 1, 3).contiguous()
|
|
weighted = weighted.reshape(
|
|
weighted.size()[:-2] + (self.all_head_size,)
|
|
)
|
|
return self.dense(weighted)
|
|
|
|
for annotate_matmul in [True, False]:
|
|
mod = SelfAttnLikeModule(
|
|
input_dim=64 * 16,
|
|
num_attention_heads=16,
|
|
attention_head_size=64,
|
|
).eval()
|
|
v = torch.randn(2, 384, 1024)
|
|
|
|
def matcher_check_fn():
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 4
|
|
)
|
|
self.assertEqual(
|
|
counters["inductor"]["qlinear_unary_matcher_count"],
|
|
3 if annotate_matmul and not TEST_ACL else 0,
|
|
)
|
|
if IS_X86: # Some issues on ARM
|
|
self.assertEqual(
|
|
counters["inductor"]["quant_lift_up_count"],
|
|
4 if annotate_matmul and not TEST_ACL else 1,
|
|
)
|
|
|
|
quantizer = X86InductorQuantizer()
|
|
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
|
|
if annotate_matmul:
|
|
quantizer.set_function_type_qconfig(
|
|
torch.matmul, quantizer.get_global_quantization_config()
|
|
)
|
|
|
|
self._test_common(
|
|
mod,
|
|
(v,),
|
|
matcher_check_fn,
|
|
check_quantization=True,
|
|
quantizer=quantizer,
|
|
)
|
|
|
|
|
|
instantiate_device_type_tests(
|
|
TestPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu", "xpu")
|
|
)
|
|
instantiate_device_type_tests(
|
|
TestDynamicPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu", "xpu")
|
|
)
|
|
instantiate_parametrized_tests(TestPatternMatcher)
|
|
if __name__ == "__main__":
|
|
if IS_LINUX and (HAS_CPU) and torch.backends.mkldnn.is_available():
|
|
run_tests()
|