mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support with statement on torch.Stream (#140138)
# Motivation We propose to support Python with statement on `torch.Stream`. This is a benefit for all accelerators when writing device-agnostic code. The device-specific stream will also be supported because they are generally derived from `torch.Stream`. With this PR, we can do like this ```python s1= torch.Stream() # Set s1 to the current stream torch.accelerator.set_stream(s1) with torch.Stream() as s2: # Inside with statement, we set s2 to the current stream assert torch.accelerator.current_stream() == s2 # Here the current stream should be s1 assert torch.accelerator.current_stream() == s1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/140138 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
04cb19d225
commit
6de110b862
@ -79,6 +79,29 @@ class TestAccelerator(TestCase):
|
||||
):
|
||||
torch.accelerator.current_stream(other_device)
|
||||
|
||||
def test_stream_context_manager(self):
|
||||
prev_stream = torch.accelerator.current_stream()
|
||||
with torch.Stream() as s:
|
||||
self.assertEqual(torch.accelerator.current_stream(), s)
|
||||
self.assertEqual(torch.accelerator.current_stream(), prev_stream)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
|
||||
def test_multi_device_stream_context_manager(self):
|
||||
src_device = 0
|
||||
dst_device = 1
|
||||
torch.accelerator.set_device_index(src_device)
|
||||
src_prev_stream = torch.accelerator.current_stream()
|
||||
dst_prev_stream = torch.accelerator.current_stream(dst_device)
|
||||
with torch.Stream(dst_device) as dst_stream:
|
||||
self.assertEqual(torch.accelerator.current_device_index(), dst_device)
|
||||
self.assertEqual(torch.accelerator.current_stream(), dst_stream)
|
||||
self.assertEqual(
|
||||
torch.accelerator.current_stream(src_device), src_prev_stream
|
||||
)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), src_device)
|
||||
self.assertEqual(torch.accelerator.current_stream(), src_prev_stream)
|
||||
self.assertEqual(torch.accelerator.current_stream(dst_device), dst_prev_stream)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user