mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] enable CK backend for bf16/fp16 on gfx11 (#143971)
this change enables enable CK backend for fp16 on Gfx11 @jeffdaily Pull Request resolved: https://github.com/pytorch/pytorch/pull/143971 Approved by: https://github.com/jeffdaily
This commit is contained in:
@ -1079,7 +1079,13 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
c10::string_view arch(dprops->gcnArchName);
|
||||
if (arch == "gfx1100") { //no CK GEMM version for gfx1100
|
||||
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
} else{
|
||||
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
|
@ -469,11 +469,315 @@ void dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
// If any of the shapes cant be tiled, we must use padding.
|
||||
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
|
||||
// Dispatch to best implementation.
|
||||
// TODO add more configurations. Optimize.
|
||||
|
||||
bool transa_ = std::tolower(transa) != 'n';
|
||||
bool transb_ = std::tolower(transb) != 'n';
|
||||
|
||||
if (use_padding) {
|
||||
if(transa_ && transb_) { // col , col
|
||||
gemm_impl_wmma<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
true,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(transa_ && !transb_) { // row, col
|
||||
gemm_impl_wmma<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
true,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(!transa_ && transb_) { //col, row
|
||||
gemm_impl_wmma<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
true,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(!transa_ && !transb_) { //row, row
|
||||
gemm_impl_wmma<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
true,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
} else {
|
||||
if(transa_ && transb_) { // col , col
|
||||
gemm_impl_wmma<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
false,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(transa_ && !transb_) { // row, col
|
||||
gemm_impl_wmma<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
false,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(!transa_ && transb_) { //col, row
|
||||
gemm_impl_wmma<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(!transa_ && !transb_) { //row, row
|
||||
gemm_impl_wmma<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>, 8,
|
||||
false,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
c10::string_view arch(dprops->gcnArchName);
|
||||
if (arch == "gfx1100") {
|
||||
dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
} else{
|
||||
dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -297,10 +297,314 @@ void dispatch_half_gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
}
|
||||
#endif
|
||||
}
|
||||
void dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
// If any of the shapes cant be tiled, we must use padding.
|
||||
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
|
||||
// Dispatch to best implementation.
|
||||
// TODO add more configurations. Optimize.
|
||||
|
||||
bool transa_ = std::tolower(transa) != 'n';
|
||||
bool transb_ = std::tolower(transb) != 'n';
|
||||
|
||||
if (use_padding) {
|
||||
if(transa_ && transb_) { // col , col
|
||||
gemm_impl_wmma<
|
||||
at::Half,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
true,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(transa_ && !transb_) { // row, col
|
||||
gemm_impl_wmma<
|
||||
at::Half,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
true,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(!transa_ && transb_) { //col, row
|
||||
gemm_impl_wmma<
|
||||
at::Half,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
true,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(!transa_ && !transb_) { //row, row
|
||||
gemm_impl_wmma<
|
||||
at::Half,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
true,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
} else {
|
||||
if(transa_ && transb_) { // col , col
|
||||
gemm_impl_wmma<
|
||||
at::Half,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
false,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(transa_ && !transb_) { // row, col
|
||||
gemm_impl_wmma<
|
||||
at::Half,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
false,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(!transa_ && transb_) { //col, row
|
||||
gemm_impl_wmma<
|
||||
at::Half,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
false,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(!transa_ && !transb_) { //row, row
|
||||
gemm_impl_wmma<
|
||||
at::Half,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>, 8,
|
||||
false,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
c10::string_view arch(dprops->gcnArchName);
|
||||
if (arch == "gfx1100") {
|
||||
dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
} else{
|
||||
dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -30,6 +30,7 @@
|
||||
#include <ck/library/utility/literals.hpp>
|
||||
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp>
|
||||
|
||||
// Define commonly used types.
|
||||
template <ck::index_t... Is>
|
||||
@ -236,4 +237,180 @@ void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
invoker.Run(argument, StreamConfig{stream, false});
|
||||
}
|
||||
|
||||
|
||||
template <
|
||||
typename Dtype,
|
||||
int BLOCK_SIZE,
|
||||
int MBLOCK,
|
||||
int NBLOCK,
|
||||
int KBLOCK,
|
||||
int K1,
|
||||
int MPER_WMMA,
|
||||
int NPER_WMMA,
|
||||
int MPER_WAVE,
|
||||
int NPER_WAVE,
|
||||
typename ABLOCK_CLUSTER_LENS,
|
||||
typename ABLOCK_CLUSTER_ORDER,
|
||||
typename ABLOCK_SRC_ORDER,
|
||||
int ABLOCK_VECTOR_DIM,
|
||||
int ABLOCK_SCALAR_VEC,
|
||||
int ABLOCK_SCALAR_VEC_K1,
|
||||
bool ABLOCK_LDS_EXTRAM,
|
||||
typename BBLOCK_CLUSTER_LENS,
|
||||
typename BBLOCK_CLUSTER_ORDER,
|
||||
typename BBLOCK_SRC_ORDER,
|
||||
int BBLOCK_VECTOR_DIM,
|
||||
int BBLOCK_SCALAR_VEC,
|
||||
int BBLOCK_SCALAR_VEC_AK1,
|
||||
bool BBLOCK_LDS_EXTRAN,
|
||||
int CMPER_WAVE,
|
||||
int CNPER_WAVE,
|
||||
typename CBLOCK_CLUSTER_LENS,
|
||||
int CNPER_BLOCK,
|
||||
bool PADDING = false,
|
||||
bool TRANSA = false,
|
||||
bool TRANSB = false>
|
||||
void gemm_impl_wmma(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
// Get input information.
|
||||
int M = m;
|
||||
int N = n;
|
||||
int K = k;
|
||||
|
||||
int StrideA = lda;
|
||||
int StrideB = ldb;
|
||||
int StrideC = ldc;
|
||||
|
||||
int KBatch = 1;
|
||||
|
||||
float falpha = alpha;
|
||||
float fbeta = beta;
|
||||
|
||||
using ADataType = typename CkMathType<Dtype>::dtype;
|
||||
using BDataType = typename CkMathType<Dtype>::dtype;
|
||||
using CDataType = typename CkMathType<Dtype>::dtype;
|
||||
using DDataType = typename CkMathType<Dtype>::dtype;
|
||||
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = typename CkMathType<Dtype>::dtype;
|
||||
|
||||
using ALayout = typename CkTensorLayout<TRANSA, TRANSB>::a_layout;
|
||||
using BLayout = typename CkTensorLayout<TRANSA, TRANSB>::b_layout;
|
||||
|
||||
using DLayout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
|
||||
static constexpr auto GemmDefault =
|
||||
ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding =
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault;
|
||||
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGemmWmma_CShuffle<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
1, // NumPrefetch
|
||||
BLOCK_SIZE,
|
||||
MBLOCK,
|
||||
NBLOCK,
|
||||
KBLOCK,
|
||||
K1,
|
||||
MPER_WMMA,
|
||||
NPER_WMMA,
|
||||
MPER_WAVE,
|
||||
NPER_WAVE,
|
||||
ABLOCK_CLUSTER_LENS,
|
||||
ABLOCK_CLUSTER_ORDER,
|
||||
ABLOCK_SRC_ORDER,
|
||||
ABLOCK_VECTOR_DIM,
|
||||
ABLOCK_SCALAR_VEC,
|
||||
ABLOCK_SCALAR_VEC_K1,
|
||||
ABLOCK_LDS_EXTRAM,
|
||||
BBLOCK_CLUSTER_LENS,
|
||||
BBLOCK_CLUSTER_ORDER,
|
||||
BBLOCK_SRC_ORDER,
|
||||
BBLOCK_VECTOR_DIM,
|
||||
BBLOCK_SCALAR_VEC,
|
||||
BBLOCK_SCALAR_VEC_AK1,
|
||||
BBLOCK_LDS_EXTRAN,
|
||||
CMPER_WAVE,
|
||||
CNPER_WAVE,
|
||||
CBLOCK_CLUSTER_LENS,
|
||||
CNPER_BLOCK>;
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
|
||||
using DDataArrayType = std::array<const void*, 0>;
|
||||
DDataArrayType DDataArray;
|
||||
|
||||
// We swap A and B inputs here as a temporary workaround
|
||||
auto argument = gemm.MakeArgument(
|
||||
reinterpret_cast<const ADataType*>(b),
|
||||
reinterpret_cast<const BDataType*>(a),
|
||||
reinterpret_cast<CDataType*>(c),
|
||||
N,
|
||||
M,
|
||||
K,
|
||||
StrideB,
|
||||
StrideA,
|
||||
StrideC,
|
||||
b_element_op,
|
||||
a_element_op,
|
||||
c_element_op);
|
||||
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
printf("error shape = %d %d %d TRANSA=%d TRANSB=%d \n",
|
||||
n, m, k,TRANSA, TRANSB);
|
||||
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
#if 1
|
||||
invoker.Run(argument, StreamConfig{stream, false});
|
||||
#else
|
||||
float ave_time = invoker.Run(argument, StreamConfig{stream, true});
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << N <<" " <<M<<" " << k <<" "
|
||||
<< "stride: "<<StrideA <<" "<<StrideB <<" "<<StrideC <<" "
|
||||
<< gemm.GetTypeString()
|
||||
<< std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
Reference in New Issue
Block a user