mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move thnvrtc and DynamicLibrary to ATen (#22362)
Summary: Having the NVRTC stub in ATen is necessary to call driver APIs in ATen. This is currently blocking https://github.com/pytorch/pytorch/pull/22229. `DynamicLibrary` is also moved as it is used in the stub code, and seems general enough. Pull Request resolved: https://github.com/pytorch/pytorch/pull/22362 Differential Revision: D16131787 Pulled By: ezyang fbshipit-source-id: add2ee8a8865229578aa00001a00d5a6671e0e73
This commit is contained in:
committed by
Facebook Github Bot
parent
74883d4865
commit
31d821e267
@ -19,17 +19,17 @@ if not TEST_CUDA:
|
||||
TestCase = object # noqa: F811
|
||||
|
||||
|
||||
_thnvrtc = None
|
||||
_caffe2_nvrtc = None
|
||||
|
||||
|
||||
def get_is_primary_context_created(device):
|
||||
flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
|
||||
active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
||||
global _thnvrtc
|
||||
if _thnvrtc is None:
|
||||
path = glob.glob('{}/lib/libthnvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
|
||||
_thnvrtc = ctypes.cdll.LoadLibrary(path)
|
||||
result = _thnvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
|
||||
global _caffe2_nvrtc
|
||||
if _caffe2_nvrtc is None:
|
||||
path = glob.glob('{}/lib/libcaffe2_nvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
|
||||
_caffe2_nvrtc = ctypes.cdll.LoadLibrary(path)
|
||||
result = _caffe2_nvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
|
||||
assert result == 0, 'cuDevicePrimaryCtxGetState failed'
|
||||
return bool(active[0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user