mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Back out "[pytorch][PR] Move thnvrtc and DynamicLibrary to ATen" (#22749)
Summary: Original commit changeset: add2ee8a8865 Pull Request resolved: https://github.com/pytorch/pytorch/pull/22749 ghstack-source-id: 86323899 Differential Revision: D16203552 fbshipit-source-id: 227df3b85316315c15d2cb7b6a5c884096a82e9e
This commit is contained in:
committed by
Facebook Github Bot
parent
8bdda03ae1
commit
ac78a86e1d
@ -24,7 +24,6 @@ set(ATen_THIRD_PARTY_INCLUDE)
|
||||
set(ATen_CUDA_SRCS)
|
||||
set(ATen_CUDA_TEST_SRCS)
|
||||
set(ATen_CUDA_INCLUDE)
|
||||
set(ATen_NVRTC_STUB_SRCS)
|
||||
set(ATen_HIP_SRCS)
|
||||
set(ATen_HIP_TEST_SRCS)
|
||||
set(ATen_HIP_INCLUDE)
|
||||
@ -102,7 +101,6 @@ add_subdirectory(src/ATen)
|
||||
set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} PARENT_SCOPE)
|
||||
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE)
|
||||
set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE)
|
||||
set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE)
|
||||
|
@ -7,7 +7,6 @@
|
||||
#include <ATen/DeviceGuard.h>
|
||||
#include <ATen/DimVector.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/DynamicLibrary.h>
|
||||
#include <ATen/Formatting.h>
|
||||
#include <ATen/Functions.h>
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
|
@ -39,8 +39,6 @@ FILE(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp")
|
||||
add_subdirectory(core)
|
||||
FILE(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh")
|
||||
FILE(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp")
|
||||
FILE(GLOB cuda_nvrtc_stub_h "cuda/nvrtc_stub/*.h")
|
||||
FILE(GLOB cuda_nvrtc_stub_cpp "cuda/nvrtc_stub/*.cpp")
|
||||
FILE(GLOB cuda_cu "cuda/*.cu" "cuda/detail/*.cu")
|
||||
FILE(GLOB cudnn_h "cudnn/*.h" "cudnn/*.cuh")
|
||||
FILE(GLOB cudnn_cpp "cudnn/*.cpp")
|
||||
@ -48,8 +46,6 @@ FILE(GLOB cudnn_cpp "cudnn/*.cpp")
|
||||
FILE(GLOB hip_h "hip/*.h" "hip/detail/*.h" "hip/*.cuh" "hip/detail/*.cuh")
|
||||
FILE(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp" "hip/impl/*.cpp")
|
||||
FILE(GLOB hip_hip "hip/*.hip" "hip/detail/*.hip" "hip/impl/*.hip")
|
||||
FILE(GLOB hip_nvrtc_stub_h "hip/nvrtc_stub/*.h")
|
||||
FILE(GLOB hip_nvrtc_stub_cpp "hip/nvrtc_stub/*.cpp")
|
||||
FILE(GLOB miopen_h "miopen/*.h")
|
||||
FILE(GLOB miopen_cpp "miopen/*.cpp")
|
||||
|
||||
@ -360,7 +356,6 @@ endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
set(ATen_CUDA_SRCS ${all_cuda_cpp})
|
||||
set(ATen_NVRTC_STUB_SRCS ${cuda_nvrtc_stub_cpp})
|
||||
if(AT_LINK_STYLE STREQUAL "INTERFACE")
|
||||
# Source code can't be added to an interface library, so it is
|
||||
# passed back to be compiled into the containing library
|
||||
@ -373,9 +368,6 @@ endif()
|
||||
|
||||
if(USE_ROCM)
|
||||
set(ATen_HIP_SRCS ${all_hip_cpp})
|
||||
# caffe2_nvrtc's stubs to driver APIs are useful for HIP.
|
||||
# See NOTE [ ATen NVRTC Stub and HIP ]
|
||||
set(ATen_NVRTC_STUB_SRCS ${hip_nvrtc_stub_cpp})
|
||||
if(AT_LINK_STYLE STREQUAL "INTERFACE")
|
||||
# Source code can't be added to an interface library, so it is
|
||||
# passed back to be compiled into the containing library
|
||||
@ -447,7 +439,6 @@ endif()
|
||||
set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} PARENT_SCOPE)
|
||||
set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE)
|
||||
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE)
|
||||
set(ATen_QUANTIZED_SRCS ${ATen_QUANTIZED_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE)
|
||||
|
@ -77,9 +77,7 @@ class CAFFE2_API Context {
|
||||
});
|
||||
return thh_state.get();
|
||||
}
|
||||
const at::cuda::NVRTC& getNVRTC() {
|
||||
return detail::getCUDAHooks().nvrtc();
|
||||
}
|
||||
|
||||
THCState* getTHCState() {
|
||||
// AT_ASSERT(thc_state);
|
||||
return thc_state.get();
|
||||
|
@ -1,74 +0,0 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <ATen/DynamicLibrary.h>
|
||||
#include <ATen/Utils.h>
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <dlfcn.h>
|
||||
#include <libgen.h>
|
||||
#else
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
|
||||
|
||||
#ifndef _WIN32
|
||||
|
||||
// Unix
|
||||
|
||||
static void* checkDL(void* x) {
|
||||
if (!x) {
|
||||
AT_ERROR("Error in dlopen or dlsym: ", dlerror());
|
||||
}
|
||||
|
||||
return x;
|
||||
}
|
||||
DynamicLibrary::DynamicLibrary(const char* name) {
|
||||
// NOLINTNEXTLINE(hicpp-signed-bitwise)
|
||||
handle = checkDL(dlopen(name, RTLD_LOCAL | RTLD_NOW));
|
||||
}
|
||||
|
||||
void* DynamicLibrary::sym(const char* name) {
|
||||
AT_ASSERT(handle);
|
||||
return checkDL(dlsym(handle, name));
|
||||
}
|
||||
|
||||
DynamicLibrary::~DynamicLibrary() {
|
||||
if (!handle)
|
||||
return;
|
||||
dlclose(handle);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// Windows
|
||||
|
||||
DynamicLibrary::DynamicLibrary(const char* name) {
|
||||
// NOLINTNEXTLINE(hicpp-signed-bitwise)
|
||||
HMODULE theModule = LoadLibraryA(name);
|
||||
if (theModule) {
|
||||
handle = theModule;
|
||||
} else {
|
||||
AT_ERROR("error in LoadLibraryA");
|
||||
}
|
||||
}
|
||||
|
||||
void* DynamicLibrary::sym(const char* name) {
|
||||
AT_ASSERT(handle);
|
||||
FARPROC procAddress = GetProcAddress((HMODULE)handle, name);
|
||||
if (!procAddress) {
|
||||
AT_ERROR("error in GetProcAddress");
|
||||
}
|
||||
return (void*)procAddress;
|
||||
}
|
||||
|
||||
DynamicLibrary::~DynamicLibrary() {
|
||||
if (!handle) {
|
||||
return;
|
||||
}
|
||||
FreeLibrary((HMODULE)handle);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace at
|
@ -1,21 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <ATen/Utils.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
struct DynamicLibrary {
|
||||
AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
|
||||
|
||||
CAFFE2_API DynamicLibrary(const char* name);
|
||||
|
||||
CAFFE2_API void* sym(const char* name);
|
||||
|
||||
CAFFE2_API ~DynamicLibrary();
|
||||
|
||||
private:
|
||||
void* handle = nullptr;
|
||||
};
|
||||
|
||||
} // namespace at
|
@ -24,10 +24,6 @@
|
||||
#define __ubsan_ignore_vptr__
|
||||
#endif
|
||||
|
||||
#define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||
TypeName(const TypeName&) = delete; \
|
||||
void operator=(const TypeName&) = delete
|
||||
|
||||
namespace at {
|
||||
|
||||
CAFFE2_API int _crash_if_asan(int);
|
||||
|
@ -3,8 +3,8 @@
|
||||
#include <ATen/core/ATenGeneral.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Context.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
|
||||
@ -21,57 +20,3 @@
|
||||
} while (0)
|
||||
|
||||
#define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)
|
||||
|
||||
// For CUDA Driver API
|
||||
//
|
||||
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
|
||||
// in ATen, and we need to use its nvrtcGetErrorString.
|
||||
// See NOTE [ USE OF NVRTC AND DRIVER API ].
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
|
||||
#define AT_CUDA_DRIVER_CHECK(EXPR) \
|
||||
do { \
|
||||
CUresult __err = EXPR; \
|
||||
if (__err != CUDA_SUCCESS) { \
|
||||
const char* err_str; \
|
||||
CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \
|
||||
if (get_error_str_err != CUDA_SUCCESS) { \
|
||||
AT_ERROR("CUDA driver error: unknown error"); \
|
||||
} else { \
|
||||
AT_ERROR("CUDA driver error: ", err_str); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#else
|
||||
|
||||
#define AT_CUDA_DRIVER_CHECK(EXPR) \
|
||||
do { \
|
||||
CUresult __err = EXPR; \
|
||||
if (__err != CUDA_SUCCESS) { \
|
||||
AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#endif
|
||||
|
||||
// For CUDA NVRTC
|
||||
//
|
||||
// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
|
||||
// incorrectly produces the error string "NVRTC unknown error."
|
||||
// The following maps it correctly.
|
||||
//
|
||||
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
|
||||
// in ATen, and we need to use its nvrtcGetErrorString.
|
||||
// See NOTE [ USE OF NVRTC AND DRIVER API ].
|
||||
#define AT_CUDA_NVRTC_CHECK(EXPR) \
|
||||
do { \
|
||||
nvrtcResult __err = EXPR; \
|
||||
if (__err != NVRTC_SUCCESS) { \
|
||||
if (static_cast<int>(__err) != 7) { \
|
||||
AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \
|
||||
} else { \
|
||||
AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
@ -2,11 +2,9 @@
|
||||
|
||||
#include <ATen/CUDAGenerator.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/DynamicLibrary.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/cuda/CUDADevice.h>
|
||||
#include <ATen/cuda/PinnedMemoryAllocator.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
#include <ATen/native/cuda/CuFFTPlanCache.h>
|
||||
#include <c10/util/Exception.h>
|
||||
@ -79,32 +77,6 @@ bool CUDAHooks::hasCuDNN() const {
|
||||
return AT_CUDNN_ENABLED();
|
||||
}
|
||||
|
||||
#ifdef USE_DIRECT_NVRTC
|
||||
static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() {
|
||||
return std::make_pair(nullptr, at::cuda::load_nvrtc());
|
||||
}
|
||||
#else
|
||||
static std::pair<std::unique_ptr<at::DynamicLibrary>, at::cuda::NVRTC*> load_nvrtc() {
|
||||
#if defined(_WIN32)
|
||||
std::string libcaffe2_nvrtc = "caffe2_nvrtc.dll";
|
||||
#elif defined(__APPLE__)
|
||||
std::string libcaffe2_nvrtc = "libcaffe2_nvrtc.dylib";
|
||||
#else
|
||||
std::string libcaffe2_nvrtc = "libcaffe2_nvrtc.so";
|
||||
#endif
|
||||
std::unique_ptr<at::DynamicLibrary> libnvrtc_stub(
|
||||
new at::DynamicLibrary(libcaffe2_nvrtc.c_str()));
|
||||
auto fn = (at::cuda::NVRTC * (*)()) libnvrtc_stub->sym("load_nvrtc");
|
||||
return std::make_pair(std::move(libnvrtc_stub), fn());
|
||||
}
|
||||
#endif
|
||||
|
||||
const at::cuda::NVRTC& CUDAHooks::nvrtc() const {
|
||||
// must hold onto DynamicLibrary otherwise it will unload
|
||||
static auto handle = load_nvrtc();
|
||||
return *handle.second;
|
||||
}
|
||||
|
||||
int64_t CUDAHooks::current_device() const {
|
||||
int device;
|
||||
cudaError_t err = cudaGetDevice(&device);
|
||||
|
@ -16,7 +16,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool hasCUDA() const override;
|
||||
bool hasMAGMA() const override;
|
||||
bool hasCuDNN() const override;
|
||||
const at::cuda::NVRTC& nvrtc() const override;
|
||||
int64_t current_device() const override;
|
||||
Allocator* getPinnedMemoryAllocator() const override;
|
||||
bool compiledWithCuDNN() const override;
|
||||
|
@ -1,13 +0,0 @@
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace at { namespace cuda {
|
||||
|
||||
NVRTC* load_nvrtc() {
|
||||
auto self = new NVRTC();
|
||||
#define CREATE_ASSIGN(name) self->name = name;
|
||||
AT_FORALL_NVRTC(CREATE_ASSIGN)
|
||||
return self;
|
||||
}
|
||||
|
||||
}} // at::cuda
|
@ -1,88 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/ATenCUDAGeneral.h>
|
||||
#include <cuda.h>
|
||||
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
#include <nvrtc.h>
|
||||
#endif
|
||||
|
||||
namespace at { namespace cuda {
|
||||
|
||||
|
||||
// NOTE [ USE OF NVRTC AND DRIVER API ]
|
||||
//
|
||||
// ATen does not directly link to either libnvrtc or libcuda because they
|
||||
// require libcuda to be installed, yet we want our GPU build to work on CPU
|
||||
// machines as long as CUDA is not initialized.
|
||||
//
|
||||
// Normal CUDA code in torch uses the cuda runtime libraries which can be
|
||||
// installed even if the driver is not installed, but sometimes we specifically
|
||||
// need to use the driver API (e.g., to load JIT compiled code).
|
||||
// To accomplish this, we lazily link libcaffe2_nvrtc which provides a struct
|
||||
// at::cuda::NVRTC that contains function pointers to all of the apis we need.
|
||||
//
|
||||
// IT IS AN ERROR TO TRY TO CALL ANY nvrtc* or cu* FUNCTION DIRECTLY.
|
||||
// INSTEAD USE, e.g.
|
||||
// detail::getCUDAHooks().nvrtc().cuLoadModule(...)
|
||||
// oe
|
||||
// globalContext().getNVRTC().cuLoadModule(...)
|
||||
//
|
||||
// If a function is missing add it to the list in ATen/cuda/nvrtc_stub/ATenNVRTC.h.
|
||||
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
|
||||
#define AT_FORALL_NVRTC(_) \
|
||||
_(nvrtcVersion) \
|
||||
_(nvrtcCreateProgram) \
|
||||
_(nvrtcDestroyProgram) \
|
||||
_(nvrtcGetPTXSize) \
|
||||
_(nvrtcGetPTX) \
|
||||
_(nvrtcCompileProgram) \
|
||||
_(nvrtcGetErrorString) \
|
||||
_(nvrtcGetProgramLogSize) \
|
||||
_(nvrtcGetProgramLog) \
|
||||
_(cuModuleLoadData) \
|
||||
_(cuModuleGetFunction) \
|
||||
_(cuOccupancyMaxActiveBlocksPerMultiprocessor) \
|
||||
_(cuGetErrorString) \
|
||||
_(cuLaunchKernel) \
|
||||
_(cuCtxGetCurrent) \
|
||||
_(cuModuleUnload) \
|
||||
_(cuDevicePrimaryCtxGetState)
|
||||
|
||||
#else
|
||||
|
||||
// NOTE [ ATen NVRTC Stub and HIP ]
|
||||
//
|
||||
// ATen's NVRTC stub library, caffe2_nvrtc, provides dynamic loading of both
|
||||
// NVRTC and driver APIs. While the former is not yet suppoted for HIP, the
|
||||
// later is supported and needed.
|
||||
//
|
||||
// The macro below strips out certain unsupported operations on HIP from the full
|
||||
// list above.
|
||||
//
|
||||
// HIP doesn't have
|
||||
// nvrtc*
|
||||
// cuOccupancyMaxActiveBlocksPerMultiprocessor
|
||||
// cuGetErrorString (maps to non-functional hipGetErrorString___)
|
||||
|
||||
#define AT_FORALL_NVRTC(_) \
|
||||
_(cuModuleLoadData) \
|
||||
_(cuModuleGetFunction) \
|
||||
_(cuLaunchKernel) \
|
||||
_(cuCtxGetCurrent) \
|
||||
_(cuModuleUnload) \
|
||||
_(cuDevicePrimaryCtxGetState)
|
||||
|
||||
#endif
|
||||
|
||||
extern "C" typedef struct NVRTC {
|
||||
#define CREATE_MEMBER(name) decltype(&name) name;
|
||||
AT_FORALL_NVRTC(CREATE_MEMBER)
|
||||
#undef CREATE_MEMBER
|
||||
} NVRTC;
|
||||
|
||||
extern "C" AT_CUDA_API NVRTC* load_nvrtc();
|
||||
|
||||
}} // at::cuda
|
@ -13,11 +13,6 @@
|
||||
// Forward-declares THCState
|
||||
struct THCState;
|
||||
|
||||
// Forward-declares at::cuda::NVRTC
|
||||
namespace at { namespace cuda {
|
||||
struct NVRTC;
|
||||
}} // at::cuda
|
||||
|
||||
namespace at {
|
||||
class Context;
|
||||
}
|
||||
@ -83,10 +78,6 @@ struct CAFFE2_API CUDAHooksInterface {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual const at::cuda::NVRTC& nvrtc() const {
|
||||
AT_ERROR("NVRTC requires CUDA. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual int64_t current_device() const {
|
||||
return -1;
|
||||
}
|
||||
|
@ -11,7 +11,6 @@
|
||||
// macro and a function implementation if we pass along __LINE__
|
||||
// and __FILE__, but no one has found this worth doing.
|
||||
|
||||
// For CUDA Runtime API
|
||||
#define C10_CUDA_CHECK(EXPR) \
|
||||
do { \
|
||||
cudaError_t __err = EXPR; \
|
||||
|
@ -454,27 +454,42 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_ROOT}/test/cpp/jit/test.cpp
|
||||
)
|
||||
|
||||
if (NOT WIN32)
|
||||
if (WIN32)
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/dynamic_library_win.cpp
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (USE_CUDA)
|
||||
if (NOT USE_ROCM)
|
||||
if (USE_CUDA AND NOT USE_ROCM)
|
||||
list(APPEND Caffe2_GPU_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp
|
||||
)
|
||||
add_library(thnvrtc SHARED ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/thnvrtc.cpp)
|
||||
target_link_libraries(thnvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB})
|
||||
target_include_directories(thnvrtc PRIVATE ${CUDA_INCLUDE_DIRS})
|
||||
install(TARGETS thnvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
endif()
|
||||
else ()
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/dynamic_library_unix.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp
|
||||
)
|
||||
if (USE_CUDA AND NOT USE_ROCM)
|
||||
list(APPEND Caffe2_GPU_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp
|
||||
)
|
||||
add_library(thnvrtc SHARED ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/thnvrtc.cpp)
|
||||
target_link_libraries(thnvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB})
|
||||
target_include_directories(thnvrtc PRIVATE ${CUDA_INCLUDE_DIRS})
|
||||
install(TARGETS thnvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
if (USE_CUDA)
|
||||
list(APPEND Caffe2_GPU_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp
|
||||
${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
|
||||
)
|
||||
add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS})
|
||||
target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB})
|
||||
target_include_directories(caffe2_nvrtc PRIVATE ${CUDA_INCLUDE_DIRS})
|
||||
install(TARGETS caffe2_nvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
endif()
|
||||
|
||||
if (USE_ROCM)
|
||||
@ -482,13 +497,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp
|
||||
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
|
||||
)
|
||||
# caffe2_nvrtc's stubs to driver APIs are useful for HIP.
|
||||
# See NOTE [ ATen NVRTC Stub and HIP ]
|
||||
add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS})
|
||||
target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB})
|
||||
target_include_directories(caffe2_nvrtc PRIVATE ${CUDA_INCLUDE_DIRS})
|
||||
target_compile_definitions(caffe2_nvrtc PRIVATE USE_ROCM __HIP_PLATFORM_HCC__)
|
||||
install(TARGETS caffe2_nvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
endif()
|
||||
|
||||
if (NOT NO_API)
|
||||
|
@ -65,8 +65,8 @@ if (@USE_CUDA@)
|
||||
${NVTOOLEXT_HOME}/lib/x64/nvToolsExt64_1.lib
|
||||
${CUDA_LIBRARIES})
|
||||
list(APPEND TORCH_INCLUDE_DIRS ${NVTOOLEXT_HOME}/include)
|
||||
find_library(CAFFE2_NVRTC_LIBRARY caffe2_nvrtc PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||
list(APPEND TORCH_CUDA_LIBRARIES ${CAFFE2_NVRTC_LIBRARY})
|
||||
find_library(THNVRTC_LIBRARY thnvrtc PATHS "${TORCH_INSTALL_PREFIX}/lib")
|
||||
list(APPEND TORCH_CUDA_LIBRARIES ${THNVRTC_LIBRARY})
|
||||
elseif(APPLE)
|
||||
set(TORCH_CUDA_LIBRARIES
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib/libcudart.dylib
|
||||
|
@ -19,17 +19,17 @@ if not TEST_CUDA:
|
||||
TestCase = object # noqa: F811
|
||||
|
||||
|
||||
_caffe2_nvrtc = None
|
||||
_thnvrtc = None
|
||||
|
||||
|
||||
def get_is_primary_context_created(device):
|
||||
flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
|
||||
active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
||||
global _caffe2_nvrtc
|
||||
if _caffe2_nvrtc is None:
|
||||
path = glob.glob('{}/lib/libcaffe2_nvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
|
||||
_caffe2_nvrtc = ctypes.cdll.LoadLibrary(path)
|
||||
result = _caffe2_nvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
|
||||
global _thnvrtc
|
||||
if _thnvrtc is None:
|
||||
path = glob.glob('{}/lib/libthnvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
|
||||
_thnvrtc = ctypes.cdll.LoadLibrary(path)
|
||||
result = _thnvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
|
||||
assert result == 0, 'cuDevicePrimaryCtxGetState failed'
|
||||
return bool(active[0])
|
||||
|
||||
|
@ -138,6 +138,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/fuser/codegen.cpp",
|
||||
"torch/csrc/jit/fuser/fallback.cpp",
|
||||
"torch/csrc/jit/fuser/cpu/fused_kernel.cpp",
|
||||
"torch/csrc/jit/fuser/cpu/dynamic_library_unix.cpp",
|
||||
"torch/csrc/jit/fuser/interface.cpp",
|
||||
"torch/csrc/jit/function.cpp",
|
||||
"test/cpp/jit/test.cpp",
|
||||
@ -147,6 +148,7 @@ libtorch_cuda_sources = [
|
||||
"torch/csrc/cuda/comm.cpp",
|
||||
"torch/csrc/cuda/nccl.cpp",
|
||||
"torch/csrc/jit/fuser/cuda/fused_kernel.cpp",
|
||||
"torch/csrc/jit/fuser/cuda/thnvrtc.cpp",
|
||||
"torch/csrc/autograd/profiler_cuda.cpp",
|
||||
"torch/csrc/autograd/functions/comm.cpp"
|
||||
]
|
||||
@ -349,9 +351,6 @@ def add_torch_libs():
|
||||
# TODO: putting USE_CUDA in propagated_pp_flags is error-prone
|
||||
propagated_pp_flags=propagated_pp_flags + [
|
||||
"-DUSE_CUDA",
|
||||
# The dynamically loaded NVRTC trick doesn't work in fbcode,
|
||||
# and it's not necessary anyway, because we have a stub
|
||||
# nvrtc library which we load canonically anyway
|
||||
"-DUSE_DIRECT_NVRTC",
|
||||
],
|
||||
deps=[
|
||||
|
27
torch/csrc/jit/fuser/cpu/dynamic_library.h
Normal file
27
torch/csrc/jit/fuser/cpu/dynamic_library.h
Normal file
@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/utils/disallow_copy.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace cpu {
|
||||
|
||||
struct DynamicLibrary {
|
||||
TH_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
|
||||
|
||||
TORCH_API DynamicLibrary(const char* name);
|
||||
|
||||
TORCH_API void* sym(const char* name);
|
||||
|
||||
TORCH_API ~DynamicLibrary();
|
||||
|
||||
private:
|
||||
void* handle = nullptr;
|
||||
};
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
39
torch/csrc/jit/fuser/cpu/dynamic_library_unix.cpp
Normal file
39
torch/csrc/jit/fuser/cpu/dynamic_library_unix.cpp
Normal file
@ -0,0 +1,39 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
|
||||
#include <torch/csrc/utils/disallow_copy.h>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <libgen.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace cpu {
|
||||
|
||||
static void* checkDL(void* x) {
|
||||
if (!x) {
|
||||
AT_ERROR("error in dlopen or dlsym: ", dlerror());
|
||||
}
|
||||
|
||||
return x;
|
||||
}
|
||||
DynamicLibrary::DynamicLibrary(const char* name) {
|
||||
// NOLINTNEXTLINE(hicpp-signed-bitwise)
|
||||
handle = checkDL(dlopen(name, RTLD_LOCAL | RTLD_NOW));
|
||||
}
|
||||
|
||||
void* DynamicLibrary::sym(const char* name) {
|
||||
AT_ASSERT(handle);
|
||||
return checkDL(dlsym(handle, name));
|
||||
}
|
||||
|
||||
DynamicLibrary::~DynamicLibrary() {
|
||||
if (!handle)
|
||||
return;
|
||||
dlclose(handle);
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
41
torch/csrc/jit/fuser/cpu/dynamic_library_win.cpp
Normal file
41
torch/csrc/jit/fuser/cpu/dynamic_library_win.cpp
Normal file
@ -0,0 +1,41 @@
|
||||
#include <Windows.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
|
||||
#include <torch/csrc/utils/disallow_copy.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace cpu {
|
||||
|
||||
|
||||
DynamicLibrary::DynamicLibrary(const char* name) {
|
||||
// NOLINTNEXTLINE(hicpp-signed-bitwise)
|
||||
HMODULE theModule = LoadLibraryA(name);
|
||||
if (theModule) {
|
||||
handle = theModule;
|
||||
} else {
|
||||
AT_ERROR("error in LoadLibraryA");
|
||||
}
|
||||
}
|
||||
|
||||
void* DynamicLibrary::sym(const char* name) {
|
||||
AT_ASSERT(handle);
|
||||
FARPROC procAddress = GetProcAddress((HMODULE)handle, name);
|
||||
if (!procAddress) {
|
||||
AT_ERROR("error in GetProcAddress");
|
||||
}
|
||||
return (void*)procAddress;
|
||||
}
|
||||
|
||||
DynamicLibrary::~DynamicLibrary() {
|
||||
if (!handle) {
|
||||
return;
|
||||
}
|
||||
FreeLibrary((HMODULE)handle);
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -2,6 +2,7 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/code_template.h>
|
||||
#include <torch/csrc/jit/fuser/compiler.h>
|
||||
#include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
|
||||
#include <torch/csrc/jit/fuser/cpu/temp_file.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
|
||||
@ -122,7 +123,7 @@ FusedKernelCPU::FusedKernelCPU(
|
||||
runCompiler(cpp_file.name(), so_file.name());
|
||||
if (debugFuser() >= 2)
|
||||
disas(so_file.name());
|
||||
so_lib = make_unique<at::DynamicLibrary>(so_file.name().c_str());
|
||||
so_lib = make_unique<DynamicLibrary>(so_file.name().c_str());
|
||||
#pragma GCC diagnostic ignored "-Wpedantic"
|
||||
kernel =
|
||||
reinterpret_cast<void (*)(uint32_t, void**)>(so_lib->sym(name_.c_str()));
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
|
||||
#include <torch/csrc/jit/fuser/fused_kernel.h>
|
||||
#include <torch/csrc/utils/disallow_copy.h>
|
||||
|
||||
@ -35,7 +36,7 @@ struct TORCH_API FusedKernelCPU : public ::torch::jit::fuser::FusedKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<at::DynamicLibrary> so_lib;
|
||||
std::unique_ptr<DynamicLibrary> so_lib;
|
||||
void (*kernel)(uint32_t, void**) = nullptr;
|
||||
};
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
#include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
|
||||
#include <torch/csrc/jit/fuser/compiler.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <ATen/CUDAGenerator.h>
|
||||
#include <THC/THC.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
|
||||
#include <torch/csrc/jit/fuser/cuda/thnvrtc.h>
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@ -23,17 +23,77 @@ namespace jit {
|
||||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
// See NOTE [ USE OF NVRTC AND DRIVER API ]
|
||||
const at::cuda::NVRTC& nvrtc() {
|
||||
return at::globalContext().getNVRTC();
|
||||
// [USE OF NVRTC AND DRIVER API]
|
||||
// libtorch does not directly link to either libnvrtc or libcuda because
|
||||
// they require libcuda to be installed. Normal CUDA code in torch uses the cuda
|
||||
// runtime libraries which can be installed even if the driver is not installed,
|
||||
// but here we specifically need to use the driver API to load JIT compiled
|
||||
// code. To accomplish this, we lazily link libthnvrtc which provides a struct
|
||||
// THNVRTC that contains function pointers to all of the apis we need.
|
||||
//
|
||||
// IT IS AN ERROR TO TRY TO CALL ANY nvrtc* or cu* FUNCTION DIRECTLY.
|
||||
// INSTEAD USE, e.g. nvrtc().cuLoadModule(...)
|
||||
// If a function is missing add it to the list in thnvrtc.
|
||||
|
||||
#ifdef USE_DIRECT_NVRTC
|
||||
std::pair<std::unique_ptr<cpu::DynamicLibrary>, THNVRTC*> loadNVRTC() {
|
||||
return std::make_pair(nullptr, torch_load_nvrtc());
|
||||
}
|
||||
#else
|
||||
std::pair<std::unique_ptr<cpu::DynamicLibrary>, THNVRTC*> loadNVRTC() {
|
||||
#if defined(_WIN32)
|
||||
std::string libthnvrtc = "thnvrtc.dll";
|
||||
#elif defined(__APPLE__)
|
||||
std::string libthnvrtc = "libthnvrtc.dylib";
|
||||
#else
|
||||
std::string libthnvrtc = "libthnvrtc.so";
|
||||
#endif
|
||||
std::unique_ptr<cpu::DynamicLibrary> libnvrtc_stub(
|
||||
new cpu::DynamicLibrary(libthnvrtc.c_str()));
|
||||
auto fn = (THNVRTC * (*)()) libnvrtc_stub->sym("torch_load_nvrtc");
|
||||
return std::make_pair(std::move(libnvrtc_stub), fn());
|
||||
}
|
||||
#endif
|
||||
|
||||
const THNVRTC& nvrtc() {
|
||||
// must hold onto DynamicLibrary otherwise it will unload
|
||||
static auto handle = loadNVRTC();
|
||||
return *handle.second;
|
||||
}
|
||||
|
||||
// We're using three CUDA APIs, so define a few helpers for error handling
|
||||
// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE, incorrectly produces the error string
|
||||
// "NVRTC unknown error." The following maps it correctly.
|
||||
static inline void nvrtcCheck(nvrtcResult result, const char* file, int line) {
|
||||
if (result != NVRTC_SUCCESS) {
|
||||
std::stringstream ss;
|
||||
ss << file << ":" << line << ": ";
|
||||
if (static_cast<int>(result) != 7)
|
||||
ss << nvrtc().nvrtcGetErrorString(result);
|
||||
else
|
||||
ss << "NVRTC_ERROR_BUILTIN_OPERATION_FAILURE";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
#define TORCH_NVRTC_CHECK(result) nvrtcCheck(result, __FILE__, __LINE__);
|
||||
|
||||
static inline void cuCheck(CUresult result, const char* file, int line) {
|
||||
if (result != CUDA_SUCCESS) {
|
||||
const char* str;
|
||||
nvrtc().cuGetErrorString(result, &str);
|
||||
std::stringstream ss;
|
||||
ss << file << ":" << line << ": " << str;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
#define TORCH_CU_CHECK(result) cuCheck(result, __FILE__, __LINE__);
|
||||
|
||||
static void getMajorMinor(
|
||||
const cudaDeviceProp* const prop,
|
||||
int& major,
|
||||
int& minor) {
|
||||
int nvrtc_major, nvrtc_minor;
|
||||
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
|
||||
TORCH_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
|
||||
|
||||
// Short-circuits if NVRTC version too low
|
||||
AT_ASSERT(nvrtc_major >= 6);
|
||||
@ -85,7 +145,7 @@ FusedKernelCUDA::FusedKernelCUDA(
|
||||
device_(device) {
|
||||
// Initializes driver's API context (if necessary)
|
||||
CUcontext pctx = 0;
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
|
||||
TORCH_CU_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
std::unique_lock<std::mutex> cudaFreeMutexLock(
|
||||
*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
|
||||
@ -105,7 +165,7 @@ FusedKernelCUDA::FusedKernelCUDA(
|
||||
|
||||
// Creates the NVRTC program
|
||||
nvrtcProgram program;
|
||||
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
|
||||
TORCH_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
|
||||
&program, code_.c_str(), nullptr, 0, nullptr, nullptr));
|
||||
|
||||
const std::string compute = "--gpu-architecture=compute_" +
|
||||
@ -124,19 +184,19 @@ FusedKernelCUDA::FusedKernelCUDA(
|
||||
throw std::runtime_error(cu.str());
|
||||
}
|
||||
ResourceGuard holdProgram(
|
||||
[&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
|
||||
AT_CUDA_NVRTC_CHECK(result);
|
||||
[&] { TORCH_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
|
||||
TORCH_NVRTC_CHECK(result);
|
||||
size_t ptx_size;
|
||||
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size));
|
||||
TORCH_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size));
|
||||
ptx_.resize(ptx_size);
|
||||
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx_.data()));
|
||||
TORCH_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx_.data()));
|
||||
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module_, ptx_.data()));
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
TORCH_CU_CHECK(nvrtc().cuModuleLoadData(&module_, ptx_.data()));
|
||||
TORCH_CU_CHECK(
|
||||
nvrtc().cuModuleGetFunction(&function_, module_, name_.c_str()));
|
||||
|
||||
// Computes max blocks
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
TORCH_CU_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxBlocks_, function_, 128, 0));
|
||||
maxBlocks_ *= prop_->multiProcessorCount;
|
||||
|
||||
@ -176,7 +236,7 @@ void FusedKernelCUDA::launch_raw(
|
||||
|
||||
// Launches kernel on current stream (device was set by executor)
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
|
||||
TORCH_CU_CHECK(nvrtc().cuLaunchKernel(
|
||||
function_,
|
||||
nBlocks,
|
||||
1,
|
||||
|
9
torch/csrc/jit/fuser/cuda/thnvrtc.cpp
Normal file
9
torch/csrc/jit/fuser/cuda/thnvrtc.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
#include <torch/csrc/jit/fuser/cuda/thnvrtc.h>
|
||||
#include <iostream>
|
||||
|
||||
THNVRTC* torch_load_nvrtc() {
|
||||
auto self = new THNVRTC();
|
||||
#define CREATE_ASSIGN(name) self->name = name;
|
||||
TORCH_FORALL_NVRTC(CREATE_ASSIGN)
|
||||
return self;
|
||||
}
|
34
torch/csrc/jit/fuser/cuda/thnvrtc.h
Normal file
34
torch/csrc/jit/fuser/cuda/thnvrtc.h
Normal file
@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <cuda.h>
|
||||
#include <nvrtc.h>
|
||||
|
||||
// See [USE OF NVRTC AND DRIVER API]
|
||||
|
||||
#define TORCH_FORALL_NVRTC(_) \
|
||||
_(nvrtcVersion) \
|
||||
_(nvrtcCreateProgram) \
|
||||
_(nvrtcDestroyProgram) \
|
||||
_(nvrtcGetPTXSize) \
|
||||
_(nvrtcGetPTX) \
|
||||
_(cuModuleLoadData) \
|
||||
_(cuModuleGetFunction) \
|
||||
_(cuOccupancyMaxActiveBlocksPerMultiprocessor) \
|
||||
_(cuGetErrorString) \
|
||||
_(nvrtcGetErrorString) \
|
||||
_(nvrtcGetProgramLogSize) \
|
||||
_(nvrtcGetProgramLog) \
|
||||
_(cuLaunchKernel) \
|
||||
_(nvrtcCompileProgram) \
|
||||
_(cuCtxGetCurrent) \
|
||||
_(cuModuleUnload) \
|
||||
_(cuDevicePrimaryCtxGetState)
|
||||
|
||||
extern "C" typedef struct THNVRTC {
|
||||
#define CREATE_MEMBER(name) decltype(&name) name;
|
||||
TORCH_FORALL_NVRTC(CREATE_MEMBER)
|
||||
#undef CREATE_MEMBER
|
||||
} THNVRTC;
|
||||
|
||||
extern "C" TORCH_API THNVRTC* torch_load_nvrtc();
|
@ -1,5 +1,4 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Utils.h>
|
||||
|
||||
#define TH_DISALLOW_COPY_AND_ASSIGN AT_DISALLOW_COPY_AND_ASSIGN
|
||||
#define TH_DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||
TypeName(const TypeName&) = delete; \
|
||||
void operator=(const TypeName&) = delete
|
||||
|
Reference in New Issue
Block a user