[ROCm] remove HCC references (#111975)

- rename `__HIP_PLATFORM_HCC__` to `__HIP_PLATFORM_AMD__`
- rename `HIP_HCC_FLAGS` to `HIP_CLANG_FLAGS`
- rename `PYTORCH_HIP_HCC_LIBRARIES` to `PYTORCH_HIP_LIBRARIES`
- workaround in tools/amd_build/build_amd.py until submodules are updated

These symbols have had a long deprecation cycle and will finally be removed in ROCm 6.0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111975
Approved by: https://github.com/ezyang, https://github.com/hongxiayang
This commit is contained in:
Jeff Daily
2023-10-26 02:39:01 +00:00
committed by PyTorch MergeBot
parent f1785373c0
commit 28c0b07d19
12 changed files with 65 additions and 38 deletions

View File

@ -1156,7 +1156,7 @@ void LayerNormBackwardKernelImplInternal(
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
const int warp_size = at::cuda::warp_size(); const int warp_size = at::cuda::warp_size();
if (dX_data != nullptr) { if (dX_data != nullptr) {
#if defined __HIP_PLATFORM_HCC__ #ifdef USE_ROCM
if (M >= 32768) { if (M >= 32768) {
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)M, maxGridY), 1); const dim3 blocks1(1, std::min((uint64_t)M, maxGridY), 1);

View File

@ -44,7 +44,7 @@ endif()
# ---[ Dependency of c10_hip # ---[ Dependency of c10_hip
target_link_libraries(c10_hip PUBLIC c10) target_link_libraries(c10_hip PUBLIC c10)
target_link_libraries(c10_hip PUBLIC ${PYTORCH_HIP_HCC_LIBRARIES}) target_link_libraries(c10_hip PUBLIC ${PYTORCH_HIP_LIBRARIES})
target_include_directories( target_include_directories(
c10_hip PUBLIC c10_hip PUBLIC

View File

@ -692,7 +692,7 @@ if(USE_ROCM)
# caffe2_nvrtc's stubs to driver APIs are useful for HIP. # caffe2_nvrtc's stubs to driver APIs are useful for HIP.
# See NOTE [ ATen NVRTC Stub and HIP ] # See NOTE [ ATen NVRTC Stub and HIP ]
add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS})
target_link_libraries(caffe2_nvrtc ${PYTORCH_HIP_HCC_LIBRARIES} ${ROCM_HIPRTC_LIB}) target_link_libraries(caffe2_nvrtc ${PYTORCH_HIP_LIBRARIES} ${ROCM_HIPRTC_LIB})
target_include_directories(caffe2_nvrtc PRIVATE ${CMAKE_BINARY_DIR}) target_include_directories(caffe2_nvrtc PRIVATE ${CMAKE_BINARY_DIR})
target_compile_definitions(caffe2_nvrtc PRIVATE USE_ROCM __HIP_PLATFORM_HCC__) target_compile_definitions(caffe2_nvrtc PRIVATE USE_ROCM __HIP_PLATFORM_HCC__)
install(TARGETS caffe2_nvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}") install(TARGETS caffe2_nvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}")
@ -1260,21 +1260,21 @@ endif()
if(USE_ROCM) if(USE_ROCM)
target_compile_definitions(torch_hip PRIVATE target_compile_definitions(torch_hip PRIVATE
USE_ROCM USE_ROCM
__HIP_PLATFORM_HCC__ __HIP_PLATFORM_AMD__
) )
# NB: Massive hack. torch/csrc/jit/codegen/fuser/codegen.cpp includes # NB: Massive hack. torch/csrc/jit/codegen/fuser/codegen.cpp includes
# torch/csrc/jit/codegen/fuser/cuda/resource_strings.h which changes the # torch/csrc/jit/codegen/fuser/cuda/resource_strings.h which changes the
# strings depending on if you're __HIP_PLATFORM_HCC__ or not. # strings depending on if you're __HIP_PLATFORM_AMD__ or not.
# But that file is in torch_cpu! So, against all odds, this macro # But that file is in torch_cpu! So, against all odds, this macro
# has to be set on torch_cpu too. I also added it to torch for # has to be set on torch_cpu too. I also added it to torch for
# better luck # better luck
target_compile_definitions(torch_cpu PRIVATE target_compile_definitions(torch_cpu PRIVATE
USE_ROCM USE_ROCM
__HIP_PLATFORM_HCC__ __HIP_PLATFORM_AMD__
) )
target_compile_definitions(torch PRIVATE target_compile_definitions(torch PRIVATE
USE_ROCM USE_ROCM
__HIP_PLATFORM_HCC__ __HIP_PLATFORM_AMD__
) )
target_include_directories(torch_hip PRIVATE target_include_directories(torch_hip PRIVATE
/opt/rocm/include /opt/rocm/include

View File

@ -46,9 +46,9 @@
// however hipblas v1 is still using its custom type // however hipblas v1 is still using its custom type
#define HIP_R_16F HIPBLAS_R_16F #define HIP_R_16F HIPBLAS_R_16F
#define HIP_R_32F HIPBLAS_R_32F #define HIP_R_32F HIPBLAS_R_32F
#else // __HIP_PLATFORM_HCC #else // USE_ROCM
#define CUBLAS_HALF_TYPE __half #define CUBLAS_HALF_TYPE __half
#endif // __HIP_PLATFORM_HCC #endif // USE_ROCM
#include "caffe2/utils/math/utils.h" #include "caffe2/utils/math/utils.h"

View File

@ -1258,7 +1258,7 @@ if(USE_ROCM)
endif() endif()
list(APPEND HIP_CXX_FLAGS -fPIC) list(APPEND HIP_CXX_FLAGS -fPIC)
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_HCC__=1) list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1)
list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1) list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1)
list(APPEND HIP_CXX_FLAGS -DUSE_ROCM) list(APPEND HIP_CXX_FLAGS -DUSE_ROCM)
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1) list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1)
@ -1294,7 +1294,7 @@ if(USE_ROCM)
hip_include_directories(${Caffe2_HIP_INCLUDE}) hip_include_directories(${Caffe2_HIP_INCLUDE})
set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB} ${ROCM_ROCTX_LIB}) ${PYTORCH_HIP_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB} ${ROCM_ROCTX_LIB})
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver) roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)

