mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 099d0d6121125062ebc05771c8330cb7cd8d053a. Reverted https://github.com/pytorch/pytorch/pull/153784 on behalf of https://github.com/Camyll due to breaking internal tests and cuda 12.4 builds still used in CI ([comment](https://github.com/pytorch/pytorch/pull/153784#issuecomment-3001702310))
124 lines
3.1 KiB
C++
124 lines
3.1 KiB
C++
#ifdef _WIN32
|
|
#include <wchar.h> // _wgetenv for nvtx
|
|
#endif
|
|
|
|
#ifndef ROCM_ON_WINDOWS
|
|
#ifdef TORCH_CUDA_USE_NVTX3
|
|
#include <nvtx3/nvtx3.hpp>
|
|
#else // TORCH_CUDA_USE_NVTX3
|
|
#include <nvToolsExt.h>
|
|
#endif // TORCH_CUDA_USE_NVTX3
|
|
#else // ROCM_ON_WINDOWS
|
|
#include <c10/util/Exception.h>
|
|
#endif // ROCM_ON_WINDOWS
|
|
#include <c10/cuda/CUDAException.h>
|
|
#include <cuda_runtime.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace torch::cuda::shared {
|
|
|
|
#ifndef ROCM_ON_WINDOWS
|
|
struct RangeHandle {
|
|
nvtxRangeId_t id;
|
|
const char* msg;
|
|
};
|
|
|
|
static void device_callback_range_end(void* userData) {
|
|
RangeHandle* handle = ((RangeHandle*)userData);
|
|
nvtxRangeEnd(handle->id);
|
|
free((void*)handle->msg);
|
|
free((void*)handle);
|
|
}
|
|
|
|
static void device_nvtxRangeEnd(void* handle, std::intptr_t stream) {
|
|
C10_CUDA_CHECK(cudaLaunchHostFunc(
|
|
(cudaStream_t)stream, device_callback_range_end, handle));
|
|
}
|
|
|
|
static void device_callback_range_start(void* userData) {
|
|
RangeHandle* handle = ((RangeHandle*)userData);
|
|
handle->id = nvtxRangeStartA(handle->msg);
|
|
}
|
|
|
|
static void* device_nvtxRangeStart(const char* msg, std::intptr_t stream) {
|
|
auto handle = static_cast<RangeHandle*>(calloc(1, sizeof(RangeHandle)));
|
|
handle->msg = strdup(msg);
|
|
handle->id = 0;
|
|
TORCH_CHECK(
|
|
cudaLaunchHostFunc(
|
|
(cudaStream_t)stream, device_callback_range_start, (void*)handle) ==
|
|
cudaSuccess);
|
|
return handle;
|
|
}
|
|
|
|
void initNvtxBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
#ifdef TORCH_CUDA_USE_NVTX3
|
|
auto nvtx = m.def_submodule("_nvtx", "nvtx3 bindings");
|
|
#else
|
|
auto nvtx = m.def_submodule("_nvtx", "libNvToolsExt.so bindings");
|
|
#endif
|
|
nvtx.def("rangePushA", nvtxRangePushA);
|
|
nvtx.def("rangePop", nvtxRangePop);
|
|
nvtx.def("rangeStartA", nvtxRangeStartA);
|
|
nvtx.def("rangeEnd", nvtxRangeEnd);
|
|
nvtx.def("markA", nvtxMarkA);
|
|
nvtx.def("deviceRangeStart", device_nvtxRangeStart);
|
|
nvtx.def("deviceRangeEnd", device_nvtxRangeEnd);
|
|
}
|
|
|
|
#else // ROCM_ON_WINDOWS
|
|
|
|
static void printUnavailableWarning() {
|
|
TORCH_WARN_ONCE("Warning: roctracer isn't available on Windows");
|
|
}
|
|
|
|
static int rangePushA(const std::string&) {
|
|
printUnavailableWarning();
|
|
return 0;
|
|
}
|
|
|
|
static int rangePop() {
|
|
printUnavailableWarning();
|
|
return 0;
|
|
}
|
|
|
|
static int rangeStartA(const std::string&) {
|
|
printUnavailableWarning();
|
|
return 0;
|
|
}
|
|
|
|
static void rangeEnd(int) {
|
|
printUnavailableWarning();
|
|
}
|
|
|
|
static void markA(const std::string&) {
|
|
printUnavailableWarning();
|
|
}
|
|
|
|
static py::object deviceRangeStart(const std::string&, std::intptr_t) {
|
|
printUnavailableWarning();
|
|
return py::none(); // Return an appropriate default object
|
|
}
|
|
|
|
static void deviceRangeEnd(py::object, std::intptr_t) {
|
|
printUnavailableWarning();
|
|
}
|
|
|
|
void initNvtxBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
auto nvtx = m.def_submodule("_nvtx", "unavailable");
|
|
|
|
nvtx.def("rangePushA", rangePushA);
|
|
nvtx.def("rangePop", rangePop);
|
|
nvtx.def("rangeStartA", rangeStartA);
|
|
nvtx.def("rangeEnd", rangeEnd);
|
|
nvtx.def("markA", markA);
|
|
nvtx.def("deviceRangeStart", deviceRangeStart);
|
|
nvtx.def("deviceRangeEnd", deviceRangeEnd);
|
|
}
|
|
#endif // ROCM_ON_WINDOWS
|
|
|
|
} // namespace torch::cuda::shared
|