diff --git a/BUILD.bazel b/BUILD.bazel index 0afee2d8d71c..a2902c0e5e17 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -228,6 +228,7 @@ filegroup( [ "aten/src/ATen/cuda/*.cpp", "aten/src/ATen/cuda/detail/*.cpp", + "aten/src/ATen/cuda/tunable/*.cpp", "aten/src/ATen/cudnn/*.cpp", "aten/src/ATen/native/cuda/*.cpp", "aten/src/ATen/native/cuda/linalg/*.cpp", diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 3e424a79f343..bf425af5fa9d 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -60,11 +60,11 @@ endif() file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h") file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp") -file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh") -file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp") +file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh" "cuda/tunable/*.cuh" "cuda/tunable/*.h") +file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp" "cuda/tunable/*.cpp") file(GLOB cuda_nvrtc_stub_h "cuda/nvrtc_stub/*.h") file(GLOB cuda_nvrtc_stub_cpp "cuda/nvrtc_stub/*.cpp") -file(GLOB cuda_cu "cuda/*.cu" "cuda/detail/*.cu") +file(GLOB cuda_cu "cuda/*.cu" "cuda/detail/*.cu" "cuda/tunable/*.cu") file(GLOB cudnn_h "cudnn/*.h" "cudnn/*.cuh") file(GLOB cudnn_cpp "cudnn/*.cpp") file(GLOB ops_h "ops/*.h") @@ -72,10 +72,10 @@ file(GLOB ops_h "ops/*.h") file(GLOB xpu_h "xpu/*.h" "xpu/detail/*.h") file(GLOB xpu_cpp "xpu/*.cpp" "xpu/detail/*.cpp") -file(GLOB hip_h "hip/*.h" "hip/detail/*.h" "hip/*.cuh" "hip/detail/*.cuh" "hip/impl/*.h") -file(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp" "hip/impl/*.cpp") +file(GLOB hip_h "hip/*.h" "hip/detail/*.h" "hip/*.cuh" "hip/detail/*.cuh" "hip/impl/*.h" "hip/tunable/*.cuh" "hip/tunable/*.h") +file(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp" "hip/impl/*.cpp" "hip/tunable/*.cpp") list(REMOVE_ITEM hip_cpp "${CMAKE_CURRENT_SOURCE_DIR}/hip/detail/LazyNVRTC.cpp") -file(GLOB hip_hip "hip/*.hip" "hip/detail/*.hip" "hip/impl/*.hip") +file(GLOB hip_hip "hip/*.hip" "hip/detail/*.hip" "hip/impl/*.hip" "hip/tunable/*.hip") file(GLOB hip_nvrtc_stub_h "hip/nvrtc_stub/*.h") file(GLOB hip_nvrtc_stub_cpp "hip/nvrtc_stub/*.cpp") file(GLOB miopen_h "miopen/*.h") diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 1a358626f418..0a3de5f9d775 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include #include @@ -232,7 +234,7 @@ namespace at::cuda::blas { } while (0) template <> -void bgemm(CUDABLAS_BGEMM_ARGTYPES(double)) { +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(double)) { // See Note [Writing Nondeterministic Operations] globalContext().alertCuBLASConfigNotDeterministic(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); @@ -245,7 +247,7 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(double)) { } template <> -void bgemm(CUDABLAS_BGEMM_ARGTYPES(float)) { +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(float)) { // See Note [Writing Nondeterministic Operations] globalContext().alertCuBLASConfigNotDeterministic(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); @@ -258,7 +260,7 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(float)) { } template <> -void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { +void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { // See Note [Writing Nondeterministic Operations] globalContext().alertCuBLASConfigNotDeterministic(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); @@ -273,7 +275,7 @@ void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) } template <> -void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { +void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { // See Note [Writing Nondeterministic Operations] globalContext().alertCuBLASConfigNotDeterministic(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); @@ -288,7 +290,7 @@ void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { } template <> -void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { // See Note [Writing Nondeterministic Operations] globalContext().alertCuBLASConfigNotDeterministic(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); @@ -335,7 +337,7 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { } template <> -void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { // See Note [Writing Nondeterministic Operations] globalContext().alertCuBLASConfigNotDeterministic(); BGEMM_CHECK_ARGVALUES(at::BFloat16); @@ -361,8 +363,119 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } +template +inline void bgemm_tunable(CUDABLAS_BGEMM_ARGTYPES(DType)) { + tunable::GemmStridedBatchedParams params; + params.transa = transa; + params.transb = transb; + params.m = m; + params.n = n; + params.k = k; + params.alpha = alpha; + params.a = a; + params.lda = lda; + params.stride_a = stridea; + params.b = b; + params.ldb = ldb; + params.stride_b = strideb; + params.beta = beta; + params.c = c; + params.ldc = ldc; + params.stride_c = stridec; + params.batch = num_batches; + + bool transa_ = ((transa != 'n') && (transa != 'N')); + bool transb_ = ((transb != 'n') && (transb != 'N')); + + if (transa_ && transb_) { + static tunable::GemmStridedBatchedTunableOp bgemm{}; + bgemm(¶ms); + } + else if (transa_ && !transb_) { + static tunable::GemmStridedBatchedTunableOp bgemm{}; + bgemm(¶ms); + } + else if (!transa_ && transb_) { + static tunable::GemmStridedBatchedTunableOp bgemm{}; + bgemm(¶ms); + } + else if (!transa_ && !transb_) { + static tunable::GemmStridedBatchedTunableOp bgemm{}; + bgemm(¶ms); + } + else { + TORCH_CHECK(false, "unreachable"); + } +} + template <> -void gemm(CUDABLAS_GEMM_ARGTYPES(double)) { +void bgemm(CUDABLAS_BGEMM_ARGTYPES(double)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + bgemm_tunable(CUDABLAS_BGEMM_ARGS(double)); + } + else { + bgemm_internal(CUDABLAS_BGEMM_ARGS(double)); + } +} + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(float)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + bgemm_tunable(CUDABLAS_BGEMM_ARGS(float)); + } + else { + bgemm_internal(CUDABLAS_BGEMM_ARGS(float)); + } +} + +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + bgemm_tunable>(CUDABLAS_BGEMM_ARGS(c10::complex)); + } + else { + bgemm_internal>(CUDABLAS_BGEMM_ARGS(c10::complex)); + } +} + +template <> +void bgemm>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + bgemm_tunable>(CUDABLAS_BGEMM_ARGS(c10::complex)); + } + else { + bgemm_internal>(CUDABLAS_BGEMM_ARGS(c10::complex)); + } +} + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + bgemm_tunable(CUDABLAS_BGEMM_ARGS(at::Half)); + } + else { + bgemm_internal(CUDABLAS_BGEMM_ARGS(at::Half)); + } +} + +template <> +void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + bgemm_tunable(CUDABLAS_BGEMM_ARGS(at::BFloat16)); + } + else { + bgemm_internal(CUDABLAS_BGEMM_ARGS(at::BFloat16)); + } +} + +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) { // See Note [Writing Nondeterministic Operations] globalContext().alertCuBLASConfigNotDeterministic(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); @@ -375,7 +488,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(double)) { } template <> -void gemm(CUDABLAS_GEMM_ARGTYPES(float)) { +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) { // See Note [Writing Nondeterministic Operations] globalContext().alertCuBLASConfigNotDeterministic(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); @@ -388,7 +501,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(float)) { } template <> -void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex)) { +void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex)) { // See Note [Writing Nondeterministic Operations] globalContext().alertCuBLASConfigNotDeterministic(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); @@ -403,7 +516,7 @@ void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex)) { } template <> -void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex)) { +void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex)) { // See Note [Writing Nondeterministic Operations] globalContext().alertCuBLASConfigNotDeterministic(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); @@ -418,7 +531,7 @@ void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex)) { } template <> -void gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) { +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) { // See Note [Writing Nondeterministic Operations] globalContext().alertCuBLASConfigNotDeterministic(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); @@ -514,7 +627,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) { } template <> -void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { globalContext().alertCuBLASConfigNotDeterministic(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasOperation_t opa = _cublasOpFromChar(transa); @@ -558,6 +671,113 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); } +template +inline void gemm_tunable(CUDABLAS_GEMM_ARGTYPES(DType)) { + tunable::GemmParams params; + params.transa = transa; + params.transb = transb; + params.m = m; + params.n = n; + params.k = k; + params.alpha = alpha; + params.a = a; + params.lda = lda; + params.b = b; + params.ldb = ldb; + params.beta = beta; + params.c = c; + params.ldc = ldc; + + bool transa_ = ((transa != 'n') && (transa != 'N')); + bool transb_ = ((transb != 'n') && (transb != 'N')); + + if (transa_ && transb_) { + static tunable::GemmTunableOp gemm{}; + gemm(¶ms); + } + else if (transa_ && !transb_) { + static tunable::GemmTunableOp gemm{}; + gemm(¶ms); + } + else if (!transa_ && transb_) { + static tunable::GemmTunableOp gemm{}; + gemm(¶ms); + } + else if (!transa_ && !transb_) { + static tunable::GemmTunableOp gemm{}; + gemm(¶ms); + } + else { + TORCH_CHECK(false, "unreachable"); + } +} + +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(double)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + gemm_tunable(CUDABLAS_GEMM_ARGS(double)); + } + else { + gemm_internal(CUDABLAS_GEMM_ARGS(double)); + } +} + +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(float)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + gemm_tunable(CUDABLAS_GEMM_ARGS(float)); + } + else { + gemm_internal(CUDABLAS_GEMM_ARGS(float)); + } +} + +template <> +void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + gemm_tunable>(CUDABLAS_GEMM_ARGS(c10::complex)); + } + else { + gemm_internal>(CUDABLAS_GEMM_ARGS(c10::complex)); + } +} + +template <> +void gemm>(CUDABLAS_GEMM_ARGTYPES(c10::complex)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + gemm_tunable>(CUDABLAS_GEMM_ARGS(c10::complex)); + } + else { + gemm_internal>(CUDABLAS_GEMM_ARGS(c10::complex)); + } +} + +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + gemm_tunable(CUDABLAS_GEMM_ARGS(at::Half)); + } + else { + gemm_internal(CUDABLAS_GEMM_ARGS(at::Half)); + } +} + +template <> +void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + gemm_tunable(CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else { + gemm_internal(CUDABLAS_GEMM_ARGS(at::BFloat16)); + } +} + #if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) #if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000 diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index ee3b41b4376a..eb12bb350c59 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -44,6 +44,8 @@ private: const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type beta,\ Dtype *c, int64_t ldc +#define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc + template inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) { AT_ERROR("at::cuda::blas::gemm: not implemented for ", typeid(Dtype).name()); @@ -62,6 +64,24 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)); template <> void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); +template +inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + AT_ERROR("at::cuda::blas::gemm_internal: not implemented for ", typeid(Dtype).name()); +} + +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)); +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)); +template <> +void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex)); +template <> +void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex)); +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)); +template <> +void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); + #if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700) enum GEMMAndBiasActivationEpilogue { None, @@ -131,6 +151,9 @@ void scaled_gemm( const Dtype *b, int64_t ldb, int64_t strideb, \ at::opmath_type beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches +#define CUDABLAS_BGEMM_ARGS(Dtype) \ + transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches + template inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name()); @@ -149,6 +172,24 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::Half)); template <> void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); +template +inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { + AT_ERROR("at::cuda::blas::bgemm_internal: not implemented for ", typeid(Dtype).name()); +} + +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(double)); +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(float)); +template <> +void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex)); +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::Half)); +template <> +void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)); + #if defined(USE_ROCM) && ROCM_VERSION <= 50500 // ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not. #define CUDABLAS_TRSM_ARGTYPES(Dtype) \ diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h new file mode 100644 index 000000000000..ab3ed56796b2 --- /dev/null +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -0,0 +1,174 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include + +#include +#include +#include + +namespace at::cuda::tunable { + +enum class BlasOp { + N = 0, + T = 1 +}; + +inline std::string BlasOpToString(BlasOp op) { + switch (op) { + case BlasOp::N: + return "N"; + case BlasOp::T: + return "T"; + } + TORCH_CHECK(false, "unrecognized BlasOp"); + return "N"; +} + +template +struct GemmParams : OpParams { + std::string Signature() const override { + return c10::str(transa, transb, "_", m, "_", n, "_", k); + } + + GemmParams* DeepCopy() const { + GemmParams* copy = new GemmParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = m * n * sizeof(T); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + } + + TuningStatus NumericalCheck(GemmParams *other) { + auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType::value).device(at::kCUDA); + // comparison done as 1D tensor + at::Tensor ref = at::from_blob(c, {m*n}, options); + at::Tensor oth = at::from_blob(other->c, {m*n}, options); + at::Tensor ref_float = ref.to(at::kFloat); + at::Tensor oth_float = oth.to(at::kFloat); + std::vector atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; + std::vector rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; + double last_succeed_atol = 1; + double last_succeed_rtol = 1; + for (auto& atol : atols) { + for (auto& rtol : rtols) { + if (at::allclose(ref_float, oth_float, rtol, atol)) { + last_succeed_atol = atol; + last_succeed_rtol = rtol; + } + } + } + if (last_succeed_atol == 1) { + return FAIL; + } + else { + TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol); + } + + return OK; + } + + char transa; + char transb; + int64_t m; + int64_t n; + int64_t k; + at::opmath_type alpha; + const T* a; + int64_t lda; + const T* b; + int64_t ldb; + at::opmath_type beta; + T* c; + int64_t ldc; +}; + +template +struct GemmStridedBatchedParams : OpParams { + std::string Signature() const override { + return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); + } + + GemmStridedBatchedParams* DeepCopy() const { + GemmStridedBatchedParams* copy = new GemmStridedBatchedParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = batch * stride_c * sizeof(T); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + } + + TuningStatus NumericalCheck(GemmStridedBatchedParams *other) { + auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType::value).device(at::kCUDA); + // comparison done as 1D tensor + at::Tensor ref = at::from_blob(c, {batch*stride_c}, options); + at::Tensor oth = at::from_blob(other->c, {batch*stride_c}, options); + at::Tensor ref_float = ref.to(at::kFloat); + at::Tensor oth_float = oth.to(at::kFloat); + std::vector atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; + std::vector rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; + double last_succeed_atol = 1; + double last_succeed_rtol = 1; + for (auto& atol : atols) { + for (auto& rtol : rtols) { + if (at::allclose(ref_float, oth_float, rtol, atol)) { + last_succeed_atol = atol; + last_succeed_rtol = rtol; + } + } + } + if (last_succeed_atol == 1) { + return FAIL; + } + else { + TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol); + } + + return OK; + } + + char transa; + char transb; + int64_t m; + int64_t n; + int64_t k; + at::opmath_type alpha; + const T* a; + int64_t lda; + int64_t stride_a; + const T* b; + int64_t ldb; + int64_t stride_b; + at::opmath_type beta; + T* c; + int64_t ldc; + int64_t stride_c; + int64_t batch; +}; + +} // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h new file mode 100644 index 000000000000..fbed75d513b5 --- /dev/null +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -0,0 +1,379 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#define TORCH_HIPBLASLT_CHECK(EXPR) \ + do { \ + hipblasStatus_t __err = EXPR; \ + TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \ + "hipblaslt error: ", \ + hipblasStatusToString(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +namespace at::cuda::tunable { + +#ifdef HIPBLASLT_HAS_GETINDEXFROMALGO +#define GETINDEXFROMALGO(algo) hipblaslt_ext::getIndexFromAlgo(algo) +#else +static int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo) { + int* algo_ptr = (int*)algo.data; + if(*algo_ptr < 0) { + return -1; + } + return *algo_ptr; +} +#define GETINDEXFROMALGO(algo) getIndexFromAlgo(algo) +#endif + +#ifdef HIPBLASLT_CUSTOM_COMPUTE_TYPE +#define COMPUTE_TYPE_32 HIPBLASLT_COMPUTE_F32 +#else +#define COMPUTE_TYPE_32 HIPBLAS_COMPUTE_32F +#endif + +#ifdef HIPBLASLT_CUSTOM_DATA_TYPE + +template +constexpr hipblasltDatatype_t HipBlasDataTypeFor(); + +template <> +constexpr hipblasltDatatype_t HipBlasDataTypeFor() { + return HIPBLASLT_R_32F; +} + +template <> +constexpr hipblasltDatatype_t HipBlasDataTypeFor() { + return HIPBLASLT_R_16F; +} + +template <> +constexpr hipblasltDatatype_t HipBlasDataTypeFor() { + return HIPBLASLT_R_16B; +} + +template <> +constexpr hipblasltDatatype_t HipBlasDataTypeFor() { + return HIPBLASLT_R_64F; +} + +#define DATA_TYPE_R_32 HIPBLASLT_R_32F + +#else + +template +constexpr hipblasDatatype_t HipBlasDataTypeFor(); + +template <> +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_32F; +} + +template <> +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_16F; +} + +template <> +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_16B; +} + +template <> +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_64F; +} + +#ifdef HIPBLAS_V2 +#define DATA_TYPE_R_32 HIP_R_32F +#else +#define DATA_TYPE_R_32 HIPBLAS_R_32F +#endif + +#endif + +template +int GetBatchFromParams(const ParamsT* params) { + return 1; +} + +template +int GetBatchFromParams(const GemmStridedBatchedParams* params) { + return params->batch; +} + +template +int GetStrideAFromParams(const ParamsT* params) { + return 1; +} + +template +int GetStrideAFromParams(const GemmStridedBatchedParams* params) { + return params->stride_a; +} + +template +int GetStrideBFromParams(const ParamsT* params) { + return 1; +} + +template +int GetStrideBFromParams(const GemmStridedBatchedParams* params) { + return params->stride_b; +} + +template +int GetStrideCFromParams(const ParamsT* params) { + return 1; +} + +template +int GetStrideCFromParams(const GemmStridedBatchedParams* params) { + return params->stride_c; +} + +static hipblasOperation_t _hipblasOpFromChar(char op) { + switch (op) { + case 'n': + case 'N': + return HIPBLAS_OP_N; + case 't': + case 'T': + return HIPBLAS_OP_T; + case 'c': + case 'C': + return HIPBLAS_OP_C; + } + AT_ERROR( + "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); +} + +static char _charFromhipblasOp(hipblasOperation_t op) { + switch (op) { + case HIPBLAS_OP_N: + return 'N'; + case HIPBLAS_OP_T: + return 'T'; + case HIPBLAS_OP_C: + return 'C'; + } + AT_ERROR( + "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`"); +} + +static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) { + if (layout == BlasOp::N) { + return HIPBLAS_OP_N; + } + return HIPBLAS_OP_T; +} + +static size_t GetHipblasltWorkspaceSize() { + static const char * env = getenv("HIPBLASLT_WORKSPACE_SIZE"); + // 256MB is max workspace size allowed for hipblaslt + // hipblaslt-bench uses 32MB + // recommendation from hipblaslt author was 76MB + size_t workspace_size = 2*128*1024*1024; // default 256MB + if (env) { + try { + workspace_size = std::stoi(env); + } catch(std::invalid_argument const& e) { + TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,", + " using default workspace size of ", workspace_size, " bytes."); + } catch(std::out_of_range const& e) { + TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,", + " using default workspace size of ", workspace_size, " bytes."); + } + } + return workspace_size; +} + +template +class HipblasltGemmOp : public Callable { + public: + HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {} + + TuningStatus Call(const ParamsT* params) override { + hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); + hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); + auto in_out_datatype = HipBlasDataTypeFor(); + auto opa = _hipblasOpFromChar(params->transa); + auto opb = _hipblasOpFromChar(params->transb); + + TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen"); + + float alpha = static_cast(params->alpha); + float beta = static_cast(params->beta); + + hipblasLtMatrixLayout_t mat_a, mat_b, mat_c; + hipblasLtMatmulDesc_t matmul; + if (opa == HIPBLAS_OP_N) { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->m, params->k, params->lda)); + } + else { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->k, params->m, params->lda)); + } + if (opb == HIPBLAS_OP_N) { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->k, params->n, params->ldb)); + } + else { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->n, params->k, params->ldb)); + } + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescCreate(&matmul, COMPUTE_TYPE_32, DATA_TYPE_R_32)); + + int batch = GetBatchFromParams(params); + if (batch > 1) { + int64_t stride_a = GetStrideAFromParams(params); + int64_t stride_b = GetStrideBFromParams(params); + int64_t stride_c = GetStrideCFromParams(params); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c))); + } + + TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &opa, sizeof(int32_t))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &opb, sizeof(int32_t))); + + size_t workspace_size = GetHipblasltWorkspaceSize(); + + auto op_handle = at::cuda::getCurrentCUDABlasLtHandle(); + + size_t ret_workspace_size = 0; + auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle, + matmul, + &alpha, + mat_a, + mat_b, + &beta, + mat_c, + mat_c, + algo_, + ret_workspace_size); + + if (status == HIPBLAS_STATUS_SUCCESS) { + if (ret_workspace_size >= workspace_size) { + //TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " workspace too large"); + return FAIL; + } + } + else { + //TUNABLE_LOG("[hipBLASLt] Solution #", algo_index, " not supported"); + return FAIL; + } + + void* workspace_buffer = nullptr; + if (workspace_size > 0) { + workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size); + } + + TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle, + matmul, + &alpha, + params->a, + mat_a, + params->b, + mat_b, + &beta, + params->c, + mat_c, + params->c, + mat_c, + &algo_, + workspace_buffer, + workspace_size, + at::cuda::getCurrentCUDAStream())); + + TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c)); + if (workspace_size > 0) { + c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer); + } + return OK; + } + + private: + hipblasLtMatmulAlgo_t algo_; +}; + +template +auto GetHipBlasLtTypeStringAndOps() { + hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); + hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); + auto in_out_datatype = HipBlasDataTypeFor(); + std::vector heuristic_result; + + hipblasLtHandle_t handle; + TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle)); + TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle, + hipblaslt_ext::GemmType::HIPBLASLT_GEMM, + transa_outer, + transb_outer, + in_out_datatype, + in_out_datatype, + in_out_datatype, + in_out_datatype, + COMPUTE_TYPE_32, + heuristic_result)); + TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle)); + + // Sort heuristic_result by algo index to make sure the order of returned algos is deterministic. + std::sort(heuristic_result.begin(), + heuristic_result.end(), + [](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) { + return GETINDEXFROMALGO(a.algo) < GETINDEXFROMALGO(b.algo); + }); + + int returned_algo_count = heuristic_result.size(); + std::vector>>> ret; + for (int i = 0; i < returned_algo_count; i++) { + auto algo = heuristic_result[i].algo; + int algo_index = GETINDEXFROMALGO(algo); + auto callable = std::make_unique>(algo); + std::string type_string = c10::str( + "Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index); + ret.emplace_back(type_string, std::move(callable)); + } + + return ret; +} + +template +auto GetHipBlasLtGemmTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +template +auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +#undef TORCH_HIPBLASLT_CHECK +#undef GETINDEXFROMALGO +#undef COMPUTE_TYPE_32 +#undef DATA_TYPE_R_32 + +} // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/GemmRocblas.h b/aten/src/ATen/cuda/tunable/GemmRocblas.h new file mode 100644 index 000000000000..f096ff00fd9b --- /dev/null +++ b/aten/src/ATen/cuda/tunable/GemmRocblas.h @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#define ROCBLAS_BETA_FEATURES_API +#include + +#define TORCH_ROCBLAS_CHECK(EXPR) \ + do { \ + rocblas_status __err = EXPR; \ + TORCH_CHECK(__err == rocblas_status_success, \ + "rocblas error: ", \ + rocblas_status_to_string(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +namespace at::cuda::tunable { + +template +constexpr rocblas_datatype RocBlasDataTypeFor(); + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f64_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f16_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_bf16_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor>() { + return rocblas_datatype_f32_c; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor>() { + return rocblas_datatype_f64_c; +} + +template +constexpr rocblas_datatype RocBlasComputeTypeFor(); + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + return rocblas_datatype_f64_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + // Note that we're returning the _compute_ type for a given datatype. + // As of 12/2022, using compute type FP16 for 16-bit floats was much + // slower than using compute type FP32. So we use FP32 compute even for + // FP16 datatypes. This is how GEMM is implemented even in the function + // rocblasGemmHelper (see fpgeneric.h) + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + // Note that we're returning the _compute_ type for a given datatype. + // As of 12/2022, using compute type FP16 for 16-bit floats was much + // slower than using compute type FP32. So we use FP32 compute even for + // BF16 datatypes. This is how GEMM is implemented even in the function + // rocblasGemmHelper (see fpgeneric.h) + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor>() { + return rocblas_datatype_f32_c; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor>() { + return rocblas_datatype_f64_c; +} + +template +auto DoCastForHalfOrBfloat16(const T fp) { + return fp; +} + +template <> +inline auto DoCastForHalfOrBfloat16(const Half fp) { + // alpha and beta should be the same as compute_type, in Half case it is float. + float h = fp; + return h; +} + +template <> +inline auto DoCastForHalfOrBfloat16(const BFloat16 fp) { + // alpha and beta should be the same as compute_type, in bfloat16 case it is float. + float h = fp; + return h; +} + +static rocblas_operation _rocblasOpFromChar(char op) { + switch (op) { + case 'n': + case 'N': + return rocblas_operation_none; + case 't': + case 'T': + return rocblas_operation_transpose; + case 'c': + case 'C': + return rocblas_operation_conjugate_transpose; + } + AT_ERROR( + "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); +} + +template +class RocblasGemmOp : public Callable> { + public: + RocblasGemmOp(int solution) : solution_{solution} {} + + TuningStatus Call(const GemmParams* params) override { + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + auto h_a = DoCastForHalfOrBfloat16(params->alpha); + auto h_b = DoCastForHalfOrBfloat16(params->beta); + auto status = rocblas_gemm_ex( + (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(), + _rocblasOpFromChar(params->transa), + _rocblasOpFromChar(params->transb), + params->m, params->n, params->k, + &h_a, + params->a, input_output_type, params->lda, + params->b, input_output_type, params->ldb, + &h_b, + params->c, input_output_type, params->ldc, + params->c, input_output_type, params->ldc, + compute_type, + rocblas_gemm_algo_solution_index, + solution_, + rocblas_gemm_flags_none); + if (status != rocblas_status_success) { + return FAIL; + } + return OK; + } + + private: + int solution_; +}; + +template +auto GetRocBlasGemmTypeStringAndOps() { + rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(); + int solution_size; + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + // Get the number of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + nullptr, + &solution_size)); + std::vector solutions(solution_size); + // Get the list of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + solutions.data(), + &solution_size)); + // Sort the solutions in ascending order to make the solution vector deterministic across runs + std::sort(solutions.begin(), solutions.end()); + + std::vector>>>> ret; + for (size_t i = 0; i < solutions.size(); ++i) { + auto callable = std::make_unique>(solutions[i]); + ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable))); + } + return ret; +} + +template +class RocblasGemmStridedBatchedOp : public Callable> { + public: + RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {} + + TuningStatus Call(const GemmStridedBatchedParams* params) override { + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + auto h_a = DoCastForHalfOrBfloat16(params->alpha); + auto h_b = DoCastForHalfOrBfloat16(params->beta); + auto status = rocblas_gemm_strided_batched_ex( + (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(), + _rocblasOpFromChar(params->transa), + _rocblasOpFromChar(params->transb), + params->m, params->n, params->k, + &h_a, + params->a, input_output_type, params->lda, params->stride_a, + params->b, input_output_type, params->ldb, params->stride_b, + &h_b, + params->c, input_output_type, params->ldc, params->stride_c, + params->c, input_output_type, params->ldc, params->stride_c, + params->batch, + compute_type, + rocblas_gemm_algo_solution_index, + solution_, + rocblas_gemm_flags_none); + if (status != rocblas_status_success) { + return FAIL; + } + return OK; + } + + private: + int solution_; +}; + +template +auto GetRocBlasGemmStridedBatchedTypeStringAndOps() { + rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(); + int solution_size; + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + // Get the number of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + nullptr, + &solution_size)); + std::vector solutions(solution_size); + // Get the list of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + solutions.data(), + &solution_size)); + // Sort the solutions in ascending order to make the solution vector deterministic across runs + std::sort(solutions.begin(), solutions.end()); + + std::vector>>>> ret; + for (size_t i = 0; i < solutions.size(); ++i) { + auto callable = std::make_unique>(solutions[i]); + ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable))); + } + return ret; +} + +} // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/README.md b/aten/src/ATen/cuda/tunable/README.md new file mode 100644 index 000000000000..364e6975c6c6 --- /dev/null +++ b/aten/src/ATen/cuda/tunable/README.md @@ -0,0 +1,88 @@ +# TunableOp + +This directory implements a TunableOp interface. + +Some operations, such as GEMMs, could be implemented using more than one library or more than one technique. For +example, a GEMM could be implemented for CUDA or ROCm using either the blas or blasLt libraries. Further, ROCm's +rocblas and hipblaslt libraries allow the user to query for all possible algorithms and then choose one. How does one +know which implementation is the fastest and should be chosen? That's what TunableOp provides. + +The behavior of TunableOp is currently easily manipulated through environment variables, though you could use the C++ +interface of at::cuda::tunable::getTuningContext(). A Python interface to the TuningContext does not yet exist. + +Currently only a TunableGemm for ROCm is implemented. Any call to at::cuda::blas::gemm() can optionally use the +TunableGemm. Calling gemm() for a given set of input arguments (transa, transb, m, n, k) will attempt to use the +fastest available implementation. + +## Environment Variables + +#### PYTORCH_TUNABLEOP_ENABLED +Default is 0. Set to 1 to enable. +This is the big on/off switch for all TunableOp implementations. + +#### PYTORCH_TUNABLEOP_TUNING +Default is 1. Set to 0 to disable. +When enabled, if a tuned entry isn't found, run the tuning step and record the entry. + +#### PYTORCH_TUNABLEOP_VERBOSE +Default is 0. Set to 1 to enable. +This will produce a lot of diagnostic messages but may be useful to see if TunableOp is being used at all. +Otherwise, TunableOp is completely silent unless there is a warning or error during its use. + +#### PYTORCH_TUNABLEOP_FILENAME +Default is 'tunableop_results.csv'. If you provide a filename, the TuningContext will attempt to read it the first time +the context is used. If tuning is enabled and new tunings are discovered, it will also write out to this same filename +with all tunings, both the ones it read in at startup as well as the new ones found at runtime. This can be used, for +example, to build up a tunings file across many workloads by reusing the same file. Unsetting this variable is not +recommended but can be done, in which case the tuning results will not be saved. + +#### PYTORCH_TUNABLEOP_NUMERICAL_CHECK +Default is 1. Set to 0 to disable. Compare the results of each possible solution against the default solution and reject +those with low accuracy. + +#### PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED +Default is 1. Set to 0 to disable hipblaslt being considered during tuning. + +### Tuning Iterations +By default, each possible solution for a given operator will be run for either 100 iterations or as many iterations can +be run within 30ms, whichever is smaller. Its average execution will be calculated. The fastest solution is chosen. In +addition, a set of warm up iterations can optionally be run prior to the timed iterations. The following environment +variables can be used to set either the maximum number of iterations to attempt or the maximum amount of time allowed in +milliseconds, or both, in which case the smaller of the two values used. + +#### PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS +Default is 30. + +#### PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS +Default is 100. + +#### PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS +Default is 0, meaning it is not used. + +#### PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS +Default is 1. + +## File Output + +Assuming you specified a filename, you'll end up with a CSV file with contents like so: + +``` +Validator,PT_VERSION,2.2.0 +Validator,ROCM_VERSION,6.0.0.0-12969-1544e39 +Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7 +Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty +GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262 +GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033 +``` + +Note the "Validator" lines. If you change a library verison, or rocm version, or pytorch version, TunableOp will detect +this and not load the tunings because they are likely affected by other software changes. + +The remaining lines are the tuned solutions for each TunableOp encountered during your execution. Each line consists of +4 comma-separated fields: operator name, operator parameters, solution name, and average execution time. The execution +time is an optional field. The CSV file can be edited, but with caution. For example, the solution name (field 3) can be +changed to "Default" and it will fall back to the original PyTorch untuned implementation. Or, in the case of ROCm's +hipBLAS or hipBLASLt libraries, if you know the specific solution index you can override the solution that TunableOp +selected by replacing the value. The operator name and parameters (fields 1 and 2) are internally named and should not +be modified. In the case of GemmTunableOp, field 1 indicates the datatype and whether the inputs are transposed (T) or +not (N) and field 2 indicates the M, N, K input shapes. diff --git a/aten/src/ATen/cuda/tunable/StreamTimer.cpp b/aten/src/ATen/cuda/tunable/StreamTimer.cpp new file mode 100644 index 000000000000..1407c32dbb35 --- /dev/null +++ b/aten/src/ATen/cuda/tunable/StreamTimer.cpp @@ -0,0 +1,43 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#include + +#include +#include +#include + +namespace at::cuda::tunable { + +StreamTimer::StreamTimer() { + AT_CUDA_CHECK(cudaEventCreate(&start_)); + AT_CUDA_CHECK(cudaEventCreate(&end_)); +} + +StreamTimer::~StreamTimer() { +} + +void StreamTimer::Start() { + AT_CUDA_CHECK(cudaDeviceSynchronize()); + AT_CUDA_CHECK(cudaEventRecord(start_, at::cuda::getCurrentCUDAStream())); +} + +void StreamTimer::End() { + AT_CUDA_CHECK(cudaEventRecord(end_, at::cuda::getCurrentCUDAStream())); + AT_CUDA_CHECK(cudaEventSynchronize(end_)); +} + +float StreamTimer::Duration() { + float time; + // time is in ms with a resolution of 1 us + AT_CUDA_CHECK(cudaEventElapsedTime(&time, start_, end_)); + return time; +} + +} // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/StreamTimer.h b/aten/src/ATen/cuda/tunable/StreamTimer.h new file mode 100644 index 000000000000..69889cbbcbfc --- /dev/null +++ b/aten/src/ATen/cuda/tunable/StreamTimer.h @@ -0,0 +1,34 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include + +#include + +namespace at::cuda::tunable { + +class StreamTimer : public ITimer { + public: + StreamTimer(); + virtual ~StreamTimer(); + + void Start() override; + + void End() override; + + float Duration() override; + + private: + cudaEvent_t start_; + cudaEvent_t end_; +}; + +} // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp new file mode 100644 index 000000000000..ad855229643f --- /dev/null +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -0,0 +1,567 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#include + +#include +#include +#include +#include +#include + +#ifndef _WIN32 +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::cuda::tunable { + +namespace { + +TuningContext tuning_context; + +} // anonymous namespace + +TuningContext* getTuningContext() { + return &tuning_context; +} + +std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry) { + return stream << entry.key_ << "," << entry.time_; +} + +// TuningResultsManager + +KernelMap TuningResultsManager::Lookup(const std::string& op_signature) { + std::scoped_lock l{lock_}; + auto it = results_.find(op_signature); + if (it == results_.cend()) { + return {}; + } + return it->second; // copied +} + +ResultEntry TuningResultsManager::Lookup(const std::string& op_signature, const std::string& params_signature) { + std::scoped_lock l{lock_}; + auto kernel_map_it = results_.find(op_signature); + if (kernel_map_it == results_.cend()) { + TUNABLE_LOG("missing op_signature, returning null ResultEntry"); + return ResultEntry::Null(); + } + + const auto& km = kernel_map_it->second; + auto it = km.find(params_signature); + if (it == km.cend()) { + TUNABLE_LOG("missing params_signature, returning null ResultEntry"); + return ResultEntry::Null(); + } + return it->second; +} + +inline void TuningResultsManager::AddImpl(const std::string& op_signature, + const std::string& params_signature, + ResultEntry best, + KernelMap& kernel_map) { + auto it = kernel_map.find(params_signature); + if (it != kernel_map.end()) { + if (it->second != best) { + TUNABLE_LOG(op_signature, "(", params_signature, ") already has a best kernel ", + "id=", it->second, " selected, want to add a different best kernel ", best, + ", the new kernel id will be ignored."); + } + return; + } + + TUNABLE_LOG(op_signature, "(", params_signature, ") -> ", best); + kernel_map.emplace(params_signature, best); +} + +void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) { + std::scoped_lock l{lock_}; + + auto it = results_.find(op_signature); + if (it == results_.end()) { + it = results_.insert({op_signature, {}}).first; + } + + AddImpl(op_signature, params_signature, best, it->second); +} + +void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) { + std::scoped_lock l{lock_}; + + auto it = results_.find(op_signature); + if (it == results_.end()) { + return; + } + + auto it2 = it->second.find(params_signature); + if (it2 == it->second.end()) { + return; + } + + TUNABLE_LOG(op_signature, "(", params_signature, ")"); + it->second.erase(it2); +} + +inline void TuningResultsManager::DisjointMergeImpl( + const std::string& op_signature, + const KernelMap& kernel_map, + /*out*/ std::unordered_map& results) { + auto it = results.find(op_signature); + if (it == results.end()) { + for (const auto& [param_sig, kernel_id] : kernel_map) { + TUNABLE_LOG(op_signature, "(", param_sig, ") -> ", kernel_id); + } + results[op_signature] = kernel_map; + return; + } + + for (const auto& [params_signature, best] : kernel_map) { + AddImpl(op_signature, params_signature, best, it->second); + } +} + +void TuningResultsManager::Load(const std::unordered_map& results_to_load) { + TUNABLE_LOG("Loading results"); + std::scoped_lock l{lock_}; + for (const auto& [op_signature, kernel_map] : results_to_load) { + DisjointMergeImpl(op_signature, kernel_map, results_); + } +} + +ResultsMap TuningResultsManager::Dump() { + std::scoped_lock l{lock_}; + return results_; +} + +void TuningResultsManager::DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map) { + std::scoped_lock l{lock_}; + DisjointMergeImpl(op_signature, kernel_map, results_); +} + +size_t TuningResultsManager::GetSize() { + size_t size = 0; + std::scoped_lock l{lock_}; + for (const auto& [op_signature, kernel_map] : results_) { + size += kernel_map.size(); + } + return size; +} + +// TuningResultsValidator + +TuningResultsValidator::TuningResultsValidator() { + RegisterValidator( + "PT_VERSION", + [this]() { return GetPyTorchVersion(); }, + [this](auto&& k) { return ValidatePyTorchVersion(std::forward(k)); }); +} + +std::unordered_map TuningResultsValidator::GetAllValidators() const { + std::unordered_map ret; + for (const auto& [key, get_validate_func_pair] : validators_) { + const GetFunc& getter = get_validate_func_pair.first; + ret[key] = getter(); + } + return ret; +} + +static bool CheckMandatoryKeys( + const TuningResultsValidator::GetValidateFuncs& gv_funcs, + const std::unordered_map& to_check) { + bool passed = true; + for (const auto& k : TuningResultsValidator::mandatory_keys) { + if (gv_funcs.find(k) == gv_funcs.end()) { + passed = false; + TUNABLE_LOG("key=\"", k, "\" is not registered for Get and Validate. "); + } + + if (to_check.find(k) == to_check.end()) { + passed = false; + TUNABLE_LOG("key=\"", k, "\" is not provided for validation. "); + } + } + return passed; +} + +static bool CheckKeysMatching( + const TuningResultsValidator::GetValidateFuncs& gv_funcs, + const std::unordered_map& to_check) { + auto get_keys = [](const auto& it) -> std::string { return it.first; }; + std::vector required_keys; + std::vector provided_keys; + std::transform(gv_funcs.cbegin(), gv_funcs.cend(), std::back_inserter(required_keys), get_keys); + std::transform(to_check.cbegin(), to_check.cend(), std::back_inserter(provided_keys), get_keys); + std::sort(required_keys.begin(), required_keys.end()); + std::sort(provided_keys.begin(), provided_keys.end()); + + std::unordered_set intersection; + std::set_intersection(required_keys.cbegin(), required_keys.cend(), + provided_keys.cbegin(), provided_keys.cend(), + std::inserter(intersection, intersection.end())); + bool matched = true; + if (intersection.size() != required_keys.size()) { + matched = false; + for (const auto& k : required_keys) { + if (intersection.find(k) == intersection.end()) { + TORCH_WARN("Unmatched validator: \"", k, "\" is required, but the tuning results does not provide it. "); + } + } + } + if (intersection.size() != provided_keys.size()) { + matched = false; + for (const auto& k : provided_keys) { + if (intersection.find(k) == intersection.end()) { + TORCH_WARN("Unmatched validator: \"", k, "\" is provided, but pytorch is unable to consume it. "); + } + } + } + return matched; +} + +TuningStatus TuningResultsValidator::ValidateAll( + const std::unordered_map& to_validate) const { + if (!CheckMandatoryKeys(validators_, to_validate)) { + return FAIL; + } + if (!CheckKeysMatching(validators_, to_validate)) { + return FAIL; + } + + for (const auto& [key, value] : to_validate) { + const auto& it = validators_.find(key); + if (it == validators_.cend()) { + TORCH_WARN("Failed to lookup validator using key ", key); + for (const auto& [key2, val2] : validators_) { + TORCH_WARN("available key ", key2); + } + return FAIL; + } + const ValidateFunc& validator = it->second.second; + if (validator(value) != OK) { + TORCH_WARN("Failed validator: ", key); + return FAIL; + } + } + + return OK; +} + +void TuningResultsValidator::RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf) { + if (validators_.find(key) != validators_.end()) { + TORCH_WARN("Attempting to re-register validator with key ", key); + } + else { + validators_[key] = std::make_pair(gf, vf); + } +} + +std::string TuningResultsValidator::GetPyTorchVersion() const { + return TORCH_VERSION; +} + +TuningStatus TuningResultsValidator::ValidatePyTorchVersion(const std::string& value) const { + if (value == GetPyTorchVersion()) { + return OK; + } + return FAIL; +} + +// TuningContext + +TuningContext::TuningContext() : + enable_{false}, + tuning_enable_{true}, + manager_initialized_{false}, + max_tuning_duration_ms_{30}, + max_tuning_iterations_{100}, + max_warmup_duration_ms_{0}, + max_warmup_iterations_{0}, + filename_{"tunableop_results.csv"}, + results_count_from_input_file_{0} +{ +} + +TuningContext::~TuningContext() { + if (!manager_initialized_) { + // TuningResultsManager was never initialized, no tuning requested or performed. + // This can happen in a DDP job where a python process spawns other workers + // but doesn't do any computation itself. + return; + } + auto filename = GetFilename(); + if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty()) { + if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) { + if (results_count_from_input_file_ > 0) { + TUNABLE_LOG("additional tuning results available, rewriting file ", filename); + } + else { + TUNABLE_LOG("writing file ", filename); + } + if (!WriteFile(filename)) { + TUNABLE_LOG("failed to write file ", filename); + } + } + } +} + +void TuningContext::EnableTunableOp() { + TUNABLE_LOG("Enable TunableOp"); + enable_ = true; +} + +void TuningContext::DisableTunableOp() { + TUNABLE_LOG("Disable TunableOp"); + enable_ = false; +} + +bool TuningContext::IsTunableOpEnabled() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_ENABLED"); + if (env != nullptr && strcmp(env, "1") == 0) { + //TUNABLE_LOG("PYTORCH_TUNABLEOP_ENABLED=1"); + return true; + } + return enable_; +} + +void TuningContext::EnableTuning() { + TUNABLE_LOG("Enable Tuning for TunableOp"); + tuning_enable_ = true; +} + +void TuningContext::DisableTuning() { + TUNABLE_LOG("Disable Tuning for TunableOp"); + tuning_enable_ = false; +} + +bool TuningContext::IsTuningEnabled() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_TUNING"); + if (env != nullptr && strcmp(env, "0") == 0) { + //TUNABLE_LOG("PYTORCH_TUNABLEOP_TUNING=1"); + return false; + } + return tuning_enable_; +} + +void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) { + max_tuning_duration_ms_ = max_duration_ms; +} + +int TuningContext::GetMaxTuningDurationMs() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"); + if (env != nullptr) { + return atoi(env); + } + return max_tuning_duration_ms_; +} + +void TuningContext::SetMaxTuningIterations(int max_iter) { + max_tuning_iterations_ = max_iter; +} + +int TuningContext::GetMaxTuningIterations() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"); + if (env != nullptr) { + return atoi(env); + } + return max_tuning_iterations_; +} + +void TuningContext::SetMaxWarmupDurationMs(int max_duration_ms) { + max_warmup_duration_ms_ = max_duration_ms; +} + +int TuningContext::GetMaxWarmupDurationMs() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"); + if (env != nullptr) { + return atoi(env); + } + return max_warmup_duration_ms_; +} + +void TuningContext::SetMaxWarmupIterations(int max_iter) { + max_warmup_iterations_ = max_iter; +} + +int TuningContext::GetMaxWarmupIterations() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS"); + if (env != nullptr) { + return atoi(env); + } + return max_warmup_iterations_; +} + +void TuningContext::EnableTunableOpAndTuning() { + EnableTunableOp(); + EnableTuning(); +} + +void TuningContext::DisableTunableOpAndTuning() { + DisableTunableOp(); + DisableTuning(); +} + +TuningResultsManager& TuningContext::GetTuningResultsManager() { + c10::call_once(manager_init_once_, [this]() { + manager_initialized_ = true; + auto filename = GetFilename(); + if (!filename.empty()) { + ReadFile(filename); + // attempt immediately to open file for writing to catch errors early + std::ofstream file(filename, std::ios::out | std::ios::app); + if (!file.good()) { + TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved"); + } + } + }); + return manager_; +} + +TuningResultsValidator& TuningContext::GetTuningResultsValidator() { + return validator_; +} + +TuningResults TuningContext::GetTuningResults() { + TuningResults tr; + tr.validators = GetTuningResultsValidator().GetAllValidators(); + tr.results = GetTuningResultsManager().Dump(); + return tr; +} + +TuningStatus TuningContext::LoadTuningResults(const TuningResults& tr) { + TORCH_CHECK(GetTuningResultsValidator().ValidateAll(tr.validators)); + GetTuningResultsManager().Load(tr.results); + return OK; +} + +void TuningContext::SetFilename(const std::string& filename) { + filename_ = filename; +} + +std::string TuningContext::GetFilename() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_FILENAME"); + std::string filename = (env == nullptr) ? filename_ : env; + if (filename.empty()) { + TUNABLE_LOG("no filename from TuningContext::GetFilename()"); + return filename; // empty string + } + + // Using static with lambda here so that we don't make a cuda call during static shutdown. + // Do this the first and only time GetFilename() is called because it is called + // the first time a TunableOp is instantiated but also during static destruction + // when the cuda or hip runtime is no longer available. + static std::string device = []() { + return c10::str(int(c10::cuda::current_device())); + }(); + + // differentiate filename based on device ordinal to avoid + // use case of one process per device writing to same file + + // does filename contain %d to insert device ordinal in specific location? + const std::string TOKEN("%d"); + std::size_t found = filename.find(TOKEN); + if (found != std::string::npos) { + filename.replace(found, TOKEN.length(), device); + } + else { + // no %d present, so append device ordinal before final '.' + found = filename.rfind("."); + if (found != std::string::npos) { + filename.insert(found, device); + } + else { + // all else fails, just prepend + filename.insert(0, device); + } + } + return filename; +} + +bool TuningContext::ReadFile(const std::string& filename) { + TUNABLE_LOG("reading tuning results from ", filename); + ResultsMap results; + std::unordered_map validators; + std::string line; + std::ifstream file(filename); + if (!file) { + TUNABLE_LOG("could not open ", filename, " for reading tuning results"); + return false; + } + while (std::getline(file, line)) { + if (line.empty()) { + continue; + } + std::string part; + std::vector parts; + std::stringstream line_as_stream(line); + while (std::getline(line_as_stream, part, ',')) { + parts.push_back(part); + } + if (parts[0] == "Validator" && parts.size() >= 3) { + validators[parts[1]] = parts[2]; + TUNABLE_LOG("Validator ", parts[1], "=", parts[2]); + } + else if (parts.size() >= 4) { + results[parts[0]].emplace(parts[1], ResultEntry(parts[2], atof(parts[3].c_str()))); + } + else if (parts.size() >= 3) { + // the timestamp from the file is optional + results[parts[0]].emplace(parts[1], ResultEntry(parts[2], 0)); + } + else { + TUNABLE_LOG("could not parse line: ", line); + } + } + if (GetTuningResultsValidator().ValidateAll(validators) != FAIL) { + manager_.Load(results); + results_count_from_input_file_ = manager_.GetSize(); + } + else { + TUNABLE_LOG("results validator check failed"); + return false; + } + return true; +} + +bool TuningContext::WriteFile(const std::string& filename) { + std::ofstream file(filename, std::ios::out | std::ios::trunc); + if (!file.good()) { + TUNABLE_LOG("error opening tuning results file for writing ", filename); + return false; + } + auto validators = GetTuningResultsValidator().GetAllValidators(); + for (const auto& [key, val] : validators) { + file << "Validator," << key << "," << val << std::endl; + } + auto results = GetTuningResultsManager().Dump(); + for (const auto& [op_sig, kernelmap] : results) { + for (const auto& [param_sig, result] : kernelmap) { + file << op_sig << "," << param_sig << "," << result << std::endl; + } + } + file.close(); + return true; +} + +} // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/Tunable.h b/aten/src/ATen/cuda/tunable/Tunable.h new file mode 100644 index 000000000000..eb849a213fe5 --- /dev/null +++ b/aten/src/ATen/cuda/tunable/Tunable.h @@ -0,0 +1,205 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::cuda::tunable { + +static void TunableLog(const std::string& msg) { + static const char *env = getenv("PYTORCH_TUNABLEOP_VERBOSE"); + if (env != nullptr && strcmp(env, "1") == 0) { + std::cerr << msg << std::endl; + } +} +#define TUNABLE_LOG(...) TunableLog(c10::str(__VA_ARGS__)) + +enum TuningStatus { + OK = 0, + FAIL = 1, + UNSUPPORTED = 2, +}; + +// Mapping from params signature to kernel id +class ResultEntry { + public: + explicit ResultEntry(const std::string& key, double time) : key_(key), time_(time) {} + bool operator==(const ResultEntry& other) { return key_ == other.key_; } + bool operator!=(const ResultEntry& other) { return key_ != other.key_; } + operator std::string () { return key_; } + friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry); + static ResultEntry Null() { return ResultEntry("Null", 0.0); } + static ResultEntry Default() { return ResultEntry("Default", 0.0); } + + private: + std::string key_; + double time_; +}; + +typedef std::unordered_map KernelMap; +typedef std::unordered_map ResultsMap; + +struct TuningResults { + // Validates if these results are compatible with the libraries + std::unordered_map validators; + + // Mapping from Callable signature to Callable's tuning result + ResultsMap results; +}; + +class TuningResultsManager { + public: + TuningResultsManager() = default; + ~TuningResultsManager() = default; + + KernelMap Lookup(const std::string& op_signature); + + ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature); + + inline void AddImpl(const std::string& op_signature, + const std::string& params_signature, + ResultEntry best, + KernelMap& kernel_map); + + void Add(const std::string& op_signature, + const std::string& params_signature, + ResultEntry best); + + void Delete(const std::string& op_signature, const std::string& params_signature); + + inline void DisjointMergeImpl( + const std::string& op_signature, + const KernelMap& kernel_map, + /*out*/ ResultsMap& results); + + void Load(const ResultsMap& results_to_load); + + ResultsMap Dump(); + + void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map); + + size_t GetSize(); + + private: + std::mutex lock_; + ResultsMap results_; +}; + +class TuningResultsValidator { + public: + using GetFunc = std::function; + using ValidateFunc = std::function; + using GetValidateFuncs = std::unordered_map>; + + TuningResultsValidator(); + ~TuningResultsValidator() = default; + + std::unordered_map GetAllValidators() const; + TuningStatus ValidateAll(const std::unordered_map& to_validate) const; + void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf); + + protected: + std::string GetPyTorchVersion() const; + TuningStatus ValidatePyTorchVersion(const std::string& value) const; + + public: + static constexpr const std::array mandatory_keys{"PT_VERSION"}; + + private: + GetValidateFuncs validators_; +}; + +class TuningContext { + public: + TuningContext(); + ~TuningContext(); + TuningContext(TuningContext &) = delete; + TuningContext(TuningContext &&) = delete; + TuningContext &operator=(TuningContext &) = delete; + TuningContext &operator=(TuningContext &&) = delete; + + void EnableTunableOp(); + void DisableTunableOp(); + bool IsTunableOpEnabled() const; + + void EnableTuning(); + void DisableTuning(); + bool IsTuningEnabled() const; + + void SetMaxTuningDurationMs(int max_duration_ms); + int GetMaxTuningDurationMs() const; + + void SetMaxTuningIterations(int max_iter); + int GetMaxTuningIterations() const; + + void SetMaxWarmupDurationMs(int max_duration_ms); + int GetMaxWarmupDurationMs() const; + + void SetMaxWarmupIterations(int max_iter); + int GetMaxWarmupIterations() const; + + void EnableTunableOpAndTuning(); + void DisableTunableOpAndTuning(); + + TuningResultsManager& GetTuningResultsManager(); + + TuningResultsValidator& GetTuningResultsValidator(); + + TuningResults GetTuningResults(); + + TuningStatus LoadTuningResults(const TuningResults& tr); + + void SetFilename(const std::string& filename); + std::string GetFilename() const; + + protected: + bool ReadFile(const std::string& filename); + bool WriteFile(const std::string& filename); + + private: + bool enable_; + bool tuning_enable_; + bool manager_initialized_; + int max_tuning_duration_ms_; + int max_tuning_iterations_; + int max_warmup_duration_ms_; + int max_warmup_iterations_; + mutable TuningResultsManager manager_; + mutable c10::once_flag manager_init_once_; + TuningResultsValidator validator_; + std::string filename_; + size_t results_count_from_input_file_; +}; + +TuningContext* getTuningContext(); + +class ITimer { + public: + ITimer() = default; + virtual ~ITimer() = default; + + virtual void Start() = 0; + virtual void End() = 0; + + /// Computes the elapsed time in milliseconds between Start() and End() + virtual float Duration() = 0; +}; + +} // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h new file mode 100644 index 000000000000..3ba0d761277b --- /dev/null +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -0,0 +1,278 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#ifdef USE_ROCM +#if ROCM_VERSION >= 50700 +#include +#endif +#include +#endif +#include +#include +#include +#include + +#ifdef USE_ROCM +#include +#endif + +#define STRINGIFY(s) #s +#define XSTRINGIFY(s) STRINGIFY(s) + +namespace at::cuda::tunable { + +template +class DefaultGemmOp : public Callable> { + public: + TuningStatus Call(const GemmParams* params) override { + at::cuda::blas::gemm_internal( + params->transa, params->transb, + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, + params->b, params->ldb, + params->beta, + params->c, params->ldc); + return OK; + } +}; + +template +class DefaultGemmStridedBatchedOp : public Callable> { + public: + TuningStatus Call(const GemmStridedBatchedParams* params) override { + at::cuda::blas::bgemm_internal( + params->transa, params->transb, + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, params->stride_a, + params->b, params->ldb, params->stride_b, + params->beta, + params->c, params->ldc, params->stride_c, + params->batch); + return OK; + } +}; + +template +bool IsZero(T v) { + return v == 0.0f; +} + +template <> +bool IsZero(BFloat16 v) { + return v.x == 0; +} + +template <> +bool IsZero(Half v) { + return float(v) == 0.0f; +} + +template <> +bool IsZero(c10::complex v) { + return v == 0.0; +} + +template <> +bool IsZero(c10::complex v) { + return v == 0.0f; +} + +template +std::string TypeName(T v) { + return "unknown"; +} + +template <> +std::string TypeName(float v) { + return "float"; +} + +template <> +std::string TypeName(double v) { + return "double"; +} + +template <> +std::string TypeName(BFloat16 v) { + return "BFloat16"; +} + +template <> +std::string TypeName(Half v) { + return "Half"; +} + +template <> +std::string TypeName(c10::complex v) { + return "c10::complex"; +} + +template <> +std::string TypeName(c10::complex v) { + return "c10::complex"; +} + + +template +class GemmTunableOp : public TunableOp, StreamTimer> { + public: + GemmTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); + +#ifdef USE_ROCM + for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + + if (validators.find("ROCM_VERSION") == validators.end()) { + std::string rocm_version = ROCM_BUILD_INFO; + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "ROCM_VERSION", + [rocm_version]() { return rocm_version; }, + [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; }); + } + + if (validators.find("GCN_ARCH_NAME") == validators.end()) { + std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName; + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "GCN_ARCH_NAME", + [gcn_arch_name]() { return gcn_arch_name; }, + [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; }); + } + + if (validators.find("ROCBLAS_VERSION") == validators.end()) { + std::string rocblas_version = c10::str( + XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".", + XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".", + XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-", + XSTRINGIFY(ROCBLAS_VERSION_TWEAK)); + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "ROCBLAS_VERSION", + [rocblas_version]() { return rocblas_version; }, + [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; }); + } +#endif + +#if defined(USE_ROCM) && ROCM_VERSION >= 50700 + static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env == nullptr || strcmp(env, "1") == 0) { + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + + if (validators.find("HIPBLASLT_VERSION") == validators.end()) { + std::string hipblaslt_version = c10::str( + XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".", + XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".", + XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-", + XSTRINGIFY(HIPBLASLT_VERSION_TWEAK)); + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "HIPBLASLT_VERSION", + [hipblaslt_version]() { return hipblaslt_version; }, + [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); + } + } +#endif + } + + std::string Signature() override { + return c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +template +class GemmStridedBatchedTunableOp : public TunableOp, StreamTimer> { + public: + GemmStridedBatchedTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + + auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); + +#ifdef USE_ROCM + for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + + if (validators.find("ROCM_VERSION") == validators.end()) { + std::string rocm_version = ROCM_BUILD_INFO; + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "ROCM_VERSION", + [rocm_version]() { return rocm_version; }, + [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; }); + } + + if (validators.find("GCN_ARCH_NAME") == validators.end()) { + std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName; + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "GCN_ARCH_NAME", + [gcn_arch_name]() { return gcn_arch_name; }, + [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; }); + } + + if (validators.find("ROCBLAS_VERSION") == validators.end()) { + std::string rocblas_version = c10::str( + XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".", + XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".", + XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-", + XSTRINGIFY(ROCBLAS_VERSION_TWEAK)); + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "ROCBLAS_VERSION", + [rocblas_version]() { return rocblas_version; }, + [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; }); + } +#endif + +#if defined(USE_ROCM) && ROCM_VERSION >= 50700 + static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env == nullptr || strcmp(env, "1") == 0) { + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + + if (validators.find("HIPBLASLT_VERSION") == validators.end()) { + std::string hipblaslt_version = c10::str( + XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".", + XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".", + XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-", + XSTRINGIFY(HIPBLASLT_VERSION_TWEAK)); + getTuningContext()->GetTuningResultsValidator().RegisterValidator( + "HIPBLASLT_VERSION", + [hipblaslt_version]() { return hipblaslt_version; }, + [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); + } + } +#endif + } + + std::string Signature() override { + return c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +#undef XSTRINGIFY +#undef STRINGIFY + +} // namespace at::cuda::tunable diff --git a/aten/src/ATen/cuda/tunable/TunableOp.h b/aten/src/ATen/cuda/tunable/TunableOp.h new file mode 100644 index 000000000000..65257974ab0c --- /dev/null +++ b/aten/src/ATen/cuda/tunable/TunableOp.h @@ -0,0 +1,242 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#include + +#ifndef _WIN32 +#include +#endif + +#include +#include +#include +#include + +namespace at::cuda::tunable { + +template +class Callable { + public: + Callable() = default; + Callable(Callable&&) = default; + virtual ~Callable() = default; + virtual TuningStatus Call(const ParamsT*) { + return FAIL; + } + virtual TuningStatus IsSupported(const ParamsT* params) { + return Call(params); + } +}; + +template +class TunableOp { + public: + TunableOp() = default; + TunableOp(TunableOp&&) = default; + virtual ~TunableOp() = default; + + TuningStatus operator()(const ParamsT* params) { + ResultEntry result = ResultEntry::Null(); + TuningContext* ctx = getTuningContext(); + if (ctx->IsTunableOpEnabled()) { + auto& mgr = ctx->GetTuningResultsManager(); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + result = mgr.Lookup(op_sig, params_sig); + // If there is not previous tuning result been found, we do the tuning iff tuning is enabled + if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) { + result = FindFastest(params); + mgr.Add(op_sig, params_sig, result); + } + } + else { + result = ResultEntry::Default(); + } + if (result == ResultEntry::Null()) { + TUNABLE_LOG("no result, using default"); + result = ResultEntry::Default(); + } + auto iter = ops_.find(result); + TORCH_CHECK(iter != ops_.end()); + return iter->second->Call(params); + } + + virtual std::string Signature() { + // According to C++17 standard https://wg21.link/n4659 section 15.7.4 + // > if the operand of typeid refers to the + // > object under construction or destruction, typeid yields the std::type_info object representing the constructor + // > or destructor’s class. + // So delay the op signature generation. + c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); }); + return signature_; + } + + protected: + void RegisterOp(const std::string& name, std::unique_ptr> op) { + this->op_names_.emplace_back(name); + this->ops_.emplace(name, std::move(op)); + } + + private: + static void WarmUp(Callable *op, ParamsT* param, size_t num_iter) { + for (size_t i = 0; i < num_iter; i++) { + TORCH_CHECK(op->Call(param) == OK); + } + } + + static double Profile(Callable *op, ParamsT* param, size_t num_iter) { + TimerT timer{}; + timer.Start(); + for (size_t i = 0; i < num_iter; i++) { + TORCH_CHECK(op->Call(param) == OK); + } + timer.End(); + return timer.Duration() / num_iter; + } + + protected: + bool IsNumericsCheckEnabled() { + static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); + if (env != nullptr && strcmp(env, "0") == 0) { + return false; + } + return true; + } + + virtual ResultEntry FindFastest(const ParamsT* params) { + TuningContext* ctx = getTuningContext(); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + TUNABLE_LOG("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates"); + auto min_duration_ms = std::numeric_limits::infinity(); + std::string id_name = "Default"; + + // calcaulte a reference answer for numerical check + ParamsT* reference_params = params->DeepCopy(); + TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK); + + // need a copy of params to reuse + ParamsT* reusable_params = params->DeepCopy(); + + for (size_t i = 0; i < op_names_.size(); i++) { + auto* candidate = ops_[op_names_[i]].get(); // borrow pointer + auto status = candidate->Call(reusable_params); + if (status != OK) { + TUNABLE_LOG("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + + if (IsNumericsCheckEnabled()) { + ParamsT* numerical_params = params->DeepCopy(); + WarmUp(candidate, numerical_params, 1); + status = reference_params->NumericalCheck(numerical_params); + numerical_params->Delete(); + if (status != OK) { + TUNABLE_LOG("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + } + + // collect a small profile + constexpr const int approx_num_iter = 3; + auto approx_duration = Profile(candidate, reusable_params, approx_num_iter); + // bail if too slow + if (approx_duration > 2 * min_duration_ms) { + TUNABLE_LOG("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + + // for warmup does user set max duration, max iters, or both? + double max_warmup_duration = ctx->GetMaxWarmupDurationMs(); + int max_warmup_iter = ctx->GetMaxWarmupIterations(); + int warmup_iter = 1; // default + if (max_warmup_duration > 0) { + int duration_iters = max_warmup_duration / approx_duration; + if (max_warmup_iter > 0) { + warmup_iter = std::min(max_warmup_iter, duration_iters); + } + else { + warmup_iter = duration_iters; + } + } + else if (max_warmup_iter > 0) { + warmup_iter = max_warmup_iter; + } + + // for tuning does user set max duration, max iters, or both? + double max_tuning_duration = ctx->GetMaxTuningDurationMs(); + int max_tuning_iter = ctx->GetMaxTuningIterations(); + int tuning_iter = 100; // default + if (max_tuning_duration > 0) { + int duration_iters = max_tuning_duration / approx_duration; + if (max_tuning_iter > 0) { + tuning_iter = std::min(max_tuning_iter, duration_iters); + } + else { + tuning_iter = duration_iters; + } + } + else if (max_tuning_iter > 0) { + tuning_iter = max_tuning_iter; + } + + // do the full warmup followed by tuning + double warmup_ms = warmup_iter * approx_duration; + double tuning_ms = tuning_iter * approx_duration; + TUNABLE_LOG("├──tuning using " + "warmup iters ", warmup_iter, " [", warmup_ms, " ms] " + "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ", + "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]); + WarmUp(candidate, reusable_params, warmup_iter); + auto duration_ms = Profile(candidate, reusable_params, tuning_iter); + if (duration_ms < min_duration_ms) { + TUNABLE_LOG("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]); + min_duration_ms = duration_ms; + id_name = op_names_[i]; + } + } + + reusable_params->Delete(); + reference_params->Delete(); + + TUNABLE_LOG("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name); + return ResultEntry(id_name, min_duration_ms); + } + + private: + std::string CreateSignature() { +#ifndef _WIN32 + const auto* name = typeid(*this).name(); + char buf[256]; + size_t buf_len = 256; + abi::__cxa_demangle(name, buf, &buf_len, nullptr); + buf[255] = '\0'; + return buf; +#else + return typeid(*this).name(); +#endif + } + + mutable c10::once_flag signature_init_once_; + std::string signature_; + + std::unordered_map>> ops_; + std::vector op_names_; +}; + +struct OpParams { + OpParams() {} + virtual ~OpParams() = default; + virtual std::string Signature() const = 0; +}; + +} // namespace at::cuda::tunable diff --git a/build_variables.bzl b/build_variables.bzl index 8e446c8d11fd..d490b4462656 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1410,6 +1410,8 @@ aten_cuda_cu_source_list = [ "aten/src/ATen/cuda/CUDABlas.cpp", "aten/src/ATen/cuda/CUDASparseBlas.cpp", "aten/src/ATen/cuda/CublasHandlePool.cpp", + "aten/src/ATen/cuda/tunable/StreamTimer.cpp", + "aten/src/ATen/cuda/tunable/Tunable.cpp", "aten/src/ATen/native/cuda/Activation.cpp", "aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp", "aten/src/ATen/native/cuda/Blas.cpp", diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 4dd80420584b..b7ffbeb07dcf 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1279,6 +1279,9 @@ if(USE_ROCM) if(HIPBLASLT_CUSTOM_COMPUTE_TYPE) list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_COMPUTE_TYPE) endif() + if(HIPBLASLT_HAS_GETINDEXFROMALGO) + list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_HAS_GETINDEXFROMALGO) + endif() if(HIP_NEW_TYPE_ENUMS) list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS) endif() diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 1abeb06228ec..f6ca263c5e5b 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -202,12 +202,12 @@ if(HIP_FOUND) "}\n" ) - try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} + try_compile(hipblaslt_compile_result_custom_datatype ${PROJECT_RANDOM_BINARY_DIR} ${file} CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ OUTPUT_VARIABLE hipblaslt_compile_output) - if(hipblaslt_compile_result) + if(hipblaslt_compile_result_custom_datatype) set(HIPBLASLT_CUSTOM_DATA_TYPE ON) #message("hipblaslt is using custom data type: ${hipblaslt_compile_output}") message("hipblaslt is using custom data type") @@ -227,12 +227,12 @@ if(HIP_FOUND) "}\n" ) - try_compile(hipblaslt_compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} + try_compile(hipblaslt_compile_result_custom_compute_type ${PROJECT_RANDOM_BINARY_DIR} ${file} CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ OUTPUT_VARIABLE hipblaslt_compile_output) - if(hipblaslt_compile_result) + if(hipblaslt_compile_result_custom_compute_type) set(HIPBLASLT_CUSTOM_COMPUTE_TYPE ON) #message("hipblaslt is using custom compute type: ${hipblaslt_compile_output}") message("hipblaslt is using custom compute type") @@ -241,6 +241,36 @@ if(HIP_FOUND) #message("hipblaslt is NOT using custom compute type: ${hipblaslt_compile_output}") message("hipblaslt is NOT using custom compute type") endif() + + # check whether hipblaslt provides getIndexFromAlgo + set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_getIndexFromAlgo.cc") + file(WRITE ${file} "" + "#include \n" + "#include \n" + "int main() {\n" + " hipblasLtMatmulAlgo_t algo;\n" + " return hipblaslt_ext::getIndexFromAlgo(algo);\n" + " return 0;\n" + "}\n" + ) + + try_compile(hipblaslt_compile_result_getindexfromalgo ${PROJECT_RANDOM_BINARY_DIR} ${file} + CMAKE_FLAGS + "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" + "-DLINK_DIRECTORIES=${ROCM_PATH}/lib" + LINK_LIBRARIES ${hipblaslt_LIBRARIES} + COMPILE_DEFINITIONS -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_HCC__ + OUTPUT_VARIABLE hipblaslt_compile_output) + + if(hipblaslt_compile_result_getindexfromalgo) + set(HIPBLASLT_HAS_GETINDEXFROMALGO ON) + #message("hipblaslt provides getIndexFromAlgo: ${hipblaslt_compile_output}") + message("hipblaslt provides getIndexFromAlgo") + else() + set(HAS_GETINDEXFROMALGO OFF) + #message("hipblaslt does not provide getIndexFromAlgo: ${hipblaslt_compile_output}") + message("hipblaslt does not provide getIndexFromAlgo") + endif() endif() # check whether HIP declares new types diff --git a/setup.py b/setup.py index b135ab58bd1a..a009bd358159 100644 --- a/setup.py +++ b/setup.py @@ -1166,6 +1166,7 @@ def main(): "include/ATen/cuda/*.h", "include/ATen/cuda/detail/*.cuh", "include/ATen/cuda/detail/*.h", + "include/ATen/cuda/tunable/*.h", "include/ATen/cudnn/*.h", "include/ATen/functorch/*.h", "include/ATen/ops/*.h", @@ -1174,6 +1175,7 @@ def main(): "include/ATen/hip/detail/*.cuh", "include/ATen/hip/detail/*.h", "include/ATen/hip/impl/*.h", + "include/ATen/hip/tunable/*.h", "include/ATen/mps/*.h", "include/ATen/miopen/*.h", "include/ATen/detail/*.h",