Files
pytorch/test/inductor/test_cpu_select_algorithm.py
PyTorch MergeBot 7ae0629d64 Revert "[inductor] turn on windows inductor UTs (#160161)"
This reverts commit f0980fc0bbd656d6c02d23ad97e945353b314f35.

Reverted https://github.com/pytorch/pytorch/pull/160161 on behalf of https://github.com/clee2000 due to broke some inductor tests on windows inductor\test_codecache.py::TestStandaloneCompile::test_different_process [GH job link](https://github.com/pytorch/pytorch/actions/runs/16853706010/job/47748778757) [HUD commit link](f0980fc0bb).  note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/160161#issuecomment-3172784292))
2025-08-10 17:33:19 +00:00

3099 lines
117 KiB
Python

# Owner(s): ["oncall: cpu inductor"]
import contextlib
import functools
import sys
import unittest
from typing import Optional
from unittest.mock import patch
import torch
import torch._dynamo.config
import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
import torch._inductor.cpu_vec_isa
import torch._inductor.select_algorithm as select_algorithm
from torch._dynamo.utils import counters
from torch._inductor import test_operators
from torch._inductor.cpu_vec_isa import VecAMX
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
)
from torch.testing._internal.common_quantization import _generate_qdq_quantized_model
from torch.testing._internal.common_quantized import (
_calculate_dynamic_per_channel_qparams,
)
from torch.testing._internal.common_utils import (
IS_MACOS,
parametrize,
skipIfWindows,
TEST_MKL,
)
try:
try:
from . import test_cpu_repro, test_torchinductor
except ImportError:
import test_cpu_repro # @manual=fbcode//caffe2/test/inductor:test_cpu_repro-library
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
except unittest.SkipTest:
if __name__ == "__main__":
sys.exit(0)
raise
check_model = test_torchinductor.check_model
set_num_threads = test_cpu_repro.set_num_threads
run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
aten = torch.ops.aten
def patches(fn):
def skip_cache(self, choices, name, key, benchmark, hint_override=None):
if benchmark is None:
return {}
timings = benchmark(choices)
for choice, timing in timings.items():
if isinstance(choice, select_algorithm.ExternKernelCaller):
# we intentionally make ATEN kernel slower to cover the cases
# where template kernels are always chosen with fusions applied
# and correctness checks at runtime.
timings[choice] = timing * 1000
return timings
for patcher in [
dynamo_config.patch(verbose=True),
dynamo_config.patch(inline_inbuilt_nn_modules=True),
inductor_config.patch(
debug=True,
max_autotune=True,
epilogue_fusion=True,
max_autotune_gemm_backends="CPP,ATEN",
),
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
]:
fn = patcher(fn)
@functools.wraps(fn)
def wrapped(*args, **kwargs):
counters.clear()
torch.manual_seed(12345)
return fn(*args, **kwargs)
return wrapped
@contextlib.contextmanager
def verify(dtype):
# For bfloat16 and half, we have to relax the tolerance
# due to the difference associave orders in different
# kernel implementations
atol, rtol = 1e-4, 1e-4
if dtype == torch.half or dtype == torch.bfloat16:
atol, rtol = 1e-2, 1e-2
with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
yield atol, rtol
def _get_epilogue(epilogue: str, other: Optional[torch.Tensor] = None):
if epilogue == "none":
return lambda x: x
elif epilogue == "relu":
return torch.nn.ReLU()
elif epilogue == "gelu":
return torch.nn.GELU()
elif epilogue == "silu":
return torch.nn.SiLU()
elif epilogue == "sigmoid":
return torch.nn.Sigmoid()
elif epilogue == "tanh":
return torch.nn.Tanh()
elif epilogue == "hardswish":
return torch.nn.Hardswish()
elif epilogue == "hardsigmoid":
return torch.nn.Hardsigmoid()
elif epilogue == "leaky_relu":
return torch.nn.LeakyReLU()
elif epilogue == "hardtanh":
return torch.nn.Hardtanh()
elif epilogue == "add":
return lambda x: x + other
elif epilogue == "sub":
return lambda x: x - other
elif epilogue == "mul":
return lambda x: x * other
elif epilogue == "div":
return lambda x: x / other
class BaseTestSelectAlgorithm(TestCase):
def _check_amx_counter(self, vec_amx):
if vec_amx:
self.assertTrue(counters["inductor"]["cpp_micro_gemm_amx_counter"] > 0)
else:
self.assertEqual(counters["inductor"]["cpp_micro_gemm_amx_counter"], 0)
def _check_brgemm_counter(self, vec_amx):
if vec_amx and torch.cpu._is_amx_fp16_supported():
self.assertTrue(counters["inductor"]["cpp_micro_brgemm_counter"] > 0)
else:
self.assertEqual(counters["inductor"]["cpp_micro_brgemm_counter"], 0)
class TestSelectAlgorithm(BaseTestSelectAlgorithm):
common = check_model
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (1, 2, 1000))
@parametrize("in_features", (1, 1000))
@parametrize("out_features", (1, 1024))
@parametrize("bias", (True, False))
@parametrize("input_3d", (True, False))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_static_shapes(
self, batch_size, in_features, out_features, bias, input_3d, dtype
):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
def forward(self, x):
return self.linear(x)
counters.clear()
mod = M(bias=bias).to(dtype=dtype).eval()
B = (2, batch_size) if input_3d else (batch_size,)
v = torch.randn(*B, in_features).to(dtype=dtype)
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
if (
counters["inductor"]["decompose_mm"] > 0
or counters["inductor"]["decompose_addmm"] > 0
):
# This is a special case where we go directly with vectorized codegen
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 0)
else:
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("in_features", (1000,))
@parametrize("out_features", (1024,))
@parametrize("bias", (True,))
@dtypes(
torch.float,
)
def test_linear_wgt_multi_users(self, in_features, out_features, bias, dtype):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.embeddings = torch.nn.Embedding(out_features, in_features)
self.linear = torch.nn.Linear(in_features, out_features, bias)
self.linear.weight = self.embeddings.weight
def forward(self, x):
x = self.embeddings(x)
return self.linear(x)
counters.clear()
mod = M(bias=bias).to(dtype=dtype).eval()
v = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bias", (True, False))
@dtypes(torch.float)
def test_linear_input_transpose(self, bias, dtype):
batch_size = 384
in_features = 196
out_features = 384
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
@torch.compile
def forward(self, x):
return self.linear(x)
counters.clear()
mod = M(bias=bias).to(dtype=dtype).eval()
v = torch.randn(in_features, batch_size).to(dtype=dtype)
self.common(mod, (v.transpose(0, 1),))
# TODO(jgong5): support transposed input
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 0)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (384,))
@parametrize("in_features", (196,))
@parametrize("out_features", (384, 385))
@parametrize("bias", (True, False))
@parametrize(
"epilogue",
(
"relu",
"gelu",
"silu",
"sigmoid",
"tanh",
"hardswish",
"hardsigmoid",
"leaky_relu",
"hardtanh",
"add",
"sub",
"mul",
"div",
),
)
@dtypes(torch.float, torch.bfloat16, torch.half)
@torch.fx.experimental._config.patch(use_duck_shape=False)
@torch._dynamo.config.patch(specialize_float=True)
def test_linear_with_pointwise(
self, batch_size, in_features, out_features, bias, epilogue, dtype
):
class M(torch.nn.Module):
def __init__(self, bias, epilogue, other):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
self.epilogue = _get_epilogue(epilogue, other)
def forward(self, x):
return self.epilogue(self.linear(x))
counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
u = torch.randn(batch_size, out_features).to(dtype=dtype)
mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
if (
(
(
dtype == torch.bfloat16
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
)
or (
dtype == torch.float16
and torch.ops.mkldnn._is_mkldnn_fp16_supported()
)
)
and epilogue != "mul"
and epilogue != "div"
or (
dtype in (torch.float16, torch.bfloat16)
and epilogue == "add"
and not bias
)
or (
dtype == torch.float32
and epilogue == "add"
and not bias
and not dynamo_config.assume_static_by_default
)
):
# Several scenarios where epilogue fusion is not counted in:
# 1. For bfloat16, the epilogue fusion is part of the template,
# not fused via scheduler. This will also be true for float16 when
# hardware has the float16 instruction. The exception is mul or
# div fusion which is not supported for oneDNN linear.
# 2. For bfloat16/float16, when oneDNN linear is not applied, linear w/o bias
# plus epilogue add is treated as linear w/ bias.
# 3. For float32, when dynamic shapes is enabled, mkl linear is not applied.
# and linear w/o bias plus epilogue add is treated as addmm.
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0)
else:
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (384,))
@parametrize("in_features", (196,))
@parametrize("out_features", (128, 129))
@parametrize("bias", (True, False))
@parametrize(
"epilogue",
(
"none",
"relu",
"add",
"sub",
"mul",
),
)
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_with_transpose(
self, batch_size, in_features, out_features, bias, epilogue, dtype
):
class M(torch.nn.Module):
def __init__(self, bias, epilogue, other):
super().__init__()
self.epilogue = _get_epilogue(epilogue, other)
self.linear = torch.nn.Linear(in_features, out_features, bias)
def forward(self, x, y):
return self.epilogue(self.linear(x)).transpose(0, 1) + y
counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
u = torch.randn(out_features, batch_size).to(dtype=dtype)
other = torch.randn(batch_size, out_features).to(dtype=dtype)
mod = M(bias=bias, epilogue=epilogue, other=other).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v, u), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@parametrize("batch_size", (1,))
@parametrize("in_features", (16,))
@parametrize("image_size", (18,))
@parametrize("out_features", (32,))
@parametrize(
"bias",
(
False,
True,
),
)
@parametrize(
"has_non_epilogue_users",
(
True,
False,
),
)
@dtypes(torch.bfloat16)
def test_linear_with_permute(
self,
batch_size,
in_features,
image_size,
out_features,
bias,
has_non_epilogue_users,
dtype,
):
# Reproducer from the convnext model in timm
class M(torch.nn.Module):
def __init__(self, bias, has_non_epilogue_users):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
self._frozen_param398 = torch.randn(batch_size, out_features, 1, 1)
self.conv = torch.nn.Conv2d(
out_features,
out_features,
kernel_size=7,
padding=3,
groups=out_features,
)
self.linear2 = torch.nn.Linear(out_features, out_features, bias)
self._frozen_param400 = torch.randn(batch_size, out_features, 1, 1)
self.has_non_epilogue_users = has_non_epilogue_users
def forward(self, mul_272, _convolution_pointwise_default_31):
out1 = torch.ops.prims.convert_element_type.default(
mul_272, torch.bfloat16
)
mul_272 = None
_linear_pointwise_default_131 = self.linear(out1)
permute_188 = torch.ops.aten.permute.default(
_linear_pointwise_default_131, [0, 3, 1, 2]
)
mul_273 = torch.ops.aten.mul.Tensor(permute_188, self._frozen_param398)
add_187 = torch.ops.aten.add.Tensor(
mul_273, _convolution_pointwise_default_31
)
convert_element_type_847 = torch.ops.prims.convert_element_type.default(
add_187, torch.bfloat16
)
_convolution_pointwise_default_29 = self.conv(convert_element_type_847)
permute_189 = torch.ops.aten.permute.default(
_convolution_pointwise_default_29, [0, 2, 3, 1]
)
permute_189 = self.linear2(permute_189)
permute_189 = torch.ops.aten.permute.default(permute_189, [0, 3, 1, 2])
permute_189 = torch.ops.aten.mul.Tensor(
permute_189, self._frozen_param400
)
# If template_buffer will be used by nodes other than the epilogue nodes,
# we can't alias the template_buffer with the Y buffer.
if self.has_non_epilogue_users:
add_191 = torch.ops.aten.add.Tensor(permute_189, add_187)
return add_191
return permute_189
view_12 = torch.randn(batch_size, image_size, image_size, in_features)
_convolution_pointwise_default_31 = torch.randn(
batch_size, out_features, image_size, image_size
).to(memory_format=torch.channels_last)
mod = M(bias=bias, has_non_epilogue_users=has_non_epilogue_users).eval()
with verify(dtype) as (atol, rtol), torch.cpu.amp.autocast():
self.common(
mod,
(
view_12,
_convolution_pointwise_default_31,
),
atol=atol,
rtol=rtol,
)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (8,))
@parametrize("in_features", (3,))
@parametrize("linear_in_features", (384,))
@parametrize("out_features", (196,))
@parametrize("bias", (True,))
@dtypes(torch.float)
def test_linear_with_input_of_flexible_layout(
self, batch_size, in_features, linear_in_features, out_features, bias, dtype
):
# Reproducer from the resmlp_12_224 model in timm
flatten_BS = int(batch_size * linear_in_features)
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.conv = torch.nn.Conv2d(
in_features,
linear_in_features,
kernel_size=16,
padding=0,
stride=16,
dilation=1,
groups=1,
)
self._frozen_param151 = torch.randn(1, 1, linear_in_features)
self._frozen_param3 = torch.randn(1, 1, linear_in_features)
self._frozen_param2 = torch.randn(linear_in_features)
self.linear = torch.nn.Linear(out_features, out_features, bias)
def forward(self, arg150_1):
_convolution_pointwise_default = self.conv(arg150_1)
view_73 = torch.ops.aten.reshape.default(
_convolution_pointwise_default,
[batch_size, linear_in_features, out_features],
)
_convolution_pointwise_default = None
permute_62 = torch.ops.aten.permute.default(view_73, [0, 2, 1])
view_73 = None
mul_111 = torch.ops.aten.mul.Tensor(self._frozen_param151, permute_62)
add_73 = torch.ops.aten.add.Tensor(self._frozen_param3, mul_111)
permute_63 = torch.ops.aten.permute.default(add_73, [0, 2, 1])
add_73 = None
view_74 = torch.ops.aten.reshape.default(
permute_63, [flatten_BS, out_features]
)
permute_63 = None
_mkl_linear_36 = self.linear(view_74)
view_75 = torch.ops.aten.reshape.default(
_mkl_linear_36, [batch_size, linear_in_features, out_features]
)
_mkl_linear_36 = None
permute_65 = torch.ops.aten.permute.default(view_75, [0, 2, 1])
view_75 = None
mul_112 = torch.ops.aten.mul.Tensor(self._frozen_param2, permute_65)
_frozen_param2 = permute_65 = None
add_74 = torch.ops.aten.add.Tensor(permute_62, mul_112)
permute_62 = mul_112 = None
return add_74
v = torch.randn(batch_size, in_features, 224, 224).to(dtype=dtype)
mod = M(bias=bias).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (8,))
@parametrize("in_features", (128,))
@parametrize("size_0", (4,))
@parametrize("size_1", (14,))
@parametrize("out_features", (512,))
@parametrize("out_features_conv", (256,))
@parametrize(
"bias",
(
False,
True,
),
)
@parametrize(
"epilogue",
(
False,
True,
),
)
@dtypes(torch.float32)
def test_linear_unsupported_epilogue_fusion(
self,
batch_size,
in_features,
size_0,
size_1,
out_features,
out_features_conv,
bias,
epilogue,
dtype,
):
img_size_0 = int(size_0 * size_0)
img_size_1 = int(size_1 * size_1)
conv_shape = int(size_0 * size_1)
flatten_BS = int(batch_size * size_0 * size_0 * size_1 * size_1)
# Reproducer from the jx_nest_base model in timm
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear1 = torch.nn.Linear(in_features, in_features, bias=bias)
self.linear2 = torch.nn.Linear(out_features, in_features, bias=bias)
self.conv = torch.nn.Conv2d(
in_features,
out_features_conv,
kernel_size=3,
padding=1,
stride=1,
dilation=1,
groups=1,
)
self.epilogue = epilogue
def forward(self, mul_239, view_425, add_184):
_mkl_linear_91 = self.linear1(view_425)
view_426 = torch.ops.aten.reshape.default(
_mkl_linear_91, [batch_size, img_size_0, img_size_1, in_features]
)
_mkl_linear_91 = None
add_187 = torch.ops.aten.add.Tensor(add_184, view_426)
add_184 = view_426 = None
view_429 = torch.ops.aten.reshape.default(
mul_239, [flatten_BS, out_features]
)
mul_239 = None
_mkl_linear_89 = self.linear2(view_429)
if self.epilogue:
_mkl_linear_89 = torch.pow(_mkl_linear_89, 2)
_mkl_linear_89 = test_operators.realize(_mkl_linear_89)
view_430 = torch.ops.aten.reshape.default(
_mkl_linear_89, [batch_size, img_size_0, img_size_1, in_features]
)
_mkl_linear_89 = None
add_191 = torch.ops.aten.add.Tensor(add_187, view_430)
add_187 = view_430 = None
view_431 = torch.ops.aten.reshape.default(
add_191, [batch_size, size_0, size_0, size_1, size_1, in_features]
)
add_191 = None
permute_203 = torch.ops.aten.permute.default(
view_431, [0, 1, 3, 2, 4, 5]
)
view_431 = None
clone_188 = torch.ops.aten.clone.default(
permute_203, memory_format=torch.contiguous_format
)
permute_203 = None
view_432 = torch.ops.aten.reshape.default(
clone_188, [batch_size, conv_shape, conv_shape, in_features]
)
clone_188 = None
permute_204 = torch.ops.aten.permute.default(view_432, [0, 3, 1, 2])
view_432 = None
_convolution_pointwise_default_1 = self.conv(permute_204)
return _convolution_pointwise_default_1
mul_239 = torch.randn(batch_size, img_size_0, img_size_1, out_features)
view_425 = torch.randn(flatten_BS, in_features)
add_184 = torch.randn(batch_size, img_size_0, img_size_1, in_features)
mod = M(bias=bias).eval()
with (
verify(dtype) as (atol, rtol),
torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16),
):
self.common(
mod,
(
mul_239,
view_425,
add_184,
),
atol=atol,
rtol=rtol,
)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
# TODO: change cpp_epilogue_fusion_counter to 1 once supported
self.assertEqual(
counters["inductor"]["cpp_epilogue_fusion_counter"], 1 if epilogue else 0
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (384,))
@parametrize("in_features", (196,))
@parametrize("out_features", (384, 385))
@parametrize("bias", (True, False))
@parametrize(
"unary",
("relu",),
)
@parametrize(
"binary",
(
"add",
"sub",
"mul",
"div",
),
)
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_with_unary_binary(
self, batch_size, in_features, out_features, bias, unary, binary, dtype
):
class M(torch.nn.Module):
def __init__(self, bias, unary, binary, other):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
self.unary = _get_epilogue(unary)
self.binary = _get_epilogue(binary, other)
def forward(self, x):
return self.binary(self.unary(self.linear(x)))
counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
u = torch.randn(batch_size, out_features).to(dtype=dtype)
mod = M(bias=bias, unary=unary, binary=binary, other=u).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (384,))
@parametrize("in_features", (196,))
@parametrize("out_features", (384,))
@parametrize("bias", (True, False))
@parametrize(
"binary",
("add",),
)
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_with_binary_input_3d(
self, batch_size, in_features, out_features, bias, binary, dtype
):
class M(torch.nn.Module):
def __init__(self, bias, binary, other):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
self.binary = _get_epilogue(binary, other)
def forward(self, x):
return self.binary(self.linear(x))
counters.clear()
B = (2, batch_size)
v = torch.randn(*B, in_features).to(dtype=dtype)
u = torch.randn(*B, out_features).to(dtype=dtype)
mod = M(bias=bias, binary=binary, other=u).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@set_num_threads(1)
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
@parametrize("batch_size", (256,))
@parametrize("in_features", (3,))
@parametrize("out_features", (1024,))
@parametrize("out_features2", (2,))
@parametrize("bias", (True, False))
@dtypes(torch.float)
def test_linear_local_and_global_buffer_dynamic_shapes(
self, batch_size, in_features, out_features, out_features2, bias, dtype
):
# Reproducer from soft_actor_critic
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
self.linear1 = torch.nn.Linear(out_features, out_features, bias)
self.linear2 = torch.nn.Linear(out_features, out_features2, bias)
def forward(self, arg7_1):
addmm_3 = self.linear(arg7_1)
relu_2 = torch.ops.aten.relu.default(addmm_3)
addmm_4 = self.linear1(relu_2)
relu_3 = torch.ops.aten.relu.default(addmm_4)
addmm_5 = self.linear2(relu_3)
split_1 = torch.ops.aten.split.Tensor(addmm_5, 1, 1)
getitem_2 = split_1[0]
getitem_3 = split_1[1]
tanh_1 = torch.ops.aten.tanh.default(getitem_3)
add_62 = torch.ops.aten.add.Tensor(tanh_1, 1)
mul_36 = torch.ops.aten.mul.Tensor(add_62, 6.0)
add_69 = torch.ops.aten.add.Tensor(mul_36, -10.0)
exp_1 = torch.ops.aten.exp.default(add_69)
return (getitem_2, exp_1)
counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M(bias=bias).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 3)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@parametrize("batch_size", (1024,))
@parametrize("in_features", (1024,))
@parametrize("out_features", (1024, 1025))
@parametrize("bias", (True, False))
@dtypes(torch.bfloat16, torch.half)
def test_linear_amx(self, batch_size, in_features, out_features, bias, dtype):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
def forward(self, x):
return self.linear(x)
counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M(bias=bias).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
vec_amx = VecAMX()
# Currently brgemm config is only added for half
if dtype == torch.half:
self._check_brgemm_counter(vec_amx)
else:
self._check_amx_counter(vec_amx)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (8,))
@parametrize("in_features", (128,))
@parametrize("in_features_2", (196,))
@parametrize("out_features", (256,))
@parametrize(
"bias",
(True,),
)
@dtypes(torch.float32)
def test_linear_with_multiple_reindexers(
self,
batch_size,
in_features,
in_features_2,
out_features,
bias,
dtype,
):
flatten_BS = int(batch_size * in_features_2)
# Reproducer from the levit_128 model in timm
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.conv = torch.nn.Conv2d(
64,
128,
kernel_size=3,
padding=1,
stride=2,
dilation=1,
groups=1,
)
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
self._frozen_param221 = torch.randn(out_features)
self._frozen_param389 = torch.randn(out_features)
self._frozen_param20 = torch.randn(out_features)
self._frozen_param21 = torch.randn(out_features)
def forward(self, view_368):
_mkl_linear_57 = self.linear(view_368)
view_369 = torch.ops.aten.reshape.default(
_mkl_linear_57, [batch_size, in_features_2, out_features]
)
_mkl_linear_57 = None
view_370 = torch.ops.aten.reshape.default(
view_369, [flatten_BS, out_features]
)
view_369 = None
sub_85 = torch.ops.aten.sub.Tensor(view_370, self._frozen_param221)
view_370 = _frozen_param221 = None
mul_261 = torch.ops.aten.mul.Tensor(sub_85, self._frozen_param389)
sub_85 = _frozen_param389 = None
mul_262 = torch.ops.aten.mul.Tensor(mul_261, self._frozen_param20)
mul_261 = _frozen_param20 = None
add_219 = torch.ops.aten.add.Tensor(mul_262, self._frozen_param21)
mul_262 = _frozen_param21 = None
view_371 = torch.ops.aten.reshape.default(
add_219, [batch_size, in_features_2, out_features]
)
add_219 = None
add_220 = torch.ops.aten.add.Tensor(view_371, 3)
clamp_min_35 = torch.ops.aten.clamp_min.default(add_220, 0)
add_220 = None
clamp_max_35 = torch.ops.aten.clamp_max.default(clamp_min_35, 6)
clamp_min_35 = None
mul_263 = torch.ops.aten.mul.Tensor(view_371, clamp_max_35)
view_371 = clamp_max_35 = None
div_51 = torch.ops.aten.div.Tensor(mul_263, 6)
mul_263 = None
return div_51
view_368 = torch.randn(flatten_BS, in_features)
mod = M(bias=bias).eval()
with verify(dtype) as (atol, rtol):
self.common(
mod,
(view_368,),
atol=atol,
rtol=rtol,
)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@parametrize("batch_size", (384,))
@parametrize("in_features", (196,))
@parametrize("out_features", (384,))
@parametrize("bias", (True, False))
@dtypes(torch.bfloat16)
def test_linear_with_embedding(
self, batch_size, in_features, out_features, bias, dtype
):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias).to(
dtype=dtype
)
self.emb = torch.nn.Embedding(64, out_features)
def forward(self, idx, x):
return self.emb(idx) + self.linear(x)
idx = torch.randint(0, 64, (batch_size,))
x = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M(bias=bias).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (idx, x), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@parametrize("batch_size", (2,))
@parametrize("in_features", (16,))
@parametrize("seq_lens", (128,))
@parametrize("out_features", (32,))
@parametrize("bias", (True,))
@dtypes(torch.bfloat16)
def test_linear_with_indirect_indexing(
self, batch_size, in_features, seq_lens, out_features, bias, dtype
):
# Reproducer from the GPT2ForSequenceClassification model in HuggingFace
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.wte = torch.nn.Embedding(128, seq_lens)
self.wpe = torch.nn.Embedding(in_features, seq_lens)
self.linear = torch.nn.Linear(out_features, seq_lens, bias)
def forward(self, view_12, input_ids, view_9):
inputs_embeds = self.wte(input_ids)
position_ids = torch.arange(0, in_features, dtype=torch.long)
position_ids = position_ids.unsqueeze(0)
position_embeds = self.wpe(position_ids)
add = inputs_embeds + position_embeds
add_4 = view_9 + add
_linear_pointwise_default_45 = self.linear(view_12)
view_13 = torch.ops.aten.reshape.default(
_linear_pointwise_default_45, [batch_size, in_features, seq_lens]
)
out = torch.ops.aten.add.Tensor(add_4, view_13)
return out
view_12 = torch.randn(batch_size * in_features, out_features)
input_ids = torch.randint(0, 128, (batch_size, in_features))
view_9 = torch.randn(batch_size, in_features, seq_lens)
mod = M(bias=bias).eval()
with verify(dtype) as (atol, rtol), torch.cpu.amp.autocast():
self.common(
mod,
(
view_12,
input_ids,
view_9,
),
atol=atol,
rtol=rtol,
)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@parametrize("batch_size", (8,))
@parametrize("in_features", (3,))
@parametrize("in_features2", (192,))
@parametrize("image_size", (224,))
@parametrize("out_features", (64,))
@parametrize(
"bias",
(True,),
)
@dtypes(torch.float32)
def test_linear_with_in_out_buffer(
self,
batch_size,
in_features,
in_features2,
image_size,
out_features,
bias,
dtype,
):
# Reproducer from the coat_lite_mini model in timm
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self._frozen_param398 = torch.randn(batch_size, out_features, 1, 1)
self.conv = torch.nn.Conv2d(
in_features,
out_features,
kernel_size=4,
padding=0,
stride=4,
dilation=1,
groups=1,
)
self.conv2 = torch.nn.Conv2d(
out_features,
out_features,
kernel_size=3,
padding=1,
stride=1,
dilation=1,
groups=out_features,
)
self.conv3 = torch.nn.Conv2d(
16,
16,
kernel_size=3,
padding=1,
stride=1,
dilation=1,
groups=16,
)
self.conv4 = torch.nn.Conv2d(
24,
24,
kernel_size=5,
padding=2,
stride=1,
dilation=1,
groups=24,
)
self.conv5 = torch.nn.Conv2d(
24,
24,
kernel_size=7,
padding=3,
stride=1,
dilation=1,
groups=24,
)
self.linear = torch.nn.Linear(out_features, in_features2, bias)
self.linear2 = torch.nn.Linear(out_features, out_features, bias)
self._frozen_param2 = torch.randn(out_features)
self._frozen_param3 = torch.randn(out_features)
self._frozen_param7 = torch.randn(out_features)
self._frozen_param8 = torch.randn(out_features)
self._frozen_param153 = torch.randn(batch_size, 1, out_features)
def forward(self, arg152_1):
_convolution_pointwise_default_35 = self.conv(arg152_1)
arg152_1 = None
view_168 = torch.ops.aten.reshape.default(
_convolution_pointwise_default_35, [8, 64, 3136]
)
_convolution_pointwise_default_35 = None
permute_97 = torch.ops.aten.permute.default(view_168, [0, 2, 1])
view_168 = None
clone_65 = torch.ops.aten.clone.default(
permute_97, memory_format=torch.contiguous_format
)
permute_97 = None
var_mean_21 = torch.ops.aten.var_mean.correction(
clone_65, [2], correction=0, keepdim=True
)
getitem_90 = var_mean_21[0]
getitem_91 = var_mean_21[1]
var_mean_21 = None
add_82 = torch.ops.aten.add.Tensor(getitem_90, 1e-05)
getitem_90 = None
rsqrt_21 = torch.ops.aten.rsqrt.default(add_82)
add_82 = None
sub_29 = torch.ops.aten.sub.Tensor(clone_65, getitem_91)
clone_65 = getitem_91 = None
mul_82 = torch.ops.aten.mul.Tensor(sub_29, rsqrt_21)
sub_29 = rsqrt_21 = None
mul_83 = torch.ops.aten.mul.Tensor(mul_82, self._frozen_param2)
mul_82 = None
add_83 = torch.ops.aten.add.Tensor(mul_83, self._frozen_param3)
mul_83 = None
_frozen_param153 = self._frozen_param153
cat_20 = torch.ops.aten.cat.default([_frozen_param153, add_83], 1)
_frozen_param153 = add_83 = None
slice_111 = torch.ops.aten.slice.Tensor(cat_20, 1, 0, 1)
slice_113 = torch.ops.aten.slice.Tensor(
cat_20, 1, 1, 9223372036854775807
)
cat_20 = None
permute_98 = torch.ops.aten.permute.default(slice_113, [0, 2, 1])
slice_113 = None
view_169 = torch.ops.aten.reshape.default(permute_98, [8, 64, 56, 56])
permute_98 = None
_convolution_pointwise_default_34 = self.conv2(view_169)
add_84 = torch.ops.aten.add.Tensor(
_convolution_pointwise_default_34, view_169
)
_convolution_pointwise_default_34 = view_169 = None
view_170 = torch.ops.aten.reshape.default(add_84, [8, 64, 3136])
add_84 = None
permute_99 = torch.ops.aten.permute.default(view_170, [0, 2, 1])
view_170 = None
cat_21 = torch.ops.aten.cat.default([slice_111, permute_99], 1)
slice_111 = permute_99 = None
var_mean_22 = torch.ops.aten.var_mean.correction(
cat_21, [2], correction=0, keepdim=True
)
getitem_92 = var_mean_22[0]
getitem_93 = var_mean_22[1]
var_mean_22 = None
add_85 = torch.ops.aten.add.Tensor(getitem_92, 1e-06)
getitem_92 = None
rsqrt_22 = torch.ops.aten.rsqrt.default(add_85)
add_85 = None
sub_30 = torch.ops.aten.sub.Tensor(cat_21, getitem_93)
getitem_93 = None
mul_84 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_22)
sub_30 = rsqrt_22 = None
mul_85 = torch.ops.aten.mul.Tensor(mul_84, self._frozen_param7)
mul_84 = None
add_86 = torch.ops.aten.add.Tensor(mul_85, self._frozen_param8)
mul_85 = None
view_171 = torch.ops.aten.reshape.default(add_86, [25096, 64])
add_86 = None
_mkl_linear_32 = self.linear(view_171)
view_171 = None
view_172 = torch.ops.aten.reshape.default(
_mkl_linear_32, [8, 3137, 192]
)
_mkl_linear_32 = None
view_173 = torch.ops.aten.reshape.default(view_172, [8, 3137, 3, 8, 8])
view_172 = None
permute_101 = torch.ops.aten.permute.default(view_173, [2, 0, 3, 1, 4])
view_173 = None
unbind_8 = torch.ops.aten.unbind.int(permute_101)
permute_101 = None
getitem_94 = unbind_8[0]
getitem_95 = unbind_8[1]
getitem_96 = unbind_8[2]
unbind_8 = None
clone_66 = torch.ops.aten.clone.default(
getitem_95, memory_format=torch.contiguous_format
)
getitem_95 = None
amax_8 = torch.ops.aten.amax.default(clone_66, [2], True)
sub_31 = torch.ops.aten.sub.Tensor(clone_66, amax_8)
clone_66 = amax_8 = None
exp_8 = torch.ops.aten.exp.default(sub_31)
sub_31 = None
sum_9 = torch.ops.aten.sum.dim_IntList(exp_8, [2], True)
div_8 = torch.ops.aten.div.Tensor(exp_8, sum_9)
exp_8 = sum_9 = None
permute_102 = torch.ops.aten.permute.default(div_8, [0, 1, 3, 2])
div_8 = None
expand_37 = torch.ops.aten.expand.default(permute_102, [8, 8, 8, 3137])
permute_102 = None
view_174 = torch.ops.aten.reshape.default(expand_37, [64, 8, 3137])
expand_37 = None
expand_38 = torch.ops.aten.expand.default(getitem_96, [8, 8, 3137, 8])
clone_67 = torch.ops.aten.clone.default(
expand_38, memory_format=torch.contiguous_format
)
expand_38 = None
view_175 = torch.ops.aten.reshape.default(clone_67, [64, 3137, 8])
clone_67 = None
bmm_16 = torch.ops.aten.bmm.default(view_174, view_175)
view_174 = view_175 = None
view_176 = torch.ops.aten.reshape.default(bmm_16, [8, 8, 8, 8])
bmm_16 = None
expand_39 = torch.ops.aten.expand.default(getitem_94, [8, 8, 3137, 8])
clone_68 = torch.ops.aten.clone.default(
expand_39, memory_format=torch.contiguous_format
)
expand_39 = None
view_177 = torch.ops.aten.reshape.default(clone_68, [64, 3137, 8])
clone_68 = None
expand_40 = torch.ops.aten.expand.default(view_176, [8, 8, 8, 8])
view_176 = None
view_178 = torch.ops.aten.reshape.default(expand_40, [64, 8, 8])
expand_40 = None
bmm_17 = torch.ops.aten.bmm.default(view_177, view_178)
view_177 = view_178 = None
view_179 = torch.ops.aten.reshape.default(bmm_17, [8, 8, 3137, 8])
bmm_17 = None
slice_116 = torch.ops.aten.slice.Tensor(
getitem_94, 2, 1, 9223372036854775807
)
getitem_94 = None
slice_120 = torch.ops.aten.slice.Tensor(
getitem_96, 2, 1, 9223372036854775807
)
getitem_96 = None
permute_103 = torch.ops.aten.permute.default(slice_120, [0, 1, 3, 2])
slice_120 = None
view_180 = torch.ops.aten.reshape.default(permute_103, [8, 64, 56, 56])
permute_103 = None
split_with_sizes_8 = torch.ops.aten.split_with_sizes.default(
view_180, [16, 24, 24], 1
)
view_180 = None
getitem_97 = split_with_sizes_8[0]
getitem_98 = split_with_sizes_8[1]
getitem_99 = split_with_sizes_8[2]
split_with_sizes_8 = None
_convolution_pointwise_default_33 = self.conv3(getitem_97)
_convolution_pointwise_default_32 = self.conv4(getitem_98)
_convolution_pointwise_default_31 = self.conv5(getitem_99)
cat_22 = torch.ops.aten.cat.default(
[
_convolution_pointwise_default_33,
_convolution_pointwise_default_32,
_convolution_pointwise_default_31,
],
1,
)
_convolution_pointwise_default_33 = (
_convolution_pointwise_default_32
) = _convolution_pointwise_default_31 = None
view_181 = torch.ops.aten.reshape.default(cat_22, [8, 8, 8, 3136])
cat_22 = None
permute_104 = torch.ops.aten.permute.default(view_181, [0, 1, 3, 2])
view_181 = None
mul_86 = torch.ops.aten.mul.Tensor(slice_116, permute_104)
slice_116 = permute_104 = None
constant_pad_nd_8 = torch.ops.aten.constant_pad_nd.default(
mul_86, [0, 0, 1, 0, 0, 0], 0.0
)
mul_86 = None
mul_87 = torch.ops.aten.mul.Tensor(view_179, 0.3535533905932738)
view_179 = None
add_87 = torch.ops.aten.add.Tensor(mul_87, constant_pad_nd_8)
mul_87 = constant_pad_nd_8 = None
return add_87
view_12 = torch.randn(batch_size, in_features, image_size, image_size)
mod = M(bias=bias).eval()
with verify(dtype) as (atol, rtol):
self.common(
mod,
(view_12,),
atol=atol,
rtol=rtol,
)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
self.assertEqual(
counters["inductor"]["cpp_epilogue_fusion_counter"], 2 if TEST_MKL else 1
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (32,))
@parametrize("in_features", (128,))
@parametrize("out_features", (64, 65))
@parametrize("bias", (False, True))
@parametrize("input_3d", (False, True))
@dtypes(torch.float32, torch.bfloat16)
@parametrize(
"epilogue",
(
"none",
"relu",
"gelu",
),
)
@skipIfWindows(msg="Windows don't support quantize.")
def test_quantized_linear_with_pointwise(
self, batch_size, in_features, out_features, bias, input_3d, dtype, epilogue
):
B = (2, batch_size) if input_3d else (batch_size,)
input = torch.randn(*B, in_features).to(dtype=torch.float32)
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
self.epilogue = _get_epilogue(epilogue)
self.linear2 = torch.nn.Linear(out_features, out_features, bias)
self.epilogue2 = _get_epilogue(epilogue)
def forward(self, x):
res = self.epilogue(self.linear(x))
res = self.epilogue2(self.linear2(res))
return res
counters.clear()
ref_quantized_mod = _generate_qdq_quantized_model(
M(bias=bias).eval(),
(input,),
)
atol, rtol = 1e-3, 1e-3
if dtype == torch.bfloat16:
atol, rtol = 5e-2, 5e-2
with (
patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)),
torch.no_grad(),
torch.autocast("cpu", enabled=(dtype == torch.bfloat16), dtype=dtype),
):
ref_res = ref_quantized_mod(input)
cfn = torch.compile(ref_quantized_mod)
res = cfn(input)
self.assertEqual(
res,
ref_res,
atol=atol,
rtol=rtol,
equal_nan=True,
exact_dtype=True,
)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@dtypes(torch.bfloat16)
@parametrize(
"batch_size",
(
1,
17,
32,
),
)
@parametrize(
"mid_dim",
(
1,
8,
),
)
@parametrize("in_features", (128, 144, 1024))
@parametrize("out_features", (64, 65, 1024))
def test_int8_woq_mm(self, dtype, batch_size, mid_dim, in_features, out_features):
def _convert_weight_to_int8pack(w):
scale, zp = _calculate_dynamic_per_channel_qparams(
w.to(torch.float), torch.int8
)
scale = torch.from_numpy(scale)
zp = torch.from_numpy(zp)
w_int8 = torch.ao.quantization.fx._decomposed.quantize_per_channel(
input=w,
scales=scale,
zero_points=zp,
axis=0,
quant_min=-128,
quant_max=127,
dtype=torch.int8,
)
return w_int8, scale.to(torch.bfloat16)
class M(torch.nn.Module):
def __init__(self, w):
super().__init__()
self.linear_weight = torch.nn.Parameter(w, requires_grad=False)
def forward(self, x, scale):
return (
torch.nn.functional.linear(x, self.linear_weight.to(x.dtype))
* scale
)
counters.clear()
# Currently, the corresponding torch.fx pattern only supports 3D x
# Add 2D X case once the corresponding pattern-matcher pattern is added
x = torch.rand((batch_size, mid_dim, in_features), dtype=dtype)
w = torch.rand((out_features, in_features), dtype=dtype)
w_int8pack, w_scales = _convert_weight_to_int8pack(w)
mod = M(w_int8pack).eval()
self.common(mod, (x, w_scales))
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
if batch_size * mid_dim >= 16:
vec_amx = VecAMX()
self._check_amx_counter(vec_amx)
@inductor_config.patch({"freezing": True, "cpp.enable_concat_linear": True})
@patches
@torch.no_grad
@dtypes(torch.bfloat16)
@parametrize(
"batch_size",
(
1,
32,
),
)
@parametrize(
"mid_dim",
(
1,
8,
),
)
@parametrize("in_features", (128,))
@parametrize("out_features", (64,))
def test_int8_woq_mm_concat(
self, dtype, batch_size, mid_dim, in_features, out_features
):
def _convert_weight_to_int8pack(w):
scale, zp = _calculate_dynamic_per_channel_qparams(
w.to(torch.float), torch.int8
)
scale = torch.from_numpy(scale)
zp = torch.from_numpy(zp)
w_int8 = torch.ao.quantization.fx._decomposed.quantize_per_channel(
input=w,
scales=scale,
zero_points=zp,
axis=0,
quant_min=-128,
quant_max=127,
dtype=torch.int8,
)
return w_int8, scale.to(torch.bfloat16)
class M(torch.nn.Module):
def __init__(self, w1, w2, w3):
super().__init__()
self.w1 = torch.nn.Parameter(w1, requires_grad=False)
self.w2 = torch.nn.Parameter(w2, requires_grad=False)
self.w3 = torch.nn.Parameter(w3, requires_grad=False)
def forward(self, x, scale1, scale2, scale3):
# Ref: _linear_fp_act_int8_weight_impl in torchao/dtypes/uintx/plain_layout.py
y1 = (
torch.mm(x.reshape(-1, x.shape[-1]), self.w1.t().to(x.dtype))
* scale1
)
y2 = (
torch.mm(x.reshape(-1, x.shape[-1]), self.w2.t().to(x.dtype))
* scale2
)
y3 = (
torch.mm(x.reshape(-1, x.shape[-1]), self.w3.t().to(x.dtype))
* scale3
)
return (
y1.reshape(*x.shape[:-1], y1.shape[-1]),
y2.reshape(*x.shape[:-1], y2.shape[-1]),
y3.reshape(*x.shape[:-1], y3.shape[-1]),
)
counters.clear()
# Currently, the corresponding torch.fx pattern only supports 3D x
# Add 2D X case once the corresponding pattern-matcher pattern is added
x = torch.rand((batch_size, mid_dim, in_features), dtype=dtype)
w1 = torch.rand((out_features, in_features), dtype=dtype)
w2 = torch.rand((out_features, in_features), dtype=dtype)
w3 = torch.rand((out_features, in_features), dtype=dtype)
w1_int8pack, w1_scales = _convert_weight_to_int8pack(w1)
w2_int8pack, w2_scales = _convert_weight_to_int8pack(w2)
w3_int8pack, w3_scales = _convert_weight_to_int8pack(w3)
mod = M(w1_int8pack, w2_int8pack, w3_int8pack).eval()
self.common(mod, (x, w1_scales, w2_scales, w3_scales))
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
if batch_size * mid_dim >= 16:
vec_amx = VecAMX()
self._check_amx_counter(vec_amx)
@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
# We set allow_ignore_mark_dynamic to True because Dynamo may end up specializing M dimension
# despite it being marked as dynamic with mark_dynamic.
@dynamo_config.patch({"allow_ignore_mark_dynamic": True})
@parametrize("has_bias", [True, False])
@parametrize("dtype", [torch.float, torch.bfloat16])
@parametrize("per_channel_quant", [True, False])
@parametrize("reshape_a", [True, False])
@parametrize("expand_a_scale", [True, False])
@parametrize("dynamic", [True, False])
@parametrize("M", [1, 32])
def test_da8w8_sym_act_sym_wgt_with_int_mm(
self, has_bias, dtype, per_channel_quant, reshape_a, expand_a_scale, dynamic, M
):
r"""
This testcase check if we can match the int8_dynamic_activation_int8_weight int8 linear pattern from torchao,
when activation is symmetrically quantized dynamically & weights are symmetrically quantized (statically)
The pattern is:
(no bias) _int_mm -> convert_element_type -> ([maybe_expand_a_scale] -> mul) -> mul
or
(with bias) pattern_no_bias -> add
Expansion of the scale of activation is optional.
The pattern depiction doesn't mean that convert_element_type output is fed into expand_a as input,
but simply that activation scale may be applied after an expand operation on it.
"""
if dtype == torch.bfloat16 and not torch.ops.mkldnn._is_mkldnn_bf16_supported():
return
in_feature = 48
out_feature = 64
q_min, q_max = -32, 31
class Mod(torch.nn.Module):
def __init__(self, dtype: torch.dtype, has_bias: bool):
super().__init__()
self.dtype = dtype
self.has_bias = has_bias
self.b = torch.randint(
q_min, q_max, [in_feature, out_feature], dtype=torch.int8
)
self.per_channel_quant = per_channel_quant
a_scale_per_tensor = torch.rand([1], dtype=dtype) * 0.01 + 0.01
a_scale_per_channel = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01
self.a_scale = (
a_scale_per_channel if per_channel_quant else a_scale_per_tensor
)
self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01
self.b_scale = self.b_scale.to(dtype)
self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None
def forward(self, a):
if reshape_a:
a_reshaped = a.reshape(-1, a.size(-1))
else:
a_reshaped = a
c = torch._int_mm(a_reshaped, self.b)
c = c.to(self.dtype)
if not expand_a_scale:
a_scale = self.a_scale
else:
a_scale = self.a_scale.expand(c.shape)
c = c * a_scale
c = c * self.b_scale
if self.has_bias:
c = c + self.bias
return c
mod = Mod(dtype, has_bias).eval()
a = torch.randint(q_min, q_max, [M, in_feature], dtype=torch.int8)
if dynamic:
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_static(a, 1)
self.common(
mod,
(a,),
atol=1e-2 if dtype is torch.bfloat16 else None,
rtol=1e-2 if dtype is torch.bfloat16 else None,
)
vec_amx = VecAMX()
self._check_amx_counter(vec_amx)
if torch._C._cpu._is_amx_tile_supported():
# Only AMX ISA based micro-kernel is currently supported for da8w8
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@dtypes(torch.bfloat16)
@parametrize("batch_size", (1,))
@parametrize("in_features", (128, 256))
@parametrize("out_features", (64, 128))
@parametrize("group_size", (32, 64))
def test_int4_woq_mm_avx512(
self, dtype, batch_size, in_features, out_features, group_size
):
class M(torch.nn.Module):
def __init__(self, K, N, group_size):
super().__init__()
self.linear_weight = torch.randint(
0, 15, (N, K // 2), dtype=torch.uint8
)
self.qscale_and_zeros = torch.rand(K // group_size, N, 2, dtype=dtype)
self.group_size = group_size
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x_shape[-1])
y = torch._weight_int4pack_mm_for_cpu(
x, self.linear_weight, self.group_size, self.qscale_and_zeros
)
return y.reshape(*x_shape[:-1], out_features)
counters.clear()
seq_len = 4
x = torch.rand((batch_size, seq_len, in_features), dtype=dtype)
mod = M(in_features, out_features, group_size).eval()
self.common(mod, (x,), reference_in_float=False)
available_isa = torch._inductor.cpu_vec_isa.pick_vec_isa()
avx512_available = "avx512" in str(available_isa)
autotune_count = 1 if avx512_available else 0
self.assertEqual(
counters["inductor"]["select_algorithm_autotune"], autotune_count
)
@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@dtypes(torch.bfloat16)
@parametrize("batch_size", (64,))
@parametrize("in_features", (14336,))
@parametrize("out_features", (96,))
@parametrize("group_size", (128,))
@set_num_threads(1)
def test_int4_woq_mm_amx_Nc_larger_than_one(
self, dtype, batch_size, in_features, out_features, group_size
):
"""
Note:
`torch._weight_int4pack_mm_for_cpu` computes with float32, while the AMX-based GEMM
template computes with bfloat16. So, the difference of computation results may be big.
But we need `_weight_int4pack_mm_for_cpu` for its pattern.
Therefore, we define module M1 for its pattern and parameters and define module M2 for
the reference computation. M2's forward function gets the dequantized and unpacked weight
in bfloat16 then computes GEMM with bfloat16.
Besides, we need to skip the VERIFY patch and cannot use self.common for testing.
"""
class M1(torch.nn.Module):
def __init__(self, K, N, group_size):
super().__init__()
self.linear_weight = torch.randint(
0, 255, (N, K // 2), dtype=torch.uint8
)
self.qscale_and_zeros = torch.rand(K // group_size, N, 2, dtype=dtype)
self.group_size = group_size
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x_shape[-1])
y = torch._weight_int4pack_mm_for_cpu(
x, self.linear_weight, self.group_size, self.qscale_and_zeros
)
return y.reshape(*x_shape[:-1], out_features)
class M2(torch.nn.Module):
def __init__(self, mod: M1):
super().__init__()
self.mod = mod
def forward(self, x):
x_eye = torch.eye(x.shape[-1], device=x.device, dtype=x.dtype)
dq_w = self.mod(x_eye).T.contiguous()
return torch.nn.functional.linear(x, dq_w)
counters.clear()
seq_len = 8
x = torch.rand((batch_size, seq_len, in_features), dtype=dtype)
mod = M1(in_features, out_features, group_size).eval()
mod2 = M2(mod)
# Skip VERIFY during torch.compile and don't use self.common. See explanation above.
with patch.object(select_algorithm, "VERIFY", None):
m = torch.compile(mod)
y_ref = mod2(x)
y = m(x)
self.assertEqual(
y,
y_ref,
atol=1e-2,
rtol=1e-2,
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.use_small_dequant_buffer": True})
@patches
@torch.no_grad
@dtypes(torch.bfloat16)
@parametrize("batch_size", (16,))
@parametrize("in_features", (14336,))
@parametrize("out_features", (96,))
@parametrize("group_size", (128,))
@set_num_threads(1)
def test_int4_woq_mm_with_small_buffer_config(
self, dtype, batch_size, in_features, out_features, group_size
):
class M1(torch.nn.Module):
def __init__(self, K, N, group_size):
super().__init__()
self.linear_weight = torch.randint(
0, 255, (N, K // 2), dtype=torch.uint8
)
self.qscale_and_zeros = torch.rand(K // group_size, N, 2, dtype=dtype)
self.group_size = group_size
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x_shape[-1])
y = torch._weight_int4pack_mm_for_cpu(
x, self.linear_weight, self.group_size, self.qscale_and_zeros
)
return y.reshape(*x_shape[:-1], out_features)
counters.clear()
seq_len = 1
x = torch.rand((batch_size, seq_len, in_features), dtype=dtype)
mod = M1(in_features, out_features, group_size).eval()
with patch.object(select_algorithm, "VERIFY", None):
m = torch.compile(mod)
_, code = run_and_get_cpp_code(m, x)
kr = 32 # only kr=32 supported in woq int4 amx kernel
_target_code_check = f"constexpr int64_t Kc_blocks = {group_size // kr};"
torch._C.FileCheck().check(_target_code_check).run(code)
@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@dtypes(torch.bfloat16)
@parametrize("batch_size", (1, 4, 6))
@parametrize("in_features", (128, 1024))
@parametrize("out_features", (128, 1024))
@parametrize("group_size", (32, 64, 128))
def test_int4_woq_mm_amx(
self, dtype, batch_size, in_features, out_features, group_size
):
"""
Note:
`torch._weight_int4pack_mm_for_cpu` computes with float32, while the AMX-based GEMM
template computes with bfloat16. So, the difference of computation results may be big.
But we need `_weight_int4pack_mm_for_cpu` for its pattern.
Therefore, we define module M1 for its pattern and parameters and define module M2 for
the reference computation. M2's forward function gets the dequantized and unpacked weight
in bfloat16 then computes GEMM with bfloat16.
Besides, we need to skip the VERIFY patch and cannot use self.common for testing.
"""
class M1(torch.nn.Module):
def __init__(self, K, N, group_size):
super().__init__()
self.linear_weight = torch.randint(
0, 255, (N, K // 2), dtype=torch.uint8
)
self.qscale_and_zeros = torch.rand(K // group_size, N, 2, dtype=dtype)
self.group_size = group_size
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x_shape[-1])
y = torch._weight_int4pack_mm_for_cpu(
x, self.linear_weight, self.group_size, self.qscale_and_zeros
)
return y.reshape(*x_shape[:-1], out_features)
class M2(torch.nn.Module):
def __init__(self, mod: M1):
super().__init__()
self.mod = mod
def forward(self, x):
x_eye = torch.eye(x.shape[-1], device=x.device, dtype=x.dtype)
dq_w = self.mod(x_eye).T.contiguous()
return torch.nn.functional.linear(x, dq_w)
counters.clear()
seq_len = 8
x = torch.rand((batch_size, seq_len, in_features), dtype=dtype)
mod = M1(in_features, out_features, group_size).eval()
mod2 = M2(mod)
# Skip VERIFY during torch.compile and don't use self.common. See explanation above.
with patch.object(select_algorithm, "VERIFY", None):
m = torch.compile(mod)
y_ref = mod2(x)
y = m(x)
self.assertEqual(
y,
y_ref,
atol=1e-2,
rtol=1e-2,
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.enable_concat_linear": True})
@patches
@torch.no_grad
@dtypes(torch.bfloat16)
@parametrize("batch_size", (4,))
@parametrize("in_features", (256,))
@parametrize("out_features", ((512, 256, 256), (512, 512)))
@parametrize("group_size", (32, 128))
def test_int4_concat_woq_mm(
self, dtype, batch_size, in_features, out_features, group_size
):
class M1(torch.nn.Module):
def __init__(self, K, out_features, group_size):
super().__init__()
self.linear_weight = [
torch.randint(0, 255, (N, K // 2), dtype=torch.uint8)
for N in out_features
]
self.qscale_and_zeros = [
torch.rand(K // group_size, N, 2, dtype=dtype) for N in out_features
]
self.group_size = group_size
self.out_features = out_features
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x_shape[-1])
y = [
torch._weight_int4pack_mm_for_cpu(
x,
self.linear_weight[idx],
self.group_size,
self.qscale_and_zeros[idx],
)
for idx in range(len(self.out_features))
]
return [
y[idx].reshape(*x_shape[:-1], self.out_features[idx])
for idx in range(len(self.out_features))
]
class M2(torch.nn.Module):
def __init__(self, mod: M1):
super().__init__()
self.mod = mod
def forward(self, x):
x_eye = torch.eye(x.shape[-1], device=x.device, dtype=x.dtype)
dq_w_list = []
for idx in range(len(self.mod.out_features)):
x_shape = x_eye.shape
dq_w = torch._weight_int4pack_mm_for_cpu(
x_eye,
self.mod.linear_weight[idx],
self.mod.group_size,
self.mod.qscale_and_zeros[idx],
)
dq_w_list.append(
dq_w.reshape(
*x_shape[:-1], self.mod.out_features[idx]
).T.contiguous()
)
return [torch.nn.functional.linear(x, dq_w) for dq_w in dq_w_list]
counters.clear()
seq_len = 8
x = torch.rand((batch_size, seq_len, in_features), dtype=dtype)
mod = M1(in_features, out_features, group_size).eval()
mod2 = M2(mod)
# Skip VERIFY during torch.compile and don't use self.common. See explanation above.
with patch.object(select_algorithm, "VERIFY", None):
y_ref = mod2(x)
m = torch.compile(mod)
y = m(x)
self.assertEqual(
y,
y_ref,
atol=1e-2,
rtol=1e-2,
)
# Only do once tuning, since the wgt has been concat
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (32,))
@parametrize("in_features", (128,))
@parametrize("out_features", (64, 65))
@parametrize("bias", (False, True))
@parametrize("input_3d", (False, True))
@parametrize("int8_mixed_bf16", (False, True))
@dtypes(torch.float32, torch.bfloat16)
@parametrize(
"epilogue",
(
"none",
"relu",
),
)
@skipIfWindows(msg="Windows don't support quantize.")
def test_quantized_linear_with_pointwise_binary(
self,
batch_size,
in_features,
out_features,
bias,
input_3d,
int8_mixed_bf16,
dtype,
epilogue,
):
if not int8_mixed_bf16 and dtype == torch.bfloat16:
return
B = (2, batch_size) if input_3d else (batch_size,)
input = torch.randn(*B, in_features).to(dtype=torch.float32)
other = torch.randn(*B, out_features).to(dtype=dtype)
# Avoid hitting qlinear inplace sum fusion
if input_3d:
other2 = torch.randn(B[0] * B[1], out_features).to(dtype=dtype)
else:
other2 = torch.randn(1, *B, out_features).to(dtype=dtype)
class M(torch.nn.Module):
def __init__(self, bias, input_3d):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
self.epilogue = _get_epilogue(epilogue)
self.linear2 = torch.nn.Linear(out_features, out_features, bias)
self.epilogue2 = _get_epilogue(epilogue)
self.input_3d = input_3d
def forward(self, x, other, other2):
res = self.epilogue(self.linear(x) + other)
# Avoid hitting qlinear inplace sum fusion
if self.input_3d:
other2 = other2.view(2, other2.size(0) // 2, other2.size(1))
else:
other2 = other2.view(other2.size(1), other2.size(2))
res = self.epilogue2(self.linear2(res) + other2)
return res
counters.clear()
ref_quantized_mod = _generate_qdq_quantized_model(
M(bias=bias, input_3d=input_3d).eval(),
(input, other, other2),
)
atol, rtol = 5e-2, 5e-2
with (
patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)),
torch.no_grad(),
torch.autocast("cpu", enabled=int8_mixed_bf16, dtype=torch.bfloat16),
):
ref_res = ref_quantized_mod(input, other, other2)
cfn = torch.compile(ref_quantized_mod)
res = cfn(input, other, other2)
self.assertEqual(
res,
ref_res,
atol=atol,
rtol=rtol,
equal_nan=True,
exact_dtype=True,
)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
self.assertEqual(
counters["inductor"]["cpp_epilogue_fusion_counter"],
0,
)
@unittest.skipIf(
not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required"
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@parametrize("batch_size", (3, 16, 32, 49))
@parametrize("in_features", (4, 68, 128)) # k should be a multiple of 4
@parametrize("out_features", (64, 65))
@parametrize("bias", (True, False))
@skipIfWindows(msg="Windows don't support quantize.")
def test_quantized_linear_amx(self, batch_size, in_features, out_features, bias):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
def forward(self, x):
return self.linear(x)
counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=torch.float32)
ref_quantized_mod = _generate_qdq_quantized_model(
M(bias=bias).eval(),
(v,),
)
atol, rtol = 1e-2, 1e-2
with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
self.common(ref_quantized_mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
vec_amx = VecAMX()
self._check_amx_counter(vec_amx)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.gemm_max_k_slices": 0})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (2,))
@parametrize("in_features", (1000,))
@parametrize("out_features", (2,))
@parametrize("bias", (True, False))
@parametrize(
"epilogue",
(
"none",
"relu",
),
)
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_k_slicing(
self, batch_size, in_features, out_features, bias, epilogue, dtype
):
class M(torch.nn.Module):
def __init__(self, bias, epilogue, other):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
self.epilogue = _get_epilogue(epilogue, other)
def forward(self, x):
return self.epilogue(self.linear(x))
counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
u = torch.randn(batch_size, out_features).to(dtype=dtype)
mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.gemm_cache_blocking": "2,2,2"})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@set_num_threads(1)
@parametrize("batch_size", (512,))
@parametrize("in_features", (1024,))
@parametrize("out_features", (1024,))
@parametrize("bias", (True, False))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_cache_blocking(
self, batch_size, in_features, out_features, bias, dtype
):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
def forward(self, x):
return self.linear(x)
counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M(bias=bias).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.gemm_thread_factors": "4,2,7"})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@set_num_threads(56)
@parametrize("batch_size", (1024,))
@parametrize("in_features", (1024,))
@parametrize("out_features", (1024,))
@parametrize("bias", (True, False))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_thread_factors(
self, batch_size, in_features, out_features, bias, dtype
):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
def forward(self, x):
return self.linear(x)
counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M(bias=bias).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": False})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (16,))
@parametrize("in_features", (128,))
@parametrize("out_features", (64,))
@parametrize("bias", (True,))
@dtypes(
torch.float,
)
def test_aoti_linear(self, batch_size, in_features, out_features, bias, dtype):
try:
try:
from . import test_aot_inductor_utils
except ImportError:
import test_aot_inductor_utils
except Exception:
# skip this UT if import failed
return
class M(torch.nn.Module):
def __init__(self, bias=bias) -> None:
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(in_features, out_features, bias=bias),
torch.nn.ReLU(),
)
def forward(self, x):
return self.mlp(x)
assert torch._inductor.config.freezing is False
counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M(bias=bias).to(dtype=dtype).eval()
torch._dynamo.reset()
torch._inductor.metrics.reset()
torch.manual_seed(0)
with verify(dtype) as (atol, rtol), torch.no_grad():
expected = mod(v)
actual = test_aot_inductor_utils.AOTIRunnerUtil.run(
mod,
(v,),
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (16,))
@parametrize("in_features", (52,))
@parametrize("out_features", (32,))
@parametrize("gemm_num", (2, 3))
def test_grouped_linear_invalid(
self,
batch_size,
in_features,
out_features,
gemm_num,
):
class M(torch.nn.Module):
def __init__(self, in_feature, out_feature, gemm_num):
super().__init__()
self.linears = [
torch.nn.Linear(in_feature, out_feature + gemm_idx, bias=False)
for gemm_idx in range(gemm_num)
]
def forward(self, x):
return [linear(x) for linear in self.linears]
# each linear has different num of out features, thus invalid grouped gemm
dtypes = []
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
for dtype in dtypes:
torch._dynamo.reset()
torch._inductor.metrics.reset()
counters.clear()
mod = M(in_features, out_features, gemm_num).eval()
v = torch.randn(batch_size, in_features).to(dtype)
with (
verify(dtype) as (atol, rtol),
torch.autocast(device_type="cpu", dtype=dtype),
torch.no_grad(),
):
self.common(mod, (v,), atol=atol, rtol=rtol)
# gemm_num independent template instead of grouped gemm template
self.assertEqual(
counters["inductor"]["cpp_templated_kernel_counter"], gemm_num
)
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 0)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (16,))
@parametrize("in_features", (52,))
@parametrize("out_features", (32,))
@parametrize("input_3d", (False, True))
@parametrize("gemm_num", (2, 3))
def test_grouped_linear(
self,
batch_size,
in_features,
out_features,
input_3d,
gemm_num,
):
class M(torch.nn.Module):
def __init__(self, in_feature, out_feature, gemm_num):
super().__init__()
self.linears = [
torch.nn.Linear(in_feature, out_feature, bias=False)
for _ in range(gemm_num)
]
def forward(self, x):
return [linear(x) for linear in self.linears]
dtypes = []
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
for dtype in dtypes:
if dtype == torch.float16 and input_3d:
# reduce the number of tests
continue
torch._dynamo.reset()
torch._inductor.metrics.reset()
counters.clear()
mod = M(in_features, out_features, gemm_num).eval()
B = (2, batch_size) if input_3d else (batch_size,)
v = torch.randn(*B, in_features).to(dtype)
with (
verify(dtype) as (atol, rtol),
torch.autocast(device_type="cpu", dtype=dtype),
torch.no_grad(),
):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 1)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (16,))
@parametrize("in_features", (52,))
@parametrize("out_features", (32,))
@parametrize("input_3d", (True, False))
@parametrize(
"bias",
(
[True, True],
[True, False],
[False, True],
[False, False],
),
)
@parametrize(
"epilogue",
(
["none", "none"],
["relu", "none"],
["none", "relu"],
["relu", "relu"],
["silu", "mul"],
),
)
def test_grouped_linear_epilogue(
self,
batch_size,
in_features,
out_features,
input_3d,
bias,
epilogue,
):
class M(torch.nn.Module):
def __init__(self, in_feature, out_feature, bias, epilogue):
super().__init__()
self.linear0 = torch.nn.Linear(in_feature, out_feature, bias=bias[0])
self.linear1 = torch.nn.Linear(in_feature, out_feature, bias=bias[1])
self.epilogue0 = epilogue[0]
self.epilogue1 = epilogue[1]
def forward(self, x):
res0 = self.linear0(x)
res1 = self.linear1(x)
if self.epilogue0 == "silu" and self.epilogue1 == "mul":
return torch.nn.functional.silu(res0) * res1
else:
if self.epilogue0 == "relu":
res0 = torch.nn.functional.relu(res0)
if self.epilogue1 == "relu":
res1 = torch.nn.functional.relu(res1)
return res0, res1
dtypes = []
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
for dtype in dtypes:
if input_3d and dtype == torch.float16:
# Reduce the number of test cases
continue
torch._dynamo.reset()
torch._inductor.metrics.reset()
counters.clear()
mod = M(in_features, out_features, bias, epilogue).eval()
B = (2, batch_size) if input_3d else (batch_size,)
v = torch.randn(*B, in_features).to(dtype)
with (
verify(dtype) as (atol, rtol),
torch.autocast(device_type="cpu", dtype=dtype),
torch.no_grad(),
):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 1)
if any(e != "none" for e in epilogue):
self.assertGreater(
counters["inductor"]["cpp_epilogue_fusion_counter"], 0
)
@inductor_config.patch({"freezing": False})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (16,))
@parametrize("in_features", (128,))
@parametrize("out_features", (64,))
@dtypes(
torch.float,
)
def test_aoti_linear_multi_view_operations(
self, batch_size, in_features, out_features, dtype
):
try:
try:
from . import test_aot_inductor_utils
except ImportError:
import test_aot_inductor_utils
except Exception:
# skip this UT if import failed
return
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.bias = torch.randn(out_features)
self.weight = torch.randn(out_features // 2, 2, in_features)
self.relu = torch.nn.ReLU()
def forward(self, x):
tmp = torch.addmm(
self.bias,
x,
self.weight.permute(2, 0, 1).view(in_features, out_features),
)
return self.relu(tmp)
assert torch._inductor.config.freezing is False
counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
torch._dynamo.reset()
torch._inductor.metrics.reset()
torch.manual_seed(0)
with verify(dtype) as (atol, rtol), torch.no_grad():
expected = mod(v)
actual = test_aot_inductor_utils.AOTIRunnerUtil.run(
mod,
(v,),
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"coordinate_descent_tuning": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
def test_cpp_coordinate_descent_tuning(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(512, 1024, bias=False)
def forward(self, x):
return self.linear(x)
v = torch.randn(1, 512)
mod = M().eval()
torch._dynamo.reset()
torch._inductor.metrics.reset()
counters.clear()
with verify(torch.bfloat16) as (atol, rtol), torch.autocast(device_type="cpu"):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (2,))
@parametrize("in_features", (128,))
@parametrize("out_features", (64,))
@parametrize("bias", (True, False))
def test_linear_to_lowp_fp(self, batch_size, in_features, out_features, bias):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
def forward(self, x):
return self.linear(x).to(torch.float16)
counters.clear()
dtype = torch.float32
mod = M(bias=bias).to(dtype=dtype).eval()
B = (batch_size,)
v = torch.randn(*B, in_features).to(dtype=dtype)
with verify(dtype) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
def test_cpp_weight_prune(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 128, bias=False)
def forward(self, x):
return self.linear(x)
v = torch.randn(2, 32).to(torch.bfloat16)
mod = M().eval().to(torch.bfloat16)
torch._dynamo.reset()
torch._inductor.metrics.reset()
counters.clear()
with verify(torch.bfloat16) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_weight_prune"], 1)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (1, 50))
@parametrize("Mdim", (192,))
@parametrize("Kdim", (196,))
@parametrize("Ndim", (84, 385))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm(self, dtype, bs, Mdim, Kdim, Ndim):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x @ y
counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (2,))
@parametrize("Mdim", (16, 32))
@parametrize("Kdim", (32,))
@parametrize("Ndim", (3, 16, 32, 48, 128, 1024, 1025))
@dtypes(torch.bfloat16, torch.half)
def test_bmm_amx(self, dtype, bs, Mdim, Kdim, Ndim):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x @ y
counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
vec_amx = VecAMX()
# Currently brgemm config is only added for half
if dtype == torch.half:
self._check_brgemm_counter(vec_amx)
else:
self._check_amx_counter(vec_amx)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (1,))
@parametrize("Mdim", (192,))
@parametrize("Kdim", (196,))
@parametrize("Ndim", (84,))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm_amp(self, dtype, bs, Mdim, Kdim, Ndim):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x @ y
counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol), torch.amp.autocast("cpu"):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (1,))
@parametrize("Mdim", (192,))
@parametrize("Kdim", (196,))
@parametrize("Ndim", (64, 65))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm_freezing(self, dtype, bs, Mdim, Kdim, Ndim):
class M(torch.nn.Module):
def __init__(self, w):
super().__init__()
self.w = torch.nn.Parameter(w, requires_grad=False)
def forward(self, x):
return x @ self.w
counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
mod = M(v).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("Ndim", (64, 61))
@parametrize(
"order",
(
((0, 1, 2), (0, 2, 1)), # First BMM in hf_Reformer
((0, 1, 2), (1, 2, 0)), # First BMM in hf_DistilBert
((0, 1, 2), (1, 0, 2)), # Second BMM in hf_DistilBert, hf_T5
((1, 0, 2), (0, 1, 2)), # Third BMM in hf_Reformer
((1, 0, 2), (1, 2, 0)), # First in hf_T5
),
)
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm_2d_permute(self, Ndim, order, dtype):
# TODO: Support bmm with transposed X
bs = 12
Mdim = 10
Kdim = 62
x_args = (bs, Mdim, Kdim)
w_args = (bs, Kdim, Ndim)
inverse_order = [torch.argsort(torch.tensor(o)).tolist() for o in order]
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w):
if order[0] != (0, 1, 2):
x_order = [x_args[i] for i in inverse_order[0]]
x = x.reshape(x_order[0], x_order[1] * x_order[2]).clone()
x = x.reshape(*x_order).permute(*order[0])
if order[1] != (0, 1, 2):
w_order = [w_args[i] for i in inverse_order[1]]
w = w.reshape(w_order[0], w_order[1] * w_order[2]).clone()
w = w.reshape(*w_order).permute(*order[1])
y = x @ w
return y
counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(
counters["inductor"]["cpp_templated_kernel_counter"],
1 if order[0] == (0, 1, 2) else 0,
)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (5,))
@parametrize("Mdim", (64,))
@parametrize("Kdim", (96,))
@dtypes(torch.float, torch.float16, torch.bfloat16)
def test_bmm_self_permute(self, bs, Mdim, Kdim, dtype):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x @ x.permute(0, 2, 1)
counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (5,))
@parametrize("Mdim", (64,))
@dtypes(torch.float)
def test_bmm_self_square(self, bs, Mdim, dtype):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x @ x
counters.clear()
u = torch.randn(bs, Mdim, Mdim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (5,))
@parametrize("Mdim", (384,))
@parametrize("Kdim", (96,))
@parametrize("Ndim", (64, 65))
@parametrize(
"epilogue",
(
"relu",
"add",
"sub",
"mul",
"div",
),
)
@dtypes(torch.float32, torch.bfloat16, torch.half)
def test_bmm_with_pointwise(self, bs, Mdim, Kdim, Ndim, epilogue, dtype):
class M(torch.nn.Module):
def __init__(self, epilogue, other):
super().__init__()
self.epilogue = _get_epilogue(epilogue, other)
def forward(self, x, w):
return self.epilogue(x @ w)
counters.clear()
x = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
w = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
other = torch.randn(bs, Mdim, Ndim).to(dtype=dtype)
mod = M(epilogue, other).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (x, w), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@dtypes(torch.float32, torch.bfloat16, torch.half)
def test_bmm_with_fused_epilogues(self, dtype):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mul = torch.randn(8, 8, 3136, 8).as_strided(
(8, 8, 3136, 8), (200704, 8, 64, 1)
)
def forward(self, x, w):
x = torch.ops.aten.reshape.default(x, [64, 3137, 8])
w = torch.ops.aten.reshape.default(w, [64, 8, 8])
bmm = torch.ops.aten.bmm.default(x, w)
bmm = torch.ops.aten.reshape.default(bmm, [8, 8, 3137, 8])
constant_pad_nd = torch.ops.aten.constant_pad_nd.default(
self.mul, [0, 0, 1, 0, 0, 0], 0.0
)
mul_2 = torch.ops.aten.mul.Tensor(bmm, 0.3535533905932738)
add = torch.ops.aten.add.Tensor(mul_2, constant_pad_nd)
return add
counters.clear()
x = torch.randn(8, 8, 3137, 8).to(dtype=dtype)
w = torch.randn(8, 8, 8, 8).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (x, w), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@patches
@torch.no_grad
@dtypes(torch.float)
def test_aoti_bmm_unique_identifiers(self, dtype):
try:
try:
from . import test_aot_inductor_utils
except ImportError:
import test_aot_inductor_utils
except Exception:
# skip this UT if import failed
return
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w):
y = x @ w
return y @ w
counters.clear()
x = torch.randn(3, 64, 64).to(dtype=dtype)
w = torch.randn(3, 64, 64).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol), torch.no_grad():
expected = mod(x, w)
actual = test_aot_inductor_utils.AOTIRunnerUtil.run(
mod,
(x, w),
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@set_num_threads(1) # avoid k_slicing to make the test deterministic
@parametrize(
"out_features1",
(
8,
16,
24,
32,
48,
),
)
@dtypes(torch.float)
def test_local_and_global_accumulator(self, out_features1, dtype):
batch_size = 256
in_features = 64
out_features = 129
in_features1 = 128
bias = True
try:
try:
from . import test_aot_inductor_utils
except ImportError:
import test_aot_inductor_utils
except Exception:
# skip this UT if import failed
return
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
self.linear1 = torch.nn.Linear(in_features1, out_features1, bias)
def forward(self, x):
y = self.linear(x)
view = torch.ops.aten.view.default(y, [-1, in_features1])
return self.linear1(view)
counters.clear()
x = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol), torch.no_grad():
expected = mod(
x,
)
actual = test_aot_inductor_utils.AOTIRunnerUtil.run(
mod,
(x,),
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 2)
@patches
@inductor_config.patch(freezing=True)
@unittest.skipIf(not torch._C._has_mkldnn, "MKLDNN is not enabled")
def test_bmm_flexible_layout(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, u, v):
view_3 = torch.ops.aten.reshape.default(u, [-1, 512, 64])
clone_1 = torch.ops.aten.clone.default(
v, memory_format=torch.contiguous_format
)
view_7 = torch.ops.aten.reshape.default(clone_1, [-1, 512, 64])
permute_6 = torch.ops.aten.permute.default(view_7, [0, 2, 1])
div = torch.ops.aten.div.Tensor(permute_6, 8.0)
# view_3 is a ReinterpretView and div is a FlexibleLayout which will become FixedLayout
bmm = torch.ops.aten.bmm.default(view_3, div)
return bmm
mod = M().eval()
u = torch.randn(2, 24, 512, 64)
v = torch.randn(48, 512, 64)
with verify(u.dtype) as (atol, rtol):
self.common(mod, (u, v))
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class _DynamicShapesTestBase(BaseTestSelectAlgorithm):
pass
class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
common = check_model
test_linear_dynamic_shapes = TestSelectAlgorithm.test_linear_static_shapes
test_linear_with_pointwise_dynamic_shapes = (
TestSelectAlgorithm.test_linear_with_pointwise
)
test_linear_with_transpose_dynamic_shapes = (
TestSelectAlgorithm.test_linear_with_transpose
)
test_linear_with_unary_binary_dynamic_shapes = (
TestSelectAlgorithm.test_linear_with_unary_binary
)
test_linear_amx_dynamic_shapes = TestSelectAlgorithm.test_linear_amx
test_linear_with_embedding_dynamic_shapes = (
TestSelectAlgorithm.test_linear_with_embedding
)
test_quantized_linear_with_pointwise_dynamic_shapes = (
TestSelectAlgorithm.test_quantized_linear_with_pointwise
)
test_quantized_linear_with_pointwise_binary_dynamic_shapes = (
TestSelectAlgorithm.test_quantized_linear_with_pointwise_binary
)
test_quantized_linear_amx_dynamic_shapes = (
TestSelectAlgorithm.test_quantized_linear_amx
)
test_grouped_linear_dynamic_shapes = TestSelectAlgorithm.test_grouped_linear
test_grouped_linear_epilogue_dynamic_shapes = (
TestSelectAlgorithm.test_grouped_linear_epilogue
)
test_linear_k_slicing_dynamic_shapes = TestSelectAlgorithm.test_linear_k_slicing
test_linear_cache_blocking_dynamic_shapes = (
TestSelectAlgorithm.test_linear_cache_blocking
)
test_linear_thread_factors_dynamic_shapes = (
TestSelectAlgorithm.test_linear_thread_factors
)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (5,))
@parametrize("Mdim", (384,))
@parametrize("Kdim", (96,))
@parametrize("Ndim", (64, 65))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm_with_pointwise_dynamic_shapes(self, bs, Mdim, Kdim, Ndim, dtype):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.epilogue = torch.nn.ReLU()
def forward(self, x, other):
return self.epilogue(x @ other)
counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
torch._dynamo.mark_dynamic(u, 0)
torch._dynamo.mark_dynamic(u, 1)
torch._dynamo.mark_static(u, 2)
torch._dynamo.mark_static(v, 2)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (5,))
@parametrize("Mdim", (384,))
@parametrize("Kdim", (96,))
@parametrize("Ndim", (64, 65))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm_with_pointwise_with_reshape_dynamic_shapes(
self, bs, Mdim, Kdim, Ndim, dtype
):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.epilogue = torch.nn.ReLU()
def forward(self, x, other, noise):
result = x.reshape(-1, Mdim, Kdim) @ other.reshape(-1, Kdim, Ndim)
return self.epilogue(result) + noise
counters.clear()
u = torch.randn(bs, 8, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, 8, Kdim, Ndim).to(dtype=dtype)
noise = torch.randn(bs * 8, Mdim, Ndim).to(dtype=dtype)
torch._dynamo.mark_dynamic(u, 0)
torch._dynamo.mark_dynamic(u, 1)
torch._dynamo.mark_static(u, 2)
torch._dynamo.mark_static(u, 3)
torch._dynamo.mark_static(v, 2)
torch._dynamo.mark_static(v, 3)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v, noise), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@dtypes(torch.float, torch.bfloat16)
def test_bmm_epilogue_dynamic_reshape(self, dtype):
bs = 5
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.epilogue = torch.nn.ReLU()
def forward(self, x, w, arg5_1):
arg131_1 = x.shape[0]
mul_91 = arg131_1 * 8
view_422 = torch.ops.aten.reshape.default(x, [mul_91, 512, 64])
view_423 = torch.ops.aten.reshape.default(w, [mul_91, 64, 512])
bmm_36 = torch.ops.aten.bmm.default(view_422, view_423)
view_424 = torch.ops.aten.reshape.default(
bmm_36, [arg131_1, 8, 512, 512]
)
abs_2 = torch.ones(512, 512, dtype=torch.int64)
lt_562 = torch.ops.aten.lt.Scalar(abs_2, 8)
add_5084 = torch.ones(512, 512, dtype=torch.int64)
add_5085 = torch.ones(512, 512, dtype=torch.int64)
full_default_1 = torch.ops.aten.full.default(
[512, 512], 15, dtype=torch.int64, layout=torch.strided
)
minimum_3 = torch.ops.aten.minimum.default(add_5085, full_default_1)
where_2 = torch.ops.aten.where.self(lt_562, abs_2, minimum_3)
add_5086 = torch.ops.aten.add.Tensor(add_5084, where_2)
embedding_5 = torch.ops.aten.embedding.default(arg5_1, add_5086)
permute_196 = torch.ops.aten.permute.default(embedding_5, [2, 0, 1])
unsqueeze_21 = torch.ops.aten.unsqueeze.default(permute_196, 0)
full_default = torch.ops.aten.full.default(
[arg131_1, 1, 1, 512],
-0.0,
dtype=torch.float32,
layout=torch.strided,
)
add_5087 = torch.ops.aten.add.Tensor(unsqueeze_21, full_default)
add_5103 = torch.ops.aten.add.Tensor(view_424, add_5087)
return add_5103
counters.clear()
u = torch.randn(bs, 8, 512, 64).to(dtype=dtype)
v = torch.randn(bs, 8, 64, 512).to(dtype=dtype)
arg5 = torch.randn(32, 8)
torch._dynamo.mark_dynamic(u, 0)
torch._dynamo.mark_static(u, 1)
torch._dynamo.mark_static(u, 2)
torch._dynamo.mark_static(u, 3)
torch._dynamo.mark_static(v, 2)
torch._dynamo.mark_static(v, 3)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v, arg5), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
def test_bmm_dynamic_bm_stride(self):
bs = 8
Mdim = 256
Kdim = 64
dtype = torch.float
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight):
return x @ weight.permute(2, 0, 1)
counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(Kdim, Mdim, bs).to(dtype=dtype)
torch._dynamo.mark_dynamic(u, 0)
torch._dynamo.mark_dynamic(u, 1)
torch._dynamo.mark_static(u, 2)
torch._dynamo.mark_static(v, 0)
torch._dynamo.mark_static(v, 1)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")
instantiate_device_type_tests(
TestSelectAlgorithmDynamicShapes, globals(), only_for="cpu"
)
if __name__ == "__main__":
from torch.testing._internal.inductor_utils import HAS_CPU
if HAS_CPU and not IS_MACOS:
run_tests()