mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] [cpp] improve cache blocking for is_dynamic_M (#131306)
## Performance Models with >= 3% performance speedup are listed below: ### AMP single-thread dynamic shape (measured on CPU with AMX support) No regressions | Model Family | Model Name | Speedup | |--------------|------------|---------| torchbench | soft_actor_critic| 3% Pull Request resolved: https://github.com/pytorch/pytorch/pull/131306 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel ghstack dependencies: #135275 Co-authored-by: Jiong Gong <jiong.gong@intel.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
4ef6c05f65
commit
13bae39e22
@ -64,11 +64,28 @@ extern "C" {{export_declaration}}
|
||||
const auto Nt_blocks = Nr_blocks;
|
||||
const auto Kt_blocks = Kr_blocks;
|
||||
{%- endif %}
|
||||
const int64_t Mc_blocks = Mt_blocks;
|
||||
const int64_t Nc_blocks = 1;
|
||||
const int64_t Kc_blocks = Kt_blocks;
|
||||
int64_t Mc_blocks, Nc_blocks, Kc_blocks;
|
||||
uint32_t L1_cache_size = {{L1_cache_size}};
|
||||
uint32_t L2_cache_size = {{L2_cache_size}};
|
||||
mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>(
|
||||
num_threads,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
Mr,
|
||||
Nr,
|
||||
Kr,
|
||||
Mt_blocks,
|
||||
Nt_blocks,
|
||||
Kt_blocks,
|
||||
Mc_blocks,
|
||||
Nc_blocks,
|
||||
Kc_blocks,
|
||||
L1_cache_size,
|
||||
L2_cache_size
|
||||
);
|
||||
const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
|
||||
const int64_t num_Nc_blocks = Nr_blocks;
|
||||
const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
|
||||
const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
|
||||
{%- else %}
|
||||
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
|
||||
@ -979,6 +996,12 @@ class CppPackedGemmTemplate(CppTemplate):
|
||||
if isinstance(micro_gemm, CppMicroGemmAMX):
|
||||
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
|
||||
|
||||
L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
|
||||
assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
||||
|
||||
L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
|
||||
assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
||||
|
||||
options = dict(
|
||||
X=X,
|
||||
W=W,
|
||||
@ -1008,6 +1031,8 @@ class CppPackedGemmTemplate(CppTemplate):
|
||||
w_zp=w_zp,
|
||||
acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
|
||||
DTYPE_TO_CPP=DTYPE_TO_CPP,
|
||||
L1_cache_size=L1_cache_size,
|
||||
L2_cache_size=L2_cache_size,
|
||||
)
|
||||
with contextlib.ExitStack() as stack:
|
||||
for buf in fake_buffers:
|
||||
|
@ -734,6 +734,62 @@ void mm_get_thread_blocking(
|
||||
assert(Mt != 0);
|
||||
}
|
||||
|
||||
template<typename X_t, typename W_t>
|
||||
void mm_get_cache_blocking(
|
||||
int num_threads,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t Mr,
|
||||
int64_t Nr,
|
||||
int64_t Kr,
|
||||
int64_t Mt_blocks,
|
||||
int64_t Nt_blocks,
|
||||
int64_t Kt_blocks,
|
||||
int64_t& Mc_blocks,
|
||||
int64_t& Nc_blocks,
|
||||
int64_t& Kc_blocks,
|
||||
uint32_t L1_cache_size,
|
||||
uint32_t L2_cache_size) {
|
||||
// See NOTE [CPP GEMM Cache Blocking Algorithm] for the cache blocking algorithm.
|
||||
// TODO(jgong5): cache cache blocking results
|
||||
// TODO: tune the factor here
|
||||
float L1_limit_factor = 1.0;
|
||||
float L2_limit_factor = 0.5;
|
||||
|
||||
auto L1 = L1_cache_size * L1_limit_factor;
|
||||
auto L2 = L2_cache_size * L2_limit_factor;
|
||||
|
||||
constexpr size_t num_byte_A = sizeof(X_t);
|
||||
constexpr size_t num_byte_B = sizeof(W_t);
|
||||
|
||||
int64_t size_cache_B = Kr * Kt_blocks * Nr * num_byte_B;
|
||||
Kc_blocks = Kt_blocks;
|
||||
if (size_cache_B > L1) {
|
||||
Kc_blocks = (int64_t)std::floor(L1 / (Kr * Nr * num_byte_B));
|
||||
}
|
||||
|
||||
float min_Mc_ratio = 2;
|
||||
int64_t min_Mc_blocks = std::ceil(min_Mc_ratio * Mr / Nr);
|
||||
auto Kt_bytes = Kt_blocks * Kr * num_byte_A;
|
||||
if (min_Mc_blocks * Mr * Kt_bytes < L2) {
|
||||
Mc_blocks = std::min(Mt_blocks, (int64_t)std::floor(L2 / (Mr * Kt_bytes)));
|
||||
Nc_blocks = 1;
|
||||
} else {
|
||||
Mc_blocks = Mt_blocks;
|
||||
Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks);
|
||||
auto Nc_bytes = Nc_blocks * Nr * 4;
|
||||
auto Kc_bytes = Kc_blocks * Kr * num_byte_A;
|
||||
if (Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2) {
|
||||
auto M_max = (std::sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8;
|
||||
if (M_max < Mc_blocks * Mr) {
|
||||
Mc_blocks = (int64_t)std::floor(M_max / Mr);
|
||||
Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void mm_get_thread_blocks(
|
||||
int thread_id,
|
||||
int64_t M_blocks,
|
||||
|
Reference in New Issue
Block a user