mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement CUDA stream protocol (#163614)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/163614 Approved by: https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
2a9745de3c
commit
fc84743707
@ -1007,6 +1007,24 @@ print(t.is_pinned())
|
||||
s.record_event(e)
|
||||
self.assertTrue("torch.cuda.Event" in e.__repr__())
|
||||
|
||||
def test_cuda_stream_protocol(self):
|
||||
stream = torch.cuda.Stream()
|
||||
|
||||
self.assertTrue(hasattr(stream, "__cuda_stream__"))
|
||||
|
||||
result = stream.__cuda_stream__()
|
||||
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0], 0) # Protocol version
|
||||
self.assertEqual(result[1], stream.cuda_stream) # Stream handle
|
||||
|
||||
external_stream = torch.cuda.ExternalStream(stream.cuda_stream)
|
||||
external_result = external_stream.__cuda_stream__()
|
||||
|
||||
self.assertEqual(external_result[0], 0)
|
||||
self.assertEqual(external_result[1], external_stream.cuda_stream)
|
||||
|
||||
def test_events(self):
|
||||
stream = torch.cuda.current_stream()
|
||||
event = torch.cuda.Event(enable_timing=True)
|
||||
|
@ -116,6 +116,16 @@ class Stream(torch._C._CudaStreamBase):
|
||||
def __repr__(self):
|
||||
return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>"
|
||||
|
||||
def __cuda_stream__(self):
|
||||
"""Implements the CUDA Stream Protocol:
|
||||
https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol
|
||||
|
||||
Returns:
|
||||
tuple: A 2-tuple of (version, handle) where version is the protocol version
|
||||
and handle is the address of cudaStream_t (CUDA) or hipStream_t (ROCm) as a Python int.
|
||||
"""
|
||||
return (0, self.cuda_stream)
|
||||
|
||||
|
||||
class ExternalStream(Stream):
|
||||
r"""Wrapper around an externally allocated CUDA stream.
|
||||
|
Reference in New Issue
Block a user