[ROCm] Update spack includes (#152569)

* Cleans up code in `caffe2/CMakeLists.txt` to remove individual ROCm library include paths and use `ROCM_INCLUDE_DIRS` CMake var instead
* `ROCM_INCLUDE_DIRS` CMake var is set in `cmake/public/LoadHIP.cmake` by adding all the ROCm packages that PyTorch depends on
* `rocm_version.h` is provided by the `rocm-core` package, so use the include directory for that component to be compliant with Spack
* Move `find_package_and_print_version(hip REQUIRED CONFIG)` earlier so that `hip_version.h` can be located in the hip package include dir for Spack
* `list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS)` to remove duplicate `/opt/rocm/include` entries in the non-Spack case
* Remove user-provided env var `ROCM_INCLUDE_DIRS` since `ROCM_PATH` already exists as a user-provided env var, which should be sufficient to locate the include directories for ROCm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152569
Approved by: https://github.com/renjithravindrankannath, https://github.com/jeffdaily

Co-authored-by: Renjith Ravindran <Renjith.RavindranKannath@amd.com>
This commit is contained in:
Jithun Nair
2025-05-09 21:36:38 +00:00
committed by PyTorch MergeBot
parent 4f425a0397
commit f11d7a5978
2 changed files with 16 additions and 17 deletions

View File

@ -1420,13 +1420,6 @@ if(USE_ROCM)
set(ROCM_SOURCE_DIR "/opt/rocm")
endif()
message(INFO "caffe2 ROCM_SOURCE_DIR = ${ROCM_SOURCE_DIR}")
target_include_directories(torch_hip PRIVATE
${ROCM_SOURCE_DIR}/include
${ROCM_SOURCE_DIR}/hcc/include
${ROCM_SOURCE_DIR}/rocblas/include
${ROCM_SOURCE_DIR}/hipsparse/include
${ROCM_SOURCE_DIR}/include/rccl/
)
if(USE_FLASH_ATTENTION)
target_compile_definitions(torch_hip PRIVATE
USE_FLASH_ATTENTION
@ -1760,7 +1753,8 @@ if(USE_ROCM)
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})
# Since PyTorch files contain HIP headers, this is also needed to capture the includes.
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE})
# ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS})
target_include_directories(torch_hip INTERFACE $<INSTALL_INTERFACE:include>)
endif()

View File

@ -26,12 +26,6 @@ else()
endif()
endif()
if(NOT DEFINED ENV{ROCM_INCLUDE_DIRS})
set(ROCM_INCLUDE_DIRS ${ROCM_PATH}/include)
else()
set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS})
endif()
# MAGMA_HOME
if(NOT DEFINED ENV{MAGMA_HOME})
set(MAGMA_HOME ${ROCM_PATH}/magma)
@ -72,6 +66,7 @@ list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
macro(find_package_and_print_version PACKAGE_NAME)
find_package("${PACKAGE_NAME}" ${ARGN})
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
list(APPEND ROCM_INCLUDE_DIRS ${${PACKAGE_NAME}_INCLUDE_DIR})
endmacro()
# Find the HIP Package
@ -82,9 +77,16 @@ find_package_and_print_version(HIP 1.0 MODULE)
if(HIP_FOUND)
set(PYTORCH_FOUND_HIP TRUE)
find_package_and_print_version(hip REQUIRED CONFIG)
# Find ROCM version for checks. UNIX filename is rocm_version.h, Windows is hip_version.h
find_file(ROCM_VERSION_HEADER_PATH NAMES rocm_version.h hip_version.h
HINTS ${ROCM_INCLUDE_DIRS}/rocm-core ${ROCM_INCLUDE_DIRS}/hip /usr/include)
if(UNIX)
find_package_and_print_version(rocm-core REQUIRED CONFIG)
find_file(ROCM_VERSION_HEADER_PATH NAMES rocm_version.h
HINTS ${rocm_core_INCLUDE_DIR}/rocm-core /usr/include)
else() # Win32
find_file(ROCM_VERSION_HEADER_PATH NAMES hip_version.h
HINTS ${hip_INCLUDE_DIR}/hip)
endif()
get_filename_component(ROCM_HEADER_NAME ${ROCM_VERSION_HEADER_PATH} NAME)
if(EXISTS ${ROCM_VERSION_HEADER_PATH})
@ -141,7 +143,6 @@ if(HIP_FOUND)
# Find ROCM components using Config mode
# These components will be searced for recursively in ${ROCM_PATH}
message("\n***** Library versions from cmake find_package *****\n")
find_package_and_print_version(hip REQUIRED CONFIG)
find_package_and_print_version(amd_comgr REQUIRED)
find_package_and_print_version(rocrand REQUIRED)
find_package_and_print_version(hiprand REQUIRED)
@ -168,7 +169,11 @@ if(HIP_FOUND)
if(UNIX)
find_package_and_print_version(rccl)
find_package_and_print_version(hsa-runtime64 REQUIRED)
endif()
list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS)
if(UNIX)
# roctx is part of roctracer
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)