mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Update spack includes (#152569)
* Cleans up code in `caffe2/CMakeLists.txt` to remove individual ROCm library include paths and use `ROCM_INCLUDE_DIRS` CMake var instead * `ROCM_INCLUDE_DIRS` CMake var is set in `cmake/public/LoadHIP.cmake` by adding all the ROCm packages that PyTorch depends on * `rocm_version.h` is provided by the `rocm-core` package, so use the include directory for that component to be compliant with Spack * Move `find_package_and_print_version(hip REQUIRED CONFIG)` earlier so that `hip_version.h` can be located in the hip package include dir for Spack * `list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS)` to remove duplicate `/opt/rocm/include` entries in the non-Spack case * Remove user-provided env var `ROCM_INCLUDE_DIRS` since `ROCM_PATH` already exists as a user-provided env var, which should be sufficient to locate the include directories for ROCm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152569 Approved by: https://github.com/renjithravindrankannath, https://github.com/jeffdaily Co-authored-by: Renjith Ravindran <Renjith.RavindranKannath@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
4f425a0397
commit
f11d7a5978
@ -1420,13 +1420,6 @@ if(USE_ROCM)
|
||||
set(ROCM_SOURCE_DIR "/opt/rocm")
|
||||
endif()
|
||||
message(INFO "caffe2 ROCM_SOURCE_DIR = ${ROCM_SOURCE_DIR}")
|
||||
target_include_directories(torch_hip PRIVATE
|
||||
${ROCM_SOURCE_DIR}/include
|
||||
${ROCM_SOURCE_DIR}/hcc/include
|
||||
${ROCM_SOURCE_DIR}/rocblas/include
|
||||
${ROCM_SOURCE_DIR}/hipsparse/include
|
||||
${ROCM_SOURCE_DIR}/include/rccl/
|
||||
)
|
||||
if(USE_FLASH_ATTENTION)
|
||||
target_compile_definitions(torch_hip PRIVATE
|
||||
USE_FLASH_ATTENTION
|
||||
@ -1760,7 +1753,8 @@ if(USE_ROCM)
|
||||
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})
|
||||
|
||||
# Since PyTorch files contain HIP headers, this is also needed to capture the includes.
|
||||
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE})
|
||||
# ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake
|
||||
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS})
|
||||
target_include_directories(torch_hip INTERFACE $<INSTALL_INTERFACE:include>)
|
||||
endif()
|
||||
|
||||
|
@ -26,12 +26,6 @@ else()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT DEFINED ENV{ROCM_INCLUDE_DIRS})
|
||||
set(ROCM_INCLUDE_DIRS ${ROCM_PATH}/include)
|
||||
else()
|
||||
set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
# MAGMA_HOME
|
||||
if(NOT DEFINED ENV{MAGMA_HOME})
|
||||
set(MAGMA_HOME ${ROCM_PATH}/magma)
|
||||
@ -72,6 +66,7 @@ list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
|
||||
macro(find_package_and_print_version PACKAGE_NAME)
|
||||
find_package("${PACKAGE_NAME}" ${ARGN})
|
||||
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
|
||||
list(APPEND ROCM_INCLUDE_DIRS ${${PACKAGE_NAME}_INCLUDE_DIR})
|
||||
endmacro()
|
||||
|
||||
# Find the HIP Package
|
||||
@ -82,9 +77,16 @@ find_package_and_print_version(HIP 1.0 MODULE)
|
||||
if(HIP_FOUND)
|
||||
set(PYTORCH_FOUND_HIP TRUE)
|
||||
|
||||
find_package_and_print_version(hip REQUIRED CONFIG)
|
||||
# Find ROCM version for checks. UNIX filename is rocm_version.h, Windows is hip_version.h
|
||||
find_file(ROCM_VERSION_HEADER_PATH NAMES rocm_version.h hip_version.h
|
||||
HINTS ${ROCM_INCLUDE_DIRS}/rocm-core ${ROCM_INCLUDE_DIRS}/hip /usr/include)
|
||||
if(UNIX)
|
||||
find_package_and_print_version(rocm-core REQUIRED CONFIG)
|
||||
find_file(ROCM_VERSION_HEADER_PATH NAMES rocm_version.h
|
||||
HINTS ${rocm_core_INCLUDE_DIR}/rocm-core /usr/include)
|
||||
else() # Win32
|
||||
find_file(ROCM_VERSION_HEADER_PATH NAMES hip_version.h
|
||||
HINTS ${hip_INCLUDE_DIR}/hip)
|
||||
endif()
|
||||
get_filename_component(ROCM_HEADER_NAME ${ROCM_VERSION_HEADER_PATH} NAME)
|
||||
|
||||
if(EXISTS ${ROCM_VERSION_HEADER_PATH})
|
||||
@ -141,7 +143,6 @@ if(HIP_FOUND)
|
||||
# Find ROCM components using Config mode
|
||||
# These components will be searced for recursively in ${ROCM_PATH}
|
||||
message("\n***** Library versions from cmake find_package *****\n")
|
||||
find_package_and_print_version(hip REQUIRED CONFIG)
|
||||
find_package_and_print_version(amd_comgr REQUIRED)
|
||||
find_package_and_print_version(rocrand REQUIRED)
|
||||
find_package_and_print_version(hiprand REQUIRED)
|
||||
@ -168,7 +169,11 @@ if(HIP_FOUND)
|
||||
if(UNIX)
|
||||
find_package_and_print_version(rccl)
|
||||
find_package_and_print_version(hsa-runtime64 REQUIRED)
|
||||
endif()
|
||||
|
||||
list(REMOVE_DUPLICATES ROCM_INCLUDE_DIRS)
|
||||
|
||||
if(UNIX)
|
||||
# roctx is part of roctracer
|
||||
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
|
||||
|
||||
|
Reference in New Issue
Block a user