mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 10:03:50 +08:00
src: gpu: intel: gemm: add strided batch support to group sums
This commit is contained in:
committed by
Andrey Guskov
parent
dc1d8c3d55
commit
7b351ae955
@ -185,6 +185,12 @@ status_t gen_t::launch_nocopy(const exec_ctx_t &ctx,
|
||||
if (problem->hasBOffset()) {
|
||||
arg_list.set(argn++, pd()->eff_zp_stride(i, DNNL_ARG_B));
|
||||
}
|
||||
if (problem->needsAGroupSums()) {
|
||||
arg_list.set(argn++, pd()->eff_gs_stride(i, DNNL_ARG_A));
|
||||
}
|
||||
if (problem->needsBGroupSums()) {
|
||||
arg_list.set(argn++, pd()->eff_gs_stride(i, DNNL_ARG_B));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < po_count; i++) {
|
||||
if (problem->postOps.binaryBatch[i]) {
|
||||
|
@ -1057,6 +1057,14 @@ void gen_kernel_t::init_interface() {
|
||||
interface_.newArgument(
|
||||
"offset_stride_B" + std::to_string(i), DataType::d);
|
||||
}
|
||||
if (problem.needsAGroupSums()) {
|
||||
interface_.newArgument(
|
||||
"group_sums_stride_A" + std::to_string(i), DataType::d);
|
||||
}
|
||||
if (problem.needsBGroupSums()) {
|
||||
interface_.newArgument(
|
||||
"group_sums_stride_B" + std::to_string(i), DataType::d);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < problem.postOps.len(); i++) {
|
||||
if (problem.postOps[i].is_binary()
|
||||
|
@ -550,6 +550,7 @@ void Generator<hw>::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr
|
||||
Subregister bOffsetA[4], bOffsetB[4], bOffsetC[4];
|
||||
Subregister bOffsetAs[4], bOffsetBs[4];
|
||||
Subregister bOffsetAo[4], bOffsetBo[4];
|
||||
Subregister bOffsetAg[4], bOffsetBg[4];
|
||||
|
||||
for (int b = 0; b < problem.batchDims; b++) {
|
||||
bOffsetA[b] = state.inputs.strideA[b];
|
||||
@ -567,6 +568,12 @@ void Generator<hw>::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr
|
||||
if(problem.hasBOffset()){
|
||||
bOffsetBo[b] = state.inputs.strideOffsetB[b];
|
||||
}
|
||||
if(problem.needsAGroupSums()){
|
||||
bOffsetAg[b] = state.inputs.strideGroupSumsA[b];
|
||||
}
|
||||
if(problem.needsBGroupSums()){
|
||||
bOffsetBg[b] = state.inputs.strideGroupSumsB[b];
|
||||
}
|
||||
if (strategy.A.base.isStateless()) bOffsetA[b] = state.ra.alloc_sub<uint64_t>();
|
||||
if (strategy.B.base.isStateless()) bOffsetB[b] = state.ra.alloc_sub<uint64_t>();
|
||||
if (strategy.C.base.isStateless()) bOffsetC[b] = state.ra.alloc_sub<uint64_t>();
|
||||
@ -588,6 +595,12 @@ void Generator<hw>::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr
|
||||
if(problem.hasBOffset()){
|
||||
emul(1, bOffsetBo[b], state.inputs.strideOffsetB[b], state.batchID[b], strategy, state);
|
||||
}
|
||||
if(problem.needsAGroupSums()){
|
||||
emul(1, bOffsetAg[b], state.inputs.strideGroupSumsA[b], state.batchID[b], strategy, state);
|
||||
}
|
||||
if(problem.needsBGroupSums()){
|
||||
emul(1, bOffsetBg[b], state.inputs.strideGroupSumsB[b], state.batchID[b], strategy, state);
|
||||
}
|
||||
}
|
||||
|
||||
if(problem.hasAScale() && state.offsetAs.isInvalid()){
|
||||
@ -606,6 +619,14 @@ void Generator<hw>::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr
|
||||
state.offsetBo = state.ra.alloc_sub(state.offsetB.getType());
|
||||
emov(1, state.offsetBo, 0, strategy, state);
|
||||
}
|
||||
if(problem.needsAGroupSums() && state.inputs.offsetAg.isInvalid()){
|
||||
state.inputs.offsetAg = state.ra.alloc_sub(state.offsetA.getType());
|
||||
emov(1, state.inputs.offsetAg, 0, strategy, state);
|
||||
}
|
||||
if(problem.needsBGroupSums() && state.inputs.offsetBg.isInvalid()){
|
||||
state.inputs.offsetBg = state.ra.alloc_sub(state.offsetB.getType());
|
||||
emov(1, state.inputs.offsetBg, 0, strategy, state);
|
||||
}
|
||||
|
||||
for (int b = 0; b < problem.batchDims; b++) {
|
||||
eadd(1, state.offsetA, state.offsetA, bOffsetA[b], strategy, state);
|
||||
@ -626,6 +647,12 @@ void Generator<hw>::gemmOffsetBatchABC(const GEMMProblem &problem, const GEMMStr
|
||||
if(problem.hasBOffset()){
|
||||
eadd(1, state.offsetBo, state.offsetBo, bOffsetBo[b], strategy, state);
|
||||
}
|
||||
if(problem.needsAGroupSums()){
|
||||
eadd(1, state.inputs.offsetAg, state.inputs.offsetAg, bOffsetAg[b], strategy, state);
|
||||
}
|
||||
if(problem.needsBGroupSums()){
|
||||
eadd(1, state.inputs.offsetBg, state.inputs.offsetBg, bOffsetBg[b], strategy, state);
|
||||
}
|
||||
if (!strategy.persistentLoop()) {
|
||||
state.ra.safeRelease(state.inputs.strideA[b]);
|
||||
state.ra.safeRelease(state.inputs.strideB[b]);
|
||||
@ -877,6 +904,16 @@ void Generator<hw>::gemmScaleInputs(const GEMMProblem &problem, const GEMMStrate
|
||||
scale(problem.Tbo, state.inputs.strideOffsetB[b]);
|
||||
}
|
||||
}
|
||||
if(problem.needsAGroupSums()){
|
||||
for (int b = 0; b < problem.batchDims; b++) {
|
||||
scale(problem.Tag, state.inputs.strideGroupSumsA[b]);
|
||||
}
|
||||
}
|
||||
if(problem.needsBGroupSums()){
|
||||
for (int b = 0; b < problem.batchDims; b++) {
|
||||
scale(problem.Tbg, state.inputs.strideGroupSumsB[b]);
|
||||
}
|
||||
}
|
||||
for (int b = 0; b < problem.batchDims; b++) {
|
||||
scale(Ta_ext, inputs.strideA[b]);
|
||||
scale(Tb_ext, inputs.strideB[b]);
|
||||
@ -1907,7 +1944,7 @@ bool Generator<hw>::gemmAccumulateCSetup(GEMMProblem &problem, GEMMStrategy &str
|
||||
}
|
||||
if (ag2D) {
|
||||
setupQAddr(Tag, state.Ag_addrs, state.Ag_layout, state.inputs.agPtr,
|
||||
i0qLate, A_h0qLate, state.inputs.ldag);
|
||||
i0qLate, A_h0qLate, state.inputs.ldag, state.inputs.offsetAg);
|
||||
}
|
||||
if (bo2D) {
|
||||
auto &j0o = lateOffsetB ? j0qLate : j0q;
|
||||
@ -1923,7 +1960,7 @@ bool Generator<hw>::gemmAccumulateCSetup(GEMMProblem &problem, GEMMStrategy &str
|
||||
}
|
||||
if (bg2D) {
|
||||
setupQAddr(Tbg, state.Bg_addrs, state.Bg_layout, state.inputs.bgPtr,
|
||||
B_h0qLate, j0qLate, state.inputs.ldbg);
|
||||
B_h0qLate, j0qLate, state.inputs.ldbg, state.inputs.offsetBg);
|
||||
}
|
||||
|
||||
if (i0qLate != state.i0) state.ra.safeRelease(i0qLate);
|
||||
@ -2629,6 +2666,12 @@ void Generator<hw>::gemmInitInterface(GEMMProblem &problem, GEMMStrategy &strate
|
||||
if(problem.hasBOffset()){
|
||||
state.inputs.strideOffsetB.push_back(interface.getArgument("offset_stride_B" + istr));
|
||||
}
|
||||
if(problem.needsAGroupSums()){
|
||||
state.inputs.strideGroupSumsA.push_back(interface.getArgument("group_sums_stride_A" + istr));
|
||||
}
|
||||
if(problem.needsBGroupSums()){
|
||||
state.inputs.strideGroupSumsB.push_back(interface.getArgument("group_sums_stride_B" + istr));
|
||||
}
|
||||
if (i < problem.batchDims - 1) {
|
||||
state.inputs.batchSize.push_back(interface.getArgument("batch_size" + istr));
|
||||
state.inputs.recipBatchSize.push_back(interface.getArgument("recip_batch_size" + istr));
|
||||
@ -2914,6 +2957,12 @@ void Generator<hw>::gemmInitInterface(GEMMProblem &problem, GEMMStrategy &strate
|
||||
if(problem.hasBOffset()){
|
||||
state.ra.claim(state.inputs.strideOffsetB[i]);
|
||||
}
|
||||
if(problem.needsAGroupSums()){
|
||||
state.ra.claim(state.inputs.strideGroupSumsA[i]);
|
||||
}
|
||||
if(problem.needsBGroupSums()){
|
||||
state.ra.claim(state.inputs.strideGroupSumsB[i]);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < problem.batchDims - 1; i++) {
|
||||
state.ra.claim(state.inputs.batchSize[i]);
|
||||
|
@ -236,6 +236,8 @@ struct GEMMState : public CommonState {
|
||||
std::vector<ngen::Subregister> strideScaleB; // ud
|
||||
std::vector<ngen::Subregister> strideOffsetA; // ud
|
||||
std::vector<ngen::Subregister> strideOffsetB; // ud
|
||||
std::vector<ngen::Subregister> strideGroupSumsA; // ud
|
||||
std::vector<ngen::Subregister> strideGroupSumsB; // ud
|
||||
std::vector<ngen::Subregister> batchSize; // ud
|
||||
std::vector<ngen::Subregister> recipBatchSize; // ud
|
||||
ngen::Subregister offsetBatch; // ud, used for non-strided batch.
|
||||
|
@ -463,6 +463,15 @@ dim_t pd_t::eff_zp_stride(int idx, int arg) const {
|
||||
return zp_md.format_desc.blocking.strides[idx];
|
||||
}
|
||||
|
||||
dim_t pd_t::eff_gs_stride(int idx, int arg) const {
|
||||
gpu_assert(utils::one_of(arg, DNNL_ARG_A, DNNL_ARG_B));
|
||||
auto gs_md = ((DNNL_ARG_A == arg) ^ swap_ab()) ? a_gs_md_ : b_gs_md_;
|
||||
gpu_assert(memory_desc_wrapper(gs_md).is_plain())
|
||||
<< "Expected plain gs_md_";
|
||||
if (gs_md.dims[idx] == 1) return 0;
|
||||
return gs_md.format_desc.blocking.strides[idx];
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace gemm
|
||||
} // namespace intel
|
||||
|
@ -212,6 +212,7 @@ struct pd_t : public gemm::pd_t {
|
||||
}
|
||||
dim_t eff_scale_stride(int idx, int arg) const;
|
||||
dim_t eff_zp_stride(int idx, int arg) const;
|
||||
dim_t eff_gs_stride(int idx, int arg) const;
|
||||
bool a_scales_grouped() const {
|
||||
bool k_grouped
|
||||
= 1 < a_scales_group_k_ && a_scales_group_k_ < desc()->k();
|
||||
|
@ -1,3 +1,7 @@
|
||||
--reset --skip-impl=ref --bia_mask=4 --bia-dt=f16 --dt=s8:u8:f16 --stag=abc --wtag=acb --dtag=abc --attr-post-ops=binary_mul:f16:5:abc+binary_add:f16:7:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src0:per_tensor:s32:1x128 --attr-scratchpad=user 2x1024x2432:1x2432x2432
|
||||
--reset --skip-impl=ref --bia_mask=4 --bia-dt=f16 --dt=s8:u8:f16 --stag=abc --wtag=acb --dtag=abc --attr-post-ops=binary_mul:f16:5:abc+binary_add:f16:7:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src0:per_tensor:s32:1x128 --attr-scratchpad=user 2x1024x9728:1x9728x2432
|
||||
--reset --skip-impl=ref --bia_mask=4 --bia-dt=f16 --dt=s8:u8:f16 --stag=abc --wtag=acb --dtag=abc --attr-post-ops=binary_mul:f16:5:abc+binary_add:f16:7:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src0:per_tensor:s32:1x128 --attr-scratchpad=user 2x333x2432:1x2432x2432
|
||||
--reset --skip-impl=ref --bia_mask=4 --bia-dt=f16 --dt=s8:u8:f16 --stag=abc --wtag=acb --dtag=abc --attr-post-ops=binary_mul:f16:5:abc+binary_add:f16:7:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src0:per_tensor:s32:1x128 --attr-scratchpad=user 2x333x9728:1x9728x2432
|
||||
--reset --skip-impl=ref --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --bia_mask=4 --bia-dt=f16 --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src:per_tensor:s32:1x128 --attr-scratchpad=user 1024x1x10240:1x10240x2560
|
||||
--reset --skip-impl=ref --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-post-ops=eltwise_swish:1.0+binary_mul:f16:6:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src:per_tensor:s32:1x128 --attr-scratchpad=user 1x4342x1536:1x1536x3840
|
||||
--reset --skip-impl=ref --dt=s8:u8:f16 --stag=abc --wtag=cab --dtag=abc --attr-post-ops=binary_add:f16:6:abc --attr-scales=src0:per_tensor:f16:1x128+wei:per_oc:f16 --attr-zero-points=wei:per_oc:u8 --attr-precomputed-reductions=src:per_tensor:s32:1x128 --attr-scratchpad=user 1x2172x3840:1x3840x1536
|
||||
|
Reference in New Issue
Block a user