View File

@ -21,13 +21,6 @@ if(NOT EXISTS ${HIP_PATH})
return() return()
endif() endif()
# HCC_PATH
if(NOT DEFINED ENV{HCC_PATH})
set(HCC_PATH ${ROCM_PATH}/hcc)
else()
set(HCC_PATH $ENV{HCC_PATH})
endif()
# HSA_PATH # HSA_PATH
if(NOT DEFINED ENV{HSA_PATH}) if(NOT DEFINED ENV{HSA_PATH})
set(HSA_PATH ${ROCM_PATH}/hsa) set(HSA_PATH ${ROCM_PATH}/hsa)
@ -240,8 +233,8 @@ if(HIP_FOUND)
message("\n***** Library versions from cmake find_package *****\n") message("\n***** Library versions from cmake find_package *****\n")
set(CMAKE_HCC_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) set(CMAKE_HIP_CLANG_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
set(CMAKE_HCC_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) set(CMAKE_HIP_CLANG_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.### ### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.###
# As of ROCm 5.1.x, all *.cmake files are under /opt/rocm/lib/cmake/<package> # As of ROCm 5.1.x, all *.cmake files are under /opt/rocm/lib/cmake/<package>
@ -303,18 +296,8 @@ if(HIP_FOUND)
find_package_and_print_version(rocthrust REQUIRED) find_package_and_print_version(rocthrust REQUIRED)
find_package_and_print_version(hipsolver REQUIRED) find_package_and_print_version(hipsolver REQUIRED)
if(HIP_COMPILER STREQUAL clang)
set(hip_library_name amdhip64)
else()
set(hip_library_name hip_hcc)
endif()
message("HIP library name: ${hip_library_name}")
# TODO: hip_hcc has an interface include flag "-hc" which is only find_library(PYTORCH_HIP_LIBRARIES amdhip64 HINTS ${HIP_PATH}/lib)
# recognizable by hcc, but not gcc and clang. Right now in our
# setup, hcc is only used for linking, but it should be used to
# compile the *_hip.cc files as well.
find_library(PYTORCH_HIP_HCC_LIBRARIES ${hip_library_name} HINTS ${HIP_PATH}/lib)
# TODO: miopen_LIBRARIES should return fullpath to the library file, # TODO: miopen_LIBRARIES should return fullpath to the library file,
# however currently it's just the lib name # however currently it's just the lib name
if(TARGET ${miopen_LIBRARIES}) if(TARGET ${miopen_LIBRARIES})
@ -330,7 +313,7 @@ if(HIP_FOUND)
find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib) find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib)
endif() endif()
# hiprtc is part of HIP # hiprtc is part of HIP
find_library(ROCM_HIPRTC_LIB ${hip_library_name} HINTS ${HIP_PATH}/lib) find_library(ROCM_HIPRTC_LIB amdhip64 HINTS ${HIP_PATH}/lib)
# roctx is part of roctracer # roctx is part of roctracer
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib) find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib)
endif() endif()

View File

@ -147,7 +147,7 @@ if(USE_CUDA)
elseif(USE_ROCM) elseif(USE_ROCM)
target_link_libraries(test_jit PRIVATE target_link_libraries(test_jit PRIVATE
${ROCM_HIPRTC_LIB} ${ROCM_HIPRTC_LIB}
${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_HIP_LIBRARIES}
${TORCH_CUDA_LIBRARIES}) ${TORCH_CUDA_LIBRARIES})
target_compile_definitions(test_jit PRIVATE USE_ROCM) target_compile_definitions(test_jit PRIVATE USE_ROCM)

