[CPU][Inductor] Improve performance of A16W8 GEMM template (#161148)

**Summary**
This PR improves the performance of A16W8 GEMM template by
- Removing the config with block_n=48 & block_m=16 as it is not very efficient.
- Using AMX microkernel when M >= 5 so that we use AMX instead of AVX512 for M=5~31.
- Converting int8 values to bf16 with intrinsics instead of `at::vec::convert` as the latter does not have optimized implementation for this case.

We saw up to >10% performance gain in various cases of running Llama-3.1-8b-instruct.

**Test plan**
Already covered by UT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161148
Approved by: https://github.com/CaoE, https://github.com/jansel
This commit is contained in:
Xia, Weiwen
2025-08-31 09:56:29 +00:00
committed by PyTorch MergeBot
parent 377033757a
commit 75bc23cfc3
2 changed files with 92 additions and 20 deletions

View File

@ -963,6 +963,15 @@ def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
return k % vnni_size == 0 and alpha == 1
def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
# We need avx512_bf16 to dequant int8 to bf16
vec_isa = kwargs.get("vec_isa", None)
assert vec_isa is not None
return vec_isa.is_avx512_bf16_supported() and check_amx_extra(
config, m, n, k, alpha, num_threads, **kwargs
)
# amx_fp16 need to be checked separately since it is not always supported when amx is supported
def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs):
assert config.input_dtype == torch.float16 and config.output_dtype == torch.float
@ -984,12 +993,12 @@ def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs):
),
*generate_gemm_config(
VecAMX,
[(32, 32, 32), (48, 16, 32), (16, 48, 32)],
[(32, 32, 32), (48, 16, 32)],
input_dtype=torch.bfloat16,
input2_dtype=torch.int8,
output_dtype=torch.float,
compute_dtype=torch.float,
extra_check=check_amx_extra,
extra_check=check_int8_bf16_amx_extra,
),
*generate_gemm_config(
VecAMX,
@ -1041,12 +1050,38 @@ class CppMicroGemmAMX(CppMicroGemm):
for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) {
{%- for vec_idx in range(0, block_n, 32) %}
{%- if (block_n - vec_idx) >= 32 %}
auto b_int8_idx_{{vec_idx}} = at::vec::Vectorized<int8_t>::loadu(
base_addr + idx_q + {{vec_idx}} ,
static_cast<int64_t>(32)
);
auto b_bf16_idx_{{vec_idx}} = at::vec::convert<{{input_t}}>(b_int8_idx_{{vec_idx}});
b_bf16_idx_{{vec_idx}}.store(dequantized_B_buf + idx_dq + {{vec_idx}});
// 1) Load 32 x int8
__m256i v8 = _mm256_loadu_si256((const __m256i*)(base_addr + idx_q + {{vec_idx}}));
// 2) Widen: 32 x i8 -> 32 x i16
__m512i v16 = _mm512_cvtepi8_epi16(v8); // sign-extend. Use _mm512_cvtepu8_epi16 for unsigned
// Split the 32 x i16 into two 16-lane halves
__m256i v16_lo = _mm512_castsi512_si256(v16);
__m256i v16_hi = _mm512_extracti64x4_epi64(v16, 1);
// 3) Widen each half to i32
__m512i v32_lo = _mm512_cvtepi16_epi32(v16_lo);
__m512i v32_hi = _mm512_cvtepi16_epi32(v16_hi);
// 4) Convert to f32
__m512 f_lo = _mm512_cvtepi32_ps(v32_lo);
__m512 f_hi = _mm512_cvtepi32_ps(v32_hi);
// 5) f32 -> bf16 (round-to-nearest-even) and pack 32 lanes to 512b
// Packs the second operand (f_lo) into the lower 16 bf16 lanes and the first (f_hi) into the upper 16.
__m512i bf = (__m512i)_mm512_cvtne2ps_pbh(f_hi, f_lo);
// 6) Store 32 x bf16 (512 bits)
_mm512_storeu_si512((__m512i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf);
{%- elif (block_n - vec_idx) >= 16 %}
// 1) Load 16 x int8 (128 bits)
__m128i v8 = _mm_loadu_si128((const __m128i*)(base_addr + idx_q + {{vec_idx}}));
// 2) Widen: 16 x i8 -> 16 x i16
__m256i v16 = _mm256_cvtepi8_epi16(v8); // for signed
// use _mm256_cvtepu8_epi16 for unsigned
// 3) Widen further: 16 x i16 -> 16 x i32
__m512i v32 = _mm512_cvtepi16_epi32(v16);
// 4) Convert to f32
__m512 f32 = _mm512_cvtepi32_ps(v32);
// 5) Convert f32 -> bf16 (round-to-nearest-even)
__m256i bf16 = (__m256i)_mm512_cvtneps_pbh(f32);
// 6) Store 16 x bf16 (256 bits)
_mm256_storeu_si256((__m256i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf16);
{%- else %}
auto b_int8_tail = at::vec::Vectorized<int8_t>::loadu(
base_addr + idx_q + {{block_n - (block_n % 32)}},
@ -1915,7 +1950,7 @@ def create_micro_gemm(
alpha,
)
def skip_amx_kernel_for_woq(config, dynamic_M, micro_gemm_cls):
def skip_amx_kernel_for_woq(dynamic_M):
# For WoQ GEMM, AMX micro-kernel may not perform well if m is small.
# Exception: for dynamic shapes, we consider using the AMX micro-kernel.
if (
@ -1924,11 +1959,7 @@ def create_micro_gemm(
or input2_dtype not in [torch.int8, torch.uint8]
):
return False
# For WOQ INT8, use AMX for m >= block_m
# For WOQ INT4, use AMX for m >= 5
block_m, *_ = config.register_blocking
is_woq_int4 = micro_gemm_cls == CppMicroGemmWoQInt4Amx
m_threshold = 5 if is_woq_int4 else block_m
m_threshold = 5
return m < m_threshold
assert isinstance(n, int) or n.is_number, n
@ -1974,9 +2005,7 @@ def create_micro_gemm(
):
continue
block_m, block_n, block_k = config.register_blocking
if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq(
config, dynamic_M, cls
):
if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq(dynamic_M):
continue
# Criteria on the ranking of configurations
# 1. ISA: AMX > VEC

View File

@ -200,12 +200,51 @@ class VecAVX512(VecISA):
else "/arch:AVX512"
) # TODO: use cflags
_dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
_is_avx512_bf16_supported = False
def __str__(self) -> str:
return "avx512"
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment]
_avx512_bf16_code = """
#include <cstdint>
#include <immintrin.h>
extern "C" __m512bh __avx512_bf16_chk_kernel(__m512 a, __m512 b) {
return _mm512_cvtne2ps_pbh(a, b);
}
"""
@functools.cache # noqa: B019
def __bool__(self) -> bool:
if super().__bool__():
if config.is_fbcode():
return False
# check avx512_bf16
if torch.cpu._is_avx512_bf16_supported() and not _IS_WINDOWS:
# save _arch_flags
base_flags = self._arch_flags
# temporarily change _arch_flags for avx512_bf16 check_build
self._arch_flags += " -mavx512bf16"
if self.check_build(VecAMX._avx512_bf16_code):
self._is_avx512_bf16_supported = True
# restore _arch_flags
self._arch_flags = base_flags
return True
return False
@functools.lru_cache(None) # noqa: B019
def is_avx512_bf16_supported(self) -> bool:
return self._is_avx512_bf16_supported
def build_arch_flags(self) -> str:
if self._is_avx512_bf16_supported:
return self._arch_flags + " -mavx512bf16"
else:
return self._arch_flags
@dataclasses.dataclass
class VecAMX(VecAVX512):
@ -267,10 +306,14 @@ extern "C" void __amx_chk_kernel() {
return self._is_amx_fp16_supported
def build_arch_flags(self) -> str:
extra_flags = ""
if self._is_avx512_bf16_supported:
# avx512_bf16 is not among the base flags, so we need to check and add it here
# And we need this flag in the WOQ case for dequantization
extra_flags += " -mavx512bf16"
if self._is_amx_fp16_supported:
return self._arch_flags + " -mamx-fp16"
else:
return self._arch_flags
extra_flags += " -mamx-fp16"
return self._arch_flags + extra_flags
@dataclasses.dataclass