mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] TunableOp (#114894)
Some operations, such as GEMMs, could be implemented using more than one library or more than one technique. For example, a GEMM could be implemented for CUDA or ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and hipblaslt libraries allow the user to query for all possible algorithms and then choose one. How does one know which implementation is the fastest and should be chosen? That's what TunableOp provides.
See the README.md for additional details.
TunableOp was ported from onnxruntime starting from commit 08dce54266
. The content was significantly modified and reorganized for use within PyTorch. The files copied and their approximate new names or source content location within aten/src/ATen/cuda/tunable include the following:
- onnxruntime/core/framework/tunable.h -> Tunable.h
- onnxruntime/core/framework/tuning_context.h -> Tunable.h
- onnxruntime/core/framework/tuning_context_impl.h -> Tunable.cpp
- onnxruntime/core/providers/rocm/tunable/gemm_common.h -> GemmCommon.h
- onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h -> GemmHipblaslt.h
- onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h -> GemmRocblas.h
- onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh -> TunableGemm.h
- onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc -> Tunable.cpp
- onnxruntime/core/providers/rocm/tunable/util.h -> StreamTimer.h
- onnxruntime/core/providers/rocm/tunable/util.cc -> StreamTimer.cpp
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114894
Approved by: https://github.com/xw285cornell, https://github.com/jianyuh
This commit is contained in:
committed by
PyTorch MergeBot
parent
90f785dc34
commit
0e6eee3c89
@ -228,6 +228,7 @@ filegroup(
|
||||
[
|
||||
"aten/src/ATen/cuda/*.cpp",
|
||||
"aten/src/ATen/cuda/detail/*.cpp",
|
||||
"aten/src/ATen/cuda/tunable/*.cpp",
|
||||
"aten/src/ATen/cudnn/*.cpp",
|
||||
"aten/src/ATen/native/cuda/*.cpp",
|
||||
"aten/src/ATen/native/cuda/linalg/*.cpp",
|
||||
|
@ -60,11 +60,11 @@ endif()
|
||||
|
||||
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
|
||||
file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp")
|
||||
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh")
|
||||
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp")
|
||||
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh" "cuda/tunable/*.cuh" "cuda/tunable/*.h")
|
||||
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp" "cuda/tunable/*.cpp")
|
||||
file(GLOB cuda_nvrtc_stub_h "cuda/nvrtc_stub/*.h")
|
||||
file(GLOB cuda_nvrtc_stub_cpp "cuda/nvrtc_stub/*.cpp")
|
||||
file(GLOB cuda_cu "cuda/*.cu" "cuda/detail/*.cu")
|
||||
file(GLOB cuda_cu "cuda/*.cu" "cuda/detail/*.cu" "cuda/tunable/*.cu")
|
||||
file(GLOB cudnn_h "cudnn/*.h" "cudnn/*.cuh")
|
||||
file(GLOB cudnn_cpp "cudnn/*.cpp")
|
||||
file(GLOB ops_h "ops/*.h")
|
||||
@ -72,10 +72,10 @@ file(GLOB ops_h "ops/*.h")
|
||||
file(GLOB xpu_h "xpu/*.h" "xpu/detail/*.h")
|
||||
file(GLOB xpu_cpp "xpu/*.cpp" "xpu/detail/*.cpp")
|
||||
|
||||
file(GLOB hip_h "hip/*.h" "hip/detail/*.h" "hip/*.cuh" "hip/detail/*.cuh" "hip/impl/*.h")
|
||||
file(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp" "hip/impl/*.cpp")
|
||||
file(GLOB hip_h "hip/*.h" "hip/detail/*.h" "hip/*.cuh" "hip/detail/*.cuh" "hip/impl/*.h" "hip/tunable/*.cuh" "hip/tunable/*.h")
|
||||
file(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp" "hip/impl/*.cpp" "hip/tunable/*.cpp")
|
||||
list(REMOVE_ITEM hip_cpp "${CMAKE_CURRENT_SOURCE_DIR}/hip/detail/LazyNVRTC.cpp")
|
||||
file(GLOB hip_hip "hip/*.hip" "hip/detail/*.hip" "hip/impl/*.hip")
|
||||
file(GLOB hip_hip "hip/*.hip" "hip/detail/*.hip" "hip/impl/*.hip" "hip/tunable/*.hip")
|
||||
file(GLOB hip_nvrtc_stub_h "hip/nvrtc_stub/*.h")
|
||||
file(GLOB hip_nvrtc_stub_cpp "hip/nvrtc_stub/*.cpp")
|
||||
file(GLOB miopen_h "miopen/*.h")
|
||||
|
@ -6,6 +6,8 @@
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/CUDADataType.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <c10/macros/Export.h>
|
||||
@ -232,7 +234,7 @@ namespace at::cuda::blas {
|
||||
} while (0)
|
||||
|
||||
template <>
|
||||
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
|
||||
void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
@ -245,7 +247,7 @@ void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
|
||||
}
|
||||
|
||||
template <>
|
||||
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
|
||||
void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
@ -258,7 +260,7 @@ void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
|
||||
}
|
||||
|
||||
template <>
|
||||
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
|
||||
void bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
@ -273,7 +275,7 @@ void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>))
|
||||
}
|
||||
|
||||
template <>
|
||||
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
|
||||
void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
@ -288,7 +290,7 @@ void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
|
||||
}
|
||||
|
||||
template <>
|
||||
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
|
||||
void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
@ -335,7 +337,7 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
|
||||
}
|
||||
|
||||
template <>
|
||||
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
||||
void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
BGEMM_CHECK_ARGVALUES(at::BFloat16);
|
||||
@ -361,8 +363,119 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
template <typename DType>
|
||||
inline void bgemm_tunable(CUDABLAS_BGEMM_ARGTYPES(DType)) {
|
||||
tunable::GemmStridedBatchedParams<DType> params;
|
||||
params.transa = transa;
|
||||
params.transb = transb;
|
||||
params.m = m;
|
||||
params.n = n;
|
||||
params.k = k;
|
||||
params.alpha = alpha;
|
||||
params.a = a;
|
||||
params.lda = lda;
|
||||
params.stride_a = stridea;
|
||||
params.b = b;
|
||||
params.ldb = ldb;
|
||||
params.stride_b = strideb;
|
||||
params.beta = beta;
|
||||
params.c = c;
|
||||
params.ldc = ldc;
|
||||
params.stride_c = stridec;
|
||||
params.batch = num_batches;
|
||||
|
||||
bool transa_ = ((transa != 'n') && (transa != 'N'));
|
||||
bool transb_ = ((transb != 'n') && (transb != 'N'));
|
||||
|
||||
if (transa_ && transb_) {
|
||||
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> bgemm{};
|
||||
bgemm(¶ms);
|
||||
}
|
||||
else if (transa_ && !transb_) {
|
||||
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> bgemm{};
|
||||
bgemm(¶ms);
|
||||
}
|
||||
else if (!transa_ && transb_) {
|
||||
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> bgemm{};
|
||||
bgemm(¶ms);
|
||||
}
|
||||
else if (!transa_ && !transb_) {
|
||||
static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> bgemm{};
|
||||
bgemm(¶ms);
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
|
||||
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
bgemm_tunable<double>(CUDABLAS_BGEMM_ARGS(double));
|
||||
}
|
||||
else {
|
||||
bgemm_internal<double>(CUDABLAS_BGEMM_ARGS(double));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
bgemm_tunable<float>(CUDABLAS_BGEMM_ARGS(float));
|
||||
}
|
||||
else {
|
||||
bgemm_internal<float>(CUDABLAS_BGEMM_ARGS(float));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
bgemm_tunable<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
|
||||
}
|
||||
else {
|
||||
bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
bgemm_tunable<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
|
||||
}
|
||||
else {
|
||||
bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
bgemm_tunable<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
|
||||
}
|
||||
else {
|
||||
bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
bgemm_tunable<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else {
|
||||
bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
@ -375,7 +488,7 @@ void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
|
||||
void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
@ -388,7 +501,7 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
|
||||
void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
@ -403,7 +516,7 @@ void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
|
||||
void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
@ -418,7 +531,7 @@ void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
// See Note [Writing Nondeterministic Operations]
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
@ -514,7 +627,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
globalContext().alertCuBLASConfigNotDeterministic();
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
cublasOperation_t opa = _cublasOpFromChar(transa);
|
||||
@ -558,6 +671,113 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
}
|
||||
|
||||
template <typename DType>
|
||||
inline void gemm_tunable(CUDABLAS_GEMM_ARGTYPES(DType)) {
|
||||
tunable::GemmParams<DType> params;
|
||||
params.transa = transa;
|
||||
params.transb = transb;
|
||||
params.m = m;
|
||||
params.n = n;
|
||||
params.k = k;
|
||||
params.alpha = alpha;
|
||||
params.a = a;
|
||||
params.lda = lda;
|
||||
params.b = b;
|
||||
params.ldb = ldb;
|
||||
params.beta = beta;
|
||||
params.c = c;
|
||||
params.ldc = ldc;
|
||||
|
||||
bool transa_ = ((transa != 'n') && (transa != 'N'));
|
||||
bool transb_ = ((transb != 'n') && (transb != 'N'));
|
||||
|
||||
if (transa_ && transb_) {
|
||||
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> gemm{};
|
||||
gemm(¶ms);
|
||||
}
|
||||
else if (transa_ && !transb_) {
|
||||
static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> gemm{};
|
||||
gemm(¶ms);
|
||||
}
|
||||
else if (!transa_ && transb_) {
|
||||
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> gemm{};
|
||||
gemm(¶ms);
|
||||
}
|
||||
else if (!transa_ && !transb_) {
|
||||
static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> gemm{};
|
||||
gemm(¶ms);
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
gemm_tunable<double>(CUDABLAS_GEMM_ARGS(double));
|
||||
}
|
||||
else {
|
||||
gemm_internal<double>(CUDABLAS_GEMM_ARGS(double));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
gemm_tunable<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
else {
|
||||
gemm_internal<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
gemm_tunable<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
|
||||
}
|
||||
else {
|
||||
gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
gemm_tunable<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
|
||||
}
|
||||
else {
|
||||
gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
gemm_tunable<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
else {
|
||||
gemm_internal<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
gemm_tunable<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
else {
|
||||
gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
}
|
||||
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
|
||||
|
@ -44,6 +44,8 @@ private:
|
||||
const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\
|
||||
Dtype *c, int64_t ldc
|
||||
|
||||
#define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc
|
||||
|
||||
template <typename Dtype>
|
||||
inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
AT_ERROR("at::cuda::blas::gemm: not implemented for ", typeid(Dtype).name());
|
||||
@ -62,6 +64,24 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
||||
template <>
|
||||
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
||||
|
||||
template <typename Dtype>
|
||||
inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
AT_ERROR("at::cuda::blas::gemm_internal: not implemented for ", typeid(Dtype).name());
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double));
|
||||
template <>
|
||||
void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float));
|
||||
template <>
|
||||
void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
|
||||
template <>
|
||||
void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
|
||||
template <>
|
||||
void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
||||
template <>
|
||||
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
||||
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
enum GEMMAndBiasActivationEpilogue {
|
||||
None,
|
||||
@ -131,6 +151,9 @@ void scaled_gemm(
|
||||
const Dtype *b, int64_t ldb, int64_t strideb, \
|
||||
at::opmath_type<Dtype> beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
|
||||
|
||||
#define CUDABLAS_BGEMM_ARGS(Dtype) \
|
||||
transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches
|
||||
|
||||
template <typename Dtype>
|
||||
inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name());
|
||||
@ -149,6 +172,24 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
|
||||
template <>
|
||||
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
|
||||
|
||||
template <typename Dtype>
|
||||
inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
AT_ERROR("at::cuda::blas::bgemm_internal: not implemented for ", typeid(Dtype).name());
|
||||
}
|
||||
|
||||
template <>
|
||||
void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double));
|
||||
template <>
|
||||
void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float));
|
||||
template <>
|
||||
void bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
|
||||
template <>
|
||||
void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
|
||||
template <>
|
||||
void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
|
||||
template <>
|
||||
void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION <= 50500
|
||||
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
|
||||
#define CUDABLAS_TRSM_ARGTYPES(Dtype) \
|
||||
|
174
aten/src/ATen/cuda/tunable/GemmCommon.h
Normal file
174
aten/src/ATen/cuda/tunable/GemmCommon.h
Normal file
@ -0,0 +1,174 @@
|
||||
// Original TunableOp is from onnxruntime.
|
||||
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
||||
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
//
|
||||
// Adapting TunableOp into PyTorch
|
||||
// Copyright (c) Advanced Micro Devices, Inc.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <ATen/cuda/tunable/TunableOp.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
enum class BlasOp {
|
||||
N = 0,
|
||||
T = 1
|
||||
};
|
||||
|
||||
inline std::string BlasOpToString(BlasOp op) {
|
||||
switch (op) {
|
||||
case BlasOp::N:
|
||||
return "N";
|
||||
case BlasOp::T:
|
||||
return "T";
|
||||
}
|
||||
TORCH_CHECK(false, "unrecognized BlasOp");
|
||||
return "N";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GemmParams : OpParams {
|
||||
std::string Signature() const override {
|
||||
return c10::str(transa, transb, "_", m, "_", n, "_", k);
|
||||
}
|
||||
|
||||
GemmParams* DeepCopy() const {
|
||||
GemmParams* copy = new GemmParams;
|
||||
*copy = *this;
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
size_t c_size = m * n * sizeof(T);
|
||||
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
||||
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
||||
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
||||
return copy;
|
||||
}
|
||||
|
||||
// only call on object returned by DeepCopy
|
||||
void Delete() {
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmParams<T> *other) {
|
||||
auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType<T>::value).device(at::kCUDA);
|
||||
// comparison done as 1D tensor
|
||||
at::Tensor ref = at::from_blob(c, {m*n}, options);
|
||||
at::Tensor oth = at::from_blob(other->c, {m*n}, options);
|
||||
at::Tensor ref_float = ref.to(at::kFloat);
|
||||
at::Tensor oth_float = oth.to(at::kFloat);
|
||||
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
double last_succeed_atol = 1;
|
||||
double last_succeed_rtol = 1;
|
||||
for (auto& atol : atols) {
|
||||
for (auto& rtol : rtols) {
|
||||
if (at::allclose(ref_float, oth_float, rtol, atol)) {
|
||||
last_succeed_atol = atol;
|
||||
last_succeed_rtol = rtol;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (last_succeed_atol == 1) {
|
||||
return FAIL;
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
|
||||
}
|
||||
|
||||
return OK;
|
||||
}
|
||||
|
||||
char transa;
|
||||
char transb;
|
||||
int64_t m;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
at::opmath_type<T> alpha;
|
||||
const T* a;
|
||||
int64_t lda;
|
||||
const T* b;
|
||||
int64_t ldb;
|
||||
at::opmath_type<T> beta;
|
||||
T* c;
|
||||
int64_t ldc;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GemmStridedBatchedParams : OpParams {
|
||||
std::string Signature() const override {
|
||||
return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch);
|
||||
}
|
||||
|
||||
GemmStridedBatchedParams* DeepCopy() const {
|
||||
GemmStridedBatchedParams* copy = new GemmStridedBatchedParams;
|
||||
*copy = *this;
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
size_t c_size = batch * stride_c * sizeof(T);
|
||||
copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
|
||||
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
|
||||
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
|
||||
return copy;
|
||||
}
|
||||
|
||||
// only call on object returned by DeepCopy
|
||||
void Delete() {
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(c);
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
|
||||
auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType<T>::value).device(at::kCUDA);
|
||||
// comparison done as 1D tensor
|
||||
at::Tensor ref = at::from_blob(c, {batch*stride_c}, options);
|
||||
at::Tensor oth = at::from_blob(other->c, {batch*stride_c}, options);
|
||||
at::Tensor ref_float = ref.to(at::kFloat);
|
||||
at::Tensor oth_float = oth.to(at::kFloat);
|
||||
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
double last_succeed_atol = 1;
|
||||
double last_succeed_rtol = 1;
|
||||
for (auto& atol : atols) {
|
||||
for (auto& rtol : rtols) {
|
||||
if (at::allclose(ref_float, oth_float, rtol, atol)) {
|
||||
last_succeed_atol = atol;
|
||||
last_succeed_rtol = rtol;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (last_succeed_atol == 1) {
|
||||
return FAIL;
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
|
||||
}
|
||||
|
||||
return OK;
|
||||
}
|
||||
|
||||
char transa;
|
||||
char transb;
|
||||
int64_t m;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
at::opmath_type<T> alpha;
|
||||
const T* a;
|
||||
int64_t lda;
|
||||
int64_t stride_a;
|
||||
const T* b;
|
||||
int64_t ldb;
|
||||
int64_t stride_b;
|
||||
at::opmath_type<T> beta;
|
||||
T* c;
|
||||
int64_t ldc;
|
||||
int64_t stride_c;
|
||||
int64_t batch;
|
||||
};
|
||||
|
||||
} // namespace at::cuda::tunable
|
379
aten/src/ATen/cuda/tunable/GemmHipblaslt.h
Normal file
379
aten/src/ATen/cuda/tunable/GemmHipblaslt.h
Normal file
@ -0,0 +1,379 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/tunable/TunableOp.h>
|
||||
#include <ATen/cuda/tunable/GemmCommon.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
|
||||
#include <hipblaslt/hipblaslt.h>
|
||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||
|
||||
#define TORCH_HIPBLASLT_CHECK(EXPR) \
|
||||
do { \
|
||||
hipblasStatus_t __err = EXPR; \
|
||||
TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \
|
||||
"hipblaslt error: ", \
|
||||
hipblasStatusToString(__err), \
|
||||
" when calling `" #EXPR "`"); \
|
||||
} while (0)
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
#ifdef HIPBLASLT_HAS_GETINDEXFROMALGO
|
||||
#define GETINDEXFROMALGO(algo) hipblaslt_ext::getIndexFromAlgo(algo)
|
||||
#else
|
||||
static int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo) {
|
||||
int* algo_ptr = (int*)algo.data;
|
||||
if(*algo_ptr < 0) {
|
||||
return -1;
|
||||
}
|
||||
return *algo_ptr;
|
||||
}
|
||||
#define GETINDEXFROMALGO(algo) getIndexFromAlgo(algo)
|
||||
#endif
|
||||
|
||||
#ifdef HIPBLASLT_CUSTOM_COMPUTE_TYPE
|
||||
#define COMPUTE_TYPE_32 HIPBLASLT_COMPUTE_F32
|
||||
#else
|
||||
#define COMPUTE_TYPE_32 HIPBLAS_COMPUTE_32F
|
||||
#endif
|
||||
|
||||
#ifdef HIPBLASLT_CUSTOM_DATA_TYPE
|
||||
|
||||
template <typename T>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor();
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<float>() {
|
||||
return HIPBLASLT_R_32F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<Half>() {
|
||||
return HIPBLASLT_R_16F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<BFloat16>() {
|
||||
return HIPBLASLT_R_16B;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasltDatatype_t HipBlasDataTypeFor<double>() {
|
||||
return HIPBLASLT_R_64F;
|
||||
}
|
||||
|
||||
#define DATA_TYPE_R_32 HIPBLASLT_R_32F
|
||||
|
||||
#else
|
||||
|
||||
template <typename T>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor();
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor<float>() {
|
||||
return HIPBLAS_R_32F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor<Half>() {
|
||||
return HIPBLAS_R_16F;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor<BFloat16>() {
|
||||
return HIPBLAS_R_16B;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr hipblasDatatype_t HipBlasDataTypeFor<double>() {
|
||||
return HIPBLAS_R_64F;
|
||||
}
|
||||
|
||||
#ifdef HIPBLAS_V2
|
||||
#define DATA_TYPE_R_32 HIP_R_32F
|
||||
#else
|
||||
#define DATA_TYPE_R_32 HIPBLAS_R_32F
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T, typename ParamsT>
|
||||
int GetBatchFromParams(const ParamsT* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return params->batch;
|
||||
}
|
||||
|
||||
template <typename T, typename ParamsT>
|
||||
int GetStrideAFromParams(const ParamsT* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return params->stride_a;
|
||||
}
|
||||
|
||||
template <typename T, typename ParamsT>
|
||||
int GetStrideBFromParams(const ParamsT* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return params->stride_b;
|
||||
}
|
||||
|
||||
template <typename T, typename ParamsT>
|
||||
int GetStrideCFromParams(const ParamsT* params) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return params->stride_c;
|
||||
}
|
||||
|
||||
static hipblasOperation_t _hipblasOpFromChar(char op) {
|
||||
switch (op) {
|
||||
case 'n':
|
||||
case 'N':
|
||||
return HIPBLAS_OP_N;
|
||||
case 't':
|
||||
case 'T':
|
||||
return HIPBLAS_OP_T;
|
||||
case 'c':
|
||||
case 'C':
|
||||
return HIPBLAS_OP_C;
|
||||
}
|
||||
AT_ERROR(
|
||||
"_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
|
||||
}
|
||||
|
||||
static char _charFromhipblasOp(hipblasOperation_t op) {
|
||||
switch (op) {
|
||||
case HIPBLAS_OP_N:
|
||||
return 'N';
|
||||
case HIPBLAS_OP_T:
|
||||
return 'T';
|
||||
case HIPBLAS_OP_C:
|
||||
return 'C';
|
||||
}
|
||||
AT_ERROR(
|
||||
"_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`");
|
||||
}
|
||||
|
||||
static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
|
||||
if (layout == BlasOp::N) {
|
||||
return HIPBLAS_OP_N;
|
||||
}
|
||||
return HIPBLAS_OP_T;
|
||||
}
|
||||
|
||||
static size_t GetHipblasltWorkspaceSize() {
|
||||
static const char * env = getenv("HIPBLASLT_WORKSPACE_SIZE");
|
||||
// 256MB is max workspace size allowed for hipblaslt
|
||||
// hipblaslt-bench uses 32MB
|
||||
// recommendation from hipblaslt author was 76MB
|
||||
size_t workspace_size = 2*128*1024*1024; // default 256MB
|
||||
if (env) {
|
||||
try {
|
||||
workspace_size = std::stoi(env);
|
||||
} catch(std::invalid_argument const& e) {
|
||||
TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,",
|
||||
" using default workspace size of ", workspace_size, " bytes.");
|
||||
} catch(std::out_of_range const& e) {
|
||||
TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,",
|
||||
" using default workspace size of ", workspace_size, " bytes.");
|
||||
}
|
||||
}
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
template <typename T, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
|
||||
class HipblasltGemmOp : public Callable<ParamsT> {
|
||||
public:
|
||||
HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
|
||||
|
||||
TuningStatus Call(const ParamsT* params) override {
|
||||
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
|
||||
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
|
||||
auto in_out_datatype = HipBlasDataTypeFor<T>();
|
||||
auto opa = _hipblasOpFromChar(params->transa);
|
||||
auto opb = _hipblasOpFromChar(params->transb);
|
||||
|
||||
TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
|
||||
|
||||
float alpha = static_cast<float>(params->alpha);
|
||||
float beta = static_cast<float>(params->beta);
|
||||
|
||||
hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
|
||||
hipblasLtMatmulDesc_t matmul;
|
||||
if (opa == HIPBLAS_OP_N) {
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->m, params->k, params->lda));
|
||||
}
|
||||
else {
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->k, params->m, params->lda));
|
||||
}
|
||||
if (opb == HIPBLAS_OP_N) {
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->k, params->n, params->ldb));
|
||||
}
|
||||
else {
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->n, params->k, params->ldb));
|
||||
}
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescCreate(&matmul, COMPUTE_TYPE_32, DATA_TYPE_R_32));
|
||||
|
||||
int batch = GetBatchFromParams<T>(params);
|
||||
if (batch > 1) {
|
||||
int64_t stride_a = GetStrideAFromParams<T>(params);
|
||||
int64_t stride_b = GetStrideBFromParams<T>(params);
|
||||
int64_t stride_c = GetStrideCFromParams<T>(params);
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
|
||||
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
|
||||
}
|
||||
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
|
||||
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &opa, sizeof(int32_t)));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
|
||||
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &opb, sizeof(int32_t)));
|
||||
|
||||
size_t workspace_size = GetHipblasltWorkspaceSize();
|
||||
|
||||
auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
|
||||
|
||||
size_t ret_workspace_size = 0;
|
||||
auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
|
||||
matmul,
|
||||
&alpha,
|
||||
mat_a,
|
||||
mat_b,
|
||||
&beta,
|
||||
mat_c,
|
||||
mat_c,
|
||||
algo_,
|
||||
ret_workspace_size);
|
||||
|
||||
if (status == HIPBLAS_STATUS_SUCCESS) {
|
||||
if (ret_workspace_size >= workspace_size) {
|
||||
//TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " workspace too large");
|
||||
return FAIL;
|
||||
}
|
||||
}
|
||||
else {
|
||||
//TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " not supported");
|
||||
return FAIL;
|
||||
}
|
||||
|
||||
void* workspace_buffer = nullptr;
|
||||
if (workspace_size > 0) {
|
||||
workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size);
|
||||
}
|
||||
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
|
||||
matmul,
|
||||
&alpha,
|
||||
params->a,
|
||||
mat_a,
|
||||
params->b,
|
||||
mat_b,
|
||||
&beta,
|
||||
params->c,
|
||||
mat_c,
|
||||
params->c,
|
||||
mat_c,
|
||||
&algo_,
|
||||
workspace_buffer,
|
||||
workspace_size,
|
||||
at::cuda::getCurrentCUDAStream()));
|
||||
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
|
||||
if (workspace_size > 0) {
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer);
|
||||
}
|
||||
return OK;
|
||||
}
|
||||
|
||||
private:
|
||||
hipblasLtMatmulAlgo_t algo_;
|
||||
};
|
||||
|
||||
template <typename T, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
|
||||
auto GetHipBlasLtTypeStringAndOps() {
|
||||
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
|
||||
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
|
||||
auto in_out_datatype = HipBlasDataTypeFor<T>();
|
||||
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
|
||||
|
||||
hipblasLtHandle_t handle;
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
|
||||
TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
|
||||
hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
|
||||
transa_outer,
|
||||
transb_outer,
|
||||
in_out_datatype,
|
||||
in_out_datatype,
|
||||
in_out_datatype,
|
||||
in_out_datatype,
|
||||
COMPUTE_TYPE_32,
|
||||
heuristic_result));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
|
||||
|
||||
// Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
|
||||
std::sort(heuristic_result.begin(),
|
||||
heuristic_result.end(),
|
||||
[](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
|
||||
return GETINDEXFROMALGO(a.algo) < GETINDEXFROMALGO(b.algo);
|
||||
});
|
||||
|
||||
int returned_algo_count = heuristic_result.size();
|
||||
std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
|
||||
for (int i = 0; i < returned_algo_count; i++) {
|
||||
auto algo = heuristic_result[i].algo;
|
||||
int algo_index = GETINDEXFROMALGO(algo);
|
||||
auto callable = std::make_unique<HipblasltGemmOp<T, ALayout, BLayout, ParamsT>>(algo);
|
||||
std::string type_string = c10::str(
|
||||
"Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
|
||||
ret.emplace_back(type_string, std::move(callable));
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
||||
auto GetHipBlasLtGemmTypeStringAndOps() {
|
||||
return GetHipBlasLtTypeStringAndOps<T, ALayout, BLayout, GemmParams<T>>();
|
||||
}
|
||||
|
||||
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
||||
auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
|
||||
return GetHipBlasLtTypeStringAndOps<T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
|
||||
}
|
||||
|
||||
#undef TORCH_HIPBLASLT_CHECK
|
||||
#undef GETINDEXFROMALGO
|
||||
#undef COMPUTE_TYPE_32
|
||||
#undef DATA_TYPE_R_32
|
||||
|
||||
} // namespace at::cuda::tunable
|
275
aten/src/ATen/cuda/tunable/GemmRocblas.h
Normal file
275
aten/src/ATen/cuda/tunable/GemmRocblas.h
Normal file
@ -0,0 +1,275 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/tunable/TunableOp.h>
|
||||
#include <ATen/cuda/tunable/GemmCommon.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
|
||||
#define ROCBLAS_BETA_FEATURES_API
|
||||
#include <rocblas/rocblas.h>
|
||||
|
||||
#define TORCH_ROCBLAS_CHECK(EXPR) \
|
||||
do { \
|
||||
rocblas_status __err = EXPR; \
|
||||
TORCH_CHECK(__err == rocblas_status_success, \
|
||||
"rocblas error: ", \
|
||||
rocblas_status_to_string(__err), \
|
||||
" when calling `" #EXPR "`"); \
|
||||
} while (0)
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
template <typename T>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor();
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<float>() {
|
||||
return rocblas_datatype_f32_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<double>() {
|
||||
return rocblas_datatype_f64_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<Half>() {
|
||||
return rocblas_datatype_f16_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<BFloat16>() {
|
||||
return rocblas_datatype_bf16_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<float>>() {
|
||||
return rocblas_datatype_f32_c;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasDataTypeFor<c10::complex<double>>() {
|
||||
return rocblas_datatype_f64_c;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor();
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<float>() {
|
||||
return rocblas_datatype_f32_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<double>() {
|
||||
return rocblas_datatype_f64_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<Half>() {
|
||||
// Note that we're returning the _compute_ type for a given datatype.
|
||||
// As of 12/2022, using compute type FP16 for 16-bit floats was much
|
||||
// slower than using compute type FP32. So we use FP32 compute even for
|
||||
// FP16 datatypes. This is how GEMM is implemented even in the function
|
||||
// rocblasGemmHelper (see fpgeneric.h)
|
||||
return rocblas_datatype_f32_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<BFloat16>() {
|
||||
// Note that we're returning the _compute_ type for a given datatype.
|
||||
// As of 12/2022, using compute type FP16 for 16-bit floats was much
|
||||
// slower than using compute type FP32. So we use FP32 compute even for
|
||||
// BF16 datatypes. This is how GEMM is implemented even in the function
|
||||
// rocblasGemmHelper (see fpgeneric.h)
|
||||
return rocblas_datatype_f32_r;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<float>>() {
|
||||
return rocblas_datatype_f32_c;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr rocblas_datatype RocBlasComputeTypeFor<c10::complex<double>>() {
|
||||
return rocblas_datatype_f64_c;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto DoCastForHalfOrBfloat16(const T fp) {
|
||||
return fp;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline auto DoCastForHalfOrBfloat16<Half>(const Half fp) {
|
||||
// alpha and beta should be the same as compute_type, in Half case it is float.
|
||||
float h = fp;
|
||||
return h;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline auto DoCastForHalfOrBfloat16<BFloat16>(const BFloat16 fp) {
|
||||
// alpha and beta should be the same as compute_type, in bfloat16 case it is float.
|
||||
float h = fp;
|
||||
return h;
|
||||
}
|
||||
|
||||
static rocblas_operation _rocblasOpFromChar(char op) {
|
||||
switch (op) {
|
||||
case 'n':
|
||||
case 'N':
|
||||
return rocblas_operation_none;
|
||||
case 't':
|
||||
case 'T':
|
||||
return rocblas_operation_transpose;
|
||||
case 'c':
|
||||
case 'C':
|
||||
return rocblas_operation_conjugate_transpose;
|
||||
}
|
||||
AT_ERROR(
|
||||
"_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class RocblasGemmOp : public Callable<GemmParams<T>> {
|
||||
public:
|
||||
RocblasGemmOp(int solution) : solution_{solution} {}
|
||||
|
||||
TuningStatus Call(const GemmParams<T>* params) override {
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
auto status = rocblas_gemm_ex(
|
||||
(rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
|
||||
_rocblasOpFromChar(params->transa),
|
||||
_rocblasOpFromChar(params->transb),
|
||||
params->m, params->n, params->k,
|
||||
&h_a,
|
||||
params->a, input_output_type, params->lda,
|
||||
params->b, input_output_type, params->ldb,
|
||||
&h_b,
|
||||
params->c, input_output_type, params->ldc,
|
||||
params->c, input_output_type, params->ldc,
|
||||
compute_type,
|
||||
rocblas_gemm_algo_solution_index,
|
||||
solution_,
|
||||
rocblas_gemm_flags_none);
|
||||
if (status != rocblas_status_success) {
|
||||
return FAIL;
|
||||
}
|
||||
return OK;
|
||||
}
|
||||
|
||||
private:
|
||||
int solution_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
auto GetRocBlasGemmTypeStringAndOps() {
|
||||
rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
|
||||
int solution_size;
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
// Get the number of available solutions
|
||||
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
||||
input_output_type,
|
||||
input_output_type,
|
||||
compute_type,
|
||||
rocblas_gemm_flags_none,
|
||||
nullptr,
|
||||
&solution_size));
|
||||
std::vector<int> solutions(solution_size);
|
||||
// Get the list of available solutions
|
||||
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
||||
input_output_type,
|
||||
input_output_type,
|
||||
compute_type,
|
||||
rocblas_gemm_flags_none,
|
||||
solutions.data(),
|
||||
&solution_size));
|
||||
// Sort the solutions in ascending order to make the solution vector deterministic across runs
|
||||
std::sort(solutions.begin(), solutions.end());
|
||||
|
||||
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
|
||||
for (size_t i = 0; i < solutions.size(); ++i) {
|
||||
auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
|
||||
ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
|
||||
public:
|
||||
RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {}
|
||||
|
||||
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
auto status = rocblas_gemm_strided_batched_ex(
|
||||
(rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
|
||||
_rocblasOpFromChar(params->transa),
|
||||
_rocblasOpFromChar(params->transb),
|
||||
params->m, params->n, params->k,
|
||||
&h_a,
|
||||
params->a, input_output_type, params->lda, params->stride_a,
|
||||
params->b, input_output_type, params->ldb, params->stride_b,
|
||||
&h_b,
|
||||
params->c, input_output_type, params->ldc, params->stride_c,
|
||||
params->c, input_output_type, params->ldc, params->stride_c,
|
||||
params->batch,
|
||||
compute_type,
|
||||
rocblas_gemm_algo_solution_index,
|
||||
solution_,
|
||||
rocblas_gemm_flags_none);
|
||||
if (status != rocblas_status_success) {
|
||||
return FAIL;
|
||||
}
|
||||
return OK;
|
||||
}
|
||||
|
||||
private:
|
||||
int solution_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
|
||||
rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
|
||||
int solution_size;
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
// Get the number of available solutions
|
||||
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
||||
input_output_type,
|
||||
input_output_type,
|
||||
compute_type,
|
||||
rocblas_gemm_flags_none,
|
||||
nullptr,
|
||||
&solution_size));
|
||||
std::vector<int> solutions(solution_size);
|
||||
// Get the list of available solutions
|
||||
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
|
||||
input_output_type,
|
||||
input_output_type,
|
||||
compute_type,
|
||||
rocblas_gemm_flags_none,
|
||||
solutions.data(),
|
||||
&solution_size));
|
||||
// Sort the solutions in ascending order to make the solution vector deterministic across runs
|
||||
std::sort(solutions.begin(), solutions.end());
|
||||
|
||||
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
|
||||
for (size_t i = 0; i < solutions.size(); ++i) {
|
||||
auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
|
||||
ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace at::cuda::tunable
|
88
aten/src/ATen/cuda/tunable/README.md
Normal file
88
aten/src/ATen/cuda/tunable/README.md
Normal file
@ -0,0 +1,88 @@
|
||||
# TunableOp
|
||||
|
||||
This directory implements a TunableOp interface.
|
||||
|
||||
Some operations, such as GEMMs, could be implemented using more than one library or more than one technique. For
|
||||
example, a GEMM could be implemented for CUDA or ROCm using either the blas or blasLt libraries. Further, ROCm's
|
||||
rocblas and hipblaslt libraries allow the user to query for all possible algorithms and then choose one. How does one
|
||||
know which implementation is the fastest and should be chosen? That's what TunableOp provides.
|
||||
|
||||
The behavior of TunableOp is currently easily manipulated through environment variables, though you could use the C++
|
||||
interface of at::cuda::tunable::getTuningContext(). A Python interface to the TuningContext does not yet exist.
|
||||
|
||||
Currently only a TunableGemm for ROCm is implemented. Any call to at::cuda::blas::gemm() can optionally use the
|
||||
TunableGemm. Calling gemm() for a given set of input arguments (transa, transb, m, n, k) will attempt to use the
|
||||
fastest available implementation.
|
||||
|
||||
## Environment Variables
|
||||
|
||||
#### PYTORCH_TUNABLEOP_ENABLED
|
||||
Default is 0. Set to 1 to enable.
|
||||
This is the big on/off switch for all TunableOp implementations.
|
||||
|
||||
#### PYTORCH_TUNABLEOP_TUNING
|
||||
Default is 1. Set to 0 to disable.
|
||||
When enabled, if a tuned entry isn't found, run the tuning step and record the entry.
|
||||
|
||||
#### PYTORCH_TUNABLEOP_VERBOSE
|
||||
Default is 0. Set to 1 to enable.
|
||||
This will produce a lot of diagnostic messages but may be useful to see if TunableOp is being used at all.
|
||||
Otherwise, TunableOp is completely silent unless there is a warning or error during its use.
|
||||
|
||||
#### PYTORCH_TUNABLEOP_FILENAME
|
||||
Default is 'tunableop_results.csv'. If you provide a filename, the TuningContext will attempt to read it the first time
|
||||
the context is used. If tuning is enabled and new tunings are discovered, it will also write out to this same filename
|
||||
with all tunings, both the ones it read in at startup as well as the new ones found at runtime. This can be used, for
|
||||
example, to build up a tunings file across many workloads by reusing the same file. Unsetting this variable is not
|
||||
recommended but can be done, in which case the tuning results will not be saved.
|
||||
|
||||
#### PYTORCH_TUNABLEOP_NUMERICAL_CHECK
|
||||
Default is 1. Set to 0 to disable. Compare the results of each possible solution against the default solution and reject
|
||||
those with low accuracy.
|
||||
|
||||
#### PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED
|
||||
Default is 1. Set to 0 to disable hipblaslt being considered during tuning.
|
||||
|
||||
### Tuning Iterations
|
||||
By default, each possible solution for a given operator will be run for either 100 iterations or as many iterations can
|
||||
be run within 30ms, whichever is smaller. Its average execution will be calculated. The fastest solution is chosen. In
|
||||
addition, a set of warm up iterations can optionally be run prior to the timed iterations. The following environment
|
||||
variables can be used to set either the maximum number of iterations to attempt or the maximum amount of time allowed in
|
||||
milliseconds, or both, in which case the smaller of the two values used.
|
||||
|
||||
#### PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS
|
||||
Default is 30.
|
||||
|
||||
#### PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS
|
||||
Default is 100.
|
||||
|
||||
#### PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS
|
||||
Default is 0, meaning it is not used.
|
||||
|
||||
#### PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS
|
||||
Default is 1.
|
||||
|
||||
## File Output
|
||||
|
||||
Assuming you specified a filename, you'll end up with a CSV file with contents like so:
|
||||
|
||||
```
|
||||
Validator,PT_VERSION,2.2.0
|
||||
Validator,ROCM_VERSION,6.0.0.0-12969-1544e39
|
||||
Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7
|
||||
Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty
|
||||
GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262
|
||||
GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033
|
||||
```
|
||||
|
||||
Note the "Validator" lines. If you change a library verison, or rocm version, or pytorch version, TunableOp will detect
|
||||
this and not load the tunings because they are likely affected by other software changes.
|
||||
|
||||
The remaining lines are the tuned solutions for each TunableOp encountered during your execution. Each line consists of
|
||||
4 comma-separated fields: operator name, operator parameters, solution name, and average execution time. The execution
|
||||
time is an optional field. The CSV file can be edited, but with caution. For example, the solution name (field 3) can be
|
||||
changed to "Default" and it will fall back to the original PyTorch untuned implementation. Or, in the case of ROCm's
|
||||
hipBLAS or hipBLASLt libraries, if you know the specific solution index you can override the solution that TunableOp
|
||||
selected by replacing the value. The operator name and parameters (fields 1 and 2) are internally named and should not
|
||||
be modified. In the case of GemmTunableOp, field 1 indicates the datatype and whether the inputs are transposed (T) or
|
||||
not (N) and field 2 indicates the M, N, K input shapes.
|
43
aten/src/ATen/cuda/tunable/StreamTimer.cpp
Normal file
43
aten/src/ATen/cuda/tunable/StreamTimer.cpp
Normal file
@ -0,0 +1,43 @@
|
||||
// Original TunableOp is from onnxruntime.
|
||||
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
||||
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
//
|
||||
// Adapting TunableOp into PyTorch
|
||||
// Copyright (c) Advanced Micro Devices, Inc.
|
||||
//
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/tunable/StreamTimer.h>
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
StreamTimer::StreamTimer() {
|
||||
AT_CUDA_CHECK(cudaEventCreate(&start_));
|
||||
AT_CUDA_CHECK(cudaEventCreate(&end_));
|
||||
}
|
||||
|
||||
StreamTimer::~StreamTimer() {
|
||||
}
|
||||
|
||||
void StreamTimer::Start() {
|
||||
AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
AT_CUDA_CHECK(cudaEventRecord(start_, at::cuda::getCurrentCUDAStream()));
|
||||
}
|
||||
|
||||
void StreamTimer::End() {
|
||||
AT_CUDA_CHECK(cudaEventRecord(end_, at::cuda::getCurrentCUDAStream()));
|
||||
AT_CUDA_CHECK(cudaEventSynchronize(end_));
|
||||
}
|
||||
|
||||
float StreamTimer::Duration() {
|
||||
float time;
|
||||
// time is in ms with a resolution of 1 us
|
||||
AT_CUDA_CHECK(cudaEventElapsedTime(&time, start_, end_));
|
||||
return time;
|
||||
}
|
||||
|
||||
} // namespace at::cuda::tunable
|
34
aten/src/ATen/cuda/tunable/StreamTimer.h
Normal file
34
aten/src/ATen/cuda/tunable/StreamTimer.h
Normal file
@ -0,0 +1,34 @@
|
||||
// Original TunableOp is from onnxruntime.
|
||||
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
||||
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
//
|
||||
// Adapting TunableOp into PyTorch
|
||||
// Copyright (c) Advanced Micro Devices, Inc.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
class StreamTimer : public ITimer {
|
||||
public:
|
||||
StreamTimer();
|
||||
virtual ~StreamTimer();
|
||||
|
||||
void Start() override;
|
||||
|
||||
void End() override;
|
||||
|
||||
float Duration() override;
|
||||
|
||||
private:
|
||||
cudaEvent_t start_;
|
||||
cudaEvent_t end_;
|
||||
};
|
||||
|
||||
} // namespace at::cuda::tunable
|
567
aten/src/ATen/cuda/tunable/Tunable.cpp
Normal file
567
aten/src/ATen/cuda/tunable/Tunable.cpp
Normal file
@ -0,0 +1,567 @@
|
||||
// Original TunableOp is from onnxruntime.
|
||||
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
||||
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
//
|
||||
// Adapting TunableOp into PyTorch
|
||||
// Copyright (c) Advanced Micro Devices, Inc.
|
||||
//
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContextLight.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <torch/version.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <cxxabi.h>
|
||||
#endif
|
||||
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
namespace {
|
||||
|
||||
TuningContext tuning_context;
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
TuningContext* getTuningContext() {
|
||||
return &tuning_context;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry) {
|
||||
return stream << entry.key_ << "," << entry.time_;
|
||||
}
|
||||
|
||||
// TuningResultsManager
|
||||
|
||||
KernelMap TuningResultsManager::Lookup(const std::string& op_signature) {
|
||||
std::scoped_lock l{lock_};
|
||||
auto it = results_.find(op_signature);
|
||||
if (it == results_.cend()) {
|
||||
return {};
|
||||
}
|
||||
return it->second; // copied
|
||||
}
|
||||
|
||||
ResultEntry TuningResultsManager::Lookup(const std::string& op_signature, const std::string& params_signature) {
|
||||
std::scoped_lock l{lock_};
|
||||
auto kernel_map_it = results_.find(op_signature);
|
||||
if (kernel_map_it == results_.cend()) {
|
||||
TUNABLE_LOG("missing op_signature, returning null ResultEntry");
|
||||
return ResultEntry::Null();
|
||||
}
|
||||
|
||||
const auto& km = kernel_map_it->second;
|
||||
auto it = km.find(params_signature);
|
||||
if (it == km.cend()) {
|
||||
TUNABLE_LOG("missing params_signature, returning null ResultEntry");
|
||||
return ResultEntry::Null();
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
inline void TuningResultsManager::AddImpl(const std::string& op_signature,
|
||||
const std::string& params_signature,
|
||||
ResultEntry best,
|
||||
KernelMap& kernel_map) {
|
||||
auto it = kernel_map.find(params_signature);
|
||||
if (it != kernel_map.end()) {
|
||||
if (it->second != best) {
|
||||
TUNABLE_LOG(op_signature, "(", params_signature, ") already has a best kernel ",
|
||||
"id=", it->second, " selected, want to add a different best kernel ", best,
|
||||
", the new kernel id will be ignored.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
TUNABLE_LOG(op_signature, "(", params_signature, ") -> ", best);
|
||||
kernel_map.emplace(params_signature, best);
|
||||
}
|
||||
|
||||
void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
|
||||
std::scoped_lock l{lock_};
|
||||
|
||||
auto it = results_.find(op_signature);
|
||||
if (it == results_.end()) {
|
||||
it = results_.insert({op_signature, {}}).first;
|
||||
}
|
||||
|
||||
AddImpl(op_signature, params_signature, best, it->second);
|
||||
}
|
||||
|
||||
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
|
||||
std::scoped_lock l{lock_};
|
||||
|
||||
auto it = results_.find(op_signature);
|
||||
if (it == results_.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto it2 = it->second.find(params_signature);
|
||||
if (it2 == it->second.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
TUNABLE_LOG(op_signature, "(", params_signature, ")");
|
||||
it->second.erase(it2);
|
||||
}
|
||||
|
||||
inline void TuningResultsManager::DisjointMergeImpl(
|
||||
const std::string& op_signature,
|
||||
const KernelMap& kernel_map,
|
||||
/*out*/ std::unordered_map<std::string, KernelMap>& results) {
|
||||
auto it = results.find(op_signature);
|
||||
if (it == results.end()) {
|
||||
for (const auto& [param_sig, kernel_id] : kernel_map) {
|
||||
TUNABLE_LOG(op_signature, "(", param_sig, ") -> ", kernel_id);
|
||||
}
|
||||
results[op_signature] = kernel_map;
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto& [params_signature, best] : kernel_map) {
|
||||
AddImpl(op_signature, params_signature, best, it->second);
|
||||
}
|
||||
}
|
||||
|
||||
void TuningResultsManager::Load(const std::unordered_map<std::string, KernelMap>& results_to_load) {
|
||||
TUNABLE_LOG("Loading results");
|
||||
std::scoped_lock l{lock_};
|
||||
for (const auto& [op_signature, kernel_map] : results_to_load) {
|
||||
DisjointMergeImpl(op_signature, kernel_map, results_);
|
||||
}
|
||||
}
|
||||
|
||||
ResultsMap TuningResultsManager::Dump() {
|
||||
std::scoped_lock l{lock_};
|
||||
return results_;
|
||||
}
|
||||
|
||||
void TuningResultsManager::DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map) {
|
||||
std::scoped_lock l{lock_};
|
||||
DisjointMergeImpl(op_signature, kernel_map, results_);
|
||||
}
|
||||
|
||||
size_t TuningResultsManager::GetSize() {
|
||||
size_t size = 0;
|
||||
std::scoped_lock l{lock_};
|
||||
for (const auto& [op_signature, kernel_map] : results_) {
|
||||
size += kernel_map.size();
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
// TuningResultsValidator
|
||||
|
||||
TuningResultsValidator::TuningResultsValidator() {
|
||||
RegisterValidator(
|
||||
"PT_VERSION",
|
||||
[this]() { return GetPyTorchVersion(); },
|
||||
[this](auto&& k) { return ValidatePyTorchVersion(std::forward<decltype(k)>(k)); });
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::string> TuningResultsValidator::GetAllValidators() const {
|
||||
std::unordered_map<std::string, std::string> ret;
|
||||
for (const auto& [key, get_validate_func_pair] : validators_) {
|
||||
const GetFunc& getter = get_validate_func_pair.first;
|
||||
ret[key] = getter();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static bool CheckMandatoryKeys(
|
||||
const TuningResultsValidator::GetValidateFuncs& gv_funcs,
|
||||
const std::unordered_map<std::string, std::string>& to_check) {
|
||||
bool passed = true;
|
||||
for (const auto& k : TuningResultsValidator::mandatory_keys) {
|
||||
if (gv_funcs.find(k) == gv_funcs.end()) {
|
||||
passed = false;
|
||||
TUNABLE_LOG("key=\"", k, "\" is not registered for Get and Validate. ");
|
||||
}
|
||||
|
||||
if (to_check.find(k) == to_check.end()) {
|
||||
passed = false;
|
||||
TUNABLE_LOG("key=\"", k, "\" is not provided for validation. ");
|
||||
}
|
||||
}
|
||||
return passed;
|
||||
}
|
||||
|
||||
static bool CheckKeysMatching(
|
||||
const TuningResultsValidator::GetValidateFuncs& gv_funcs,
|
||||
const std::unordered_map<std::string, std::string>& to_check) {
|
||||
auto get_keys = [](const auto& it) -> std::string { return it.first; };
|
||||
std::vector<std::string> required_keys;
|
||||
std::vector<std::string> provided_keys;
|
||||
std::transform(gv_funcs.cbegin(), gv_funcs.cend(), std::back_inserter(required_keys), get_keys);
|
||||
std::transform(to_check.cbegin(), to_check.cend(), std::back_inserter(provided_keys), get_keys);
|
||||
std::sort(required_keys.begin(), required_keys.end());
|
||||
std::sort(provided_keys.begin(), provided_keys.end());
|
||||
|
||||
std::unordered_set<std::string> intersection;
|
||||
std::set_intersection(required_keys.cbegin(), required_keys.cend(),
|
||||
provided_keys.cbegin(), provided_keys.cend(),
|
||||
std::inserter(intersection, intersection.end()));
|
||||
bool matched = true;
|
||||
if (intersection.size() != required_keys.size()) {
|
||||
matched = false;
|
||||
for (const auto& k : required_keys) {
|
||||
if (intersection.find(k) == intersection.end()) {
|
||||
TORCH_WARN("Unmatched validator: \"", k, "\" is required, but the tuning results does not provide it. ");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (intersection.size() != provided_keys.size()) {
|
||||
matched = false;
|
||||
for (const auto& k : provided_keys) {
|
||||
if (intersection.find(k) == intersection.end()) {
|
||||
TORCH_WARN("Unmatched validator: \"", k, "\" is provided, but pytorch is unable to consume it. ");
|
||||
}
|
||||
}
|
||||
}
|
||||
return matched;
|
||||
}
|
||||
|
||||
TuningStatus TuningResultsValidator::ValidateAll(
|
||||
const std::unordered_map<std::string, std::string>& to_validate) const {
|
||||
if (!CheckMandatoryKeys(validators_, to_validate)) {
|
||||
return FAIL;
|
||||
}
|
||||
if (!CheckKeysMatching(validators_, to_validate)) {
|
||||
return FAIL;
|
||||
}
|
||||
|
||||
for (const auto& [key, value] : to_validate) {
|
||||
const auto& it = validators_.find(key);
|
||||
if (it == validators_.cend()) {
|
||||
TORCH_WARN("Failed to lookup validator using key ", key);
|
||||
for (const auto& [key2, val2] : validators_) {
|
||||
TORCH_WARN("available key ", key2);
|
||||
}
|
||||
return FAIL;
|
||||
}
|
||||
const ValidateFunc& validator = it->second.second;
|
||||
if (validator(value) != OK) {
|
||||
TORCH_WARN("Failed validator: ", key);
|
||||
return FAIL;
|
||||
}
|
||||
}
|
||||
|
||||
return OK;
|
||||
}
|
||||
|
||||
void TuningResultsValidator::RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf) {
|
||||
if (validators_.find(key) != validators_.end()) {
|
||||
TORCH_WARN("Attempting to re-register validator with key ", key);
|
||||
}
|
||||
else {
|
||||
validators_[key] = std::make_pair(gf, vf);
|
||||
}
|
||||
}
|
||||
|
||||
std::string TuningResultsValidator::GetPyTorchVersion() const {
|
||||
return TORCH_VERSION;
|
||||
}
|
||||
|
||||
TuningStatus TuningResultsValidator::ValidatePyTorchVersion(const std::string& value) const {
|
||||
if (value == GetPyTorchVersion()) {
|
||||
return OK;
|
||||
}
|
||||
return FAIL;
|
||||
}
|
||||
|
||||
// TuningContext
|
||||
|
||||
TuningContext::TuningContext() :
|
||||
enable_{false},
|
||||
tuning_enable_{true},
|
||||
manager_initialized_{false},
|
||||
max_tuning_duration_ms_{30},
|
||||
max_tuning_iterations_{100},
|
||||
max_warmup_duration_ms_{0},
|
||||
max_warmup_iterations_{0},
|
||||
filename_{"tunableop_results.csv"},
|
||||
results_count_from_input_file_{0}
|
||||
{
|
||||
}
|
||||
|
||||
TuningContext::~TuningContext() {
|
||||
if (!manager_initialized_) {
|
||||
// TuningResultsManager was never initialized, no tuning requested or performed.
|
||||
// This can happen in a DDP job where a python process spawns other workers
|
||||
// but doesn't do any computation itself.
|
||||
return;
|
||||
}
|
||||
auto filename = GetFilename();
|
||||
if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty()) {
|
||||
if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
|
||||
if (results_count_from_input_file_ > 0) {
|
||||
TUNABLE_LOG("additional tuning results available, rewriting file ", filename);
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG("writing file ", filename);
|
||||
}
|
||||
if (!WriteFile(filename)) {
|
||||
TUNABLE_LOG("failed to write file ", filename);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TuningContext::EnableTunableOp() {
|
||||
TUNABLE_LOG("Enable TunableOp");
|
||||
enable_ = true;
|
||||
}
|
||||
|
||||
void TuningContext::DisableTunableOp() {
|
||||
TUNABLE_LOG("Disable TunableOp");
|
||||
enable_ = false;
|
||||
}
|
||||
|
||||
bool TuningContext::IsTunableOpEnabled() const {
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_ENABLED");
|
||||
if (env != nullptr && strcmp(env, "1") == 0) {
|
||||
//TUNABLE_LOG("PYTORCH_TUNABLEOP_ENABLED=1");
|
||||
return true;
|
||||
}
|
||||
return enable_;
|
||||
}
|
||||
|
||||
void TuningContext::EnableTuning() {
|
||||
TUNABLE_LOG("Enable Tuning for TunableOp");
|
||||
tuning_enable_ = true;
|
||||
}
|
||||
|
||||
void TuningContext::DisableTuning() {
|
||||
TUNABLE_LOG("Disable Tuning for TunableOp");
|
||||
tuning_enable_ = false;
|
||||
}
|
||||
|
||||
bool TuningContext::IsTuningEnabled() const {
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_TUNING");
|
||||
if (env != nullptr && strcmp(env, "0") == 0) {
|
||||
//TUNABLE_LOG("PYTORCH_TUNABLEOP_TUNING=1");
|
||||
return false;
|
||||
}
|
||||
return tuning_enable_;
|
||||
}
|
||||
|
||||
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
|
||||
max_tuning_duration_ms_ = max_duration_ms;
|
||||
}
|
||||
|
||||
int TuningContext::GetMaxTuningDurationMs() const {
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS");
|
||||
if (env != nullptr) {
|
||||
return atoi(env);
|
||||
}
|
||||
return max_tuning_duration_ms_;
|
||||
}
|
||||
|
||||
void TuningContext::SetMaxTuningIterations(int max_iter) {
|
||||
max_tuning_iterations_ = max_iter;
|
||||
}
|
||||
|
||||
int TuningContext::GetMaxTuningIterations() const {
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS");
|
||||
if (env != nullptr) {
|
||||
return atoi(env);
|
||||
}
|
||||
return max_tuning_iterations_;
|
||||
}
|
||||
|
||||
void TuningContext::SetMaxWarmupDurationMs(int max_duration_ms) {
|
||||
max_warmup_duration_ms_ = max_duration_ms;
|
||||
}
|
||||
|
||||
int TuningContext::GetMaxWarmupDurationMs() const {
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS");
|
||||
if (env != nullptr) {
|
||||
return atoi(env);
|
||||
}
|
||||
return max_warmup_duration_ms_;
|
||||
}
|
||||
|
||||
void TuningContext::SetMaxWarmupIterations(int max_iter) {
|
||||
max_warmup_iterations_ = max_iter;
|
||||
}
|
||||
|
||||
int TuningContext::GetMaxWarmupIterations() const {
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS");
|
||||
if (env != nullptr) {
|
||||
return atoi(env);
|
||||
}
|
||||
return max_warmup_iterations_;
|
||||
}
|
||||
|
||||
void TuningContext::EnableTunableOpAndTuning() {
|
||||
EnableTunableOp();
|
||||
EnableTuning();
|
||||
}
|
||||
|
||||
void TuningContext::DisableTunableOpAndTuning() {
|
||||
DisableTunableOp();
|
||||
DisableTuning();
|
||||
}
|
||||
|
||||
TuningResultsManager& TuningContext::GetTuningResultsManager() {
|
||||
c10::call_once(manager_init_once_, [this]() {
|
||||
manager_initialized_ = true;
|
||||
auto filename = GetFilename();
|
||||
if (!filename.empty()) {
|
||||
ReadFile(filename);
|
||||
// attempt immediately to open file for writing to catch errors early
|
||||
std::ofstream file(filename, std::ios::out | std::ios::app);
|
||||
if (!file.good()) {
|
||||
TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
|
||||
}
|
||||
}
|
||||
});
|
||||
return manager_;
|
||||
}
|
||||
|
||||
TuningResultsValidator& TuningContext::GetTuningResultsValidator() {
|
||||
return validator_;
|
||||
}
|
||||
|
||||
TuningResults TuningContext::GetTuningResults() {
|
||||
TuningResults tr;
|
||||
tr.validators = GetTuningResultsValidator().GetAllValidators();
|
||||
tr.results = GetTuningResultsManager().Dump();
|
||||
return tr;
|
||||
}
|
||||
|
||||
TuningStatus TuningContext::LoadTuningResults(const TuningResults& tr) {
|
||||
TORCH_CHECK(GetTuningResultsValidator().ValidateAll(tr.validators));
|
||||
GetTuningResultsManager().Load(tr.results);
|
||||
return OK;
|
||||
}
|
||||
|
||||
void TuningContext::SetFilename(const std::string& filename) {
|
||||
filename_ = filename;
|
||||
}
|
||||
|
||||
std::string TuningContext::GetFilename() const {
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_FILENAME");
|
||||
std::string filename = (env == nullptr) ? filename_ : env;
|
||||
if (filename.empty()) {
|
||||
TUNABLE_LOG("no filename from TuningContext::GetFilename()");
|
||||
return filename; // empty string
|
||||
}
|
||||
|
||||
// Using static with lambda here so that we don't make a cuda call during static shutdown.
|
||||
// Do this the first and only time GetFilename() is called because it is called
|
||||
// the first time a TunableOp is instantiated but also during static destruction
|
||||
// when the cuda or hip runtime is no longer available.
|
||||
static std::string device = []() {
|
||||
return c10::str(int(c10::cuda::current_device()));
|
||||
}();
|
||||
|
||||
// differentiate filename based on device ordinal to avoid
|
||||
// use case of one process per device writing to same file
|
||||
|
||||
// does filename contain %d to insert device ordinal in specific location?
|
||||
const std::string TOKEN("%d");
|
||||
std::size_t found = filename.find(TOKEN);
|
||||
if (found != std::string::npos) {
|
||||
filename.replace(found, TOKEN.length(), device);
|
||||
}
|
||||
else {
|
||||
// no %d present, so append device ordinal before final '.'
|
||||
found = filename.rfind(".");
|
||||
if (found != std::string::npos) {
|
||||
filename.insert(found, device);
|
||||
}
|
||||
else {
|
||||
// all else fails, just prepend
|
||||
filename.insert(0, device);
|
||||
}
|
||||
}
|
||||
return filename;
|
||||
}
|
||||
|
||||
bool TuningContext::ReadFile(const std::string& filename) {
|
||||
TUNABLE_LOG("reading tuning results from ", filename);
|
||||
ResultsMap results;
|
||||
std::unordered_map<std::string, std::string> validators;
|
||||
std::string line;
|
||||
std::ifstream file(filename);
|
||||
if (!file) {
|
||||
TUNABLE_LOG("could not open ", filename, " for reading tuning results");
|
||||
return false;
|
||||
}
|
||||
while (std::getline(file, line)) {
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
std::string part;
|
||||
std::vector<std::string> parts;
|
||||
std::stringstream line_as_stream(line);
|
||||
while (std::getline(line_as_stream, part, ',')) {
|
||||
parts.push_back(part);
|
||||
}
|
||||
if (parts[0] == "Validator" && parts.size() >= 3) {
|
||||
validators[parts[1]] = parts[2];
|
||||
TUNABLE_LOG("Validator ", parts[1], "=", parts[2]);
|
||||
}
|
||||
else if (parts.size() >= 4) {
|
||||
results[parts[0]].emplace(parts[1], ResultEntry(parts[2], atof(parts[3].c_str())));
|
||||
}
|
||||
else if (parts.size() >= 3) {
|
||||
// the timestamp from the file is optional
|
||||
results[parts[0]].emplace(parts[1], ResultEntry(parts[2], 0));
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG("could not parse line: ", line);
|
||||
}
|
||||
}
|
||||
if (GetTuningResultsValidator().ValidateAll(validators) != FAIL) {
|
||||
manager_.Load(results);
|
||||
results_count_from_input_file_ = manager_.GetSize();
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG("results validator check failed");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TuningContext::WriteFile(const std::string& filename) {
|
||||
std::ofstream file(filename, std::ios::out | std::ios::trunc);
|
||||
if (!file.good()) {
|
||||
TUNABLE_LOG("error opening tuning results file for writing ", filename);
|
||||
return false;
|
||||
}
|
||||
auto validators = GetTuningResultsValidator().GetAllValidators();
|
||||
for (const auto& [key, val] : validators) {
|
||||
file << "Validator," << key << "," << val << std::endl;
|
||||
}
|
||||
auto results = GetTuningResultsManager().Dump();
|
||||
for (const auto& [op_sig, kernelmap] : results) {
|
||||
for (const auto& [param_sig, result] : kernelmap) {
|
||||
file << op_sig << "," << param_sig << "," << result << std::endl;
|
||||
}
|
||||
}
|
||||
file.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace at::cuda::tunable
|
205
aten/src/ATen/cuda/tunable/Tunable.h
Normal file
205
aten/src/ATen/cuda/tunable/Tunable.h
Normal file
@ -0,0 +1,205 @@
|
||||
// Original TunableOp is from onnxruntime.
|
||||
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
||||
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
//
|
||||
// Adapting TunableOp into PyTorch
|
||||
// Copyright (c) Advanced Micro Devices, Inc.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/CallOnce.h>
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
static void TunableLog(const std::string& msg) {
|
||||
static const char *env = getenv("PYTORCH_TUNABLEOP_VERBOSE");
|
||||
if (env != nullptr && strcmp(env, "1") == 0) {
|
||||
std::cerr << msg << std::endl;
|
||||
}
|
||||
}
|
||||
#define TUNABLE_LOG(...) TunableLog(c10::str(__VA_ARGS__))
|
||||
|
||||
enum TuningStatus {
|
||||
OK = 0,
|
||||
FAIL = 1,
|
||||
UNSUPPORTED = 2,
|
||||
};
|
||||
|
||||
// Mapping from params signature to kernel id
|
||||
class ResultEntry {
|
||||
public:
|
||||
explicit ResultEntry(const std::string& key, double time) : key_(key), time_(time) {}
|
||||
bool operator==(const ResultEntry& other) { return key_ == other.key_; }
|
||||
bool operator!=(const ResultEntry& other) { return key_ != other.key_; }
|
||||
operator std::string () { return key_; }
|
||||
friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry);
|
||||
static ResultEntry Null() { return ResultEntry("Null", 0.0); }
|
||||
static ResultEntry Default() { return ResultEntry("Default", 0.0); }
|
||||
|
||||
private:
|
||||
std::string key_;
|
||||
double time_;
|
||||
};
|
||||
|
||||
typedef std::unordered_map<std::string, ResultEntry> KernelMap;
|
||||
typedef std::unordered_map<std::string, KernelMap> ResultsMap;
|
||||
|
||||
struct TuningResults {
|
||||
// Validates if these results are compatible with the libraries
|
||||
std::unordered_map<std::string, std::string> validators;
|
||||
|
||||
// Mapping from Callable signature to Callable's tuning result
|
||||
ResultsMap results;
|
||||
};
|
||||
|
||||
class TuningResultsManager {
|
||||
public:
|
||||
TuningResultsManager() = default;
|
||||
~TuningResultsManager() = default;
|
||||
|
||||
KernelMap Lookup(const std::string& op_signature);
|
||||
|
||||
ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature);
|
||||
|
||||
inline void AddImpl(const std::string& op_signature,
|
||||
const std::string& params_signature,
|
||||
ResultEntry best,
|
||||
KernelMap& kernel_map);
|
||||
|
||||
void Add(const std::string& op_signature,
|
||||
const std::string& params_signature,
|
||||
ResultEntry best);
|
||||
|
||||
void Delete(const std::string& op_signature, const std::string& params_signature);
|
||||
|
||||
inline void DisjointMergeImpl(
|
||||
const std::string& op_signature,
|
||||
const KernelMap& kernel_map,
|
||||
/*out*/ ResultsMap& results);
|
||||
|
||||
void Load(const ResultsMap& results_to_load);
|
||||
|
||||
ResultsMap Dump();
|
||||
|
||||
void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map);
|
||||
|
||||
size_t GetSize();
|
||||
|
||||
private:
|
||||
std::mutex lock_;
|
||||
ResultsMap results_;
|
||||
};
|
||||
|
||||
class TuningResultsValidator {
|
||||
public:
|
||||
using GetFunc = std::function<std::string()>;
|
||||
using ValidateFunc = std::function<TuningStatus(const std::string&)>;
|
||||
using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;
|
||||
|
||||
TuningResultsValidator();
|
||||
~TuningResultsValidator() = default;
|
||||
|
||||
std::unordered_map<std::string, std::string> GetAllValidators() const;
|
||||
TuningStatus ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;
|
||||
void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
|
||||
|
||||
protected:
|
||||
std::string GetPyTorchVersion() const;
|
||||
TuningStatus ValidatePyTorchVersion(const std::string& value) const;
|
||||
|
||||
public:
|
||||
static constexpr const std::array mandatory_keys{"PT_VERSION"};
|
||||
|
||||
private:
|
||||
GetValidateFuncs validators_;
|
||||
};
|
||||
|
||||
class TuningContext {
|
||||
public:
|
||||
TuningContext();
|
||||
~TuningContext();
|
||||
TuningContext(TuningContext &) = delete;
|
||||
TuningContext(TuningContext &&) = delete;
|
||||
TuningContext &operator=(TuningContext &) = delete;
|
||||
TuningContext &operator=(TuningContext &&) = delete;
|
||||
|
||||
void EnableTunableOp();
|
||||
void DisableTunableOp();
|
||||
bool IsTunableOpEnabled() const;
|
||||
|
||||
void EnableTuning();
|
||||
void DisableTuning();
|
||||
bool IsTuningEnabled() const;
|
||||
|
||||
void SetMaxTuningDurationMs(int max_duration_ms);
|
||||
int GetMaxTuningDurationMs() const;
|
||||
|
||||
void SetMaxTuningIterations(int max_iter);
|
||||
int GetMaxTuningIterations() const;
|
||||
|
||||
void SetMaxWarmupDurationMs(int max_duration_ms);
|
||||
int GetMaxWarmupDurationMs() const;
|
||||
|
||||
void SetMaxWarmupIterations(int max_iter);
|
||||
int GetMaxWarmupIterations() const;
|
||||
|
||||
void EnableTunableOpAndTuning();
|
||||
void DisableTunableOpAndTuning();
|
||||
|
||||
TuningResultsManager& GetTuningResultsManager();
|
||||
|
||||
TuningResultsValidator& GetTuningResultsValidator();
|
||||
|
||||
TuningResults GetTuningResults();
|
||||
|
||||
TuningStatus LoadTuningResults(const TuningResults& tr);
|
||||
|
||||
void SetFilename(const std::string& filename);
|
||||
std::string GetFilename() const;
|
||||
|
||||
protected:
|
||||
bool ReadFile(const std::string& filename);
|
||||
bool WriteFile(const std::string& filename);
|
||||
|
||||
private:
|
||||
bool enable_;
|
||||
bool tuning_enable_;
|
||||
bool manager_initialized_;
|
||||
int max_tuning_duration_ms_;
|
||||
int max_tuning_iterations_;
|
||||
int max_warmup_duration_ms_;
|
||||
int max_warmup_iterations_;
|
||||
mutable TuningResultsManager manager_;
|
||||
mutable c10::once_flag manager_init_once_;
|
||||
TuningResultsValidator validator_;
|
||||
std::string filename_;
|
||||
size_t results_count_from_input_file_;
|
||||
};
|
||||
|
||||
TuningContext* getTuningContext();
|
||||
|
||||
class ITimer {
|
||||
public:
|
||||
ITimer() = default;
|
||||
virtual ~ITimer() = default;
|
||||
|
||||
virtual void Start() = 0;
|
||||
virtual void End() = 0;
|
||||
|
||||
/// Computes the elapsed time in milliseconds between Start() and End()
|
||||
virtual float Duration() = 0;
|
||||
};
|
||||
|
||||
} // namespace at::cuda::tunable
|
278
aten/src/ATen/cuda/tunable/TunableGemm.h
Normal file
278
aten/src/ATen/cuda/tunable/TunableGemm.h
Normal file
@ -0,0 +1,278 @@
|
||||
// Original TunableOp is from onnxruntime.
|
||||
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
||||
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
//
|
||||
// Adapting TunableOp into PyTorch
|
||||
// Copyright (c) Advanced Micro Devices, Inc.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/tunable/GemmCommon.h>
|
||||
#ifdef USE_ROCM
|
||||
#if ROCM_VERSION >= 50700
|
||||
#include <ATen/cuda/tunable/GemmHipblaslt.h>
|
||||
#endif
|
||||
#include <ATen/cuda/tunable/GemmRocblas.h>
|
||||
#endif
|
||||
#include <ATen/cuda/tunable/StreamTimer.h>
|
||||
#include <ATen/cuda/tunable/TunableOp.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <rocm-core/rocm_version.h>
|
||||
#endif
|
||||
|
||||
#define STRINGIFY(s) #s
|
||||
#define XSTRINGIFY(s) STRINGIFY(s)
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
template <typename T>
|
||||
class DefaultGemmOp : public Callable<GemmParams<T>> {
|
||||
public:
|
||||
TuningStatus Call(const GemmParams<T>* params) override {
|
||||
at::cuda::blas::gemm_internal<T>(
|
||||
params->transa, params->transb,
|
||||
params->m, params->n, params->k,
|
||||
params->alpha,
|
||||
params->a, params->lda,
|
||||
params->b, params->ldb,
|
||||
params->beta,
|
||||
params->c, params->ldc);
|
||||
return OK;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
|
||||
public:
|
||||
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
||||
at::cuda::blas::bgemm_internal<T>(
|
||||
params->transa, params->transb,
|
||||
params->m, params->n, params->k,
|
||||
params->alpha,
|
||||
params->a, params->lda, params->stride_a,
|
||||
params->b, params->ldb, params->stride_b,
|
||||
params->beta,
|
||||
params->c, params->ldc, params->stride_c,
|
||||
params->batch);
|
||||
return OK;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
bool IsZero(T v) {
|
||||
return v == 0.0f;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool IsZero(BFloat16 v) {
|
||||
return v.x == 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool IsZero(Half v) {
|
||||
return float(v) == 0.0f;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool IsZero(c10::complex<double> v) {
|
||||
return v == 0.0;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool IsZero(c10::complex<float> v) {
|
||||
return v == 0.0f;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string TypeName(T v) {
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string TypeName(float v) {
|
||||
return "float";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string TypeName(double v) {
|
||||
return "double";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string TypeName(BFloat16 v) {
|
||||
return "BFloat16";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string TypeName(Half v) {
|
||||
return "Half";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string TypeName(c10::complex<double> v) {
|
||||
return "c10::complex<double>";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string TypeName(c10::complex<float> v) {
|
||||
return "c10::complex<float>";
|
||||
}
|
||||
|
||||
|
||||
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
||||
class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
||||
public:
|
||||
GemmTunableOp() {
|
||||
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
|
||||
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
|
||||
#ifdef USE_ROCM
|
||||
for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
|
||||
this->RegisterOp(std::move(name), std::move(op));
|
||||
}
|
||||
|
||||
if (validators.find("ROCM_VERSION") == validators.end()) {
|
||||
std::string rocm_version = ROCM_BUILD_INFO;
|
||||
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
||||
"ROCM_VERSION",
|
||||
[rocm_version]() { return rocm_version; },
|
||||
[rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; });
|
||||
}
|
||||
|
||||
if (validators.find("GCN_ARCH_NAME") == validators.end()) {
|
||||
std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName;
|
||||
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
||||
"GCN_ARCH_NAME",
|
||||
[gcn_arch_name]() { return gcn_arch_name; },
|
||||
[gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; });
|
||||
}
|
||||
|
||||
if (validators.find("ROCBLAS_VERSION") == validators.end()) {
|
||||
std::string rocblas_version = c10::str(
|
||||
XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
|
||||
XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
|
||||
XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
|
||||
XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
|
||||
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
||||
"ROCBLAS_VERSION",
|
||||
[rocblas_version]() { return rocblas_version; },
|
||||
[rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env == nullptr || strcmp(env, "1") == 0) {
|
||||
// disallow tuning of hipblaslt with c10::complex
|
||||
if constexpr (
|
||||
!std::is_same_v<T, c10::complex<float>> &&
|
||||
!std::is_same_v<T, c10::complex<double>>) {
|
||||
for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
|
||||
this->RegisterOp(std::move(name), std::move(op));
|
||||
}
|
||||
}
|
||||
|
||||
if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
|
||||
std::string hipblaslt_version = c10::str(
|
||||
XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".",
|
||||
XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".",
|
||||
XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-",
|
||||
XSTRINGIFY(HIPBLASLT_VERSION_TWEAK));
|
||||
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
||||
"HIPBLASLT_VERSION",
|
||||
[hipblaslt_version]() { return hipblaslt_version; },
|
||||
[hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string Signature() override {
|
||||
return c10::str("GemmTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, BlasOp ALayout, BlasOp BLayout>
|
||||
class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
|
||||
public:
|
||||
GemmStridedBatchedTunableOp() {
|
||||
this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
|
||||
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
|
||||
#ifdef USE_ROCM
|
||||
for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
|
||||
this->RegisterOp(std::move(name), std::move(op));
|
||||
}
|
||||
|
||||
if (validators.find("ROCM_VERSION") == validators.end()) {
|
||||
std::string rocm_version = ROCM_BUILD_INFO;
|
||||
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
||||
"ROCM_VERSION",
|
||||
[rocm_version]() { return rocm_version; },
|
||||
[rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; });
|
||||
}
|
||||
|
||||
if (validators.find("GCN_ARCH_NAME") == validators.end()) {
|
||||
std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName;
|
||||
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
||||
"GCN_ARCH_NAME",
|
||||
[gcn_arch_name]() { return gcn_arch_name; },
|
||||
[gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; });
|
||||
}
|
||||
|
||||
if (validators.find("ROCBLAS_VERSION") == validators.end()) {
|
||||
std::string rocblas_version = c10::str(
|
||||
XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
|
||||
XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
|
||||
XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
|
||||
XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
|
||||
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
||||
"ROCBLAS_VERSION",
|
||||
[rocblas_version]() { return rocblas_version; },
|
||||
[rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env == nullptr || strcmp(env, "1") == 0) {
|
||||
// disallow tuning of hipblaslt with c10::complex
|
||||
if constexpr (
|
||||
!std::is_same_v<T, c10::complex<float>> &&
|
||||
!std::is_same_v<T, c10::complex<double>>) {
|
||||
for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
|
||||
this->RegisterOp(std::move(name), std::move(op));
|
||||
}
|
||||
}
|
||||
|
||||
if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
|
||||
std::string hipblaslt_version = c10::str(
|
||||
XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".",
|
||||
XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".",
|
||||
XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-",
|
||||
XSTRINGIFY(HIPBLASLT_VERSION_TWEAK));
|
||||
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
|
||||
"HIPBLASLT_VERSION",
|
||||
[hipblaslt_version]() { return hipblaslt_version; },
|
||||
[hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string Signature() override {
|
||||
return c10::str("GemmStridedBatchedTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
||||
}
|
||||
};
|
||||
|
||||
#undef XSTRINGIFY
|
||||
#undef STRINGIFY
|
||||
|
||||
} // namespace at::cuda::tunable
|
242
aten/src/ATen/cuda/tunable/TunableOp.h
Normal file
242
aten/src/ATen/cuda/tunable/TunableOp.h
Normal file
@ -0,0 +1,242 @@
|
||||
// Original TunableOp is from onnxruntime.
|
||||
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
|
||||
// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
//
|
||||
// Adapting TunableOp into PyTorch
|
||||
// Copyright (c) Advanced Micro Devices, Inc.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <cxxabi.h>
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
template <typename ParamsT>
|
||||
class Callable {
|
||||
public:
|
||||
Callable() = default;
|
||||
Callable(Callable&&) = default;
|
||||
virtual ~Callable() = default;
|
||||
virtual TuningStatus Call(const ParamsT*) {
|
||||
return FAIL;
|
||||
}
|
||||
virtual TuningStatus IsSupported(const ParamsT* params) {
|
||||
return Call(params);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ParamsT, typename TimerT>
|
||||
class TunableOp {
|
||||
public:
|
||||
TunableOp() = default;
|
||||
TunableOp(TunableOp&&) = default;
|
||||
virtual ~TunableOp() = default;
|
||||
|
||||
TuningStatus operator()(const ParamsT* params) {
|
||||
ResultEntry result = ResultEntry::Null();
|
||||
TuningContext* ctx = getTuningContext();
|
||||
if (ctx->IsTunableOpEnabled()) {
|
||||
auto& mgr = ctx->GetTuningResultsManager();
|
||||
auto op_sig = Signature();
|
||||
auto params_sig = params->Signature();
|
||||
result = mgr.Lookup(op_sig, params_sig);
|
||||
// If there is not previous tuning result been found, we do the tuning iff tuning is enabled
|
||||
if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) {
|
||||
result = FindFastest(params);
|
||||
mgr.Add(op_sig, params_sig, result);
|
||||
}
|
||||
}
|
||||
else {
|
||||
result = ResultEntry::Default();
|
||||
}
|
||||
if (result == ResultEntry::Null()) {
|
||||
TUNABLE_LOG("no result, using default");
|
||||
result = ResultEntry::Default();
|
||||
}
|
||||
auto iter = ops_.find(result);
|
||||
TORCH_CHECK(iter != ops_.end());
|
||||
return iter->second->Call(params);
|
||||
}
|
||||
|
||||
virtual std::string Signature() {
|
||||
// According to C++17 standard https://wg21.link/n4659 section 15.7.4
|
||||
// > if the operand of typeid refers to the
|
||||
// > object under construction or destruction, typeid yields the std::type_info object representing the constructor
|
||||
// > or destructor’s class.
|
||||
// So delay the op signature generation.
|
||||
c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
|
||||
return signature_;
|
||||
}
|
||||
|
||||
protected:
|
||||
void RegisterOp(const std::string& name, std::unique_ptr<Callable<ParamsT>> op) {
|
||||
this->op_names_.emplace_back(name);
|
||||
this->ops_.emplace(name, std::move(op));
|
||||
}
|
||||
|
||||
private:
|
||||
static void WarmUp(Callable<ParamsT> *op, ParamsT* param, size_t num_iter) {
|
||||
for (size_t i = 0; i < num_iter; i++) {
|
||||
TORCH_CHECK(op->Call(param) == OK);
|
||||
}
|
||||
}
|
||||
|
||||
static double Profile(Callable<ParamsT> *op, ParamsT* param, size_t num_iter) {
|
||||
TimerT timer{};
|
||||
timer.Start();
|
||||
for (size_t i = 0; i < num_iter; i++) {
|
||||
TORCH_CHECK(op->Call(param) == OK);
|
||||
}
|
||||
timer.End();
|
||||
return timer.Duration() / num_iter;
|
||||
}
|
||||
|
||||
protected:
|
||||
bool IsNumericsCheckEnabled() {
|
||||
static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||
if (env != nullptr && strcmp(env, "0") == 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual ResultEntry FindFastest(const ParamsT* params) {
|
||||
TuningContext* ctx = getTuningContext();
|
||||
auto op_sig = Signature();
|
||||
auto params_sig = params->Signature();
|
||||
TUNABLE_LOG("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
|
||||
auto min_duration_ms = std::numeric_limits<double>::infinity();
|
||||
std::string id_name = "Default";
|
||||
|
||||
// calcaulte a reference answer for numerical check
|
||||
ParamsT* reference_params = params->DeepCopy();
|
||||
TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
|
||||
|
||||
// need a copy of params to reuse
|
||||
ParamsT* reusable_params = params->DeepCopy();
|
||||
|
||||
for (size_t i = 0; i < op_names_.size(); i++) {
|
||||
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
|
||||
auto status = candidate->Call(reusable_params);
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (IsNumericsCheckEnabled()) {
|
||||
ParamsT* numerical_params = params->DeepCopy();
|
||||
WarmUp(candidate, numerical_params, 1);
|
||||
status = reference_params->NumericalCheck(numerical_params);
|
||||
numerical_params->Delete();
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// collect a small profile
|
||||
constexpr const int approx_num_iter = 3;
|
||||
auto approx_duration = Profile(candidate, reusable_params, approx_num_iter);
|
||||
// bail if too slow
|
||||
if (approx_duration > 2 * min_duration_ms) {
|
||||
TUNABLE_LOG("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
|
||||
// for warmup does user set max duration, max iters, or both?
|
||||
double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
|
||||
int max_warmup_iter = ctx->GetMaxWarmupIterations();
|
||||
int warmup_iter = 1; // default
|
||||
if (max_warmup_duration > 0) {
|
||||
int duration_iters = max_warmup_duration / approx_duration;
|
||||
if (max_warmup_iter > 0) {
|
||||
warmup_iter = std::min(max_warmup_iter, duration_iters);
|
||||
}
|
||||
else {
|
||||
warmup_iter = duration_iters;
|
||||
}
|
||||
}
|
||||
else if (max_warmup_iter > 0) {
|
||||
warmup_iter = max_warmup_iter;
|
||||
}
|
||||
|
||||
// for tuning does user set max duration, max iters, or both?
|
||||
double max_tuning_duration = ctx->GetMaxTuningDurationMs();
|
||||
int max_tuning_iter = ctx->GetMaxTuningIterations();
|
||||
int tuning_iter = 100; // default
|
||||
if (max_tuning_duration > 0) {
|
||||
int duration_iters = max_tuning_duration / approx_duration;
|
||||
if (max_tuning_iter > 0) {
|
||||
tuning_iter = std::min(max_tuning_iter, duration_iters);
|
||||
}
|
||||
else {
|
||||
tuning_iter = duration_iters;
|
||||
}
|
||||
}
|
||||
else if (max_tuning_iter > 0) {
|
||||
tuning_iter = max_tuning_iter;
|
||||
}
|
||||
|
||||
// do the full warmup followed by tuning
|
||||
double warmup_ms = warmup_iter * approx_duration;
|
||||
double tuning_ms = tuning_iter * approx_duration;
|
||||
TUNABLE_LOG("├──tuning using "
|
||||
"warmup iters ", warmup_iter, " [", warmup_ms, " ms] "
|
||||
"and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ",
|
||||
"instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
|
||||
WarmUp(candidate, reusable_params, warmup_iter);
|
||||
auto duration_ms = Profile(candidate, reusable_params, tuning_iter);
|
||||
if (duration_ms < min_duration_ms) {
|
||||
TUNABLE_LOG("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]);
|
||||
min_duration_ms = duration_ms;
|
||||
id_name = op_names_[i];
|
||||
}
|
||||
}
|
||||
|
||||
reusable_params->Delete();
|
||||
reference_params->Delete();
|
||||
|
||||
TUNABLE_LOG("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
|
||||
return ResultEntry(id_name, min_duration_ms);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string CreateSignature() {
|
||||
#ifndef _WIN32
|
||||
const auto* name = typeid(*this).name();
|
||||
char buf[256];
|
||||
size_t buf_len = 256;
|
||||
abi::__cxa_demangle(name, buf, &buf_len, nullptr);
|
||||
buf[255] = '\0';
|
||||
return buf;
|
||||
#else
|
||||
return typeid(*this).name();
|
||||
#endif
|
||||
}
|
||||
|
||||
mutable c10::once_flag signature_init_once_;
|
||||
std::string signature_;
|
||||
|
||||
std::unordered_map<std::string, std::unique_ptr<Callable<ParamsT>>> ops_;
|
||||
std::vector<std::string> op_names_;
|
||||
};
|
||||
|
||||
struct OpParams {
|
||||
OpParams() {}
|
||||
virtual ~OpParams() = default;
|
||||
virtual std::string Signature() const = 0;
|
||||
};
|
||||
|
||||
} // namespace at::cuda::tunable
|
@ -1410,6 +1410,8 @@ aten_cuda_cu_source_list = [
|
||||
"aten/src/ATen/cuda/CUDABlas.cpp",
|
||||
"aten/src/ATen/cuda/CUDASparseBlas.cpp",
|
||||
"aten/src/ATen/cuda/CublasHandlePool.cpp",
|
||||
"aten/src/ATen/cuda/tunable/StreamTimer.cpp",
|
||||
"aten/src/ATen/cuda/tunable/Tunable.cpp",
|
||||
"aten/src/ATen/native/cuda/Activation.cpp",
|
||||
"aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp",
|
||||
"aten/src/ATen/native/cuda/Blas.cpp",
|
||||
|
@ -1279,6 +1279,9 @@ if(USE_ROCM)
|
||||
if(HIPBLASLT_CUSTOM_COMPUTE_TYPE)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_COMPUTE_TYPE)
|
||||
endif()
|
||||
if(HIPBLASLT_HAS_GETINDEXFROMALGO)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_HAS_GETINDEXFROMALGO)
|
||||
endif()
|
||||
if(HIP_NEW_TYPE_ENUMS)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
|
||||
endif()
|
||||
|
@ -202,12 +202,12 @@ if(HIP_FOUND)
|
||||
"}\n"
|
||||
)
|
||||
|
||||
try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
try_compile(hipblaslt_compile_result_custom_datatype ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
||||
OUTPUT_VARIABLE hipblaslt_compile_output)
|
||||
|
||||
if(hipblaslt_compile_result)
|
||||
if(hipblaslt_compile_result_custom_datatype)
|
||||
set(HIPBLASLT_CUSTOM_DATA_TYPE ON)
|
||||
#message("hipblaslt is using custom data type: ${hipblaslt_compile_output}")
|
||||
message("hipblaslt is using custom data type")
|
||||
@ -227,12 +227,12 @@ if(HIP_FOUND)
|
||||
"}\n"
|
||||
)
|
||||
|
||||
try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
try_compile(hipblaslt_compile_result_custom_compute_type ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
||||
OUTPUT_VARIABLE hipblaslt_compile_output)
|
||||
|
||||
if(hipblaslt_compile_result)
|
||||
if(hipblaslt_compile_result_custom_compute_type)
|
||||
set(HIPBLASLT_CUSTOM_COMPUTE_TYPE ON)
|
||||
#message("hipblaslt is using custom compute type: ${hipblaslt_compile_output}")
|
||||
message("hipblaslt is using custom compute type")
|
||||
@ -241,6 +241,36 @@ if(HIP_FOUND)
|
||||
#message("hipblaslt is NOT using custom compute type: ${hipblaslt_compile_output}")
|
||||
message("hipblaslt is NOT using custom compute type")
|
||||
endif()
|
||||
|
||||
# check whether hipblaslt provides getIndexFromAlgo
|
||||
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_getIndexFromAlgo.cc")
|
||||
file(WRITE ${file} ""
|
||||
"#include <hipblaslt/hipblaslt.h>\n"
|
||||
"#include <hipblaslt/hipblaslt-ext.hpp>\n"
|
||||
"int main() {\n"
|
||||
" hipblasLtMatmulAlgo_t algo;\n"
|
||||
" return hipblaslt_ext::getIndexFromAlgo(algo);\n"
|
||||
" return 0;\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
try_compile(hipblaslt_compile_result_getindexfromalgo ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
CMAKE_FLAGS
|
||||
"-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
||||
"-DLINK_DIRECTORIES=${ROCM_PATH}/lib"
|
||||
LINK_LIBRARIES ${hipblaslt_LIBRARIES}
|
||||
COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__
|
||||
OUTPUT_VARIABLE hipblaslt_compile_output)
|
||||
|
||||
if(hipblaslt_compile_result_getindexfromalgo)
|
||||
set(HIPBLASLT_HAS_GETINDEXFROMALGO ON)
|
||||
#message("hipblaslt provides getIndexFromAlgo: ${hipblaslt_compile_output}")
|
||||
message("hipblaslt provides getIndexFromAlgo")
|
||||
else()
|
||||
set(HAS_GETINDEXFROMALGO OFF)
|
||||
#message("hipblaslt does not provide getIndexFromAlgo: ${hipblaslt_compile_output}")
|
||||
message("hipblaslt does not provide getIndexFromAlgo")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# check whether HIP declares new types
|
||||
|
2
setup.py
2
setup.py
@ -1166,6 +1166,7 @@ def main():
|
||||
"include/ATen/cuda/*.h",
|
||||
"include/ATen/cuda/detail/*.cuh",
|
||||
"include/ATen/cuda/detail/*.h",
|
||||
"include/ATen/cuda/tunable/*.h",
|
||||
"include/ATen/cudnn/*.h",
|
||||
"include/ATen/functorch/*.h",
|
||||
"include/ATen/ops/*.h",
|
||||
@ -1174,6 +1175,7 @@ def main():
|
||||
"include/ATen/hip/detail/*.cuh",
|
||||
"include/ATen/hip/detail/*.h",
|
||||
"include/ATen/hip/impl/*.h",
|
||||
"include/ATen/hip/tunable/*.h",
|
||||
"include/ATen/mps/*.h",
|
||||
"include/ATen/miopen/*.h",
|
||||
"include/ATen/detail/*.h",
|
||||
|
Reference in New Issue
Block a user