[ROCm] CK-based GEMM (#131004)

- composable_kernel as a third_party submodule
- "ck" as a `torch.backends.cuda.preferred_linalg_library()`
- reference CK gemm implementations for float, bfloat16, and half types

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131004
Approved by: https://github.com/xw285cornell, https://github.com/pruthvistony

Co-authored-by: Andres Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Pruthvi Madugundu <pruthvigithub@gmail.com>
This commit is contained in:
Jeff Daily
2024-10-20 02:57:41 +00:00
committed by PyTorch MergeBot
parent 0a2407b93c
commit 3f3b692a00
20 changed files with 1645 additions and 32 deletions

4
.gitmodules vendored
View File

@ -127,3 +127,7 @@
[submodule "third_party/NVTX"]
path = third_party/NVTX
url = https://github.com/NVIDIA/NVTX.git
[submodule "third_party/composable_kernel"]
path = third_party/composable_kernel
url = https://github.com/ROCm/composable_kernel.git
branch = develop

View File

@ -7,7 +7,7 @@
namespace at {
enum class BlasBackend : int8_t { Cublas, Cublaslt };
enum class BlasBackend : int8_t { Cublas, Cublaslt, Ck };
inline std::string BlasBackendToString(at::BlasBackend backend) {
switch (backend) {
@ -15,6 +15,8 @@ inline std::string BlasBackendToString(at::BlasBackend backend) {
return "at::BlasBackend::Cublas";
case BlasBackend::Cublaslt:
return "at::BlasBackend::Cublaslt";
case BlasBackend::Ck:
return "at::BlasBackend::Ck";
default:
TORCH_CHECK(false, "Unknown blas backend");
}

View File

@ -321,6 +321,8 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
#else
TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
"Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(),
"Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm.");
if (b != at::BlasBackend::Cublas) {
TORCH_WARN_ONCE(
"torch.backends.cuda.preferred_blas_library is an experimental feature. "

View File

@ -149,6 +149,9 @@ class TORCH_API Context {
static bool hasCuBLASLt() {
return detail::getCUDAHooks().hasCuBLASLt();
}
static bool hasROCM() {
return detail::getCUDAHooks().hasROCM();
}
static bool hasHIP() {
return detail::getHIPHooks().hasHIP();
}

View File

@ -19,6 +19,7 @@
// until hipblas has an API to accept flags, we must use rocblas here
#include <hipblas/hipblas.h>
#include <rocblas/rocblas.h>
#include <ATen/native/hip/ck_gemm.h>
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
// needed to work around calling rocblas API instead of hipblas API
@ -793,6 +794,7 @@ inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
static_assert(false && sizeof(Dtype), "at::cuda::blas::gemm_internal_cublas: not implemented");
}
template <>
void gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
// See Note [Writing Nondeterministic Operations]
@ -1001,6 +1003,11 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
#endif
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
}
#endif
else {
gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
}
@ -1012,6 +1019,11 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
}
#endif
else {
gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
}
@ -1055,6 +1067,11 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
#endif
else {
gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
@ -1066,6 +1083,11 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
#ifdef USE_ROCM
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
#endif
else {
gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}

View File

@ -79,6 +79,7 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
transpose_tensor = tensor.is_contiguous();
return resolve_conj_if_indicated(tensor, true);
}
IntArrayRef tensor_strides = tensor.strides();
IntArrayRef tensor_sizes = tensor.sizes();
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {

View File

@ -7,7 +7,6 @@
// ROCm 6.3 is planned to have these functions, but until then here they are.
#if defined(USE_ROCM) && ROCM_VERSION >= 60201
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {

View File

@ -0,0 +1,24 @@
#pragma once
#include <ATen/OpMathType.h>
#include <ATen/hip/HIPBlas.h>
namespace at::native {
template <typename Dtype>
inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented");
}
template <>
void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double));
template <>
void gemm_internal_ck<float>(CUDABLAS_GEMM_ARGTYPES(float));
template <>
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
} // namespace at::native

View File

@ -0,0 +1,479 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/native/hip/ck_gemm.h>
#include <ATen/native/hip/ck_gemm_template.h>
#include <ck/utility/sequence.hpp>
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
namespace at::native {
void dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
// If any of the shapes cant be tiled, we must use padding.
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
// Dispatch to best implementation.
// TODO add more configurations. Optimize.
bool transa_ = std::tolower(transa) != 'n';
bool transb_ = std::tolower(transb) != 'n';
if (use_padding) {
if (m <= 128) {
if(transa_ && transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
true,
true>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(transa_ && !transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
true,
false>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(!transa_ && transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
false,
true>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(!transa_ && !transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
false,
false>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else {
TORCH_CHECK(false, "unreachable");
}
} else {
if(transa_ && transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
true,
true>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(transa_ && !transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
true,
false>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(!transa_ && transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
false,
true>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(!transa_ && !transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
false,
false>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else {
TORCH_CHECK(false, "unreachable");
}
}
} else {
{
if(transa_ && transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
false,
true,
true>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(transa_ && !transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
false,
true,
false>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(!transa_ && transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
false,
false,
true>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else if(!transa_ && !transb_) {
gemm_impl<
at::BFloat16,
256,
128,
128,
64,
8,
8,
32,
32,
2,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
8,
8,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
8,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
false,
false,
false>
(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
else {
TORCH_CHECK(false, "unreachable");
}
}
}
}
template <>
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
} // namespace at::native

View File

@ -0,0 +1,486 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/native/hip/ck_gemm.h>
#include <ATen/native/hip/ck_gemm_template.h>
#include <ck/utility/sequence.hpp>
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
namespace at::native {
void dispatch_float_gemm(CUDABLAS_GEMM_ARGTYPES(float)) {
// If any of the shapes cant be tiled, we must use padding.
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
// Dispatch to best implementation.
// TODO add more configurations. Optimize.
bool transa_ = std::tolower(transa) != 'n';
bool transb_ = std::tolower(transb) != 'n';
if (use_padding) {
if (m <= 128) {
if(transa_ && transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
true,
true>
(CUDABLAS_GEMM_ARGS(float));
}
else if(transa_ && !transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
true,
false>
(CUDABLAS_GEMM_ARGS(float));
}
else if(!transa_ && transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
false,
true>
(CUDABLAS_GEMM_ARGS(float));
}
else if(!transa_ && !transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
false,
false>
(CUDABLAS_GEMM_ARGS(float));
}
else {
TORCH_CHECK(false, "unreachable");
}
} else {
if(transa_ && transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
true,
true>
(CUDABLAS_GEMM_ARGS(float));
}
else if(transa_ && !transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
true,
false>
(CUDABLAS_GEMM_ARGS(float));
}
else if(!transa_ && transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
false,
true>
(CUDABLAS_GEMM_ARGS(float));
}
else if(!transa_ && !transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
true,
false,
false>
(CUDABLAS_GEMM_ARGS(float));
}
else {
TORCH_CHECK(false, "unreachable");
}
}
} else {
{
if(transa_ && transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
false,
true,
true>
(CUDABLAS_GEMM_ARGS(float));
}
else if(transa_ && !transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
false,
true,
false>
(CUDABLAS_GEMM_ARGS(float));
}
else if(!transa_ && transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
false,
false,
true>
(CUDABLAS_GEMM_ARGS(float));
}
else if(!transa_ && !transb_) {
gemm_impl<
float,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
2,
4,
4,
0,
S<8,32,1>,
S<0,2,1>,
S<0,2,1>,
1,
4,
4,
0,
1,
1,
S<1, 32, 1, 8>,
S<4>,
false,
false,
false>
(CUDABLAS_GEMM_ARGS(float));
}
else {
TORCH_CHECK(false, "unreachable");
}
}
}
}
template <>
void gemm_internal_ck<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
dispatch_float_gemm(CUDABLAS_GEMM_ARGS(float));
}
// temporarily put this here until we implement double support
template <>
void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
return;
}
} // namespace at::native

View File

@ -0,0 +1,306 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/native/hip/ck_gemm.h>
#include <ATen/native/hip/ck_gemm_template.h>
#include <ck/utility/sequence.hpp>
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
namespace at::native {
void dispatch_half_gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
#if 0
// If any of the shapes cant be tiled, we must use padding.
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
// Dispatch to best implementation.
// TODO add more configurations. Optimize.
bool transa_ = std::tolower(transa) != 'n';
bool transb_ = std::tolower(transb) != 'n';
if (use_padding) {
if (m <= 128) {
if(transa_ && transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
1,
true,
true,
true>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(transa_ && !transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
1,
true,
true,
false>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(!transa_ && transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
1,
true,
false,
true>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(!transa_ && !transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
1,
true,
false,
false>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else {
TORCH_CHECK(false, "unreachable");
}
} else {
if(transa_ && transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
1,
true,
true,
true>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(transa_ && !transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
1,
true,
true,
false>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(!transa_ && transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
1,
true,
false,
true>(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(!transa_ && !transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
1,
true,
false,
false>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else {
TORCH_CHECK(false, "unreachable");
}
}
} else {
{
if(transa_ && transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
1,
true,
true,
true>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(transa_ && !transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
1,
true,
true,
false>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(!transa_ && transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>
1,
true,
false,
true>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else if(!transa_ && !transb_) {
gemm_impl<
at::Half,
256,
256,
128,
32,
4,
4,
32,
32,
4,
2,
S<8,32,1>,
S<1,0,2>,
S<1,0,2>,
1,
true,
false,
false>
(CUDABLAS_GEMM_ARGS(at::Half));
}
else {
TORCH_CHECK(false, "unreachable");
}
}
}
#endif
}
template <>
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half));
}
} // namespace at::native

View File

@ -0,0 +1,289 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#undef __HIP_NO_HALF_CONVERSIONS__
#include <cstdlib>
#include <initializer_list>
#include <numeric>
#include <ATen/ATen.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <ATen/native/hip/ck_gemm.h>
#include <ck/ck.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/utility/data_type.hpp>
#include <ck/library/reference_tensor_operation/cpu/reference_gemm.hpp>
#include <ck/library/utility/check_err.hpp>
#include <ck/library/utility/device_memory.hpp>
#include <ck/library/utility/fill.hpp>
#include <ck/library/utility/host_tensor.hpp>
#include <ck/library/utility/host_tensor_generator.hpp>
#include <ck/library/utility/literals.hpp>
#include <ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp>
// Define commonly used types.
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
namespace at::native {
template <typename T>
struct CkMathType {
using dtype = T;
};
template <>
struct CkMathType<at::BFloat16> {
using dtype = ck::bhalf_t;
};
template <>
struct CkMathType<at::Half> {
using dtype = ck::half_t;
};
template <bool A, bool B>
struct CkTensorLayout {
// default goes to row-wise for now
using a_layout = Row;
using b_layout = Row;
};
// True denotes transpose is necessary. Default is Col, so return Row
template <>
struct CkTensorLayout<true, true> {
using a_layout = Col;
using b_layout = Col;
};
template <>
struct CkTensorLayout<true, false> {
using a_layout = Row;
using b_layout = Col;
};
template <>
struct CkTensorLayout<false, true> {
using a_layout = Col;
using b_layout = Row;
};
template <>
struct CkTensorLayout<false, false> {
using a_layout = Row;
using b_layout = Row;
};
// Elementwise Operators
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename C, typename AB>
__host__ __device__ constexpr void operator()(C& c, const AB& ab) const;
template<>
__host__ __device__ constexpr void operator()<float, float>
(float& c, const float& ab) const
{
c = alpha_ * ab;
};
template<>
__host__ __device__ constexpr void operator()<ck::bhalf_t, ck::bhalf_t>
(ck::bhalf_t& c, const ck::bhalf_t& ab) const
{
c = alpha_ * ab;
};
template<>
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>
(ck::half_t& c, const ck::half_t& ab) const
{
c = alpha_ * ab;
};
float alpha_;
// TODO: Leaving for now, will use later
float beta_;
};
template <
typename Dtype,
int BLOCK_SIZE,
int MBLOCK,
int NBLOCK,
int KBLOCK,
int AK1,
int BK1,
int MPER_XDL,
int NPER_XDL,
int MPER_WAVE,
int NPER_WAVE,
typename ABLOCK_CLUSTER_LENS,
typename ABLOCK_CLUSTER_ORDER,
typename ABLOCK_SRC_ORDER,
int ABLOCK_VECTOR_DIM,
int ABLOCK_SCALAR_VEC,
int ABLOCK_SCALAR_VEC_AK1,
bool ABLOCK_LDS_EXTRAM,
typename BBLOCK_CLUSTER_LENS,
typename BBLOCK_CLUSTER_ORDER,
typename BBLOCK_SRC_ORDER,
int BBLOCK_VECTOR_DIM,
int BBLOCK_SCALAR_VEC,
int BBLOCK_SCALAR_VEC_AK1,
bool BBLOCK_LDS_EXTRAN,
int CMPER_WAVE,
int CNPER_WAVE,
typename BLOCK_CLUSTER_LENS,
typename CDE_SCALAR_VEC,
bool PADDING = false,
bool TRANSA = false,
bool TRANSB = false>
void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
// Get input information.
int M = m;
int N = n;
int K = k;
int StrideA = lda;
int StrideB = ldb;
int StrideC = ldc;
int KBatch = 1;
float falpha = alpha;
float fbeta = beta;
using ADataType = typename CkMathType<Dtype>::dtype;
using BDataType = typename CkMathType<Dtype>::dtype;
using CDataType = typename CkMathType<Dtype>::dtype;
using DDataType = typename CkMathType<Dtype>::dtype;
using AccDataType = float;
using CShuffleDataType = typename CkMathType<Dtype>::dtype;
using ALayout = typename CkTensorLayout<TRANSA, TRANSB>::a_layout;
using BLayout = typename CkTensorLayout<TRANSA, TRANSB>::b_layout;
using DLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = AlphaBetaAdd;
static constexpr auto GemmDefault =
ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding =
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmSpec,
BLOCK_SIZE,
MBLOCK,
NBLOCK,
KBLOCK,
AK1,
BK1,
MPER_XDL,
NPER_XDL,
MPER_WAVE,
NPER_WAVE,
ABLOCK_CLUSTER_LENS,
ABLOCK_CLUSTER_ORDER,
ABLOCK_SRC_ORDER,
ABLOCK_VECTOR_DIM,
ABLOCK_SCALAR_VEC,
ABLOCK_SCALAR_VEC_AK1,
ABLOCK_LDS_EXTRAM,
BBLOCK_CLUSTER_LENS,
BBLOCK_CLUSTER_ORDER,
BBLOCK_SRC_ORDER,
BBLOCK_VECTOR_DIM,
BBLOCK_SCALAR_VEC,
BBLOCK_SCALAR_VEC_AK1,
BBLOCK_LDS_EXTRAN,
CMPER_WAVE,
CNPER_WAVE,
BLOCK_CLUSTER_LENS,
CDE_SCALAR_VEC>;
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{alpha, beta};
using DDataArrayType = std::array<const void*, 0>;
DDataArrayType DDataArray;
// We swap A and B inputs here as a temporary workaround
auto argument = gemm.MakeArgument(
reinterpret_cast<const void*>(b),
reinterpret_cast<const void*>(a),
DDataArray,
reinterpret_cast<void*>(c),
N,
M,
K,
StrideB,
StrideA,
std::array<ck::index_t, 0>{},
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
auto stream = at::cuda::getCurrentHIPStream().stream();
invoker.Run(argument, StreamConfig{stream, false});
}
} // namespace at::native

View File

@ -57,24 +57,6 @@ inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
return *reinterpret_cast<const __nv_bfloat16*>(&x);
}
#endif
#if defined(__HIPCC__) && defined(USE_ROCM)
// 6.2.0 introduced __hip_bfloat16_raw
#if defined(__BF16_HOST_DEVICE__)
inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
x = __hip_bfloat16_raw(value).x;
}
inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
return __hip_bfloat16(__hip_bfloat16_raw{x});
}
#else // !defined(__BF16_HOST_DEVICE__)
inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
x = value.data;
}
inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
return __hip_bfloat16{x};
}
#endif // !defined(__BF16_HOST_DEVICE__)
#endif // defined(__HIPCC__) && defined(USE_ROCM)
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
inline C10_HOST_DEVICE BFloat16::BFloat16(

View File

@ -13,9 +13,6 @@
#if defined(__CUDACC__) && !defined(USE_ROCM)
#include <cuda_bf16.h>
#endif
#if defined(__HIPCC__) && defined(USE_ROCM)
#include <hip/hip_bf16.h>
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
#if defined(CL_SYCL_LANGUAGE_VERSION)
@ -110,10 +107,6 @@ struct alignas(2) BFloat16 {
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
#endif
#if defined(__HIPCC__) && defined(USE_ROCM)
inline C10_HOST_DEVICE BFloat16(const __hip_bfloat16& value);
explicit inline C10_HOST_DEVICE operator __hip_bfloat16() const;
#endif
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);

View File

@ -8487,6 +8487,22 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.assertEqual(out1, out2)
self.assertEqual(out_ref, out2.cpu())
@skipCUDAIfNotRocm
@unittest.skipIf(not blaslt_supported_device(), "blasLt not supported on current device")
@setBlasBackendsToDefaultFinally
def test_ck_blas_library(self):
m1 = torch.randint(2, 5, (7168, 8192), device='cuda', dtype=torch.float)
m2 = torch.randint(2, 5, (1280, 8192), device='cuda', dtype=torch.float)
torch.backends.cuda.preferred_blas_library('ck')
ck_out = torch.nn.functional.linear(m1, m2)
cpu_out = torch.nn.functional.linear(m1.cpu(), m2.cpu())
self.assertEqual(ck_out, cpu_out)
def test_permute_matmul(self):
a = torch.ones([2, 5, 24, 24])
b = torch.ones([3, 2, 5, 24, 24])

View File

@ -1281,6 +1281,7 @@ def _set_blas_preferred_backend(arg: torch._C._BlasBackend): ...
class _BlasBackend:
Cublas: _BlasBackend
Cublaslt: _BlasBackend
Ck: _BlasBackend
class ConvBackend(Enum): ...

View File

@ -216,6 +216,7 @@ _BlasBackends = {
"cublas": torch._C._BlasBackend.Cublas,
"cublaslt": torch._C._BlasBackend.Cublaslt,
"hipblaslt": torch._C._BlasBackend.Cublaslt, # alias
"ck": torch._C._BlasBackend.Ck,
}
_BlasBackends_str = ", ".join(_BlasBackends.keys())
@ -224,16 +225,17 @@ def preferred_blas_library(
backend: Union[None, str, torch._C._BlasBackend] = None
) -> torch._C._BlasBackend:
r"""
Override the library PyTorch uses for BLAS operations. Choose between cuBLAS and cuBLASLt.
Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only].
.. warning:: This flag is experimental and subject to change.
When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available.
For PyTorch built for ROCm, hipBLAS and hipBLASLt may offer different performance.
For PyTorch built for ROCm, hipBLAS, hipBLASLt, and CK may offer different performance.
This flag (a :class:`str`) allows overriding which BLAS library to use.
* If `"cublas"` is set then cuBLAS will be used wherever possible.
* If `"cublaslt"` is set then cuBLASLt will be used wherever possible.
* If `"ck"` is set then CK will be used wherever possible.
* When no input is given, this function returns the currently preferred library.
* User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt
globally.

View File

@ -2079,7 +2079,8 @@ Call this whenever a new thread is created in order to propagate values from
py::enum_<at::BlasBackend>(py_module, "_BlasBackend")
.value("Cublas", at::BlasBackend::Cublas)
.value("Cublaslt", at::BlasBackend::Cublaslt);
.value("Cublaslt", at::BlasBackend::Cublaslt)
.value("Ck", at::BlasBackend::Ck);
py_module.def("_set_blas_preferred_backend", [](at::BlasBackend b) {
at::globalContext().setBlasPreferredBackend(b);

View File

@ -5,9 +5,9 @@
#endif
#include <ATen/ATen.h>
#if !defined(USE_ROCM)
#include <cuda_bf16.h>
#endif
namespace c10d::symmetric_memory {
template <typename T>