Move thnvrtc and DynamicLibrary to ATen (#22362)

Summary:
Having the NVRTC stub in ATen is necessary to call driver APIs in ATen. This is currently blocking https://github.com/pytorch/pytorch/pull/22229.

`DynamicLibrary` is also moved as it is used in the stub code, and seems general enough.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22362

Differential Revision: D16131787

Pulled By: ezyang

fbshipit-source-id: add2ee8a8865229578aa00001a00d5a6671e0e73
This commit is contained in:
SsnL
2019-07-09 07:16:46 -07:00
committed by Facebook Github Bot
parent 74883d4865
commit 31d821e267
28 changed files with 357 additions and 271 deletions

View File

@ -19,17 +19,17 @@ if not TEST_CUDA:
TestCase = object # noqa: F811
_thnvrtc = None
_caffe2_nvrtc = None
def get_is_primary_context_created(device):
flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
global _thnvrtc
if _thnvrtc is None:
path = glob.glob('{}/lib/libthnvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
_thnvrtc = ctypes.cdll.LoadLibrary(path)
result = _thnvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
global _caffe2_nvrtc
if _caffe2_nvrtc is None:
path = glob.glob('{}/lib/libcaffe2_nvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
_caffe2_nvrtc = ctypes.cdll.LoadLibrary(path)
result = _caffe2_nvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
assert result == 0, 'cuDevicePrimaryCtxGetState failed'
return bool(active[0])