mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CPU][Inductor] Improve performance of A16W8 GEMM template (#161148)
**Summary** This PR improves the performance of A16W8 GEMM template by - Removing the config with block_n=48 & block_m=16 as it is not very efficient. - Using AMX microkernel when M >= 5 so that we use AMX instead of AVX512 for M=5~31. - Converting int8 values to bf16 with intrinsics instead of `at::vec::convert` as the latter does not have optimized implementation for this case. We saw up to >10% performance gain in various cases of running Llama-3.1-8b-instruct. **Test plan** Already covered by UT. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161148 Approved by: https://github.com/CaoE, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
377033757a
commit
75bc23cfc3
@ -963,6 +963,15 @@ def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
||||
return k % vnni_size == 0 and alpha == 1
|
||||
|
||||
|
||||
def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
||||
# We need avx512_bf16 to dequant int8 to bf16
|
||||
vec_isa = kwargs.get("vec_isa", None)
|
||||
assert vec_isa is not None
|
||||
return vec_isa.is_avx512_bf16_supported() and check_amx_extra(
|
||||
config, m, n, k, alpha, num_threads, **kwargs
|
||||
)
|
||||
|
||||
|
||||
# amx_fp16 need to be checked separately since it is not always supported when amx is supported
|
||||
def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
||||
assert config.input_dtype == torch.float16 and config.output_dtype == torch.float
|
||||
@ -984,12 +993,12 @@ def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs):
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecAMX,
|
||||
[(32, 32, 32), (48, 16, 32), (16, 48, 32)],
|
||||
[(32, 32, 32), (48, 16, 32)],
|
||||
input_dtype=torch.bfloat16,
|
||||
input2_dtype=torch.int8,
|
||||
output_dtype=torch.float,
|
||||
compute_dtype=torch.float,
|
||||
extra_check=check_amx_extra,
|
||||
extra_check=check_int8_bf16_amx_extra,
|
||||
),
|
||||
*generate_gemm_config(
|
||||
VecAMX,
|
||||
@ -1041,12 +1050,38 @@ class CppMicroGemmAMX(CppMicroGemm):
|
||||
for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) {
|
||||
{%- for vec_idx in range(0, block_n, 32) %}
|
||||
{%- if (block_n - vec_idx) >= 32 %}
|
||||
auto b_int8_idx_{{vec_idx}} = at::vec::Vectorized<int8_t>::loadu(
|
||||
base_addr + idx_q + {{vec_idx}} ,
|
||||
static_cast<int64_t>(32)
|
||||
);
|
||||
auto b_bf16_idx_{{vec_idx}} = at::vec::convert<{{input_t}}>(b_int8_idx_{{vec_idx}});
|
||||
b_bf16_idx_{{vec_idx}}.store(dequantized_B_buf + idx_dq + {{vec_idx}});
|
||||
// 1) Load 32 x int8
|
||||
__m256i v8 = _mm256_loadu_si256((const __m256i*)(base_addr + idx_q + {{vec_idx}}));
|
||||
// 2) Widen: 32 x i8 -> 32 x i16
|
||||
__m512i v16 = _mm512_cvtepi8_epi16(v8); // sign-extend. Use _mm512_cvtepu8_epi16 for unsigned
|
||||
// Split the 32 x i16 into two 16-lane halves
|
||||
__m256i v16_lo = _mm512_castsi512_si256(v16);
|
||||
__m256i v16_hi = _mm512_extracti64x4_epi64(v16, 1);
|
||||
// 3) Widen each half to i32
|
||||
__m512i v32_lo = _mm512_cvtepi16_epi32(v16_lo);
|
||||
__m512i v32_hi = _mm512_cvtepi16_epi32(v16_hi);
|
||||
// 4) Convert to f32
|
||||
__m512 f_lo = _mm512_cvtepi32_ps(v32_lo);
|
||||
__m512 f_hi = _mm512_cvtepi32_ps(v32_hi);
|
||||
// 5) f32 -> bf16 (round-to-nearest-even) and pack 32 lanes to 512b
|
||||
// Packs the second operand (f_lo) into the lower 16 bf16 lanes and the first (f_hi) into the upper 16.
|
||||
__m512i bf = (__m512i)_mm512_cvtne2ps_pbh(f_hi, f_lo);
|
||||
// 6) Store 32 x bf16 (512 bits)
|
||||
_mm512_storeu_si512((__m512i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf);
|
||||
{%- elif (block_n - vec_idx) >= 16 %}
|
||||
// 1) Load 16 x int8 (128 bits)
|
||||
__m128i v8 = _mm_loadu_si128((const __m128i*)(base_addr + idx_q + {{vec_idx}}));
|
||||
// 2) Widen: 16 x i8 -> 16 x i16
|
||||
__m256i v16 = _mm256_cvtepi8_epi16(v8); // for signed
|
||||
// use _mm256_cvtepu8_epi16 for unsigned
|
||||
// 3) Widen further: 16 x i16 -> 16 x i32
|
||||
__m512i v32 = _mm512_cvtepi16_epi32(v16);
|
||||
// 4) Convert to f32
|
||||
__m512 f32 = _mm512_cvtepi32_ps(v32);
|
||||
// 5) Convert f32 -> bf16 (round-to-nearest-even)
|
||||
__m256i bf16 = (__m256i)_mm512_cvtneps_pbh(f32);
|
||||
// 6) Store 16 x bf16 (256 bits)
|
||||
_mm256_storeu_si256((__m256i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf16);
|
||||
{%- else %}
|
||||
auto b_int8_tail = at::vec::Vectorized<int8_t>::loadu(
|
||||
base_addr + idx_q + {{block_n - (block_n % 32)}},
|
||||
@ -1915,7 +1950,7 @@ def create_micro_gemm(
|
||||
alpha,
|
||||
)
|
||||
|
||||
def skip_amx_kernel_for_woq(config, dynamic_M, micro_gemm_cls):
|
||||
def skip_amx_kernel_for_woq(dynamic_M):
|
||||
# For WoQ GEMM, AMX micro-kernel may not perform well if m is small.
|
||||
# Exception: for dynamic shapes, we consider using the AMX micro-kernel.
|
||||
if (
|
||||
@ -1924,11 +1959,7 @@ def create_micro_gemm(
|
||||
or input2_dtype not in [torch.int8, torch.uint8]
|
||||
):
|
||||
return False
|
||||
# For WOQ INT8, use AMX for m >= block_m
|
||||
# For WOQ INT4, use AMX for m >= 5
|
||||
block_m, *_ = config.register_blocking
|
||||
is_woq_int4 = micro_gemm_cls == CppMicroGemmWoQInt4Amx
|
||||
m_threshold = 5 if is_woq_int4 else block_m
|
||||
m_threshold = 5
|
||||
return m < m_threshold
|
||||
|
||||
assert isinstance(n, int) or n.is_number, n
|
||||
@ -1974,9 +2005,7 @@ def create_micro_gemm(
|
||||
):
|
||||
continue
|
||||
block_m, block_n, block_k = config.register_blocking
|
||||
if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq(
|
||||
config, dynamic_M, cls
|
||||
):
|
||||
if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq(dynamic_M):
|
||||
continue
|
||||
# Criteria on the ranking of configurations
|
||||
# 1. ISA: AMX > VEC
|
||||
|
||||
@ -200,12 +200,51 @@ class VecAVX512(VecISA):
|
||||
else "/arch:AVX512"
|
||||
) # TODO: use cflags
|
||||
_dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
|
||||
_is_avx512_bf16_supported = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "avx512"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment]
|
||||
|
||||
_avx512_bf16_code = """
|
||||
#include <cstdint>
|
||||
#include <immintrin.h>
|
||||
|
||||
extern "C" __m512bh __avx512_bf16_chk_kernel(__m512 a, __m512 b) {
|
||||
return _mm512_cvtne2ps_pbh(a, b);
|
||||
}
|
||||
"""
|
||||
|
||||
@functools.cache # noqa: B019
|
||||
def __bool__(self) -> bool:
|
||||
if super().__bool__():
|
||||
if config.is_fbcode():
|
||||
return False
|
||||
# check avx512_bf16
|
||||
if torch.cpu._is_avx512_bf16_supported() and not _IS_WINDOWS:
|
||||
# save _arch_flags
|
||||
base_flags = self._arch_flags
|
||||
# temporarily change _arch_flags for avx512_bf16 check_build
|
||||
self._arch_flags += " -mavx512bf16"
|
||||
if self.check_build(VecAMX._avx512_bf16_code):
|
||||
self._is_avx512_bf16_supported = True
|
||||
# restore _arch_flags
|
||||
self._arch_flags = base_flags
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
def is_avx512_bf16_supported(self) -> bool:
|
||||
return self._is_avx512_bf16_supported
|
||||
|
||||
def build_arch_flags(self) -> str:
|
||||
if self._is_avx512_bf16_supported:
|
||||
return self._arch_flags + " -mavx512bf16"
|
||||
else:
|
||||
return self._arch_flags
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecAMX(VecAVX512):
|
||||
@ -267,10 +306,14 @@ extern "C" void __amx_chk_kernel() {
|
||||
return self._is_amx_fp16_supported
|
||||
|
||||
def build_arch_flags(self) -> str:
|
||||
extra_flags = ""
|
||||
if self._is_avx512_bf16_supported:
|
||||
# avx512_bf16 is not among the base flags, so we need to check and add it here
|
||||
# And we need this flag in the WOQ case for dequantization
|
||||
extra_flags += " -mavx512bf16"
|
||||
if self._is_amx_fp16_supported:
|
||||
return self._arch_flags + " -mamx-fp16"
|
||||
else:
|
||||
return self._arch_flags
|
||||
extra_flags += " -mamx-fp16"
|
||||
return self._arch_flags + extra_flags
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
Reference in New Issue
Block a user