mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] OCP FP8 Support for new GPUs (#146632)
TLDR: Follow up/ Build on top of https://github.com/pytorch/pytorch/pull/144476. add OCP FP8 support for gfx950 refer to https://github.com/pytorch/ao/pull/1677 This pull request includes several changes to improve compatibility and support for new GPU architectures and data types, particularly for ROCm. The key updates involve adding support for new ROCm versions and GPU architectures, updating data type handling, and removing outdated checks. ### Improvements to GPU Architecture and ROCm Version Support: * [`aten/src/ATen/Context.cpp`](diffhunk://#diff-33de472d304acbe57d693c8567370c638068bedc1aa0ce8e9dc115dad05a7810L323-R326): Added support for new GPU architectures `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. * [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199): Updated architecture support in multiple functions to include `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL865-R876) ### Updates to Data Type Handling: * [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L81-L98): Enhanced data type conversion to include new float8 types for both CUDA and ROCm environments. * [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fL29-R80): Updated `HipDataTypeFor` template to handle new float8 types and added hard-coded enum values for ROCm versions prior to 6.3. ### Removal of Outdated Checks: * [`cmake/public/LoadHIP.cmake`](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197): Removed the check for `HIP_NEW_TYPE_ENUMS` as it is no longer necessary with the updated ROCm versions. [[1]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197) [[2]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L211-R182) These changes ensure better compatibility and performance on newer hardware and software environments, particularly for users leveraging ROCm and CUDA for deep learning and scientific computing tasks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146632 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
b1a81a4a65
commit
f95ab46797
@ -221,12 +221,16 @@ def max_clock_rate():
|
||||
return 1700
|
||||
elif "gfx908" in gcn_arch:
|
||||
return 1502
|
||||
elif "gfx12" in gcn_arch:
|
||||
return 1700
|
||||
elif "gfx11" in gcn_arch:
|
||||
return 1700
|
||||
elif "gfx103" in gcn_arch:
|
||||
return 1967
|
||||
elif "gfx101" in gcn_arch:
|
||||
return 1144
|
||||
elif "gfx95" in gcn_arch:
|
||||
return 1700 # TODO: placeholder, get actual value
|
||||
else:
|
||||
return 1100
|
||||
|
||||
|
Reference in New Issue
Block a user