[ROCm] add hipblaslt support (#114329)

Disabled by default. Enable with env var DISABLE_ADDMM_HIP_LT=0. Tested on both ROCm 5.7 and 6.0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114329
Approved by: https://github.com/malfet
This commit is contained in:
Jeff Daily
2023-12-15 15:36:46 +00:00
committed by PyTorch MergeBot
parent 287a865677
commit b062ea3803
9 changed files with 363 additions and 32 deletions

View File

@ -237,6 +237,9 @@ COMMON_HIP_FLAGS = [
'-DUSE_ROCM=1',
]
if ROCM_VERSION is not None and ROCM_VERSION >= (6, 0):
COMMON_HIP_FLAGS.append('-DHIPBLAS_V2')
COMMON_HIPCC_FLAGS = [
'-DCUDA_HAS_FP16=1',
'-D__HIP_NO_HALF_OPERATORS__=1',