[ROCm] TunableOp (#114894)

Some operations, such as GEMMs, could be implemented using more than one library or more than one technique. For example, a GEMM could be implemented for CUDA or ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and hipblaslt libraries allow the user to query for all possible algorithms and then choose one. How does one know which implementation is the fastest and should be chosen? That's what TunableOp provides.

See the README.md for additional details.

TunableOp was ported from onnxruntime starting from commit 08dce54266.  The content was significantly modified and reorganized for use within PyTorch.  The files copied and their approximate new names or source content location within aten/src/ATen/cuda/tunable include the following:

- onnxruntime/core/framework/tunable.h -> Tunable.h
- onnxruntime/core/framework/tuning_context.h -> Tunable.h
- onnxruntime/core/framework/tuning_context_impl.h -> Tunable.cpp
- onnxruntime/core/providers/rocm/tunable/gemm_common.h -> GemmCommon.h
- onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h -> GemmHipblaslt.h
- onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h -> GemmRocblas.h
- onnxruntime/core/providers/rocm/tunable/gemm_tunable.cuh -> TunableGemm.h
- onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc -> Tunable.cpp
- onnxruntime/core/providers/rocm/tunable/util.h -> StreamTimer.h
- onnxruntime/core/providers/rocm/tunable/util.cc -> StreamTimer.cpp

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114894
Approved by: https://github.com/xw285cornell, https://github.com/jianyuh
This commit is contained in:
Jeff Daily
2024-02-14 19:03:49 +00:00
committed by PyTorch MergeBot
parent 90f785dc34
commit 0e6eee3c89
18 changed files with 2606 additions and 22 deletions

View File

@ -1166,6 +1166,7 @@ def main():
"include/ATen/cuda/*.h",
"include/ATen/cuda/detail/*.cuh",
"include/ATen/cuda/detail/*.h",
"include/ATen/cuda/tunable/*.h",
"include/ATen/cudnn/*.h",
"include/ATen/functorch/*.h",
"include/ATen/ops/*.h",
@ -1174,6 +1175,7 @@ def main():
"include/ATen/hip/detail/*.cuh",
"include/ATen/hip/detail/*.h",
"include/ATen/hip/impl/*.h",
"include/ATen/hip/tunable/*.h",
"include/ATen/mps/*.h",
"include/ATen/miopen/*.h",
"include/ATen/detail/*.h",