From 914d3ca2ba0d63437d835da10364cccb2bb41a16 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Thu, 20 Jun 2024 17:19:58 -0700 Subject: [PATCH] [inductor][cpp] BF16 AMX micro-gemm support (#127195) This PR adds the intrinsics based micro-gemm for BF16 using Advanced Matrix eXtension (AMX) instructions available in Intel 4th and 5th Xeon processors. A compilation check is added to `codecache.py` to check the validity of the compiler support. Also, since AMX requires an initialization in the Linux kernel to extra register states, an initialization function is added to do that and triggered via `codecache.py`. Performance speedups with >=10% on BF16 AMP, max_autotune vs. no autotune, measured on Intel(R) Xeon(R) Platinum 8488C: Static shapes Single-threaded | Model Family | Model Name | Speedup | |--------------|------------|---------| | timm_models | mixer_b16_224 | 1.54 | | timm_models | convit_base | 1.53 | | huggingface | MobileBertForQuestionAnswering | 1.52 | | torchbench | fastNLP_Bert | 1.44 | | torchbench | llama | 1.33 | | timm_models | swin_base_patch4_window7_224 | 1.31 | | torchbench | dlrm | 1.28 | | torchbench | timm_vision_transformer_large | 1.28 | | huggingface | MobileBertForMaskedLM | 1.27 | | timm_models | vit_base_patch16_224 | 1.26 | | timm_models | beit_base_patch16_224 | 1.23 | | timm_models | jx_nest_base | 1.21 | | torchbench | pyhpc_equation_of_state | 1.18 | | huggingface | Speech2Text2ForCausalLM | 1.15 | | timm_models | pit_b_224 | 1.14 | | timm_models | twins_pcpvt_base | 1.14 | | torchbench | maml_omniglot | 1.1 | | timm_models | botnet26t_256 | 1.1 | Multi-threaded | Model Family | Model Name | Speedup | |--------------|------------|---------| | torchbench | BERT_pytorch | 1.35 | | torchbench | lennard_jones | 2.43 | | torchbench | hf_Albert | 1.35 | | torchbench | hf_T5 | 1.34 | | torchbench | soft_actor_critic | 1.34 | | torchbench | fastNLP_Bert | 1.28 | | huggingface | LayoutLMForSequenceClassification | 1.26 | | torchbench | llama | 1.24 | | huggingface | GPT2ForSequenceClassification | 1.19 | | torchbench | hf_Bart | 1.17 | | torchbench | hf_Bert_large | 1.16 | | torchbench | hf_GPT2 | 1.16 | | timm_models | gmixer_24_224 | 1.16 | | torchbench | hf_GPT2_large | 1.15 | | torchbench | maml_omniglot | 1.14 | | torchbench | hf_Bert | 1.13 | | torchbench | hf_DistilBert | 1.13 | | torchbench | hf_T5_large | 1.12 | | huggingface | MT5ForConditionalGeneration | 1.11 | Dynamic shapes Single-threaded | Model Family | Model Name | Speedup | |--------------|------------|-------| | timm_models | mixer_b16_224 | 1.52 | | timm_models | convit_base | 1.5 | | huggingface | MobileBertForQuestionAnswering | 1.49 | | torchbench | fastNLP_Bert | 1.42 | | torchbench | timm_vision_transformer_large | 1.28 | | timm_models | swin_base_patch4_window7_224 | 1.27 | | torchbench | llama | 1.26 | | huggingface | MobileBertForMaskedLM | 1.25 | | timm_models | vit_base_patch16_224 | 1.25 | | timm_models | beit_base_patch16_224 | 1.24 | | timm_models | jx_nest_base | 1.2 | | torchbench | dlrm | 1.19 | | timm_models | pit_b_224 | 1.13 | | timm_models | twins_pcpvt_base | 1.13 | | torchbench | hf_Bert_large | 1.12 | | torchbench | hf_BigBird | 1.11 | | huggingface | Speech2Text2ForCausalLM | 1.11 | | timm_models | eca_botnext26ts_256 | 1.11 | | timm_models | botnet26t_256 | 1.1 | Multi-threaded | Model Family | Model Name | Speedup | |--------------|------------|-------| | torchbench | BERT_pytorch | 1.18 | | torchbench | lennard_jones | 2.18 | | torchbench | hf_Albert | 1.37 | | torchbench | soft_actor_critic | 1.31 | | huggingface | GPT2ForSequenceClassification | 1.29 | | torchbench | hf_T5 | 1.28 | | torchbench | fastNLP_Bert | 1.27 | | torchbench | hf_Bart | 1.21 | | torchbench | hf_Bert_large | 1.19 | | torchbench | hf_T5_large | 1.19 | | torchbench | hf_Bert | 1.16 | | torchbench | hf_GPT2 | 1.16 | | huggingface | CamemBert | 1.16 | | torchbench | hf_GPT2_large | 1.13 | | torchbench | functorch_maml_omniglot | 1.12 | | huggingface | BertForMaskedLM | 1.12 | | huggingface | MT5ForConditionalGeneration | 1.12 | | torchbench | hf_DistilBert | 1.11 | | timm_models | mixnet_l | 1.11 | | timm_models | tf_mixnet_l | 1.11 | No perf regressions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127195 Approved by: https://github.com/jansel --- aten/src/ATen/cpu/Utils.cpp | 47 +++ aten/src/ATen/cpu/Utils.h | 6 + test/inductor/test_cpu_repro.py | 4 +- test/inductor/test_cpu_select_algorithm.py | 33 ++ torch/_C/_cpu.pyi | 2 + torch/_dynamo/trace_rules.py | 4 + torch/_inductor/codecache.py | 67 +++- torch/_inductor/codegen/cpp_gemm_template.py | 22 +- torch/_inductor/codegen/cpp_micro_gemm.py | 328 +++++++++++++++++-- torch/_inductor/codegen/cpp_prefix.h | 62 ++++ torch/cpu/__init__.py | 10 + torch/csrc/cpu/Module.cpp | 2 + 12 files changed, 545 insertions(+), 42 deletions(-) diff --git a/aten/src/ATen/cpu/Utils.cpp b/aten/src/ATen/cpu/Utils.cpp index fbf861dcabcf..626ffd3a61e6 100644 --- a/aten/src/ATen/cpu/Utils.cpp +++ b/aten/src/ATen/cpu/Utils.cpp @@ -2,6 +2,10 @@ #if !defined(__s390x__ ) && !defined(__powerpc__) #include #endif +#if defined(__linux__) +#include +#include +#endif namespace at::cpu { bool is_cpu_support_avx2() { @@ -28,4 +32,47 @@ bool is_cpu_support_avx512_vnni() { #endif } +bool is_cpu_support_amx_tile() { +#if !defined(__s390x__) && !defined(__powerpc__) + return cpuinfo_initialize() && cpuinfo_has_x86_amx_tile(); +#else + return false; +#endif +} + +bool init_amx() { + if (!is_cpu_support_amx_tile()) { + return false; + } + +#if defined(__linux__) && !defined(__ANDROID__) +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 +#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) +#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) +#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) + +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 + + unsigned long bitmask = 0; + // Request permission to use AMX instructions + long rc = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + if (rc) { + return false; + } + // Check if the system supports AMX instructions + rc = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + if (rc) { + return false; + } + if (bitmask & XFEATURE_MASK_XTILE) { + return true; + } + return false; +#else + return true; +#endif +} + } // namespace at::cpu diff --git a/aten/src/ATen/cpu/Utils.h b/aten/src/ATen/cpu/Utils.h index 0ad6f8e893ca..03136f04a85a 100644 --- a/aten/src/ATen/cpu/Utils.h +++ b/aten/src/ATen/cpu/Utils.h @@ -10,4 +10,10 @@ TORCH_API bool is_cpu_support_avx512(); // Detect if CPU support Vector Neural Network Instruction. TORCH_API bool is_cpu_support_avx512_vnni(); +// Detect if CPU support Advanced Matrix Extension. +TORCH_API bool is_cpu_support_amx_tile(); + +// Enable the system to use AMX instructions. +TORCH_API bool init_amx(); + } // namespace at::cpu diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 1d0fb3e8e1b9..fd9e75847b52 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1588,8 +1588,8 @@ class CPUReproTests(TestCase): ) @patch("torch.cuda.is_available", lambda: False) def test_auto_simd(self): - vec_avx512 = codecache.supported_vec_isa_list[0] - vec_avx2 = codecache.supported_vec_isa_list[1] + vec_avx512 = codecache.supported_vec_isa_list[1] + vec_avx2 = codecache.supported_vec_isa_list[2] self.assertTrue(vec_avx512.bit_width() == 512) self.assertTrue(vec_avx2.bit_width() == 256) self.assertTrue(vec_avx512.nelements() == 16) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 27e14e7e9f5b..70c8a2460be6 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -12,6 +12,7 @@ import torch._dynamo.config as dynamo_config import torch._inductor.config as inductor_config import torch._inductor.select_algorithm as select_algorithm from torch._dynamo.utils import counters +from torch._inductor.codecache import VecAMX from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_device_type import ( dtypes, @@ -333,6 +334,37 @@ class TestSelectAlgorithm(TestCase): self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @parametrize("bias", (True, False)) + def test_linear_amx(self, bias): + batch_size = 1024 + in_features = 1024 + out_features = 1024 + dtype = torch.bfloat16 + + class M(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias) + + def forward(self, x): + return self.linear(x) + + counters.clear() + v = torch.randn(batch_size, in_features).to(dtype=dtype) + mod = M(bias=bias).to(dtype=dtype).eval() + atol, rtol = 1e-2, 1e-2 + with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)): + self.common(mod, (v,), atol=atol, rtol=rtol) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + vec_amx = VecAMX() + if vec_amx: + self.assertTrue(counters["inductor"]["cpp_micro_gemm_amx_counter"] > 0) + else: + self.assertEqual(counters["inductor"]["cpp_micro_gemm_amx_counter"], 0) + @dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) class _DynamicShapesTestBase(TestCase): @@ -351,6 +383,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): test_linear_with_unary_binary_dynamic_shapes = ( TestSelectAlgorithm.test_linear_with_unary_binary ) + test_linear_amx_dynamic_shapes = TestSelectAlgorithm.test_linear_amx instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu") diff --git a/torch/_C/_cpu.pyi b/torch/_C/_cpu.pyi index 37794bd7c10b..e70d9cb4d179 100644 --- a/torch/_C/_cpu.pyi +++ b/torch/_C/_cpu.pyi @@ -5,3 +5,5 @@ from torch.types import _bool def _is_cpu_support_avx2() -> _bool: ... def _is_cpu_support_avx512() -> _bool: ... def _is_cpu_support_avx512_vnni() -> _bool: ... +def _is_cpu_support_amx_tile() -> _bool: ... +def _init_amx() -> _bool: ... diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index abbef02e63c6..a9486a530db1 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -409,6 +409,8 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._C._cpu._is_cpu_support_avx2", "torch._C._cpu._is_cpu_support_avx512", "torch._C._cpu._is_cpu_support_avx512_vnni", + "torch._C._cpu._is_cpu_support_amx_tile", + "torch._C._cpu._init_amx", "torch._C._crash_if_aten_asan", "torch._C._crash_if_csrc_asan", "torch._C._crash_if_csrc_ubsan", @@ -2423,6 +2425,8 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch.cpu._is_cpu_support_avx2", "torch.cpu._is_cpu_support_avx512", "torch.cpu._is_cpu_support_avx512_vnni", + "torch.cpu._is_cpu_support_amx_tile", + "torch.cpu._init_amx", "torch.cpu.current_device", "torch.cpu.current_stream", "torch.cpu.device_count", diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 5109011f064a..c7dcc46b433c 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1336,18 +1336,11 @@ cdll.LoadLibrary("__lib_path__") def __hash__(self) -> int: return hash(str(self)) - @functools.lru_cache(None) # noqa: B019 - def __bool__(self) -> bool: + def check_build(self, code) -> bool: from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions - if config.cpp.vec_isa_ok is not None: - return config.cpp.vec_isa_ok - - if config.is_fbcode(): - return True - key, input_path = write( - VecISA._avx_code, + code, "cpp", extra=_get_isa_dry_compile_fingerprint(self._arch_flags), ) @@ -1385,6 +1378,16 @@ cdll.LoadLibrary("__lib_path__") return True + @functools.lru_cache(None) # noqa: B019 + def __bool__(self) -> bool: + if config.cpp.vec_isa_ok is not None: + return config.cpp.vec_isa_ok + + if config.is_fbcode(): + return True + + return self.check_build(VecISA._avx_code) + @dataclasses.dataclass class VecNEON(VecISA): @@ -1418,6 +1421,46 @@ class VecAVX512(VecISA): __hash__: Callable[[VecISA], Any] = VecISA.__hash__ +@dataclasses.dataclass +class VecAMX(VecAVX512): + _arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8" + + def __str__(self) -> str: + return super().__str__() + " amx_tile" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + _amx_code = """ +#include +#include + +struct amx_tilecfg { + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[16]; + uint8_t rows[16]; +}; + +extern "C" void __amx_chk_kernel() { + amx_tilecfg cfg = {0}; + _tile_loadconfig(&cfg); + _tile_zero(0); + _tile_dpbf16ps(0, 1, 2); + _tile_dpbusd(0, 1, 2); +} +""" + + @functools.lru_cache(None) # noqa: B019 + def __bool__(self) -> bool: + if super().__bool__(): + if config.is_fbcode(): + return False + if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx(): + return True + return False + + @dataclasses.dataclass class VecAVX2(VecISA): _bit_width = 256 @@ -1483,15 +1526,17 @@ def x86_isa_checker() -> List[str]: avx2 = torch.cpu._is_cpu_support_avx2() avx512 = torch.cpu._is_cpu_support_avx512() + amx_tile = torch.cpu._is_cpu_support_amx_tile() _check_and_append_supported_isa(supported_isa, avx2, "avx2") _check_and_append_supported_isa(supported_isa, avx512, "avx512") + _check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile") return supported_isa invalid_vec_isa = InvalidVecISA() -supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()] +supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()] # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content @@ -1528,7 +1573,7 @@ def valid_vec_isa_list() -> List[VecISA]: """ _cpu_supported_x86_isa = x86_isa_checker() for isa in supported_vec_isa_list: - if str(isa) in _cpu_supported_x86_isa and isa: + if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa: isa_list.append(isa) return isa_list diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 60ae0bfdc750..b639eac3f5c4 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -3,13 +3,14 @@ from typing import Any, Callable, cast, List, Optional, Union import torch import torch.utils +from ..._dynamo.utils import counters from .. import ir, lowering as L from ..kernel.mm_common import mm_args from ..select_algorithm import DataProcessorTemplateWrapper from ..utils import cache_on_self, has_free_symbols, parallel_num_threads from ..virtualized import V -from .cpp_micro_gemm import create_micro_gemm +from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType from .cpp_template import CppTemplate from .cpp_template_kernel import CppTemplateKernel @@ -84,15 +85,18 @@ extern "C" int64_t k_block_start = 0; int64_t k_block_end = K0_blocks; {%- endif %} + {{ micro_gemm.codegen_init(kernel) }} for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { const int64_t m_start = mc * M0; const int64_t m_end = std::min((mc + Mc_blocks) * M0, M); const int64_t m_size = m_end - m_start; + {%- if use_local_acc %} + {{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"]) }} + {%- endif %} for (int64_t nc = n_block_start; nc < n_block_end; ++nc) { const int64_t n_start = nc * N0; const int64_t n_size = N0; {%- if use_local_acc %} - {{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"]) }} {%- set acc = kernel.local_buffers[acc_buf_name] %} {%- else %} {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} @@ -128,6 +132,7 @@ extern "C" }} } } + {{ micro_gemm.codegen_finalize(kernel) }} } } """ @@ -332,6 +337,17 @@ class CppPackedGemmTemplate(CppTemplate): blocked_w = ( W.reshape(k, n // block_n, block_n).transpose(0, 1).contiguous() ) + if micro_gemm.get_b_layout() != LayoutType.NORMAL: + assert ( + micro_gemm.get_b_layout() == LayoutType.VNNI2 + ), "We only support VNNI2 for now" + assert k % 2 == 0, "k should be even for VNNI2 layout" + blocked_w = ( + blocked_w.view(n // block_n, k // 2, 2, block_n) + .transpose(-1, -2) + .contiguous() + .view(n // block_n, k, block_n) + ) # normalize stride to be "contiguous_strides" per size # this avoids the problems in L.view during template codegen new_stride = [1] @@ -462,6 +478,8 @@ class CppPackedGemmTemplate(CppTemplate): ) assert micro_gemm is not None assert self.register_blocking == micro_gemm.register_blocking + if isinstance(micro_gemm, CppMicroGemmAMX): + counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 options = dict( X=X, diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 47d6e87e5a70..64514c0bc4f4 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -1,13 +1,14 @@ # mypy: allow-untyped-defs -from collections import namedtuple -from typing import Dict, List, Optional, Type +import dataclasses +from enum import Enum +from typing import Callable, Dict, List, Optional, Type import sympy import torch from .. import ir -from ..codecache import pick_vec_isa, VecAVX2, VecAVX512 +from ..codecache import pick_vec_isa, VecAMX, VecAVX2, VecAVX512, VecISA from ..utils import IndentedBuffer, parallel_num_threads from ..virtualized import V from .common import KernelTemplate @@ -15,6 +16,11 @@ from .cpp_template_kernel import CppTemplateKernel from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp +class LayoutType(Enum): + NORMAL = 0 + VNNI2 = 1 + + class CppMicroGemm: """ A class that codegens a kernel that computes small-sized matrix multiplication. @@ -30,6 +36,9 @@ class CppMicroGemm: DECLARE_KERNEL = r""" template inline void {{kernel_name}}( +{%- if kernel_extra_args_declare %} + {{kernel_extra_args_declare}} +{%- endif %} const {{input_t}}* __restrict__ A, const {{input_t}}* __restrict__ B, {{output_t}}* __restrict__ C, @@ -69,12 +78,19 @@ inline void {{kernel_name}}( "output_t": DTYPE_TO_CPP[self.output_dtype], "compute_t": DTYPE_TO_CPP[self.compute_dtype], "alpha": self.alpha, + "kernel_extra_args_declare": self.get_kernel_extra_args_declare(), } def get_kernel_declaration(self): options = self.get_common_options() return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options) + def get_kernel_extra_args_declare(self) -> str: + return "" + + def get_kernel_extra_args(self) -> str: + return "" + def codegen_define(self, kernel: CppTemplateKernel) -> str: raise NotImplementedError @@ -102,6 +118,9 @@ inline void {{kernel_name}}( res = IndentedBuffer() res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(") with res.indent(): + extra_args = self.get_kernel_extra_args() + if extra_args: + res.writeline(extra_args) res.writeline(f"{A_ptr},") res.writeline(f"{B_ptr},") res.writeline(f"{C_ptr},") @@ -114,17 +133,31 @@ inline void {{kernel_name}}( res.writeline(");") return res.getvalue() + def codegen_init( + self, + kernel: CppTemplateKernel, + ) -> str: + return "" + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "" + + def get_b_layout(self) -> LayoutType: + return LayoutType.NORMAL + + +@dataclasses.dataclass +class CppMicroGemmConfig: + input_dtype: torch.dtype + output_dtype: torch.dtype + compute_dtype: torch.dtype + vec_isa_cls: Type[VecISA] + register_blocking: GemmBlocking + extra_check: Optional[Callable[..., bool]] = None -CppMicroGemmConfig = namedtuple( - "CppMicroGemmConfig", - [ - "input_dtype", - "output_dtype", - "compute_dtype", - "vec_isa_cls", - "register_blocking", - ], -) micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {} @@ -147,6 +180,7 @@ def generate_gemm_config( input_dtype=torch.float, output_dtype=None, compute_dtype=None, + extra_check=None, ): if output_dtype is None: output_dtype = input_dtype @@ -159,6 +193,7 @@ def generate_gemm_config( compute_dtype, vec_isa_cls, GemmBlocking(*blocking), + extra_check, ) for blocking in register_blockings ] @@ -197,36 +232,51 @@ class CppMicroGemmRef(CppMicroGemm): return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) +def check_fp32_vec_extra(config, m, n, k, alpha, num_threads): + # TODO(jgong5): support n % n_block_size != 0 + return n % config.register_blocking.block_n == 0 + + @register_micro_gemm( *generate_gemm_config( - VecAVX512, [(8, 48, 1), (8, 32, 1), (16, 16, 1)], input_dtype=torch.float + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.float, + extra_check=check_fp32_vec_extra, ), *generate_gemm_config( VecAVX512, [(8, 48, 1), (8, 32, 1), (16, 16, 1)], input_dtype=torch.bfloat16, output_dtype=torch.float, + extra_check=check_fp32_vec_extra, ), *generate_gemm_config( VecAVX512, [(8, 48, 1), (8, 32, 1), (16, 16, 1)], input_dtype=torch.half, output_dtype=torch.float, + extra_check=check_fp32_vec_extra, ), *generate_gemm_config( - VecAVX2, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], input_dtype=torch.float + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + extra_check=check_fp32_vec_extra, ), *generate_gemm_config( VecAVX2, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], input_dtype=torch.bfloat16, output_dtype=torch.float, + extra_check=check_fp32_vec_extra, ), *generate_gemm_config( VecAVX2, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], input_dtype=torch.half, output_dtype=torch.float, + extra_check=check_fp32_vec_extra, ), ) class CppMicroGemmFP32Vec(CppMicroGemm): @@ -367,6 +417,218 @@ inline void {{kernel_name}}_kernel( return result +# extra check for CppMicroGemmAMX +def check_amx_extra(config, m, n, k, alpha, num_threads): + return n % config.register_blocking.block_n == 0 and k % 2 == 0 and alpha == 1 + + +@register_micro_gemm( + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + extra_check=check_amx_extra, + ), +) +class CppMicroGemmAMX(CppMicroGemm): + """ + This class generates the code for micro gemm using Advanced Matrix eXtention (AMX) + instructions available in 4th generation Intel Xeon for compute. + It supports input types of torch.bfloat16 with fp32 output. + TODO(jgong5): support int8 data type. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); + // TODO(jgong5): loop unroll for M and N + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + int64_t m_tail = m; + for (int64_t n = 0; n < N; n += {{block_n}}) { + {%- for num_rows in range(block_m, 0, -16) %} + {%- if num_rows != block_m %} + else + {%- endif %} + if (block_m >= {{num_rows}}) { + {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + amx_state, + A + m * lda, + B + n, + C + m * ldc + n, + K, + lda, + ldb, + ldc, + 16 + ); + block_m -= {{num_rows}}; + m_tail += {{num_rows}}; + } + {%- endfor %} + if (block_m > 0) { + {{kernel_name}}_amx_kernel_16_{{num_columns}}( + amx_state, + A + m_tail * lda, + B + n, + C + m_tail * ldc + n, + K, + lda, + ldb, + ldc, + block_m + ); + } + } + } +} +""" + + TEMPLATE_KERNEL = r""" +template +inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + AMXState& amx_state, + const {{input_t}}* __restrict__ A, + const {{input_t}}* __restrict__ B, + {{output_t}}* __restrict__ C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + uint8_t tilecfg_rows +) { + // TODO(jgong5): add prefetch hint for A, B, C + auto loadconfig = [](const amx_tilecfg& cfg) { + _tile_loadconfig(&cfg); + }; + const auto last_k_offset = K / {{block_k}} * {{block_k}}; + const auto tail_k_size = K - last_k_offset; + if C10_LIKELY (last_k_offset > 0) { + amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig); + } else { + amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); + } + auto load_c = [&]() { + {%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); + {%- endfor %} + {%- endfor %} + }; + auto zero_c = [&]() { + {%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_zero({{tile_idx}}); + {%- endfor %} + {%- endfor %} + }; + + if constexpr (accum) { + load_c(); + } else { + zero_c(); + } + + auto compute = [&](int k) { + {%- set tile_offset_a = num_rows // 16 * num_columns %} + {%- set tile_offset_b = tile_offset_a + num_rows // 16 %} + {%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx_a = tile_offset_a + tile_row %} + {%- set tile_idx_b = tile_offset_b + tile_col %} + {%- set tile_idx_c = tile_row * num_columns + tile_col %} + {%- if tile_col == 0 %} + _tile_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}})); + {%- endif %} + {%- if tile_row == 0 %} + _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * 2}}, ldb * 2 * sizeof({{input_t}})); + {%- endif %} + _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- endfor %} + {%- endfor %} + }; + + {{kernel.unroll_pragma(4)}} + for (int k = 0; k < last_k_offset; k += {{block_k}}) { + compute(k); + } + + auto store_c = [&]() { + // store to C + {%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); + {%- endfor %} + {%- endfor %} + }; + + // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead + if C10_UNLIKELY (tail_k_size > 0) { + if C10_LIKELY (last_k_offset > 0) { + store_c(); + amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); + load_c(); + } + compute(last_k_offset); + } + + store_c(); +} +""" + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + block_m, block_n, block_k = self.register_blocking + assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX" + assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX" + assert block_k == 32, "Only support block_k = 32 for AMX" + num_columns = block_n // 16 + options = { + "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, + "block_m": block_m, + "block_n": block_n, + "block_k": block_k, + "num_columns": num_columns, + **self.get_common_options(), + } + result = "" + for num_rows in range(block_m, 0, -16): + amx_kernel_options = {**options, "num_rows": num_rows} + result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + amx_kernel_options + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + def codegen_init( + self, + kernel: CppTemplateKernel, + ) -> str: + return "AMXState amx_state;" + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "amx_state.release([]() { _tile_release(); });" + + def get_kernel_extra_args_declare(self) -> str: + return "AMXState& amx_state," + + def get_kernel_extra_args(self) -> str: + return "amx_state," + + def get_b_layout(self): + return LayoutType.VNNI2 + + def create_micro_gemm( name, m, @@ -403,36 +665,48 @@ def create_micro_gemm( matched_configs = [] for cls, configs in micro_gemm_configs.items(): for config in configs: - if not isinstance(vec_isa, config.vec_isa_cls): + if not issubclass(vec_isa.__class__, config.vec_isa_cls): continue if ( config.input_dtype == input_dtype and config.output_dtype == output_dtype and config.compute_dtype == compute_dtype ): - block_m, block_n, block_k = config.register_blocking - # TODO(jgong5): support n % n_block_size != 0 - if n % block_n != 0: + if config.extra_check is not None and not config.extra_check( + config, m, n, k, alpha, num_threads + ): continue + block_m, block_n, block_k = config.register_blocking # Criteria on the ranking of configurations - # 1. Dividable by block sizes (block_m, block_k) - # 2. Number of mxn blocks is large enough to occupy all the threads - # 3. Register blocks are larger + # 1. ISA: AMX > VEC + # 2. Dividable by block sizes (block_m, block_n, block_k) + # 3. Number of mxn blocks is large enough to occupy all the threads + # 4. Register blocks are larger + isa_score = 0 + if config.vec_isa_cls == VecAMX: + isa_score += 1 dividable_score = 0 - if k % block_k == 0: - dividable_score += 1 if m % block_m == 0: dividable_score += 1 + if n % block_n == 0: + dividable_score += 1 + if k % block_k == 0: + dividable_score += 1 occupancy_score = 0 - n_blocks = n // block_n - total_mxn_blocks = n // block_n * ((m + block_m - 1) // block_m) + n_blocks = (n + block_n - 1) // block_n + total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m) if n_blocks >= num_threads: occupancy_score += 1 if total_mxn_blocks >= num_threads: occupancy_score += 1 + register_bytes = ( + block_m * block_n * config.compute_dtype.itemsize + + (block_m * block_k + block_k * block_n) + * config.input_dtype.itemsize + ) matched_configs.append( ( - (dividable_score, occupancy_score, block_m * block_n * block_k), + (isa_score, dividable_score, occupancy_score, register_bytes), cls, config, ) diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 5da8b3918bc2..a8715deeb631 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -416,3 +416,65 @@ inline void mm_get_thread_blocks( m_block_start = std::min(thread_id * Mt_blocks, M_blocks); m_block_end = std::min(m_block_start + Mt_blocks, M_blocks); } + +struct amx_tilecfg { + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[16]; + uint8_t rows[16]; +}; + +class AMXState { + private: + amx_tilecfg tilecfg_; + uint8_t rows_; + uint16_t colsb_; + uint8_t num_tile_rows_; + uint8_t num_tile_columns_; + + public: + AMXState() : rows_(0), colsb_(0), num_tile_rows_(0), num_tile_columns_(0) { + memset(&tilecfg_, 0, sizeof(tilecfg_)); + } + + inline void configure( + uint8_t rows, + uint16_t colsb, + uint8_t num_tile_rows, + uint8_t num_tile_columns, + void (*loadconfig)(const amx_tilecfg&)) { + if (tilecfg_.palette_id == 1 && rows_ == rows && colsb_ == colsb && + num_tile_rows_ == num_tile_rows && + num_tile_columns_ == num_tile_columns) { + return; + } + tilecfg_.palette_id = 1; + rows_ = rows; + colsb_ = colsb; + num_tile_rows_ = num_tile_rows; + num_tile_columns_ = num_tile_columns; + const auto num_c_tiles = num_tile_rows * num_tile_columns; + // For C + for (int i = 0; i < num_c_tiles; i++) { + tilecfg_.rows[i] = rows; + tilecfg_.colsb[i] = 64; + } + // For A + for (int i = 0; i < num_tile_rows; i++) { + tilecfg_.rows[i + num_c_tiles] = rows; + tilecfg_.colsb[i + num_c_tiles] = colsb; + } + // For B + for (int i = 0; i < num_tile_columns; i++) { + tilecfg_.rows[i + num_c_tiles + num_tile_rows] = colsb / 4; + tilecfg_.colsb[i + num_c_tiles + num_tile_rows] = 64; + } + loadconfig(tilecfg_); + } + + inline void release(void (*tile_release)()) { + tilecfg_.palette_id = 0; + tile_release(); + } +}; diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index d404ad4ba3b9..978405c0410c 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -45,6 +45,16 @@ def _is_cpu_support_vnni() -> bool: return torch._C._cpu._is_cpu_support_avx512_vnni() +def _is_cpu_support_amx_tile() -> bool: + r"""Returns a bool indicating if CPU supports AMX_TILE.""" + return torch._C._cpu._is_cpu_support_amx_tile() + + +def _init_amx() -> bool: + r"""Initializes AMX instructions.""" + return torch._C._cpu._init_amx() + + def is_available() -> bool: r"""Returns a bool indicating if CPU is currently available. diff --git a/torch/csrc/cpu/Module.cpp b/torch/csrc/cpu/Module.cpp index 3485f2a991cb..37109e89aa48 100644 --- a/torch/csrc/cpu/Module.cpp +++ b/torch/csrc/cpu/Module.cpp @@ -11,6 +11,8 @@ void initModule(PyObject* module) { cpu.def("_is_cpu_support_avx2", at::cpu::is_cpu_support_avx2); cpu.def("_is_cpu_support_avx512", at::cpu::is_cpu_support_avx512); cpu.def("_is_cpu_support_avx512_vnni", at::cpu::is_cpu_support_avx512_vnni); + cpu.def("_is_cpu_support_amx_tile", at::cpu::is_cpu_support_amx_tile); + cpu.def("_init_amx", at::cpu::init_amx); } } // namespace torch::cpu