mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add get_stream_from_external API for CUDA backend (#143799)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143799 Approved by: https://github.com/albanD, https://github.com/EikanWang ghstack dependencies: #142347, #141119, #141123
This commit is contained in:
committed by
PyTorch MergeBot
parent
8f6c4d1732
commit
3848de55ed
@ -950,6 +950,9 @@ class TestCudaMultiGPU(TestCase):
|
||||
ext_stream = torch.cuda.ExternalStream(stream_v)
|
||||
self.assertEqual(stream_v, ext_stream.cuda_stream)
|
||||
self.assertEqual(ext_stream.device.index, device.idx)
|
||||
ext_stream = torch.cuda.get_stream_from_external(stream_v, device)
|
||||
self.assertEqual(stream_v, ext_stream.cuda_stream)
|
||||
self.assertEqual(ext_stream.device.index, device.idx)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
||||
def test_external_streams_multi_device(self):
|
||||
@ -958,6 +961,9 @@ class TestCudaMultiGPU(TestCase):
|
||||
ext_stream = torch.cuda.ExternalStream(stream_v, device=device)
|
||||
self.assertEqual(stream_v, ext_stream.cuda_stream)
|
||||
self.assertEqual(ext_stream.device.index, device.idx)
|
||||
ext_stream = torch.cuda.get_stream_from_external(stream_v, device)
|
||||
self.assertEqual(stream_v, ext_stream.cuda_stream)
|
||||
self.assertEqual(ext_stream.device.index, device.idx)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_caching_pinned_memory_multi_gpu(self):
|
||||
|
Reference in New Issue
Block a user