[CPU][GEMM Template] Improve A16W8 performance (#162479)

**Summary**
Improve A16W8 performance by
1. supporting GQA concat linear
2. using smaller cache blocking size
3. improving code for dequantization of weight (reducing instructions and adding prefetch)

We saw > 5% E2E next token performance gain when running Llama3.1-8B-instruct.

**Test plan**
Already covered by UT

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162479
Approved by: https://github.com/mingfeima, https://github.com/CaoE, https://github.com/jansel
This commit is contained in:
Xia, Weiwen
2025-09-18 01:28:37 +00:00
committed by PyTorch MergeBot
parent f17e2ab1f9
commit 48a7e8cc70
4 changed files with 37 additions and 38 deletions

View File

@ -809,13 +809,27 @@ class CppGemmTemplate(CppTemplate):
if (
config.cpp.use_small_dequant_buffer
and dtype_A is torch.bfloat16
and dtype_B is torch.uint8
and Mt_blocks == 1
):
# Make a small dequant_B buffer for woq int4 [q_group_size, Nr]
# Since when Mt_blocks == 1, L1-reside B block can't be reused by A.
if Kc_blocks * Kr >= self.q_group_size():
Kc_blocks = self.q_group_size() // Kr
if dtype_B is torch.uint8:
# A16W4
# Make a small dequant_B buffer for woq int4 [q_group_size, Nr]
# Since when Mt_blocks == 1, L1-reside B block can't be reused by A.
if Kc_blocks * Kr >= self.q_group_size():
Kc_blocks = self.q_group_size() // Kr
elif dtype_B is torch.int8:
# A16W8
# Make A, B, C buffer in L1
A_buf_size_div_K = self.m * num_byte_A
B_buf_size_div_K = Nr * num_byte_B
# assume acc in float32/int32 and Mc_blocks = Nc_blocks = 1
C_buf_size = Mr * Nr * 4
K_block_size = (L1 - C_buf_size) // (
A_buf_size_div_K + B_buf_size_div_K
)
if Kc_blocks * Kr >= K_block_size:
Kc_blocks = (K_block_size + Kr - 1) // Kr
# Step 2: Decide Mc assuming A block is L2-reside.
min_Mc_ratio = 2 # TODO(jgong5): something to tune?

View File

