mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Lazily load libcuda libnvrtc from c++ (#17317)
Summary: Fixes https://github.com/pytorch/pytorch/issues/16860 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17317 Differential Revision: D14157877 Pulled By: zdevito fbshipit-source-id: c37aec2d77c2e637d4fc6ceffe2bd32901c70317
This commit is contained in:
committed by
Facebook Github Bot
parent
81b43202ae
commit
356a94b64e
25
setup.py
25
setup.py
@ -638,31 +638,6 @@ if not IS_WINDOWS:
|
||||
)
|
||||
extensions.append(DL)
|
||||
|
||||
|
||||
if USE_CUDA:
|
||||
thnvrtc_link_flags = extra_link_args + [make_relative_rpath('lib')]
|
||||
if IS_LINUX:
|
||||
thnvrtc_link_flags = thnvrtc_link_flags + ['-Wl,--no-as-needed']
|
||||
# these have to be specified as -lcuda in link_flags because they
|
||||
# have to come right after the `no-as-needed` option
|
||||
if IS_WINDOWS:
|
||||
thnvrtc_link_flags += ['cuda.lib', 'nvrtc.lib']
|
||||
else:
|
||||
thnvrtc_link_flags += ['-lcuda', '-lnvrtc']
|
||||
cuda_stub_path = [cuda_lib_path + '/stubs']
|
||||
if IS_DARWIN:
|
||||
# on macOS this is where the CUDA stub is installed according to the manual
|
||||
cuda_stub_path = ["/usr/local/cuda/lib"]
|
||||
THNVRTC = Extension("torch._nvrtc",
|
||||
sources=['torch/csrc/nvrtc.cpp'],
|
||||
language='c++',
|
||||
extra_compile_args=main_compile_args + extra_compile_args,
|
||||
include_dirs=[cwd],
|
||||
library_dirs=library_dirs + cuda_stub_path,
|
||||
extra_link_args=thnvrtc_link_flags,
|
||||
)
|
||||
extensions.append(THNVRTC)
|
||||
|
||||
# These extensions are built by cmake and copied manually in build_extensions()
|
||||
# inside the build_ext implementaiton
|
||||
extensions.append(
|
||||
|
@ -2,6 +2,8 @@ import ctypes
|
||||
import torch
|
||||
from common_utils import TestCase, run_tests, skipIfRocm
|
||||
import unittest
|
||||
import glob
|
||||
import os
|
||||
|
||||
# NOTE: this needs to be run in a brand new process
|
||||
|
||||
@ -17,10 +19,17 @@ if not TEST_CUDA:
|
||||
TestCase = object # noqa: F811
|
||||
|
||||
|
||||
_thnvrtc = None
|
||||
|
||||
|
||||
def get_is_primary_context_created(device):
|
||||
flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
|
||||
active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
||||
result = torch.cuda.cudart().cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
|
||||
global _thnvrtc
|
||||
if _thnvrtc is None:
|
||||
path = glob.glob('{}/lib/libthnvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
|
||||
_thnvrtc = ctypes.cdll.LoadLibrary(path)
|
||||
result = _thnvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
|
||||
assert result == 0, 'cuDevicePrimaryCtxGetState failed'
|
||||
return bool(active[0])
|
||||
|
||||
|
@ -97,7 +97,6 @@ if not args.out_of_place_only:
|
||||
# These files use nvrtc, hip doesn't have equivalent
|
||||
"csrc/autograd/profiler.h",
|
||||
"csrc/autograd/profiler.cpp",
|
||||
"csrc/cuda/cuda_check.h",
|
||||
# These files are compatible with both cuda and hip
|
||||
"csrc/autograd/engine.cpp"
|
||||
]
|
||||
|
@ -122,6 +122,7 @@ libtorch_cuda_sources = [
|
||||
"torch/csrc/cuda/comm.cpp",
|
||||
"torch/csrc/cuda/nccl.cpp",
|
||||
"torch/csrc/jit/fuser/cuda/fused_kernel.cpp",
|
||||
"torch/csrc/jit/fuser/cuda/thnvrtc.cpp",
|
||||
"torch/csrc/autograd/profiler_cuda.cpp",
|
||||
"torch/csrc/autograd/functions/comm.cpp"
|
||||
]
|
||||
@ -213,6 +214,7 @@ def add_torch_libs():
|
||||
link_whole=True,
|
||||
propagated_pp_flags=[
|
||||
"-DUSE_CUDA",
|
||||
"-DUSE_DIRECT_NVRTC",
|
||||
],
|
||||
deps=[
|
||||
":generated-autograd-headers",
|
||||
|
@ -212,6 +212,10 @@ else ()
|
||||
list(APPEND TORCH_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp
|
||||
)
|
||||
add_library(thnvrtc SHARED ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/thnvrtc.cpp)
|
||||
target_link_libraries(thnvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB})
|
||||
target_include_directories(thnvrtc PRIVATE ${CUDA_INCLUDE_DIRS})
|
||||
install(TARGETS thnvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
|
@ -94,11 +94,6 @@ else:
|
||||
|
||||
del _dl_flags
|
||||
|
||||
try:
|
||||
import torch._nvrtc
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from torch._C import *
|
||||
|
||||
__all__ += [name for name in dir(_C)
|
||||
|
@ -1,5 +1,4 @@
|
||||
#include <torch/csrc/autograd/profiler.h>
|
||||
#include <torch/csrc/cuda/cuda_check.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <nvToolsExt.h>
|
||||
|
||||
@ -9,6 +8,15 @@ namespace torch { namespace autograd { namespace profiler {
|
||||
|
||||
namespace {
|
||||
|
||||
static inline void cudaCheck(cudaError_t result, const char * file, int line) {
|
||||
if(result != cudaSuccess) {
|
||||
std::stringstream ss;
|
||||
ss << file << ":" << line << ": " << cudaGetErrorString(result);
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
#define TORCH_CUDA_CHECK(result) cudaCheck(result,__FILE__,__LINE__);
|
||||
|
||||
struct CUDAMethods : public CUDAStubs {
|
||||
void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) override {
|
||||
TORCH_CUDA_CHECK(cudaGetDevice(device));
|
||||
|
@ -1,38 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <nvrtc.h>
|
||||
|
||||
namespace torch {
|
||||
// We're using three CUDA APIs, so define a few helpers for error handling
|
||||
static inline void nvrtcCheck(nvrtcResult result,const char * file, int line) {
|
||||
if(result != NVRTC_SUCCESS) {
|
||||
std::stringstream ss;
|
||||
ss << file << ":" << line << ": " << nvrtcGetErrorString(result);
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
#define TORCH_NVRTC_CHECK(result) ::torch::nvrtcCheck(result,__FILE__,__LINE__);
|
||||
|
||||
static inline void cuCheck(CUresult result, const char * file, int line) {
|
||||
if(result != CUDA_SUCCESS) {
|
||||
const char * str;
|
||||
cuGetErrorString(result, &str);
|
||||
std::stringstream ss;
|
||||
ss << file << ":" << line << ": " << str;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
#define TORCH_CU_CHECK(result) ::torch::cuCheck(result,__FILE__,__LINE__);
|
||||
|
||||
static inline void cudaCheck(cudaError_t result, const char * file, int line) {
|
||||
if(result != cudaSuccess) {
|
||||
std::stringstream ss;
|
||||
ss << file << ":" << line << ": " << cudaGetErrorString(result);
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
#define TORCH_CUDA_CHECK(result) ::torch::cudaCheck(result,__FILE__,__LINE__);
|
||||
|
||||
}
|
@ -17,6 +17,8 @@ struct DynamicLibrary {
|
||||
|
||||
~DynamicLibrary();
|
||||
|
||||
static std::string directoryOf(void* addr);
|
||||
|
||||
private:
|
||||
void* handle = nullptr;
|
||||
};
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <torch/csrc/utils/disallow_copy.h>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <libgen.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -32,6 +33,17 @@ DynamicLibrary::~DynamicLibrary() {
|
||||
dlclose(handle);
|
||||
}
|
||||
|
||||
std::string DynamicLibrary::directoryOf(void* addr) {
|
||||
Dl_info info = {};
|
||||
if (!dladdr(addr, &info)) {
|
||||
AT_ERROR("could not look up address: ", addr);
|
||||
}
|
||||
std::string name = info.dli_fname;
|
||||
std::vector<char> path(name.begin(), name.end());
|
||||
char* directory = dirname(path.data());
|
||||
return directory;
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
|
@ -16,6 +16,10 @@ void* DynamicLibrary::sym(const char* name) {
|
||||
AT_ERROR("NYI: DynamicLibrary on Windows");
|
||||
}
|
||||
|
||||
std::string DynamicLibrary::directoryOf(void* addr) {
|
||||
AT_ERROR("NYI: DynamicLibrary on Windows");
|
||||
}
|
||||
|
||||
DynamicLibrary::~DynamicLibrary() {}
|
||||
|
||||
} // namespace cpu
|
||||
|
@ -4,7 +4,8 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <THC/THC.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/csrc/cuda/cuda_check.h>
|
||||
#include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
|
||||
#include <torch/csrc/jit/fuser/cuda/thnvrtc.h>
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
|
||||
// Note: unclear why this forward declaration is necessary
|
||||
@ -12,9 +13,7 @@
|
||||
#include <THC/THCTensorRandom.h>
|
||||
THCGenerator* THCRandom_getGenerator(THCState* state);
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <nvrtc.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
@ -28,6 +27,18 @@ namespace jit {
|
||||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
// [USE OF NVRTC AND DRIVER API]
|
||||
// libtorch does not directly link to either libnvrtc or libcuda because
|
||||
// they require libcuda to be installed. Normal CUDA code in torch uses the cuda
|
||||
// runtime libraries which can be installed even if the driver is not installed,
|
||||
// but here we specifically need to use the driver API to load JIT compiled
|
||||
// code. To accomplish this, we lazily link libthnvrtc which provides a struct
|
||||
// THNVRTC that contains function pointers to all of the apis we need.
|
||||
//
|
||||
// IT IS AN ERROR TO TRY TO CALL ANY nvrtc* or cu* FUNCTION DIRECTLY.
|
||||
// INSTEAD USE, e.g. nvrtc().cuLoadModule(...)
|
||||
// If a function is missing add it to the list in thnvrtc.
|
||||
|
||||
void checkCUDAVersion(const cudaDeviceProp& prop) {
|
||||
if ((prop.major >= 6 && CUDA_VERSION < 8000) ||
|
||||
(prop.major >= 7 && CUDA_VERSION < 9000)) {
|
||||
@ -40,12 +51,58 @@ void checkCUDAVersion(const cudaDeviceProp& prop) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_DIRECT_NVRTC
|
||||
std::pair<std::unique_ptr<cpu::DynamicLibrary>, THNVRTC*> loadNVRTC() {
|
||||
return std::make_pair(nullptr, torch_load_nvrtc());
|
||||
}
|
||||
#else
|
||||
std::pair<std::unique_ptr<cpu::DynamicLibrary>, THNVRTC*> loadNVRTC() {
|
||||
std::string path = cpu::DynamicLibrary::directoryOf((void*)checkCUDAVersion);
|
||||
#ifdef __APPLE__
|
||||
std::string libthnvrtc = path + "/libthnvrtc.dylib";
|
||||
#else
|
||||
std::string libthnvrtc = path + "/libthnvrtc.so";
|
||||
#endif
|
||||
std::unique_ptr<cpu::DynamicLibrary> libnvrtc_stub(
|
||||
new cpu::DynamicLibrary(libthnvrtc.c_str()));
|
||||
auto fn = (THNVRTC * (*)()) libnvrtc_stub->sym("torch_load_nvrtc");
|
||||
return std::make_pair(std::move(libnvrtc_stub), fn());
|
||||
}
|
||||
#endif
|
||||
|
||||
const THNVRTC& nvrtc() {
|
||||
// must hold onto DynamicLibrary otherwise it will unload
|
||||
static auto handle = loadNVRTC();
|
||||
return *handle.second;
|
||||
}
|
||||
|
||||
// We're using three CUDA APIs, so define a few helpers for error handling
|
||||
static inline void nvrtcCheck(nvrtcResult result, const char* file, int line) {
|
||||
if (result != NVRTC_SUCCESS) {
|
||||
std::stringstream ss;
|
||||
ss << file << ":" << line << ": " << nvrtc().nvrtcGetErrorString(result);
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
#define TORCH_NVRTC_CHECK(result) nvrtcCheck(result, __FILE__, __LINE__);
|
||||
|
||||
static inline void cuCheck(CUresult result, const char* file, int line) {
|
||||
if (result != CUDA_SUCCESS) {
|
||||
const char* str;
|
||||
nvrtc().cuGetErrorString(result, &str);
|
||||
std::stringstream ss;
|
||||
ss << file << ":" << line << ": " << str;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
#define TORCH_CU_CHECK(result) cuCheck(result, __FILE__, __LINE__);
|
||||
|
||||
static void getMajorMinor(
|
||||
const cudaDeviceProp* const prop,
|
||||
int& major,
|
||||
int& minor) {
|
||||
int nvrtc_major, nvrtc_minor;
|
||||
TORCH_NVRTC_CHECK(nvrtcVersion(&nvrtc_major, &nvrtc_minor));
|
||||
TORCH_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
|
||||
|
||||
// Short-circuits if NVRTC version too low
|
||||
AT_ASSERT(nvrtc_major >= 6);
|
||||
@ -97,7 +154,7 @@ FusedKernelCUDA::FusedKernelCUDA(
|
||||
device_(device) {
|
||||
// Initializes driver's API context (if necessary)
|
||||
CUcontext pctx = 0;
|
||||
TORCH_CU_CHECK(cuCtxGetCurrent(&pctx));
|
||||
TORCH_CU_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
std::unique_lock<std::mutex> cudaFreeMutexLock(
|
||||
*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
|
||||
@ -117,14 +174,15 @@ FusedKernelCUDA::FusedKernelCUDA(
|
||||
|
||||
// Creates the NVRTC program
|
||||
nvrtcProgram program;
|
||||
TORCH_NVRTC_CHECK(nvrtcCreateProgram(
|
||||
TORCH_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
|
||||
&program, code_.c_str(), nullptr, 0, nullptr, nullptr));
|
||||
|
||||
const std::string compute = "--gpu-architecture=compute_" +
|
||||
std::to_string(major) + std::to_string(minor);
|
||||
const std::vector<const char*> args = {
|
||||
"--std=c++11", compute.c_str(), "-default-device"};
|
||||
const auto result = nvrtcCompileProgram(program, args.size(), args.data());
|
||||
const auto result =
|
||||
nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
|
||||
if (result == NVRTC_ERROR_COMPILATION) {
|
||||
size_t logsize;
|
||||
nvrtcGetProgramLogSize(program, &logsize);
|
||||
@ -135,18 +193,19 @@ FusedKernelCUDA::FusedKernelCUDA(
|
||||
throw std::runtime_error(cu.str());
|
||||
}
|
||||
ResourceGuard holdProgram(
|
||||
[&] { TORCH_NVRTC_CHECK(nvrtcDestroyProgram(&program)); });
|
||||
[&] { TORCH_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
|
||||
TORCH_NVRTC_CHECK(result);
|
||||
size_t ptx_size;
|
||||
TORCH_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size));
|
||||
TORCH_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size));
|
||||
ptx_.resize(ptx_size);
|
||||
TORCH_NVRTC_CHECK(nvrtcGetPTX(program, ptx_.data()));
|
||||
TORCH_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx_.data()));
|
||||
|
||||
TORCH_CU_CHECK(cuModuleLoadData(&module_, ptx_.data()));
|
||||
TORCH_CU_CHECK(cuModuleGetFunction(&function_, module_, name_.c_str()));
|
||||
TORCH_CU_CHECK(nvrtc().cuModuleLoadData(&module_, ptx_.data()));
|
||||
TORCH_CU_CHECK(
|
||||
nvrtc().cuModuleGetFunction(&function_, module_, name_.c_str()));
|
||||
|
||||
// Computes max blocks
|
||||
TORCH_CU_CHECK(cuOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
TORCH_CU_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxBlocks_, function_, 128, 0));
|
||||
maxBlocks_ *= prop_->multiProcessorCount;
|
||||
|
||||
@ -182,7 +241,7 @@ void FusedKernelCUDA::launch_raw(
|
||||
|
||||
// Launches kernel on current stream (device was set by executor)
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
TORCH_CU_CHECK(cuLaunchKernel(
|
||||
TORCH_CU_CHECK(nvrtc().cuLaunchKernel(
|
||||
function_,
|
||||
nBlocks,
|
||||
1,
|
||||
@ -199,6 +258,10 @@ void FusedKernelCUDA::launch_raw(
|
||||
at::cuda::set_device(prior_device);
|
||||
}
|
||||
|
||||
FusedKernelCUDA::~FusedKernelCUDA() {
|
||||
nvrtc().cuModuleUnload(module_);
|
||||
}
|
||||
|
||||
static std::shared_ptr<FusedKernel> createFusionKernel(
|
||||
int16_t device,
|
||||
std::string name,
|
||||
|
@ -30,9 +30,7 @@ struct TORCH_API FusedKernelCUDA : public ::torch::jit::fuser::FusedKernel {
|
||||
std::vector<PartitionDesc> concat_desc,
|
||||
bool has_random);
|
||||
|
||||
~FusedKernelCUDA() override {
|
||||
cuModuleUnload(module_);
|
||||
}
|
||||
~FusedKernelCUDA() override;
|
||||
|
||||
void launch_raw(const uint32_t numel, std::vector<void*>& arguments)
|
||||
const override;
|
||||
|
9
torch/csrc/jit/fuser/cuda/thnvrtc.cpp
Normal file
9
torch/csrc/jit/fuser/cuda/thnvrtc.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
#include <torch/csrc/jit/fuser/cuda/thnvrtc.h>
|
||||
#include <iostream>
|
||||
|
||||
THNVRTC* torch_load_nvrtc() {
|
||||
auto self = new THNVRTC();
|
||||
#define CREATE_ASSIGN(name) self->name = name;
|
||||
TORCH_FORALL_NVRTC(CREATE_ASSIGN)
|
||||
return self;
|
||||
}
|
31
torch/csrc/jit/fuser/cuda/thnvrtc.h
Normal file
31
torch/csrc/jit/fuser/cuda/thnvrtc.h
Normal file
@ -0,0 +1,31 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <nvrtc.h>
|
||||
|
||||
// See [USE OF NVRTC AND DRIVER API]
|
||||
|
||||
#define TORCH_FORALL_NVRTC(_) \
|
||||
_(nvrtcVersion) \
|
||||
_(nvrtcCreateProgram) \
|
||||
_(nvrtcDestroyProgram) \
|
||||
_(nvrtcGetPTXSize) \
|
||||
_(nvrtcGetPTX) \
|
||||
_(cuModuleLoadData) \
|
||||
_(cuModuleGetFunction) \
|
||||
_(cuOccupancyMaxActiveBlocksPerMultiprocessor) \
|
||||
_(cuGetErrorString) \
|
||||
_(nvrtcGetErrorString) \
|
||||
_(cuLaunchKernel) \
|
||||
_(nvrtcCompileProgram) \
|
||||
_(cuCtxGetCurrent) \
|
||||
_(cuModuleUnload) \
|
||||
_(cuDevicePrimaryCtxGetState)
|
||||
|
||||
extern "C" typedef struct THNVRTC {
|
||||
#define CREATE_MEMBER(name) decltype(&name) name;
|
||||
TORCH_FORALL_NVRTC(CREATE_MEMBER)
|
||||
#undef CREATE_MEMBER
|
||||
} THNVRTC;
|
||||
|
||||
extern "C" THNVRTC* torch_load_nvrtc();
|
@ -1,44 +0,0 @@
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
static PyObject* module;
|
||||
|
||||
static PyMethodDef TorchNvrtcMethods[] = {
|
||||
{nullptr, nullptr, 0, nullptr}
|
||||
};
|
||||
|
||||
#if PY_MAJOR_VERSION != 2
|
||||
static struct PyModuleDef torchnvrtcmodule = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"torch._nvrtc",
|
||||
nullptr,
|
||||
-1,
|
||||
TorchNvrtcMethods
|
||||
};
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
PyMODINIT_FUNC init_nvrtc(void)
|
||||
#else
|
||||
PyMODINIT_FUNC PyInit__nvrtc(void)
|
||||
#endif
|
||||
{
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
#define ASSERT_TRUE(cmd) if (!(cmd)) {PyErr_SetString(PyExc_ImportError, "initialization error in torch._nvrtc"); return;}
|
||||
#else
|
||||
#define ASSERT_TRUE(cmd) if (!(cmd)) return nullptr
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
ASSERT_TRUE(module = Py_InitModule("torch._nvrtc", TorchNvrtcMethods));
|
||||
#else
|
||||
ASSERT_TRUE(module = PyModule_Create(&torchnvrtcmodule));
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
#else
|
||||
return module;
|
||||
#endif
|
||||
|
||||
#undef ASSERT_TRUE
|
||||
}
|
Reference in New Issue
Block a user