Add get_stream_from_external API for CUDA backend (#143799)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143799
Approved by: https://github.com/albanD, https://github.com/EikanWang
ghstack dependencies: #142347, #141119, #141123
This commit is contained in:
Yu, Guangye
2024-12-31 09:55:42 +00:00
committed by PyTorch MergeBot
parent 8f6c4d1732
commit 3848de55ed
5 changed files with 53 additions and 0 deletions

View File

@ -23,6 +23,7 @@ torch.cuda
get_device_name get_device_name
get_device_properties get_device_properties
get_gencode_flags get_gencode_flags
get_stream_from_external
get_sync_debug_mode get_sync_debug_mode
init init
ipc_collect ipc_collect

View File

@ -950,6 +950,9 @@ class TestCudaMultiGPU(TestCase):
ext_stream = torch.cuda.ExternalStream(stream_v) ext_stream = torch.cuda.ExternalStream(stream_v)
self.assertEqual(stream_v, ext_stream.cuda_stream) self.assertEqual(stream_v, ext_stream.cuda_stream)
self.assertEqual(ext_stream.device.index, device.idx) self.assertEqual(ext_stream.device.index, device.idx)
ext_stream = torch.cuda.get_stream_from_external(stream_v, device)
self.assertEqual(stream_v, ext_stream.cuda_stream)
self.assertEqual(ext_stream.device.index, device.idx)
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_external_streams_multi_device(self): def test_external_streams_multi_device(self):
@ -958,6 +961,9 @@ class TestCudaMultiGPU(TestCase):
ext_stream = torch.cuda.ExternalStream(stream_v, device=device) ext_stream = torch.cuda.ExternalStream(stream_v, device=device)
self.assertEqual(stream_v, ext_stream.cuda_stream) self.assertEqual(stream_v, ext_stream.cuda_stream)
self.assertEqual(ext_stream.device.index, device.idx) self.assertEqual(ext_stream.device.index, device.idx)
ext_stream = torch.cuda.get_stream_from_external(stream_v, device)
self.assertEqual(stream_v, ext_stream.cuda_stream)
self.assertEqual(ext_stream.device.index, device.idx)
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_caching_pinned_memory_multi_gpu(self): def test_caching_pinned_memory_multi_gpu(self):

View File

@ -1827,6 +1827,7 @@ def _mps_elapsedTimeOfEvents(start_event_id: _int, end_event_id: _int) -> _float
def _cuda_getCurrentStream(device: _int) -> Tuple: ... def _cuda_getCurrentStream(device: _int) -> Tuple: ...
def _cuda_getCurrentRawStream(device: _int) -> _int: ... def _cuda_getCurrentRawStream(device: _int) -> _int: ...
def _cuda_getDefaultStream(device: _int) -> Tuple: ... def _cuda_getDefaultStream(device: _int) -> Tuple: ...
def _cuda_getStreamFromExternal(data_ptr: _int, device_index: _int) -> Tuple: ...
def _cuda_getCurrentBlasHandle() -> _int: ... def _cuda_getCurrentBlasHandle() -> _int: ...
def _cuda_clearCublasWorkspaces() -> None: ... def _cuda_clearCublasWorkspaces() -> None: ...
def _cuda_setDevice(device: _int) -> None: ... def _cuda_setDevice(device: _int) -> None: ...

View File

@ -1867,6 +1867,21 @@ PyObject* THCPModule_benchmarkLimitCuDNN(PyObject* _unused, PyObject* noargs) {
return THPUtils_packInt32(at::globalContext().benchmarkLimitCuDNN()); return THPUtils_packInt32(at::globalContext().benchmarkLimitCuDNN());
} }
static void initCudaMethodBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def(
"_cuda_getStreamFromExternal",
[](uintptr_t data_ptr, c10::DeviceIndex device_index) {
cudaStream_t ext_stream =
// NOLINTNEXTLINE(performance-no-int-to-ptr)
reinterpret_cast<cudaStream_t>(reinterpret_cast<void*>(data_ptr));
at::cuda::CUDAStream stream =
c10::cuda::getStreamFromExternal(ext_stream, device_index);
return std::make_tuple(
stream.id(), stream.device_index(), stream.device_type());
});
}
// NOLINTNEXTLINE(*-c-arrays*, *-global-variables) // NOLINTNEXTLINE(*-c-arrays*, *-global-variables)
static struct PyMethodDef _THCPModule_methods[] = { static struct PyMethodDef _THCPModule_methods[] = {
{"_cuda_init", THCPModule_initExtension, METH_NOARGS, nullptr}, {"_cuda_init", THCPModule_initExtension, METH_NOARGS, nullptr},
@ -2127,6 +2142,7 @@ void initModule(PyObject* module) {
shared::initGdsBindings(module); shared::initGdsBindings(module);
registerCudaDeviceProperties(module); registerCudaDeviceProperties(module);
registerCudaPluggableAllocator(module); registerCudaPluggableAllocator(module);
initCudaMethodBindings(module);
} }
} // namespace torch::cuda } // namespace torch::cuda

View File

@ -1036,6 +1036,34 @@ def default_stream(device: Optional[_device_t] = None) -> Stream:
) )
def get_stream_from_external(
data_ptr: int, device: Optional[_device_t] = None
) -> Stream:
r"""Return a :class:`Stream` from an externally allocated CUDA stream.
This function is used to wrap streams allocated in other libraries in order
to facilitate data exchange and multi-library interactions.
.. note:: This function doesn't manage the stream life-cycle, it is the user
responsibility to keep the referenced stream alive while this returned
stream is being used.
Args:
data_ptr(int): Integer representation of the `cudaStream_t` value that
is allocated externally.
device(torch.device or int, optional): the device where the stream
was originally allocated. If device is specified incorrectly,
subsequent launches using this stream may fail.
"""
_lazy_init()
streamdata = torch._C._cuda_getStreamFromExternal(
data_ptr, _get_device_index(device, optional=True)
)
return Stream(
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
)
def current_blas_handle(): def current_blas_handle():
r"""Return cublasHandle_t pointer to current cuBLAS handle""" r"""Return cublasHandle_t pointer to current cuBLAS handle"""
_lazy_init() _lazy_init()
@ -1669,6 +1697,7 @@ __all__ = [
"get_per_process_memory_fraction", "get_per_process_memory_fraction",
"get_rng_state", "get_rng_state",
"get_rng_state_all", "get_rng_state_all",
"get_stream_from_external",
"get_sync_debug_mode", "get_sync_debug_mode",
"graph", "graph",
"graph_pool_handle", "graph_pool_handle",