Add get_stream_from_external API for CUDA backend (#143799)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143799
Approved by: https://github.com/albanD, https://github.com/EikanWang
ghstack dependencies: #142347, #141119, #141123
This commit is contained in:
Yu, Guangye
2024-12-31 09:55:42 +00:00
committed by PyTorch MergeBot
parent 8f6c4d1732
commit 3848de55ed
5 changed files with 53 additions and 0 deletions

View File

@ -950,6 +950,9 @@ class TestCudaMultiGPU(TestCase):
ext_stream = torch.cuda.ExternalStream(stream_v)
self.assertEqual(stream_v, ext_stream.cuda_stream)
self.assertEqual(ext_stream.device.index, device.idx)
ext_stream = torch.cuda.get_stream_from_external(stream_v, device)
self.assertEqual(stream_v, ext_stream.cuda_stream)
self.assertEqual(ext_stream.device.index, device.idx)
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_external_streams_multi_device(self):
@ -958,6 +961,9 @@ class TestCudaMultiGPU(TestCase):
ext_stream = torch.cuda.ExternalStream(stream_v, device=device)
self.assertEqual(stream_v, ext_stream.cuda_stream)
self.assertEqual(ext_stream.device.index, device.idx)
ext_stream = torch.cuda.get_stream_from_external(stream_v, device)
self.assertEqual(stream_v, ext_stream.cuda_stream)
self.assertEqual(ext_stream.device.index, device.idx)
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_caching_pinned_memory_multi_gpu(self):