Compare commits

...

1 Commits

Author SHA1 Message Date
04a786e334 Always use NVTX3 2025-11-14 16:44:52 +04:00
6 changed files with 2 additions and 30 deletions

View File

@ -1643,8 +1643,6 @@ if(USE_CUDA)
target_link_libraries(torch_cuda PUBLIC c10_cuda)
if(TARGET torch::nvtx3)
target_link_libraries(torch_cuda PRIVATE torch::nvtx3)
else()
target_link_libraries(torch_cuda PUBLIC torch::nvtoolsext)
endif()
target_include_directories(
@ -1741,9 +1739,6 @@ if(BUILD_SHARED_LIBS)
if(USE_CUDA)
target_link_libraries(torch_global_deps ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS})
target_link_libraries(torch_global_deps torch::cudart)
if(TARGET torch::nvtoolsext)
target_link_libraries(torch_global_deps torch::nvtoolsext)
endif()
endif()
install(TARGETS torch_global_deps DESTINATION "${TORCH_INSTALL_LIB_DIR}")
endif()

View File

@ -968,11 +968,8 @@ find_package_handle_standard_args(nvtx3 DEFAULT_MSG nvtx3_dir)
if(nvtx3_FOUND)
add_library(torch::nvtx3 INTERFACE IMPORTED)
target_include_directories(torch::nvtx3 INTERFACE "${nvtx3_dir}")
target_compile_definitions(torch::nvtx3 INTERFACE TORCH_CUDA_USE_NVTX3)
else()
message(WARNING "Cannot find NVTX3, find old NVTX instead")
add_library(torch::nvtoolsext INTERFACE IMPORTED)
set_property(TARGET torch::nvtoolsext PROPERTY INTERFACE_LINK_LIBRARIES CUDA::nvToolsExt)
message(FATAL_ERROR "Cannot find NVTX3!")
endif()

View File

@ -132,9 +132,6 @@ if(@USE_CUDA@)
else()
set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB})
endif()
if(TARGET torch::nvtoolsext)
list(APPEND TORCH_CUDA_LIBRARIES torch::nvtoolsext)
endif()
if(@BUILD_SHARED_LIBS@)
find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib")

View File

@ -150,10 +150,6 @@ if(USE_CUDA)
if(TARGET torch::nvtx3)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtx3)
else()
if(TARGET torch::nvtoolsext)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtoolsext)
endif()
endif()
endif()

View File

@ -2,18 +2,13 @@
#include <wchar.h> // _wgetenv for nvtx
#endif
#include <cuda_runtime.h>
#ifndef ROCM_ON_WINDOWS
#if CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3)
#include <nvtx3/nvtx3.hpp>
#else // CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3)
#include <nvToolsExt.h>
#endif // CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3)
#else // ROCM_ON_WINDOWS
#include <c10/util/Exception.h>
#endif // ROCM_ON_WINDOWS
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::cuda::shared {
@ -55,11 +50,7 @@ static void* device_nvtxRangeStart(const char* msg, std::intptr_t stream) {
void initNvtxBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
#ifdef TORCH_CUDA_USE_NVTX3
auto nvtx = m.def_submodule("_nvtx", "nvtx3 bindings");
#else
auto nvtx = m.def_submodule("_nvtx", "libNvToolsExt.so bindings");
#endif
nvtx.def("rangePushA", nvtxRangePushA);
nvtx.def("rangePop", nvtxRangePop);
nvtx.def("rangeStartA", nvtxRangeStartA);

View File

@ -1,11 +1,7 @@
#include <sstream>
#ifndef ROCM_ON_WINDOWS
#if CUDART_VERSION >= 13000 || defined(TORCH_CUDA_USE_NVTX3)
#include <nvtx3/nvtx3.hpp>
#else
#include <nvToolsExt.h>
#endif
#else // ROCM_ON_WINDOWS
#include <c10/util/Exception.h>
#endif // ROCM_ON_WINDOWS