View File

@ -37,7 +37,7 @@ if(USE_CUDA)
elseif(USE_ROCM) elseif(USE_ROCM)
target_link_libraries(test_lazy PRIVATE target_link_libraries(test_lazy PRIVATE
${ROCM_HIPRTC_LIB} ${ROCM_HIPRTC_LIB}
${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_HIP_LIBRARIES}
${TORCH_CUDA_LIBRARIES}) ${TORCH_CUDA_LIBRARIES})
target_compile_definitions(test_lazy PRIVATE USE_ROCM) target_compile_definitions(test_lazy PRIVATE USE_ROCM)

View File

@ -62,13 +62,13 @@ if(USE_CUDA)
elseif(USE_ROCM) elseif(USE_ROCM)
target_link_libraries(test_tensorexpr PRIVATE target_link_libraries(test_tensorexpr PRIVATE
${ROCM_HIPRTC_LIB} ${ROCM_HIPRTC_LIB}
${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_HIP_LIBRARIES}
${TORCH_CUDA_LIBRARIES}) ${TORCH_CUDA_LIBRARIES})
target_compile_definitions(test_tensorexpr PRIVATE USE_ROCM) target_compile_definitions(test_tensorexpr PRIVATE USE_ROCM)
target_link_libraries(tutorial_tensorexpr PRIVATE target_link_libraries(tutorial_tensorexpr PRIVATE
${ROCM_HIPRTC_LIB} ${ROCM_HIPRTC_LIB}
${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_HIP_LIBRARIES}
${TORCH_CUDA_LIBRARIES}) ${TORCH_CUDA_LIBRARIES})
target_compile_definitions(tutorial_tensorexpr PRIVATE USE_ROCM) target_compile_definitions(tutorial_tensorexpr PRIVATE USE_ROCM)
endif() endif()

View File

@ -144,6 +144,50 @@ def is_hip_clang() -> bool:
return False return False
# TODO Remove once the following submodules are updated
hip_platform_files = [
"third_party/fbgemm/fbgemm_gpu/CMakeLists.txt",
"third_party/fbgemm/fbgemm_gpu/cmake/Hip.cmake",
"third_party/fbgemm/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp",
"third_party/fbgemm/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp",
"third_party/fbgemm/fbgemm_gpu/codegen/embedding_backward_split_template.cu",
"third_party/fbgemm/fbgemm_gpu/codegen/embedding_forward_quantized_split_lookup.cu",
"third_party/fbgemm/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh",
"third_party/fbgemm/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh",
"third_party/fbgemm/fbgemm_gpu/src/jagged_tensor_ops.cu",
"third_party/fbgemm/fbgemm_gpu/src/quantize_ops.cu",
"third_party/fbgemm/fbgemm_gpu/src/sparse_ops.cu",
"third_party/fbgemm/fbgemm_gpu/src/split_embeddings_cache_cuda.cu",
"third_party/fbgemm/fbgemm_gpu/src/topology_utils.cpp",
"third_party/fbgemm/src/EmbeddingSpMDM.cc",
"third_party/gloo/cmake/Dependencies.cmake",
"third_party/gloo/gloo/cuda.cu",
"third_party/kineto/libkineto/CMakeLists.txt",
"third_party/nvfuser/CMakeLists.txt",
"third_party/tensorpipe/cmake/Hip.cmake",
]
def remove_hcc(line: str) -> str:
line = line.replace("HIP_PLATFORM_HCC", "HIP_PLATFORM_AMD")
line = line.replace("HIP_HCC_FLAGS", "HIP_CLANG_FLAGS")
return line
for hip_platform_file in hip_platform_files:
do_write = False
if os.path.exists(hip_platform_file):
with open(hip_platform_file) as sources:
lines = sources.readlines()
newlines = [remove_hcc(line) for line in lines]
if lines == newlines:
print(f"{hip_platform_file} skipped")
else:
with open(hip_platform_file, "w") as sources:
for line in newlines:
sources.write(line)
print(f"{hip_platform_file} updated")
hipify_python.hipify( hipify_python.hipify(
project_directory=proj_dir, project_directory=proj_dir,
output_directory=out_dir, output_directory=out_dir,

View File

@ -140,7 +140,7 @@ if(USE_ROCM)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS
USE_ROCM USE_ROCM
__HIP_PLATFORM_HCC__ __HIP_PLATFORM_AMD__
) )
list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ROCM_ROCTX_LIB}) list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ROCM_ROCTX_LIB})
endif() endif()

View File

@ -234,7 +234,7 @@ COMMON_NVCC_FLAGS = [
COMMON_HIP_FLAGS = [ COMMON_HIP_FLAGS = [
'-fPIC', '-fPIC',
'-D__HIP_PLATFORM_HCC__=1', '-D__HIP_PLATFORM_AMD__=1',
'-DUSE_ROCM=1', '-DUSE_ROCM=1',
] ]