mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CPU][GEMM Template] Improve A16W8 performance (#162479)
**Summary** Improve A16W8 performance by 1. supporting GQA concat linear 2. using smaller cache blocking size 3. improving code for dequantization of weight (reducing instructions and adding prefetch) We saw > 5% E2E next token performance gain when running Llama3.1-8B-instruct. **Test plan** Already covered by UT Pull Request resolved: https://github.com/pytorch/pytorch/pull/162479 Approved by: https://github.com/mingfeima, https://github.com/CaoE, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
f17e2ab1f9
commit
48a7e8cc70
@ -809,13 +809,27 @@ class CppGemmTemplate(CppTemplate):
|
||||
if (
|
||||
config.cpp.use_small_dequant_buffer
|
||||
and dtype_A is torch.bfloat16
|
||||
and dtype_B is torch.uint8
|
||||
and Mt_blocks == 1
|
||||
):
|
||||
# Make a small dequant_B buffer for woq int4 [q_group_size, Nr]
|
||||
# Since when Mt_blocks == 1, L1-reside B block can't be reused by A.
|
||||
if Kc_blocks * Kr >= self.q_group_size():
|
||||
Kc_blocks = self.q_group_size() // Kr
|
||||
if dtype_B is torch.uint8:
|
||||
# A16W4
|
||||
# Make a small dequant_B buffer for woq int4 [q_group_size, Nr]
|
||||
# Since when Mt_blocks == 1, L1-reside B block can't be reused by A.
|
||||
if Kc_blocks * Kr >= self.q_group_size():
|
||||
Kc_blocks = self.q_group_size() // Kr
|
||||
|
||||
elif dtype_B is torch.int8:
|
||||
# A16W8
|
||||
# Make A, B, C buffer in L1
|
||||
A_buf_size_div_K = self.m * num_byte_A
|
||||
B_buf_size_div_K = Nr * num_byte_B
|
||||
# assume acc in float32/int32 and Mc_blocks = Nc_blocks = 1
|
||||
C_buf_size = Mr * Nr * 4
|
||||
K_block_size = (L1 - C_buf_size) // (
|
||||
A_buf_size_div_K + B_buf_size_div_K
|
||||
)
|
||||
if Kc_blocks * Kr >= K_block_size:
|
||||
Kc_blocks = (K_block_size + Kr - 1) // Kr
|
||||
|
||||
# Step 2: Decide Mc assuming A block is L2-reside.
|
||||
min_Mc_ratio = 2 # TODO(jgong5): something to tune?
|
||||
|
@ -1049,17 +1049,16 @@ class CppMicroGemmAMX(CppMicroGemm):
|
||||
{{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx;
|
||||
for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) {
|
||||
{%- for vec_idx in range(0, block_n, 32) %}
|
||||
_mm_prefetch(base_addr + idx_q + 64 * ldb, _MM_HINT_T0);
|
||||
{%- if (block_n - vec_idx) >= 32 %}
|
||||
// 1) Load 32 x int8
|
||||
__m256i v8 = _mm256_loadu_si256((const __m256i*)(base_addr + idx_q + {{vec_idx}}));
|
||||
// 2) Widen: 32 x i8 -> 32 x i16
|
||||
__m512i v16 = _mm512_cvtepi8_epi16(v8); // sign-extend. Use _mm512_cvtepu8_epi16 for unsigned
|
||||
// Split the 32 x i16 into two 16-lane halves
|
||||
__m256i v16_lo = _mm512_castsi512_si256(v16);
|
||||
__m256i v16_hi = _mm512_extracti64x4_epi64(v16, 1);
|
||||
// 2) Extract two halves
|
||||
__m128i v8_lo = _mm256_extracti128_si256(v8, 0);
|
||||
__m128i v8_hi = _mm256_extracti128_si256(v8, 1);
|
||||
// 3) Widen each half to i32
|
||||
__m512i v32_lo = _mm512_cvtepi16_epi32(v16_lo);
|
||||
__m512i v32_hi = _mm512_cvtepi16_epi32(v16_hi);
|
||||
__m512i v32_lo = _mm512_cvtepi8_epi32(v8_lo);
|
||||
__m512i v32_hi = _mm512_cvtepi8_epi32(v8_hi);
|
||||
// 4) Convert to f32
|
||||
__m512 f_lo = _mm512_cvtepi32_ps(v32_lo);
|
||||
__m512 f_hi = _mm512_cvtepi32_ps(v32_hi);
|
||||
@ -1071,16 +1070,13 @@ class CppMicroGemmAMX(CppMicroGemm):
|
||||
{%- elif (block_n - vec_idx) >= 16 %}
|
||||
// 1) Load 16 x int8 (128 bits)
|
||||
__m128i v8 = _mm_loadu_si128((const __m128i*)(base_addr + idx_q + {{vec_idx}}));
|
||||
// 2) Widen: 16 x i8 -> 16 x i16
|
||||
__m256i v16 = _mm256_cvtepi8_epi16(v8); // for signed
|
||||
// use _mm256_cvtepu8_epi16 for unsigned
|
||||
// 3) Widen further: 16 x i16 -> 16 x i32
|
||||
__m512i v32 = _mm512_cvtepi16_epi32(v16);
|
||||
// 4) Convert to f32
|
||||
// 2) Widen: 16 x i8 -> 16 x i32
|
||||
__m512i v32 = _mm512_cvtepi8_epi32(v8);
|
||||
// 3) Convert to f32
|
||||
__m512 f32 = _mm512_cvtepi32_ps(v32);
|
||||
// 5) Convert f32 -> bf16 (round-to-nearest-even)
|
||||
// 4) Convert f32 -> bf16 (round-to-nearest-even)
|
||||
__m256i bf16 = (__m256i)_mm512_cvtneps_pbh(f32);
|
||||
// 6) Store 16 x bf16 (256 bits)
|
||||
// 5) Store 16 x bf16 (256 bits)
|
||||
_mm256_storeu_si256((__m256i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf16);
|
||||
{%- else %}
|
||||
auto b_int8_tail = at::vec::Vectorized<int8_t>::loadu(
|
||||
|
@ -160,13 +160,6 @@ def addmm_patterns_init():
|
||||
):
|
||||
return False
|
||||
|
||||
equal_shape_inputs = [weight_inputs]
|
||||
for equal_shape_group in equal_shape_inputs:
|
||||
inps = [match.kwargs[name] for name in equal_shape_group]
|
||||
if not all(
|
||||
inp.meta["val"].shape == inps[0].meta["val"].shape for inp in inps
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_concat_weights(match):
|
||||
@ -205,7 +198,8 @@ def addmm_patterns_init():
|
||||
cat_w = torch.cat((w1, w2, w3), dim=1)
|
||||
cat_s = torch.cat((s1, s2, s3), dim=0)
|
||||
mm = (inp @ cat_w).mul(cat_s)
|
||||
return mm.chunk(3, dim=1)
|
||||
n1, n2 = w1.size(1), w2.size(1)
|
||||
return mm.tensor_split([n1, n1 + n2], dim=-1)
|
||||
|
||||
register_replacement(
|
||||
int8_woq_fusion_pattern,
|
||||
|
@ -1104,12 +1104,6 @@ def _is_valid_concat_linear_int8_woq_optimization_pattern():
|
||||
w1_cols = match.kwargs["w1"].meta["val"].size()[0]
|
||||
w2_cols = match.kwargs["w2"].meta["val"].size()[0]
|
||||
w3_cols = match.kwargs["w3"].meta["val"].size()[0]
|
||||
# Technically, the shapes of the three weights need not be equal.
|
||||
# But currently, we only enable replacement in this case.
|
||||
if w1_cols != w2_cols or w2_cols != w3_cols:
|
||||
return False
|
||||
if 3 * w1_cols != num_scales:
|
||||
return False
|
||||
return (
|
||||
# For now, we only support woq mm kernels
|
||||
# with x.type=bfloat16 and w.type=int8
|
||||
@ -1125,6 +1119,7 @@ def _is_valid_concat_linear_int8_woq_optimization_pattern():
|
||||
and w1.device == w2.device
|
||||
and w2.device == w3.device
|
||||
and x.device == scales.device
|
||||
and num_scales == w1_cols + w2_cols + w3_cols
|
||||
)
|
||||
|
||||
return fn
|
||||
@ -1162,7 +1157,7 @@ def _register_concat_linear_int8_woq_lowering(
|
||||
extra_check=_is_valid_concat_linear_int8_woq_optimization_pattern(),
|
||||
pass_number=4,
|
||||
)
|
||||
def woq(match: Match, *args, **kwargs):
|
||||
def woq_int8(match: Match, *args, **kwargs):
|
||||
x = kwargs["x"]
|
||||
w1 = kwargs["w1"]
|
||||
w2 = kwargs["w2"]
|
||||
@ -1218,7 +1213,7 @@ def _register_concat_linear_int8_woq_lowering(
|
||||
match.graph.erase_node(cat_wgt_node)
|
||||
match.graph.lint()
|
||||
|
||||
return woq
|
||||
return woq_int8
|
||||
|
||||
|
||||
def _register_woq_lowering(pattern, computation_woq, computation_reshape):
|
||||
@ -1226,7 +1221,7 @@ def _register_woq_lowering(pattern, computation_woq, computation_reshape):
|
||||
pattern,
|
||||
extra_check=_is_valid_woq_optimization_pattern(),
|
||||
)
|
||||
def woq(match: Match, *args, **kwargs):
|
||||
def woq_int8(match: Match, *args, **kwargs):
|
||||
x = kwargs["x"]
|
||||
weight = kwargs["weight"]
|
||||
scales = kwargs["scales"]
|
||||
@ -1242,7 +1237,7 @@ def _register_woq_lowering(pattern, computation_woq, computation_reshape):
|
||||
func2 = L[computation_woq](func1, weight, scales)
|
||||
return L[computation_reshape](func2, out_shape)
|
||||
|
||||
return woq
|
||||
return woq_int8
|
||||
|
||||
|
||||
def _register_woq_mm_int8_pattern1():
|
||||
|
Reference in New Issue
Block a user