mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	expose fast get_current_stream (#78165)
Expose fast no-frills version of getting raw `cudaStream_t` in python (200 ns instead of 4 us) Pull Request resolved: https://github.com/pytorch/pytorch/pull/78165 Approved by: https://github.com/SherlockNoMad, https://github.com/soumith, https://github.com/gchanan
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							49e15b578a
						
					
				
				
					commit
					4ea176ea57
				
			| @ -910,6 +910,7 @@ def _multiprocessing_init() -> None: ... | ||||
|  | ||||
| # Defined in torch/csrc/cuda/Module.cpp | ||||
| def _cuda_getCurrentStream(device: _int) -> _int: ... | ||||
| def _cuda_getCurrentRawStream(device: _int) -> _int: ... | ||||
| def _cuda_getDefaultStream(device: _int) -> _int: ... | ||||
| def _cuda_getCurrentBlasHandle() -> _int: ... | ||||
| def _cuda_setDevice(device: _int) -> None: ... | ||||
|  | ||||
| @ -150,6 +150,18 @@ PyObject * THCPModule_getCurrentStream_wrap( | ||||
|   END_HANDLE_TH_ERRORS | ||||
| } | ||||
|  | ||||
| PyObject * THCPModule_getCurrentStream_raw( | ||||
|     PyObject * /* unused */, PyObject *device_index) { | ||||
|   HANDLE_TH_ERRORS | ||||
|   THPUtils_assert( | ||||
|     THPUtils_checkLong(device_index), "invalid argument to getCurrentStream"); | ||||
|   int64_t device = THPUtils_unpackLong(device_index); | ||||
|   return PyLong_FromVoidPtr( | ||||
|     at::cuda::getCurrentCUDAStream(device).stream()); | ||||
|   END_HANDLE_TH_ERRORS | ||||
| } | ||||
|  | ||||
|  | ||||
| PyObject * THCPModule_getDefaultStream_wrap( | ||||
|     PyObject * /* unused */, PyObject *device_index) { | ||||
|   HANDLE_TH_ERRORS | ||||
| @ -668,6 +680,8 @@ static struct PyMethodDef _THCPModule_methods[] = { | ||||
|   {"_cuda_isInBadFork", THCPModule_isInBadFork, METH_NOARGS, nullptr}, | ||||
|   {"_cuda_getCurrentStream", | ||||
|     THCPModule_getCurrentStream_wrap, METH_O, nullptr}, | ||||
|   {"_cuda_getCurrentRawStream", | ||||
|     THCPModule_getCurrentStream_raw, METH_O, nullptr}, | ||||
|   {"_cuda_getDefaultStream", | ||||
|     THCPModule_getDefaultStream_wrap, METH_O, nullptr}, | ||||
|   {"_cuda_getCurrentBlasHandle", THCPModule_getCurrentBlasHandle_wrap, METH_NOARGS, nullptr}, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user