[ROCm] enable CK backend for bf16/fp16 on gfx11 (#143971)

this change enables enable CK backend for fp16 on Gfx11
@jeffdaily

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143971
Approved by: https://github.com/jeffdaily
This commit is contained in:
jzhou
2025-03-18 18:18:18 +00:00
committed by PyTorch MergeBot
parent e0e8639a10
commit dfdf58f8cb
4 changed files with 794 additions and 3 deletions

View File

@ -1079,7 +1079,13 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
auto dprops = at::cuda::getCurrentDeviceProperties();
c10::string_view arch(dprops->gcnArchName);
if (arch == "gfx1100") { //no CK GEMM version for gfx1100
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
} else{
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
}
}
#endif
else {

View File

@ -469,11 +469,315 @@ void dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
}
void dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
// If any of the shapes cant be tiled, we must use padding.
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
// Dispatch to best implementation.
// TODO add more configurations. Optimize.
bool transa_ = std::tolower(transa) != 'n';
bool transb_ = std::tolower(transb) != 'n';
if (use_padding) {
if(transa_ && transb_) { // col , col
gemm_impl_wmma<
at::BFloat16,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
true,
true,
true>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(transa_ && !transb_) { // row, col
gemm_impl_wmma<
at::BFloat16,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
true,
true,
false>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(!transa_ && transb_) { //col, row
gemm_impl_wmma<
at::BFloat16,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
true,
false,
true>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(!transa_ && !transb_) { //row, row
gemm_impl_wmma<
at::BFloat16,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
true,
false,
false>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else {
TORCH_CHECK(false, "unreachable");
}
} else {
if(transa_ && transb_) { // col , col
gemm_impl_wmma<
at::BFloat16,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
false,
true,
true>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(transa_ && !transb_) { // row, col
gemm_impl_wmma<
at::BFloat16,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
false,
true,
false>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(!transa_ && transb_) { //col, row
gemm_impl_wmma<
at::BFloat16,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
false,
false,
true>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(!transa_ && !transb_) { //row, row
gemm_impl_wmma<
at::BFloat16,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
1,
1,
S<1, 32, 1, 8>, 8,
false,
false,
false>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else {
TORCH_CHECK(false, "unreachable");
}
}
}
template <>
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16));
auto dprops = at::cuda::getCurrentDeviceProperties();
c10::string_view arch(dprops->gcnArchName);
if (arch == "gfx1100") {
dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGS(at::BFloat16));
} else{
dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
}
} // namespace at::native

View File

@ -297,10 +297,314 @@ void dispatch_half_gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
}
#endif
}
void dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
// If any of the shapes cant be tiled, we must use padding.
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
// Dispatch to best implementation.
// TODO add more configurations. Optimize.
bool transa_ = std::tolower(transa) != 'n';
bool transb_ = std::tolower(transb) != 'n';
if (use_padding) {
if(transa_ && transb_) { // col , col
gemm_impl_wmma<
at::Half,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
true,
true,
true>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(transa_ && !transb_) { // row, col
gemm_impl_wmma<
at::Half,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
true,
true,
false>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(!transa_ && transb_) { //col, row
gemm_impl_wmma<
at::Half,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
true,
false,
true>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(!transa_ && !transb_) { //row, row
gemm_impl_wmma<
at::Half,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
true,
false,
false>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else {
TORCH_CHECK(false, "unreachable");
}
} else {
if(transa_ && transb_) { // col , col
gemm_impl_wmma<
at::Half,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
false,
true,
true>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(transa_ && !transb_) { // row, col
gemm_impl_wmma<
at::Half,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
false,
true,
false>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(!transa_ && transb_) { //col, row
gemm_impl_wmma<
at::Half,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
false,
false,
true>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(!transa_ && !transb_) { //row, row
gemm_impl_wmma<
at::Half,
256,
128,
256,
64,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
1,
1,
S<1, 32, 1, 8>, 8,
false,
false,
false>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else {
TORCH_CHECK(false, "unreachable");
}
}
}
template <>
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half));
auto dprops = at::cuda::getCurrentDeviceProperties();
c10::string_view arch(dprops->gcnArchName);
if (arch == "gfx1100") {
dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGS(at::Half));
} else{
dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half));
}
}
} // namespace at::native

View File

@ -30,6 +30,7 @@
#include <ck/library/utility/literals.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp>
// Define commonly used types.
template <ck::index_t... Is>
@ -236,4 +237,180 @@ void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
invoker.Run(argument, StreamConfig{stream, false});
}
template <
typename Dtype,
int BLOCK_SIZE,
int MBLOCK,
int NBLOCK,
int KBLOCK,
int K1,
int MPER_WMMA,
int NPER_WMMA,
int MPER_WAVE,
int NPER_WAVE,
typename ABLOCK_CLUSTER_LENS,
typename ABLOCK_CLUSTER_ORDER,
typename ABLOCK_SRC_ORDER,
int ABLOCK_VECTOR_DIM,
int ABLOCK_SCALAR_VEC,
int ABLOCK_SCALAR_VEC_K1,
bool ABLOCK_LDS_EXTRAM,
typename BBLOCK_CLUSTER_LENS,
typename BBLOCK_CLUSTER_ORDER,
typename BBLOCK_SRC_ORDER,
int BBLOCK_VECTOR_DIM,
int BBLOCK_SCALAR_VEC,
int BBLOCK_SCALAR_VEC_AK1,
bool BBLOCK_LDS_EXTRAN,
int CMPER_WAVE,
int CNPER_WAVE,
typename CBLOCK_CLUSTER_LENS,
int CNPER_BLOCK,
bool PADDING = false,
bool TRANSA = false,
bool TRANSB = false>
void gemm_impl_wmma(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
// Get input information.
int M = m;
int N = n;
int K = k;
int StrideA = lda;
int StrideB = ldb;
int StrideC = ldc;
int KBatch = 1;
float falpha = alpha;
float fbeta = beta;
using ADataType = typename CkMathType<Dtype>::dtype;
using BDataType = typename CkMathType<Dtype>::dtype;
using CDataType = typename CkMathType<Dtype>::dtype;
using DDataType = typename CkMathType<Dtype>::dtype;
using AccDataType = float;
using CShuffleDataType = typename CkMathType<Dtype>::dtype;
using ALayout = typename CkTensorLayout<TRANSA, TRANSB>::a_layout;
using BLayout = typename CkTensorLayout<TRANSA, TRANSB>::b_layout;
using DLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault =
ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding =
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGemmWmma_CShuffle<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmSpec,
1, // NumPrefetch
BLOCK_SIZE,
MBLOCK,
NBLOCK,
KBLOCK,
K1,
MPER_WMMA,
NPER_WMMA,
MPER_WAVE,
NPER_WAVE,
ABLOCK_CLUSTER_LENS,
ABLOCK_CLUSTER_ORDER,
ABLOCK_SRC_ORDER,
ABLOCK_VECTOR_DIM,
ABLOCK_SCALAR_VEC,
ABLOCK_SCALAR_VEC_K1,
ABLOCK_LDS_EXTRAM,
BBLOCK_CLUSTER_LENS,
BBLOCK_CLUSTER_ORDER,
BBLOCK_SRC_ORDER,
BBLOCK_VECTOR_DIM,
BBLOCK_SCALAR_VEC,
BBLOCK_SCALAR_VEC_AK1,
BBLOCK_LDS_EXTRAN,
CMPER_WAVE,
CNPER_WAVE,
CBLOCK_CLUSTER_LENS,
CNPER_BLOCK>;
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
using DDataArrayType = std::array<const void*, 0>;
DDataArrayType DDataArray;
// We swap A and B inputs here as a temporary workaround
auto argument = gemm.MakeArgument(
reinterpret_cast<const ADataType*>(b),
reinterpret_cast<const BDataType*>(a),
reinterpret_cast<CDataType*>(c),
N,
M,
K,
StrideB,
StrideA,
StrideC,
b_element_op,
a_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
printf("error shape = %d %d %d TRANSA=%d TRANSB=%d \n",
n, m, k,TRANSA, TRANSB);
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
auto stream = at::cuda::getCurrentHIPStream().stream();
#if 1
invoker.Run(argument, StreamConfig{stream, false});
#else
float ave_time = invoker.Run(argument, StreamConfig{stream, true});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << N <<" " <<M<<" " << k <<" "
<< "stride: "<<StrideA <<" "<<StrideB <<" "<<StrideC <<" "
<< gemm.GetTypeString()
<< std::endl;
#endif
}
} // namespace at::native