src: gpu: intel: gemm: add strided batch support to group sums

This commit is contained in:
Guskov, Andrey Y
2025-10-13 23:00:57 -07:00
committed by Andrey Guskov
parent dc1d8c3d55
commit 7b351ae955
7 changed files with 81 additions and 2 deletions

View File

@ -185,6 +185,12 @@ status_t gen_t::launch_nocopy(const exec_ctx_t &ctx,
if (problem->hasBOffset()) {
arg_list.set(argn++, pd()->eff_zp_stride(i, DNNL_ARG_B));
}
if (problem->needsAGroupSums()) {
arg_list.set(argn++, pd()->eff_gs_stride(i, DNNL_ARG_A));
}
if (problem->needsBGroupSums()) {
arg_list.set(argn++, pd()->eff_gs_stride(i, DNNL_ARG_B));
}
}
for (int i = 0; i < po_count; i++) {
if (problem->postOps.binaryBatch[i]) {

View File

@ -1057,6 +1057,14 @@ void gen_kernel_t::init_interface() {
interface_.newArgument(
"offset_stride_B" + std::to_string(i), DataType::d);
}
if (problem.needsAGroupSums()) {
interface_.newArgument(
"group_sums_stride_A" + std::to_string(i), DataType::d);
}
if (problem.needsBGroupSums()) {
interface_.newArgument(
"group_sums_stride_B" + std::to_string(i), DataType::d);
}
}
for (size_t i = 0; i < problem.postOps.len(); i++) {
if (problem.postOps[i].is_binary()

View File

@ -550,6 +550,7 @@ void Generator<hw>::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr
Subregister bOffsetA[4], bOffsetB[4], bOffsetC[4];
Subregister bOffsetAs[4], bOffsetBs[4];
Subregister bOffsetAo[4], bOffsetBo[4];
Subregister bOffsetAg[4], bOffsetBg[4];
for (int b = 0; b < problem.batchDims; b++) {
bOffsetA[b] = state.inputs.strideA[b];
@ -567,6 +568,12 @@ void Generator<hw>::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr
if(problem.hasBOffset()){
bOffsetBo[b] = state.inputs.strideOffsetB[b];
}
if(problem.needsAGroupSums()){
bOffsetAg[b] = state.inputs.strideGroupSumsA[b];
}
if(problem.needsBGroupSums()){
bOffsetBg[b] = state.inputs.strideGroupSumsB[b];
}
if (strategy.A.base.isStateless()) bOffsetA[b] = state.ra.alloc_sub<uint64_t>();
if (strategy.B.base.isStateless()) bOffsetB[b] = state.ra.alloc_sub<uint64_t>();
if (strategy.C.base.isStateless()) bOffsetC[b] = state.ra.alloc_sub<uint64_t>();
@ -588,6 +595,12 @@ void Generator<hw>::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr
if(problem.hasBOffset()){
emul(1, bOffsetBo[b], state.inputs.strideOffsetB[b], state.batchID[b], strategy, state);
}
if(problem.needsAGroupSums()){
emul(1, bOffsetAg[b], state.inputs.strideGroupSumsA[b], state.batchID[b], strategy, state);
}
if(problem.needsBGroupSums()){
emul(1, bOffsetBg[b], state.inputs.strideGroupSumsB[b], state.batchID[b], strategy, state);
}
}
if(problem.hasAScale() && state.offsetAs.isInvalid()){
@ -606,6 +619,14 @@ void Generator<hw>::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr
state.offsetBo = state.ra.alloc_sub(state.offsetB.getType());
emov(1, state.offsetBo, 0, strategy, state);
}
if(problem.needsAGroupSums() && state.inputs.offsetAg.isInvalid()){
state.inputs.offsetAg = state.ra.alloc_sub(state.offsetA.getType());
emov(1, state.inputs.offsetAg, 0, strategy, state);
}
if(problem.needsBGroupSums() && state.inputs.offsetBg.isInvalid()){
state.inputs.offsetBg = state.ra.alloc_sub(state.offsetB.getType());
emov(1, state.inputs.offsetBg, 0, strategy, state);
}
for (int b = 0; b < problem.batchDims; b++) {
eadd(1, state.offsetA, state.offsetA, bOffsetA[b], strategy, state);
@ -626,6 +647,12 @@ void Generator<hw>::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr
if(problem.hasBOffset()){
eadd(1, state.offsetBo, state.offsetBo, bOffsetBo[b], strategy, state);
}
if(problem.needsAGroupSums()){
eadd(1, state.inputs.offsetAg, state.inputs.offsetAg, bOffsetAg[b], strategy, state);
}
if(problem.needsBGroupSums()){
eadd(1, state.inputs.offsetBg, state.inputs.offsetBg, bOffsetBg[b], strategy, state);
}
if (!strategy.persistentLoop()) {
state.ra.safeRelease(state.inputs.strideA[b]);
state.ra.safeRelease(state.inputs.strideB[b]);
@ -877,6 +904,16 @@ void Generator<hw>::gemmScaleInputs(const GEMMProblem &problem, const GEMMStrate
scale(problem.Tbo, state.inputs.strideOffsetB[b]);
}
}
if(problem.needsAGroupSums()){
for (int b = 0; b < problem.batchDims; b++) {
scale(problem.Tag, state.inputs.strideGroupSumsA[b]);
}
}
if(problem.needsBGroupSums()){
for (int b = 0; b < problem.batchDims; b++) {
scale(problem.Tbg, state.inputs.strideGroupSumsB[b]);
}
}
for (int b = 0; b < problem.batchDims; b++) {
scale(Ta_ext, inputs.strideA[b]);
scale(Tb_ext, inputs.strideB[b]);
@ -1907,7 +1944,7 @@ bool Generator<hw>::gemmAccumulateCSetup(GEMMProblem &problem, GEMMStrategy &str
}
if (ag2D) {
setupQAddr(Tag, state.Ag_addrs, state.Ag_layout, state.inputs.agPtr,
i0qLate, A_h0qLate, state.inputs.ldag);
i0qLate, A_h0qLate, state.inputs.ldag, state.inputs.offsetAg);
}
if (bo2D) {
auto &j0o = lateOffsetB ? j0qLate : j0q;
@ -1923,7 +1960,7 @@ bool Generator<hw>::gemmAccumulateCSetup(GEMMProblem &problem, GEMMStrategy &str
}
if (bg2D) {
setupQAddr(Tbg, state.Bg_addrs, state.Bg_layout, state.inputs.bgPtr,
B_h0qLate, j0qLate, state.inputs.ldbg);
B_h0qLate, j0qLate, state.inputs.ldbg, state.inputs.offsetBg);
}
if (i0qLate != state.i0) state.ra.safeRelease(i0qLate);
@ -2629,6 +2666,12 @@ void Generator<hw>::gemmInitInterface(GEMMProblem &problem, GEMMStrategy &strate
if(problem.hasBOffset()){
state.inputs.strideOffsetB.push_back(interface.getArgument("offset_stride_B" + istr));
}
if(problem.needsAGroupSums()){
state.inputs.strideGroupSumsA.push_back(interface.getArgument("group_sums_stride_A" + istr));
}
if(problem.needsBGroupSums()){
state.inputs.strideGroupSumsB.push_back(interface.getArgument("group_sums_stride_B" + istr));
}
if (i < problem.batchDims - 1) {
state.inputs.batchSize.push_back(interface.getArgument("batch_size" + istr));
state.inputs.recipBatchSize.push_back(interface.getArgument("recip_batch_size" + istr));
@ -2914,6 +2957,12 @@ void Generator<hw>::gemmInitInterface(GEMMProblem &problem, GEMMStrategy &strate
if(problem.hasBOffset()){
state.ra.claim(state.inputs.strideOffsetB[i]);
}
if(problem.needsAGroupSums()){
state.ra.claim(state.inputs.strideGroupSumsA[i]);
}
if(problem.needsBGroupSums()){
state.ra.claim(state.inputs.strideGroupSumsB[i]);
}
}
for (int i = 0; i < problem.batchDims - 1; i++) {
state.ra.claim(state.inputs.batchSize[i]);

View File

@ -236,6 +236,8 @@ struct GEMMState : public CommonState {
std::vector<ngen::Subregister> strideScaleB; // ud
std::vector<ngen::Subregister> strideOffsetA; // ud
std::vector<ngen::Subregister> strideOffsetB; // ud
std::vector<ngen::Subregister> strideGroupSumsA; // ud
std::vector<ngen::Subregister> strideGroupSumsB; // ud
std::vector<ngen::Subregister> batchSize; // ud
std::vector<ngen::Subregister> recipBatchSize; // ud
ngen::Subregister offsetBatch; // ud, used for non-strided batch.

View File

@ -463,6 +463,15 @@ dim_t pd_t::eff_zp_stride(int idx, int arg) const {
return zp_md.format_desc.blocking.strides[idx];
}
dim_t pd_t::eff_gs_stride(int idx, int arg) const {
gpu_assert(utils::one_of(arg, DNNL_ARG_A, DNNL_ARG_B));
auto gs_md = ((DNNL_ARG_A == arg) ^ swap_ab()) ? a_gs_md_ : b_gs_md_;
gpu_assert(memory_desc_wrapper(gs_md).is_plain())
<< "Expected plain gs_md_";
if (gs_md.dims[idx] == 1) return 0;
return gs_md.format_desc.blocking.strides[idx];
}
} // namespace jit
} // namespace gemm
} // namespace intel

View File

@ -212,6 +212,7 @@ struct pd_t : public gemm::pd_t {
}
dim_t eff_scale_stride(int idx, int arg) const;
dim_t eff_zp_stride(int idx, int arg) const;
dim_t eff_gs_stride(int idx, int arg) const;
bool a_scales_grouped() const {
bool k_grouped
= 1 < a_scales_group_k_ && a_scales_group_k_ < desc()->k();

View File

@ -1,3 +1,7 @@
--reset --skip-impl=ref --bia_mask=4 --bia-dt=f16 --dt=s8:u8:f16 --stag=abc --wtag=acb --dtag=abc --attr-post-ops=binary_mul:f16:5:abc+binary_add:f16:7:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src0:per_tensor:s32:1x128 --attr-scratchpad=user 2x1024x2432:1x2432x2432
--reset --skip-impl=ref --bia_mask=4 --bia-dt=f16 --dt=s8:u8:f16 --stag=abc --wtag=acb --dtag=abc --attr-post-ops=binary_mul:f16:5:abc+binary_add:f16:7:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src0:per_tensor:s32:1x128 --attr-scratchpad=user 2x1024x9728:1x9728x2432
--reset --skip-impl=ref --bia_mask=4 --bia-dt=f16 --dt=s8:u8:f16 --stag=abc --wtag=acb --dtag=abc --attr-post-ops=binary_mul:f16:5:abc+binary_add:f16:7:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src0:per_tensor:s32:1x128 --attr-scratchpad=user 2x333x2432:1x2432x2432
--reset --skip-impl=ref --bia_mask=4 --bia-dt=f16 --dt=s8:u8:f16 --stag=abc --wtag=acb --dtag=abc --attr-post-ops=binary_mul:f16:5:abc+binary_add:f16:7:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src0:per_tensor:s32:1x128 --attr-scratchpad=user 2x333x9728:1x9728x2432
--reset --skip-impl=ref --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --bia_mask=4 --bia-dt=f16 --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src:per_tensor:s32:1x128 --attr-scratchpad=user 1024x1x10240:1x10240x2560
--reset --skip-impl=ref --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-post-ops=eltwise_swish:1.0+binary_mul:f16:6:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src:per_tensor:s32:1x128 --attr-scratchpad=user 1x4342x1536:1x1536x3840
--reset --skip-impl=ref --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-post-ops=binary_add:f16:6:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src:per_tensor:s32:1x128 --attr-scratchpad=user 1x2172x3840:1x3840x1536