@ -1049,17 +1049,16 @@ class CppMicroGemmAMX(CppMicroGemm):
{{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx;
for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) {
{%- for vec_idx in range(0, block_n, 32) %}
_mm_prefetch(base_addr + idx_q + 64 * ldb, _MM_HINT_T0);
{%- if (block_n - vec_idx) >= 32 %}
// 1) Load 32 x int8
__m256i v8 = _mm256_loadu_si256((const __m256i*)(base_addr + idx_q + {{vec_idx}}));
// 2) Widen: 32 x i8 -> 32 x i16
__m512i v16 = _mm512_cvtepi8_epi16(v8); // sign-extend. Use _mm512_cvtepu8_epi16 for unsigned
// Split the 32 x i16 into two 16-lane halves
__m256i v16_lo = _mm512_castsi512_si256(v16);
__m256i v16_hi = _mm512_extracti64x4_epi64(v16, 1);
// 2) Extract two halves
__m128i v8_lo = _mm256_extracti128_si256(v8, 0);
__m128i v8_hi = _mm256_extracti128_si256(v8, 1);
// 3) Widen each half to i32
__m512i v32_lo = _mm512_cvtepi16_epi32(v16_lo);
__m512i v32_hi = _mm512_cvtepi16_epi32(v16_hi);
__m512i v32_lo = _mm512_cvtepi8_epi32(v8_lo);
__m512i v32_hi = _mm512_cvtepi8_epi32(v8_hi);
// 4) Convert to f32
__m512 f_lo = _mm512_cvtepi32_ps(v32_lo);
__m512 f_hi = _mm512_cvtepi32_ps(v32_hi);
@ -1071,16 +1070,13 @@ class CppMicroGemmAMX(CppMicroGemm):
{%- elif (block_n - vec_idx) >= 16 %}
// 1) Load 16 x int8 (128 bits)
__m128i v8 = _mm_loadu_si128((const __m128i*)(base_addr + idx_q + {{vec_idx}}));
// 2) Widen: 16 x i8 -> 16 x i16
__m256i v16 = _mm256_cvtepi8_epi16(v8); // for signed
// use _mm256_cvtepu8_epi16 for unsigned
// 3) Widen further: 16 x i16 -> 16 x i32
__m512i v32 = _mm512_cvtepi16_epi32(v16);
// 4) Convert to f32
// 2) Widen: 16 x i8 -> 16 x i32
__m512i v32 = _mm512_cvtepi8_epi32(v8);
// 3) Convert to f32
__m512 f32 = _mm512_cvtepi32_ps(v32);
// 5) Convert f32 -> bf16 (round-to-nearest-even)
// 4) Convert f32 -> bf16 (round-to-nearest-even)
__m256i bf16 = (__m256i)_mm512_cvtneps_pbh(f32);
// 6) Store 16 x bf16 (256 bits)
// 5) Store 16 x bf16 (256 bits)
_mm256_storeu_si256((__m256i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf16);
{%- else %}
auto b_int8_tail = at::vec::Vectorized<int8_t>::loadu(

View File

@ -160,13 +160,6 @@ def addmm_patterns_init():
):
return False
equal_shape_inputs = [weight_inputs]
for equal_shape_group in equal_shape_inputs:
inps = [match.kwargs[name] for name in equal_shape_group]
if not all(
inp.meta["val"].shape == inps[0].meta["val"].shape for inp in inps
):
return False
return True
def check_concat_weights(match):
@ -205,7 +198,8 @@ def addmm_patterns_init():
cat_w = torch.cat((w1, w2, w3), dim=1)
cat_s = torch.cat((s1, s2, s3), dim=0)
mm = (inp @ cat_w).mul(cat_s)
return mm.chunk(3, dim=1)
n1, n2 = w1.size(1), w2.size(1)
return mm.tensor_split([n1, n1 + n2], dim=-1)
register_replacement(
int8_woq_fusion_pattern,

View File

@ -1104,12 +1104,6 @@ def _is_valid_concat_linear_int8_woq_optimization_pattern():
w1_cols = match.kwargs["w1"].meta["val"].size()[0]
w2_cols = match.kwargs["w2"].meta["val"].size()[0]
w3_cols = match.kwargs["w3"].meta["val"].size()[0]
# Technically, the shapes of the three weights need not be equal.
# But currently, we only enable replacement in this case.
if w1_cols != w2_cols or w2_cols != w3_cols:
return False
if 3 * w1_cols != num_scales:
return False
return (
# For now, we only support woq mm kernels
# with x.type=bfloat16 and w.type=int8
@ -1125,6 +1119,7 @@ def _is_valid_concat_linear_int8_woq_optimization_pattern():
and w1.device == w2.device
and w2.device == w3.device
and x.device == scales.device
and num_scales == w1_cols + w2_cols + w3_cols
)
return fn
@ -1162,7 +1157,7 @@ def _register_concat_linear_int8_woq_lowering(
extra_check=_is_valid_concat_linear_int8_woq_optimization_pattern(),
pass_number=4,
)
def woq(match: Match, *args, **kwargs):
def woq_int8(match: Match, *args, **kwargs):
x = kwargs["x"]
w1 = kwargs["w1"]
w2 = kwargs["w2"]
@ -1218,7 +1213,7 @@ def _register_concat_linear_int8_woq_lowering(
match.graph.erase_node(cat_wgt_node)
match.graph.lint()
return woq
return woq_int8
def _register_woq_lowering(pattern, computation_woq, computation_reshape):
@ -1226,7 +1221,7 @@ def _register_woq_lowering(pattern, computation_woq, computation_reshape):
pattern,
extra_check=_is_valid_woq_optimization_pattern(),
)
def woq(match: Match, *args, **kwargs):
def woq_int8(match: Match, *args, **kwargs):
x = kwargs["x"]
weight = kwargs["weight"]
scales = kwargs["scales"]
@ -1242,7 +1237,7 @@ def _register_woq_lowering(pattern, computation_woq, computation_reshape):
func2 = L[computation_woq](func1, weight, scales)
return L[computation_reshape](func2, out_shape)
return woq
return woq_int8
def _register_woq_mm_int8_pattern1():