mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
expose fast get_current_stream (#78165)
Expose fast no-frills version of getting raw `cudaStream_t` in python (200 ns instead of 4 us) Pull Request resolved: https://github.com/pytorch/pytorch/pull/78165 Approved by: https://github.com/SherlockNoMad, https://github.com/soumith, https://github.com/gchanan
This commit is contained in:
committed by
PyTorch MergeBot
parent
49e15b578a
commit
4ea176ea57
@ -910,6 +910,7 @@ def _multiprocessing_init() -> None: ...
|
||||
|
||||
# Defined in torch/csrc/cuda/Module.cpp
|
||||
def _cuda_getCurrentStream(device: _int) -> _int: ...
|
||||
def _cuda_getCurrentRawStream(device: _int) -> _int: ...
|
||||
def _cuda_getDefaultStream(device: _int) -> _int: ...
|
||||
def _cuda_getCurrentBlasHandle() -> _int: ...
|
||||
def _cuda_setDevice(device: _int) -> None: ...
|
||||
|
||||
@ -150,6 +150,18 @@ PyObject * THCPModule_getCurrentStream_wrap(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject * THCPModule_getCurrentStream_raw(
|
||||
PyObject * /* unused */, PyObject *device_index) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assert(
|
||||
THPUtils_checkLong(device_index), "invalid argument to getCurrentStream");
|
||||
int64_t device = THPUtils_unpackLong(device_index);
|
||||
return PyLong_FromVoidPtr(
|
||||
at::cuda::getCurrentCUDAStream(device).stream());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
PyObject * THCPModule_getDefaultStream_wrap(
|
||||
PyObject * /* unused */, PyObject *device_index) {
|
||||
HANDLE_TH_ERRORS
|
||||
@ -668,6 +680,8 @@ static struct PyMethodDef _THCPModule_methods[] = {
|
||||
{"_cuda_isInBadFork", THCPModule_isInBadFork, METH_NOARGS, nullptr},
|
||||
{"_cuda_getCurrentStream",
|
||||
THCPModule_getCurrentStream_wrap, METH_O, nullptr},
|
||||
{"_cuda_getCurrentRawStream",
|
||||
THCPModule_getCurrentStream_raw, METH_O, nullptr},
|
||||
{"_cuda_getDefaultStream",
|
||||
THCPModule_getDefaultStream_wrap, METH_O, nullptr},
|
||||
{"_cuda_getCurrentBlasHandle", THCPModule_getCurrentBlasHandle_wrap, METH_NOARGS, nullptr},
|
||||
|
||||
Reference in New Issue
Block a user