mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor][cpp] BF16 AMX micro-gemm support (#127195)
This PR adds the intrinsics based micro-gemm for BF16 using Advanced Matrix eXtension (AMX) instructions available in Intel 4th and 5th Xeon processors. A compilation check is added to `codecache.py` to check the validity of the compiler support. Also, since AMX requires an initialization in the Linux kernel to extra register states, an initialization function is added to do that and triggered via `codecache.py`. Performance speedups with >=10% on BF16 AMP, max_autotune vs. no autotune, measured on Intel(R) Xeon(R) Platinum 8488C: Static shapes Single-threaded | Model Family | Model Name | Speedup | |--------------|------------|---------| | timm_models | mixer_b16_224 | 1.54 | | timm_models | convit_base | 1.53 | | huggingface | MobileBertForQuestionAnswering | 1.52 | | torchbench | fastNLP_Bert | 1.44 | | torchbench | llama | 1.33 | | timm_models | swin_base_patch4_window7_224 | 1.31 | | torchbench | dlrm | 1.28 | | torchbench | timm_vision_transformer_large | 1.28 | | huggingface | MobileBertForMaskedLM | 1.27 | | timm_models | vit_base_patch16_224 | 1.26 | | timm_models | beit_base_patch16_224 | 1.23 | | timm_models | jx_nest_base | 1.21 | | torchbench | pyhpc_equation_of_state | 1.18 | | huggingface | Speech2Text2ForCausalLM | 1.15 | | timm_models | pit_b_224 | 1.14 | | timm_models | twins_pcpvt_base | 1.14 | | torchbench | maml_omniglot | 1.1 | | timm_models | botnet26t_256 | 1.1 | Multi-threaded | Model Family | Model Name | Speedup | |--------------|------------|---------| | torchbench | BERT_pytorch | 1.35 | | torchbench | lennard_jones | 2.43 | | torchbench | hf_Albert | 1.35 | | torchbench | hf_T5 | 1.34 | | torchbench | soft_actor_critic | 1.34 | | torchbench | fastNLP_Bert | 1.28 | | huggingface | LayoutLMForSequenceClassification | 1.26 | | torchbench | llama | 1.24 | | huggingface | GPT2ForSequenceClassification | 1.19 | | torchbench | hf_Bart | 1.17 | | torchbench | hf_Bert_large | 1.16 | | torchbench | hf_GPT2 | 1.16 | | timm_models | gmixer_24_224 | 1.16 | | torchbench | hf_GPT2_large | 1.15 | | torchbench | maml_omniglot | 1.14 | | torchbench | hf_Bert | 1.13 | | torchbench | hf_DistilBert | 1.13 | | torchbench | hf_T5_large | 1.12 | | huggingface | MT5ForConditionalGeneration | 1.11 | Dynamic shapes Single-threaded | Model Family | Model Name | Speedup | |--------------|------------|-------| | timm_models | mixer_b16_224 | 1.52 | | timm_models | convit_base | 1.5 | | huggingface | MobileBertForQuestionAnswering | 1.49 | | torchbench | fastNLP_Bert | 1.42 | | torchbench | timm_vision_transformer_large | 1.28 | | timm_models | swin_base_patch4_window7_224 | 1.27 | | torchbench | llama | 1.26 | | huggingface | MobileBertForMaskedLM | 1.25 | | timm_models | vit_base_patch16_224 | 1.25 | | timm_models | beit_base_patch16_224 | 1.24 | | timm_models | jx_nest_base | 1.2 | | torchbench | dlrm | 1.19 | | timm_models | pit_b_224 | 1.13 | | timm_models | twins_pcpvt_base | 1.13 | | torchbench | hf_Bert_large | 1.12 | | torchbench | hf_BigBird | 1.11 | | huggingface | Speech2Text2ForCausalLM | 1.11 | | timm_models | eca_botnext26ts_256 | 1.11 | | timm_models | botnet26t_256 | 1.1 | Multi-threaded | Model Family | Model Name | Speedup | |--------------|------------|-------| | torchbench | BERT_pytorch | 1.18 | | torchbench | lennard_jones | 2.18 | | torchbench | hf_Albert | 1.37 | | torchbench | soft_actor_critic | 1.31 | | huggingface | GPT2ForSequenceClassification | 1.29 | | torchbench | hf_T5 | 1.28 | | torchbench | fastNLP_Bert | 1.27 | | torchbench | hf_Bart | 1.21 | | torchbench | hf_Bert_large | 1.19 | | torchbench | hf_T5_large | 1.19 | | torchbench | hf_Bert | 1.16 | | torchbench | hf_GPT2 | 1.16 | | huggingface | CamemBert | 1.16 | | torchbench | hf_GPT2_large | 1.13 | | torchbench | functorch_maml_omniglot | 1.12 | | huggingface | BertForMaskedLM | 1.12 | | huggingface | MT5ForConditionalGeneration | 1.12 | | torchbench | hf_DistilBert | 1.11 | | timm_models | mixnet_l | 1.11 | | timm_models | tf_mixnet_l | 1.11 | No perf regressions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127195 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
632910e2a8
commit
914d3ca2ba
@ -2,6 +2,10 @@
|
||||
#if !defined(__s390x__ ) && !defined(__powerpc__)
|
||||
#include <cpuinfo.h>
|
||||
#endif
|
||||
#if defined(__linux__)
|
||||
#include <sys/syscall.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
namespace at::cpu {
|
||||
bool is_cpu_support_avx2() {
|
||||
@ -28,4 +32,47 @@ bool is_cpu_support_avx512_vnni() {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool is_cpu_support_amx_tile() {
|
||||
#if !defined(__s390x__) && !defined(__powerpc__)
|
||||
return cpuinfo_initialize() && cpuinfo_has_x86_amx_tile();
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool init_amx() {
|
||||
if (!is_cpu_support_amx_tile()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(__linux__) && !defined(__ANDROID__)
|
||||
#define XFEATURE_XTILECFG 17
|
||||
#define XFEATURE_XTILEDATA 18
|
||||
#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG)
|
||||
#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA)
|
||||
#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA)
|
||||
|
||||
#define ARCH_GET_XCOMP_PERM 0x1022
|
||||
#define ARCH_REQ_XCOMP_PERM 0x1023
|
||||
|
||||
unsigned long bitmask = 0;
|
||||
// Request permission to use AMX instructions
|
||||
long rc = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
|
||||
if (rc) {
|
||||
return false;
|
||||
}
|
||||
// Check if the system supports AMX instructions
|
||||
rc = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask);
|
||||
if (rc) {
|
||||
return false;
|
||||
}
|
||||
if (bitmask & XFEATURE_MASK_XTILE) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace at::cpu
|
||||
|
@ -10,4 +10,10 @@ TORCH_API bool is_cpu_support_avx512();
|
||||
// Detect if CPU support Vector Neural Network Instruction.
|
||||
TORCH_API bool is_cpu_support_avx512_vnni();
|
||||
|
||||
// Detect if CPU support Advanced Matrix Extension.
|
||||
TORCH_API bool is_cpu_support_amx_tile();
|
||||
|
||||
// Enable the system to use AMX instructions.
|
||||
TORCH_API bool init_amx();
|
||||
|
||||
} // namespace at::cpu
|
||||
|
@ -1588,8 +1588,8 @@ class CPUReproTests(TestCase):
|
||||
)
|
||||
@patch("torch.cuda.is_available", lambda: False)
|
||||
def test_auto_simd(self):
|
||||
vec_avx512 = codecache.supported_vec_isa_list[0]
|
||||
vec_avx2 = codecache.supported_vec_isa_list[1]
|
||||
vec_avx512 = codecache.supported_vec_isa_list[1]
|
||||
vec_avx2 = codecache.supported_vec_isa_list[2]
|
||||
self.assertTrue(vec_avx512.bit_width() == 512)
|
||||
self.assertTrue(vec_avx2.bit_width() == 256)
|
||||
self.assertTrue(vec_avx512.nelements() == 16)
|
||||
|
@ -12,6 +12,7 @@ import torch._dynamo.config as dynamo_config
|
||||
import torch._inductor.config as inductor_config
|
||||
import torch._inductor.select_algorithm as select_algorithm
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.codecache import VecAMX
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
@ -333,6 +334,37 @@ class TestSelectAlgorithm(TestCase):
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
@torch.no_grad
|
||||
@parametrize("bias", (True, False))
|
||||
def test_linear_amx(self, bias):
|
||||
batch_size = 1024
|
||||
in_features = 1024
|
||||
out_features = 1024
|
||||
dtype = torch.bfloat16
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, bias):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features, bias)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
counters.clear()
|
||||
v = torch.randn(batch_size, in_features).to(dtype=dtype)
|
||||
mod = M(bias=bias).to(dtype=dtype).eval()
|
||||
atol, rtol = 1e-2, 1e-2
|
||||
with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
vec_amx = VecAMX()
|
||||
if vec_amx:
|
||||
self.assertTrue(counters["inductor"]["cpp_micro_gemm_amx_counter"] > 0)
|
||||
else:
|
||||
self.assertEqual(counters["inductor"]["cpp_micro_gemm_amx_counter"], 0)
|
||||
|
||||
|
||||
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
|
||||
class _DynamicShapesTestBase(TestCase):
|
||||
@ -351,6 +383,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
|
||||
test_linear_with_unary_binary_dynamic_shapes = (
|
||||
TestSelectAlgorithm.test_linear_with_unary_binary
|
||||
)
|
||||
test_linear_amx_dynamic_shapes = TestSelectAlgorithm.test_linear_amx
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")
|
||||
|
@ -5,3 +5,5 @@ from torch.types import _bool
|
||||
def _is_cpu_support_avx2() -> _bool: ...
|
||||
def _is_cpu_support_avx512() -> _bool: ...
|
||||
def _is_cpu_support_avx512_vnni() -> _bool: ...
|
||||
def _is_cpu_support_amx_tile() -> _bool: ...
|
||||
def _init_amx() -> _bool: ...
|
||||
|
@ -409,6 +409,8 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._cpu._is_cpu_support_avx2",
|
||||
"torch._C._cpu._is_cpu_support_avx512",
|
||||
"torch._C._cpu._is_cpu_support_avx512_vnni",
|
||||
"torch._C._cpu._is_cpu_support_amx_tile",
|
||||
"torch._C._cpu._init_amx",
|
||||
"torch._C._crash_if_aten_asan",
|
||||
"torch._C._crash_if_csrc_asan",
|
||||
"torch._C._crash_if_csrc_ubsan",
|
||||
@ -2423,6 +2425,8 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch.cpu._is_cpu_support_avx2",
|
||||
"torch.cpu._is_cpu_support_avx512",
|
||||
"torch.cpu._is_cpu_support_avx512_vnni",
|
||||
"torch.cpu._is_cpu_support_amx_tile",
|
||||
"torch.cpu._init_amx",
|
||||
"torch.cpu.current_device",
|
||||
"torch.cpu.current_stream",
|
||||
"torch.cpu.device_count",
|
||||
|
@ -1336,18 +1336,11 @@ cdll.LoadLibrary("__lib_path__")
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
def __bool__(self) -> bool:
|
||||
def check_build(self, code) -> bool:
|
||||
from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions
|
||||
|
||||
if config.cpp.vec_isa_ok is not None:
|
||||
return config.cpp.vec_isa_ok
|
||||
|
||||
if config.is_fbcode():
|
||||
return True
|
||||
|
||||
key, input_path = write(
|
||||
VecISA._avx_code,
|
||||
code,
|
||||
"cpp",
|
||||
extra=_get_isa_dry_compile_fingerprint(self._arch_flags),
|
||||
)
|
||||
@ -1385,6 +1378,16 @@ cdll.LoadLibrary("__lib_path__")
|
||||
|
||||
return True
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
def __bool__(self) -> bool:
|
||||
if config.cpp.vec_isa_ok is not None:
|
||||
return config.cpp.vec_isa_ok
|
||||
|
||||
if config.is_fbcode():
|
||||
return True
|
||||
|
||||
return self.check_build(VecISA._avx_code)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecNEON(VecISA):
|
||||
@ -1418,6 +1421,46 @@ class VecAVX512(VecISA):
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecAMX(VecAVX512):
|
||||
_arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__str__() + " amx_tile"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
_amx_code = """
|
||||
#include <cstdint>
|
||||
#include <immintrin.h>
|
||||
|
||||
struct amx_tilecfg {
|
||||
uint8_t palette_id;
|
||||
uint8_t start_row;
|
||||
uint8_t reserved_0[14];
|
||||
uint16_t colsb[16];
|
||||
uint8_t rows[16];
|
||||
};
|
||||
|
||||
extern "C" void __amx_chk_kernel() {
|
||||
amx_tilecfg cfg = {0};
|
||||
_tile_loadconfig(&cfg);
|
||||
_tile_zero(0);
|
||||
_tile_dpbf16ps(0, 1, 2);
|
||||
_tile_dpbusd(0, 1, 2);
|
||||
}
|
||||
"""
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
def __bool__(self) -> bool:
|
||||
if super().__bool__():
|
||||
if config.is_fbcode():
|
||||
return False
|
||||
if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecAVX2(VecISA):
|
||||
_bit_width = 256
|
||||
@ -1483,15 +1526,17 @@ def x86_isa_checker() -> List[str]:
|
||||
|
||||
avx2 = torch.cpu._is_cpu_support_avx2()
|
||||
avx512 = torch.cpu._is_cpu_support_avx512()
|
||||
amx_tile = torch.cpu._is_cpu_support_amx_tile()
|
||||
|
||||
_check_and_append_supported_isa(supported_isa, avx2, "avx2")
|
||||
_check_and_append_supported_isa(supported_isa, avx512, "avx512")
|
||||
_check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile")
|
||||
|
||||
return supported_isa
|
||||
|
||||
|
||||
invalid_vec_isa = InvalidVecISA()
|
||||
supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()]
|
||||
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()]
|
||||
|
||||
|
||||
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
|
||||
@ -1528,7 +1573,7 @@ def valid_vec_isa_list() -> List[VecISA]:
|
||||
"""
|
||||
_cpu_supported_x86_isa = x86_isa_checker()
|
||||
for isa in supported_vec_isa_list:
|
||||
if str(isa) in _cpu_supported_x86_isa and isa:
|
||||
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa:
|
||||
isa_list.append(isa)
|
||||
|
||||
return isa_list
|
||||
|
@ -3,13 +3,14 @@ from typing import Any, Callable, cast, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils
|
||||
from ..._dynamo.utils import counters
|
||||
from .. import ir, lowering as L
|
||||
|
||||
from ..kernel.mm_common import mm_args
|
||||
from ..select_algorithm import DataProcessorTemplateWrapper
|
||||
from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
|
||||
from ..virtualized import V
|
||||
from .cpp_micro_gemm import create_micro_gemm
|
||||
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType
|
||||
from .cpp_template import CppTemplate
|
||||
|
||||
from .cpp_template_kernel import CppTemplateKernel
|
||||
@ -84,15 +85,18 @@ extern "C"
|
||||
int64_t k_block_start = 0;
|
||||
int64_t k_block_end = K0_blocks;
|
||||
{%- endif %}
|
||||
{{ micro_gemm.codegen_init(kernel) }}
|
||||
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
|
||||
const int64_t m_start = mc * M0;
|
||||
const int64_t m_end = std::min((mc + Mc_blocks) * M0, M);
|
||||
const int64_t m_size = m_end - m_start;
|
||||
{%- if use_local_acc %}
|
||||
{{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"]) }}
|
||||
{%- endif %}
|
||||
for (int64_t nc = n_block_start; nc < n_block_end; ++nc) {
|
||||
const int64_t n_start = nc * N0;
|
||||
const int64_t n_size = N0;
|
||||
{%- if use_local_acc %}
|
||||
{{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"]) }}
|
||||
{%- set acc = kernel.local_buffers[acc_buf_name] %}
|
||||
{%- else %}
|
||||
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
|
||||
@ -128,6 +132,7 @@ extern "C"
|
||||
}}
|
||||
}
|
||||
}
|
||||
{{ micro_gemm.codegen_finalize(kernel) }}
|
||||
}
|
||||
}
|
||||
"""
|
||||
@ -332,6 +337,17 @@ class CppPackedGemmTemplate(CppTemplate):
|
||||
blocked_w = (
|
||||
W.reshape(k, n // block_n, block_n).transpose(0, 1).contiguous()
|
||||
)
|
||||
if micro_gemm.get_b_layout() != LayoutType.NORMAL:
|
||||
assert (
|
||||
micro_gemm.get_b_layout() == LayoutType.VNNI2
|
||||
), "We only support VNNI2 for now"
|
||||
assert k % 2 == 0, "k should be even for VNNI2 layout"
|
||||
blocked_w = (
|
||||
blocked_w.view(n // block_n, k // 2, 2, block_n)
|
||||
.transpose(-1, -2)
|
||||
.contiguous()
|
||||
.view(n // block_n, k, block_n)
|
||||
)
|
||||
# normalize stride to be "contiguous_strides" per size
|
||||
# this avoids the problems in L.view during template codegen
|
||||
new_stride = [1]
|
||||
@ -462,6 +478,8 @@ class CppPackedGemmTemplate(CppTemplate):
|
||||
)
|
||||
assert micro_gemm is not None
|
||||
assert self.register_blocking == micro_gemm.register_blocking
|
||||
if isinstance(micro_gemm, CppMicroGemmAMX):
|
||||
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
|
||||
|
||||
options = dict(
|
||||
X=X,
|
||||
|
@ -1,13 +1,14 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from collections import namedtuple
|
||||
from typing import Dict, List, Optional, Type
|
||||
import dataclasses
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List, Optional, Type
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from .. import ir
|
||||
from ..codecache import pick_vec_isa, VecAVX2, VecAVX512
|
||||
from ..codecache import pick_vec_isa, VecAMX, VecAVX2, VecAVX512, VecISA
|
||||
from ..utils import IndentedBuffer, parallel_num_threads
|
||||
from ..virtualized import V
|
||||
from .common import KernelTemplate
|
||||
@ -15,6 +16,11 @@ from .cpp_template_kernel import CppTemplateKernel
|
||||
from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp
|
||||
|
||||
|
||||
class LayoutType(Enum):
|
||||
NORMAL = 0
|
||||
VNNI2 = 1
|
||||
|
||||
|
||||
class CppMicroGemm:
|
||||
"""
|
||||
A class that codegens a kernel that computes small-sized matrix multiplication.
|
||||
@ -30,6 +36,9 @@ class CppMicroGemm:
|
||||
DECLARE_KERNEL = r"""
|
||||
template <bool accum>
|
||||
inline void {{kernel_name}}(
|
||||
{%- if kernel_extra_args_declare %}
|
||||
{{kernel_extra_args_declare}}
|
||||
{%- endif %}
|
||||
const {{input_t}}* __restrict__ A,
|
||||
const {{input_t}}* __restrict__ B,
|
||||
{{output_t}}* __restrict__ C,
|
||||
@ -69,12 +78,19 @@ inline void {{kernel_name}}(
|
||||
"output_t": DTYPE_TO_CPP[self.output_dtype],
|
||||
"compute_t": DTYPE_TO_CPP[self.compute_dtype],
|
||||
"alpha": self.alpha,
|
||||
"kernel_extra_args_declare": self.get_kernel_extra_args_declare(),
|
||||
}
|
||||
|
||||
def get_kernel_declaration(self):
|
||||
options = self.get_common_options()
|
||||
return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options)
|
||||
|
||||
def get_kernel_extra_args_declare(self) -> str:
|
||||
return ""
|
||||
|
||||
def get_kernel_extra_args(self) -> str:
|
||||
return ""
|
||||
|
||||
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -102,6 +118,9 @@ inline void {{kernel_name}}(
|
||||
res = IndentedBuffer()
|
||||
res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(")
|
||||
with res.indent():
|
||||
extra_args = self.get_kernel_extra_args()
|
||||
if extra_args:
|
||||
res.writeline(extra_args)
|
||||
res.writeline(f"{A_ptr},")
|
||||
res.writeline(f"{B_ptr},")
|
||||
res.writeline(f"{C_ptr},")
|
||||
@ -114,17 +133,31 @@ inline void {{kernel_name}}(
|
||||
res.writeline(");")
|
||||
return res.getvalue()
|
||||
|
||||
def codegen_init(
|
||||
self,
|
||||
kernel: CppTemplateKernel,
|
||||
) -> str:
|
||||
return ""
|
||||
|
||||
def codegen_finalize(
|
||||
self,
|
||||
kernel: CppTemplateKernel,
|
||||
) -> str:
|
||||
return ""
|
||||
|
||||
def get_b_layout(self) -> LayoutType:
|
||||
return LayoutType.NORMAL
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CppMicroGemmConfig:
|
||||
input_dtype: torch.dtype
|
||||
output_dtype: torch.dtype
|
||||
compute_dtype: torch.dtype
|
||||
vec_isa_cls: Type[VecISA]
|
||||
register_blocking: GemmBlocking
|
||||
extra_check: Optional[Callable[..., bool]] = None
|
||||
|
||||
CppMicroGemmConfig = namedtuple(
|
||||
"CppMicroGemmConfig",
|
||||
[
|
||||
"input_dtype",
|
||||
"output_dtype",
|
||||
"compute_dtype",
|
||||
"vec_isa_cls",
|
||||
"register_blocking",
|
||||
],
|
||||
)
|
||||
|
||||
micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {}
|
||||
|
||||
@ -147,6 +180,7 @@ def generate_gemm_config(
|
||||
input_dtype=torch.float,
|
||||
output_dtype=None,
|
||||
compute_dtype=None,
|
||||
extra_check=None,
|
||||
):
|
||||
if output_dtype is None:
|
||||
output_dtype = input_dtype
|
||||
@ -159,6 +193,7 @@ def generate_gemm_config(
|
||||
compute_dtype,
|
||||
vec_isa_cls,
|
||||
GemmBlocking(*blocking),
|
||||
extra_check,
|
||||
)
|
||||
for blocking in register_blockings
|
||||
]
|
||||
@ -197,36 +232,51 @@ class CppMicroGemmRef(CppMicroGemm):
|
||||
return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options)
|
||||
|
||||
|
||||
def check_fp32_vec_extra(config, m, n, k, alpha, num_threads):
|
||||
# TODO(jgong5): support n % n_block_size != 0
|
||||
return n % config.register_blocking.block_n == 0
|
||||
|
||||
|
||||
@register_micro_gemm(
|
||||
*generate_gemm_config(
|
||||
VecAVX512, [(8, 48, 1), (8, 32, 1), (16, 16, 1)], input_dtype=torch.float
|
||||
VecAVX512,
|
||||
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
|
||||
input_dtype=torch.float,
|
||||
extra_check=check_fp32_vec_extra,
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecAVX512,
|
||||
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
|
||||
input_dtype=torch.bfloat16,
|
||||
output_dtype=torch.float,
|
||||
extra_check=check_fp32_vec_extra,
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecAVX512,
|
||||
[(8, 48, 1), (8, 32, 1), (16, 16, 1)],
|
||||
input_dtype=torch.half,
|
||||
output_dtype=torch.float,
|
||||
extra_check=check_fp32_vec_extra,
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecAVX2, [(4, 24, 1), (4, 16, 1), (8, 8, 1)], input_dtype=torch.float
|
||||
VecAVX2,
|
||||
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
||||
input_dtype=torch.float,
|
||||
extra_check=check_fp32_vec_extra,
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecAVX2,
|
||||
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
||||
input_dtype=torch.bfloat16,
|
||||
output_dtype=torch.float,
|
||||
extra_check=check_fp32_vec_extra,
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecAVX2,
|
||||
[(4, 24, 1), (4, 16, 1), (8, 8, 1)],
|
||||
input_dtype=torch.half,
|
||||
output_dtype=torch.float,
|
||||
extra_check=check_fp32_vec_extra,
|
||||
),
|
||||
)
|
||||
class CppMicroGemmFP32Vec(CppMicroGemm):
|
||||
@ -367,6 +417,218 @@ inline void {{kernel_name}}_kernel(
|
||||
return result
|
||||
|
||||
|
||||
# extra check for CppMicroGemmAMX
|
||||
def check_amx_extra(config, m, n, k, alpha, num_threads):
|
||||
return n % config.register_blocking.block_n == 0 and k % 2 == 0 and alpha == 1
|
||||
|
||||
|
||||
@register_micro_gemm(
|
||||
*generate_gemm_config(
|
||||
VecAMX,
|
||||
[(32, 32, 32), (48, 16, 32), (16, 48, 32)],
|
||||
input_dtype=torch.bfloat16,
|
||||
output_dtype=torch.float,
|
||||
extra_check=check_amx_extra,
|
||||
),
|
||||
)
|
||||
class CppMicroGemmAMX(CppMicroGemm):
|
||||
"""
|
||||
This class generates the code for micro gemm using Advanced Matrix eXtention (AMX)
|
||||
instructions available in 4th generation Intel Xeon for compute.
|
||||
It supports input types of torch.bfloat16 with fp32 output.
|
||||
TODO(jgong5): support int8 data type.
|
||||
"""
|
||||
|
||||
TEMPLATE_ENTRY = r"""
|
||||
{{declare_kernel}} {
|
||||
TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
|
||||
TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2");
|
||||
// TODO(jgong5): loop unroll for M and N
|
||||
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
||||
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
||||
int64_t m_tail = m;
|
||||
for (int64_t n = 0; n < N; n += {{block_n}}) {
|
||||
{%- for num_rows in range(block_m, 0, -16) %}
|
||||
{%- if num_rows != block_m %}
|
||||
else
|
||||
{%- endif %}
|
||||
if (block_m >= {{num_rows}}) {
|
||||
{{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}<accum>(
|
||||
amx_state,
|
||||
A + m * lda,
|
||||
B + n,
|
||||
C + m * ldc + n,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
16
|
||||
);
|
||||
block_m -= {{num_rows}};
|
||||
m_tail += {{num_rows}};
|
||||
}
|
||||
{%- endfor %}
|
||||
if (block_m > 0) {
|
||||
{{kernel_name}}_amx_kernel_16_{{num_columns}}<accum>(
|
||||
amx_state,
|
||||
A + m_tail * lda,
|
||||
B + n,
|
||||
C + m_tail * ldc + n,
|
||||
K,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
block_m
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
TEMPLATE_KERNEL = r"""
|
||||
template <bool accum>
|
||||
inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
|
||||
AMXState& amx_state,
|
||||
const {{input_t}}* __restrict__ A,
|
||||
const {{input_t}}* __restrict__ B,
|
||||
{{output_t}}* __restrict__ C,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
uint8_t tilecfg_rows
|
||||
) {
|
||||
// TODO(jgong5): add prefetch hint for A, B, C
|
||||
auto loadconfig = [](const amx_tilecfg& cfg) {
|
||||
_tile_loadconfig(&cfg);
|
||||
};
|
||||
const auto last_k_offset = K / {{block_k}} * {{block_k}};
|
||||
const auto tail_k_size = K - last_k_offset;
|
||||
if C10_LIKELY (last_k_offset > 0) {
|
||||
amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig);
|
||||
} else {
|
||||
amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
|
||||
}
|
||||
auto load_c = [&]() {
|
||||
{%- for tile_row in range(num_rows // 16) %}
|
||||
{%- for tile_col in range(num_columns) %}
|
||||
{%- set tile_idx = tile_row * num_columns + tile_col %}
|
||||
_tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
};
|
||||
auto zero_c = [&]() {
|
||||
{%- for tile_row in range(num_rows // 16) %}
|
||||
{%- for tile_col in range(num_columns) %}
|
||||
{%- set tile_idx = tile_row * num_columns + tile_col %}
|
||||
_tile_zero({{tile_idx}});
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
};
|
||||
|
||||
if constexpr (accum) {
|
||||
load_c();
|
||||
} else {
|
||||
zero_c();
|
||||
}
|
||||
|
||||
auto compute = [&](int k) {
|
||||
{%- set tile_offset_a = num_rows // 16 * num_columns %}
|
||||
{%- set tile_offset_b = tile_offset_a + num_rows // 16 %}
|
||||
{%- for tile_row in range(num_rows // 16) %}
|
||||
{%- for tile_col in range(num_columns) %}
|
||||
{%- set tile_idx_a = tile_offset_a + tile_row %}
|
||||
{%- set tile_idx_b = tile_offset_b + tile_col %}
|
||||
{%- set tile_idx_c = tile_row * num_columns + tile_col %}
|
||||
{%- if tile_col == 0 %}
|
||||
_tile_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}}));
|
||||
{%- endif %}
|
||||
{%- if tile_row == 0 %}
|
||||
_tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * 2}}, ldb * 2 * sizeof({{input_t}}));
|
||||
{%- endif %}
|
||||
_tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}});
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
};
|
||||
|
||||
{{kernel.unroll_pragma(4)}}
|
||||
for (int k = 0; k < last_k_offset; k += {{block_k}}) {
|
||||
compute(k);
|
||||
}
|
||||
|
||||
auto store_c = [&]() {
|
||||
// store to C
|
||||
{%- for tile_row in range(num_rows // 16) %}
|
||||
{%- for tile_col in range(num_columns) %}
|
||||
{%- set tile_idx = tile_row * num_columns + tile_col %}
|
||||
_tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}}));
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
};
|
||||
|
||||
// TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
|
||||
if C10_UNLIKELY (tail_k_size > 0) {
|
||||
if C10_LIKELY (last_k_offset > 0) {
|
||||
store_c();
|
||||
amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig);
|
||||
load_c();
|
||||
}
|
||||
compute(last_k_offset);
|
||||
}
|
||||
|
||||
store_c();
|
||||
}
|
||||
"""
|
||||
|
||||
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
||||
block_m, block_n, block_k = self.register_blocking
|
||||
assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX"
|
||||
assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX"
|
||||
assert block_k == 32, "Only support block_k = 32 for AMX"
|
||||
num_columns = block_n // 16
|
||||
options = {
|
||||
"declare_kernel": self.get_kernel_declaration(),
|
||||
"kernel": kernel,
|
||||
"block_m": block_m,
|
||||
"block_n": block_n,
|
||||
"block_k": block_k,
|
||||
"num_columns": num_columns,
|
||||
**self.get_common_options(),
|
||||
}
|
||||
result = ""
|
||||
for num_rows in range(block_m, 0, -16):
|
||||
amx_kernel_options = {**options, "num_rows": num_rows}
|
||||
result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render(
|
||||
amx_kernel_options
|
||||
)
|
||||
result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(
|
||||
options
|
||||
)
|
||||
return result
|
||||
|
||||
def codegen_init(
|
||||
self,
|
||||
kernel: CppTemplateKernel,
|
||||
) -> str:
|
||||
return "AMXState amx_state;"
|
||||
|
||||
def codegen_finalize(
|
||||
self,
|
||||
kernel: CppTemplateKernel,
|
||||
) -> str:
|
||||
return "amx_state.release([]() { _tile_release(); });"
|
||||
|
||||
def get_kernel_extra_args_declare(self) -> str:
|
||||
return "AMXState& amx_state,"
|
||||
|
||||
def get_kernel_extra_args(self) -> str:
|
||||
return "amx_state,"
|
||||
|
||||
def get_b_layout(self):
|
||||
return LayoutType.VNNI2
|
||||
|
||||
|
||||
def create_micro_gemm(
|
||||
name,
|
||||
m,
|
||||
@ -403,36 +665,48 @@ def create_micro_gemm(
|
||||
matched_configs = []
|
||||
for cls, configs in micro_gemm_configs.items():
|
||||
for config in configs:
|
||||
if not isinstance(vec_isa, config.vec_isa_cls):
|
||||
if not issubclass(vec_isa.__class__, config.vec_isa_cls):
|
||||
continue
|
||||
if (
|
||||
config.input_dtype == input_dtype
|
||||
and config.output_dtype == output_dtype
|
||||
and config.compute_dtype == compute_dtype
|
||||
):
|
||||
block_m, block_n, block_k = config.register_blocking
|
||||
# TODO(jgong5): support n % n_block_size != 0
|
||||
if n % block_n != 0:
|
||||
if config.extra_check is not None and not config.extra_check(
|
||||
config, m, n, k, alpha, num_threads
|
||||
):
|
||||
continue
|
||||
block_m, block_n, block_k = config.register_blocking
|
||||
# Criteria on the ranking of configurations
|
||||
# 1. Dividable by block sizes (block_m, block_k)
|
||||
# 2. Number of mxn blocks is large enough to occupy all the threads
|
||||
# 3. Register blocks are larger
|
||||
# 1. ISA: AMX > VEC
|
||||
# 2. Dividable by block sizes (block_m, block_n, block_k)
|
||||
# 3. Number of mxn blocks is large enough to occupy all the threads
|
||||
# 4. Register blocks are larger
|
||||
isa_score = 0
|
||||
if config.vec_isa_cls == VecAMX:
|
||||
isa_score += 1
|
||||
dividable_score = 0
|
||||
if k % block_k == 0:
|
||||
dividable_score += 1
|
||||
if m % block_m == 0:
|
||||
dividable_score += 1
|
||||
if n % block_n == 0:
|
||||
dividable_score += 1
|
||||
if k % block_k == 0:
|
||||
dividable_score += 1
|
||||
occupancy_score = 0
|
||||
n_blocks = n // block_n
|
||||
total_mxn_blocks = n // block_n * ((m + block_m - 1) // block_m)
|
||||
n_blocks = (n + block_n - 1) // block_n
|
||||
total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m)
|
||||
if n_blocks >= num_threads:
|
||||
occupancy_score += 1
|
||||
if total_mxn_blocks >= num_threads:
|
||||
occupancy_score += 1
|
||||
register_bytes = (
|
||||
block_m * block_n * config.compute_dtype.itemsize
|
||||
+ (block_m * block_k + block_k * block_n)
|
||||
* config.input_dtype.itemsize
|
||||
)
|
||||
matched_configs.append(
|
||||
(
|
||||
(dividable_score, occupancy_score, block_m * block_n * block_k),
|
||||
(isa_score, dividable_score, occupancy_score, register_bytes),
|
||||
cls,
|
||||
config,
|
||||
)
|
||||
|
@ -416,3 +416,65 @@ inline void mm_get_thread_blocks(
|
||||
m_block_start = std::min(thread_id * Mt_blocks, M_blocks);
|
||||
m_block_end = std::min(m_block_start + Mt_blocks, M_blocks);
|
||||
}
|
||||
|
||||
struct amx_tilecfg {
|
||||
uint8_t palette_id;
|
||||
uint8_t start_row;
|
||||
uint8_t reserved_0[14];
|
||||
uint16_t colsb[16];
|
||||
uint8_t rows[16];
|
||||
};
|
||||
|
||||
class AMXState {
|
||||
private:
|
||||
amx_tilecfg tilecfg_;
|
||||
uint8_t rows_;
|
||||
uint16_t colsb_;
|
||||
uint8_t num_tile_rows_;
|
||||
uint8_t num_tile_columns_;
|
||||
|
||||
public:
|
||||
AMXState() : rows_(0), colsb_(0), num_tile_rows_(0), num_tile_columns_(0) {
|
||||
memset(&tilecfg_, 0, sizeof(tilecfg_));
|
||||
}
|
||||
|
||||
inline void configure(
|
||||
uint8_t rows,
|
||||
uint16_t colsb,
|
||||
uint8_t num_tile_rows,
|
||||
uint8_t num_tile_columns,
|
||||
void (*loadconfig)(const amx_tilecfg&)) {
|
||||
if (tilecfg_.palette_id == 1 && rows_ == rows && colsb_ == colsb &&
|
||||
num_tile_rows_ == num_tile_rows &&
|
||||
num_tile_columns_ == num_tile_columns) {
|
||||
return;
|
||||
}
|
||||
tilecfg_.palette_id = 1;
|
||||
rows_ = rows;
|
||||
colsb_ = colsb;
|
||||
num_tile_rows_ = num_tile_rows;
|
||||
num_tile_columns_ = num_tile_columns;
|
||||
const auto num_c_tiles = num_tile_rows * num_tile_columns;
|
||||
// For C
|
||||
for (int i = 0; i < num_c_tiles; i++) {
|
||||
tilecfg_.rows[i] = rows;
|
||||
tilecfg_.colsb[i] = 64;
|
||||
}
|
||||
// For A
|
||||
for (int i = 0; i < num_tile_rows; i++) {
|
||||
tilecfg_.rows[i + num_c_tiles] = rows;
|
||||
tilecfg_.colsb[i + num_c_tiles] = colsb;
|
||||
}
|
||||
// For B
|
||||
for (int i = 0; i < num_tile_columns; i++) {
|
||||
tilecfg_.rows[i + num_c_tiles + num_tile_rows] = colsb / 4;
|
||||
tilecfg_.colsb[i + num_c_tiles + num_tile_rows] = 64;
|
||||
}
|
||||
loadconfig(tilecfg_);
|
||||
}
|
||||
|
||||
inline void release(void (*tile_release)()) {
|
||||
tilecfg_.palette_id = 0;
|
||||
tile_release();
|
||||
}
|
||||
};
|
||||
|
@ -45,6 +45,16 @@ def _is_cpu_support_vnni() -> bool:
|
||||
return torch._C._cpu._is_cpu_support_avx512_vnni()
|
||||
|
||||
|
||||
def _is_cpu_support_amx_tile() -> bool:
|
||||
r"""Returns a bool indicating if CPU supports AMX_TILE."""
|
||||
return torch._C._cpu._is_cpu_support_amx_tile()
|
||||
|
||||
|
||||
def _init_amx() -> bool:
|
||||
r"""Initializes AMX instructions."""
|
||||
return torch._C._cpu._init_amx()
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
r"""Returns a bool indicating if CPU is currently available.
|
||||
|
||||
|
@ -11,6 +11,8 @@ void initModule(PyObject* module) {
|
||||
cpu.def("_is_cpu_support_avx2", at::cpu::is_cpu_support_avx2);
|
||||
cpu.def("_is_cpu_support_avx512", at::cpu::is_cpu_support_avx512);
|
||||
cpu.def("_is_cpu_support_avx512_vnni", at::cpu::is_cpu_support_avx512_vnni);
|
||||
cpu.def("_is_cpu_support_amx_tile", at::cpu::is_cpu_support_amx_tile);
|
||||
cpu.def("_init_amx", at::cpu::init_amx);
|
||||
}
|
||||
|
||||
} // namespace torch::cpu
|
||||
|
Reference in New Issue
Block a user