mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] CK-based GEMM (#131004)
- composable_kernel as a third_party submodule - "ck" as a `torch.backends.cuda.preferred_linalg_library()` - reference CK gemm implementations for float, bfloat16, and half types Pull Request resolved: https://github.com/pytorch/pytorch/pull/131004 Approved by: https://github.com/xw285cornell, https://github.com/pruthvistony Co-authored-by: Andres Lugo <Andy.LugoReyes@amd.com> Co-authored-by: Pruthvi Madugundu <pruthvigithub@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
0a2407b93c
commit
3f3b692a00
4
.gitmodules
vendored
4
.gitmodules
vendored
@ -127,3 +127,7 @@
|
||||
[submodule "third_party/NVTX"]
|
||||
path = third_party/NVTX
|
||||
url = https://github.com/NVIDIA/NVTX.git
|
||||
[submodule "third_party/composable_kernel"]
|
||||
path = third_party/composable_kernel
|
||||
url = https://github.com/ROCm/composable_kernel.git
|
||||
branch = develop
|
||||
|
@ -7,7 +7,7 @@
|
||||
|
||||
namespace at {
|
||||
|
||||
enum class BlasBackend : int8_t { Cublas, Cublaslt };
|
||||
enum class BlasBackend : int8_t { Cublas, Cublaslt, Ck };
|
||||
|
||||
inline std::string BlasBackendToString(at::BlasBackend backend) {
|
||||
switch (backend) {
|
||||
@ -15,6 +15,8 @@ inline std::string BlasBackendToString(at::BlasBackend backend) {
|
||||
return "at::BlasBackend::Cublas";
|
||||
case BlasBackend::Cublaslt:
|
||||
return "at::BlasBackend::Cublaslt";
|
||||
case BlasBackend::Ck:
|
||||
return "at::BlasBackend::Ck";
|
||||
default:
|
||||
TORCH_CHECK(false, "Unknown blas backend");
|
||||
}
|
||||
|
@ -321,6 +321,8 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
|
||||
#else
|
||||
TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
|
||||
"Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
|
||||
TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(),
|
||||
"Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm.");
|
||||
if (b != at::BlasBackend::Cublas) {
|
||||
TORCH_WARN_ONCE(
|
||||
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
|
||||
|
@ -149,6 +149,9 @@ class TORCH_API Context {
|
||||
static bool hasCuBLASLt() {
|
||||
return detail::getCUDAHooks().hasCuBLASLt();
|
||||
}
|
||||
static bool hasROCM() {
|
||||
return detail::getCUDAHooks().hasROCM();
|
||||
}
|
||||
static bool hasHIP() {
|
||||
return detail::getHIPHooks().hasHIP();
|
||||
}
|
||||
|
@ -19,6 +19,7 @@
|
||||
// until hipblas has an API to accept flags, we must use rocblas here
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <rocblas/rocblas.h>
|
||||
#include <ATen/native/hip/ck_gemm.h>
|
||||
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
|
||||
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
|
||||
// needed to work around calling rocblas API instead of hipblas API
|
||||
@ -793,6 +794,7 @@ inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
static_assert(false && sizeof(Dtype), "at::cuda::blas::gemm_internal_cublas: not implemented");
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
void gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
@ -1001,6 +1003,11 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
|
||||
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
|
||||
#endif
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
|
||||
}
|
||||
@ -1012,6 +1019,11 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
@ -1055,6 +1067,11 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
@ -1066,6 +1083,11 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
|
@ -79,6 +79,7 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
|
||||
transpose_tensor = tensor.is_contiguous();
|
||||
return resolve_conj_if_indicated(tensor, true);
|
||||
}
|
||||
|
||||
IntArrayRef tensor_strides = tensor.strides();
|
||||
IntArrayRef tensor_sizes = tensor.sizes();
|
||||
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
|
||||
|
@ -7,7 +7,6 @@
|
||||
|
||||
// ROCm 6.3 is planned to have these functions, but until then here they are.
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60201
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
|
||||
|
24
aten/src/ATen/native/hip/ck_gemm.h
Normal file
24
aten/src/ATen/native/hip/ck_gemm.h
Normal file
@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/hip/HIPBlas.h>
|
||||
namespace at::native {
|
||||
|
||||
|
||||
template <typename Dtype>
|
||||
inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented");
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double));
|
||||
template <>
|
||||
void gemm_internal_ck<float>(CUDABLAS_GEMM_ARGTYPES(float));
|
||||
template <>
|
||||
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
||||
template <>
|
||||
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
||||
|
||||
|
||||
|
||||
} // namespace at::native
|
479
aten/src/ATen/native/hip/ck_gemm_bfloat16.hip
Normal file
479
aten/src/ATen/native/hip/ck_gemm_bfloat16.hip
Normal file
@ -0,0 +1,479 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
|
||||
#include <ATen/native/hip/ck_gemm.h>
|
||||
#include <ATen/native/hip/ck_gemm_template.h>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
// If any of the shapes cant be tiled, we must use padding.
|
||||
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
|
||||
// Dispatch to best implementation.
|
||||
// TODO add more configurations. Optimize.
|
||||
bool transa_ = std::tolower(transa) != 'n';
|
||||
bool transb_ = std::tolower(transb) != 'n';
|
||||
|
||||
if (use_padding) {
|
||||
if (m <= 128) {
|
||||
if(transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(!transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(!transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
|
||||
} else {
|
||||
if(transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(!transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(!transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
{
|
||||
if(transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
false,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
false,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(!transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
false,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else if(!transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::BFloat16,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
2,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
false,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <>
|
||||
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
|
||||
} // namespace at::native
|
486
aten/src/ATen/native/hip/ck_gemm_float.hip
Normal file
486
aten/src/ATen/native/hip/ck_gemm_float.hip
Normal file
@ -0,0 +1,486 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
|
||||
#include <ATen/native/hip/ck_gemm.h>
|
||||
#include <ATen/native/hip/ck_gemm_template.h>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void dispatch_float_gemm(CUDABLAS_GEMM_ARGTYPES(float)) {
|
||||
// If any of the shapes cant be tiled, we must use padding.
|
||||
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
|
||||
// Dispatch to best implementation.
|
||||
// TODO add more configurations. Optimize.
|
||||
bool transa_ = std::tolower(transa) != 'n';
|
||||
bool transb_ = std::tolower(transb) != 'n';
|
||||
|
||||
if (use_padding) {
|
||||
if (m <= 128) {
|
||||
if(transa_ && transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else if(transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else if(!transa_ && transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else if(!transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
if(transa_ && transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else if(transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else if(!transa_ && transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else if(!transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
true,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
{
|
||||
if(transa_ && transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
false,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else if(transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
false,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else if(!transa_ && transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
false,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else if(!transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
float,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
S<8,32,1>,
|
||||
S<0,2,1>,
|
||||
S<0,2,1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<4>,
|
||||
false,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <>
|
||||
void gemm_internal_ck<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
|
||||
dispatch_float_gemm(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
|
||||
// temporarily put this here until we implement double support
|
||||
template <>
|
||||
void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
306
aten/src/ATen/native/hip/ck_gemm_half.hip
Normal file
306
aten/src/ATen/native/hip/ck_gemm_half.hip
Normal file
@ -0,0 +1,306 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
|
||||
#include <ATen/native/hip/ck_gemm.h>
|
||||
#include <ATen/native/hip/ck_gemm_template.h>
|
||||
|
||||
#include <ck/utility/sequence.hpp>
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void dispatch_half_gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
#if 0
|
||||
// If any of the shapes cant be tiled, we must use padding.
|
||||
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
|
||||
// Dispatch to best implementation.
|
||||
// TODO add more configurations. Optimize.
|
||||
|
||||
bool transa_ = std::tolower(transa) != 'n';
|
||||
bool transb_ = std::tolower(transb) != 'n';
|
||||
|
||||
if (use_padding) {
|
||||
if (m <= 128) {
|
||||
if(transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
1,
|
||||
true,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
1,
|
||||
true,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(!transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
1,
|
||||
true,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(!transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
1,
|
||||
true,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
|
||||
|
||||
|
||||
} else {
|
||||
|
||||
if(transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
1,
|
||||
true,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
1,
|
||||
true,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(!transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
1,
|
||||
true,
|
||||
false,
|
||||
true>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(!transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
1,
|
||||
true,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
{
|
||||
if(transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
1,
|
||||
true,
|
||||
true,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
1,
|
||||
true,
|
||||
true,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(!transa_ && transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>
|
||||
1,
|
||||
true,
|
||||
false,
|
||||
true>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else if(!transa_ && !transb_) {
|
||||
gemm_impl<
|
||||
at::Half,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
4,
|
||||
4,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<8,32,1>,
|
||||
S<1,0,2>,
|
||||
S<1,0,2>,
|
||||
1,
|
||||
true,
|
||||
false,
|
||||
false>
|
||||
(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
|
||||
} // namespace at::native
|
289
aten/src/ATen/native/hip/ck_gemm_template.h
Normal file
289
aten/src/ATen/native/hip/ck_gemm_template.h
Normal file
@ -0,0 +1,289 @@
|
||||
/*
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
#include <numeric>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
#include <ATen/native/hip/ck_gemm.h>
|
||||
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
|
||||
#include <ck/library/reference_tensor_operation/cpu/reference_gemm.hpp>
|
||||
#include <ck/library/utility/check_err.hpp>
|
||||
#include <ck/library/utility/device_memory.hpp>
|
||||
#include <ck/library/utility/fill.hpp>
|
||||
#include <ck/library/utility/host_tensor.hpp>
|
||||
#include <ck/library/utility/host_tensor_generator.hpp>
|
||||
#include <ck/library/utility/literals.hpp>
|
||||
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp>
|
||||
|
||||
// Define commonly used types.
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
namespace at::native {
|
||||
|
||||
template <typename T>
|
||||
struct CkMathType {
|
||||
using dtype = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CkMathType<at::BFloat16> {
|
||||
using dtype = ck::bhalf_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CkMathType<at::Half> {
|
||||
using dtype = ck::half_t;
|
||||
};
|
||||
|
||||
|
||||
template <bool A, bool B>
|
||||
struct CkTensorLayout {
|
||||
// default goes to row-wise for now
|
||||
using a_layout = Row;
|
||||
using b_layout = Row;
|
||||
};
|
||||
|
||||
// True denotes transpose is necessary. Default is Col, so return Row
|
||||
template <>
|
||||
struct CkTensorLayout<true, true> {
|
||||
using a_layout = Col;
|
||||
using b_layout = Col;
|
||||
};
|
||||
|
||||
|
||||
template <>
|
||||
struct CkTensorLayout<true, false> {
|
||||
using a_layout = Row;
|
||||
using b_layout = Col;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CkTensorLayout<false, true> {
|
||||
using a_layout = Col;
|
||||
using b_layout = Row;
|
||||
};
|
||||
|
||||
|
||||
template <>
|
||||
struct CkTensorLayout<false, false> {
|
||||
using a_layout = Row;
|
||||
using b_layout = Row;
|
||||
};
|
||||
|
||||
|
||||
// Elementwise Operators
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename C, typename AB>
|
||||
__host__ __device__ constexpr void operator()(C& c, const AB& ab) const;
|
||||
|
||||
template<>
|
||||
__host__ __device__ constexpr void operator()<float, float>
|
||||
(float& c, const float& ab) const
|
||||
{
|
||||
c = alpha_ * ab;
|
||||
};
|
||||
|
||||
template<>
|
||||
__host__ __device__ constexpr void operator()<ck::bhalf_t, ck::bhalf_t>
|
||||
(ck::bhalf_t& c, const ck::bhalf_t& ab) const
|
||||
{
|
||||
c = alpha_ * ab;
|
||||
};
|
||||
|
||||
template<>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>
|
||||
(ck::half_t& c, const ck::half_t& ab) const
|
||||
{
|
||||
c = alpha_ * ab;
|
||||
};
|
||||
|
||||
float alpha_;
|
||||
// TODO: Leaving for now, will use later
|
||||
float beta_;
|
||||
};
|
||||
|
||||
template <
|
||||
typename Dtype,
|
||||
int BLOCK_SIZE,
|
||||
int MBLOCK,
|
||||
int NBLOCK,
|
||||
int KBLOCK,
|
||||
int AK1,
|
||||
int BK1,
|
||||
int MPER_XDL,
|
||||
int NPER_XDL,
|
||||
int MPER_WAVE,
|
||||
int NPER_WAVE,
|
||||
typename ABLOCK_CLUSTER_LENS,
|
||||
typename ABLOCK_CLUSTER_ORDER,
|
||||
typename ABLOCK_SRC_ORDER,
|
||||
int ABLOCK_VECTOR_DIM,
|
||||
int ABLOCK_SCALAR_VEC,
|
||||
int ABLOCK_SCALAR_VEC_AK1,
|
||||
bool ABLOCK_LDS_EXTRAM,
|
||||
typename BBLOCK_CLUSTER_LENS,
|
||||
typename BBLOCK_CLUSTER_ORDER,
|
||||
typename BBLOCK_SRC_ORDER,
|
||||
int BBLOCK_VECTOR_DIM,
|
||||
int BBLOCK_SCALAR_VEC,
|
||||
int BBLOCK_SCALAR_VEC_AK1,
|
||||
bool BBLOCK_LDS_EXTRAN,
|
||||
int CMPER_WAVE,
|
||||
int CNPER_WAVE,
|
||||
typename BLOCK_CLUSTER_LENS,
|
||||
typename CDE_SCALAR_VEC,
|
||||
bool PADDING = false,
|
||||
bool TRANSA = false,
|
||||
bool TRANSB = false>
|
||||
void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
// Get input information.
|
||||
int M = m;
|
||||
int N = n;
|
||||
int K = k;
|
||||
|
||||
int StrideA = lda;
|
||||
int StrideB = ldb;
|
||||
int StrideC = ldc;
|
||||
|
||||
int KBatch = 1;
|
||||
|
||||
float falpha = alpha;
|
||||
float fbeta = beta;
|
||||
|
||||
using ADataType = typename CkMathType<Dtype>::dtype;
|
||||
using BDataType = typename CkMathType<Dtype>::dtype;
|
||||
using CDataType = typename CkMathType<Dtype>::dtype;
|
||||
using DDataType = typename CkMathType<Dtype>::dtype;
|
||||
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = typename CkMathType<Dtype>::dtype;
|
||||
|
||||
using ALayout = typename CkTensorLayout<TRANSA, TRANSB>::a_layout;
|
||||
using BLayout = typename CkTensorLayout<TRANSA, TRANSB>::b_layout;
|
||||
|
||||
using DLayout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = AlphaBetaAdd;
|
||||
|
||||
|
||||
static constexpr auto GemmDefault =
|
||||
ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding =
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault;
|
||||
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<>,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
BLOCK_SIZE,
|
||||
MBLOCK,
|
||||
NBLOCK,
|
||||
KBLOCK,
|
||||
AK1,
|
||||
BK1,
|
||||
MPER_XDL,
|
||||
NPER_XDL,
|
||||
MPER_WAVE,
|
||||
NPER_WAVE,
|
||||
ABLOCK_CLUSTER_LENS,
|
||||
ABLOCK_CLUSTER_ORDER,
|
||||
ABLOCK_SRC_ORDER,
|
||||
ABLOCK_VECTOR_DIM,
|
||||
ABLOCK_SCALAR_VEC,
|
||||
ABLOCK_SCALAR_VEC_AK1,
|
||||
ABLOCK_LDS_EXTRAM,
|
||||
BBLOCK_CLUSTER_LENS,
|
||||
BBLOCK_CLUSTER_ORDER,
|
||||
BBLOCK_SRC_ORDER,
|
||||
BBLOCK_VECTOR_DIM,
|
||||
BBLOCK_SCALAR_VEC,
|
||||
BBLOCK_SCALAR_VEC_AK1,
|
||||
BBLOCK_LDS_EXTRAN,
|
||||
CMPER_WAVE,
|
||||
CNPER_WAVE,
|
||||
BLOCK_CLUSTER_LENS,
|
||||
CDE_SCALAR_VEC>;
|
||||
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{alpha, beta};
|
||||
|
||||
|
||||
using DDataArrayType = std::array<const void*, 0>;
|
||||
DDataArrayType DDataArray;
|
||||
|
||||
// We swap A and B inputs here as a temporary workaround
|
||||
auto argument = gemm.MakeArgument(
|
||||
reinterpret_cast<const void*>(b),
|
||||
reinterpret_cast<const void*>(a),
|
||||
DDataArray,
|
||||
reinterpret_cast<void*>(c),
|
||||
N,
|
||||
M,
|
||||
K,
|
||||
StrideB,
|
||||
StrideA,
|
||||
std::array<ck::index_t, 0>{},
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
invoker.Run(argument, StreamConfig{stream, false});
|
||||
}
|
||||
|
||||
} // namespace at::native
|
@ -57,24 +57,6 @@ inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
|
||||
return *reinterpret_cast<const __nv_bfloat16*>(&x);
|
||||
}
|
||||
#endif
|
||||
#if defined(__HIPCC__) && defined(USE_ROCM)
|
||||
// 6.2.0 introduced __hip_bfloat16_raw
|
||||
#if defined(__BF16_HOST_DEVICE__)
|
||||
inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
|
||||
x = __hip_bfloat16_raw(value).x;
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
|
||||
return __hip_bfloat16(__hip_bfloat16_raw{x});
|
||||
}
|
||||
#else // !defined(__BF16_HOST_DEVICE__)
|
||||
inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
|
||||
x = value.data;
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
|
||||
return __hip_bfloat16{x};
|
||||
}
|
||||
#endif // !defined(__BF16_HOST_DEVICE__)
|
||||
#endif // defined(__HIPCC__) && defined(USE_ROCM)
|
||||
|
||||
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
||||
inline C10_HOST_DEVICE BFloat16::BFloat16(
|
||||
|
@ -13,9 +13,6 @@
|
||||
#if defined(__CUDACC__) && !defined(USE_ROCM)
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
#if defined(__HIPCC__) && defined(USE_ROCM)
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
|
||||
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
||||
#if defined(CL_SYCL_LANGUAGE_VERSION)
|
||||
@ -110,10 +107,6 @@ struct alignas(2) BFloat16 {
|
||||
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
|
||||
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
|
||||
#endif
|
||||
#if defined(__HIPCC__) && defined(USE_ROCM)
|
||||
inline C10_HOST_DEVICE BFloat16(const __hip_bfloat16& value);
|
||||
explicit inline C10_HOST_DEVICE operator __hip_bfloat16() const;
|
||||
#endif
|
||||
|
||||
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
||||
inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
|
||||
|
@ -8487,6 +8487,22 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
self.assertEqual(out1, out2)
|
||||
self.assertEqual(out_ref, out2.cpu())
|
||||
|
||||
@skipCUDAIfNotRocm
|
||||
@unittest.skipIf(not blaslt_supported_device(), "blasLt not supported on current device")
|
||||
@setBlasBackendsToDefaultFinally
|
||||
def test_ck_blas_library(self):
|
||||
m1 = torch.randint(2, 5, (7168, 8192), device='cuda', dtype=torch.float)
|
||||
m2 = torch.randint(2, 5, (1280, 8192), device='cuda', dtype=torch.float)
|
||||
|
||||
torch.backends.cuda.preferred_blas_library('ck')
|
||||
ck_out = torch.nn.functional.linear(m1, m2)
|
||||
|
||||
cpu_out = torch.nn.functional.linear(m1.cpu(), m2.cpu())
|
||||
|
||||
self.assertEqual(ck_out, cpu_out)
|
||||
|
||||
|
||||
|
||||
def test_permute_matmul(self):
|
||||
a = torch.ones([2, 5, 24, 24])
|
||||
b = torch.ones([3, 2, 5, 24, 24])
|
||||
|
1
third_party/composable_kernel
vendored
Submodule
1
third_party/composable_kernel
vendored
Submodule
Submodule third_party/composable_kernel added at 11b7a4db00
@ -1281,6 +1281,7 @@ def _set_blas_preferred_backend(arg: torch._C._BlasBackend): ...
|
||||
class _BlasBackend:
|
||||
Cublas: _BlasBackend
|
||||
Cublaslt: _BlasBackend
|
||||
Ck: _BlasBackend
|
||||
|
||||
class ConvBackend(Enum): ...
|
||||
|
||||
|
@ -216,6 +216,7 @@ _BlasBackends = {
|
||||
"cublas": torch._C._BlasBackend.Cublas,
|
||||
"cublaslt": torch._C._BlasBackend.Cublaslt,
|
||||
"hipblaslt": torch._C._BlasBackend.Cublaslt, # alias
|
||||
"ck": torch._C._BlasBackend.Ck,
|
||||
}
|
||||
_BlasBackends_str = ", ".join(_BlasBackends.keys())
|
||||
|
||||
@ -224,16 +225,17 @@ def preferred_blas_library(
|
||||
backend: Union[None, str, torch._C._BlasBackend] = None
|
||||
) -> torch._C._BlasBackend:
|
||||
r"""
|
||||
Override the library PyTorch uses for BLAS operations. Choose between cuBLAS and cuBLASLt.
|
||||
Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only].
|
||||
|
||||
.. warning:: This flag is experimental and subject to change.
|
||||
|
||||
When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available.
|
||||
For PyTorch built for ROCm, hipBLAS and hipBLASLt may offer different performance.
|
||||
For PyTorch built for ROCm, hipBLAS, hipBLASLt, and CK may offer different performance.
|
||||
This flag (a :class:`str`) allows overriding which BLAS library to use.
|
||||
|
||||
* If `"cublas"` is set then cuBLAS will be used wherever possible.
|
||||
* If `"cublaslt"` is set then cuBLASLt will be used wherever possible.
|
||||
* If `"ck"` is set then CK will be used wherever possible.
|
||||
* When no input is given, this function returns the currently preferred library.
|
||||
* User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt
|
||||
globally.
|
||||
|
@ -2079,7 +2079,8 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
|
||||
py::enum_<at::BlasBackend>(py_module, "_BlasBackend")
|
||||
.value("Cublas", at::BlasBackend::Cublas)
|
||||
.value("Cublaslt", at::BlasBackend::Cublaslt);
|
||||
.value("Cublaslt", at::BlasBackend::Cublaslt)
|
||||
.value("Ck", at::BlasBackend::Ck);
|
||||
|
||||
py_module.def("_set_blas_preferred_backend", [](at::BlasBackend b) {
|
||||
at::globalContext().setBlasPreferredBackend(b);
|
||||
|
@ -5,9 +5,9 @@
|
||||
#endif
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#endif
|
||||
namespace c10d::symmetric_memory {
|
||||
|
||||
template <typename T>
|
||||
|
Reference in New Issue
Block a user