expose fast get_current_stream (#78165)

Expose fast no-frills version of getting raw `cudaStream_t` in python (200 ns instead of 4 us)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78165
Approved by: https://github.com/SherlockNoMad, https://github.com/soumith, https://github.com/gchanan
This commit is contained in:
Natalia Gimelshein
2022-05-24 15:54:47 +00:00
committed by PyTorch MergeBot
parent 49e15b578a
commit 4ea176ea57
2 changed files with 15 additions and 0 deletions

View File

@ -910,6 +910,7 @@ def _multiprocessing_init() -> None: ...
# Defined in torch/csrc/cuda/Module.cpp
def _cuda_getCurrentStream(device: _int) -> _int: ...
def _cuda_getCurrentRawStream(device: _int) -> _int: ...
def _cuda_getDefaultStream(device: _int) -> _int: ...
def _cuda_getCurrentBlasHandle() -> _int: ...
def _cuda_setDevice(device: _int) -> None: ...

View File

@ -150,6 +150,18 @@ PyObject * THCPModule_getCurrentStream_wrap(
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_getCurrentStream_raw(
PyObject * /* unused */, PyObject *device_index) {
HANDLE_TH_ERRORS
THPUtils_assert(
THPUtils_checkLong(device_index), "invalid argument to getCurrentStream");
int64_t device = THPUtils_unpackLong(device_index);
return PyLong_FromVoidPtr(
at::cuda::getCurrentCUDAStream(device).stream());
END_HANDLE_TH_ERRORS
}
PyObject * THCPModule_getDefaultStream_wrap(
PyObject * /* unused */, PyObject *device_index) {
HANDLE_TH_ERRORS
@ -668,6 +680,8 @@ static struct PyMethodDef _THCPModule_methods[] = {
{"_cuda_isInBadFork", THCPModule_isInBadFork, METH_NOARGS, nullptr},
{"_cuda_getCurrentStream",
THCPModule_getCurrentStream_wrap, METH_O, nullptr},
{"_cuda_getCurrentRawStream",
THCPModule_getCurrentStream_raw, METH_O, nullptr},
{"_cuda_getDefaultStream",
THCPModule_getDefaultStream_wrap, METH_O, nullptr},
{"_cuda_getCurrentBlasHandle", THCPModule_getCurrentBlasHandle_wrap, METH_NOARGS, nullptr},