Compare commits

...

5 Commits

Author SHA1 Message Date
c4a8bfe5c6 update 2025-09-23 07:09:52 -07:00
48155a6188 uipdate 2025-09-23 06:59:49 -07:00
c1689aacf6 update 2025-09-22 23:17:03 -07:00
99dbfbdb7a update 2025-09-22 23:16:21 -07:00
b0460db289 Implement CUDA stream protocol 2025-09-22 23:14:10 -07:00
2 changed files with 28 additions and 0 deletions

View File

@ -1007,6 +1007,24 @@ print(t.is_pinned())
s.record_event(e)
self.assertTrue("torch.cuda.Event" in e.__repr__())
def test_cuda_stream_protocol(self):
stream = torch.cuda.Stream()
self.assertTrue(hasattr(stream, "__cuda_stream__"))
result = stream.__cuda_stream__()
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 2)
self.assertEqual(result[0], 0) # Protocol version
self.assertEqual(result[1], stream.cuda_stream) # Stream handle
external_stream = torch.cuda.ExternalStream(stream.cuda_stream)
external_result = external_stream.__cuda_stream__()
self.assertEqual(external_result[0], 0)
self.assertEqual(external_result[1], external_stream.cuda_stream)
def test_events(self):
stream = torch.cuda.current_stream()
event = torch.cuda.Event(enable_timing=True)

View File

@ -116,6 +116,16 @@ class Stream(torch._C._CudaStreamBase):
def __repr__(self):
return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>"
def __cuda_stream__(self):
"""Implements the CUDA Stream Protocol:
https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol
Returns:
tuple: A 2-tuple of (version, handle) where version is the protocol version
and handle is the address of cudaStream_t (CUDA) or hipStream_t (ROCm) as a Python int.
"""
return (0, self.cuda_stream)
class ExternalStream(Stream):
r"""Wrapper around an externally allocated CUDA stream.