[inductor] [cpp] improve cache blocking for is_dynamic_M (#131306)

## Performance
Models with >= 3% performance speedup are listed below:

### AMP single-thread dynamic shape (measured on CPU with AMX support)
No regressions

| Model Family | Model Name | Speedup |
|--------------|------------|---------|
torchbench | soft_actor_critic| 3%

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131306
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
ghstack dependencies: #135275

Co-authored-by: Jiong Gong <jiong.gong@intel.com>
This commit is contained in:
Wu, Chunyuan
2024-09-05 22:36:00 -07:00
committed by PyTorch MergeBot
parent 4ef6c05f65
commit 13bae39e22
2 changed files with 85 additions and 4 deletions

View File

@ -64,11 +64,28 @@ extern "C" {{export_declaration}}
const auto Nt_blocks = Nr_blocks;
const auto Kt_blocks = Kr_blocks;
{%- endif %}
const int64_t Mc_blocks = Mt_blocks;
const int64_t Nc_blocks = 1;
const int64_t Kc_blocks = Kt_blocks;
int64_t Mc_blocks, Nc_blocks, Kc_blocks;
uint32_t L1_cache_size = {{L1_cache_size}};
uint32_t L2_cache_size = {{L2_cache_size}};
mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>(
num_threads,
M,
N,
K,
Mr,
Nr,
Kr,
Mt_blocks,
Nt_blocks,
Kt_blocks,
Mc_blocks,
Nc_blocks,
Kc_blocks,
L1_cache_size,
L2_cache_size
);
const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
const int64_t num_Nc_blocks = Nr_blocks;
const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
{%- else %}
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
@ -979,6 +996,12 @@ class CppPackedGemmTemplate(CppTemplate):
if isinstance(micro_gemm, CppMicroGemmAMX):
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
options = dict(
X=X,
W=W,
@ -1008,6 +1031,8 @@ class CppPackedGemmTemplate(CppTemplate):
w_zp=w_zp,
acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
DTYPE_TO_CPP=DTYPE_TO_CPP,
L1_cache_size=L1_cache_size,
L2_cache_size=L2_cache_size,
)
with contextlib.ExitStack() as stack:
for buf in fake_buffers:

View File

@ -734,6 +734,62 @@ void mm_get_thread_blocking(
assert(Mt != 0);
}
template<typename X_t, typename W_t>
void mm_get_cache_blocking(
int num_threads,
int64_t M,
int64_t N,
int64_t K,
int64_t Mr,
int64_t Nr,
int64_t Kr,
int64_t Mt_blocks,
int64_t Nt_blocks,
int64_t Kt_blocks,
int64_t& Mc_blocks,
int64_t& Nc_blocks,
int64_t& Kc_blocks,
uint32_t L1_cache_size,
uint32_t L2_cache_size) {
// See NOTE [CPP GEMM Cache Blocking Algorithm] for the cache blocking algorithm.
// TODO(jgong5): cache cache blocking results
// TODO: tune the factor here
float L1_limit_factor = 1.0;
float L2_limit_factor = 0.5;
auto L1 = L1_cache_size * L1_limit_factor;
auto L2 = L2_cache_size * L2_limit_factor;
constexpr size_t num_byte_A = sizeof(X_t);
constexpr size_t num_byte_B = sizeof(W_t);
int64_t size_cache_B = Kr * Kt_blocks * Nr * num_byte_B;
Kc_blocks = Kt_blocks;
if (size_cache_B > L1) {
Kc_blocks = (int64_t)std::floor(L1 / (Kr * Nr * num_byte_B));
}
float min_Mc_ratio = 2;
int64_t min_Mc_blocks = std::ceil(min_Mc_ratio * Mr / Nr);
auto Kt_bytes = Kt_blocks * Kr * num_byte_A;
if (min_Mc_blocks * Mr * Kt_bytes < L2) {
Mc_blocks = std::min(Mt_blocks, (int64_t)std::floor(L2 / (Mr * Kt_bytes)));
Nc_blocks = 1;
} else {
Mc_blocks = Mt_blocks;
Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks);
auto Nc_bytes = Nc_blocks * Nr * 4;
auto Kc_bytes = Kc_blocks * Kr * num_byte_A;
if (Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2) {
auto M_max = (std::sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8;
if (M_max < Mc_blocks * Mr) {
Mc_blocks = (int64_t)std::floor(M_max / Mr);
Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks);
}
}
}
}
inline void mm_get_thread_blocks(
int thread_id,
int64_t M_blocks,